hwRecProto2.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. ########################################################################################
  2. # Author: Thomas Flucke
  3. # Date: 2017-05-13
  4. # Abreviations:
  5. # vect = Vector
  6. # ANN = Artifical Neural Network
  7. # corr = Correct version
  8. ########################################################################################
  9. # Declare the data format
  10. import random
  11. N_INPUT = 972
  12. N_OUTPUT = 10 + 26 + 26
  13. class Datum:
  14. def __init__(self, label, img):
  15. self.label = [0] * N_OUTPUT
  16. self.label[label - 1] = 1
  17. self.img = img
  18. class IAM:
  19. def __init__(self):
  20. print "Building dataset..."
  21. self.train = []
  22. self.test = []
  23. for x in range(1, N_OUTPUT + 1):
  24. print "Preparing sample %d..." % x
  25. for f in os.listdir(DATA_FOLDER % x):
  26. img = scipy.ndimage.imread(IMG_TEMPLATE % (x, f), True)
  27. img = scipy.misc.imresize(img, 0.03)
  28. img = list(itertools.chain.from_iterable(img))
  29. if len(self.test) < (5 * x):
  30. self.test.append(Datum(x, img))
  31. else:
  32. self.train.append(Datum(x, img))
  33. def nextBatch(self, size):
  34. used = []
  35. res = []
  36. labels = []
  37. while len(used) < size:
  38. i = random.randint(0, len(self.train) - 1)
  39. if i in used:
  40. continue
  41. else:
  42. used.append(i)
  43. res.append(self.train[i].img)
  44. labels.append(self.train[i].label)
  45. return res, labels
  46. def testSet(self):
  47. res = []
  48. labels = []
  49. for i in range(0, len(self.test)):
  50. res.append(self.test[i].img)
  51. labels.append(self.test[i].label)
  52. return res, labels
  53. ########################################################################################
  54. # Load IAM dataset
  55. import cPickle as pickle
  56. with open('iamDataset.obj', 'rb') as input:
  57. iam = pickle.load(input)
  58. ########################################################################################
  59. # Import tensorflow library and define AAN
  60. import tensorflow as tf
  61. LEARNING_CONST = 0.5
  62. # Create a tensorflow placeholder for an array [nx784] float32's (a.k.a. n MNIST vectors)
  63. inVect = tf.placeholder(tf.float32, [None, N_INPUT])
  64. # Initalize the ANN with zero's
  65. # Define tensorflow variable for the weight matrix [784x10] so we can matrix multiply
  66. weights = tf.Variable(tf.zeros([N_INPUT, N_OUTPUT]))
  67. # Define tensorflow variable for the bias vector
  68. biases = tf.Variable(tf.zeros([N_OUTPUT]))
  69. # Define formula for calculating output [nx10]
  70. outVect = tf.nn.softmax(tf.matmul(inVect, weights) + biases)
  71. # Create a tensorflow placeholder for the correct answer vector
  72. outVectCorr = tf.placeholder(tf.float32, [None, N_OUTPUT])
  73. # Calculate how incorrect the solutions arrive were
  74. crossEntropy = tf.reduce_mean(
  75. -tf.reduce_sum(
  76. outVectCorr * tf.log(outVect),
  77. # Tells reduce_sum to use the 10-length array, and not the n-length
  78. reduction_indices=[1]
  79. )
  80. )
  81. trainStep = tf.train.GradientDescentOptimizer(LEARNING_CONST).minimize(crossEntropy)
  82. ########################################################################################
  83. # Define accuracy checking conditions
  84. # Define formula for determining correctness
  85. # Highest value in outVect 1st index == highest value in correct outVect 1st index
  86. predictionCorr = tf.equal(tf.argmax(outVect, 1), tf.argmax(outVectCorr, 1))
  87. # Calculate how accurate the system was
  88. accuracy = tf.reduce_mean(tf.cast(predictionCorr, tf.float32))
  89. ########################################################################################
  90. # Run the system
  91. # Create interactive session
  92. sess = tf.InteractiveSession()
  93. # Initialize variables
  94. tf.global_variables_initializer().run()
  95. for _ in range(1000) :
  96. # Get 100 random digits from training set
  97. batchIns, batchOuts = iam.nextBatch(100)
  98. # Run the training step in the interactive session with the given inputs/outputs
  99. sess.run(trainStep, feed_dict={inVect: batchIns, outVectCorr: batchOuts})
  100. # Check accuracy
  101. testIn, testOuts = iam.testSet()
  102. print(sess.run(accuracy, feed_dict={inVect: testIn, outVectCorr: testOuts}))