Pārlūkot izejas kodu

Added standard classifier interface.

Tom Flucke 6 gadi atpakaļ
vecāks
revīzija
28ad91d329
1 mainītis faili ar 30 papildinājumiem un 24 dzēšanām
  1. 30 24
      src/classifiers/nearestneighbors.py

+ 30 - 24
src/classifiers/nearestneighbors.py

@@ -1,43 +1,38 @@
 #!/usr/bin/python3
+
 from sklearn.model_selection import KFold
 import numpy as np
+import typing
 try:
     import sample
 except ImportError:
-    import os
-    import sys
+    import os, sys
     sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + \
                     '/../feature-extractor')
     import sample
 
 DEFAULT_FEATURES = ["average_iat", "high.avg_burst_size", "high.burst_count"]
 
-def main():
-    # a test of this method using an arbitrarily generated list of 5 vectors with
-    # 3 features each
-    # nearestNeighbors([[1, 1, 0], [1, 0, 0], [0, 0, 0], [0, 5, 5]], [[1, 1, 4]])
-    args = parse_args()
+def main(options: list):
+    args = parse_args(options)
     try:
         import cPickle as pickle
     except:
         import pickle
     samples = pickle.load(args.features_file)
-    from random import shuffle
-    shuffle(samples)
+    num_users=len(np.unique([s.user for s in samples]))
+    assert(num_users >= args.min_users)
     features = args.feature if args.feature else DEFAULT_FEATURES
     from Vector import FeatureVector
-    data, labels = zip(*[(FeatureVector(p, features).get(), p.user)
-                         for p in samples])
-    res = kNearestNeighbors(np.array(data), np.array(labels),
-                            n=args.folds, verbose=args.verbose, k=args.k_neighbors,
-                            weights=args.weight, guesses=args.top)
-    print("Overall Accuracy: %f" % np.average(res))
+    data, labels = map(np.array,
+                       zip(*[(FeatureVector(p, features).get(), p.user)
+                             for p in samples]))
+    avg, p = classify(data, labels, num_users, args)
+    print("Overall Accuracy: %f" % avg)
     if args.p_value:
-        _, p = t_test(res, labels)
-        print("P-Value: %f" % (p / 2))
-
+        print("P-Value: %f" % p)
 
-def parse_args():
+def parse_args(args: list):
     import argparse
     parser = argparse.ArgumentParser(
         description='Run a data set through a kNearestNeighbors classifier.')
@@ -59,8 +54,18 @@ def parse_args():
     parser.add_argument('-t', '--top', type=int, default=1,
                         help='Number of guesses to be considered \"correct\" \
                         (default: 1)')
-    return parser.parse_args()
-
+    parser.add_argument('-m', '--min-users', type=int, default=10,
+                        help='Minimum number of unique users to consider a sample\
+                        file valid. (default: 10)')
+    return parser.parse_args(args)
+
+def classify(data, labels, num_users: int, args):
+    s = np.arange(data.shape[0])
+    np.random.shuffle(s)
+    res = kNearestNeighbors(data[s], labels[s],
+                            n=args.folds, verbose=args.verbose, k=args.k_neighbors,
+                            weights=args.weight, guesses=args.top)
+    return (np.average(res), t_test(res, num_users)[1] / 2)
 
 def kNearestNeighbors(data: list, labels: list,
                       n=5, verbose=0, k=5, weights="uniform", guesses=1):
@@ -140,11 +145,12 @@ def find_in_predictions(probabilities: list, tests: int, labels: list):
             for probs, test in zip(probabilities, tests)]
 
 
-def t_test(accuracy: list, labels: list):
+def t_test(accuracy: list, num_users: int):
     from scipy import stats
-    random_avg = 1.0/len(np.unique(labels))
+    random_avg = 1.0/num_users
     return stats.ttest_1samp(accuracy, random_avg)
 
 
 if __name__ == '__main__':
-    main()
+    import sys
+    main(sys.argv[1:])