显存占用

对于transformer中, 最占用显存的部分就是注意力计算了, 他会产生一个很大的注意力得分矩阵, 这个矩阵形状通常是 ( batch , head, N, N) N代表序列维度

如果我们使用32批次, 一张640 640 的图片, 然后patch选 16 那 N就是1600, head选8, 显存占用就有

1
32 * 8 * 1600 * 1600 * 4 / 1024 / 1024 / 1024 = 2.44GB

一个注意力分数矩阵就已经2.44G了, 通常encoder和decoder有6层, 这个显存占用就已经很大了

主要减少显存占用的方法:

  1. 减小batch
  2. 增大patch, 也就可以减少N, 这个是O(N^2)级别的, 降低起来很有效

问题汇总

这个是我学习transformer中的一些疑惑

词嵌入的目的

问: 能否直接将词在词库中的编码直接交给transformer而不使用词嵌入

答:

直接使用单词在词库中的索引数字(即数字ID)作为输入是不可行的,原因有以下几点:

  1. 缺乏语义信息:数字ID只是单纯的标识符,不包含任何语义信息。例如,"猫"的ID可能是42,"狗"的ID可能是128,但这些数字之间的差值(86)没有任何语义意义。
  2. 缺乏相似性表示:相似词汇应该有相似的表示,但数字ID无法体现这一点。"猫"和"猫咪"语义接近,但它们的ID可能是完全不相关的数字。
  3. 模型无法学习:神经网络在处理连续数值时,会假设数值的大小关系反映了某种含义。但词汇ID的大小关系是随机的,没有语义基础,这会导致模型无法学习有意义的模式。
  4. 梯度问题:在反向传播过程中,如果直接使用数字ID,无法为相似的词建立相似的梯度更新路径,阻碍了模型的学习过程。
  5. 维度问题:使用单一数字会将所有词汇信息压缩到一维,而词嵌入通常是几百维的向量,能够捕捉词汇的多种语义特征。

这就是为什么词嵌入层是必要的 - 它将离散的词汇ID转换为连续的向量表示,为每个词创建一个具有丰富语义信息的表示,使模型能够理解词与词之间的关系和相似性,从而更好地学习和理解文

计算损失方式的区别

问: 在Transformer模型特别是解码器(decoder)训练过程中,为什么我们不直接用解码器的输出向量(维度为[序列长度-1, 512])与经过词嵌入和位置编码后的真实标签向量(同样维度为[序列长度-1, 512])直接计算损失?而是要先将解码器输出通过线性层和softmax转换成词汇表大小的概率分布,然后与真实标签的独热编码计算交叉熵损失?这两种方法在原理和效果上有什么不同?

答:

  1. 目标不同:模型的目标是预测下一个词,而不是预测下一个词的嵌入向量。嵌入空间是连续的,同一个词的嵌入表示可能有多种合理的变体,而我们关心的是最终预测出正确的词。
  2. 嵌入空间的冗余性:嵌入空间是高维的(通常是512维),且存在大量冗余。同一个语义概念可能有多种有效的嵌入表示,直接在这个空间中比较会导致不必要的惩罚。
  3. 一词多嵌入问题:同一个词在不同上下文中可能有略微不同的嵌入表示,直接比较嵌入向量会惩罚这种合理的变化。
  4. 训练稳定性:在离散的词汇分布上计算损失(通常是交叉熵)比在连续的嵌入空间中计算损失(如MSE)更加稳定,梯度更加明确。
  5. 评估一致性:最终我们评估模型时是看它预测词的准确性,而不是嵌入的相似度,所以损失函数应该与评估指标一致。
  6. 梯度传播效率:通过softmax和交叉熵损失,模型可以获得更明确的梯度信号,知道应该增强或减弱哪些词的概率,而不是在高维空间中进行模糊的调整。

这种设计让模型明确地学习"选择正确的词"这一任务,而不是"生成一个与目标嵌入相似的向量",使训练更加聚焦和高效。

vit为何不使用encoder的全部输出

问: vit为何不使用encoder的全部输出来预测分类结果, 而单纯使用一个cls patch来预测呢?

答:

  1. 继承自NLP领域的BERT模型设计:ViT的设计很大程度上借鉴了NLP领域的Transformer结构,特别是BERT。在BERT中,[CLS]令牌被设计用来捕获整个序列的全局信息,并用于分类任务。ViT沿用了这一设计理念。
  2. 自注意力机制的全局信息整合:通过多层自注意力机制,[CLS]令牌能够与所有图像patch交互,逐渐积累和整合全局信息。在最后一层,[CLS]令牌理论上已经包含了整个图像的关键信息。
  3. 计算效率考虑:使用单一令牌进行分类比处理所有patch的输出更加高效。一个典型的ViT可能将图像分成196个patch,如果全部用于分类将大大增加分类头的参数量和计算量。
  4. 简化模型结构:使用单一的[CLS]令牌简化了模型结构,使其在概念上更加清晰,也更容易与其他架构(如基于CNN的模型)进行比较和集成。
  5. 实验验证的有效性:在原始ViT论文的实验中,这种设计已被证明非常有效,能够达到甚至超过当时最先进的CNN模型的性能。

值得注意的是,后续的一些研究确实探索了使用更多encoder输出信息的方法,例如:

  • 为分类头增加全局平均池化层,利用所有patch的信息
  • 在特定任务(如目标检测)中使用所有patch的输出
  • 设计混合模型结合CNN和Transformer的优势

问: 有什么实验验证了只使用一个和全部使用的区别吗?

答:

  1. DeiT (Data-efficient Image Transformers) 论文中的实验:研究者们比较了使用[CLS]令牌和使用所有patch的全局平均池化(GAP)的效果。他们发现在某些设置下,全局平均池化的表现略优于单独使用[CLS]令牌。

  2. Token Pooling研究

    :一些工作如"Token Pooling in Vision Transformers"系统地比较了不同的池化策略,包括:

    • 仅使用[CLS]令牌
    • 全局平均池化(GAP)所有patch
    • 全局最大池化(GMP)所有patch
    • 加权池化策略

    结果表明不同数据集上最佳策略可能不同,但GAP通常能提供稳定的性能改进,特别是在小型数据集上。

  3. ToMe (Token Merging) 相关研究中提到,当模型规模较小或训练数据有限时,全局平均池化通常优于单独使用[CLS]令牌;而在大规模预训练模型中,两者差异变得不那么显著。

  4. CaiT (Class-Attention in Image Transformers) 通过分离patch的自注意力和分类注意力,表明了一种更优的利用所有patch信息的方式,在某些配置下取得了比标准ViT更好的结果。

  5. LV-ViT (Token Labeling) 研究提出了一种在训练过程中为每个patch分配标签的方法,这种方法隐含地利用了所有patch的信息,并在多个基准测试中取得了显著提升。

总体来说,研究表明:

  • 在规模较小的模型或有限数据场景下,使用全局平均池化等方法利用所有patch信息通常能带来1-2%的准确率提升。
  • 随着模型规模和预训练数据量增加,[CLS]令牌能够更有效地聚合全局信息,与全局池化的性能差距减小。
  • 任务的性质也会影响最佳选择 - 例如,细粒度分类任务可能更受益于使用所有patch的信息。
  • 一些混合方法,例如结合[CLS]和GAP的加权平均,有时能够达到最佳性能。