执行阶段
这部分要短得多,更简单。 首先,我们加载 MNIST。 我们可以像之前的章节那样使用 ScikitLearn,但是 TensorFlow 提供了自己的助手来获取数据,将其缩放(0 到 1 之间),将它洗牌,并提供一个简单的功能来一次加载一个小批量:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/")
现在我们定义我们要运行的迭代数,以及小批量的大小:
n_epochs = 10001
batch_size = 50
现在我们去训练模型:
with tf.Session() as sess:
init.run()
for epoch in range(n_epochs):
for iteration in range(mnist.train.num_examples // batch_size):
X_batch, y_batch = mnist.train.next_batch(batch_size)
sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
acc_test = accuracy.eval(feed_dict={X: mnist.test.images, y: mnist.test.labels})
print(epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)
save_path = saver.save(sess, "./my_model_final.ckpt")
该代码打开一个 TensorFlow 会话,并运行初始化所有变量的init
节点。 然后它运行的主要训练循环:在每个时期,通过一些小批次的对应于训练集的大小的代码进行迭代。 每个小批量通过next_batch()
方法获取,然后代码简单地运行训练操作,为当前的小批量输入数据和目标提供。 接下来,在每个时期结束时,代码评估最后一个小批量和完整训练集上的模型,并打印出结果。 最后,模型参数保存到磁盘。