nearestneighbors.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier
  2. import numpy as np
  3. import sys
  4. from Vector import *
  5. def main():
  6. # a test of this method using an arbitrarily generated list of 5 vectors with 3 features each
  7. # nearestNeighbors([[1, 1, 0], [1, 0, 0], [0, 0, 0], [0, 5, 5]], [[1, 1, 4]])
  8. print(len(sys.argv))
  9. if len(sys.argv) != 4:
  10. print("Usage: nearestneighbors.py datafile.bin classificationsfile.bin testdatafile.bin")
  11. exit()
  12. data = readPickledData(sys.argv[1])
  13. classifcations = readPickledData(sys.argv[2])
  14. testdata = readPickledData(sys.argv[3])
  15. newdata, newtest = [], []
  16. for d in data:
  17. newdata.append(d.features)
  18. for d in testdata:
  19. newtest.append(d.features)
  20. print(newdata)
  21. print(classifcations)
  22. print(newtest)
  23. kNearestNeighbors(newdata, classifcations, newtest)
  24. # kNearestNeighbors([[1, 1, 0], [1, 0, 0], [0, 0, 0], [0, 5, 5]], ["three", 2, 3, "5"], [[1, 1, 0], [0, 5, 5]])
  25. def kNearestNeighbors(data: list, classifications: list, test_data: list):
  26. kn = KNeighborsClassifier(n_neighbors=2)
  27. kn.fit(data, classifications)
  28. print("Predictions, matching test_data by index: ")
  29. print(test_data)
  30. print(kn.predict(test_data))
  31. def nearestNeighbors(data: list, test_data: list):
  32. x = np.array(data)
  33. nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(x)
  34. dist, indicies = nbrs.kneighbors(test_data)
  35. print("Indicies:")
  36. print(indicies)
  37. print("Distances:")
  38. print(dist)
  39. return indicies, dist
  40. if __name__ == '__main__':
  41. main()