[논문] About Model Parallelism
드디어 종강했다. 중간고사 이후, 수업 과제로 논문(EFFICIENT AND ROBUST PARALLEL DNN TRAINING THROUGH MODEL PARALLELISM ON MULTI-GPU PLATFORM) 발표를 진행했었다. 그때 느낀 바로는 굉장히 간단한 아이디어라고 생각했다. 하지만 코로나 때문에 기말고사가 사라지고 논문에 나온 실험을 진행해서 보고서로 제출하게 되었다. 이게 불행의 시작이었다. 이 논문은 저자가 공개한 코드가 없다... 그래서 어쩔 수 없이 여러 시행착오 끝에 타협을 하게 되었다. 구현하고자 하는 내용은 다음과 같다.
- Model Parallelism에서 일어나는 weight staleness 구현
- weight staleness를 저자들이 제안한 SpecTrain: weight prediction using smoothed gradient로 해결
일단 간단히 설명하면, Multi-GPU를 효율적으로 사용하기 위해 다양한 parallelism 기법들이 제안되었다. 그중 Data parallelism은 모델을 여러 GPU에 복사하고, train data를 나눠서 공급하는 기법이다. Model parallelism은 모델을 분할해서 여러 GPU에 올리고 학습을 진행한다. 각각의 장단, issue 등이 있지만 발표 논문에서는 Model parallelism에서 볼 수 있는 weight staleness를 해결하였다. 논문에 나온 다음 그림을 보자. 3개의 GPU에 모델이 나뉘어 있다.
정사각형의 밝은 네모는 forward process이고, 직사각형의 어두운 네모는 backward process이다. 모델이 나뉘어 있기 때문에 순차적으로 처리할 수밖에 없는데, 위에서 화살표로 표시되듯이 input 4에 대한 forward를 진행하고 해당하는 gradient를 backward 통해서 업데이트할 때, t=4 시점의 weight에 업데이트해야 하는데 batch pipelining 때문에 해당하는 weight는 이미 사라져 버렸다. 이러면 제대로 된 학습이 힘들다고 할 수 있다. 이를 해결하기 위해 과거의 weight를 기억해 놓았다가 update 하는 weight stashing 등이 제안되었는데, 모델이 클수록 메모리를 많이 차지한다. 위 그림의 갈색 화살표는 update 시점 전까지의 weight들을 smoothed gradient 상태로 보관하였다가 얼마나 많은 time이 경과하였는지를 반영하여, 해당 time에서의 weight를 예측하여 사용하는 것을 나타낸 그림이다. 예측 식도 간단하고 해서, 금방 할 줄 알았다. 환경은 RTX 2080Ti x2, CUDA 10.1, Cudnn7.6.5이다.
- 첫 번째 시도: multiprocessing package를 사용하기
일단 저자들이 VGG16과 CIFAR-10을 사용했기 때문에, 따라해준다. 다만 코드가 없어 model shape이나 batch 등은 알 수 없어 내 맘대로 하기로 했다. 일단 인터넷에서 긁어온 VGG16을 두 개로 나눠준다.
m1 = nn.Sequential(
# 3 224 128
nn.Conv2d(3, 64, 3, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 64, 3, padding=1),
nn.LeakyReLU(0.2),
nn.MaxPool2d(2, 2),
# 64 112 64
nn.Conv2d(64, 128, 3, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 128, 3, padding=1),
nn.LeakyReLU(0.2),
nn.MaxPool2d(2, 2),
# 128 56 32
nn.Conv2d(128, 256, 3, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 256, 3, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 256, 3, padding=1),
nn.LeakyReLU(0.2),
nn.MaxPool2d(2, 2),
# 256 28 16
nn.Conv2d(256, 512, 3, padding=1),
nn.LeakyReLU(0.2),
)
# 21 layers - 10 layers with weight
m2 = nn.Sequential(
# 256 28 16
nn.Conv2d(512, 512, 3, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(512, 512, 3, padding=1),
nn.LeakyReLU(0.2),
nn.MaxPool2d(2, 2),
# 512 14 8
nn.Conv2d(512, 512, 3, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(512, 512, 3, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(512, 512, 3, padding=1),
nn.LeakyReLU(0.2),
nn.MaxPool2d(2, 2),
# classifier
nn.Flatten(),
nn.Dropout(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Dropout(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.Softmax(),
)
이제 python의 multiprocessing package를 사용해서, 2개의 child process에게 각 model의 forward process 및 backward process를 맡기고, model parallelism에서 발생하는 weight staleness를 관측하려고 했다.
# import multiprocessing
from multiprocessing import Process, Queue
# Set start method: Spawn (options: spawn, fork, ...)
torch.multiprocessing.set_start_method("spawn")
# Run
m1_result = Queue()
m2_result = Queue()
while True:
# load data from dataiter
inputs, labels = get_next_data(dataiter)
# forward in m1
proc1 = Process(target=mod1.run, args=(inputs, m1_result, m2_result))
proc1.start()
proc2 = Process(target=mod2.run, args=(None, m1_result, m2_result))
proc2.start()
proc1.join()
proc2.join()
일단 model이라는 class를 선언 후 두 개의 instance를 만들고, 거기게 각각 m1, m2를 넣어주었다. 그리고 member function으로 run(inputs, m1_result, m2_result) 함수를 선언해줬다. 모델 병렬 화이니만큼 앞 모델의 결괏값을 뒤 모델이 받고, 뒤 모델에서 계산한 loss를 앞 모델에도 전달해 줘야 한다고 생각해서 이렇게 했다. 이 때는 아직 torch의 forward & backward 과정을 이해하지 못해서 이렇게 짰다. 근데 이게 문제가 아니라 process handling을 못해서 결국 제대로 학습시켜보지도 못하고 실패했다. process 간의 communication과 scheduling에 뭔가 문제가 있는 모양인데, 제대로 이해도 못했고 시간이 없어서 바로 다음 시도로 넘어갔다. 언제 시간이 남으면 꼭 공부해봐야지.
- 두 번째 시도: Gpipe 라이브러리 사용 (설명 참조: Kakao Brain)
약간 치트키 느낌인데, 조원분들이 조사 중에 발견한 package이다. 이거는 다른 논문 (GPipe: Easy Scaling with Micro-Batch Pipeline Parallelism)을 구현한 건데, 간단히 말하면 batch를 micro-batch로 또 쪼개서 GPU idle을 최소화하는 방법이다. 예를 들어, batch size가 32이면 첫 번째 모델이 32개의 데이터를 처리하고 2번째 모델로 넘겨준다. 그럼 두 번째 모델은 첫 번째 모델의 처리 결과를 받기 전까지는 놀게 된다. 그러면 아까우니까, 32개를 사용자가 지정한 n개의 micro-batch로 나눠서 준다. n=8이면, 32 / 8 = 4니까 앞 모델이 4개 처리하고 넘겨주고를 반복하게 된다. 그러면 2번째 모델도 더 빨리 일을 시작할 수 있다. 더더 자세히는 위 링크에 아주 잘 나와있다! 코드를 보자.
# import Gpipe package
from torchgpipe import GPipe
# partition model using Gpipe
model = GPipe(VGG16, balance=[15, 16], chunks=1)
분할하기 매우 매우 쉽다. 분할된 모델 = Gpipe(Seq_model_to_partition, balance=[], chunks=num_micro_batches)이다. 우리는 weight staleness를 봐야 하므로 chunks=1 (batch = micro-batch)로 정했다. 여기서 balance는 list인데, len(balance)는 모델을 몇 개로 나눌지이고 각 index의 값은 몇 개의 layer를 할당할지 이다. 예를 들어 GPU가 3개이고 model layer가 29개면 [10, 10, 9] 등으로 정할 수 있다. layer 총개수만 맞추면 어떻게 분할하든 내 마음이다! 자동으로 configure 해주는 기능도 있는데 안 썼다. 근데 이렇게 하니까 문제점이, 내가 지금 Gpipe 실험을 하는 건지 SpecTrain실험을 하는 건지... 저자들은 Gpipe를 안 썼으니까 나도 Gpipe를 사용하지 않기로 했다.
- 세 번째 시도: weight staleness 상황 구현
결국 시간 내에 진정한 model parallelism을 구현하진 못했다. 심지어 따로따로 업데이트하니까, ._version이라는 memver variable이 자꾸 버전 안 맞는다고 업데이트 안 해줘서 스트레스받았다. 심지어 바꿀 수도 없는 c...var... 욕이 아니라 c로 짜여 있다는 소리 같다 ㅎㅎ. 좀 일찍 할 걸... 그래도 일단 뭐라도 해야 하니까, update시에 강제로 weight의 gradient를 0으로 만들고 update 했다. 나는 GPU가 2개라서, 1번째 GPU에 올라간 모델만 staleness issue를 겪는 상황을 가정했다.
### maintain current gradient for gradient smoothing
current_grads = []
for param in m1.parameters():
current_grads.append(param.grad)
### if it's first learning
if not recent_m1_grads:
for param in m1.parameters():
recent_m1_grads.append(param.grad)
else:
for j, param in enumerate(m1.parameters(), 0):
### smooth the gradient using paper's equation
recent_m1_grads[j] = gamma * recent_m1_grads[j] + (1 - gamma) * param.grad
### period means version difference.
if i % period == 0:
for j, param in enumerate(m1.parameters(), 0):
### Use predicted gradient, using paper's equation
param.grad = period * recent_m1_grads[j]
else:
### else set the gradient to 0, to resemble update miss due to version difference
opt1.zero_grad()
첫 current_grad 부분은, 현재 계산된 weight를 저장하는 부분이다. 밑에 ### smooth~ 부분에서는 현재 계산된 weight를 smoothing을 하고, 맨 밑에 부분에서는 update 주기가 되면 주기 * smoothed gradient로 gradient를 바꿔준다. update하지 않을 때에는 optimzer.zero_grad()를 사용해서 gradient를 다 0으로 만들고 update 한다. 내가 뭘 하는지 모르겠지만, 일단 이렇게 학습을 진행했을 때 loss를 찍어보면 다음과 같다. 시간 없어서 plot도 대충 다 나오게 찍었다.
원래 plot은 잘 나온 것만 보여주는 거다. diff=n 은 몇 번마다 update 할지이고, 옆에는 얻어진 Accuracy가 적혀 있다. 업데이트를 더 드물게 할수록 모델 정확도가 떡락하는 것을 볼 수 있다. 아무튼 이렇게 얼렁뚱땅 해서 냈다.
아마 잘 아시는 분들이 이 포스팅을 읽으면 기가 찰 수도 있다. 나도 좀 기가 차긴 하다. 하필 골라도 코드 없는 논문을 골랐다는 게 일단 첫 번째로 기가 찬다. 앞으로도 공부에는 아마 이딴 포스팅만 올라올 것이다. 난 성능충은 아니라서 학습 시간 줄이고 GPU 최적화하고 이런 쪽은 재미없는 거 같아... 관련 연구자들께는 죄송합니다 ㅎㅎ
감사합니다.