Преглед на файлове

Added nearest neighbors graphs

Also fixed bug that emerged from generating multiple graphs at once.
Tom Flucke преди 6 години
родител
ревизия
43d87ff7b9
променени са 2 файла, в които са добавени 47 реда и са изтрити 1 реда
  1. 46 1
      src/classifiers/nearestneighbors.py
  2. 1 0
      src/classifiers/randomforest.py

+ 46 - 1
src/classifiers/nearestneighbors.py

@@ -30,6 +30,10 @@ def main(options: list):
     np.random.shuffle(s)
     if args.graph_top:
         graph_top(args, data[s], labels[s])
+    if args.graph_k:
+        graph_k(args, data[s], labels[s])
+    if args.graph_weights:
+        graph_w(args, data[s], labels[s])
     res, matrix = kNearestNeighbors(data[s], labels[s], n=args.folds,
                                     verbose=args.verbose, guesses=args.top,
                                     k=args.k_neighbors, weights=args.weight)
@@ -55,17 +59,54 @@ def graph_top(args, data, labels):
     import seaborn as sns
     from pandas import DataFrame
     from matplotlib import pyplot as plt
+    plt.figure()
     dataset = DataFrame(res, columns=["Top-N Guesses", "Accuracy (%)"])
-    graph = sns.lineplot("Top-N Guesses", "Accuracy", data=dataset)
+    graph = sns.lineplot("Top-N Guesses", "Accuracy (%)", data=dataset)
     graph.set_xticks(np.arange(1, len(label_list), 2))
     graph.set_yticks(np.arange(0, 1, 0.1))
     plt.title('K-Nearest Neighbor Accuracy on Nth Guess')
     graph.get_figure().savefig("nearest-neighbor-top-n.png")
 
+def graph_k(args, data, labels):
+    res = []
+    MAX_K = 25
+    for k in range(1, MAX_K):
+        res.append((k, np.average(
+            kNearestNeighbors(data, labels, n=args.folds, guesses=args.top,
+                              verbose=args.verbose, k=k,
+                              weights=args.weight)[0])
+        ))
+    import seaborn as sns
+    from pandas import DataFrame
+    from matplotlib import pyplot as plt
+    plt.figure()
+    dataset = DataFrame(res, columns=["Nearest K Neighbors", "Accuracy (%)"])
+    graph = sns.lineplot("Nearest K Neighbors", "Accuracy (%)", data=dataset)
+    graph.set_xticks(np.arange(1, MAX_K, 2))
+    plt.title('K-Nearest Neighbor Accuracy')
+    graph.get_figure().savefig("nearest-neighbor-k.png")
+
+def graph_w(args, data, labels):
+    res = []
+    for w in ["uniform", "distance"]:
+        res.append((w, np.average(
+            kNearestNeighbors(data, labels, n=args.folds, guesses=args.top,
+                              verbose=args.verbose, k=args.k_neighbors, weights=w)[0])
+        ))
+    import seaborn as sns
+    from pandas import DataFrame
+    from matplotlib import pyplot as plt
+    plt.figure()
+    dataset = DataFrame(res, columns=["Weight Formula", "Accuracy (%)"])
+    graph = sns.barplot("Weight Formula", "Accuracy (%)", data=dataset)
+    plt.title('K-Nearest Neighbor Accuracy')
+    graph.get_figure().savefig("nearest-neighbor-w.png")
+
 def gen_confusion_matrix(matrix, labels):
     import seaborn as sns
     from pandas import DataFrame
     from matplotlib import pyplot as plt
+    plt.figure()
     label_list = list(map(lambda l: l[0:6], sorted(np.unique(labels))))
     dataset = DataFrame(matrix, columns=label_list, index=label_list)
     sns.set(font_scale=0.8)
@@ -104,6 +145,10 @@ def parse_args(args: list):
                         help='Generates a confusion matrix.')
     parser.add_argument('--graph-top', action="store_true",
                         help='Generates a graph of accuracy in top N guesses.')
+    parser.add_argument('--graph-k', action="store_true",
+                        help='Generates a graph of accuracy for k-nearest neighbors.')
+    parser.add_argument('--graph-weights', action="store_true",
+                        help='Generates a graph comparing weights.')
     return parser.parse_args(args)
 
 def classify(data, labels, num_users: int, args):

+ 1 - 0
src/classifiers/randomforest.py

@@ -38,6 +38,7 @@ def main(options: list):
         import seaborn as sns
         from pandas import DataFrame
         from matplotlib import pyplot as plt
+        plt.figure()
         label_list = list(map(lambda l: l[0:6], sorted(np.unique(labels))))
         dataset = DataFrame(matrix, columns=label_list, index=label_list)
         sns.set(font_scale=0.8)