def load_adversarial_model(self, model_file=None): Q_encoder = Encoder(self.latent_vector_dim) P_decoder = Decoder(self.latent_vector_dim) D_guess = Discriminator(self.latent_vector_dim) if model_file is not None: print("loading >>> ", model_file, " ...") save_model = torch.load(model_file) Q_encoder.load_state_dict(save_model["encoder"]) P_decoder.load_state_dict(save_model["decoder"]) D_guess.load_state_dict(save_model["guess"]) else: checkpoint_dir = self.model_root if (not os.path.exists(checkpoint_dir)): os.makedirs(checkpoint_dir) latest = latest_checkpoint(checkpoint_dir) if latest is not None: print("loading >>> ", latest, " ...") save_model = torch.load(latest) Q_encoder.load_state_dict(save_model["encoder"]) P_decoder.load_state_dict(save_model["decoder"]) D_guess.load_state_dict(save_model["guess"]) return Q_encoder, P_decoder, D_guess
def load_model(self, model_file = None): ''' 加载模型 :param model_file: 指定模型的存盘文件 :return: ''' if self.model_name == "aae": model, _, _ = self.load_adversarial_model(model_file) else: model = self.create_initial_model() if model_file is not None: print("loading >>> ", model_file, " ...") model.load_state_dict(torch.load(model_file)) else: checkpoint_dir = self.model_root if (not os.path.exists(checkpoint_dir)): os.makedirs(checkpoint_dir) latest = latest_checkpoint(checkpoint_dir) if latest is not None: print("loading >>> ", latest, " ...") model.load_state_dict(torch.load(latest)) return model
def load_model(self, model_file=None): ''' 加载模型 :param model_file: 模型文件 :return: 网络模型 ''' if model_file is not None: print("loading >>> ", model_file, " ...") model = torch.load(model_file) return model else: checkpoint_dir = self.model_root if (not os.path.exists(checkpoint_dir)): os.makedirs(checkpoint_dir) latest = latest_checkpoint(checkpoint_dir) if latest is not None: print("loading >>> ", latest, " ...") model = torch.load(latest) else: model = self.create_new_top_cnn_model() return model
def load_model(self, model_file): ''' 加载模型 :param model_file: 模型文件 :return: 网络模型 ''' if model_file is None: checkpoint_dir = self.model_root if (not os.path.exists(checkpoint_dir)): os.makedirs(checkpoint_dir) model_file = latest_checkpoint(checkpoint_dir) if model_file is not None: print("loading >>> ", model_file, " ...") load_object = torch.load(model_file) if isinstance(load_object, dict): model = self.create_initial_model() model.load_state_dict(torch.load(model_file)) else: model = load_object else: model = self.create_initial_model() return model
def test_latest_checkpoint(self): result = util.latest_checkpoint( "D:/CloudSpace/WorkSpace/PatholImage/models/pytorch/se_densenet_22_x_256" ) print(result)