Theano 实例:Softmax 回归
MNIST 数据集的下载和导入
MNIST 数据集 是一个手写数字组成的数据集,现在被当作一个机器学习算法评测的基准数据集。
这是一个下载并解压数据的脚本:
In [1]:
- %%file download_mnist.py
- import os
- import os.path
- import urllib
- import gzip
- import shutil
- if not os.path.exists('mnist'):
- os.mkdir('mnist')
- def download_and_gzip(name):
- if not os.path.exists(name + '.gz'):
- urllib.urlretrieve('http://yann.lecun.com/exdb/' + name + '.gz', name + '.gz')
- if not os.path.exists(name):
- with gzip.open(name + '.gz', 'rb') as f_in, open(name, 'wb') as f_out:
- shutil.copyfileobj(f_in, f_out)
- download_and_gzip('mnist/train-images-idx3-ubyte')
- download_and_gzip('mnist/train-labels-idx1-ubyte')
- download_and_gzip('mnist/t10k-images-idx3-ubyte')
- download_and_gzip('mnist/t10k-labels-idx1-ubyte')
- Overwriting download_mnist.py
可以运行这个脚本来下载和解压数据:
In [2]:
- %run download_mnist.py
使用如下的脚本来导入 MNIST 数据,源码地址:
https://github.com/Newmu/Theano-Tutorials/blob/master/load.py
In [3]:
- %%file load.py
- import numpy as np
- import os
- datasets_dir = './'
- def one_hot(x,n):
- if type(x) == list:
- x = np.array(x)
- x = x.flatten()
- o_h = np.zeros((len(x),n))
- o_h[np.arange(len(x)),x] = 1
- return o_h
- def mnist(ntrain=60000,ntest=10000,onehot=True):
- data_dir = os.path.join(datasets_dir,'mnist/')
- fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
- loaded = np.fromfile(file=fd,dtype=np.uint8)
- trX = loaded[16:].reshape((60000,28*28)).astype(float)
- fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))
- loaded = np.fromfile(file=fd,dtype=np.uint8)
- trY = loaded[8:].reshape((60000))
- fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))
- loaded = np.fromfile(file=fd,dtype=np.uint8)
- teX = loaded[16:].reshape((10000,28*28)).astype(float)
- fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))
- loaded = np.fromfile(file=fd,dtype=np.uint8)
- teY = loaded[8:].reshape((10000))
- trX = trX/255.
- teX = teX/255.
- trX = trX[:ntrain]
- trY = trY[:ntrain]
- teX = teX[:ntest]
- teY = teY[:ntest]
- if onehot:
- trY = one_hot(trY, 10)
- teY = one_hot(teY, 10)
- else:
- trY = np.asarray(trY)
- teY = np.asarray(teY)
- return trX,teX,trY,teY
- Overwriting load.py
softmax 回归
Softmax
回归相当于 Logistic
回归的一个一般化,Logistic
回归处理的是两类问题,Softmax
回归处理的是 N
类问题。
Logistic
回归输出的是标签为 1 的概率(标签为 0 的概率也就知道了),对应地,对 N 类问题 Softmax
输出的是每个类对应的概率。
具体的内容,可以参考 UFLDL
教程:
http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92
In [4]:
- import theano
- from theano import tensor as T
- import numpy as np
- from load import mnist
- Using gpu device 1: Tesla C2075 (CNMeM is disabled)
我们来看它具体的实现。
这两个函数一个是将数据转化为 GPU
计算的类型,另一个是初始化权重:
In [5]:
- def floatX(X):
- return np.asarray(X, dtype=theano.config.floatX)
- def init_weights(shape):
- return theano.shared(floatX(np.random.randn(*shape) * 0.01))
Softmax
的模型在 theano
中已经实现好了:
In [6]:
- A = T.matrix()
- B = T.nnet.softmax(A)
- test_softmax = theano.function([A], B)
- a = floatX(np.random.rand(3, 4))
- b = test_softmax(a)
- print b.shape
- # 行和
- print b.sum(1)
- (3, 4)
- [ 1.00000012 1. 1. ]
softmax
函数会按照行对矩阵进行 Softmax
归一化。
所以我们的模型为:
In [7]:
- def model(X, w):
- return T.nnet.softmax(T.dot(X, w))
导入数据:
In [8]:
- trX, teX, trY, teY = mnist(onehot=True)
定义变量,并初始化权重:
In [9]:
- X = T.fmatrix()
- Y = T.fmatrix()
- w = init_weights((784, 10))
定义模型输出和预测:
In [10]:
- py_x = model(X, w)
- y_pred = T.argmax(py_x, axis=1)
损失函数为多类的交叉熵,这个在 theano
中也被定义好了:
In [11]:
- cost = T.mean(T.nnet.categorical_crossentropy(py_x, Y))
- gradient = T.grad(cost=cost, wrt=w)
- update = [[w, w - gradient * 0.05]]
编译 train
和 predict
函数:
In [12]:
- train = theano.function(inputs=[X, Y], outputs=cost, updates=update, allow_input_downcast=True)
- predict = theano.function(inputs=[X], outputs=y_pred, allow_input_downcast=True)
迭代 100 次,测试集正确率为 0.925:
In [13]:
- for i in range(100):
- for start, end in zip(range(0, len(trX), 128), range(128, len(trX), 128)):
- cost = train(trX[start:end], trY[start:end])
- print "{0:03d}".format(i), np.mean(np.argmax(teY, axis=1) == predict(teX))
- 000 0.8862
- 001 0.8985
- 002 0.9042
- 003 0.9084
- 004 0.9104
- 005 0.9121
- 006 0.9121
- 007 0.9142
- 008 0.9158
- 009 0.9163
- 010 0.9162
- 011 0.9166
- 012 0.9171
- 013 0.9176
- 014 0.9182
- 015 0.9182
- 016 0.9184
- 017 0.9188
- 018 0.919
- 019 0.919
- 020 0.9194
- 021 0.9201
- 022 0.9204
- 023 0.9203
- 024 0.9205
- 025 0.9207
- 026 0.9207
- 027 0.9209
- 028 0.9214
- 029 0.9213
- 030 0.9212
- 031 0.9211
- 032 0.9217
- 033 0.9217
- 034 0.9217
- 035 0.922
- 036 0.9222
- 037 0.922
- 038 0.922
- 039 0.9218
- 040 0.9219
- 041 0.9223
- 042 0.9225
- 043 0.9226
- 044 0.9227
- 045 0.9225
- 046 0.9227
- 047 0.9231
- 048 0.9231
- 049 0.9231
- 050 0.9232
- 051 0.9232
- 052 0.9231
- 053 0.9231
- 054 0.9233
- 055 0.9233
- 056 0.9237
- 057 0.9239
- 058 0.9239
- 059 0.9239
- 060 0.924
- 061 0.9242
- 062 0.9242
- 063 0.9243
- 064 0.9243
- 065 0.9244
- 066 0.9244
- 067 0.9244
- 068 0.9245
- 069 0.9244
- 070 0.9244
- 071 0.9245
- 072 0.9244
- 073 0.9243
- 074 0.9243
- 075 0.9244
- 076 0.9243
- 077 0.9242
- 078 0.9244
- 079 0.9244
- 080 0.9243
- 081 0.9242
- 082 0.9239
- 083 0.9241
- 084 0.9242
- 085 0.9243
- 086 0.9244
- 087 0.9243
- 088 0.9243
- 089 0.9244
- 090 0.9246
- 091 0.9246
- 092 0.9246
- 093 0.9247
- 094 0.9246
- 095 0.9246
- 096 0.9246
- 097 0.9246
- 098 0.9246
- 099 0.9248