nearestneighbors.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. #!/usr/bin/python3
  2. from sklearn.model_selection import KFold
  3. from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier
  4. from sklearn.ensemble import RandomForestClassifier
  5. import numpy as np
  6. import sys
  7. from Vector import *
  8. def main():
  9. # a test of this method using an arbitrarily generated list of 5 vectors with
  10. # 3 features each
  11. # nearestNeighbors([[1, 1, 0], [1, 0, 0], [0, 0, 0], [0, 5, 5]], [[1, 1, 4]])
  12. print(len(sys.argv))
  13. if len(sys.argv) != 5:
  14. print("Usage: nearestneighbors.py datafile.bin classificationsfile.bin " \
  15. "testdatafile.bin -(p/e)")
  16. exit()
  17. data = readPickledData(sys.argv[1])
  18. classifcations = readPickledData(sys.argv[2])
  19. testdata = readPickledData(sys.argv[3])
  20. newdata, newtest = [], []
  21. for d in data:
  22. newdata.append(d.features)
  23. for d in testdata:
  24. newtest.append(d.features)
  25. print(newdata)
  26. print(classifcations)
  27. print(newtest)
  28. kNearestNeighbors(newdata, classifcations, newtest)
  29. # print("Random Forest:")
  30. # randomForest(newdata, classifcations, newtest)
  31. # kNearestNeighbors([[1, 1, 0], [1, 0, 0], [0, 0, 0], [0, 5, 5]],
  32. # ["three", 2, 3, "5"], [[1, 1, 0], [0, 5, 5]])
  33. def kNearestNeighbors(data: list, classifications: list, test_data: list):
  34. folds = KFold(n_splits=5)
  35. for train_index, test_index in folds.split(data):
  36. kn = KNeighborsClassifier(n_neighbors=2)
  37. kn.fit(data[train_index], classifications[train_index])
  38. p = kn.predict(test_data[test_index])
  39. print("Predictions, matching test_data by index: ")
  40. print(test_data[test_index])
  41. print(p)
  42. writestr = "Predictions, matching test_data by index:\n" + str(test_data) \
  43. + "\n" + str(p)
  44. # if sys.argv[4][1] == 'p':
  45. # pickle.dump((test_data, p), open("results.bin", "wb"))
  46. # else:
  47. # with open("results.txt", "w+") as file:
  48. # file.write(writestr)
  49. def nearestNeighbors(data: list, test_data: list):
  50. x = np.array(data)
  51. nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(x)
  52. dist, indicies = nbrs.kneighbors(test_data)
  53. print("Indicies:")
  54. print(indicies)
  55. print("Distances:")
  56. print(dist)
  57. return indicies, dist
  58. def randomForest(data: list, classifications: list, test_data: list):
  59. rfc = RandomForestClassifier(n_estimators=len(data))
  60. rfc.fit(data, classifications)
  61. print(rfc.predict(test_data))
  62. if __name__ == '__main__':
  63. main()