Theano tensor 模块:索引
In [1]:
- import theano
- import theano.tensor as T
- import numpy as np
- Using gpu device 1: Tesla C2075 (CNMeM is disabled)
简单索引
tensor
模块完全支持 numpy
中的简单索引:
In [2]:
- t = T.arange(9)
- print t[1::2].eval()
- [1 3 5 7]
numpy
结果:
In [3]:
- n = np.arange(9)
- print n[1::2]
- [1 3 5 7]
mask 索引
tensor
模块虽然支持简单索引,但并不支持 mask
索引,例如这样的做法是错误的:
In [4]:
- t = T.arange(9).reshape((3,3))
- print t[t > 4].eval()
- [[[0 1 2]
- [0 1 2]
- [0 1 2]]
- [[0 1 2]
- [0 1 2]
- [3 4 5]]
- [[3 4 5]
- [3 4 5]
- [3 4 5]]]
numpy
中的结果:
In [5]:
- n = np.arange(9).reshape((3,3))
- print n[n > 4]
- [5 6 7 8]
要想像 numpy
一样得到正确结果,我们需要使用这样的方法:
In [6]:
- print t[(t > 4).nonzero()].eval()
- [5 6 7 8]
使用索引进行赋值
tensor
模块不支持直接使用索引赋值,例如 a[5] = b, a[5]+=b
等是不允许的。
不过可以考虑用 set_subtensor
和 inc_subtensor
来实现类似的功能:
T.set_subtensor(x, y)
实现类似 r[10:] = 5 的功能:
In [7]:
- r = T.vector()
- new_r = T.set_subtensor(r[10:], 5)
T.inc_subtensor(x, y)
实现类似 r[10:] += 5 的功能:
In [8]:
- r = T.vector()
- new_r = T.inc_subtensor(r[10:], 5)