#!/usr/bin/python3 from sklearn.model_selection import KFold from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier from sklearn.ensemble import RandomForestClassifier import numpy as np import sys from Vector import FeatureVector DEFAULT_FEATURES = ["average_iat", "high.avg_burst_size", "high.burst_count"] def main(): # a test of this method using an arbitrarily generated list of 5 vectors with # 3 features each # nearestNeighbors([[1, 1, 0], [1, 0, 0], [0, 0, 0], [0, 5, 5]], [[1, 1, 4]]) args = parse_args() try: import cPickle as pickle except: import pickle samples = pickle.load(args.features_file) features = args.feature if args.feature else DEFAULT_FEATURES from random import shuffle shuffle(samples) data, labels = zip(*[(FeatureVector(p, features).get(), p.user) for p in samples]) res = kNearestNeighbors(np.array(data), np.array(labels), n=args.folds, verbose=args.verbose) print("Overall Accuracy: %f" % res) def parse_args(): import argparse parser = argparse.ArgumentParser( description='Run a data set through a kNearestNeighbors classifier.') parser.add_argument('features_file', type=argparse.FileType('rb'), help='File of extracted features.') parser.add_argument('-v', '--verbose', action="count", default=0, help='Show more information') parser.add_argument('-n', '--folds', type=int, default=5, help='Number of cross-validation folds (default: 5)') parser.add_argument('-f', '--feature', action='append', type=str, help='Add feature to list of features to test with.') return parser.parse_args() def kNearestNeighbors(data: list, labels: list, n=5, verbose=0): folds = KFold(n_splits=n) i = 1 avg = 0 for train_index, test_index in folds.split(data): if verbose >= 1: print("Round %d:" % i) i += 1 if verbose >= 2: print("Training on: ", train_index) kn = KNeighborsClassifier(n_neighbors=2) kn.fit(data[train_index], labels[train_index]) predictions = kn.predict(data[test_index]) correct = [a == p for a, p in zip(labels[test_index], predictions)] accuracy = correct.count(True)/len(correct) if verbose >= 1: print(accuracy) avg += accuracy return avg/n if __name__ == '__main__': main()