|
@@ -0,0 +1,93 @@
|
|
|
|
|
+#!/usr/bin/python3
|
|
|
|
|
+
|
|
|
|
|
+from sklearn.model_selection import KFold
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+import typing
|
|
|
|
|
+try:
|
|
|
|
|
+ import sample
|
|
|
|
|
+except ImportError:
|
|
|
|
|
+ import os, sys
|
|
|
|
|
+ sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + \
|
|
|
|
|
+ '/../feature-extractor')
|
|
|
|
|
+ import sample
|
|
|
|
|
+
|
|
|
|
|
+DEFAULT_FEATURES = ["average_iat", "high.avg_burst_size", "high.burst_count"]
|
|
|
|
|
+
|
|
|
|
|
+def main(options: list):
|
|
|
|
|
+ args = parse_args(options)
|
|
|
|
|
+ try:
|
|
|
|
|
+ import cPickle as pickle
|
|
|
|
|
+ except:
|
|
|
|
|
+ import pickle
|
|
|
|
|
+ samples = pickle.load(args.features_file)
|
|
|
|
|
+ features = args.feature if args.feature else DEFAULT_FEATURES
|
|
|
|
|
+ from Vector import FeatureVector
|
|
|
|
|
+ data, labels = map(np.array,
|
|
|
|
|
+ zip(*[(FeatureVector(p, features).get(), p.user)
|
|
|
|
|
+ for p in samples]))
|
|
|
|
|
+ avg, p = classify(data, labels, num_users, args)
|
|
|
|
|
+ print("Overall Accuracy: %f" % avg)
|
|
|
|
|
+ if args.p_value:
|
|
|
|
|
+ print("P-Value: %f" % p)
|
|
|
|
|
+
|
|
|
|
|
+def parse_args(args: list):
|
|
|
|
|
+ import argparse
|
|
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
|
|
+ description='Run a data set through a Random Forest classifier.')
|
|
|
|
|
+ parser.add_argument('features_file', type=argparse.FileType('rb'),
|
|
|
|
|
+ help='File of extracted features.')
|
|
|
|
|
+ parser.add_argument('-v', '--verbose', action="count", default=0,
|
|
|
|
|
+ help='Show more information')
|
|
|
|
|
+ parser.add_argument('-n', '--folds', type=int, default=5,
|
|
|
|
|
+ help='Number of cross-validation folds (default: 5)')
|
|
|
|
|
+ parser.add_argument('-e', '--estimators', type=int, default=100,
|
|
|
|
|
+ help='Number of random decision trees (default: 100)')
|
|
|
|
|
+ parser.add_argument('-c', '--criterion', choices=["gini", "entropy"],
|
|
|
|
|
+ default="gini", help='Function to evaluate tree split \
|
|
|
|
|
+ value (default: \"gini\")')
|
|
|
|
|
+ parser.add_argument('-f', '--feature', action='append', type=str,
|
|
|
|
|
+ help='Add feature to list of features to test with.')
|
|
|
|
|
+ parser.add_argument('-p', '--p-value', action='store_const', default=False,
|
|
|
|
|
+ const=True, help='Calculate a p-value from a t-test.')
|
|
|
|
|
+ return parser.parse_args(args)
|
|
|
|
|
+
|
|
|
|
|
+def classify(data, labels, num_users: int, args):
|
|
|
|
|
+ s = np.arange(data.shape[0])
|
|
|
|
|
+ np.random.shuffle(s)
|
|
|
|
|
+ res = randomForest(data[s], labels[s],
|
|
|
|
|
+ n=args.folds, verbose=args.verbose, fn=args.criterion
|
|
|
|
|
+ estimators=args.estimators)
|
|
|
|
|
+ return (np.average(res), t_test(res, num_users)[1] / 2)
|
|
|
|
|
+
|
|
|
|
|
+def randomForest(data: list, labels: list, n=5, verbose=0, estimators=100,
|
|
|
|
|
+ fn="gini"):
|
|
|
|
|
+ from sklearn.ensemble import RandomForestClassifier
|
|
|
|
|
+ folds = KFold(n_splits=n)
|
|
|
|
|
+ i = 1
|
|
|
|
|
+ avg = 0
|
|
|
|
|
+ accuracies = []
|
|
|
|
|
+ label_list = sorted(np.unique(labels))
|
|
|
|
|
+ for train_index, test_index in folds.split(data):
|
|
|
|
|
+ if verbose >= 1:
|
|
|
|
|
+ print("Round %d:" % i)
|
|
|
|
|
+ i += 1
|
|
|
|
|
+ if verbose >= 2:
|
|
|
|
|
+ print("Training on: ", train_index)
|
|
|
|
|
+ rfc = RandomForestClassifier(n_estimators=estimators, criterion=fn)
|
|
|
|
|
+ rfc.fit(data[train_index], labels[train_index])
|
|
|
|
|
+ accuracy = rfc.score(data[test_index], labels[test_index])
|
|
|
|
|
+ if verbose >= 1:
|
|
|
|
|
+ print(accuracy)
|
|
|
|
|
+ accuracies.append(accuracy)
|
|
|
|
|
+ return accuracies
|
|
|
|
|
+
|
|
|
|
|
+def t_test(accuracy: list, num_users: int):
|
|
|
|
|
+ from scipy import stats
|
|
|
|
|
+ random_avg = 1.0/num_users
|
|
|
|
|
+ res = stats.ttest_1samp(accuracy, random_avg, nan_policy="omit")
|
|
|
|
|
+ # If all numbers are identical, p-value = 1
|
|
|
|
|
+ return res if not np.isnan(res[0]) else (0, 1)
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == '__main__':
|
|
|
|
|
+ import sys
|
|
|
|
|
+ main(sys.argv[1:])
|