7.2. 梯度下降和随机梯度下降
在本节中,我们将介绍梯度下降(gradientdescent)的工作原理。虽然梯度下降在深度学习中很少被直接使用,但理解梯度的意义以及沿着梯度反方向更新自变量可能降低目标函数值的原因是学习后续优化算法的基础。随后,我们将引出随机梯度下降(stochasticgradient descent)。
7.2.1. 一维梯度下降
我们先以简单的一维梯度下降为例,解释梯度下降算法可能降低目标函数值的原因。假设连续可导的函数
的输入和输出都是标量。给定绝对值足够小的数 ,根据泰勒展开公式(参见附录中“数学基础”一节),我们得到以下的近似:
这里
是函数 在 处的梯度。一维函数的梯度是一个标量,也称导数。
接下来,找到一个常数
,使得 足够小,那么可以将 替换为 并得到
如果导数
,那么 ,所以
这意味着,如果通过
来迭代
,函数 的值可能会降低。因此在梯度下降中,我们先选取一个初始值 和常数 ,然后不断通过上式来迭代 ,直到达到停止条件,例如 的值已足够小或迭代次数已达到某个值。
下面我们以目标函数
为例来看一看梯度下降是如何工作的。虽然我们知道最小化 的解为 ,这里依然使用这个简单函数来观察 是如何被迭代的。首先,导入本节实验所需的包或模块。
- In [1]:
- %matplotlib inline
- import d2lzh as d2l
- import math
- from mxnet import nd
- import numpy as np
接下来使用
作为初始值,并设 。使用梯度下降对 迭代10次,可见最终 的值较接近最优解。
- In [2]:
- def gd(eta):
- x = 10
- results = [x]
- for i in range(10):
- x -= eta * 2 * x # f(x) = x * x的导数为f'(x) = 2 * x
- results.append(x)
- print('epoch 10, x:', x)
- return results
- res = gd(0.2)
- epoch 10, x: 0.06046617599999997
下面将绘制出自变量
的迭代轨迹。
- In [3]:
- def show_trace(res):
- n = max(abs(min(res)), abs(max(res)), 10)
- f_line = np.arange(-n, n, 0.1)
- d2l.set_figsize()
- d2l.plt.plot(f_line, [x * x for x in f_line])
- d2l.plt.plot(res, [x * x for x in res], '-o')
- d2l.plt.xlabel('x')
- d2l.plt.ylabel('f(x)')
- show_trace(res)
7.2.2. 学习率
上述梯度下降算法中的正数
通常叫作学习率。这是一个超参数,需要人工设定。如果使用过小的学习率,会导致 更新缓慢从而需要更多的迭代才能得到较好的解。
下面展示使用学习率
时自变量 的迭代轨迹。可见,同样迭代10次后,当学习率过小时,最终 的值依然与最优解存在较大偏差。
- In [4]:
- show_trace(gd(0.05))
- epoch 10, x: 3.4867844009999995
如果使用过大的学习率,
可能会过大从而使前面提到的一阶泰勒展开公式不再成立:这时我们无法保证迭代 会降低 的值。
举个例子,当设学习率
时,可以看到 不断越过(overshoot)最优解 并逐渐发散。
- In [5]:
- show_trace(gd(1.1))
- epoch 10, x: 61.917364224000096
7.2.3. 多维梯度下降
在了解了一维梯度下降之后,我们再考虑一种更广义的情况:目标函数的输入为向量,输出为标量。假设目标函数
的输入是一个 维向量 。目标函数 有关 的梯度是一个由 个偏导数组成的向量:
为表示简洁,我们用
代替 。梯度中每个偏导数元素 代表着 在 有关输入 的变化率。为了测量 沿着单位向量 (即 )方向上的变化率,在多元微积分中,我们定义 在 上沿着 方向的方向导数为
依据方向导数性质 [1,14.6节定理三],以上方向导数可以改写为
方向导数
给出了 在 上沿着所有可能方向的变化率。为了最小化 ,我们希望找到 能被降低最快的方向。因此,我们可以通过单位向量 来最小化方向导数 。
由于
,其中 为梯度 和单位向量 之间的夹角,当 时, 取得最小值 。因此,当 在梯度方向 的相反方向时,方向导数 被最小化。因此,我们可能通过梯度下降算法来不断降低目标函数 的值:
同样,其中
(取正数)称作学习率。
下面我们构造一个输入为二维向量
和输出为标量的目标函数 。那么,梯度 。我们将观察梯度下降从初始位置 开始对自变量 的迭代轨迹。我们先定义两个辅助函数,第一个函数使用给定的自变量更新函数,从初始位置 开始迭代自变量 共20次,第二个函数对自变量 的迭代轨迹进行可视化。
- In [6]:
- def train_2d(trainer): # 本函数将保存在d2lzh包中方便以后使用
- x1, x2, s1, s2 = -5, -2, 0, 0 # s1和s2是自变量状态,本章后续几节会使用
- results = [(x1, x2)]
- for i in range(20):
- x1, x2, s1, s2 = trainer(x1, x2, s1, s2)
- results.append((x1, x2))
- print('epoch %d, x1 %f, x2 %f' % (i + 1, x1, x2))
- return results
- def show_trace_2d(f, results): # 本函数将保存在d2lzh包中方便以后使用
- d2l.plt.plot(*zip(*results), '-o', color='#ff7f0e')
- x1, x2 = np.meshgrid(np.arange(-5.5, 1.0, 0.1), np.arange(-3.0, 1.0, 0.1))
- d2l.plt.contour(x1, x2, f(x1, x2), colors='#1f77b4')
- d2l.plt.xlabel('x1')
- d2l.plt.ylabel('x2')
然后,观察学习率为
时自变量的迭代轨迹。使用梯度下降对自变量 迭代20次后,可见最终 的值较接近最优解 。
- In [7]:
- eta = 0.1
- def f_2d(x1, x2): # 目标函数
- return x1 ** 2 + 2 * x2 ** 2
- def gd_2d(x1, x2, s1, s2):
- return (x1 - eta * 2 * x1, x2 - eta * 4 * x2, 0, 0)
- show_trace_2d(f_2d, train_2d(gd_2d))
- epoch 20, x1 -0.057646, x2 -0.000073
7.2.4. 随机梯度下降
在深度学习里,目标函数通常是训练数据集中有关各个样本的损失函数的平均。设
是有关索引为 的训练数据样本的损失函数, 是训练数据样本数, 是模型的参数向量,那么目标函数定义为
目标函数在
处的梯度计算为
如果使用梯度下降,每次自变量迭代的计算开销为
,它随着 线性增长。因此,当训练数据样本数很大时,梯度下降每次迭代的计算开销很高。
随机梯度下降(stochastic gradientdescent,SGD)减少了每次迭代的计算开销。在随机梯度下降的每次迭代中,我们随机均匀采样的一个样本索引
,并计算梯度 来迭代 :
这里
同样是学习率。可以看到每次迭代的计算开销从梯度下降的 降到了常数 。值得强调的是,随机梯度 是对梯度 的无偏估计:
这意味着,平均来说,随机梯度是对梯度的一个良好的估计。
下面我们通过在梯度中添加均值为0的随机噪声来模拟随机梯度下降,以此来比较它与梯度下降的区别。
- In [8]:
- def sgd_2d(x1, x2, s1, s2):
- return (x1 - eta * (2 * x1 + np.random.normal(0.1)),
- x2 - eta * (4 * x2 + np.random.normal(0.1)), 0, 0)
- show_trace_2d(f_2d, train_2d(sgd_2d))
- epoch 20, x1 -0.229936, x2 -0.125876
可以看到,随机梯度下降中自变量的迭代轨迹相对于梯度下降中的来说更为曲折。这是由于实验所添加的噪声使模拟的随机梯度的准确度下降。在实际中,这些噪声通常指训练数据集中的无意义的干扰。
7.2.5. 小结
- 使用适当的学习率,沿着梯度反方向更新自变量可能降低目标函数值。梯度下降重复这一更新过程直到得到满足要求的解。
- 学习率过大或过小都有问题。一个合适的学习率通常是需要通过多次实验找到的。
- 当训练数据集的样本较多时,梯度下降每次迭代的计算开销较大,因而随机梯度下降通常更受青睐。
7.2.6. 练习
- 使用一个不同的目标函数,观察梯度下降和随机梯度下降中自变量的迭代轨迹。
- 在二维梯度下降的实验中尝试使用不同的学习率,观察并分析实验现象。
7.2.7. 参考文献
[1] Stewart, J. (2010). Calculus: early transcendentals. 7th ed. CengageLearning.