Fast Inference from Transformers via Speculative Decoding
추측적 디코딩을 통한 트랜스포머의 빠른 추론
Yaniv Leviathan, Matan Kalman, Yossi Matias (2023)
작은 드래프트 모델이 여러 토큰을 빠르게 생성한 뒤, 큰 타겟 모델이 이를 한 번에 병렬 검증하는 추측적 디코딩을 제안했다. 이 방법은 타겟 모델의 출력 분포를 정확히 보존하면서 2-3배의 추론 속도 향상을 달성한다.
배경
자기 회귀적(autoregressive) 언어 모델의 추론은 본질적으로 순차적이다. 각 토큰 생성에 전체 모델의 순전파가 필요하며, 이는 배치 크기가 작을 때 GPU 계산 자원의 활용률이 매우 낮아지는(memory-bandwidth bound) 문제를 야기한다. 특히 대규모 모델(100B+)에서 단일 토큰 생성의 지연 시간(latency)이 길어 실시간 응용에 제약이 있었다. 모델 병렬화, 양자화 등의 최적화가 연구되었지만, 자기 회귀 디코딩의 순차적 특성 자체를 해결하는 방법은 부족했다.
핵심 아이디어
추측적 디코딩(speculative decoding)은 '대부분의 토큰은 예측하기 쉬우며, 작은 모델도 올바르게 생성할 수 있다'는 직관에 기반한다. 작은 드래프트 모델(draft model)이 K개의 토큰을 자기 회귀적으로 빠르게 생성하고, 큰 타겟 모델(target model)이 이 K개 토큰에 대해 단 한 번의 순전파로 각 위치의 확률 분포를 병렬 계산한다. 그런 다음 수정된 거부 샘플링(modified rejection sampling) 기법으로 드래프트 토큰을 앞에서부터 순서대로 승인하거나 거부한다. 수학적으로 이 과정은 타겟 모델에서 직접 샘플링한 것과 동일한 분포를 보장하므로, 출력 품질의 저하가 전혀 없다.
방법론
각 디코딩 스텝에서 드래프트 모델 M_q가 현재 컨텍스트에서 gamma개의 토큰을 자기 회귀적으로 생성하며, 각 위치의 드래프트 분포 q(x)를 저장한다. 타겟 모델 M_p는 원본 컨텍스트에 드래프트 토큰을 붙인 시퀀스에 대해 한 번의 순전파를 수행하여 각 위치의 타겟 분포 p(x)를 계산한다. 각 드래프트 토큰 x에 대해 min(1, p(x)/q(x))의 확률로 승인한다. 거부된 위치에서는 수정 분포 max(0, p(x)-q(x))에서 재샘플링하고, 모든 토큰이 승인되면 추가로 p에서 한 토큰을 더 샘플링한다. 이를 통해 한 라운드에서 평균적으로 1/(1-alpha) + 1개의 토큰을 생성하며, alpha는 드래프트와 타겟 분포의 일치도이다.
주요 결과
T5-XXL(11B) 모델에서 T5-Small을 드래프트 모델로 사용했을 때, 텍스트 요약과 번역 태스크에서 2-3배의 추론 지연 시간 감소를 달성했다. 출력 분포가 타겟 모델과 수학적으로 동일함을 이론적으로 증명하고 실험적으로 검증했다. 드래프트 길이 gamma의 최적값은 태스크와 모델 쌍에 따라 4-8 범위에서 결정되었다. 채팅과 같은 다양한 생성 태스크에서 일관된 속도 향상을 보여 방법의 범용성을 입증했다.
임팩트
추측적 디코딩은 LLM 추론 최적화의 핵심 기법으로 자리잡아, Google(PaLM/Gemini 서빙), Meta, Anthropic 등 주요 AI 기업의 프로덕션 서빙 스택에 통합되었다. 이 아이디어는 Medusa(다중 헤드 드래프트), EAGLE(자기 드래프트), SpecInfer(트리 기반 검증) 등 다양한 변형을 촉발했다. 드래프트 모델 없이도 적용 가능한 self-speculative decoding, 트리 구조 검증(tree attention) 등으로 발전하고 있으며, vLLM 등 주요 서빙 프레임워크에서 기본 지원된다.