nanoVLM에서 KV 캐싱 구현으로 38% 속도 향상
nanoVLM에서 KV 캐시 구현하기 우리의 nanoVLM 저장소는 순수한 PyTorch로 자체 비전 언어 모델을 훈련시키는 작은 코드베이스입니다. 이 저장소에서 KV 캐시를 처음부터 구현하여 생성 속도를 38% 개선했습니다. 이 블로그 글에서는 KV 캐시와 그 구현 과정에서 얻은 경험을 공유합니다. 이러한 교훈들은 모든 자동 회귀 언어 모델 생성에 적용될 수 있습니다. 작은 코드베이스에서 처음부터 구현하는 것은 매우 유익한 학습 경험이며, 이 글을 통해 함께 그 여정을 즐겨보세요. 트랜스포머 아키텍처 복습 캐싱에 들어가기 전에, 트랜스포머 모델에서 어텐션이 어떻게 작동하는지 다시 살펴봅시다. 트랜스포머 언어 모델은 여러 층으로 구성되며, 각 층은 다음과 같은 요소로 이루어져 있습니다: - 멀티-헤드 어텐션 - 포지셔널 피딩 포워드 네트워크 - 노멀라이제이션 레이어 자체 어텐션 메커니즘, 특히 단일 어텐션 헤드 내에서 KV 캐시가 도움이 되는 부분을 집중적으로 살펴보겠습니다. 간단한 PyTorch 구현을 통해 주요 계산을 시각화해 보겠습니다. ```python import torch input_seq_length = 5 dim_model = 10 input_ids_emb = torch.randn(input_seq_length, dim_model) W_q = torch.randn(dim_model, dim_model) W_k = torch.randn(dim_model, dim_model) W_v = torch.randn(dim_model, dim_model) Q = input_ids_emb @ W_q K = input_ids_emb @ W_k V = input_ids_emb @ W_v ``` 자체 어텐션 계산 T개의 입력 임베딩 시퀀스 (X \in \mathbb{R}^{T \times D})에 대한 자체 어텐션은 다음과 같이 계산됩니다: [ \text{Attention}(X; Q, K, V) = \text{softmax}\left( \frac{QK^\top \cdot M}{\sqrt{d_k}} \right)V ] 다음은 인과 마스크를 사용한 최소한의 PyTorch 구현입니다. ```python import torch.nn.functional as F attention_scores = Q @ K.T 미래 토큰 접근을 막는 하삼각형 마스크 causal_mask = torch.tril(torch.ones(input_seq_length, input_seq_length)) masked_scores = attention_scores.masked_fill(causal_mask == 0, float('-inf')) attention_weights = F.softmax(masked_scores, dim=-1) output = attention_weights @ V ``` 불필요한 중복 발생 자동 회귀 생성에서 모델은 한 번에 하나의 토큰을 생성합니다. 각 단계마다 전체 시퀀스의 Q, K, V를 다시 계산하는데, 이미 변경되지 않은 이전 토큰들까지 모두 다시 계산됩니다. ```python new_token_emb = torch.randn(1, dim_model) extended_input = torch.cat([input_ids_emb, new_token_emb], dim=0) Q_ext = extended_input @ W_q K_ext = extended_input @ W_k V_ext = extended_input @ W_v ``` 중복을 확인하기 위해 다음과 같이 테스트할 수 있습니다. python torch.testing.assert_close(K, K_ext[:input_seq_length]) # 테스트 통과 torch.testing.assert_close(V, V_ext[:input_seq_length]) # 테스트 통과 이 테스트 결과는 대부분의 어텐션 계산이 불필요하게 반복되며, 시퀀스가 길어질수록 비용이 더 많이 들음을 보여줍니다. KV 캐시로 이 문제 해결 불필요한 중복을 제거하기 위해 KV 캐시를 사용합니다. 이 방법은 전체 시퀀스를 다시 계산하는 대신 경량화된 점진적 업데이트를 수행합니다. 실제로, 이 캐시는 층별로 "key"와 "value"를 키로 하는 사전 형태이며, 각 키의 모양은 (배치_사이즈, 헤드_수, 캐시_시퀀스_길이, 헤드_차원)입니다. 이는 현대의 큰 언어 모델들이 효율적으로 긴 출력을 생성할 수 있는 기초가 됩니다. nanoVLM에서의 KV 캐시: 이론에서 실천 KV 캐시의 이론을 이해한 후, nanoVLM 저장소에서 실제로 어떻게 구현되는지 알아보겠습니다. 이 저장소는 매우 간결하고 독립적인 코드베이스이므로 이상적인 테스트 베드입니다. KV 캐시는 모델의 세 가지 주요 구성요소에서 사용됩니다: 1. KV 캐시를 사용하고 업데이트하는 어텐션 블록 2. 층별로 캐시를 추적하는 언어 모델 3. 초기 채움(pre-fill)과 순차 디코딩(decode) 단계로 나뉘어진 생성 루프 1. 어텐션 블록에서의 KV 캐시 업데이트 LanguageModelGroupedAttention 클래스에서 forward 함수를 수정하여, 키와 값을 캐싱하고 업데이트할 수 있도록 합니다. 이전에는 매번 생성 단계에서 K, V를 재계산했지만, 이제는 현재 토큰에 대한 Knew, Vnew만 계산하고 캐시에 추가합니다. ```python def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None): is_prefill = block_kv_cache is None B, T_curr, C = x.size() # 입력을 Q, K, V로 프로젝팅 q_curr, k_curr, v_curr = project_current_tokens(x) q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin) if not is_prefill and block_kv_cache['key'] is not None: # 캐시에 새로운 키와 값을 추가 k = torch.cat([block_kv_cache['key'], k_rotated], dim=2) v = torch.cat([block_kv_cache['value'], v_curr], dim=2) else: # 첫 번째 패스 (pre-fill) — 캐시 없음 k, v = k_rotated, v_curr block_kv_cache = {'key': k, 'value': v} return attention_output, block_kv_cache ``` 2. 층별 캐시 추적 LanguageModel 클래스에서 층별 캐시 추적을 도입합니다. start_pos 인자는 새로 생성된 토큰에 대한 올바른 회전 위치 인코딩 계산을 보장합니다. ```python def forward(self, x, kv_cache=None, start_pos=0): T_curr = x.size(1) position_ids = torch.arange(start_pos, start_pos + T_curr, device=x.device) cos, sin = self.rotary_embd(position_ids) for i, block in enumerate(self.blocks): # 층별 KV 캐시 전달 x, kv_cache[i] = block(x, cos, sin, attention_mask, kv_cache[i]) return x, kv_cache ``` 3. 생성 루프에서의 초기 채움과 디코딩 구분 VisionLanguageModel 클래스의 generate() 메서드에서 가장 큰 아키텍처 변경이 이루어졌습니다. 생성을 두 단계로 나눕니다: - 초기 채움 단계 (pre-fill): 입력 프롬프트를 처리하고 캐시를 채웁니다. - 예: [프롬프트: "What is"] → [트랜스포머] → [캐시: 모든 층의 K, V] - 디코딩 단계 (decode): 캐시된 K, V를 사용하여 한 토큰씩 생성합니다. - 예: [토큰: "the"] → [Q("the") + 캐시된 K/V] → [다음 토큰: "?"] → ... 해당 코드는 다음과 같습니다. ```python PRE-FILL: 입력 프롬프트 처리, 캐시 채우기 prompt_output, kv_cache_list = self.forward( inputs, kv_cache=None, start_pos=0 ) DECODE: 캐시된 K/V를 사용하여 한 토큰씩 생성 for i in range(max_new_tokens): next_token = sample_from(prompt_output) decode_output, kv_cache_list = self.forward( next_token, kv_cache=kv_cache_list, start_pos=current_position # 각 단계마다 업데이트 ) prompt_output = decode_output ``` 이 단계들을 분리함으로써 불필요한 계산을 피하고, 특히 긴 프롬프트에서 추론 속도를 크게 향상시킵니다. 변경 요약 | 모듈 | 원래 동작 | 새로운 동작 | |------|-----------|-------------| | LanguageModelGroupedAttention.forward | 매 단계마다 Q, K, V 재계산 | KV 캐시 사용 및 업데이트 | | LanguageModel.forward | 이전 상태 기억하지 않음 | 층별 KV 캐시 추적, start_pos 처리 | | VisionLanguageModel.generate | 단일 단계 생성 루프 | 초기 채움(pre-fill)과 디코딩(decode) 단계로 분리 | KV 캐시의 중요성 | 이점 | 설명 | |------|------| | 점진적 성장 | 캐시는 각 새 토큰당 한 행씩 증가 | | 위치 인식 디코딩 | start_pos는 위치 인코딩 계산의 정확성을 보장 | | 효율성 | 각 토큰 추론 비용을 O(시퀀스_길이)로 줄임, 이전에는 이차 비용이었음 | KV 캐시는 자동 회귀 생성 동안 불필요한 계산을 제거하여, 특히 긴 시퀀스와 실시간 애플리케이션에서 더 빠르고 효율적인 추론을 가능하게 합니다. 이는 속도와 메모리 사이의 균형을 맞추는 방법으로, 더욱 복잡한 추론 방식(예: 빔 서치)을 제한할 수 있다는 단점이 있지만, 큰 언어 모델의 추론 속도를 높이는 인기 있는 방법입니다. 이제 여러분도 이 방법이 어떻게 작동하는지 알게 되었습니다! 업계 전문가들은 nanoVLM의 KV 캐시 구현이 큰 언어 모델의 실용성을 크게 향상시킨다고 평가합니다. nanoVLM은 소비자 하드웨어에서도 큰 언어 모델을 실행할 수 있게끔 만드는 중요한 기술적 발전을 이루어냈습니다. 이러한 개선은 모델의 추론 시간을 줄이고, 실시간 응용 분야에서의 성능을 향상시키는 데 크게 기여합니다.