Trains a memory network on the bAbI dataset.

References:

Reaches 98.6% accuracy on task 'single_supporting_fact_10k' after 120 epochs.Time per epoch: 3s on CPU (core i7).

  1. from __future__ import print_function
  2. from keras.models import Sequential, Model
  3. from keras.layers.embeddings import Embedding
  4. from keras.layers import Input, Activation, Dense, Permute, Dropout
  5. from keras.layers import add, dot, concatenate
  6. from keras.layers import LSTM
  7. from keras.utils.data_utils import get_file
  8. from keras.preprocessing.sequence import pad_sequences
  9. from functools import reduce
  10. import tarfile
  11. import numpy as np
  12. import re
  13. def tokenize(sent):
  14. '''Return the tokens of a sentence including punctuation.
  15. >>> tokenize('Bob dropped the apple. Where is the apple?')
  16. ['Bob', 'dropped', 'the', 'apple', '.', 'Where', 'is', 'the', 'apple', '?']
  17. '''
  18. return [x.strip() for x in re.split(r'(\W+)?', sent) if x.strip()]
  19. def parse_stories(lines, only_supporting=False):
  20. '''Parse stories provided in the bAbi tasks format
  21. If only_supporting is true, only the sentences
  22. that support the answer are kept.
  23. '''
  24. data = []
  25. story = []
  26. for line in lines:
  27. line = line.decode('utf-8').strip()
  28. nid, line = line.split(' ', 1)
  29. nid = int(nid)
  30. if nid == 1:
  31. story = []
  32. if '\t' in line:
  33. q, a, supporting = line.split('\t')
  34. q = tokenize(q)
  35. if only_supporting:
  36. # Only select the related substory
  37. supporting = map(int, supporting.split())
  38. substory = [story[i - 1] for i in supporting]
  39. else:
  40. # Provide all the substories
  41. substory = [x for x in story if x]
  42. data.append((substory, q, a))
  43. story.append('')
  44. else:
  45. sent = tokenize(line)
  46. story.append(sent)
  47. return data
  48. def get_stories(f, only_supporting=False, max_length=None):
  49. '''Given a file name, read the file,
  50. retrieve the stories,
  51. and then convert the sentences into a single story.
  52. If max_length is supplied,
  53. any stories longer than max_length tokens will be discarded.
  54. '''
  55. data = parse_stories(f.readlines(), only_supporting=only_supporting)
  56. flatten = lambda data: reduce(lambda x, y: x + y, data)
  57. data = [(flatten(story), q, answer) for story, q, answer in data
  58. if not max_length or len(flatten(story)) < max_length]
  59. return data
  60. def vectorize_stories(data):
  61. inputs, queries, answers = [], [], []
  62. for story, query, answer in data:
  63. inputs.append([word_idx[w] for w in story])
  64. queries.append([word_idx[w] for w in query])
  65. answers.append(word_idx[answer])
  66. return (pad_sequences(inputs, maxlen=story_maxlen),
  67. pad_sequences(queries, maxlen=query_maxlen),
  68. np.array(answers))
  69. try:
  70. path = get_file('babi-tasks-v1-2.tar.gz',
  71. origin='https://s3.amazonaws.com/text-datasets/'
  72. 'babi_tasks_1-20_v1-2.tar.gz')
  73. except:
  74. print('Error downloading dataset, please download it manually:\n'
  75. '$ wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2'
  76. '.tar.gz\n'
  77. '$ mv tasks_1-20_v1-2.tar.gz ~/.keras/datasets/babi-tasks-v1-2.tar.gz')
  78. raise
  79. challenges = {
  80. # QA1 with 10,000 samples
  81. 'single_supporting_fact_10k': 'tasks_1-20_v1-2/en-10k/qa1_'
  82. 'single-supporting-fact_{}.txt',
  83. # QA2 with 10,000 samples
  84. 'two_supporting_facts_10k': 'tasks_1-20_v1-2/en-10k/qa2_'
  85. 'two-supporting-facts_{}.txt',
  86. }
  87. challenge_type = 'single_supporting_fact_10k'
  88. challenge = challenges[challenge_type]
  89. print('Extracting stories for the challenge:', challenge_type)
  90. with tarfile.open(path) as tar:
  91. train_stories = get_stories(tar.extractfile(challenge.format('train')))
  92. test_stories = get_stories(tar.extractfile(challenge.format('test')))
  93. vocab = set()
  94. for story, q, answer in train_stories + test_stories:
  95. vocab |= set(story + q + [answer])
  96. vocab = sorted(vocab)
  97. # Reserve 0 for masking via pad_sequences
  98. vocab_size = len(vocab) + 1
  99. story_maxlen = max(map(len, (x for x, _, _ in train_stories + test_stories)))
  100. query_maxlen = max(map(len, (x for _, x, _ in train_stories + test_stories)))
  101. print('-')
  102. print('Vocab size:', vocab_size, 'unique words')
  103. print('Story max length:', story_maxlen, 'words')
  104. print('Query max length:', query_maxlen, 'words')
  105. print('Number of training stories:', len(train_stories))
  106. print('Number of test stories:', len(test_stories))
  107. print('-')
  108. print('Here\'s what a "story" tuple looks like (input, query, answer):')
  109. print(train_stories[0])
  110. print('-')
  111. print('Vectorizing the word sequences...')
  112. word_idx = dict((c, i + 1) for i, c in enumerate(vocab))
  113. inputs_train, queries_train, answers_train = vectorize_stories(train_stories)
  114. inputs_test, queries_test, answers_test = vectorize_stories(test_stories)
  115. print('-')
  116. print('inputs: integer tensor of shape (samples, max_length)')
  117. print('inputs_train shape:', inputs_train.shape)
  118. print('inputs_test shape:', inputs_test.shape)
  119. print('-')
  120. print('queries: integer tensor of shape (samples, max_length)')
  121. print('queries_train shape:', queries_train.shape)
  122. print('queries_test shape:', queries_test.shape)
  123. print('-')
  124. print('answers: binary (1 or 0) tensor of shape (samples, vocab_size)')
  125. print('answers_train shape:', answers_train.shape)
  126. print('answers_test shape:', answers_test.shape)
  127. print('-')
  128. print('Compiling...')
  129. # placeholders
  130. input_sequence = Input((story_maxlen,))
  131. question = Input((query_maxlen,))
  132. # encoders
  133. # embed the input sequence into a sequence of vectors
  134. input_encoder_m = Sequential()
  135. input_encoder_m.add(Embedding(input_dim=vocab_size,
  136. output_dim=64))
  137. input_encoder_m.add(Dropout(0.3))
  138. # output: (samples, story_maxlen, embedding_dim)
  139. # embed the input into a sequence of vectors of size query_maxlen
  140. input_encoder_c = Sequential()
  141. input_encoder_c.add(Embedding(input_dim=vocab_size,
  142. output_dim=query_maxlen))
  143. input_encoder_c.add(Dropout(0.3))
  144. # output: (samples, story_maxlen, query_maxlen)
  145. # embed the question into a sequence of vectors
  146. question_encoder = Sequential()
  147. question_encoder.add(Embedding(input_dim=vocab_size,
  148. output_dim=64,
  149. input_length=query_maxlen))
  150. question_encoder.add(Dropout(0.3))
  151. # output: (samples, query_maxlen, embedding_dim)
  152. # encode input sequence and questions (which are indices)
  153. # to sequences of dense vectors
  154. input_encoded_m = input_encoder_m(input_sequence)
  155. input_encoded_c = input_encoder_c(input_sequence)
  156. question_encoded = question_encoder(question)
  157. # compute a 'match' between the first input vector sequence
  158. # and the question vector sequence
  159. # shape: `(samples, story_maxlen, query_maxlen)`
  160. match = dot([input_encoded_m, question_encoded], axes=(2, 2))
  161. match = Activation('softmax')(match)
  162. # add the match matrix with the second input vector sequence
  163. response = add([match, input_encoded_c]) # (samples, story_maxlen, query_maxlen)
  164. response = Permute((2, 1))(response) # (samples, query_maxlen, story_maxlen)
  165. # concatenate the match matrix with the question vector sequence
  166. answer = concatenate([response, question_encoded])
  167. # the original paper uses a matrix multiplication for this reduction step.
  168. # we choose to use a RNN instead.
  169. answer = LSTM(32)(answer) # (samples, 32)
  170. # one regularization layer -- more would probably be needed.
  171. answer = Dropout(0.3)(answer)
  172. answer = Dense(vocab_size)(answer) # (samples, vocab_size)
  173. # we output a probability distribution over the vocabulary
  174. answer = Activation('softmax')(answer)
  175. # build the final model
  176. model = Model([input_sequence, question], answer)
  177. model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy',
  178. metrics=['accuracy'])
  179. # train
  180. model.fit([inputs_train, queries_train], answers_train,
  181. batch_size=32,
  182. epochs=120,
  183. validation_data=([inputs_test, queries_test], answers_test))