randomforest.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. #!/home/tflucke/bin/bin/python3
  2. from sklearn.model_selection import KFold
  3. import numpy as np
  4. import typing
  5. try:
  6. import sample
  7. except ImportError:
  8. import os, sys
  9. sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + \
  10. '/../feature-extractor')
  11. import sample
  12. DEFAULT_FEATURES = ["average_iat", "high.avg_burst_size", "high.burst_count"]
  13. def main(options: list):
  14. args = parse_args(options)
  15. try:
  16. import cPickle as pickle
  17. except:
  18. import pickle
  19. samples = pickle.load(args.features_file)
  20. features = args.feature if args.feature else DEFAULT_FEATURES
  21. from Vector import FeatureVector
  22. data, labels = map(np.array,
  23. zip(*[(FeatureVector(p, features).get(), p.user)
  24. for p in samples]))
  25. num_users = len(np.unique([s.user for s in samples]))
  26. s = np.arange(data.shape[0])
  27. np.random.shuffle(s)
  28. res, matrix = random_forest(data[s], labels[s], fn=args.criterion,
  29. n=args.folds, verbose=args.verbose,
  30. estimators=args.estimators)
  31. print("Overall Accuracy: %f" % np.average(res))
  32. if args.p_value:
  33. print("P-Value: %f" % p)
  34. if args.graph:
  35. import seaborn as sns
  36. from pandas import DataFrame
  37. from matplotlib import pyplot as plt
  38. plt.figure()
  39. label_list = list(map(lambda l: l[0:6], sorted(np.unique(labels))))
  40. dataset = DataFrame(matrix, columns=label_list, index=label_list)
  41. sns.set(font_scale=0.8)
  42. graph = sns.heatmap(data=dataset, annot=True, cbar=False)
  43. graph.set_xticklabels(graph.get_xticklabels(), rotation=50,
  44. horizontalalignment="right")
  45. plt.subplots_adjust(left=0.15, bottom=0.2)
  46. plt.ylabel('True Label')
  47. plt.xlabel('Predicted Label')
  48. plt.title('K-Nearest Neighbor Confusion Matrix')
  49. graph.get_figure().savefig("random-forest.png")
  50. def parse_args(args: list):
  51. import argparse
  52. parser = argparse.ArgumentParser(
  53. description='Run a data set through a Random Forest classifier.')
  54. parser.add_argument('features_file', type=argparse.FileType('rb'),
  55. help='File of extracted features.')
  56. parser.add_argument('-v', '--verbose', action="count", default=0,
  57. help='Show more information')
  58. parser.add_argument('-n', '--folds', type=int, default=5,
  59. help='Number of cross-validation folds (default: 5)')
  60. parser.add_argument('-e', '--estimators', type=int, default=100,
  61. help='Number of random decision trees (default: 100)')
  62. parser.add_argument('-c', '--criterion', choices=["gini", "entropy"],
  63. default="gini", help='Function to evaluate tree split \
  64. value (default: \"gini\")')
  65. parser.add_argument('-f', '--feature', action='append', type=str,
  66. help='Add feature to list of features to test with.')
  67. parser.add_argument('-p', '--p-value', action='store_const', default=False,
  68. const=True, help='Calculate a p-value from a t-test.')
  69. parser.add_argument('-g', '--graph', action="store_true",
  70. help='Generates a confusion matrix.')
  71. return parser.parse_args(args)
  72. def classify(data, labels, num_users: int, args):
  73. s = np.arange(data.shape[0])
  74. np.random.shuffle(s)
  75. res, _ = random_forest(data[s], labels[s],
  76. n=args.folds, verbose=args.verbose, fn=args.criterion,
  77. estimators=args.estimators)
  78. return (np.average(res), t_test(res, num_users)[1] / 2)
  79. def random_forest(data: list, labels: list, n=5, verbose=0, estimators=100,
  80. fn="gini"):
  81. from sklearn.ensemble import RandomForestClassifier
  82. from sklearn.metrics import confusion_matrix
  83. folds = KFold(n_splits=n)
  84. i = 1
  85. avg = 0
  86. accuracies = []
  87. output = []
  88. truth = []
  89. label_list = sorted(np.unique(labels))
  90. for train_index, test_index in folds.split(data):
  91. if verbose >= 1:
  92. print("Round %d:" % i)
  93. i += 1
  94. if verbose >= 2:
  95. print("Training on: ", train_index)
  96. rfc = RandomForestClassifier(n_estimators=estimators, criterion=fn)
  97. rfc.fit(data[train_index], labels[train_index])
  98. predictions = rfc.predict(data[test_index])
  99. output.extend(predictions)
  100. truth.extend(labels[test_index])
  101. accuracy = [a == p
  102. for a, p in zip(labels[test_index], predictions)
  103. ].count(True)/len(predictions)
  104. if verbose >= 1:
  105. print(accuracy)
  106. accuracies.append(accuracy)
  107. return (accuracies, confusion_matrix(truth, output, labels=label_list))
  108. def t_test(accuracy: list, num_users: int):
  109. from scipy import stats
  110. random_avg = 1.0/num_users
  111. res = stats.ttest_1samp(accuracy, random_avg, nan_policy="omit")
  112. # If all numbers are identical, p-value = 1
  113. return res if not np.isnan(res[0]) else (0, 1)
  114. if __name__ == '__main__':
  115. import sys
  116. main(sys.argv[1:])