GreedyEmbeddingHelper

class paddle.fluid.layers.GreedyEmbeddingHelper(embedding_fn, start_tokens, end_token)[源代码]

GreedyEmbeddingHelper是 DecodeHelper 的子类。作为解码helper,它使用 argmax 进行采样,并将采样结果送入embedding层,以此作为下一解码步的输入。

参数

  • embedding_fn (callable) - 作用于 argmax 结果的函数,通常是一个将词id转换为词嵌入的embedding层,注意 ,这里要使用 embedding 而非 embedding,因为选中的id的形状是

    GreedyEmbeddingHelper - 图1

    ,如果使用后者则还需要在这里提供unsqueeze。

  • start_tokens (Variable) - 形状为

    GreedyEmbeddingHelper - 图2

    、数据类型为int64、 值为起始标记id的tensor。

  • end_token (int) - 结束标记id。

代码示例

  1. import paddle.fluid as fluid
  2. import paddle.fluid.layers as layers
  3. start_tokens = fluid.data(name="start_tokens",
  4. shape=[None],
  5. dtype="int64")
  6. trg_embeder = lambda x: fluid.embedding(
  7. x, size=[10000, 128], param_attr=fluid.ParamAttr(name="trg_embedding"))
  8. output_layer = lambda x: layers.fc(x,
  9. size=10000,
  10. num_flatten_dims=len(x.shape) - 1,
  11. param_attr=fluid.ParamAttr(name=
  12. "output_w"),
  13. bias_attr=False)
  14. helper = layers.GreedyEmbeddingHelper(trg_embeder, start_tokens=start_tokens, end_token=1)
  15. decoder_cell = layers.GRUCell(hidden_size=128)
  16. decoder = layers.BasicDecoder(decoder_cell, helper, output_fn=output_layer)
  17. outputs = layers.dynamic_decode(
  18. decoder=decoder, inits=decoder_cell.get_initial_states(start_tokens))

initialize()

GreedyEmbeddingHelper初始化,其使用构造函数中的 start_tokens 作为第一个解码步的输入,并给出每个序列是否结束的初始标识。这是 BasicDecoder 初始化的一部分。

返回:(initial_inputs, initial_finished) 的二元组, initial_inputs 同构造函数中的 start_tokensinitial_finished 是一个bool类型、值为False的tensor,其形状和 start_tokens 相同。

返回类型:tuple

sample(time, outputs, states)

使用 argmax 根据 outputs 进行采样。

参数:

  • time (Variable) - 调用者提供的形状为[1]的tensor,表示当前解码的时间步长。其数据类型为int64。
  • outputs (Variable) - tensor变量,通常其数据类型为float32或float64,形状为

    GreedyEmbeddingHelper - 图3

    ,表示当前解码步预测产生的logit(未归一化的概率),和由 BasicDecoder.output_fn(BasicDecoder.cell.call()) 返回的 outputs 是同一内容。

  • states (Variable) - 单个tensor变量或tensor变量组成的嵌套结构,和由 BasicDecoder.cell.call() 返回的 new_states 是同一内容。

返回:数据类型为int64形状为

GreedyEmbeddingHelper - 图4

的tensor,表示采样得到的id。

返回类型:Variable

next_inputs(time, outputs, states, sample_ids)

sample_ids 使用 embedding_fn ,以此作为下一解码步的输入;同时直接使用输入参数中的 states 作为下一解码步的状态;并通过判别 sample_ids 是否得到 end_token,依此产生每个序列是否结束的标识。

参数:

  • time (Variable) - 调用者提供的形状为[1]的tensor,表示当前解码的时间步长。其数据类型为int64。
  • outputs (Variable) - tensor变量,通常其数据类型为float32或float64,形状为

    GreedyEmbeddingHelper - 图5

    ,表示当前解码步预测产生的logit(未归一化的概率),和由 BasicDecoder.output_fn(BasicDecoder.cell.call()) 返回的 outputs 是同一内容。

  • states (Variable) - 单个tensor变量或tensor变量组成的嵌套结构,和由 BasicDecoder.cell.call() 返回的 new_states 是同一内容。
  • sample_ids (Variable) - 数据类型为int64形状为

    GreedyEmbeddingHelper - 图6

    的tensor,和由 sample() 返回的 sample_ids 是同一内容。

返回: (finished, next_inputs, next_states) 的三元组。 next_inputs, next_states 均是单个tensor变量或tensor变量组成的嵌套结构,tensor的形状是

GreedyEmbeddingHelper - 图7

next_states 和输入参数中的 states 相同; finished 是一个bool类型且形状为

GreedyEmbeddingHelper - 图8

的tensor。

返回类型:tuple