Sequence to sequence example in Keras (character-level).

This script demonstrates how to implement a basic character-levelsequence-to-sequence model. We apply it to translatingshort English sentences into short French sentences,character-by-character. Note that it is fairly unusual todo character-level machine translation, as word-levelmodels are more common in this domain.

Summary of the algorithm

  • We start with input sequences from a domain (e.g. English sentences) and corresponding target sequences from another domain (e.g. French sentences).
  • An encoder LSTM turns input sequences to 2 state vectors (we keep the last LSTM state and discard the outputs).
  • A decoder LSTM is trained to turn the target sequences into the same sequence but offset by one timestep in the future, a training process called "teacher forcing" in this context. It uses as initial state the state vectors from the encoder. Effectively, the decoder learns to generate targets[t+1…] given targets[…t], conditioned on the input sequence.
  • In inference mode, when we want to decode unknown input sequences, we:
    • Encode the input sequence into state vectors
    • Start with a target sequence of size 1 (just the start-of-sequence character)
    • Feed the state vectors and 1-char target sequence to the decoder to produce predictions for the next character
    • Sample the next character using these predictions (we simply use argmax).
    • Append the sampled character to the target sequence
    • Repeat until we generate the end-of-sequence character or we hit the character limit.

Data download

English to French sentence pairs.

Lots of neat sentence pairs datasets.

