def _init_models(self): self.net_G = CustomPoseGenerator(self.opt.pose_feature_size, 2048, self.opt.noise_feature_size, dropout=self.opt.drop, norm_layer=self.norm_layer, fuse_mode=self.opt.fuse_mode, connect_layers=self.opt.connect_layers) e_base_model = create(self.opt.arch, cut_at_pooling=True) e_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=2) self.net_E = SiameseNet(e_base_model, e_embed_model) di_base_model = create(self.opt.arch, cut_at_pooling=True) di_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=1) self.net_Di = SiameseNet(di_base_model, di_embed_model) self.net_Dp = NLayerDiscriminator(3+18, norm_layer=self.norm_layer) if self.opt.stage==1: init_weights(self.net_G) init_weights(self.net_Dp) state_dict = remove_module_key(torch.load(self.opt.netE_pretrain)) self.net_E.load_state_dict(state_dict) state_dict['embed_model.classifier.weight'] = state_dict['embed_model.classifier.weight'][1] state_dict['embed_model.classifier.bias'] = torch.FloatTensor([state_dict['embed_model.classifier.bias'][1]]) self.net_Di.load_state_dict(state_dict) elif self.opt.stage==2: self._load_state_dict(self.net_E, self.opt.netE_pretrain) self._load_state_dict(self.net_G, self.opt.netG_pretrain) self._load_state_dict(self.net_Di, self.opt.netDi_pretrain) self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain) else: assert('unknown training stage') self.net_E = torch.nn.DataParallel(self.net_E).cuda() self.net_G = torch.nn.DataParallel(self.net_G).cuda() self.net_Di = torch.nn.DataParallel(self.net_Di).cuda() self.net_Dp = torch.nn.DataParallel(self.net_Dp).cuda()
def _load_state_dict(self, net, path): state_dict = remove_module_key(torch.load(path)) net.load_state_dict(state_dict)