元算子:通过元算子实现自己的卷积层

元算子是jittor的关键概念,元算子的层次结构如下所示。

元算子由重索引算子,重索引化简算子和元素级算子组成。重索引算子,重索引化简算子都是一元算子。 重索引算子是其输入和输出之间的一对多映射。重索引简化算子是多对一映射。广播,填补, 切分算子是常见的重新索引算子。 而化简,累乘,累加算子是常见的索引化简算子。 元素级算子是元算子的第三部分,与前两个相比,元素算级子可能包含多个输入。 但是元素级算子的所有输入和输出形状必须相同,它们是一对一映射的。 例如,两个变量的加法是一个二进制的逐元素算子。

元算子:通过元算子实现自己的卷积层 - 图1



元算子的层级结构。元算子包含三类算子,重索引算子,重索引化简算子,元素级算子。元算
子的反向传播算子还是元算子。元算子可以组成常用的深度学习算子。而这些深度学习算子又
可以进一步组成深度学习模型。


在上一个教程中,我们演示了如何通过三个元算子实现矩阵乘法:

  1. def matmul(a, b):
  2. (n, m), k = a.shape, b.shape[-1]
  3. a = a.broadcast([n,m,k], dims=[2])
  4. b = b.broadcast([n,m,k], dims=[0])
  5. return (a*b).sum(dim=1)

在本教程中,我们将展示如何使用元算子实现自己的卷积。

首先,让我们实现一个朴素的Python卷积:

  1. import numpy as np
  2. import os
  3. def conv_naive(x, w):
  4. N,H,W,C = x.shape
  5. Kh, Kw, _C, Kc = w.shape
  6. assert C==_C, (x.shape, w.shape)
  7. y = np.zeros([N,H-Kh+1,W-Kw+1,Kc])
  8. for i0 in range(N):
  9. for i1 in range(H-Kh+1):
  10. for i2 in range(W-Kw+1):
  11. for i3 in range(Kh):
  12. for i4 in range(Kw):
  13. for i5 in range(C):
  14. for i6 in range(Kc):
  15. if i1-i3<0 or i2-i4<0 or i1-i3>=H or i2-i4>=W: continue
  16. y[i0, i1, i2, i6] += x[i0, i1 + i3, i2 + i4, i5] * w[i3,i4,i5,i6]
  17. return y

然后,让我们下载一个猫的图像,并使用conv_naive实现一个简单的水平滤波器。

  1. # %matplotlib inline
  2. import pylab as pl
  3. img_path="/tmp/cat.jpg"
  4. if not os.path.isfile(img_path):
  5. !wget -O - 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/4f/Felis_silvestris_catus_lying_on_rice_straw.jpg/220px-Felis_silvestris_catus_lying_on_rice_straw.jpg' > $img_path
  6. img = pl.imread(img_path)
  7. pl.subplot(121)
  8. pl.imshow(img)
  9. kernel = np.array([
  10. [-1, -1, -1],
  11. [0, 0, 0],
  12. [1, 1, 1],
  13. ])
  14. pl.subplot(122)
  15. x = img[np.newaxis,:,:,:1].astype("float32")
  16. w = kernel[:,:,np.newaxis,np.newaxis].astype("float32")
  17. y = conv_naive(x, w)
  18. print (x.shape, y.shape) # shape exists confusion
  19. pl.imshow(y[0,:,:,0])

看起来不错,我们的naive_conv运作良好。现在让我们用jittor替换我们的朴素实现。

  1. import jittor as jt
  2. def conv(x, w):
  3. N,H,W,C = x.shape
  4. Kh, Kw, _C, Kc = w.shape
  5. assert C==_C
  6. xx = x.reindex([N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc], [
  7. 'i0', # Nid
  8. 'i1+i3', # Hid+Khid
  9. 'i2+i4', # Wid+KWid
  10. 'i5', # Cid|
  11. ])
  12. ww = w.broadcast_var(xx)
  13. yy = xx*ww
  14. y = yy.sum([3,4,5]) # Kh, Kw, c
  15. return y
  16. # Let's disable tuner. This will cause jittor not to use mkl for convolution
  17. jt.flags.enable_tuner = 0
  18. jx = jt.array(x)
  19. jw = jt.array(w)
  20. jy = conv(jx, jw).fetch_sync()
  21. print (jx.shape, jy.shape)
  22. pl.imshow(jy[0,:,:,0])

他们的结果看起来一样。那么它们的性能如何?

  1. %time y = conv_naive(x, w)
  2. %time jy = conv(jx, jw).fetch_sync()

可以看出jittor的实现要快得多。 那么,为什么这两个实现在数学上等效,而jittor的实现运行速度更快? 我们将逐步进行解释:

