Restore a character-level sequence to sequence model from to generate predictions.

This script loads the s2s.h5 model saved by lstm_seq2seq.py and generates sequences from it. It assumesthat no changes have been made (for example: latent_dim is unchanged,and the input data and model architecture are unchanged).

See lstm_seq2seq.py for more details on themodel architecture and how it is trained.

  1. from __future__ import print_function
  2. from keras.models import Model, load_model
  3. from keras.layers import Input
  4. import numpy as np
  5. batch_size = 64 # Batch size for training.
  6. epochs = 100 # Number of epochs to train for.
  7. latent_dim = 256 # Latent dimensionality of the encoding space.
  8. num_samples = 10000 # Number of samples to train on.
  9. # Path to the data txt file on disk.
  10. data_path = 'fra-eng/fra.txt'
  11. # Vectorize the data. We use the same approach as the training script.
  12. # NOTE: the data must be identical, in order for the character -> integer
  13. # mappings to be consistent.
  14. # We omit encoding target_texts since they are not needed.
  15. input_texts = []
  16. target_texts = []
  17. input_characters = set()
  18. target_characters = set()
  19. with open(data_path, 'r', encoding='utf-8') as f:
  20. lines = f.read().split('\n')
  21. for line in lines[: min(num_samples, len(lines) - 1)]:
  22. input_text, target_text = line.split('\t')
  23. # We use "tab" as the "start sequence" character
  24. # for the targets, and "\n" as "end sequence" character.
  25. target_text = '\t' + target_text + '\n'
  26. input_texts.append(input_text)
  27. target_texts.append(target_text)
  28. for char in input_text:
  29. if char not in input_characters:
  30. input_characters.add(char)
  31. for char in target_text:
  32. if char not in target_characters:
  33. target_characters.add(char)
  34. input_characters = sorted(list(input_characters))
  35. target_characters = sorted(list(target_characters))
  36. num_encoder_tokens = len(input_characters)
  37. num_decoder_tokens = len(target_characters)
  38. max_encoder_seq_length = max([len(txt) for txt in input_texts])
  39. max_decoder_seq_length = max([len(txt) for txt in target_texts])
  40. print('Number of samples:', len(input_texts))
  41. print('Number of unique input tokens:', num_encoder_tokens)
  42. print('Number of unique output tokens:', num_decoder_tokens)
  43. print('Max sequence length for inputs:', max_encoder_seq_length)
  44. print('Max sequence length for outputs:', max_decoder_seq_length)
  45. input_token_index = dict(
  46. [(char, i) for i, char in enumerate(input_characters)])
  47. target_token_index = dict(
  48. [(char, i) for i, char in enumerate(target_characters)])
  49. encoder_input_data = np.zeros(
  50. (len(input_texts), max_encoder_seq_length, num_encoder_tokens),
  51. dtype='float32')
  52. for i, input_text in enumerate(input_texts):
  53. for t, char in enumerate(input_text):
  54. encoder_input_data[i, t, input_token_index[char]] = 1.
  55. # Restore the model and construct the encoder and decoder.
  56. model = load_model('s2s.h5')
  57. encoder_inputs = model.input[0] # input_1
  58. encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output # lstm_1
  59. encoder_states = [state_h_enc, state_c_enc]
  60. encoder_model = Model(encoder_inputs, encoder_states)
  61. decoder_inputs = model.input[1] # input_2
  62. decoder_state_input_h = Input(shape=(latent_dim,), name='input_3')
  63. decoder_state_input_c = Input(shape=(latent_dim,), name='input_4')
  64. decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
  65. decoder_lstm = model.layers[3]
  66. decoder_outputs, state_h_dec, state_c_dec = decoder_lstm(
  67. decoder_inputs, initial_state=decoder_states_inputs)
  68. decoder_states = [state_h_dec, state_c_dec]
  69. decoder_dense = model.layers[4]
  70. decoder_outputs = decoder_dense(decoder_outputs)
  71. decoder_model = Model(
  72. [decoder_inputs] + decoder_states_inputs,
  73. [decoder_outputs] + decoder_states)
  74. # Reverse-lookup token index to decode sequences back to
  75. # something readable.
  76. reverse_input_char_index = dict(
  77. (i, char) for char, i in input_token_index.items())
  78. reverse_target_char_index = dict(
  79. (i, char) for char, i in target_token_index.items())
  80. # Decodes an input sequence. Future work should support beam search.
  81. def decode_sequence(input_seq):
  82. # Encode the input as state vectors.
  83. states_value = encoder_model.predict(input_seq)
  84. # Generate empty target sequence of length 1.
  85. target_seq = np.zeros((1, 1, num_decoder_tokens))
  86. # Populate the first character of target sequence with the start character.
  87. target_seq[0, 0, target_token_index['\t']] = 1.
  88. # Sampling loop for a batch of sequences
  89. # (to simplify, here we assume a batch of size 1).
  90. stop_condition = False
  91. decoded_sentence = ''
  92. while not stop_condition:
  93. output_tokens, h, c = decoder_model.predict(
  94. [target_seq] + states_value)
  95. # Sample a token
  96. sampled_token_index = np.argmax(output_tokens[0, -1, :])
  97. sampled_char = reverse_target_char_index[sampled_token_index]
  98. decoded_sentence += sampled_char
  99. # Exit condition: either hit max length
  100. # or find stop character.
  101. if (sampled_char == '\n' or
  102. len(decoded_sentence) > max_decoder_seq_length):
  103. stop_condition = True
  104. # Update the target sequence (of length 1).
  105. target_seq = np.zeros((1, 1, num_decoder_tokens))
  106. target_seq[0, 0, sampled_token_index] = 1.
  107. # Update states
  108. states_value = [h, c]
  109. return decoded_sentence
  110. for seq_index in range(100):
  111. # Take one sequence (part of the training set)
  112. # for trying out decoding.
  113. input_seq = encoder_input_data[seq_index: seq_index + 1]
  114. decoded_sentence = decode_sequence(input_seq)
  115. print('-')
  116. print('Input sentence:', input_texts[seq_index])
  117. print('Decoded sentence:', decoded_sentence)