InputSpec

class paddle.static. InputSpec ( shape=None, dtype=’float32’, name=None ) [源代码]

用于描述模型输入的签名信息,包括shape、dtype和name。

此接口常用于指定高层API中模型的输入张量信息,或动态图转静态图时,指定被 paddle.jit.to_static 装饰的forward函数每个输入参数的张量信息。

参数:

  • shape (list|tuple)- 声明维度信息的list或tuple,默认值为None。

  • dtype (np.dtype|VarType|str,可选)- 数据类型,支持bool,float16,float32,float64,int8,int16,int32,int64,uint8。默认值为float32。

  • name (str)- 被创建对象的名字,具体用法请参见 Name

返回:初始化后的 InputSpec 对象

返回类型:InputSpec

代码示例

  1. from paddle.static import InputSpec
  2. input = InputSpec([None, 784], 'float32', 'x')
  3. label = InputSpec([None, 1], 'int64', 'label')
  4. print(input) # InputSpec(shape=(-1, 784), dtype=VarType.FP32, name=x)
  5. print(label) # InputSpec(shape=(-1, 1), dtype=VarType.INT64, name=label)

from_tensor ( tensor, name=None )

该接口将根据输入Tensor的shape、dtype等信息构建InputSpec对象。

参数:

  • tensor (Tensor) - 用于构建InputSpec的源Tensor

  • name (str): 被创建对象的名字,具体用法请参见 Name 。 默认为:None。

返回:根据Tensor信息构造的 InputSpec 对象

返回类型:InputSpec

代码示例

  1. import numpy as np
  2. import paddle
  3. from paddle.static import InputSpec
  4. x = paddle.to_tensor(np.ones([2, 2], np.float32))
  5. x_spec = InputSpec.from_tensor(x, name='x')
  6. print(x_spec) # InputSpec(shape=(2, 2), dtype=VarType.FP32, name=x)

from_numpy ( ndarray, name=None )

该接口将根据输入numpy ndarray的shape、dtype等信息构建InputSpec对象。

参数:

  • ndarray (Tensor) - 用于构建InputSpec的numpy ndarray

  • name (str): 被创建对象的名字,具体用法请参见 Name 。 默认为:None。

返回:根据ndarray信息构造的 InputSpec 对象

返回类型:InputSpec

代码示例

  1. import numpy as np
  2. from paddle.static import InputSpec
  3. x = np.ones([2, 2], np.float32)
  4. x_spec = InputSpec.from_numpy(x, name='x')
  5. print(x_spec) # InputSpec(shape=(2, 2), dtype=VarType.FP32, name=x)

batch ( batch_size )

该接口将batch_size插入到当前InputSpec对象的shape元组最前面。

参数:

  • batch_size (int) - 被插入的batch size整型数值

返回: 更新shape信息后的 InputSpec 对象

返回类型:InputSpec

代码示例

  1. from paddle.static import InputSpec
  2. x_spec = InputSpec(shape=[64], dtype='float32', name='x')
  3. x_spec.batch(4)
  4. print(x_spec) # InputSpec(shape=(4, 64), dtype=VarType.FP32, name=x)

unbatch ( )

该接口将当前InputSpec对象shape[0]值移除。

返回: 更新shape信息后的 InputSpec 对象

返回类型:InputSpec

代码示例

  1. from paddle.static import InputSpec
  2. x_spec = InputSpec(shape=[4, 64], dtype='float32', name='x')
  3. x_spec.unbatch()
  4. print(x_spec) # InputSpec(shape=(64,), dtype=VarType.FP32, name=x)

使用本API的教程文档