학습된 파라미터 로드 중 표제와 같은 에러가 발생하는 경우가 있다. 이런 경우에 대한 해결책을 정리해본다.

에러 발생 작업 순서:

1. apex의 DistributedDataParallel를 이용해 multi gpu로 모델 학습 중 아래 코드를 이용해 모델 state 저장

torch.save('state_dict':model.state_dict())

 

2. 저장된 state_dict을 단일 gpu를 사용해 테스트 하기 위해 torch.load()로 복원 하던 중 다음과 같은 에러가 발생했다.

'''
Error(s) in loading state_dict for model: 
	Missing key(s) in state_dict: "backbone.block0.0.0.weight", ~, ~
	Unexpected key(s) in state_dict: "module.backbone.0.0.weight", ~, ~
'''

3. 위 에러 메시지에서 확인 가능 하듯 저장된 state_dict에는 모든 weight 이름 앞에 "module" 이라는 prefix가 추가되어 있다.

 

발생 사유:

아래와 같이 DistributedDataParallel 사용 시

DDPmodel = DistributedDataParallel(model)

리턴된 _DDPmodel_은 _model_을 _module_로 감싼 형태 이다.

즉, 기존 model_의 *attribute_은 *model.module.attribute 과 같이 접근 해야 하는데 그걸 빼먹은 것이다.

 

 

해결 방법 1:

학습 중 모델 저장 시 아래 코드를 이용해 모델 상태를 저장한다.

torch.save('state_dict':model.module.state_dict())

 

 

해결 방법 2:

본래 목적이 기 학습된 모델의 구조를 유지하고 파라미터의 극히 일부분만 finetuning_하는 것이 었기 때문에 *_해결 방법 1** 은 좋은 해결 방법은 아니다. 본래의 목적을 위해서는 로드된 모델 _state_dict_의 파라미터 이름에서 *module. (또는 .module)을 제거 하면 정상적인 로드가 가능하다.다음은 예제 코드이다.

loaded_state_dict = torch.load(state_dict_path)
new_state_dict = OrderedDict()
for n, v in loaded_state_dict.items():
    name = n.replace("module.","") # .module이 중간에 포함된 형태라면 (".module","")로 치환
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)

비슷 한 이슈를 격은 사람들이 이 링크에 있는거 같다.

link:참조

+ Recent posts