瀏覽代碼

Fixed bug caused by move + added t-test.

Thomas Flucke 6 年之前
父節點
當前提交
757326f5f5
共有 2 個文件被更改,包括 15 次插入4 次删除
  1. 1 1
      src/classifiers/Vector.py
  2. 14 3
      src/classifiers/nearestneighbors.py

+ 1 - 1
src/classifiers/Vector.py

@@ -8,7 +8,7 @@ import sys
 import typing
 from typing import List
 sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + \
-                '/../src/feature-extractor')
+                '/../feature-extractor')
 import sample
 
 class FeatureVector:

+ 14 - 3
src/classifiers/nearestneighbors.py

@@ -25,7 +25,10 @@ def main():
                          for p in samples])
     res = kNearestNeighbors(np.array(data), np.array(labels),
                             n=args.folds, verbose=args.verbose)
-    print("Overall Accuracy: %f" % res)
+    print("Overall Accuracy: %f" % np.average(res))
+    if args.p_value:
+        _, p = t_test(res, labels)
+        print("P-Value: %f" % (p / 2))
 
 def parse_args():
     import argparse
@@ -39,12 +42,15 @@ def parse_args():
                         help='Number of cross-validation folds (default: 5)')
     parser.add_argument('-f', '--feature', action='append', type=str,
                         help='Add feature to list of features to test with.')
+    parser.add_argument('-p', '--p-value', action='store_const', default=False,
+                        const=True, help='Calculate a p-value from a t-test.')
     return parser.parse_args()
 
 def kNearestNeighbors(data: list, labels: list, n=5, verbose=0):
     folds = KFold(n_splits=n)
     i = 1
     avg = 0
+    accuracies = []
     for train_index, test_index in folds.split(data):
         if verbose >= 1:
             print("Round %d:" % i)
@@ -58,8 +64,13 @@ def kNearestNeighbors(data: list, labels: list, n=5, verbose=0):
         accuracy = correct.count(True)/len(correct)
         if verbose >= 1:
             print(accuracy)
-        avg += accuracy
-    return avg/n
+        accuracies.append(accuracy)
+    return accuracies
+
+def t_test(accuracy: list, labels: list):
+    from scipy import stats
+    random_avg = 1.0/len(np.unique(labels))
+    return stats.ttest_1samp(accuracy, random_avg)
 
 if __name__ == '__main__':
     main()