Theano tensor 模块:索引

In [1]:

  1. import theano
  2. import theano.tensor as T
  3. import numpy as np
  1. Using gpu device 1: Tesla C2075 (CNMeM is disabled)

简单索引

tensor 模块完全支持 numpy 中的简单索引:

In [2]:

  1. t = T.arange(9)
  2.  
  3. print t[1::2].eval()
  1. [1 3 5 7]

numpy 结果:

In [3]:

  1. n = np.arange(9)
  2.  
  3. print n[1::2]
  1. [1 3 5 7]

mask 索引

tensor 模块虽然支持简单索引,但并不支持 mask 索引,例如这样的做法是错误的:

In [4]:

  1. t = T.arange(9).reshape((3,3))
  2.  
  3. print t[t > 4].eval()
  1. [[[0 1 2]
  2. [0 1 2]
  3. [0 1 2]]
  4.  
  5. [[0 1 2]
  6. [0 1 2]
  7. [3 4 5]]
  8.  
  9. [[3 4 5]
  10. [3 4 5]
  11. [3 4 5]]]

numpy 中的结果:

In [5]:

  1. n = np.arange(9).reshape((3,3))
  2.  
  3. print n[n > 4]
  1. [5 6 7 8]

要想像 numpy 一样得到正确结果,我们需要使用这样的方法:

In [6]:

  1. print t[(t > 4).nonzero()].eval()
  1. [5 6 7 8]

使用索引进行赋值

tensor 模块不支持直接使用索引赋值,例如 a[5] = b, a[5]+=b 等是不允许的。

不过可以考虑用 set_subtensorinc_subtensor 来实现类似的功能:

T.set_subtensor(x, y)

实现类似 r[10:] = 5 的功能:

In [7]:

  1. r = T.vector()
  2.  
  3. new_r = T.set_subtensor(r[10:], 5)

T.inc_subtensor(x, y)

实现类似 r[10:] += 5 的功能:

In [8]:

  1. r = T.vector()
  2.  
  3. new_r = T.inc_subtensor(r[10:], 5)

原文: https://nbviewer.jupyter.org/github/lijin-THU/notes-python/blob/master/09-theano/09.16-tensor-indexing.ipynb