对角批处理在长上下文中的循环记忆Transformer中解锁并行性
Sivtsov, Danil ; Rodkin, Ivan ; Kuzmin, Gleb ; Kuratov, Yuri ; Oseledets, Ivan
发布日期: 6/8/2025

摘要
变压器模型在处理长上下文推理时面临挑战,因为它们的时间复杂度为二次方,内存复杂度为线性。循环记忆变压器(Recurrent Memory Transformers, RMTs)提供了一种解决方案,通过将渐近成本降低到线性时间和常数内存使用。然而,RMTs的内存更新机制导致了顺序执行,从而形成了性能瓶颈。我们引入了对角批处理(Diagonal Batching),这是一种调度方案,能够在RMTs中解锁跨段并行计算,同时保持精确的递归性。该方法消除了顺序约束,使得即使对于单个长上下文输入,也能在没有复杂的批处理和流水线技术的情况下实现高效的GPU推理。由于这一技术纯粹是对运行时计算顺序的重新排列,现有的RMT模型可以无需重新训练即可采用它。将对角批处理应用于LLaMA-1B自回归RMT模型时,在131,072个标记序列上,其速度比标准全注意力LLaMA-1B提高了3.3倍,比顺序RMT实现提高了1.8倍。通过对消除顺序瓶颈,对角批处理降低了推理成本和延迟,从而增强了RMTs作为实际应用中长上下文问题的有效解决方案的地位。