Switch

注意:该API仅支持【静态图】模式

  • class paddle.fluid.layers.Switch(name=None)[源代码]

该类用于实现Switch分支控制功能。Switch分支包含多个case分支和一个default分支,Switch控制流会依次检查各case分支条件是否满足,并仅执行第一个满足条件的case分支后面的语句。若不存在满足条件的case分支,则仅执行default分支后面的语句。

注解

如果参数 cond 的形状为[1],强烈建议您使用新的OP case 而不是 Switch。 OP case 的使用方式更简单,并且调用该OP所用的代码更少且功能与 Switch 一样。

  • 成员函数:
    • case(cond) - Switch的case分支,其参数cond为bool型的标量Variable。只有当前case分支的cond为True,且之前的case分支的cond均为False,该case分支后的语句才会执行,且不再执行之后的case后的语句。
    • default() - Switch的default分支。当所有case分支的cond均为False时,执行default分支后的语句。

注意:case和default函数只能用于Switch的scope内部,示例如下:

  1. with fluid.layers.Switch() as switch:
  2. with switch.case(cond1):
  3. i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=1)
  4. with switch.case(cond2):
  5. i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=2)
  6. with switch.default():
  7. i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
  • 参数:
    • name (str,可选) - 具体用法请参见 Name ,一般无需设置,默认值为None。

代码示例

  1. import paddle.fluid as fluid
  2.  
  3. lr = fluid.layers.create_global_var(
  4. shape=[1],
  5. value=0.0,
  6. dtype='float32',
  7. persistable=True,
  8. name="learning_rate")
  9. zero_var = fluid.layers.fill_constant(
  10. shape=[1], dtype='float32', value=0.0)
  11. one_var = fluid.layers.fill_constant(
  12. shape=[1], dtype='float32', value=1.0)
  13. two_var = fluid.layers.fill_constant(
  14. shape=[1], dtype='float32', value=2.0)
  15.  
  16. # 将参数中的begin设为非0值,则进入Switch的default分支,输出数组中的数字将为2
  17. global_step = fluid.layers.autoincreased_step_counter(counter_name='@LR_DECAY_COUNTER@', begin=0, step=1)
  18.  
  19. with fluid.layers.control_flow.Switch() as switch:
  20. with switch.case(global_step == zero_var):
  21. fluid.layers.assign(input=one_var, output=lr)
  22. with switch.default():
  23. fluid.layers.assign(input=two_var, output=lr)
  24.  
  25. exe = fluid.Executor(fluid.CPUPlace())
  26. exe.run(fluid.default_startup_program())
  27.  
  28. res = exe.run(fluid.default_main_program(), feed={}, fetch_list=[lr])
  29. print(res) # [array([1.], dtype=float32)]