tf.TensorArray :TensorFlow 动态数组 *
在部分网络结构,尤其是涉及到时间序列的结构中,我们可能需要将一系列张量以数组的方式依次存放起来,以供进一步处理。当然,在Eager Execution下,你可以直接使用一个Python列表(List)存放数组。不过,如果你需要基于计算图的特性(例如使用 @tf.function
加速模型运行或者使用SavedModel导出模型),就无法使用这种方式了。因此,TensorFlow提供了 tf.TensorArray
,一种支持计算图特性的TensorFlow动态数组。
由于需要支持计算图, tf.TensorArray
的使用方式和一般编程语言中的列表/数组类型略有不同,包括4个方法:
一个简单的示例如下:
- import tensorflow as tf
- @tf.function
- def array_write_and_read():
- arr = tf.TensorArray(dtype=tf.float32, size=3)
- arr = arr.write(0, tf.constant(0.0))
- arr = arr.write(1, tf.constant(1.0))
- arr = arr.write(2, tf.constant(2.0))
- arr_0 = arr.read(0)
- arr_1 = arr.read(1)
- arr_2 = arr.read(2)
- return arr_0, arr_1, arr_2
- a, b, c = array_write_and_read()
- print(a, b, c)
输出:
- tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(1.0, shape=(), dtype=float32) tf.Tensor(2.0, shape=(), dtype=float32)