FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
FlashAttention: IO 인식 기반의 빠르고 메모리 효율적인 정확한 어텐션
Tri Dao, Daniel Y. Fu, Stefano Ermon, et al. (2022)
GPU 메모리 계층 구조를 고려한 IO-aware 알고리즘으로 어텐션 연산을 재구성하여, 근사 없이 정확한(exact) 어텐션을 기존 대비 2-4배 빠르게 수행하면서 메모리 사용량을 시퀀스 길이에 대해 선형으로 줄였다.
배경
Transformer의 셀프 어텐션은 시퀀스 길이 N에 대해 O(N^2)의 시간 및 메모리 복잡도를 가져 긴 시퀀스 처리에 심각한 병목이었다. 기존 연구들은 이를 해결하기 위해 sparse attention(Longformer), low-rank approximation(Linformer), kernel-based methods(Performer) 등 근사 어텐션(approximate attention)을 제안했으나, 정확도 손실이 불가피했고 실제 wall-clock 속도 향상이 제한적이었다. 이는 기존 접근들이 FLOPS 절감에만 집중하고, GPU의 메모리 접근 패턴(IO 복잡도)을 무시했기 때문이다.
핵심 아이디어
FlashAttention의 핵심 통찰은 어텐션 연산의 병목이 산술 연산(FLOPS)이 아니라 GPU HBM(고대역폭 메모리)과 SRAM(온칩 캐시) 간의 데이터 이동(IO)이라는 것이다. 표준 어텐션은 N x N 크기의 어텐션 행렬 전체를 HBM에 저장하고 다시 읽어야 하므로 O(N^2)의 메모리 접근이 발생한다. FlashAttention은 타일링(tiling)과 재계산(recomputation) 기법을 결합하여, 어텐션 행렬을 절대 HBM에 저장하지 않고 SRAM 내에서 블록 단위로 계산을 완료한다. 소프트맥스의 온라인 정규화(online softmax normalization)를 통해 블록 단위 계산에서도 수학적으로 정확한 결과를 보장한다.
방법론
입력 Q, K, V 행렬을 SRAM에 적재 가능한 블록 크기로 분할한다. 외부 루프에서 K, V의 블록을, 내부 루프에서 Q의 블록을 순회하며, 각 블록 조합에 대해 부분 어텐션 출력을 SRAM 내에서 계산한다. 소프트맥스의 분자와 분모를 별도로 추적하는 온라인 소프트맥스 기법을 사용하여 블록 간 결과를 점진적으로 병합한다. 역전파 시에는 어텐션 행렬을 저장하지 않고 Q, K, V와 출력 통계량(row-wise max, sum)만 저장한 뒤 필요 시 재계산한다. 이를 통해 메모리 사용량이 O(N)으로 감소한다. CUDA 커널로 구현되어 Fused 연산을 통해 커널 론치 오버헤드도 제거한다.
주요 결과
FlashAttention은 표준 PyTorch 어텐션 대비 2-4배의 wall-clock 속도 향상과 5-20배의 메모리 절감을 달성했다. GPT-2 학습에서 HuggingFace 및 Megatron-LM 구현 대비 최대 3배 빠른 학습 속도를 보였다. 시퀀스 길이를 최대 16K까지 확장할 수 있게 하여, Long Range Arena 벤치마크에서 기존 근사 어텐션 방법들을 능가하는 정확도를 달성했다. Path-X(16K 길이) 태스크에서 Transformer가 최초로 랜덤 이상의 성능을 달성했다.
임팩트
FlashAttention은 Transformer 추론 및 학습의 효율성을 획기적으로 개선한 시스템 수준의 혁신이다. PyTorch 2.0의 기본 어텐션 구현으로 통합되었으며, 현재 거의 모든 LLM 학습 및 추론 프레임워크(Hugging Face, DeepSpeed, Megatron 등)에서 사용되고 있다. IO-aware 알고리즘 설계라는 원칙은 이후 FlashAttention-2, FlashAttention-3, FlashDecoding 등 후속 최적화의 기반이 되었으며, 하드웨어 인식 알고리즘 연구의 중요성을 학계에 각인시켰다.