본문 바로가기
에러(Error) 이야기

[Pytorch] RuntimeError: Error(s) in loading state_dict for RobertaForSequenceClassification

by Kaya_Alpha 2024. 4. 9.

1. 에러 발생 상황

학습된 모델을 저장하고 다시 불러오는(model.load_state_dict()) 상황에서 발생하였다..

에러 메시지는 다음과 같다.

에러 발생...

 

2. 해결 방법

찾아보니 버전차이가 원인이였다.

fine-tuning을 진행한 pytorch의 버전은 2.2버전이고, inference하기 위해 불러온 환경에서의 pytorch 버전은 1.7 버전이였다.

아마 버전차이로 인해 에러가 발생하는것 같다.

 

해결방법은 의외로 간단하다.

load_state_dict의 인자 중 strict를 False로 주면 바로 해결된다.(load할 수 있는 key값만 불러오는 방식인듯)

model.load_state_dict(torch.load("./model/saved_model.pt"),strict = False)

 

위와 같이 인자를 추가하면 에러가 해결된다!