用优雅的方式计算Transformers自注意力层梯度
用优雅的方式计算Transformers自注意力层梯度
1. 引言
在深度学习中,我们经常使用 backward 方法自动计算梯度,却很少真正动手推导。然而在一些需要手动优化的场景中,理解梯度的计算过程是不可或缺的。
可惜的是,目前相关资料不仅稀缺,内容质量也参差不齐。如果仅依赖繁琐的逐元素(element-wise)连加符号进行推导,过程往往显得不够直观和优雅。因此,本文将以 Transformer 中的自注意力层为例,使用微分法计算,力求严谨、清晰。
我们为了简化问题,这里不考虑多头注意力。对于一般的注意力层,我们有
其中。设损失函数为,已知上游梯度,我们考虑求,以及其他中间所有变量的梯度。
2. 基本概念
2.1 梯度
对于标量函数和矩阵变量,梯度是一个与形状相同的矩阵,满足
直观地理解,梯度矩阵指示了使函数值增长最快的方向(最陡上升方向)。为了使推导简洁,我们在后文中统一记为。
2.2 微分
在一元微积分中,我们知道函数的微分是。
推广到多元函数,对于标量函数和矩阵变量,全微分定义为所有自变量变化引起函数值变化的线性主要部分,即
事实上,这本质是梯度矩阵与变化量矩阵 的内积。利用矩阵迹(Trace)的性质,我们可以将其写成更紧凑的矩阵形式:结合前文定义的简写,此时我们得到了矩阵微分法最核心的联系公式:这给了我们一种全新的求解梯度的视角: 如果我们能通过运算法则求出,并将其整理为 的形式,那么根据一一对应关系,矩阵就是我们要找的梯度。3. 重要性质
3.1 迹
矩阵的迹定义为对角线元素之和。在矩阵求导中,迹主要起到了交换位置的作用。
最常用的性质如下:
- 转置不变性:
- 循环性质:这是最重要的性质,允许我们在乘积中循环移动矩阵的位置(前提是维度匹配):
注意:这里是循环移动,而不是随意交换。一般情况下 。 - 伴随性质:这其实是源于内积的伴随性质
3.2 哈达玛乘积
符号表示矩阵的逐元素乘积(Element-wise Product),即
我们会比较深入的利用它的如下几个性质。- 交换律结合律与分配律:
这是最基本的性质,和标量乘法一致: - 迹运算中的“游动”性质(最重要的推导工具)
在计算矩阵内积(即迹运算)时,哈达玛积中的项可以“跳”到另一边的矩阵上:这是因为在推导梯度时,这个性质允许我们将 旁边的系数(比如激活函数的导数)“移走”,从而凑出 的形式。 - 广播与矩阵乘法的结合律(行缩放性质)
设 是列向量, 是矩阵, 是任意矩阵(或向量),则有:这是因为因此,当参与运算的是列向量(用于行广播)时,哈达玛积可以表现出特殊的结合性。
3.3 微分
矩阵微分的运算法则与标量微积分高度相似,唯一的区别在于矩阵乘法不满足交换律,因此在求导时必须严格保持矩阵的左右顺序。
假设 是常数矩阵, 是变量矩阵:
- 加法法则:
- 乘法法则(注意顺序):
- 转置法则:
- 逐元素函数法则:
若 是对 逐元素应用函数(如 ReLU),则:其中 是对 中每个元素求导后组成的矩阵。
4. 问题解决
有了这些性质后,解决我们最开始提出的问题会变得非常简单。
4.1 P、V 的梯度
首先我们计算
因此从而4.2 S 的梯度
接下来我们要计算,会稍微麻烦一点。注意到
其中,元素全为 1,表示行和组成的矩阵。此时我们先取对数此时两边求微分有注意到正好满足列向量的行传播形式,因此因此注意到右边化简为这里用到了伴随性质。因此从而4.3 Q、K 的梯度
剩余比较简单,直接给出结果为
5. 总结常见的梯度
下面我们不加证明的给出常见的梯度公式,可作为练习。
5.1 线性函数
对于,,为行向量,设
此时其中为行向量。5.2 激活函数
对于,设
此时其中对于,设
此时其中。5.3 损失函数
对于,均方误差为
此时对于,,交叉熵损失为
此时若则