首先,让我们看一下jt.reindex的帮助文档。

  1. help(jt.reindex)

遵循该文档,我们可以扩展重索引操作以便更好地理解:

  1. xx = x.reindex([N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc], [
  2. 'i0', # Nid
  3. 'i1+i3', # Hid+Khid
  4. 'i2+i4', # Wid+KWid
  5. 'i5', # Cid
  6. ])
  7. ww = w.broadcast_var(xx)
  8. yy = xx*ww
  9. y = yy.sum([3,4,5]) # Kh, Kw, c

扩展后:

  1. shape = [N,H+Kh-1,W+Kw-1,Kh,Kw,C,Kc]
  2. # expansion of x.reindex
  3. xx = np.zeros(shape, x.dtype)
  4. for i0 in range(shape[0]):
  5. for i1 in range(shape[1]):
  6. for i2 in range(shape[2]):
  7. for i3 in range(shape[3]):
  8. for i4 in range(shape[4]):
  9. for i5 in range(shape[5]):
  10. for i6 in range(shape[6]):
  11. if is_overflow(i0,i1,i2,i3,i4,i5,i6):
  12. xx[i0,i1,...,in] = 0
  13. else:
  14. xx[i0,i1,i2,i3,i4,i5,i6] = x[i0,i1+i3,i2+i4,i5]
  15. # expansion of w.broadcast_var(xx)
  16. ww = np.zeros(shape, x.dtype)
  17. for i0 in range(shape[0]):
  18. for i1 in range(shape[1]):
  19. for i2 in range(shape[2]):
  20. for i3 in range(shape[3]):
  21. for i4 in range(shape[4]):
  22. for i5 in range(shape[5]):
  23. for i6 in range(shape[6]):
  24. ww[i0,i1,i2,i3,i4,i5,i6] = w[i3,i4,i5,i6]
  25. # expansion of xx*ww
  26. yy = np.zeros(shape, x.dtype)
  27. for i0 in range(shape[0]):
  28. for i1 in range(shape[1]):
  29. for i2 in range(shape[2]):
  30. for i3 in range(shape[3]):
  31. for i4 in range(shape[4]):
  32. for i5 in range(shape[5]):
  33. for i6 in range(shape[6]):
  34. yy[i0,i1,i2,i3,i4,i5,i6] = xx[i0,i1,i2,i3,i4,i5,i6] * ww[i0,i1,i2,i3,i4,i5,i6]
  35. # expansion of yy.sum([3,4,5])
  36. shape2 = [N,H-Kh+1,W-Kw+1,Kc]
  37. y = np.zeros(shape2, x.dtype)
  38. for i0 in range(shape[0]):
  39. for i1 in range(shape[1]):
  40. for i2 in range(shape[2]):
  41. for i3 in range(shape[3]):
  42. for i4 in range(shape[4]):
  43. for i5 in range(shape[5]):
  44. for i6 in range(shape[6]):
  45. y[i0,i1,i2,i6] += yy[i0,i1,i2,i3,i4,i5,i6]

循环融合后:

  1. shape2 = [N,H-Kh+1,W-Kw+1,Kc]
  2. y = np.zeros(shape2, x.dtype)
  3. for i0 in range(shape[0]):
  4. for i1 in range(shape[1]):
  5. for i2 in range(shape[2]):
  6. for i3 in range(shape[3]):
  7. for i4 in range(shape[4]):
  8. for i5 in range(shape[5]):
  9. for i6 in range(shape[6]):
  10. if not is_overflow(i0,i1,i2,i3,i4,i5,i6):
  11. y[i0,i1,i2,i6] += x[i0,i1+i3,i2+i4,i5] * w[i3,i4,i5,i6]

这是就元算子的优化技巧,它可以将多个算子融合为一个复杂的融合算子,包括许多卷积的变化(例如group conv,separate conv等)。

jittor会尝试将融合算子优化得尽可能快。 让我们尝试一些优化(将形状作为常量编译到内核中),并编译到底层的c++内核代码中。

  1. jt.flags.compile_options={"compile_shapes":1}
  2. with jt.profile_scope() as report:
  3. jy = conv(jx, jw).fetch_sync()
  4. jt.flags.compile_options={}
  5. print(f"Time: {float(report[1][4])/1e6}ms")
  6. with open(report[1][1], 'r') as f:
  7. print(f.read())

比之前的实现还要更快! 从输出中我们可以看一看func0的函数定义,这是我们卷积内核的主要代码,该内核代码是即时生成的。因为编译器知道内核的形状,所以使用了更多的优化方法。

在这个教程中,Jittor简单演示了元算子的使用,并不是真正的性能测试,所以使用了比较小的数据规模进行测试,如果需要性能测试,请打开jt.flags.enable_tuner = 1,会启动使用专门的硬件库加速。