Optical character recognition

This example uses a convolutional stack followed by a recurrent stackand a CTC logloss function to perform optical character recognitionof generated text images. I have no evidence of whether it actuallylearns general shapes of text, or just is able to recognize allthe different fonts thrown at it…the purpose is more to demonstrate CTCinside of Keras. Note that the font list may need to be updatedfor the particular OS in use.

This starts off with 4 letter words. For the first 12 epochs, thedifficulty is gradually increased using the TextImageGenerator classwhich is both a generator class for test/train data and a Kerascallback class. After 20 epochs, longer sequences are thrown at itby recompiling the model to handle a wider image and rebuildingthe word list to include two words separated by a space.

The table below shows normalized edit distance values. Theano usesa slightly different CTC implementation, hence the different results.

EpochTFTH
100.0270.064
150.0380.035
200.0430.045
250.0140.019

This requires cairo and editdistance packages:

  1. pip install cairocffi
  2. pip install editdistance

Created by Mike Henryhttps://github.com/mbhenry/

  1. import os
  2. import itertools
  3. import codecs
  4. import re
  5. import datetime
  6. import cairocffi as cairo
  7. import editdistance
  8. import numpy as np
  9. from scipy import ndimage
  10. import pylab
  11. from keras import backend as K
  12. from keras.layers.convolutional import Conv2D, MaxPooling2D
  13. from keras.layers import Input, Dense, Activation
  14. from keras.layers import Reshape, Lambda
  15. from keras.layers.merge import add, concatenate
  16. from keras.models import Model
  17. from keras.layers.recurrent import GRU
  18. from keras.optimizers import SGD
  19. from keras.utils.data_utils import get_file
  20. from keras.preprocessing import image
  21. import keras.callbacks
  22. OUTPUT_DIR = 'image_ocr'
  23. # character classes and matching regex filter
  24. regex = r'^[a-z ]+$'
  25. alphabet = u'abcdefghijklmnopqrstuvwxyz '
  26. np.random.seed(55)
  27. # this creates larger "blotches" of noise which look
  28. # more realistic than just adding gaussian noise
  29. # assumes greyscale with pixels ranging from 0 to 1
  30. def speckle(img):
  31. severity = np.random.uniform(0, 0.6)
  32. blur = ndimage.gaussian_filter(np.random.randn(*img.shape) * severity, 1)
  33. img_speck = (img + blur)
  34. img_speck[img_speck > 1] = 1
  35. img_speck[img_speck <= 0] = 0
  36. return img_speck
  37. # paints the string in a random location the bounding box
  38. # also uses a random font, a slight random rotation,
  39. # and a random amount of speckle noise
  40. def paint_text(text, w, h, rotate=False, ud=False, multi_fonts=False):
  41. surface = cairo.ImageSurface(cairo.FORMAT_RGB24, w, h)
  42. with cairo.Context(surface) as context:
  43. context.set_source_rgb(1, 1, 1) # White
  44. context.paint()
  45. # this font list works in CentOS 7
  46. if multi_fonts:
  47. fonts = [
  48. 'Century Schoolbook', 'Courier', 'STIX',
  49. 'URW Chancery L', 'FreeMono']
  50. context.select_font_face(
  51. np.random.choice(fonts),
  52. cairo.FONT_SLANT_NORMAL,
  53. np.random.choice([cairo.FONT_WEIGHT_BOLD, cairo.FONT_WEIGHT_NORMAL]))
  54. else:
  55. context.select_font_face('Courier',
  56. cairo.FONT_SLANT_NORMAL,
  57. cairo.FONT_WEIGHT_BOLD)
  58. context.set_font_size(25)
  59. box = context.text_extents(text)
  60. border_w_h = (4, 4)
  61. if box[2] > (w - 2 * border_w_h[1]) or box[3] > (h - 2 * border_w_h[0]):
  62. raise IOError(('Could not fit string into image.'
  63. 'Max char count is too large for given image width.'))
  64. # teach the RNN translational invariance by
  65. # fitting text box randomly on canvas, with some room to rotate
  66. max_shift_x = w - box[2] - border_w_h[0]
  67. max_shift_y = h - box[3] - border_w_h[1]
  68. top_left_x = np.random.randint(0, int(max_shift_x))
  69. if ud:
  70. top_left_y = np.random.randint(0, int(max_shift_y))
  71. else:
  72. top_left_y = h // 2
  73. context.move_to(top_left_x - int(box[0]), top_left_y - int(box[1]))
  74. context.set_source_rgb(0, 0, 0)
  75. context.show_text(text)
  76. buf = surface.get_data()
  77. a = np.frombuffer(buf, np.uint8)
  78. a.shape = (h, w, 4)
  79. a = a[:, :, 0] # grab single channel
  80. a = a.astype(np.float32) / 255
  81. a = np.expand_dims(a, 0)
  82. if rotate:
  83. a = image.random_rotation(a, 3 * (w - top_left_x) / w + 1)
  84. a = speckle(a)
  85. return a
  86. def shuffle_mats_or_lists(matrix_list, stop_ind=None):
  87. ret = []
  88. assert all([len(i) == len(matrix_list[0]) for i in matrix_list])
  89. len_val = len(matrix_list[0])
  90. if stop_ind is None:
  91. stop_ind = len_val
  92. assert stop_ind <= len_val
  93. a = list(range(stop_ind))
  94. np.random.shuffle(a)
  95. a += list(range(stop_ind, len_val))
  96. for mat in matrix_list:
  97. if isinstance(mat, np.ndarray):
  98. ret.append(mat[a])
  99. elif isinstance(mat, list):
  100. ret.append([mat[i] for i in a])
  101. else:
  102. raise TypeError('`shuffle_mats_or_lists` only supports '
  103. 'numpy.array and list objects.')
  104. return ret
  105. # Translation of characters to unique integer values
  106. def text_to_labels(text):
  107. ret = []
  108. for char in text:
  109. ret.append(alphabet.find(char))
  110. return ret
  111. # Reverse translation of numerical classes back to characters
  112. def labels_to_text(labels):
  113. ret = []
  114. for c in labels:
  115. if c == len(alphabet): # CTC Blank
  116. ret.append("")
  117. else:
  118. ret.append(alphabet[c])
  119. return "".join(ret)
  120. # only a-z and space..probably not to difficult
  121. # to expand to uppercase and symbols
  122. def is_valid_str(in_str):
  123. search = re.compile(regex, re.UNICODE).search
  124. return bool(search(in_str))
  125. # Uses generator functions to supply train/test with
  126. # data. Image renderings and text are created on the fly
  127. # each time with random perturbations
  128. class TextImageGenerator(keras.callbacks.Callback):
  129. def __init__(self, monogram_file, bigram_file, minibatch_size,
  130. img_w, img_h, downsample_factor, val_split,
  131. absolute_max_string_len=16):
  132. self.minibatch_size = minibatch_size
  133. self.img_w = img_w
  134. self.img_h = img_h
  135. self.monogram_file = monogram_file
  136. self.bigram_file = bigram_file
  137. self.downsample_factor = downsample_factor
  138. self.val_split = val_split
  139. self.blank_label = self.get_output_size() - 1
  140. self.absolute_max_string_len = absolute_max_string_len
  141. def get_output_size(self):
  142. return len(alphabet) + 1
  143. # num_words can be independent of the epoch size due to the use of generators
  144. # as max_string_len grows, num_words can grow
  145. def build_word_list(self, num_words, max_string_len=None, mono_fraction=0.5):
  146. assert max_string_len <= self.absolute_max_string_len
  147. assert num_words % self.minibatch_size == 0
  148. assert (self.val_split * num_words) % self.minibatch_size == 0
  149. self.num_words = num_words
  150. self.string_list = [''] * self.num_words
  151. tmp_string_list = []
  152. self.max_string_len = max_string_len
  153. self.Y_data = np.ones([self.num_words, self.absolute_max_string_len]) * -1
  154. self.X_text = []
  155. self.Y_len = [0] * self.num_words
  156. def _is_length_of_word_valid(word):
  157. return (max_string_len == -1 or
  158. max_string_len is None or
  159. len(word) <= max_string_len)
  160. # monogram file is sorted by frequency in english speech
  161. with codecs.open(self.monogram_file, mode='r', encoding='utf-8') as f:
  162. for line in f:
  163. if len(tmp_string_list) == int(self.num_words * mono_fraction):
  164. break
  165. word = line.rstrip()
  166. if _is_length_of_word_valid(word):
  167. tmp_string_list.append(word)
  168. # bigram file contains common word pairings in english speech
  169. with codecs.open(self.bigram_file, mode='r', encoding='utf-8') as f:
  170. lines = f.readlines()
  171. for line in lines:
  172. if len(tmp_string_list) == self.num_words:
  173. break
  174. columns = line.lower().split()
  175. word = columns[0] + ' ' + columns[1]
  176. if is_valid_str(word) and _is_length_of_word_valid(word):
  177. tmp_string_list.append(word)
  178. if len(tmp_string_list) != self.num_words:
  179. raise IOError('Could not pull enough words'
  180. 'from supplied monogram and bigram files.')
  181. # interlace to mix up the easy and hard words
  182. self.string_list[::2] = tmp_string_list[:self.num_words // 2]
  183. self.string_list[1::2] = tmp_string_list[self.num_words // 2:]
  184. for i, word in enumerate(self.string_list):
  185. self.Y_len[i] = len(word)
  186. self.Y_data[i, 0:len(word)] = text_to_labels(word)
  187. self.X_text.append(word)
  188. self.Y_len = np.expand_dims(np.array(self.Y_len), 1)
  189. self.cur_val_index = self.val_split
  190. self.cur_train_index = 0
  191. # each time an image is requested from train/val/test, a new random
  192. # painting of the text is performed
  193. def get_batch(self, index, size, train):
  194. # width and height are backwards from typical Keras convention
  195. # because width is the time dimension when it gets fed into the RNN
  196. if K.image_data_format() == 'channels_first':
  197. X_data = np.ones([size, 1, self.img_w, self.img_h])
  198. else:
  199. X_data = np.ones([size, self.img_w, self.img_h, 1])
  200. labels = np.ones([size, self.absolute_max_string_len])
  201. input_length = np.zeros([size, 1])
  202. label_length = np.zeros([size, 1])
  203. source_str = []
  204. for i in range(size):
  205. # Mix in some blank inputs. This seems to be important for
  206. # achieving translational invariance
  207. if train and i > size - 4:
  208. if K.image_data_format() == 'channels_first':
  209. X_data[i, 0, 0:self.img_w, :] = self.paint_func('')[0, :, :].T
  210. else:
  211. X_data[i, 0:self.img_w, :, 0] = self.paint_func('',)[0, :, :].T
  212. labels[i, 0] = self.blank_label
  213. input_length[i] = self.img_w // self.downsample_factor - 2
  214. label_length[i] = 1
  215. source_str.append('')
  216. else:
  217. if K.image_data_format() == 'channels_first':
  218. X_data[i, 0, 0:self.img_w, :] = (
  219. self.paint_func(self.X_text[index + i])[0, :, :].T)
  220. else:
  221. X_data[i, 0:self.img_w, :, 0] = (
  222. self.paint_func(self.X_text[index + i])[0, :, :].T)
  223. labels[i, :] = self.Y_data[index + i]
  224. input_length[i] = self.img_w // self.downsample_factor - 2
  225. label_length[i] = self.Y_len[index + i]
  226. source_str.append(self.X_text[index + i])
  227. inputs = {'the_input': X_data,
  228. 'the_labels': labels,
  229. 'input_length': input_length,
  230. 'label_length': label_length,
  231. 'source_str': source_str # used for visualization only
  232. }
  233. outputs = {'ctc': np.zeros([size])} # dummy data for dummy loss function
  234. return (inputs, outputs)
  235. def next_train(self):
  236. while 1:
  237. ret = self.get_batch(self.cur_train_index,
  238. self.minibatch_size, train=True)
  239. self.cur_train_index += self.minibatch_size
  240. if self.cur_train_index >= self.val_split:
  241. self.cur_train_index = self.cur_train_index % 32
  242. (self.X_text, self.Y_data, self.Y_len) = shuffle_mats_or_lists(
  243. [self.X_text, self.Y_data, self.Y_len], self.val_split)
  244. yield ret
  245. def next_val(self):
  246. while 1:
  247. ret = self.get_batch(self.cur_val_index,
  248. self.minibatch_size, train=False)
  249. self.cur_val_index += self.minibatch_size
  250. if self.cur_val_index >= self.num_words:
  251. self.cur_val_index = self.val_split + self.cur_val_index % 32
  252. yield ret
  253. def on_train_begin(self, logs={}):
  254. self.build_word_list(16000, 4, 1)
  255. self.paint_func = lambda text: paint_text(
  256. text, self.img_w, self.img_h,
  257. rotate=False, ud=False, multi_fonts=False)
  258. def on_epoch_begin(self, epoch, logs={}):
  259. # rebind the paint function to implement curriculum learning
  260. if 3 <= epoch < 6:
  261. self.paint_func = lambda text: paint_text(
  262. text, self.img_w, self.img_h,
  263. rotate=False, ud=True, multi_fonts=False)
  264. elif 6 <= epoch < 9:
  265. self.paint_func = lambda text: paint_text(
  266. text, self.img_w, self.img_h,
  267. rotate=False, ud=True, multi_fonts=True)
  268. elif epoch >= 9:
  269. self.paint_func = lambda text: paint_text(
  270. text, self.img_w, self.img_h,
  271. rotate=True, ud=True, multi_fonts=True)
  272. if epoch >= 21 and self.max_string_len < 12:
  273. self.build_word_list(32000, 12, 0.5)
  274. # the actual loss calc occurs here despite it not being
  275. # an internal Keras loss function
  276. def ctc_lambda_func(args):
  277. y_pred, labels, input_length, label_length = args
  278. # the 2 is critical here since the first couple outputs of the RNN
  279. # tend to be garbage:
  280. y_pred = y_pred[:, 2:, :]
  281. return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
  282. # For a real OCR application, this should be beam search with a dictionary
  283. # and language model. For this example, best path is sufficient.
  284. def decode_batch(test_func, word_batch):
  285. out = test_func([word_batch])[0]
  286. ret = []
  287. for j in range(out.shape[0]):
  288. out_best = list(np.argmax(out[j, 2:], 1))
  289. out_best = [k for k, g in itertools.groupby(out_best)]
  290. outstr = labels_to_text(out_best)
  291. ret.append(outstr)
  292. return ret
  293. class VizCallback(keras.callbacks.Callback):
  294. def __init__(self, run_name, test_func, text_img_gen, num_display_words=6):
  295. self.test_func = test_func
  296. self.output_dir = os.path.join(
  297. OUTPUT_DIR, run_name)
  298. self.text_img_gen = text_img_gen
  299. self.num_display_words = num_display_words
  300. if not os.path.exists(self.output_dir):
  301. os.makedirs(self.output_dir)
  302. def show_edit_distance(self, num):
  303. num_left = num
  304. mean_norm_ed = 0.0
  305. mean_ed = 0.0
  306. while num_left > 0:
  307. word_batch = next(self.text_img_gen)[0]
  308. num_proc = min(word_batch['the_input'].shape[0], num_left)
  309. decoded_res = decode_batch(self.test_func,
  310. word_batch['the_input'][0:num_proc])
  311. for j in range(num_proc):
  312. edit_dist = editdistance.eval(decoded_res[j],
  313. word_batch['source_str'][j])
  314. mean_ed += float(edit_dist)
  315. mean_norm_ed += float(edit_dist) / len(word_batch['source_str'][j])
  316. num_left -= num_proc
  317. mean_norm_ed = mean_norm_ed / num
  318. mean_ed = mean_ed / num
  319. print('\nOut of %d samples: Mean edit distance:'
  320. '%.3f Mean normalized edit distance: %0.3f'
  321. % (num, mean_ed, mean_norm_ed))
  322. def on_epoch_end(self, epoch, logs={}):
  323. self.model.save_weights(
  324. os.path.join(self.output_dir, 'weights%02d.h5' % (epoch)))
  325. self.show_edit_distance(256)
  326. word_batch = next(self.text_img_gen)[0]
  327. res = decode_batch(self.test_func,
  328. word_batch['the_input'][0:self.num_display_words])
  329. if word_batch['the_input'][0].shape[0] < 256:
  330. cols = 2
  331. else:
  332. cols = 1
  333. for i in range(self.num_display_words):
  334. pylab.subplot(self.num_display_words // cols, cols, i + 1)
  335. if K.image_data_format() == 'channels_first':
  336. the_input = word_batch['the_input'][i, 0, :, :]
  337. else:
  338. the_input = word_batch['the_input'][i, :, :, 0]
  339. pylab.imshow(the_input.T, cmap='Greys_r')
  340. pylab.xlabel(
  341. 'Truth = \'%s\'\nDecoded = \'%s\'' %
  342. (word_batch['source_str'][i], res[i]))
  343. fig = pylab.gcf()
  344. fig.set_size_inches(10, 13)
  345. pylab.savefig(os.path.join(self.output_dir, 'e%02d.png' % (epoch)))
  346. pylab.close()
  347. def train(run_name, start_epoch, stop_epoch, img_w):
  348. # Input Parameters
  349. img_h = 64
  350. words_per_epoch = 16000
  351. val_split = 0.2
  352. val_words = int(words_per_epoch * (val_split))
  353. # Network parameters
  354. conv_filters = 16
  355. kernel_size = (3, 3)
  356. pool_size = 2
  357. time_dense_size = 32
  358. rnn_size = 512
  359. minibatch_size = 32
  360. if K.image_data_format() == 'channels_first':
  361. input_shape = (1, img_w, img_h)
  362. else:
  363. input_shape = (img_w, img_h, 1)
  364. fdir = os.path.dirname(
  365. get_file('wordlists.tgz',
  366. origin='http://www.mythic-ai.com/datasets/wordlists.tgz',
  367. untar=True))
  368. img_gen = TextImageGenerator(
  369. monogram_file=os.path.join(fdir, 'wordlist_mono_clean.txt'),
  370. bigram_file=os.path.join(fdir, 'wordlist_bi_clean.txt'),
  371. minibatch_size=minibatch_size,
  372. img_w=img_w,
  373. img_h=img_h,
  374. downsample_factor=(pool_size ** 2),
  375. val_split=words_per_epoch - val_words)
  376. act = 'relu'
  377. input_data = Input(name='the_input', shape=input_shape, dtype='float32')
  378. inner = Conv2D(conv_filters, kernel_size, padding='same',
  379. activation=act, kernel_initializer='he_normal',
  380. name='conv1')(input_data)
  381. inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max1')(inner)
  382. inner = Conv2D(conv_filters, kernel_size, padding='same',
  383. activation=act, kernel_initializer='he_normal',
  384. name='conv2')(inner)
  385. inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max2')(inner)
  386. conv_to_rnn_dims = (img_w // (pool_size ** 2),
  387. (img_h // (pool_size ** 2)) * conv_filters)
  388. inner = Reshape(target_shape=conv_to_rnn_dims, name='reshape')(inner)
  389. # cuts down input size going into RNN:
  390. inner = Dense(time_dense_size, activation=act, name='dense1')(inner)
  391. # Two layers of bidirectional GRUs
  392. # GRU seems to work as well, if not better than LSTM:
  393. gru_1 = GRU(rnn_size, return_sequences=True,
  394. kernel_initializer='he_normal', name='gru1')(inner)
  395. gru_1b = GRU(rnn_size, return_sequences=True,
  396. go_backwards=True, kernel_initializer='he_normal',
  397. name='gru1_b')(inner)
  398. gru1_merged = add([gru_1, gru_1b])
  399. gru_2 = GRU(rnn_size, return_sequences=True,
  400. kernel_initializer='he_normal', name='gru2')(gru1_merged)
  401. gru_2b = GRU(rnn_size, return_sequences=True, go_backwards=True,
  402. kernel_initializer='he_normal', name='gru2_b')(gru1_merged)
  403. # transforms RNN output to character activations:
  404. inner = Dense(img_gen.get_output_size(), kernel_initializer='he_normal',
  405. name='dense2')(concatenate([gru_2, gru_2b]))
  406. y_pred = Activation('softmax', name='softmax')(inner)
  407. Model(inputs=input_data, outputs=y_pred).summary()
  408. labels = Input(name='the_labels',
  409. shape=[img_gen.absolute_max_string_len], dtype='float32')
  410. input_length = Input(name='input_length', shape=[1], dtype='int64')
  411. label_length = Input(name='label_length', shape=[1], dtype='int64')
  412. # Keras doesn't currently support loss funcs with extra parameters
  413. # so CTC loss is implemented in a lambda layer
  414. loss_out = Lambda(
  415. ctc_lambda_func, output_shape=(1,),
  416. name='ctc')([y_pred, labels, input_length, label_length])
  417. # clipnorm seems to speeds up convergence
  418. sgd = SGD(lr=0.02, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5)
  419. model = Model(inputs=[input_data, labels, input_length, label_length],
  420. outputs=loss_out)
  421. # the loss calc occurs elsewhere, so use a dummy lambda func for the loss
  422. model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=sgd)
  423. if start_epoch > 0:
  424. weight_file = os.path.join(
  425. OUTPUT_DIR,
  426. os.path.join(run_name, 'weights%02d.h5' % (start_epoch - 1)))
  427. model.load_weights(weight_file)
  428. # captures output of softmax so we can decode the output during visualization
  429. test_func = K.function([input_data], [y_pred])
  430. viz_cb = VizCallback(run_name, test_func, img_gen.next_val())
  431. model.fit_generator(
  432. generator=img_gen.next_train(),
  433. steps_per_epoch=(words_per_epoch - val_words) // minibatch_size,
  434. epochs=stop_epoch,
  435. validation_data=img_gen.next_val(),
  436. validation_steps=val_words // minibatch_size,
  437. callbacks=[viz_cb, img_gen],
  438. initial_epoch=start_epoch)
  439. if __name__ == '__main__':
  440. run_name = datetime.datetime.now().strftime('%Y:%m:%d:%H:%M:%S')
  441. train(run_name, 0, 20, 128)
  442. # increase to wider images and start at epoch 20.
  443. # The learned weights are reloaded
  444. train(run_name, 20, 25, 512)