HyperAIHyperAI
il y a 18 jours

Décalage de représentation : unifier la compression de jetons avec FlashAttention

Joonmyung Choi, Sanghyeok Lee, Byungoh Ko, Eunseo Kim, Jihyung Kil, Hyunwoo J. Kim
Décalage de représentation : unifier la compression de jetons avec FlashAttention
Résumé

Les Transformers ont fait preuve d’un succès remarquable dans les domaines de la vision, du langage et de la vidéo. Toutefois, l’augmentation de la complexité des tâches a entraîné une croissance des modèles et du nombre de tokens, ce qui accroît le coût quadratique de l’attention auto-attentive ainsi que la surcharge liée à l’accès à la mémoire GPU. Afin de réduire le coût computationnel de l’attention auto-attentive, des travaux antérieurs ont proposé des techniques de compression de tokens, consistant à éliminer les tokens redondants ou moins informatifs. Parallèlement, des noyaux d’attention fusionnés, tels que FlashAttention, ont été développés pour atténuer la surcharge mémoire en évitant la construction des cartes d’attention et les opérations d’entrée/sortie associées vers la HBM (High Bandwidth Memory). Cependant, cette approche rend incompatible la plupart des méthodes de compression de tokens sans entraînement, qui dépendent des cartes d’attention pour évaluer l’importance des tokens. Dans ce travail, nous proposons Representation Shift, une métrique sans entraînement et indépendante du modèle, qui mesure le degré de variation de la représentation de chaque token. Cette méthode s’intègre de manière transparente à FlashAttention, sans nécessiter de cartes d’attention ni de re-entraînement. Notre approche se généralise également au-delà des Transformers, s’appliquant aux réseaux de neurones convolutifs (CNN) et aux modèles à espace d’état. Des expérimentations étendues montrent que Representation Shift permet une compression efficace des tokens compatible avec FlashAttention, entraînant des accélérations significatives, atteignant respectivement 5,5 % et 4,4 % dans les tâches de recherche vidéo-texte et de question-réponse vidéo. Le code est disponible à l’adresse suivante : https://github.com/mlvlab/Representation-Shift.