def __init__(self, args, condition, c2b, nor): self.data = DataLM(os.path.join(args.root_path, args.data_path), args.batch_size, args.eval_batch_size, condition=condition) self.c2b = c2b if nor: args.model_run = args.model_nor else: args.model_run = args.model_vmf self.args = load_args(args.exp_path, args.model_run) self.model = load_model(self.args, len(self.data.dictionary), args.exp_path, args.model_run) self.learner = Code2Code(self.model.lat_dim, self.model.ninp) self.learner.cuda() self.optim = torch.optim.Adam(self.learner.parameters(), lr=0.001)
def __init__(self, args): self.data = DataLM(os.path.join(args.root_path, args.data_path), args.batch_size, args.eval_batch_size, condition=True) word_list = sorted(self.data.dictionary.word2idx.items(), key=itemgetter(1)) vmf_args = load_args(args.exp_path, args.model_vmf) vmf_model = load_model(vmf_args, len(self.data.dictionary), args.exp_path, args.model_vmf) vmf_emb = vmf_model.emb.weight self.write_word_embedding(args.exp_path, args.model_vmf + '_emb', word_list, vmf_emb) nor_args = load_args(args.exp_path, args.model_nor) nor_model = load_model(nor_args, len(self.data.dictionary), args.exp_path, args.model_nor) nor_emb = nor_model.emb.weight self.write_word_embedding(args.exp_path, args.model_nor + '_emb', word_list, nor_emb)
def load_data(data_path, eval_batch_siez, condition): data = DataLM(data_path, eval_batch_siez, eval_batch_siez, condition) return data
class CodeLearner(): def __init__(self, args, condition, c2b, nor): self.data = DataLM(os.path.join(args.root_path, args.data_path), args.batch_size, args.eval_batch_size, condition=condition) self.c2b = c2b if nor: args.model_run = args.model_nor else: args.model_run = args.model_vmf self.args = load_args(args.exp_path, args.model_run) self.model = load_model(self.args, len(self.data.dictionary), args.exp_path, args.model_run) self.learner = Code2Code(self.model.lat_dim, self.model.ninp) self.learner.cuda() self.optim = torch.optim.Adam(self.learner.parameters(), lr=0.001) def run_train(self): valid_acc = [] for e in range(10): print("EPO: {}".format(e)) self.train_epo(self.data.train) acc = self.evaluate(self.data.test) valid_acc.append(acc) return min(valid_acc) def train_epo(self, train_batches): self.learner.train() print("Epo start") acc_loss = 0 cnt = 0 random.shuffle(train_batches) for idx, batch in enumerate(train_batches): self.optim.zero_grad() seq_len, batch_sz = batch.size() if self.data.condition: seq_len -= 1 if self.model.input_cd_bit > 1: bit = batch[0, :] bit = GVar(bit) else: bit = None batch = batch[1:, :] else: bit = None feed = self.data.get_feed(batch) seq_len, batch_sz = feed.size() emb = self.model.drop(self.model.emb(feed)) if self.model.input_cd_bit > 1: bit = self.model.enc_bit(bit) else: bit = None h = self.model.forward_enc(emb, bit) tup, kld, vecs = self.model.forward_build_lat( h) # batchsz, lat dim if self.model.dist_type == 'vmf': code = tup['mu'] elif self.model.dist_type == 'nor': code = tup['mean'] else: raise NotImplementedError emb = torch.mean(emb, dim=0) if self.c2b: loss = self.learner(code, emb) else: loss = self.learner(code, emb) loss.backward() self.optim.step() acc_loss += loss.data[0] cnt += 1 if idx % 400 == 0 and (idx > 0): print("Training {}".format(acc_loss / cnt)) acc_loss = 0 cnt = 0 def evaluate(self, dev_batches): self.learner.eval() print("Test start") acc_loss = 0 cnt = 0 random.shuffle(dev_batches) for idx, batch in enumerate(dev_batches): self.optim.zero_grad() seq_len, batch_sz = batch.size() if self.data.condition: seq_len -= 1 if self.model.input_cd_bit > 1: bit = batch[0, :] bit = GVar(bit) else: bit = None batch = batch[1:, :] else: bit = None feed = self.data.get_feed(batch) seq_len, batch_sz = feed.size() emb = self.model.drop(self.model.emb(feed)) if self.model.input_cd_bit > 1: bit = self.model.enc_bit(bit) else: bit = None h = self.model.forward_enc(emb, bit) tup, kld, vecs = self.model.forward_build_lat( h) # batchsz, lat dim if self.model.dist_type == 'vmf': code = tup['mu'] elif self.model.dist_type == 'nor': code = tup['mean'] else: raise NotImplementedError emb = torch.mean(emb, dim=0) if self.c2b: loss = self.learner(code, emb) else: loss = self.learner(code, emb) acc_loss += loss.data[0] cnt += 1 if idx % 400 == 0: acc_loss = 0 cnt = 0 # print("===============test===============") # print(acc_loss / cnt) print(acc_loss / cnt) return float(acc_loss / cnt)
def synthesis_bow_rep(args): data = DataLM(os.path.join(args.root_path, args.data_path), args.batch_size, args.eval_batch_size, condition=True)
def main(): args = NVLL.argparser.parse_arg() if args.model == 'nvdm': set_seed(args) args, writer = set_save_name_log_nvdm(args) logging.info("Device Info: {}".format(device)) print("Current dir {}".format(os.getcwd())) from NVLL.data.ng import DataNg # from NVLL.model.nvdm import BowVAE # For rcv use nvdm_v2? if args.data_name == '20ng': from NVLL.model.nvdm import BowVAE elif args.data_name == 'rcv': from NVLL.model.nvdm import BowVAE else: raise NotImplementedError from NVLL.framework.train_eval_nvdm import Runner # Datarcv_Distvmf_Modelnvdm_Emb400_Hid400_lat50 data = DataNg(args) model = BowVAE(args, vocab_size=data.vocab_size, n_hidden=args.nhid, n_lat=args.lat_dim, n_sample=args.nsample, dist=args.dist) # Automatic matching loading if args.load is not None: model.load_state_dict(torch.load(args.load), strict=False) else: print("Auto load temp closed.") """ files = os.listdir(os.path.join(args.exp_path)) files = [f for f in files if f.endswith(".model")] current_name = "Data{}_Dist{}_Model{}_Emb{}_Hid{}_lat{}".format(args.data_name, str(args.dist), args.model, for f in files: if current_name in f: try: model.load_state_dict(torch.load(os.path.join( args.exp_path, f)), strict=False) print("Auto Load success! File Name: {}".format(f)) break except RuntimeError: print("Automatic Load failed!") """ model.to(device) # if torch.cuda.is_available() and GPU_FLAG: # print("Model in GPU") # model = model.cuda() runner = Runner(args, model, data, writer) runner.start() runner.end() elif args.model == 'nvrnn': set_seed(args) args, writer = set_save_name_log_nvrnn(args) logging.info("Device Info: {}".format(device)) print("Current dir {}".format(os.getcwd())) from NVLL.data.lm import DataLM from NVLL.model.nvrnn import RNNVAE from NVLL.framework.train_eval_nvrnn import Runner if (args.data_name == 'ptb') or (args.data_name == 'trec') or (args.data_name == 'yelp_sent'): data = DataLM(os.path.join(args.root_path, args.data_path), args.batch_size, args.eval_batch_size, condition=False) elif args.data_name == 'yelp': data = DataLM(os.path.join(args.root_path, args.data_path), args.batch_size, args.eval_batch_size, condition=True) else: raise NotImplementedError model = RNNVAE(args, args.enc_type, len(data.dictionary), args.emsize, args.nhid, args.lat_dim, args.nlayers, dropout=args.dropout, tie_weights=False, input_z=args.input_z, mix_unk=args.mix_unk, condition=(args.cd_bit or args.cd_bow), input_cd_bow=args.cd_bow, input_cd_bit=args.cd_bit) # Automatic matching loading if args.load is not None: model.load_state_dict(torch.load(args.load), strict=False) else: print("Auto load temp closed.") """ files = os.listdir(os.path.join( args.exp_path)) files = [f for f in files if f.endswith(".model")] current_name = "Data{}_Dist{}_Model{}_Enc{}Bi{}_Emb{}_Hid{}_lat{}".format(args.data_name, str(args.dist), args.model, args.enc_type, args.bi, args.emsize, args.nhid, args.lat_dim) for f in files: if current_name in f and ("mixunk0.0" in f): try: model.load_state_dict(torch.load(os.path.join( args.exp_path, f)), strict=False) print("Auto Load success! {}".format(f)) break except RuntimeError: print("Automatic Load failed!") """ model.to(device) # if torch.cuda.is_available() and GPU_FLAG: # model = model.cuda() runner = Runner(args, model, data, writer) runner.start() runner.end()
def load_data(self, data_path): data = DataLM(data_path, self.args.batch_size, self.args.eval_batch_size) return data