def reload_network(self): """ Reload entire model or only decoder (atlasnet) depending on the options :return: """ if self.opt.reload_model_path != "": yellow_print( f"Network weights loaded from {self.opt.reload_model_path}!") # print(self.network.state_dict().keys()) # print(torch.load(self.opt.reload_model_path).keys()) self.network.module.load_state_dict( torch.load(self.opt.reload_model_path, map_location='cuda:0')) elif self.opt.reload_decoder_path != "": opt = deepcopy(self.opt) opt.SVR = False network = EncoderDecoder(opt) network = nn.DataParallel(network, device_ids=opt.multi_gpu) network.module.load_state_dict( torch.load(opt.reload_decoder_path, map_location='cuda:0')) self.network.module.decoder = network.module.decoder yellow_print( f"Network Decoder weights loaded from {self.opt.reload_decoder_path}!" ) else: yellow_print("No network weights to reload!")
def build_network(self): """ Create network architecture. Refer to auxiliary.model :return: """ if torch.cuda.is_available(): self.opt.device = torch.device(f"cuda:{self.opt.multi_gpu[0]}") else: # Run on CPU self.opt.device = torch.device(f"cpu") self.network = EncoderDecoder(self.opt) self.network = nn.DataParallel(self.network, device_ids=self.opt.multi_gpu) self.reload_network()
def reload_network(self): """ Reload entire model or only decoder (atlasnet) depending on the options :return: """ #self.opt.reload_model_path="network.pth" if self.opt.reload_model_path != "": yellow_print( f"Network weights loaded from {self.opt.reload_model_path}!") # print(self.network.state_dict().keys()) # print(torch.load(self.opt.reload_model_path).keys()) state_dict = torch.load(self.opt.reload_model_path, map_location='cuda:0') # create new OrderedDict that does not contain `module.` from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v self.network.module.load_state_dict(new_state_dict) elif self.opt.reload_decoder_path != "": opt = deepcopy(self.opt) opt.SVR = False network = EncoderDecoder(opt) network = nn.DataParallel(network, device_ids=opt.multi_gpu) network.module.load_state_dict( torch.load(opt.reload_decoder_path, map_location='cuda:0')) self.network.module.decoder = network.module.decoder yellow_print( f"Network Decoder weights loaded from {self.opt.reload_decoder_path}!" ) else: yellow_print("No network weights to reload!")