References

  1. from __future__ import print_function
  2. from keras.models import Model
  3. from keras.layers import Input, LSTM, Dense
  4. import numpy as np
  5. batch_size = 64 # Batch size for training.
  6. epochs = 100 # Number of epochs to train for.
  7. latent_dim = 256 # Latent dimensionality of the encoding space.
  8. num_samples = 10000 # Number of samples to train on.
  9. # Path to the data txt file on disk.
  10. data_path = 'fra-eng/fra.txt'
  11. # Vectorize the data.
  12. input_texts = []
  13. target_texts = []
  14. input_characters = set()
  15. target_characters = set()
  16. with open(data_path, 'r', encoding='utf-8') as f:
  17. lines = f.read().split('\n')
  18. for line in lines[: min(num_samples, len(lines) - 1)]:
  19. input_text, target_text = line.split('\t')
  20. # We use "tab" as the "start sequence" character
  21. # for the targets, and "\n" as "end sequence" character.
  22. target_text = '\t' + target_text + '\n'
  23. input_texts.append(input_text)
  24. target_texts.append(target_text)
  25. for char in input_text:
  26. if char not in input_characters:
  27. input_characters.add(char)
  28. for char in target_text:
  29. if char not in target_characters:
  30. target_characters.add(char)
  31. input_characters = sorted(list(input_characters))
  32. target_characters = sorted(list(target_characters))
  33. num_encoder_tokens = len(input_characters)
  34. num_decoder_tokens = len(target_characters)
  35. max_encoder_seq_length = max([len(txt) for txt in input_texts])
  36. max_decoder_seq_length = max([len(txt) for txt in target_texts])
  37. print('Number of samples:', len(input_texts))
  38. print('Number of unique input tokens:', num_encoder_tokens)
  39. print('Number of unique output tokens:', num_decoder_tokens)
  40. print('Max sequence length for inputs:', max_encoder_seq_length)
  41. print('Max sequence length for outputs:', max_decoder_seq_length)
  42. input_token_index = dict(
  43. [(char, i) for i, char in enumerate(input_characters)])
  44. target_token_index = dict(
  45. [(char, i) for i, char in enumerate(target_characters)])
  46. encoder_input_data = np.zeros(
  47. (len(input_texts), max_encoder_seq_length, num_encoder_tokens),
  48. dtype='float32')
  49. decoder_input_data = np.zeros(
  50. (len(input_texts), max_decoder_seq_length, num_decoder_tokens),
  51. dtype='float32')
  52. decoder_target_data = np.zeros(
  53. (len(input_texts), max_decoder_seq_length, num_decoder_tokens),
  54. dtype='float32')
  55. for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
  56. for t, char in enumerate(input_text):
  57. encoder_input_data[i, t, input_token_index[char]] = 1.
  58. encoder_input_data[i, t + 1:, input_token_index[' ']] = 1.
  59. for t, char in enumerate(target_text):
  60. # decoder_target_data is ahead of decoder_input_data by one timestep
  61. decoder_input_data[i, t, target_token_index[char]] = 1.
  62. if t > 0:
  63. # decoder_target_data will be ahead by one timestep
  64. # and will not include the start character.
  65. decoder_target_data[i, t - 1, target_token_index[char]] = 1.
  66. decoder_input_data[i, t + 1:, target_token_index[' ']] = 1.
  67. decoder_target_data[i, t:, target_token_index[' ']] = 1.
  68. # Define an input sequence and process it.
  69. encoder_inputs = Input(shape=(None, num_encoder_tokens))
  70. encoder = LSTM(latent_dim, return_state=True)
  71. encoder_outputs, state_h, state_c = encoder(encoder_inputs)
  72. # We discard `encoder_outputs` and only keep the states.
  73. encoder_states = [state_h, state_c]
  74. # Set up the decoder, using `encoder_states` as initial state.
  75. decoder_inputs = Input(shape=(None, num_decoder_tokens))
  76. # We set up our decoder to return full output sequences,
  77. # and to return internal states as well. We don't use the
  78. # return states in the training model, but we will use them in inference.
  79. decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
  80. decoder_outputs, _, _ = decoder_lstm(decoder_inputs,
  81. initial_state=encoder_states)
  82. decoder_dense = Dense(num_decoder_tokens, activation='softmax')
  83. decoder_outputs = decoder_dense(decoder_outputs)
  84. # Define the model that will turn
  85. # `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
  86. model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
  87. # Run training
  88. model.compile(optimizer='rmsprop', loss='categorical_crossentropy',
  89. metrics=['accuracy'])
  90. model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
  91. batch_size=batch_size,
  92. epochs=epochs,
  93. validation_split=0.2)
  94. # Save model
  95. model.save('s2s.h5')
  96. # Next: inference mode (sampling).
  97. # Here's the drill:
  98. # 1) encode input and retrieve initial decoder state
  99. # 2) run one step of decoder with this initial state
  100. # and a "start of sequence" token as target.
  101. # Output will be the next target token
  102. # 3) Repeat with the current target token and current states
  103. # Define sampling models
  104. encoder_model = Model(encoder_inputs, encoder_states)
  105. decoder_state_input_h = Input(shape=(latent_dim,))
  106. decoder_state_input_c = Input(shape=(latent_dim,))
  107. decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
  108. decoder_outputs, state_h, state_c = decoder_lstm(
  109. decoder_inputs, initial_state=decoder_states_inputs)
  110. decoder_states = [state_h, state_c]
  111. decoder_outputs = decoder_dense(decoder_outputs)
  112. decoder_model = Model(
  113. [decoder_inputs] + decoder_states_inputs,
  114. [decoder_outputs] + decoder_states)
  115. # Reverse-lookup token index to decode sequences back to
  116. # something readable.
  117. reverse_input_char_index = dict(
  118. (i, char) for char, i in input_token_index.items())
  119. reverse_target_char_index = dict(
  120. (i, char) for char, i in target_token_index.items())
  121. def decode_sequence(input_seq):
  122. # Encode the input as state vectors.
  123. states_value = encoder_model.predict(input_seq)
  124. # Generate empty target sequence of length 1.
  125. target_seq = np.zeros((1, 1, num_decoder_tokens))
  126. # Populate the first character of target sequence with the start character.
  127. target_seq[0, 0, target_token_index['\t']] = 1.
  128. # Sampling loop for a batch of sequences
  129. # (to simplify, here we assume a batch of size 1).
  130. stop_condition = False
  131. decoded_sentence = ''
  132. while not stop_condition:
  133. output_tokens, h, c = decoder_model.predict(
  134. [target_seq] + states_value)
  135. # Sample a token
  136. sampled_token_index = np.argmax(output_tokens[0, -1, :])
  137. sampled_char = reverse_target_char_index[sampled_token_index]
  138. decoded_sentence += sampled_char
  139. # Exit condition: either hit max length
  140. # or find stop character.
  141. if (sampled_char == '\n' or
  142. len(decoded_sentence) > max_decoder_seq_length):
  143. stop_condition = True
  144. # Update the target sequence (of length 1).
  145. target_seq = np.zeros((1, 1, num_decoder_tokens))
  146. target_seq[0, 0, sampled_token_index] = 1.
  147. # Update states
  148. states_value = [h, c]
  149. return decoded_sentence
  150. for seq_index in range(100):
  151. # Take one sequence (part of the training set)
  152. # for trying out decoding.
  153. input_seq = encoder_input_data[seq_index: seq_index + 1]
  154. decoded_sentence = decode_sequence(input_seq)
  155. print('-')
  156. print('Input sentence:', input_texts[seq_index])
  157. print('Decoded sentence:', decoded_sentence)