0. 前言
本文是个人原创专栏《大模型学习笔记 - 知乎》的第三十三篇,点击原文「链接」获取更加阅读体验。欢迎各位在评论区交流、指正,更需要您的关注、点赞 ^.^。
书接上回,继续围绕 MiniMax-01 技术报告 中值得关注的点进行精读,顺带着对已经略感遗忘的大模型知识进行复习。此篇笔记的内容将主要围绕 MiniMax-01 的层归一化来展开。
笔者的归纳总结:
- 对于层归一化的方法,MiniMax-01 选择了常规的 RMSNorm;
- 对于层归一化的位置,MiniMax-01 选择了当下并不流行的 Post-Norm 结构,这是经过实验对比后的选择;
- 对于 DeepNorm 的应用,笔者查看代码后的判断是,只采用了放大残差连接的部分。
强烈建议先阅读苏神的文章:浅谈Transformer的初始化、参数化与标准化,从二阶矩的角度清晰地介绍了参数初始化、Normalization、NTK参数化(Softmax Attention 的除以)是如何相互配合以保证深度模型的训练稳定性。
如果有需要,也可以阅读本专栏的第三篇文章:大模型结构基础(三):归一化技术的升级,对于 RMSNorm、DeepNorm,以及 Post-Norm 和 Pre-Norm 的对比进行了基本的介绍。本篇笔记会对其中的一些重点进行强调和补充。
1. 基础要点回顾
1.1 RMSNorm
仅借用苏神文章里的公式,强调一下 RMSNorm 究竟是什么操作。
在 Vanilla Transformer 中,下表 i 对应着 Batch 中第 i 个 Sequence,j 对应 Sequence 中第 j 个 Token,k 对应着 Token 某层向量表示的第 k 维。RMSNorm 的操作:
- 先对一个 Batch 中全部 Token 的向量表示,独立地做归一化;
- 再对一个 Batch 中全部 Token 的向量表示,共用 d 个系数,在每一维上进行不同尺度的放缩。
1.2 DeepNorm
建议阅读 DeepNorm 的论文,Arxiv 版链接,PAMI 版链接。
设计 DeepNorm 的原始出发点是提高 Post-Norm 的训练稳定性,论文先论证了 Post-Norm 训练不稳定的原因:
- 在训练初始阶段,model update 过快;
- 因 model update 剧烈,进一步引发层归一化的梯度消失。
DeepNorm 论文用模型在单步更新前后的输出变化的范数(上图中的)来刻画 Model Update 的剧烈程度,并以控制 Model 单步 Update 的剧烈程度至常量级作为目标,来推导 DeepNorm 的具体设计。
DeepNorm 最终选择了一种平衡使用放大残差连接和缩小参数初始化两种手段的设计。笔者这里想表达的是,DeepNorm 论文给出了一套适用于 Encoder/Deocder-only 结构或 Encoder-Decoder 结构的固定设计参数,但实际上可以有不同的具体设计,只要能够将 Model 单步 Update 的剧烈程度限制在常量级。
DeepNorm 论文给出的标准操作如下图所示:残差连接的放大系数为,参数初始化的缩小系数是。
1.3 Post-Norm vs Pre-Norm
DeepNorm 论文从 Model Update 的剧烈程度入手,分析了 Post-Norm 结构训练稳定性差的原因。相比之下,Post-Norm 对于残差连接效果的弱化更容易让人理解 Pre-Norm 为什么出现,相关内容在前言部分建议阅读的资料中都包含,笔者就不在本篇笔记中赘述了。
但是 Pre-Norm 也有缺点,在模型能力方面,Pre-Norm 要逊于 Post-Norm。MiniMax-01 技术报告采用的解释是:Pre-Norm 对于残差连接的保护,在一定程度上造成了模型有效深度的降低(笔者注:保护残差连接的手段太简单粗暴,过犹不及);DeepNorm 论文采用的解释是:底层的梯度整体趋向于大于顶层的梯度(笔者注:各层的等效学习率之间存在明显差异,各层参数更新的同步性不佳)。
2. MiniMax-01 的层归一化
如上图所示,无论是 Softmax Attention Layer 还是 Lightning Attention Layer(可参考:大模型结构基础(八):MiniMax-01 精读之 Hybrid Lightning Attention),MiniMax-01 都采用了:
- RMSNorm;
- Post-Norm,即层归一化在残差连接后;
- 残差连接根据系数的放大。
2.1 MiniMaxText01RMSNorm
这部分可自行对照代码与1.1节中的公式。
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxText01
class MiniMaxText01RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
MiniMaxText01RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
2.2 Post-Norm
MiniMax-01 对于 Post-Norm 的代码实现,是在 Pre-Norm 的代码基础上改造而来的。注意看 MiniMaxText01DecoderLayer 构造函数中层归一化的名称,input_layernorm 和 post_attention_layernorm,典型的 Pre-Norm 特征。仔细看完 forward 函数就会意识到,MiniMax-01 在首层的输入端仍然保留有一次 Pre-Norm;而且最后一层输出端的 Post-Norm 需要单独实现在 DecoderLayer 之外,事实也的确如此。
class MiniMaxText01DecoderLayer(nn.Module):
def __init__(self, config: MiniMaxText01Config, layer_idx: int):
...
self.input_layernorm = MiniMaxText01RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MiniMaxText01RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...
def forward(
...
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if self.postnorm:
residual = hidden_states
... self.self_attn ...
hidden_states = residual * self.layernorm_attention_alpha \
+ hidden_states * self.layernorm_attention_beta
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if self.postnorm:
residual = hidden_states
... self.block_sparse_moe ...
hidden_states = residual * self.layernorm_mlp_alpha \
+ hidden_states * self.layernorm_mlp_beta
...
2.3 一半的 DeepNorm
上方的代码中就包含了 MiniMax-01 对于 DeepNorm 的应用,下面单独抽出来。
hidden_states = residual * self.layernorm_attention_alpha \
+ hidden_states * self.layernorm_attention_beta
hidden_states = residual * self.layernorm_mlp_alpha \
+ hidden_states * self.layernorm_mlp_beta
layernorm_attention_alpha、layernorm_attention_beta、layernorm_mlp_alpha 和 layernorm_mlp_beta 都是提前设定好的超参数:
{
"layernorm_full_attention_alpha": 3.5565588200778455,
"layernorm_full_attention_beta": 1.0,
"layernorm_linear_attention_alpha": 3.5565588200778455,
"layernorm_linear_attention_beta": 1.0,
"layernorm_mlp_alpha": 3.5565588200778455,
"layernorm_mlp_beta": 1.0,
}
无论是 Softmax Attention Layer 还是 Linear Attention Layer,attention 之后的残差连接的放大系数都是 3.55655(过一会儿再说这个数字是怎么来的)。
注意!MiniMax-01 代码中的 beta 变量,跟 DeepNorm 原始设计中的完全不是一个含义,而且固定为1,可以忽略其存在。而笔者未发现 MiniMax-01 在相关参数的初始化中采用了 DeepNorm 原始设计中的缩小。
因此,MiniMax-01 应该只采用了 DeepNorm 放大残差连接的部分,更类似对于 Post-Norm 弱化残差连接效果的补偿。
最后,我们来看放大系数 3.55655 是如何确定的。MiniMax-01 完全采用了 DeepNorm 论文给定的方案,即在 decoder-only 结构中,残差连接的放大系数是二倍层数的四分之一次方,MiniMax-01 共包含80层,故有:
3. 参考资料
- MiniMax-01: Scaling Foundation Models with Lightning Attention
- HF Mirror: MiniMaxAI/MiniMax-Text-01/modeling_minimax_text_01.py
- DeepNet: Scaling Transformers to 1,000 Layers
- 大模型结构基础(三):归一化技术的升级
- 大模型结构基础(八):MiniMax-01 精读之 Hybrid Lightning Attention
自认为纸上谈兵得还行,欢迎大佬们介绍工作!也欢迎直接赏饭!
本文暂时没有评论,来添加一个吧(●'◡'●)