将26个token紧缩成1个,新方式极致节省ChatGPT输入框空间

进入正文之前,先考虑一下像 ChatGPT 这样的 Transformer 语言模型(LM)的 prompt:

将26个token紧缩成1个,新方式极致节省ChatGPT输入框空间

随着每天产生数百万用户和查询,ChatGPT 使用自注意力机制对 prompt 进行反复编码,其时间和内存复杂度随输入长度呈二次方增长。缓存 prompt 的 transformer 激活可以防止部分重新计算,但随着缓存 prompt 数量的增加,这种策略仍然会产生很大的内存和保存老本。在大规模情况下,即使 prompt 长度稍微减少一点,也可能会带来计算、内存和保存空间的节省,同时还可以让用户将更多内容放入 LM 有限的上下文窗口中。

那么。应该如何降低 prompt 的老本呢?典型的方式是微调或蒸馏模型,使其在没有 prompt 的情况下表现得与原始模型相似,或许还可以使用参数高效的自适应方式。然而,这种方式的一个基本缺点是每次必要为新的 prompt 重新训练模型(下图 1 中间所示)。

将26个token紧缩成1个,新方式极致节省ChatGPT输入框空间

本文中,斯坦福大学的研究者提出了 gisting 模型(上图 1 底部),它将任意 prompt 紧缩成一组更小的虚拟「Gist」 token,类似于前缀微调 。然而,前缀微调必要通过梯度下降为每个工作进修 prefix,而 Gisting 采用元进修方式,仅仅通过 prompt 预测 Gist prefix,而不必要为每个工作进行 prefix 进修。这样可以摊销每个工作 prefix 进修的老本,使得在没有额外训练的情况下泛化到未知的指令。

此外,由于「Gist」token 比完整 prompt 要短得多,因此 Gisting 允许 prompt 被紧缩、缓存和重复使用,以提高计算服从。

将26个token紧缩成1个,新方式极致节省ChatGPT输入框空间

论文地址:https://arxiv.org/pdf/2304.08467v1.pdf

研究者提出了一种非常简单的方式来进修指令遵循的 gist 模型:简单地进行指令微调,在 prompt 后插入 gish token,修改后的注意力掩膜阻止 gist token 后的 token 参照 gist token 前的 token。这使得模型同时进修 prompt 紧缩和指令遵循,而无需额外的训练老本。

在 decodr-only(LLaMA-7B)和 encoder-decoder(FLAN-T5-XXL)LM 上,gisting 可实现高达 26 倍的即时紧缩率,同时保持与原始模型相似的输入质量。这使得推理过程中 FLOPs 减少了 40%,延迟加速了 4.2%,与传统的 prompt 缓存方式相比,保存老本大大降低。

Gisting

研究者首先在指令微调的背景下描述 gisting。对于指令遵循数据集将26个token紧缩成1个,新方式极致节省ChatGPT输入框空间,t 表示用自然语言 prompt 编码的工作 (例如将此翻译成法语),x 表示工作的(可选)输入 (例如 The cat),y 表示期望的输入(例如 Le chat)。指令微调的目的是通过连接 t 和 x,然后让通常预先训练的语言模型自回归地预测 y,从而进修分布 pLM(y | t,x)。推理时可以使用新的工作 t 和输入 x 进行 prompt,从模型中解码以获得预测结果。

然而,连接 t 和 x 的这种模式具有缺点:基于 Transformer 的 LM 具有有限的上下文窗口,其受架构或计算能力所限。后者特别难解决,因为自注意力随输入长度呈二次方扩展。因此很长的 prompt,尤其那些被反复重用的 prompt,计算服从低下。有哪些选项可以用来降低 prompt 的老本呢?

一种简单的方式是针对特定工作 t 进行 LM 微调,即给定包含仅在工作 t 下的输入 / 输入示例的数据集将26个token紧缩成1个,新方式极致节省ChatGPT输入框空间,可以进修一个专门的将26个token紧缩成1个,新方式极致节省ChatGPT输入框空间,它更快,因为不必要考虑 t。

更好的是,prefix/prompt 微调或 adapter 等参数高效微调方式能够以比全面微调低得多的老本实现相同的目的。然而仍然存在问题:必须至少保存每个工作的一部分模型权重,并且更重要的是,对于每个工作 t,必须收集相应的输入 / 输入对数据集 D^t 并重新训练模型。

Gisting 是一种不同的方式,它摊销了两部分老本:(1)在 t 上条件化 p_LM 的推理时间老本,(2)进修每个 t 的新 p^t_LM 的训练时间老本。其思想是在微调期间进修 t 的紧缩版本 G (t),使得从 p_G (y | G (t),x) 进行推理比从 p_LM (y|t,x) 更快。

在 LM 术语中,G (t) 将是一组「虚拟」的 Gist token,其数量比 t 中的 token 少,但仍会在 LM 中引起类似的行为。接着可以缓存并重复使用 G (t) 上的 transformer 激活(例如键和值矩阵)以提高计算服从。重要的是,研究者希望 G 可以泛化到未见过的工作:给定一个新工作 t,则可以预测并使用相应的 Gist 激活 G (t) 而无需进行任何额外训练。

通过掩膜进修 Gisting

上文描述了 Gisting 的一般框架,接下来将探讨一种进修此类模型的极简单方式:使用 LM 本身用作 Gist 预测器 G。这不仅利用了 LM 中的预存在知识,而且允许通过简单地执行标准指令微调来进修 gisting 并修改 Transformer 注意力掩膜来增强 prompt 紧缩。这意味着 Gisting 不会产生额外训练老本,只必要基于标准指令微调即可!

