def init_model_from_states(config): print("Init model...") model = NvNet(config=config) if config["cuda_devices"] is not None: if num_gpu > 0: model = torch.nn.DataParallel(model) # multi-gpu inference model = model.cuda() checkpoint = torch.load(config['saved_model_path'], map_location='cpu') state_dict = checkpoint["state_dict"] if not config["load_from_data_parallel"]: model.load_state_dict(state_dict) else: from collections import OrderedDict # Load state_dict from checkpoint model trained by multi-gpu new_state_dict = OrderedDict() for k, v in state_dict.items(): if not "vae" in k: # disable the vae path if "module." in k: new_state_dict[k] = v # name = k[7:] else: name = "module." + k # fix the bug of missing keywords caused by data parallel new_state_dict[name] = v model.load_state_dict(new_state_dict) return model
def init_model_from_states(config): print("Init model...") model = NvNet(config=config) if config["cuda_devices"] is not None: # model = torch.nn.DataParallel(model) # multi-gpu training model = model.cuda() checkpoint = torch.load(config["best_model_file"]) state_dict = checkpoint["state_dict"] if not config["load_from_data_parallel"]: model.load_state_dict(state_dict) else: from collections import OrderedDict # Load state_dict from checkpoint model trained by multi-gpu new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] new_state_dict[name] = v model.load_state_dict(new_state_dict) return model