前言
Flamingo算是DeepMind的多模态融合LLM的一个较老的工作了(2022年),之前粗略读过没来得及及时总结,本次过年笔者重新细读了论文,发现其在50多页的论文中有着不少细节,本文对该工作进行读后感笔记,希望对诸位读者有所帮助。如有谬误请见谅并联系指出,本文遵守CC 4.0 BY-SA版权协议,转载请联系作者并注明出处,谢谢。
∇ \nabla ∇ 联系方式
e-mail: FesianXu@gmail.com
github: https://github.com/FesianXu
知乎专栏: 计算机视觉/计算机图形理论与应用(https://www.zhihu.com/column/c_1265262560611299328)
之前笔者在介绍BLIP2的博文[1]中,曾经介绍过采用Q-Former融合视觉语义向量和LLM的方法,BLIP2工作中由于只采用了图文对(Image-Text Pair)数据,因此其in-context能力欠缺,few-shot效果不佳。而在Flamingo [2] 这个工作中,作者从互联网数据中收集了大量的图文交织(Image-Text Interleaving)数据,这为Flamingo提供few-shot、in-context能力提供了基础保障,虽然Flamingo中采用Perceiver Resampler和Gated cross-attention的方法融合多模态信息目前不是主流(主流还是Q-Former),但其论文中提到的数据构建方式,模型结构消融实验等仍然能提供很多有意义的参考。我们接下来开始对Flamingo进行介绍。
Flamingo [2] 和 BLIP2 [3] 都是尝试将已预训练好的视觉特征编码器(如ViT、Resnet等)和已预训练好的大语言模型(LLama、OPT等)进行结合的工作,从而使得大语言模型可以交织文本和视觉进行输入(如图片、视频等),最终输出文本,我们称之为MLLM(多模态大语言模型),读者可在博文 [1] 中得到更多相关的背景信息,在此不再累述。Flamingo采用了所谓的感知重采样(Perceiver Resampler)技术和门控交叉注意力技术(Gated Cross-Attention)进行视觉多模态信息和LLM的融合,整体结构如Fig 1.所示,其中视觉编码器和LLM都是固定参数而不在训练中更新,感知重采样器将变长的视觉向量转换成定长的多模态语义向量,通过门控注意力单元将信息融入固定的LLM中,最终实现输入中可混合多模态信息而输出文本信息。
Fig 1. Flamingo的框架图,主要由视觉编码器(vision encoder)、感知重采样器(perceiver resampler)、LLM和交织在LLM中的门控交叉注意力层(gated xattn-dense)组成。其中感知重采样器和门控注意力单元的结构如Fig 2所示,其中的视觉编码器采用NFNet(NormalizerFree ResNet),作者先在图文对数据上采用CLIP的方式对NFNet进行预训练,随后进行参数固定。如果视觉端输入是视频,则按照1 fps进行采样后将 N N N帧进行视觉特征提取(若是图片输入,则N=1),注意到此时position embedding按照帧粒度组织,即是统一帧的不同patch共用一个position embedding以建模帧间序列信息。尔后对多帧的特征进行展开、拼接,作为transformer的k和v,而采用一个可学习的query向量作为transformer的q输入,这个思路可参考博文 [1],不在此展开,具体伪代码可见Code 1。感知重采样机制的一个好处就是,可以将变长的视频输入转变为定长的输入,此处定长的输入长度为64。
门控注意力单元的设计,则是在原先固定的LLM结构的每一层基础上叠加了门控单元,门控单元由交叉注意力机制和门控结构、FFW交替组成,其中交叉注意力的k和v都是感知重采样器的输出,而q则是文本输入。为了保证在训练初始阶段模型和原先的LLM不至于偏差太远,作者采用了门控机制,具体来说就是将新层的输出乘上一个可学习的 tanh ( α ) \tanh(\alpha) tanh(α),将LLM的原先输入与其加和,只需要在初始化时候将 α = 0 \alpha = 0 α=0即可确保初始化时候和原先LLM无太大偏差。作者对在训练过程中每一LM层的 α \alpha α变化进行了可视化,见Fig 3.可发现两个规律,第一随着层数加深,门控值则更大,第二随着训练过程,门控值也逐渐变大,这个倒是符合我们的认识,浅层提取基础特征而深层则更加富有语义信息,因此在深层中的门控更大有利于引入更多的视觉语义信息。
Fig 2. Flamingo中采用的感知重采样器和门控交叉注意力模型结构。def perceiver_resampler( x_f, # The [T, S, d] visual features (T=time, S=space) time_embeddings, # The [T, 1, d] time pos embeddings. x, # R learned latents of shape [R, d] num_layers, # Number of layers ): """The Perceiver Resampler model.""" # Add the time position embeddings and flatten. x_f = x_f + time_embeddings x_f = flatten(x_f) # [T, S, d] -> [T * S, d] # Apply the Perceiver Resampler layers. for i in range(num_layers): # Attention. x = x + attention_i(q=x, kv=concat([x_f, x])) # Feed forward. x = x + ffw_i(x) return x def gated_xattn_dense( y, # input language features x, # input visual features alpha_xattn, # xattn gating parameter – init at 0. alpha_dense, # ffw gating parameter – init at 0. ): """Applies a GATED XATTN-DENSE layer.""" # 1. Gated Cross Attention y = y + tanh(alpha_xattn) * attention(q=y, kv=x) # 2. Gated Feed Forward (dense) Layer y = y + tanh(alpha_dense) * ffw(y) # Regular self-attention + FFW on language y = y + frozen_attention(q=y, kv=y) y = y + frozen_ffw(y) return y # output visually informed language featuresCode 1. 感知重采样器和门控交叉注意力单元的伪代码。 Fig 3. 注意力层中和FFW的门控值在不同层的变化趋势。
说完了模型结构上的改动,我们还需要关注到本工作中的数据构建,在本工作中,作者不仅仅构建了图文对数据(LTIP),而且还构建了视频对数据(VTP)和图文交织数据。图文交织数据(M3W: Interleaved image and text dataset)指的是图片和文本进行多次交织组成的数据,图片会穿插在文本上下文中,而不是简单的图文一对一的关系数据。作者通过解析大概4.3千万个网页的DOM,构建了图文交织数据,如Fig 4.所示,图片穿插在了文章上下文中,而上文和下文可能和该图片都由语义关联。
Fig 4. 来自于网页上的图文交织数据示例,注意到相关的图片会穿插在文本之间,上文和下文都可能和该图片有语义关联。原网页来自 [4]。怎么对图文交织数据进行建模也是一个值得关注的点,如Fig 5 (b) 所示,一张内嵌到文本中的图片可能和上文或者下文或者两者都产生语义关联,在Flamingo中作者选择按照概率 p n e x t p_{next} pnext采样后续文本作为成对文本,亦或者反过来选择前继文本。从图文交织数据中可以组成多个成对数据,如Fig 5 (a)所示,此时通过门控交叉注意力单元中的掩膜设置,可以同时对该图文交织数据中出现的多个成对数据进行建模,具体原理见 [5]。当然,这里对图文交织数据的应用会比较朴素,只考虑以某个图片为中心的局部单向语义信息,而没有考虑到全局信息的建模,相当于还是简单将图文交织数据去局部组织图文对数据进行训练。
在交叉注意力单元中采用这种方式,虽然一次性只能让图片直接关注到一个相关联的文本,但是通过后续的LM单元的自注意力模块,能同时建模任意数量的图片输入和文本输入,实现图文交织数据作为输入的对话,当然也就能支持BLIP2所欠缺的few-shot功能了。这里的做法,按照笔者的认识,相当于就是交叉注意力单元只负责建模图文局部的语义对齐,而图文交织数据全局的信息对齐则由紧接着的LM完成。
Fig 5. 在Flamingo中应用图文交织型数据的方法,由于嵌入到文本中的图片可能和上文、下文产生语义关联,在本工作中采用按一定概率的方式采样后续文本进行应用。Flamingo的效果从benchmark测试看是能吊打很多few-shot和zero-shot数据集的sota方法的,受限于篇幅笔者不会展开。在文中作者做了很多坚实的消融实验,去验证Flamingo的各种设计的效果,如Fig 6.所示,主要对几点进行了消融:
- 是否采用全量数据? 特别是对M3W图文交织数据的有无进行了消融,我们发现图文交织数据能提供大约17%的提升。
- 是否采用门控机制?实验证明采用门控机制能带来月8%的提升。
- 采用交叉注意力层的频率?实验证明每一层都引入门控交叉注意力层效果是最好的。
- 是否采用感知重采样单元引入视觉信息?实验证明该设计能带来约4%的提升。
- 视觉编码器的选择同样对结果影响巨大。
- 是否固定LLM的参数?实验证明固定LLM反而能带来最好的效果,而让LLM随着训练一起进行(会采用massive text数据集一起训练)反而效果会差8%左右,笔者估计是训练过程需要平衡多个目标导致的,如何让LLM也能训练起来可能也是一个值得关注的点。
Reference
[1]. https://blog.csdn.net/LoseInVain/article/details/136013909, 《BLIP2——采用Q-Former融合视觉语义与LLM能力的方法》
[2]. Alayrac, J. B., Donahue, J., Luc, P., Miech, A., Barr, I., Hasson, Y., … & Simonyan, K. (2022). Flamingo: a visual language model for few-shot learning. Advances in Neural Information Processing Systems, 35, 23716-23736. aka Flamingo
[3]. Li, Junnan, Dongxu Li, Silvio Savarese, and Steven Hoi. “Blip-2: Bootstrapping language-image pre-training with frozen image encoders and large language models.” arXiv preprint arXiv:2301.12597 (2023). aka BLIP2
[4]. https://baijiahao.baidu.com/s?id=1761390872940868294&wfr=spider&for=pc
[5]. https://blog.csdn.net/LoseInVain/article/details/119530520
还没有评论,来说两句吧...