ICLR 2020Citations: 1,500+

Large Batch Optimization for Deep Learning: Training BERT in 76 minutes

딥러닝을 위한 대규모 배치 최적화: BERT 76분 학습

Yang You, Jing Li, Sashank Reddi, et al. (2019)

LARS의 레이어별 적응적 스케일링을 Adam 옵티마이저에 결합한 LAMB(Layer-wise Adaptive Moments optimizer for Batch training) 알고리즘을 제안하여, BERT 사전학습을 배치 크기 64K에서 안정적으로 수행하고 학습 시간을 3일에서 76분으로 단축했다.

배경

BERT 사전학습은 방대한 계산 자원을 요구하여, 원래 논문에서는 16개 TPU로 4일간 학습해야 했다. 학습을 가속하기 위해 배치 크기를 크게 늘리고 GPU/TPU를 추가하는 데이터 병렬화가 자연스러운 접근이지만, 큰 배치에서의 학습 불안정 문제가 있었다. LARS는 합성곱 네트워크에서 성공했으나, Transformer 기반의 BERT에는 그대로 적용이 어려웠다. 특히 BERT는 SGD가 아닌 Adam 옵티마이저를 사용하며, 어텐션 레이어와 임베딩 레이어의 그래디언트 특성이 합성곱 레이어와 다르다.

핵심 아이디어

LAMB는 두 가지 핵심 기법을 결합한다. 첫째, Adam의 적응적 모멘트 추정(1차, 2차 모멘트)을 유지하여 파라미터별 학습률 조정을 수행한다. 둘째, LARS에서 영감을 받은 레이어별 신뢰 비율(trust ratio) φ(||w||) / ||Adam_update||를 곱하여, 각 레이어의 업데이트 크기를 가중치 노름에 비례하도록 정규화한다. 이 두 수준의 적응(파라미터 수준의 Adam + 레이어 수준의 신뢰 비율)이 결합되어, 매우 큰 배치에서도 안정적인 학습이 가능해진다. 이론적으로도 LAMB의 수렴성을 비볼록 최적화 설정에서 증명했다.

방법론

표준 Adam의 업데이트 규칙 m_t = β₁m_{t-1} + (1-β₁)g_t, v_t = β₂v_{t-1} + (1-β₂)g_t²에 바이어스 보정을 적용한 후, 정규화된 업데이트 r_t = m̂_t/√(v̂_t) + λw_t를 계산한다. 최종 업데이트에 φ(||w_t||) / ||r_t||의 신뢰 비율을 곱한다: w_{t+1} = w_t - η × (φ(||w_t||)/||r_t||) × r_t. 여기서 φ는 범위 제한 함수로 보통 항등 함수를 사용한다. BERT 학습 시 점진적 웜업과 선형 학습률 감쇠를 함께 적용한다.

주요 결과

BERT-Large 사전학습에서 배치 크기를 512에서 64K(65,536)까지 증가시키면서 GLUE 벤치마크에서의 미세조정 성능을 유지했다. 1,024개의 TPU v3를 사용하여 76분 만에 BERT 사전학습을 완료했으며, 이는 원래의 3일에서 약 49배 가속이다. SQuAD v1.1에서도 원래 BERT와 동등한 F1 점수를 달성했다. Adam과 LARS 단독으로는 배치 크기 8K 이상에서 성능이 저하되었으나, LAMB는 64K까지 안정적이었다.

임팩트

LAMB는 대규모 모델 사전학습의 효율화에 직접적으로 기여한 옵티마이저로, BERT 이후 다양한 Transformer 모델의 대규모 배치 학습에 채택되었다. 레이어별 적응적 학습률이라는 개념이 Transformer 아키텍처에서도 효과적임을 입증하여, 이후 대규모 언어 모델 학습의 최적화 연구에 영향을 미쳤다. 76분 BERT 학습이라는 상징적 결과는 계산 효율성의 중요성을 대중적으로 알리는 데도 기여했다.

관련 Foundation 논문

관련 논문