학습된 파라미터 로드 중 표제와 같은 에러가 발생하는 경우가 있다. 이런 경우에 대한 해결책을 정리해본다.
에러 발생 작업 순서:
1. apex의 DistributedDataParallel를 이용해 multi gpu로 모델 학습 중 아래 코드를 이용해 모델 state 저장
torch.save('state_dict':model.state_dict())
2. 저장된 state_dict을 단일 gpu를 사용해 테스트 하기 위해 torch.load()로 복원 하던 중 다음과 같은 에러가 발생했다.
'''
Error(s) in loading state_dict for model:
Missing key(s) in state_dict: "backbone.block0.0.0.weight", ~, ~
Unexpected key(s) in state_dict: "module.backbone.0.0.weight", ~, ~
'''
3. 위 에러 메시지에서 확인 가능 하듯 저장된 state_dict에는 모든 weight 이름 앞에 "module" 이라는 prefix가 추가되어 있다.
발생 사유:
아래와 같이 DistributedDataParallel 사용 시
DDPmodel = DistributedDataParallel(model)
리턴된 _DDPmodel_은 _model_을 _module_로 감싼 형태 이다.
즉, 기존 model_의 *attribute_은 *model.module.attribute 과 같이 접근 해야 하는데 그걸 빼먹은 것이다.
해결 방법 1:
학습 중 모델 저장 시 아래 코드를 이용해 모델 상태를 저장한다.
torch.save('state_dict':model.module.state_dict())
해결 방법 2:
본래 목적이 기 학습된 모델의 구조를 유지하고 파라미터의 극히 일부분만 finetuning_하는 것이 었기 때문에 *_해결 방법 1** 은 좋은 해결 방법은 아니다. 본래의 목적을 위해서는 로드된 모델 _state_dict_의 파라미터 이름에서 *module. (또는 .module)을 제거 하면 정상적인 로드가 가능하다.다음은 예제 코드이다.
loaded_state_dict = torch.load(state_dict_path)
new_state_dict = OrderedDict()
for n, v in loaded_state_dict.items():
name = n.replace("module.","") # .module이 중간에 포함된 형태라면 (".module","")로 치환
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
비슷 한 이슈를 격은 사람들이 이 링크에 있는거 같다.
link:참조
'pytorch' 카테고리의 다른 글
[pytorch] RuntimeError " All tensor must be on devices[0]: 0" (0) | 2021.08.31 |
---|---|
[Profile] GPU profile을 통한 병목 진단 및 개선 (6) | 2021.07.19 |
[pytorch] AttributeError: DistributedDataParallel has no attribute (0) | 2021.04.21 |
[pytorch] torch.gather 설명 (2) | 2021.03.05 |