符号微分
图 D-1 展示了符号微分是如何运行在相当简单的函数上的,。该函数的计算图如图的左边所示。通过符号微分,我们可得到图的右部分,它代表了 ,相似地也可得到关于y
的导数。
概算法先获得叶子节点的偏导数。常数 5 返回常数 0,因为常数的导数总是 0。变量x
返回常数 1,变量y
返回常数 0,因为 (如果我们找关于y
的偏导数,那它将反过来)。
现在我们移动到计算图的相乘节点处,代数告诉我们,u
和v
相乘后的导数为 。因此我们可以构造有图中大的部分,代表0 × x + y × 1
。
最后我们往上走到计算图的相加节点处,正如 5 条规则里提到的,和的导数等于导数的和。所以我们只需要创建一个相加节点,连接我们已经计算出来的部分。我们可以得到正确的偏导数,即:。
然而,这个过程可简化。对该图应用一些微不足道的剪枝步骤,可以去掉所有不必要的操作,然后我们可以得到一个小得多的只有一个节点的偏导计算图:。
在这个例子里,简化操作是相当简单的,但对更复杂的函数来说,符号微分会产生一个巨大的计算图,该图可能很难去简化,以导致次优的性能。更重要的是,符号微分不能处理由任意代码定义的函数,例如,如下已在第 9 章讨论过的函数:
def my_func(a, b):
z = 0
for i in range(100):
z = a * np.cos(z + i) + z * np.sin(b - i)
return z