具体来说,向模型词汇表和嵌入矩阵中添加一个特殊的 gist token,类似于此类模型中常见的句子开头 / 结尾 token。然后对于给定的(工作,输入)元组(t,x),使用 (t, g_1, . . . , g_k, x) 中一组 k 个连续的 gist token 将 t 和 x 连接在一起,例如将26个token紧缩成1个,新方式极致节省ChatGPT输入框空间。这个序列被输入到模型中,有一个限制,即在 gist token 之后的输入 token 不能参照之前的 prompt token(但它们可以参照 gist token)。这会强制模型将 prompt 中的信息紧缩成 gist token,因为输入 x (输入 y) 无法处理 prompt t。

下图 2 展示了所必要的更改。对于 GPT-3 或 LLaMA 等通常采用自回归因果注意力掩膜的 decoder-only LM,只需 mask out 图 2a 所示的三角形左下角。对于具有双向编码器和自回归解码器的 encoder-decoder LM,则必要进行两项修改(图 2b 所示)。

首先,在通常没有掩膜的编码器中,阻止输入 token x 参照 prompt token t。但还必须防止 prompt t 和 gist token g_i 参照输入 token x,否则编码器将根据输入进修不同的 gist 表示。最后解码器正常运行,除了在交叉注意力期间,这时必要阻止解码器参照 prompt token t。

将26个token紧缩成1个,新方式极致节省ChatGPT输入框空间

实验结果

对于不同数量的 gist token, LLaMA-7B 和 FLAN-T5-XXL 的 ROUGE-L 和 ChatGPT 评估结果如下图 3 所示。

将26个token紧缩成1个,新方式极致节省ChatGPT输入框空间

模型通常对 gist token 的数量 k 不敏感:将 prompt 紧缩到单个 token 并不会导致显著性能下降。事实上,在某些情况下,过多的 gist token 会损害性能 (例如 LLaMA-7B, 10 gist tokens),这可能是因为增加的容量使训练分布过拟合。因此,研究者在下表 1 中给出了单 token 模型的具体数值,并在剩余实验中使用单个 gist 模型。

将26个token紧缩成1个,新方式极致节省ChatGPT输入框空间

在见过的指令上,gist 模型获得了与其对应阳性对照模型几乎相同的 ROUGE 和 ChatGPT 性能,在 LLaMA-7B FLANT5-XXL 上的胜率分别为 48.6% 和 50.8%。这里研究者最感兴趣的是它们在未见过工作上的泛化能力,这必要通过另外两个数据集来衡量的。

在 Alpaca+ 训练数据集中未见过的 prompt 中,可以看到 gist 模型在未见过 prompt 上有着强大的泛化能力:与对照组相比,分别有 49.7%(LLaMA)和 46.2%(FLAN-T5)的胜率。在最具挑战性的 OOD Human split 上,gist 模型的胜率略微下降,分别为 45.8%(LLaMA)和 42.5%(FLANT5)。

本文的目的是让 gist 模型紧密地模仿原始模型的功能,因此有人可能会问究竟什么时候 gist 模型与对照组无差别。下图 4 说明了这种情况发生的频率:对于已见过工作(但是未见过的输入),gist 模型几乎有一半的时间与对照组不相上下。对于未见过的工作,这一数字下降到了 20-25%。对于 OOD Human 工作,这一数字又下降到 10%。无论如何,gist 模型输入的质量是很高的。

将26个token紧缩成1个,新方式极致节省ChatGPT输入框空间

总的来说,这些结果表明,gist 模型可以可靠地紧缩 prompt,甚至在训练分布之外的某些 prompt 上也可以做到这一点,特别是像 LLaMA 这样的 decoder-only 因果 LM。FLAN-T5 等 encoder-decoder 模型表现略差,一个可能的原因是 gist 掩膜抑制了编码器中的双向注意力流,这比仅 mask 自回归解码器的一部分 history 更具挑战性。未来必要进一步的工作来研究这个假设。

计算、内存和保存服从

最后,回到这项工作的核心动机之一:gisting 可以带来什么样的服从提升?

下表 2 展示了使用 PyTorch 2.0 分析器对模型进行单次前向传递的结果(即使用单个输入 token 的自回归解码的一步),并对 Human eval split 中的 252 个指令取平均值。与未经优化的模型相比,gist 缓存显著提高了服从。两种模型的 FLOPs 节约率达到了 40%,时钟时间降低了 4-7%。

将26个token紧缩成1个,新方式极致节省ChatGPT输入框空间

然而更重要的是,与指令缓存相比,gist 缓存有着除延迟之外的关键优势:将 26 个 token 紧缩为 1 个可以在输入上下文窗口中腾出更多空间,这受到绝对位置嵌入或者 GPU VRAM 的限制。特别是对于 LLaMA-7B,KV 缓存中的每个 token 必要 1.05MB 的保存空间。尽管在测试的 prompt 长度下,KV 缓存相对于 LLaMA-7B 推断所需的内存总贡献微不足道,但一个越来越常见的场景是开发人员在大量用户之间缓存许多 prompt,保存老本很快就会增加。在保存空间相同的情况下,gist 缓存能比完整指令缓存多 26 倍的 prompt。

给TA打赏
共{{data.count}}人
人已打赏
AI

GIF动画渲染、让灯塔闪烁、创立航空动态图……ChatGPT代码解释器插件「不止于代码」

2023-5-7 12:47:00

AI

首个单细胞生物学基础大型语言模型,在超1000万个细胞从事预训练

2023-5-9 11:29:00

0 条回复 A文章作者 M管理员
    暂无讨论,说说你的看法吧
个人中心
今日签到
搜索