Training Deep Nets with Sublinear Memory Cost
서브선형 메모리 비용으로 심층 네트워크 학습
Tianqi Chen, Bing Xu, Chiyuan Zhang, Carlos Guestrin (2016)
순전파 시 모든 중간 활성화를 메모리에 저장하는 대신, 일부 체크포인트만 저장하고 역전파 시 필요한 활성화를 재계산하는 기법으로, 메모리 사용량을 O(n)에서 O(√n)으로 줄이면서 계산 비용은 약 20%만 증가시켰다.
배경
심층 신경망의 표준 역전파는 모든 레이어의 중간 활성화를 메모리에 저장해야 하므로, 메모리 사용량이 네트워크 깊이에 비례하여 선형적으로 증가한다. ResNet-1001과 같은 매우 깊은 네트워크나 긴 시퀀스의 RNN 학습에서 GPU 메모리가 병목이 되어, 배치 크기를 줄이거나 모델 크기를 제한해야 했다. 계산 비용을 약간 증가시키더라도 메모리를 극적으로 줄일 수 있는 기법이 필요했다.
핵심 아이디어
핵심 아이디어는 메모리와 계산의 트레이드오프이다. n개 레이어의 네트워크에서 √n개 간격으로 체크포인트 레이어를 지정하고, 이 레이어의 활성화만 메모리에 저장한다. 역전파 시 그래디언트 계산에 필요한 중간 활성화는 가장 가까운 체크포인트에서부터 순전파를 재실행하여 복원한다. 이렇게 하면 각 세그먼트(체크포인트 간격)의 최대 활성화 수가 √n이고, 체크포인트 수도 √n이므로 총 메모리는 O(√n)이 된다. 이 전략은 재귀적으로 적용할 수 있으며, 일반적인 계산 그래프(RNN, LSTM 포함)에도 적용 가능한 프레임워크로 확장된다.
방법론
임의의 계산 그래프를 세그먼트로 분할하고, 각 세그먼트의 경계에서만 활성화를 저장하는 자동 미분(automatic differentiation) 확장을 구현했다. 체크포인트 선택 전략으로는 균등 간격 분할이 가장 단순하며, 동적 프로그래밍으로 최적 체크포인트 위치를 탐색할 수도 있다. 구현은 TensorFlow, MXNet 등 자동 미분 프레임워크에 통합되며, 사용자가 체크포인트 레이어를 지정하는 간단한 API를 제공한다.
주요 결과
1,000 레이어의 피드포워드 네트워크에서 메모리 사용량을 약 32배 감소시키면서 계산 시간은 약 20-30%만 증가했다. ImageNet 학습에서 ResNet-101을 더 큰 배치 크기로 학습할 수 있게 했으며, 1,000 스텝의 RNN에서도 서브선형 메모리를 달성했다. 재귀적 체크포인팅을 적용하면 O(log n) 메모리까지 감소가 가능하지만 계산 오버헤드가 O(n log n)으로 증가한다.
임팩트
그래디언트 체크포인팅은 현대 딥러닝 프레임워크(PyTorch의 torch.utils.checkpoint, TensorFlow의 tf.recompute_grad)에 표준 기능으로 통합되어, 대규모 모델 학습의 핵심 인프라가 되었다. 특히 Transformer 기반 대규모 언어 모델 학습에서 메모리 제약을 극복하는 필수 기법으로 활용되며, FlashAttention, 모델 병렬화 등 다른 메모리 최적화 기법과 함께 사용된다. 이 아이디어는 계산-메모리 트레이드오프의 근본적 원리를 확립하여 효율적 학습 연구에 지속적인 영향을 미치고 있다.