PyTorch 모델 특정 종류의 레이어를 바꾸기
개발새발/개발 셋업
2023. 3. 16. 11:54
API를 쓰다보면, activation function을 비롯해서 모델의 레이어를 바꾸고 싶은 일이 종종 생기는데, nn.Sequantial()이 다중으로 겹쳐진 경우 구글링 해서 나오는 코드들은 대부분 동작을 안 한다. (예시 isinstance()를 한두번 정도 묻는 것, 혹은 setattr로 변경을 시도하는 것) 코드 아래의 코드는 재귀함수 형태로, 바꾸고 싶은 instance를 target에 전달하고, 바꾸고 싶은 instance를 source에 전달을 한다. 이러면 nn.Sequantial에 얼마나 감싸져있던 재귀형태로 모두 진행하여 모델 레이어를 변경할 수 있다. def replace_module(modules:nn.Module, target, source): for name, child in..