nearestneighbors.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier
  2. from sklearn.ensemble import RandomForestClassifier
  3. import numpy as np
  4. import sys
  5. from Vector import *
  6. def main():
  7. # a test of this method using an arbitrarily generated list of 5 vectors with 3 features each
  8. # nearestNeighbors([[1, 1, 0], [1, 0, 0], [0, 0, 0], [0, 5, 5]], [[1, 1, 4]])
  9. print(len(sys.argv))
  10. if len(sys.argv) != 5:
  11. print("Usage: nearestneighbors.py datafile.bin classificationsfile.bin testdatafile.bin -(p/e)")
  12. exit()
  13. data = readPickledData(sys.argv[1])
  14. classifcations = readPickledData(sys.argv[2])
  15. testdata = readPickledData(sys.argv[3])
  16. newdata, newtest = [], []
  17. for d in data:
  18. newdata.append(d.features)
  19. for d in testdata:
  20. newtest.append(d.features)
  21. print(newdata)
  22. print(classifcations)
  23. print(newtest)
  24. kNearestNeighbors(newdata, classifcations, newtest)
  25. # print("Random Forest:")
  26. # randomForest(newdata, classifcations, newtest)
  27. # kNearestNeighbors([[1, 1, 0], [1, 0, 0], [0, 0, 0], [0, 5, 5]], ["three", 2, 3, "5"], [[1, 1, 0], [0, 5, 5]])
  28. def kNearestNeighbors(data: list, classifications: list, test_data: list):
  29. kn = KNeighborsClassifier(n_neighbors=2)
  30. kn.fit(data, classifications)
  31. p = kn.predict(test_data)
  32. print("Predictions, matching test_data by index: ")
  33. print(test_data)
  34. print(p)
  35. writestr = "Predictions, matching test_data by index:\n" + str(test_data) + "\n" + str(p)
  36. if sys.argv[4][1] == 'p':
  37. pickle.dump((test_data, p), open("results.bin", "wb"))
  38. else:
  39. with open("results.txt", "w+") as file:
  40. file.write(writestr)
  41. def nearestNeighbors(data: list, test_data: list):
  42. x = np.array(data)
  43. nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(x)
  44. dist, indicies = nbrs.kneighbors(test_data)
  45. print("Indicies:")
  46. print(indicies)
  47. print("Distances:")
  48. print(dist)
  49. return indicies, dist
  50. def randomForest(data: list, classifications: list, test_data: list):
  51. rfc = RandomForestClassifier(n_estimators=len(data))
  52. rfc.fit(data, classifications)
  53. print(rfc.predict(test_data))
  54. if __name__ == '__main__':
  55. main()