コード例 #1
0
 def load_model(self, path, name):
     model = RNNVAE(self.args,
                    self.args.enc_type,
                    len(self.data.dictionary),
                    self.args.emsize,
                    self.args.nhid,
                    self.args.lat_dim,
                    self.args.nlayers,
                    dropout=self.args.dropout,
                    tie_weights=self.args.tied,
                    input_z=self.args.input_z,
                    mix_unk=self.args.mix_unk,
                    condition=(self.args.cd_bit or self.args.cd_bow),
                    input_cd_bow=self.args.cd_bow,
                    input_cd_bit=self.args.cd_bit)
     model.load_state_dict(torch.load(os.path.join(path, name + '.model')))
     model = model.cuda()
     return model
コード例 #2
0
 def load_model(args, ntoken, path, name):
     model = RNNVAE(args, args.enc_type, ntoken, args.emsize,
                    args.nhid, args.lat_dim, args.nlayers,
                    dropout=args.dropout, tie_weights=args.tied,
                    input_z=args.input_z, mix_unk=args.mix_unk,
                    condition=(args.cd_bit or args.cd_bow),
                    input_cd_bow=args.cd_bow, input_cd_bit=args.cd_bit)
     model.load_state_dict(torch.load(os.path.join(path, name + '.model')))
     from NVLL.util.gpu_flag import GPU_FLAG
     if torch.cuda.is_available() and GPU_FLAG:
         model = model.cuda()
     model = model.eval()
     return model
コード例 #3
0
    def end(self):
        # Load the best saved model.
        model = RNNVAE(self.args,
                       self.args.enc_type,
                       len(self.data.dictionary),
                       self.args.emsize,
                       self.args.nhid,
                       self.args.lat_dim,
                       self.args.nlayers,
                       dropout=self.args.dropout,
                       tie_weights=self.args.tied,
                       input_z=self.args.input_z,
                       mix_unk=self.args.mix_unk,
                       condition=(self.args.cd_bit or self.args.cd_bow),
                       input_cd_bow=self.args.cd_bow,
                       input_cd_bit=self.args.cd_bit)
        model.load_state_dict(torch.load(self.args.save_name + '.model'),
                              strict=False)
        model.to(device)
        # if torch.cuda.is_available() and GPU_FLAG:
        #     model = model.cuda()
        model = model.eval()
        print(model)
        print(self.args)
        # print("Anneal Type: {}".format(anneal_list[self.args.anneal]))
        train_loss, train_kl, train_total_loss = self.evaluate(
            self.args, model, self.data.train)

        cur_loss, cur_kl, test_loss = self.evaluate(self.args, model,
                                                    self.data.test)
        Runner.log_eval(self.writer, None, cur_loss, cur_kl, test_loss, True)

        os.rename(self.args.save_name + '.model',
                  self.args.save_name + '_' + str(test_loss) + '.model')
        os.rename(self.args.save_name + '.args',
                  self.args.save_name + '_' + str(test_loss) + '.args')

        # Write result to board
        self.write_board(self.args, train_loss, train_kl, train_total_loss,
                         cur_loss, cur_kl, test_loss)
        self.writer.close()
コード例 #4
0
ファイル: nvll.py プロジェクト: wangzheallen/vmf_vae_nlp
def main():
    args = NVLL.argparser.parse_arg()

    if args.model == 'nvdm':
        set_seed(args)
        args, writer = set_save_name_log_nvdm(args)
        logging.info("Device Info: {}".format(device))
        print("Current dir {}".format(os.getcwd()))

        from NVLL.data.ng import DataNg
        # from NVLL.model.nvdm import BowVAE
        # For rcv use nvdm_v2?

        if args.data_name == '20ng':
            from NVLL.model.nvdm import BowVAE
        elif args.data_name == 'rcv':
            from NVLL.model.nvdm import BowVAE
        else:
            raise NotImplementedError
        from NVLL.framework.train_eval_nvdm import Runner
        # Datarcv_Distvmf_Modelnvdm_Emb400_Hid400_lat50
        data = DataNg(args)
        model = BowVAE(args,
                       vocab_size=data.vocab_size,
                       n_hidden=args.nhid,
                       n_lat=args.lat_dim,
                       n_sample=args.nsample,
                       dist=args.dist)
        # Automatic matching loading
        if args.load is not None:
            model.load_state_dict(torch.load(args.load), strict=False)
        else:
            print("Auto load temp closed.")
            """
            files = os.listdir(os.path.join(args.exp_path))
            files = [f for f in files if f.endswith(".model")]
            current_name = "Data{}_Dist{}_Model{}_Emb{}_Hid{}_lat{}".format(args.data_name, str(args.dist),
                                                                                      args.model,
            for f in files:
                if current_name in f:
                    try:
                        model.load_state_dict(torch.load(os.path.join(
                            args.exp_path, f)), strict=False)
                        print("Auto Load success! File Name: {}".format(f))
                        break
                    except RuntimeError:
                        print("Automatic Load failed!")
            """
        model.to(device)
        # if torch.cuda.is_available() and GPU_FLAG:
        #     print("Model in GPU")
        #     model = model.cuda()
        runner = Runner(args, model, data, writer)
        runner.start()
        runner.end()

    elif args.model == 'nvrnn':

        set_seed(args)
        args, writer = set_save_name_log_nvrnn(args)
        logging.info("Device Info: {}".format(device))
        print("Current dir {}".format(os.getcwd()))

        from NVLL.data.lm import DataLM
        from NVLL.model.nvrnn import RNNVAE
        from NVLL.framework.train_eval_nvrnn import Runner
        if (args.data_name == 'ptb') or (args.data_name
                                         == 'trec') or (args.data_name
                                                        == 'yelp_sent'):
            data = DataLM(os.path.join(args.root_path, args.data_path),
                          args.batch_size,
                          args.eval_batch_size,
                          condition=False)
        elif args.data_name == 'yelp':
            data = DataLM(os.path.join(args.root_path, args.data_path),
                          args.batch_size,
                          args.eval_batch_size,
                          condition=True)
        else:
            raise NotImplementedError
        model = RNNVAE(args,
                       args.enc_type,
                       len(data.dictionary),
                       args.emsize,
                       args.nhid,
                       args.lat_dim,
                       args.nlayers,
                       dropout=args.dropout,
                       tie_weights=False,
                       input_z=args.input_z,
                       mix_unk=args.mix_unk,
                       condition=(args.cd_bit or args.cd_bow),
                       input_cd_bow=args.cd_bow,
                       input_cd_bit=args.cd_bit)
        # Automatic matching loading
        if args.load is not None:
            model.load_state_dict(torch.load(args.load), strict=False)
        else:
            print("Auto load temp closed.")
            """
            files = os.listdir(os.path.join( args.exp_path))
            files = [f for f in files if f.endswith(".model")]
            current_name = "Data{}_Dist{}_Model{}_Enc{}Bi{}_Emb{}_Hid{}_lat{}".format(args.data_name, str(args.dist),
                                                                                      args.model, args.enc_type,
                                                                                      args.bi,
                                                                                      args.emsize,
                                                                                      args.nhid, args.lat_dim)
            for f in files:
                if current_name in f and ("mixunk0.0" in f):
                    try:
                        model.load_state_dict(torch.load(os.path.join(
                                                                      args.exp_path, f)), strict=False)
                        print("Auto Load success! {}".format(f))
                        break
                    except RuntimeError:
                        print("Automatic Load failed!")
            """
        model.to(device)
        # if torch.cuda.is_available() and GPU_FLAG:
        #     model = model.cuda()
        runner = Runner(args, model, data, writer)
        runner.start()
        runner.end()