본문 바로가기
Language/python

[ Pytorch ] torch 기반의 모델 저장과 불러오기

by _YUJIN_ 2023. 8. 15.
  • Pytorch 모델을 불러올때 있었던 이슈는 모델을 잘 저장을 했지만, 불러올때 자꾸 GPU를 잡는 문제가 있었다.
  • 처음에는 디바이스 설정을 잘못한줄 알았지만.. 디버깅을 해보았을때 모델을 불러올때부터 GPU를 사용한다는 것을 알게 되어서 torch 모델 저장과 불러오는 방법을 다시 찾아보게 되었다 ㅎㅎ..
    < 참고한 블로그 :  https://gaussian37.github.io/dl-pytorch-deploy/  >
  • 지금 생각해보면 pytorch를 제대로 공부해보고 써보지 않아서 생긴 문제 같기도...ㅎㅎ

01. 모델 저장

  • 모델을 저장하면 '모델 구조 자체'와 '학습한 파라미터' 두가지를 저장하게 된다. 
  1. pytorch에서는 'state_dict' 함수를 사용하여 파라미터의 텐서를 사전 형식으로 추출할 수 있다. 
  2. 'torch.save' 라는 'pickel'의 wrapper 함수를 사용하여 파일로 저장할 수 있다. 
#-- 학습이 끝난 신경망 모델 변수 = model
params = model.state_dict()

#-- model.prm라는 파일로 저장
troch.save(params, 'model.prm', pickle_protocol = 4)
  • 모델을 사용하게 될 디바이스가 CPU라면 파라미터를 저장을 할때 모델을 CPU로 전송하고 저장하는 방법도 있다. 
#-- 모델을 CPU로 이동
model.cpu()

#-- 파라미터 저장
params = model.state_dict()
torch.save(params, 'model.prm', pickle_protocol = 4)

02. 저장된 모델 불러오기 

  • 'torch.load' 라는 메서드를 활용해서 위에서 저장한 'model.prm' 을 불러올 수 있다. (모델의 구조 + 학습된 파라미터)
  • 여기서 'map_location'이라는 인수를 사용하게 되는데,, 이 인수를 설정하지 않아서 이슈가 생겼던 것이다..
  • GPU로 학습한 모델의 파라미터를 그대로 저장하게 되면 torch.load를 통해 모델을 불러올 경우 CPU로 불러온 후 GPU로 전송하게 된다. 이때 GPU가 잡히는 문제가 발생했던 것이다..
  • 이런 오류를 방지하기 위해서는 꼭 'map_location' 인수가 필요하다!!!  (모델을 cpu로 저장했다면 상관없을 것 같다.)
#-- model.prm 불러오기
params = torch.load('model.prm', map_location = 'cpu')
model.load_state_dict(params)

 

반응형