Пример #1
0
    def load_adversarial_model(self, model_file=None):
        Q_encoder = Encoder(self.latent_vector_dim)
        P_decoder = Decoder(self.latent_vector_dim)
        D_guess = Discriminator(self.latent_vector_dim)

        if model_file is not None:
            print("loading >>> ", model_file, " ...")
            save_model = torch.load(model_file)
            Q_encoder.load_state_dict(save_model["encoder"])
            P_decoder.load_state_dict(save_model["decoder"])
            D_guess.load_state_dict(save_model["guess"])
        else:
            checkpoint_dir = self.model_root
            if (not os.path.exists(checkpoint_dir)):
                os.makedirs(checkpoint_dir)

            latest = latest_checkpoint(checkpoint_dir)
            if latest is not None:
                print("loading >>> ", latest, " ...")
                save_model = torch.load(latest)
                Q_encoder.load_state_dict(save_model["encoder"])
                P_decoder.load_state_dict(save_model["decoder"])
                D_guess.load_state_dict(save_model["guess"])

        return Q_encoder, P_decoder, D_guess
Пример #2
0
    def load_model(self, model_file = None):
        '''
        加载模型
        :param model_file: 指定模型的存盘文件
        :return:
        '''
        if self.model_name == "aae":
            model, _, _ = self.load_adversarial_model(model_file)
        else:
            model = self.create_initial_model()

            if model_file is not None:
                print("loading >>> ", model_file, " ...")
                model.load_state_dict(torch.load(model_file))
            else:
                checkpoint_dir = self.model_root
                if (not os.path.exists(checkpoint_dir)):
                    os.makedirs(checkpoint_dir)

                latest = latest_checkpoint(checkpoint_dir)
                if latest is not None:
                    print("loading >>> ", latest, " ...")
                    model.load_state_dict(torch.load(latest))

        return model
    def load_model(self, model_file=None):
        '''
        加载模型
        :param model_file: 模型文件
        :return: 网络模型
        '''
        if model_file is not None:
            print("loading >>> ", model_file, " ...")
            model = torch.load(model_file)
            return model
        else:
            checkpoint_dir = self.model_root
            if (not os.path.exists(checkpoint_dir)):
                os.makedirs(checkpoint_dir)

            latest = latest_checkpoint(checkpoint_dir)
            if latest is not None:
                print("loading >>> ", latest, " ...")
                model = torch.load(latest)
            else:
                model = self.create_new_top_cnn_model()
            return model
Пример #4
0
    def load_model(self, model_file):
        '''
        加载模型
        :param model_file: 模型文件
        :return: 网络模型
        '''
        if model_file is None:
            checkpoint_dir = self.model_root
            if (not os.path.exists(checkpoint_dir)):
                os.makedirs(checkpoint_dir)

            model_file = latest_checkpoint(checkpoint_dir)

        if model_file is not None:
            print("loading >>> ", model_file, " ...")
            load_object = torch.load(model_file)
            if isinstance(load_object, dict):
                model = self.create_initial_model()
                model.load_state_dict(torch.load(model_file))
            else:
                model = load_object
        else:
            model = self.create_initial_model()
        return model
Пример #5
0
 def test_latest_checkpoint(self):
     result = util.latest_checkpoint(
         "D:/CloudSpace/WorkSpace/PatholImage/models/pytorch/se_densenet_22_x_256"
     )
     print(result)