神经网络的梯度

考虑有如下 dd 层的神经网络

ht=ft(ht1) and y=fdf1(x)\mathbf{h}^{t}=f_{t}\left(\mathbf{h}^{t-1}\right) \quad \text { and } \quad y=\ell \circ f_{d} \circ \ldots \circ f_{1}(\mathbf{x})

计算损失 \ell 关于参数 Wt\mathbf{W}^{t} 的梯度

Wt=hdhdhd1ht+1hthtWt\frac{\partial \ell}{\partial \mathbf{W}^{t}}=\frac{\partial \ell}{\partial \mathbf{h}^{d}} \frac{\partial \mathbf{h}^{d}}{\partial \mathbf{h}^{d-1}} \ldots \frac{\partial \mathbf{h}^{t+1}}{\partial \mathbf{h}^{t}} \frac{\partial \mathbf{h}^{t}}{\partial \mathbf{W}^{t}}

进行了 dtd-t 次矩阵乘法

数值稳定性的两个常见问题

梯度爆炸

1.51004×10171.5^{100} \approx 4 \times 10^{17}

  • 值超出值域(infinity)
    • 对于 16 位浮点数尤为严重(数值区间 [6×105,6×104][6\times10^{-5},6\times10^{4}]
  • 对学习率敏感
    • 如果学习率太大 → 大参数值 → 更大的梯度
    • 如果学习率太小 → 训练无进展
    • 可能需要在训练过程中不断调整学习率

梯度消失

0.81002×10100.8^{100} \approx 2 \times 10^{-10}

  • 梯度值变成 0
    • 对 16 位浮点数尤为严重
  • 训练没有进展
    • 不管如何选择学习率
  • 对于底部层尤为严重
    • 仅仅顶部层训练的较好
    • 无法让神经网络更深

以 MLP 为例

加入如下 MLP(为了简单省略了偏移)

ft(ht1)=σ(Wtht1)σ 是激活函数 f_{t}\left(\mathbf{h}^{t-1}\right)=\sigma\left(\mathbf{W}^{t} \mathbf{h}^{t-1}\right) \quad \sigma \text { 是激活函数 }

htht1=diag(σ(Wtht1))(W^t)Tσ 是 σ 的导数函数 \frac{\partial \mathbf{h}^{t}}{\partial \mathbf{h}^{t-1}}=\operatorname{diag}\left(\sigma^{\prime}\left(\mathbf{W}^{t} \mathbf{h}^{t-1}\right)\right)\left(\hat{W}^{t}\right)^{T} \quad \sigma^{\prime} \text { 是 } \sigma \text { 的导数函数 }

i=td1hi+1hi=i=td1diag(σ(Wihi1))(Wi)T\prod_{i=t}^{d-1} \frac{\partial \mathbf{h}^{i+1}}{\partial \mathbf{h}^{i}}=\prod_{i=t}^{d-1} \operatorname{diag}\left(\sigma^{\prime}\left(\mathbf{W}^{i} \mathbf{h}^{i-1}\right)\right)\left(W^{i}\right)^{T}

梯度爆炸

使用 ReLU 作为激活函数

σ(x)=max(0,x) and σ(x)={1 if x>00 otherwise \sigma(x)=\max (0, x) \quad \text { and } \quad \sigma^{\prime}(x)= \begin{cases}1 & \text { if } x>0 \\ 0 & \text { otherwise }\end{cases}

i=td1hi+1hi=i=td1diag(σ(Wihi1))(Wi)T\prod_{i=t}^{d-1} \frac{\partial \mathbf{h}^{i+1}}{\partial \mathbf{h}^{i}}=\prod_{i=t}^{d-1} \operatorname{diag}\left(\sigma^{\prime}\left(\mathbf{W}^{i} \mathbf{h}^{i-1}\right)\right)\left(W^{i}\right)^{T} 的一些元素会来自 i=td1(Wi)T\prod_{i=t}^{d-1} \left(W^{i}\right)^{T}

如果 dtd-t 很大(层数很深),值可能会很大

梯度消失

使用 sigmoid 作为激活函数

σ(x)=11+exσ(x)=σ(x)(1σ(x))\sigma(x)=\frac{1}{1+e^{-x}} \quad \sigma^{\prime}(x)=\sigma(x)(1-\sigma(x))

Untitled

i=td1hi+1hi=i=td1diag(σ(Wihi1))(Wi)T\prod_{i=t}^{d-1} \frac{\partial \mathbf{h}^{i+1}}{\partial \mathbf{h}^{i}}=\prod_{i=t}^{d-1} \operatorname{diag}\left(\sigma^{\prime}\left(\mathbf{W}^{i} \mathbf{h}^{i-1}\right)\right)\left(W^{i}\right)^{T} 可能是 dtd-t 个小数值的乘积

总结

  • 当数值过大或者过小时会导致数值问题
  • 常发生在深度模型中,因为其会对 nn 个数累乘