Theano 实例:人工神经网络

神经网络的模型可以参考 UFLDL 的教程,这里不做过多描述。

http://ufldl.stanford.edu/wiki/index.php/%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C

In [1]:

  1. import theano
  2. import theano.tensor as T
  3.  
  4. import numpy as np
  5. from load import mnist
  1. Using gpu device 1: Tesla K10.G2.8GB (CNMeM is disabled)

我们在这里使用一个简单的三层神经网络:输入 - 隐层 - 输出。

对于网络的激活函数,隐层用 sigmoid 函数,输出层用 softmax 函数,其模型如下:

\begin{aligned} h & = \sigma (W_h X) \ o & = \text{softmax} (W_o h)\end{aligned}

In [2]:

  1. def model(X, w_h, w_o):
  2. """
  3. input:
  4. X: input data
  5. w_h: hidden unit weights
  6. w_o: output unit weights
  7. output:
  8. Y: probability of y given x
  9. """
  10. # 隐层
  11. h = T.nnet.sigmoid(T.dot(X, w_h))
  12. # 输出层
  13. pyx = T.nnet.softmax(T.dot(h, w_o))
  14. return pyx

使用随机梯度下降的方法进行训练:

In [3]:

  1. def sgd(cost, params, lr=0.05):
  2. """
  3. input:
  4. cost: cost function
  5. params: parameters
  6. lr: learning rate
  7. output:
  8. update rules
  9. """
  10. grads = T.grad(cost=cost, wrt=params)
  11. updates = []
  12. for p, g in zip(params, grads):
  13. updates.append([p, p - g * lr])
  14. return updates

对于 MNIST 手写数字的问题,我们使用一个 784 × 625 × 10 即输入层大小为 784,隐层大小为 625,输出层大小为 10 的神经网络来模拟,最后的输出表示数字为 09 的概率。

为了对权重进行更新,我们需要将权重设为 shared 变量:

In [4]:

  1. def floatX(X):
  2. return np.asarray(X, dtype=theano.config.floatX)
  3.  
  4. def init_weights(shape):
  5. return theano.shared(floatX(np.random.randn(*shape) * 0.01))

因此变量初始化为:

In [5]:

  1. X = T.matrix()
  2. Y = T.matrix()
  3.  
  4. w_h = init_weights((784, 625))
  5. w_o = init_weights((625, 10))

模型输出为:

In [6]:

  1. py_x = model(X, w_h, w_o)

预测的结果为:

In [7]:

  1. y_x = T.argmax(py_x, axis=1)

模型的误差函数为:

In [8]:

  1. cost = T.mean(T.nnet.categorical_crossentropy(py_x, Y))

更新规则为:

In [9]:

  1. updates = sgd(cost, [w_h, w_o])

定义训练和预测的函数:

In [10]:

  1. train = theano.function(inputs=[X, Y], outputs=cost, updates=updates, allow_input_downcast=True)
  2. predict = theano.function(inputs=[X], outputs=y_x, allow_input_downcast=True)

训练:

导入 MNIST 数据:

In [11]:

  1. trX, teX, trY, teY = mnist(onehot=True)

训练 100 轮,正确率为 0.956:

In [12]:

  1. for i in range(100):
  2. for start, end in zip(range(0, len(trX), 128), range(128, len(trX), 128)):
  3. cost = train(trX[start:end], trY[start:end])
  4. print "{0:03d}".format(i), np.mean(np.argmax(teY, axis=1) == predict(teX))
  1. 000 0.7028
  2. 001 0.8285
  3. 002 0.8673
  4. 003 0.883
  5. 004 0.89
  6. 005 0.895
  7. 006 0.8984
  8. 007 0.9017
  9. 008 0.9047
  10. 009 0.907
  11. 010 0.9089
  12. 011 0.9105
  13. 012 0.9127
  14. 013 0.914
  15. 014 0.9152
  16. 015 0.9159
  17. 016 0.9169
  18. 017 0.9173
  19. 018 0.918
  20. 019 0.9185
  21. 020 0.919
  22. 021 0.9197
  23. 022 0.9201
  24. 023 0.9205
  25. 024 0.9206
  26. 025 0.9212
  27. 026 0.9219
  28. 027 0.9228
  29. 028 0.9228
  30. 029 0.9229
  31. 030 0.9236
  32. 031 0.9244
  33. 032 0.925
  34. 033 0.9255
  35. 034 0.9263
  36. 035 0.927
  37. 036 0.9274
  38. 037 0.9278
  39. 038 0.928
  40. 039 0.9284
  41. 040 0.9289
  42. 041 0.9294
  43. 042 0.9298
  44. 043 0.9302
  45. 044 0.9311
  46. 045 0.932
  47. 046 0.9325
  48. 047 0.9332
  49. 048 0.934
  50. 049 0.9347
  51. 050 0.9354
  52. 051 0.9358
  53. 052 0.9365
  54. 053 0.9372
  55. 054 0.9377
  56. 055 0.9385
  57. 056 0.9395
  58. 057 0.9399
  59. 058 0.9405
  60. 059 0.9411
  61. 060 0.9416
  62. 061 0.9422
  63. 062 0.9427
  64. 063 0.9429
  65. 064 0.9431
  66. 065 0.9438
  67. 066 0.9444
  68. 067 0.9446
  69. 068 0.9449
  70. 069 0.9453
  71. 070 0.9458
  72. 071 0.9462
  73. 072 0.9469
  74. 073 0.9475
  75. 074 0.9474
  76. 075 0.9476
  77. 076 0.948
  78. 077 0.949
  79. 078 0.9497
  80. 079 0.95
  81. 080 0.9503
  82. 081 0.9507
  83. 082 0.9507
  84. 083 0.9515
  85. 084 0.9519
  86. 085 0.9521
  87. 086 0.9523
  88. 087 0.9529
  89. 088 0.9536
  90. 089 0.9538
  91. 090 0.9542
  92. 091 0.9545
  93. 092 0.9544
  94. 093 0.9546
  95. 094 0.9547
  96. 095 0.9549
  97. 096 0.9552
  98. 097 0.9554
  99. 098 0.9557
  100. 099 0.9562

原文: https://nbviewer.jupyter.org/github/lijin-THU/notes-python/blob/master/09-theano/09.11-net-on-mnist.ipynb