cond

注意:该API仅支持【静态图】模式

  • paddle.fluid.layers.cond(pred, true_fn=None, false_fn=None, name=None)[源代码]

如果 predTrue ,该API返回 true_fn() ,否则返回 false_fn() 。 用户如果不想在 callable 中做任何事,可以把 true_fnfalse_fn 设为 None ,此时本API会把该 callable 视为简单返回 None

true_fnfalse_fn 需要返回同样嵌套结构(nest structure)的Tensor,如果不想返回任何值也可都返回 None 。 PaddlePaddle里Tensor的嵌套结构是指一个Tensor,或者Tensor的元组(tuple),或者Tensor的列表(list)。

注解

  • 因为PaddlePaddle的静态图数据流, true_fnfalse_fn 返回的元组必须形状相同,但是里面的Tensor形状可以不同。

  • 不论运行哪个分支,在 true_fnfalse_fn 外创建的Tensor和Op都会被运行,即PaddlePaddle并不是惰性语法(lazy semantics)。例如

  1. import paddle.fluid as fluid
  2. a = fluid.data(name='a', shape=[-1, 1], dtype='float32')
  3. b = fluid.data(name='b', shape=[-1, 1], dtype='float32')
  4. c = a * b
  5. out = fluid.layers.cond(a < b, lambda: a + c, lambda: b * b)

不管 a < b 是否成立, c = a * b 都会被运行。

  • 参数:
    • pred (Variable) - 一个形状为[1]的布尔型(boolean)的Tensor,该布尔值决定要返回 true_fn 还是 false_fn 的运行结果。
    • true_fn (callable) - 一个当 predTrue 时被调用的callable,默认值: None
    • false_fn (callable) - 一个当 predFalse 时被调用的callable,默认值: None
    • name (str,可选) – 具体用法请参见 Name ,一般无需设置,默认值: None
  • 返回:
  • 如果 predTrue ,该API返回 true_fn() ,否则返回 false_fn()

返回类型:Variable|list(Variable)|tuple(Variable)

  • 抛出异常:
    • TypeError - 如果 true_fnfalse_fn 不是callable。
    • ValueError - 如果 true_fnfalse_fn 没有返回同样的嵌套结构(nest structure),对嵌套结构的解释见上文。

代码示例

  1. import paddle.fluid as fluid
  2. import paddle.fluid.layers as layers
  3. from paddle.fluid.executor import Executor
  4. from paddle.fluid.framework import Program, program_guard
  5.  
  6. #
  7. # pseudocode:
  8. # if 0.1 < 0.23:
  9. # return 1, True
  10. # else:
  11. # return 3, 2
  12. #
  13.  
  14. def true_func():
  15. return layers.fill_constant(
  16. shape=[1, 2], dtype='int32', value=1), layers.fill_constant(
  17. shape=[2, 3], dtype='bool', value=True)
  18.  
  19. def false_func():
  20. return layers.fill_constant(
  21. shape=[3, 4], dtype='float32', value=3), layers.fill_constant(
  22. shape=[4, 5], dtype='int64', value=2)
  23.  
  24. main_program = Program()
  25. startup_program = Program()
  26. with program_guard(main_program, startup_program):
  27. x = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
  28. y = layers.fill_constant(shape=[1], dtype='float32', value=0.23)
  29. pred = layers.less_than(x, y)
  30. out = layers.cond(pred, true_func, false_func)
  31. # out is a tuple containing 2 tensors
  32.  
  33. place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
  34. ) else fluid.CPUPlace()
  35. exe = fluid.Executor(place)
  36. ret = exe.run(main_program, fetch_list=out)
  37. # ret[0] = [[1 1]]
  38. # ret[1] = [[ True True True]
  39. # [ True True True]]