randomforest.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. #!/home/tflucke/bin/bin/python3
  2. from sklearn.model_selection import KFold
  3. import numpy as np
  4. import typing
  5. try:
  6. import sample
  7. except ImportError:
  8. import os, sys
  9. sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + \
  10. '/../feature-extractor')
  11. import sample
  12. DEFAULT_FEATURES = ["average_iat", "high.avg_burst_size", "high.burst_count"]
  13. def main(options: list):
  14. args = parse_args(options)
  15. try:
  16. import cPickle as pickle
  17. except:
  18. import pickle
  19. samples = pickle.load(args.features_file)
  20. features = args.feature if args.feature else DEFAULT_FEATURES
  21. from Vector import FeatureVector
  22. data, labels = map(np.array,
  23. zip(*[(FeatureVector(p, features).get(), p.user)
  24. for p in samples]))
  25. num_users = len(np.unique([s.user for s in samples]))
  26. s = np.arange(data.shape[0])
  27. np.random.shuffle(s)
  28. res, matrix = random_forest(data[s], labels[s], fn=args.criterion,
  29. n=args.folds, verbose=args.verbose,
  30. estimators=args.estimators)
  31. print("Overall Accuracy: %f" % np.average(res))
  32. if args.p_value:
  33. print("P-Value: %f" % p)
  34. if args.graph:
  35. import seaborn as sns
  36. from pandas import DataFrame
  37. from matplotlib import pyplot as plt
  38. label_list = list(map(lambda l: l[0:6], sorted(np.unique(labels))))
  39. dataset = DataFrame(matrix, columns=label_list, index=label_list)
  40. sns.set(font_scale=0.8)
  41. graph = sns.heatmap(data=dataset, annot=True, cbar=False)
  42. graph.set_xticklabels(graph.get_xticklabels(), rotation=50,
  43. horizontalalignment="right")
  44. plt.subplots_adjust(left=0.15, bottom=0.2)
  45. plt.ylabel('True Label')
  46. plt.xlabel('Predicted Label')
  47. plt.title('K-Nearest Neighbor Confusion Matrix')
  48. graph.get_figure().savefig("random-forest.png")
  49. def parse_args(args: list):
  50. import argparse
  51. parser = argparse.ArgumentParser(
  52. description='Run a data set through a Random Forest classifier.')
  53. parser.add_argument('features_file', type=argparse.FileType('rb'),
  54. help='File of extracted features.')
  55. parser.add_argument('-v', '--verbose', action="count", default=0,
  56. help='Show more information')
  57. parser.add_argument('-n', '--folds', type=int, default=5,
  58. help='Number of cross-validation folds (default: 5)')
  59. parser.add_argument('-e', '--estimators', type=int, default=100,
  60. help='Number of random decision trees (default: 100)')
  61. parser.add_argument('-c', '--criterion', choices=["gini", "entropy"],
  62. default="gini", help='Function to evaluate tree split \
  63. value (default: \"gini\")')
  64. parser.add_argument('-f', '--feature', action='append', type=str,
  65. help='Add feature to list of features to test with.')
  66. parser.add_argument('-p', '--p-value', action='store_const', default=False,
  67. const=True, help='Calculate a p-value from a t-test.')
  68. parser.add_argument('-g', '--graph', action="store_true",
  69. help='Generates a confusion matrix.')
  70. return parser.parse_args(args)
  71. def classify(data, labels, num_users: int, args):
  72. s = np.arange(data.shape[0])
  73. np.random.shuffle(s)
  74. res, _ = random_forest(data[s], labels[s],
  75. n=args.folds, verbose=args.verbose, fn=args.criterion,
  76. estimators=args.estimators)
  77. return (np.average(res), t_test(res, num_users)[1] / 2)
  78. def random_forest(data: list, labels: list, n=5, verbose=0, estimators=100,
  79. fn="gini"):
  80. from sklearn.ensemble import RandomForestClassifier
  81. from sklearn.metrics import confusion_matrix
  82. folds = KFold(n_splits=n)
  83. i = 1
  84. avg = 0
  85. accuracies = []
  86. output = []
  87. truth = []
  88. label_list = sorted(np.unique(labels))
  89. for train_index, test_index in folds.split(data):
  90. if verbose >= 1:
  91. print("Round %d:" % i)
  92. i += 1
  93. if verbose >= 2:
  94. print("Training on: ", train_index)
  95. rfc = RandomForestClassifier(n_estimators=estimators, criterion=fn)
  96. rfc.fit(data[train_index], labels[train_index])
  97. predictions = rfc.predict(data[test_index])
  98. output.extend(predictions)
  99. truth.extend(labels[test_index])
  100. accuracy = [a == p
  101. for a, p in zip(labels[test_index], predictions)
  102. ].count(True)/len(predictions)
  103. if verbose >= 1:
  104. print(accuracy)
  105. accuracies.append(accuracy)
  106. return (accuracies, confusion_matrix(truth, output, labels=label_list))
  107. def t_test(accuracy: list, num_users: int):
  108. from scipy import stats
  109. random_avg = 1.0/num_users
  110. res = stats.ttest_1samp(accuracy, random_avg, nan_policy="omit")
  111. # If all numbers are identical, p-value = 1
  112. return res if not np.isnan(res[0]) else (0, 1)
  113. if __name__ == '__main__':
  114. import sys
  115. main(sys.argv[1:])