def load_model(self, path, name): model = RNNVAE(self.args, self.args.enc_type, len(self.data.dictionary), self.args.emsize, self.args.nhid, self.args.lat_dim, self.args.nlayers, dropout=self.args.dropout, tie_weights=self.args.tied, input_z=self.args.input_z, mix_unk=self.args.mix_unk, condition=(self.args.cd_bit or self.args.cd_bow), input_cd_bow=self.args.cd_bow, input_cd_bit=self.args.cd_bit) model.load_state_dict(torch.load(os.path.join(path, name + '.model'))) model = model.cuda() return model
def load_model(args, ntoken, path, name): model = RNNVAE(args, args.enc_type, ntoken, args.emsize, args.nhid, args.lat_dim, args.nlayers, dropout=args.dropout, tie_weights=args.tied, input_z=args.input_z, mix_unk=args.mix_unk, condition=(args.cd_bit or args.cd_bow), input_cd_bow=args.cd_bow, input_cd_bit=args.cd_bit) model.load_state_dict(torch.load(os.path.join(path, name + '.model'))) from NVLL.util.gpu_flag import GPU_FLAG if torch.cuda.is_available() and GPU_FLAG: model = model.cuda() model = model.eval() return model
def end(self): # Load the best saved model. model = RNNVAE(self.args, self.args.enc_type, len(self.data.dictionary), self.args.emsize, self.args.nhid, self.args.lat_dim, self.args.nlayers, dropout=self.args.dropout, tie_weights=self.args.tied, input_z=self.args.input_z, mix_unk=self.args.mix_unk, condition=(self.args.cd_bit or self.args.cd_bow), input_cd_bow=self.args.cd_bow, input_cd_bit=self.args.cd_bit) model.load_state_dict(torch.load(self.args.save_name + '.model'), strict=False) model.to(device) # if torch.cuda.is_available() and GPU_FLAG: # model = model.cuda() model = model.eval() print(model) print(self.args) # print("Anneal Type: {}".format(anneal_list[self.args.anneal])) train_loss, train_kl, train_total_loss = self.evaluate( self.args, model, self.data.train) cur_loss, cur_kl, test_loss = self.evaluate(self.args, model, self.data.test) Runner.log_eval(self.writer, None, cur_loss, cur_kl, test_loss, True) os.rename(self.args.save_name + '.model', self.args.save_name + '_' + str(test_loss) + '.model') os.rename(self.args.save_name + '.args', self.args.save_name + '_' + str(test_loss) + '.args') # Write result to board self.write_board(self.args, train_loss, train_kl, train_total_loss, cur_loss, cur_kl, test_loss) self.writer.close()
def main(): args = NVLL.argparser.parse_arg() if args.model == 'nvdm': set_seed(args) args, writer = set_save_name_log_nvdm(args) logging.info("Device Info: {}".format(device)) print("Current dir {}".format(os.getcwd())) from NVLL.data.ng import DataNg # from NVLL.model.nvdm import BowVAE # For rcv use nvdm_v2? if args.data_name == '20ng': from NVLL.model.nvdm import BowVAE elif args.data_name == 'rcv': from NVLL.model.nvdm import BowVAE else: raise NotImplementedError from NVLL.framework.train_eval_nvdm import Runner # Datarcv_Distvmf_Modelnvdm_Emb400_Hid400_lat50 data = DataNg(args) model = BowVAE(args, vocab_size=data.vocab_size, n_hidden=args.nhid, n_lat=args.lat_dim, n_sample=args.nsample, dist=args.dist) # Automatic matching loading if args.load is not None: model.load_state_dict(torch.load(args.load), strict=False) else: print("Auto load temp closed.") """ files = os.listdir(os.path.join(args.exp_path)) files = [f for f in files if f.endswith(".model")] current_name = "Data{}_Dist{}_Model{}_Emb{}_Hid{}_lat{}".format(args.data_name, str(args.dist), args.model, for f in files: if current_name in f: try: model.load_state_dict(torch.load(os.path.join( args.exp_path, f)), strict=False) print("Auto Load success! File Name: {}".format(f)) break except RuntimeError: print("Automatic Load failed!") """ model.to(device) # if torch.cuda.is_available() and GPU_FLAG: # print("Model in GPU") # model = model.cuda() runner = Runner(args, model, data, writer) runner.start() runner.end() elif args.model == 'nvrnn': set_seed(args) args, writer = set_save_name_log_nvrnn(args) logging.info("Device Info: {}".format(device)) print("Current dir {}".format(os.getcwd())) from NVLL.data.lm import DataLM from NVLL.model.nvrnn import RNNVAE from NVLL.framework.train_eval_nvrnn import Runner if (args.data_name == 'ptb') or (args.data_name == 'trec') or (args.data_name == 'yelp_sent'): data = DataLM(os.path.join(args.root_path, args.data_path), args.batch_size, args.eval_batch_size, condition=False) elif args.data_name == 'yelp': data = DataLM(os.path.join(args.root_path, args.data_path), args.batch_size, args.eval_batch_size, condition=True) else: raise NotImplementedError model = RNNVAE(args, args.enc_type, len(data.dictionary), args.emsize, args.nhid, args.lat_dim, args.nlayers, dropout=args.dropout, tie_weights=False, input_z=args.input_z, mix_unk=args.mix_unk, condition=(args.cd_bit or args.cd_bow), input_cd_bow=args.cd_bow, input_cd_bit=args.cd_bit) # Automatic matching loading if args.load is not None: model.load_state_dict(torch.load(args.load), strict=False) else: print("Auto load temp closed.") """ files = os.listdir(os.path.join( args.exp_path)) files = [f for f in files if f.endswith(".model")] current_name = "Data{}_Dist{}_Model{}_Enc{}Bi{}_Emb{}_Hid{}_lat{}".format(args.data_name, str(args.dist), args.model, args.enc_type, args.bi, args.emsize, args.nhid, args.lat_dim) for f in files: if current_name in f and ("mixunk0.0" in f): try: model.load_state_dict(torch.load(os.path.join( args.exp_path, f)), strict=False) print("Auto Load success! {}".format(f)) break except RuntimeError: print("Automatic Load failed!") """ model.to(device) # if torch.cuda.is_available() and GPU_FLAG: # model = model.cuda() runner = Runner(args, model, data, writer) runner.start() runner.end()