hwRecProto2.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. ########################################################################################
  2. # Author: Thomas Flucke
  3. # Date: 2017-05-13
  4. # Abreviations:
  5. # vect = Vector
  6. # ANN = Artifical Neural Network
  7. # corr = Correct version
  8. ########################################################################################
  9. # Declare the data format
  10. import random
  11. class Datum:
  12. def __init__(self, label, img):
  13. self.label = [0] * (10 + 26 + 26)
  14. self.label[label - 1] = 1
  15. self.img = img
  16. class IAM:
  17. def __init__(self):
  18. print "Building dataset..."
  19. self.train = []
  20. self.test = []
  21. for x in range(1, (10 + 26 + 26) + 1):
  22. print "Preparing sample %d..." % x
  23. for f in os.listdir(DATA_FOLDER % x):
  24. img = scipy.ndimage.imread(IMG_TEMPLATE % (x, f), True)
  25. img = scipy.misc.imresize(img, 0.03)
  26. img = list(itertools.chain.from_iterable(img))
  27. if len(self.test) < (5 * x):
  28. self.test.append(Datum(x, img))
  29. else:
  30. self.train.append(Datum(x, img))
  31. def nextBatch(self, size):
  32. used = []
  33. res = []
  34. labels = []
  35. while len(used) < size:
  36. i = random.randint(0, len(self.train) - 1)
  37. if i in used:
  38. continue
  39. else:
  40. used.append(i)
  41. res.append(self.train[i].img)
  42. labels.append(self.train[i].label)
  43. return res, labels
  44. ########################################################################################
  45. # Load IAM dataset
  46. import cPickle as pickle
  47. with open('iamDataset.obj', 'rb') as input:
  48. iam = pickle.load(input)
  49. ########################################################################################
  50. # Import tensorflow library and define AAN
  51. import tensorflow as tf
  52. LEARNING_CONST = 0.5
  53. # Create a tensorflow placeholder for an array [nx784] float32's (a.k.a. n MNIST vectors)
  54. inVect = tf.placeholder(tf.float32, [None, 972])
  55. # Initalize the ANN with zero's
  56. # Define tensorflow variable for the weight matrix [784x10] so we can matrix multiply
  57. weights = tf.Variable(tf.zeros([972, (10 + 26 + 26)]))
  58. # Define tensorflow variable for the bias vector
  59. biases = tf.Variable(tf.zeros([(10 + 26 + 26)]))
  60. # Define formula for calculating output [nx10]
  61. outVect = tf.nn.softmax(tf.matmul(inVect, weights) + biases)
  62. # Create a tensorflow placeholder for the correct answer vector
  63. outVectCorr = tf.placeholder(tf.float32, [None, (10 + 26 + 26)])
  64. # Calculate how incorrect the solutions arrive were
  65. crossEntropy = tf.reduce_mean(
  66. -tf.reduce_sum(
  67. outVectCorr * tf.log(outVect),
  68. # Tells reduce_sum to use the 10-length array, and not the n-length
  69. reduction_indices=[1]
  70. )
  71. )
  72. trainStep = tf.train.GradientDescentOptimizer(LEARNING_CONST).minimize(crossEntropy)
  73. ########################################################################################
  74. # Define accuracy checking conditions
  75. # Define formula for determining correctness
  76. # Highest value in outVect 1st index == highest value in correct outVect 1st index
  77. predictionCorr = tf.equal(tf.argmax(outVect, 1), tf.argmax(outVectCorr, 1))
  78. # Calculate how accurate the system was
  79. accuracy = tf.reduce_mean(tf.cast(predictionCorr, tf.float32))
  80. ########################################################################################
  81. # Run the system
  82. # Create interactive session
  83. sess = tf.InteractiveSession()
  84. # Initialize variables
  85. tf.global_variables_initializer().run()
  86. for _ in range(1000) :
  87. # Get 100 random digits from training set
  88. batchIns, batchOuts = iam.nextBatch(100)
  89. # Run the training step in the interactive session with the given inputs/outputs
  90. sess.run(trainStep, feed_dict={inVect: batchIns, outVectCorr: batchOuts})
  91. # Check accuracy
  92. print(sess.run(accuracy, feed_dict={inVect: (o.img for o in iam.test), outVectCorr: (o.label for o in iam.test)}))