广播语义
许多pytorch
操作都支持NumPy广播语义
简而言之,如果Pytorch
操作支持广播,则其张量参数可以自动扩展为相同大小(不需要复制数据)。
一般语义
如果pytorch
张量满足以下条件,那么就可以广播:
- 每个张量至少有一个维度。
- 在遍历维度大小时, 从尾部维度开始遍历, 并且二者维度必须相等, 它们其中一个要么是
1
要么不存在.
例如:
>>> x=torch.FloatTensor(5,7,3)
>>> y=torch.FloatTensor(5,7,3)
# 相同形状的质量可以被广播(上述规则总是成立的)
>>> x=torch.FloatTensor()
>>> y=torch.FloatTensor(2,2)
# x和y不能被广播,因为x没有维度
# can line up trailing dimensions
>>> x=torch.FloatTensor(5,3,4,1)
>>> y=torch.FloatTensor( 3,1,1)
# x和y可以被广播
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist
# 但是:
>>> x=torch.FloatTensor(5,2,4,1)
>>> y=torch.FloatTensor( 3,1,1)
# x和y不能被广播,因为在`3rd`中
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3
如果x
和y
可以被广播,得到的张量大小的计算方法如下:
- 如果维数
x
和y
不相等,在维度较少的张量的维度前加上1
使它们相等的长度。 - 然后,对于每个维度的大小,生成维度的大小是
x
和y
的最大值。
例如:
# 可以排列尾部维度,使阅读更容易
>>> x=torch.FloatTensor(5,1,4,1)
>>> y=torch.FloatTensor( 3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])
# 但不是必要的:
>>> x=torch.FloatTensor(1)
>>> y=torch.FloatTensor(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])
>>> x=torch.FloatTensor(5,2,4,1)
>>> y=torch.FloatTensor(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
直接语义(In-place语义)
一个复杂的问题是in-place
操作不允许in-place
张量像广播那样改变形状。
例如:
>>> x=torch.FloatTensor(5,3,4,1)
>>> y=torch.FloatTensor(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])
# but:
>>> x=torch.FloatTensor(1,3,1)
>>> y=torch.FloatTensor(3,1,7)
>>> (x.add_(y)).size()
RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.
向后兼容性
以前版本的PyTorch
只要张量中的元素数量是相等的, 便允许某些点状pointwise
函数在不同的形状的张量上执行, 其中点状操作是通过将每个张量视为1
维来执行。PyTorch
现在支持广播语义并且不推荐使用点状函数操作向量,并且张量不支持广播但具有相同数量的元素将生成一个Python
警告。
注意,广播的引入可能会导致向后不兼容,因为两个张量的形状不同,但是可以被广播且具有相同数量的元素。
例如:
>>> torch.add(torch.ones(4,1), torch.randn(4))
事先生成一个torch.Size([4,1])
的张量,然后再提供一个torch.Size([4,4])
的张量。为了帮助识别你代码中可能存在的由引入广播语义的向后不兼容情况, 您可以设置torch.utils.backcompat.broadcast_warning.enabled
为True
,这种情况下会生成一个python
警告。
例如:
>>> torch.utils.backcompat.broadcast_warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.
译者署名
用户名 | 头像 | 职能 | 签名 |
---|---|---|---|
Song | 翻译 | 人生总要追求点什么 |