#!/usr/bin/python3 from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier from sklearn.ensemble import RandomForestClassifier import numpy as np import sys from Vector import * def main(): # a test of this method using an arbitrarily generated list of 5 vectors with # 3 features each # nearestNeighbors([[1, 1, 0], [1, 0, 0], [0, 0, 0], [0, 5, 5]], [[1, 1, 4]]) print(len(sys.argv)) if len(sys.argv) != 5: print("Usage: nearestneighbors.py datafile.bin classificationsfile.bin " \ "testdatafile.bin -(p/e)") exit() data = readPickledData(sys.argv[1]) classifcations = readPickledData(sys.argv[2]) testdata = readPickledData(sys.argv[3]) newdata, newtest = [], [] for d in data: newdata.append(d.features) for d in testdata: newtest.append(d.features) print(newdata) print(classifcations) print(newtest) kNearestNeighbors(newdata, classifcations, newtest) # print("Random Forest:") # randomForest(newdata, classifcations, newtest) # kNearestNeighbors([[1, 1, 0], [1, 0, 0], [0, 0, 0], [0, 5, 5]], # ["three", 2, 3, "5"], [[1, 1, 0], [0, 5, 5]]) def kNearestNeighbors(data: list, classifications: list): folds = KFold(n_splits=5) for train_index, test_index in folds.split(data): kn = KNeighborsClassifier(n_neighbors=2) kn.fit(data[train_index], classifications[train_index]) p = kn.predict(test_data[test_index]) print("Predictions, matching test_data by index: ") print(test_data[test_index]) print(p) writestr = "Predictions, matching test_data by index:\n" + str(test_data) \ + "\n" + str(p) # if sys.argv[4][1] == 'p': # pickle.dump((test_data, p), open("results.bin", "wb")) # else: # with open("results.txt", "w+") as file: # file.write(writestr) def nearestNeighbors(data: list, test_data: list): x = np.array(data) nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(x) dist, indicies = nbrs.kneighbors(test_data) print("Indicies:") print(indicies) print("Distances:") print(dist) return indicies, dist def randomForest(data: list, classifications: list, test_data: list): rfc = RandomForestClassifier(n_estimators=len(data)) rfc.fit(data, classifications) print(rfc.predict(test_data)) if __name__ == '__main__': main()