FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
FlashAttention-2: 개선된 병렬화와 작업 분할로 더 빠른 어텐션
Tri Dao (2023)
FlashAttention의 알고리즘을 GPU의 병렬화와 작업 분할 측면에서 재설계하여, 비인과적 어텐션에서 이론적 최대 FLOPS의 약 70%를 달성하고, 원래 FlashAttention 대비 약 2배의 속도 향상을 이루었다.
배경
FlashAttention이 IO-aware 어텐션의 실용성을 입증한 후, GPU 하드웨어 활용률을 더 높이기 위한 최적화가 필요했다. FlashAttention v1은 A100 GPU의 이론적 최대 FLOPS의 약 25-35%만 활용하고 있었는데, 이는 GPU 스레드 블록 간 작업 분배의 비효율, 불필요한 공유 메모리 읽기/쓰기, 워프(warp) 간 동기화 오버헤드 등에 기인했다. 특히 인과적(causal) 마스킹이 적용된 경우 삼각 형태의 불균등한 워크로드로 인해 GPU 점유율이 더욱 낮았다.
핵심 아이디어
FlashAttention-2는 세 가지 핵심 최적화를 도입한다. 첫째, 알고리즘의 루프 구조를 변경하여 외부 루프를 Q 블록에, 내부 루프를 K/V 블록에 대해 순회하도록 재구성한다. 이를 통해 각 스레드 블록이 하나의 Q 블록에 대한 전체 어텐션 출력을 독립적으로 계산하므로, 스레드 블록 간 통신(synchronization)이 불필요해진다. 둘째, 워프 내 작업 분할을 최적화하여 공유 메모리(shared memory) 접근과 동기화를 최소화한다. 기존에는 K를 워프 간에 분할했지만, FlashAttention-2는 Q를 워프 간에 분할하여 리덕션(reduction) 단계의 공유 메모리 읽기/쓰기를 제거한다. 셋째, 시퀀스 길이 차원에서의 병렬화를 추가하여, 배치 크기나 헤드 수가 적을 때도 GPU SM(Streaming Multiprocessor)을 충분히 활용한다.
방법론
Q를 외부 루프로, K/V를 내부 루프로 하는 타일링을 적용한다. 각 스레드 블록은 하나의 Q 블록을 담당하여 모든 K/V 블록을 순회하면서 부분 출력을 축적한다. 워프 수준에서 Q를 4개 워프에 분할하되 K/V는 모든 워프가 공유하여, softmax 통계(최대값, 합)의 워프 간 동기화를 register-level warp shuffle로 처리한다. 인과적 마스킹에서는 유효한 블록만 계산하도록 조기 종료를 적용하여 불필요한 계산을 제거한다. 시퀀스 길이가 길 때는 Q를 추가로 분할하여 더 많은 스레드 블록을 생성하고, 최종 결과를 별도 커널에서 병합한다.
주요 결과
A100 80GB GPU에서 FlashAttention-2는 비인과적 어텐션에서 최대 230 TFLOPS를 달성하여 이론적 최대(312 TFLOPS)의 약 73%에 도달했다. 이는 FlashAttention v1 대비 약 2배의 속도 향상이다. 인과적 어텐션에서도 유사한 비율의 향상을 보였으며, 특히 긴 시퀀스(4K, 8K, 16K)에서 더 큰 속도 이득을 달성했다. GPT 스타일 모델의 end-to-end 학습에서 FlashAttention v1 대비 1.3-1.5배, 표준 어텐션 대비 5-8배의 학습 속도 향상을 기록했다.
임팩트
FlashAttention-2는 현재 대부분의 LLM 학습 및 추론 프레임워크의 기본 어텐션 구현으로 사용되고 있다. PyTorch의 SDPA(Scaled Dot-Product Attention), HuggingFace Transformers, DeepSpeed, Megatron-LM 등에 통합되어 사실상 산업 표준이 되었다. GPU 하드웨어에 대한 깊은 이해를 바탕으로 한 시스템 수준 최적화가 알고리즘 수준 개선 못지않게 중요함을 재확인시켰다. 이후 FlashAttention-3(Hopper 아키텍처 최적화), FlashDecoding(추론 특화) 등으로 지속적으로 발전하고 있다.