def __init__(self, decode_from, params, cuda=False): self.decode_from = decode_from self.params = params params.enc_nh = params.dec_nh # not sure why this is necessary... self.train_data = MonoTextData(params.train_data, label=False) self.vocab = self.train_data.vocab self.vocab_size = len(self.vocab) # do I need these? model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) params.device = self.device = torch.device("cuda" if cuda else "cpu") self.encoder = LSTMEncoder(params, self.vocab_size, model_init, emb_init) self.decoder = LSTMDecoder(params, self.vocab, model_init, emb_init) self.vae = VAE(self.encoder, self.decoder, params).to(params.device) # assuming models were trained on a gpu... if cuda: self.vae.load_state_dict(torch.load(self.decode_from)) else: self.vae.load_state_dict( torch.load(self.decode_from, map_location='cpu'))
def create_corpus(args): train_data = MonoTextData(args.train_data, label=args.label) vocab = train_data.vocab vocab_size = len(vocab) val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) print('Train data: %d samples' % len(train_data)) print('finish reading datasets, vocab size is %d' % len(vocab)) print('dropped sentences: %d' % train_data.dropped) sys.stdout.flush() return (train_data, val_data, test_data), vocab
def _load_train_data(self): class Defaulter(dict): def __missing__(self, item): return 0 word2idx = Defaulter( **{item: self.bp.emb.vocab[item].index \ for item in self.bp.emb.vocab}) train_data = MonoTextData(self.params.train_data, label=False, vocab=word2idx) return train_data
def read_dataset(fn, vocab, device, model, classify_using_samples): """ Read dataset in file fn with vocab and return (codes, labels) """ data = MonoTextData(fn, label=True, vocab=vocab) data_batch, labels_batch = data.create_data_batch_labels(batch_size=1, device=device, batch_first=True) with torch.no_grad(): labels, codes = [], [] for i in np.random.permutation(len(data_batch)): batch_data = data_batch[i] batch_labels = labels_batch[i] batch_labels = [int(x) for x in batch_labels] labels_ = torch.tensor(batch_labels, dtype=torch.long, requires_grad=False, device=device) batch_size, sent_len = batch_data.size() if classify_using_samples: # to use samples, we can't simply encode more samples # need to keep the correspondance between z and x # so that we can cross-validate without leaks if we want to resample params, KL = model.encode(batch_data, 1, return_parameters=True) codes_ = torch.cat(params, dim=1) else: params, KL = model.encode(batch_data, 1, return_parameters=True) if type(params) == tuple: codes_ = params[0] # mean else: # Gumbel Softmax codes_ = params labels.append(labels_) codes.append(codes_) codes = torch.cat(codes, 0) labels = torch.stack(labels).reshape((-1, )) return (codes, labels)
def main(args): class uniform_initializer(object): def __init__(self, stdv): self.stdv = stdv def __call__(self, tensor): nn.init.uniform_(tensor, -self.stdv, self.stdv) class xavier_normal_initializer(object): def __call__(self, tensor): nn.init.xavier_normal_(tensor) if args.cuda: print('using cuda') print(args) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} train_data = MonoTextData(args.train_data, label=args.label) vocab = train_data.vocab vocab_size = len(vocab) val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) print('Train data: %d samples' % len(train_data)) print('finish reading datasets, vocab size is %d' % len(vocab)) print('dropped sentences: %d' % train_data.dropped) sys.stdout.flush() log_niter = (len(train_data) // args.batch_size) // 10 model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) if args.enc_type == 'lstm': encoder = LSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh else: raise ValueError("the specified encoder type is not supported") decoder = LSTMDecoder(args, vocab, model_init, emb_init) device = torch.device("cuda" if args.cuda else "cpu") args.device = device vae = VAE(encoder, decoder, args).to(device) if args.eval: print('begin evaluation') vae.load_state_dict(torch.load(args.load_path)) vae.eval() with torch.no_grad(): test_data_batch = test_data.create_data_batch( batch_size=args.batch_size, device=device, batch_first=True) test(vae, test_data_batch, "TEST", args) au, au_var = calc_au(vae, test_data_batch) print("%d active units" % au) # print(au_var) test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) calc_iwnll(vae, test_data_batch, args) return enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=1.0, momentum=args.momentum) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=1.0, momentum=args.momentum) opt_dict['lr'] = 1.0 iter_ = decay_cnt = 0 best_loss = 1e4 best_kl = best_nll = best_ppl = 0 pre_mi = 0 aggressive_flag = True if args.aggressive else False vae.train() start = time.time() kl_weight = args.kl_start anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size)) train_data_batch = train_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) for epoch in range(args.epochs): report_kl_loss = report_rec_loss = 0 report_num_words = report_num_sents = 0 for i in np.random.permutation(len(train_data_batch)): batch_data = train_data_batch[i] batch_size, sent_len = batch_data.size() # not predict start symbol report_num_words += (sent_len - 1) * batch_size report_num_sents += batch_size # kl_weight = 1.0 kl_weight = min(1.0, kl_weight + anneal_rate) sub_iter = 1 batch_data_enc = batch_data burn_num_words = 0 burn_pre_loss = 1e4 burn_cur_loss = 0 while aggressive_flag and sub_iter < 100: enc_optimizer.zero_grad() dec_optimizer.zero_grad() burn_batch_size, burn_sents_len = batch_data_enc.size() burn_num_words += (burn_sents_len - 1) * burn_batch_size loss, loss_rc, loss_kl = vae.loss(batch_data_enc, kl_weight, nsamples=args.nsamples) burn_cur_loss += loss.sum().item() loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) enc_optimizer.step() id_ = np.random.random_integers(0, len(train_data_batch) - 1) batch_data_enc = train_data_batch[id_] if sub_iter % 15 == 0: burn_cur_loss = burn_cur_loss / burn_num_words if burn_pre_loss - burn_cur_loss < 0: break burn_pre_loss = burn_cur_loss burn_cur_loss = burn_num_words = 0 sub_iter += 1 # if sub_iter >= 30: # break # print(sub_iter) enc_optimizer.zero_grad() dec_optimizer.zero_grad() loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) loss_rc = loss_rc.sum() loss_kl = loss_kl.sum() if not aggressive_flag: enc_optimizer.step() dec_optimizer.step() report_rec_loss += loss_rc.item() report_kl_loss += loss_kl.item() if iter_ % log_niter == 0: train_loss = (report_rec_loss + report_kl_loss) / report_num_sents if aggressive_flag or epoch == 0: vae.eval() with torch.no_grad(): mi = calc_mi(vae, val_data_batch) au, _ = calc_au(vae, val_data_batch) vae.train() print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \ 'au %d, time elapsed %.2fs' % (epoch, iter_, train_loss, report_kl_loss / report_num_sents, mi, report_rec_loss / report_num_sents, au, time.time() - start)) else: print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \ 'time elapsed %.2fs' % (epoch, iter_, train_loss, report_kl_loss / report_num_sents, report_rec_loss / report_num_sents, time.time() - start)) sys.stdout.flush() report_rec_loss = report_kl_loss = 0 report_num_words = report_num_sents = 0 iter_ += 1 if aggressive_flag and (iter_ % len(train_data_batch)) == 0: vae.eval() cur_mi = calc_mi(vae, val_data_batch) vae.train() print("pre mi:%.4f. cur mi:%.4f" % (pre_mi, cur_mi)) if cur_mi - pre_mi < 0: aggressive_flag = False print("STOP BURNING") pre_mi = cur_mi print('kl weight %.4f' % kl_weight) vae.eval() with torch.no_grad(): loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args) au, au_var = calc_au(vae, val_data_batch) print("%d active units" % au) # print(au_var) if loss < best_loss: print('update best loss') best_loss = loss best_nll = nll best_kl = kl best_ppl = ppl torch.save(vae.state_dict(), args.save_path) if loss > opt_dict["best_loss"]: opt_dict["not_improved"] += 1 if opt_dict["not_improved"] >= decay_epoch and epoch >= 15: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay vae.load_state_dict(torch.load(args.save_path)) print('new lr: %f' % opt_dict["lr"]) decay_cnt += 1 enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break if epoch % args.test_nepoch == 0: with torch.no_grad(): loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) vae.train() # compute importance weighted estimate of log p(x) vae.load_state_dict(torch.load(args.save_path)) vae.eval() with torch.no_grad(): loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) au, au_var = calc_au(vae, test_data_batch) print("%d active units" % au) # print(au_var) test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) with torch.no_grad(): calc_iwnll(vae, test_data_batch, args)
class VAESampler: def __init__(self, decode_from, params, cuda=False): self.decode_from = decode_from self.params = params params.enc_nh = params.dec_nh # not sure why this is necessary... self.train_data = MonoTextData(params.train_data, label=False) self.vocab = self.train_data.vocab self.vocab_size = len(self.vocab) # do I need these? model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) params.device = self.device = torch.device("cuda" if cuda else "cpu") self.encoder = LSTMEncoder(params, self.vocab_size, model_init, emb_init) self.decoder = LSTMDecoder(params, self.vocab, model_init, emb_init) self.vae = VAE(self.encoder, self.decoder, params).to(params.device) # assuming models were trained on a gpu... if cuda: self.vae.load_state_dict(torch.load(self.decode_from)) else: self.vae.load_state_dict( torch.load(self.decode_from, map_location='cpu')) def to_s(self, decoded): return [' '.join(item) for item in decoded] def beam(self, z, K=5): decoded_batch = self.vae.decoder.beam_search_decode(z, K) return self.to_s(decoded_batch) def sample(self, z, temperature=1.0): decoded_batch = self.vae.decoder.sample_decode(z, temperature) return self.to_s(decoded_batch) def greedy(self, z): decoded_batch = self.vae.decoder.greedy_decode(z) return self.to_s(decoded_batch) def str2ids(self, s): "encode string s as list of word ids" raise NotImplemented def encode(self, t): """ Returns (z, mu, log_var) from encoder given list of strings. z is a sample from gaussian specified with (mu, log_var) """ str_ids = [] for s in t: ids = self.str2ids(s) str_ids.append(ids) tensor = self.train_data._to_tensor(str_ids, True, self.device)[0] z, (mu, log_var) = self.vae.encoder.sample(tensor, 1) return z, mu, log_var def z(self, t): "return sampled latent zs for list of strings t" z, mu, logvar = self.encode(t) return z.squeeze(1) def mu(self, t): "return mean of latent gaussian for list of strings t" z, mu, logvar = self.encode(t) return mu.squeeze(1)
parser.add_argument('--dec_h_dim', default=1024, type=int) parser.add_argument('--dec_num_layers', default=1, type=int) parser.add_argument('--dec_dropout', default=0.5, type=float) parser.add_argument('--num_nu_updates', default=5, type=int) parser.add_argument('--nu_lr', default=1e-5, type=float) parser.add_argument('--end2end_lr', default=8e-4, type=float) parser.add_argument('--max_grad_norm', default=5.0, type=float) if sys.argv[1:] == ['0', '0']: args = parser.parse_args([]) # run in pycharm console else: args = parser.parse_args() # run in cmd # parameters train_data_all = MonoTextData(args.train_data, label=True) vocab = train_data_all.vocab vocab_size = len(vocab) val_data_all = MonoTextData(args.val_data, label=True, vocab=vocab) test_data_all = MonoTextData(args.test_data, label=True, vocab=vocab) print('Batch size: %d' % args.batch_size) print('Train data: %d sentences' % len(train_data_all)) print('Val data: %d sentences' % len(val_data_all)) print('Test data: %d sentences' % len(test_data_all)) print('finish reading datasets, vocab size is %d' % len(vocab)) print('dropped sentences: %d' % train_data_all.dropped) results_folder = args.results_folder_prefix + args.model + '/' if not os.path.exists(results_folder): os.makedirs(results_folder) logging.basicConfig(filename=os.path.join(results_folder, args.log_prefix+'.log'),
def main(args): train_data = MonoTextData(args.train_data, label=args.label) vocab = train_data.vocab vocab_size = len(vocab) vocab_path = os.path.join("/".join(args.train_data.split("/")[:-1]), "vocab.txt") with open(vocab_path, "w") as fout: for i in range(vocab_size): fout.write("{}\n".format(vocab.id2word(i))) #return val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) print('Train data: %d samples' % len(train_data)) print('finish reading datasets, vocab size is %d' % len(vocab)) print('dropped sentences: %d' % train_data.dropped) sys.stdout.flush() log_niter = (len(train_data)//args.batch_size)//10 model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) #device = torch.device("cuda" if args.cuda else "cpu") device = "cuda" if args.cuda else "cpu" args.device = device if args.enc_type == 'lstm': encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh else: raise ValueError("the specified encoder type is not supported") decoder = LSTMDecoder(args, vocab, model_init, emb_init) vae = VAE(encoder, decoder, args).to(device) print('begin evaluation') vae.load_state_dict(torch.load(args.load_path)) vae.eval() with torch.no_grad(): test_data_batch, test_batch_labels = test_data.create_data_batch_labels(batch_size=args.batch_size, device=device, batch_first=True) # test(vae, test_data_batch, "TEST", args) # au, au_var = calc_au(vae, test_data_batch) # print("%d active units" % au) train_data_batch, train_batch_labels = train_data.create_data_batch_labels(batch_size=args.batch_size, device=device, batch_first=True) val_data_batch, val_batch_labels = val_data.create_data_batch_labels(batch_size=args.batch_size, device=device, batch_first=True) print("getting vectors for training") print(args.save_dir) save_latents(args, vae, train_data_batch, train_batch_labels, "train") print("getting vectors for validating") save_latents(args, vae, val_data_batch, val_batch_labels, "val") print("getting vectors for testing") save_latents(args, vae, test_data_batch, test_batch_labels, "test")
def main(args): class uniform_initializer(object): def __init__(self, stdv): self.stdv = stdv def __call__(self, tensor): nn.init.uniform_(tensor, -self.stdv, self.stdv) class xavier_normal_initializer(object): def __call__(self, tensor): nn.init.xavier_normal_(tensor) if args.cuda: print('using cuda') print(args) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} train_data = MonoTextData(args.train_data) vocab = train_data.vocab vocab_size = len(vocab) val_data = MonoTextData(args.val_data, vocab=vocab) test_data = MonoTextData(args.test_data, vocab=vocab) print('Train data: %d samples' % len(train_data)) print('finish reading datasets, vocab size is %d' % len(vocab)) print('dropped sentences: %d' % train_data.dropped) sys.stdout.flush() log_niter = (len(train_data) // args.batch_size) // 10 model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) device = torch.device("cuda" if args.cuda else "cpu") args.device = device encoder = LSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh decoder = LSTMDecoder(args, vocab, model_init, emb_init) vae = VAE(encoder, decoder, args).to(device) if args.optim == 'sgd': enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=1.0) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=1.0) opt_dict['lr'] = 1.0 else: enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001, betas=(0.9, 0.999)) dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001, betas=(0.9, 0.999)) opt_dict['lr'] = 0.001 iter_ = decay_cnt = 0 best_loss = 1e4 best_kl = best_nll = best_ppl = 0 pre_mi = -1 aggressive_flag = True if args.aggressive else False vae.train() start = time.time() kl_weight = args.kl_start anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size)) plot_data = train_data.data_sample(nsample=args.num_plot, device=device, batch_first=True) if args.plot_mode == 'multiple': grid_z = generate_grid(args.zmin, args.zmax, args.dz, device, ndim=1) plot_fn = plot_multiple elif args.plot_mode == 'single': grid_z = generate_grid(args.zmin, args.zmax, args.dz, device, ndim=1) plot_fn = plot_single posterior_mean = [] infer_mean = [] posterior_mean.append( vae.calc_model_posterior_mean(plot_data[0], grid_z)) infer_mean.append(vae.calc_infer_mean(plot_data[0])) train_data_batch = train_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) for epoch in range(args.epochs): report_kl_loss = report_rec_loss = 0 report_num_words = report_num_sents = 0 for i in np.random.permutation(len(train_data_batch)): if args.plot_mode == "single": batch_data, _ = plot_data else: batch_data = train_data_batch[i] batch_size, sent_len = batch_data.size() # not predict start symbol report_num_words += (sent_len - 1) * batch_size report_num_sents += batch_size # kl_weight = 1.0 kl_weight = min(1.0, kl_weight + anneal_rate) sub_iter = 1 batch_data_enc = batch_data burn_num_words = 0 burn_pre_loss = 1e4 burn_cur_loss = 0 while aggressive_flag and sub_iter < 100: enc_optimizer.zero_grad() dec_optimizer.zero_grad() burn_batch_size, burn_sents_len = batch_data_enc.size() burn_num_words += (burn_sents_len - 1) * burn_batch_size loss, loss_rc, loss_kl = vae.loss(batch_data_enc, kl_weight, nsamples=args.nsamples) burn_cur_loss += loss.sum().item() loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) enc_optimizer.step() if args.plot_mode == "single": batch_data_enc, _ = plot_data else: id_ = np.random.random_integers(0, len(train_data_batch) - 1) batch_data_enc = train_data_batch[id_] if sub_iter % 15 == 0: burn_cur_loss = burn_cur_loss / burn_num_words if burn_pre_loss - burn_cur_loss < 0: break burn_pre_loss = burn_cur_loss burn_cur_loss = burn_num_words = 0 sub_iter += 1 if args.plot_mode == 'single' and epoch == 0 and aggressive_flag: vae.eval() with torch.no_grad(): posterior_mean.append(posterior_mean[-1]) infer_mean.append(vae.calc_infer_mean(plot_data[0])) vae.train() enc_optimizer.zero_grad() dec_optimizer.zero_grad() loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) loss_rc = loss_rc.sum() loss_kl = loss_kl.sum() if not aggressive_flag: enc_optimizer.step() dec_optimizer.step() if args.plot_mode == 'single' and epoch == 0: vae.eval() with torch.no_grad(): posterior_mean.append( vae.calc_model_posterior_mean(plot_data[0], grid_z)) if aggressive_flag: infer_mean.append(infer_mean[-1]) else: infer_mean.append(vae.calc_infer_mean(plot_data[0])) vae.train() report_rec_loss += loss_rc.item() report_kl_loss += loss_kl.item() if iter_ % log_niter == 0: train_loss = (report_rec_loss + report_kl_loss) / report_num_sents if aggressive_flag or epoch == 0: vae.eval() mi = calc_mi(vae, val_data_batch) vae.train() print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \ 'time elapsed %.2fs' % (epoch, iter_, train_loss, report_kl_loss / report_num_sents, mi, report_rec_loss / report_num_sents, time.time() - start)) else: print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \ 'time elapsed %.2fs' % (epoch, iter_, train_loss, report_kl_loss / report_num_sents, report_rec_loss / report_num_sents, time.time() - start)) sys.stdout.flush() report_rec_loss = report_kl_loss = 0 report_num_words = report_num_sents = 0 if iter_ % args.plot_niter == 0 and epoch == 0: vae.eval() with torch.no_grad(): if args.plot_mode == 'single' and iter_ != 0: plot_fn(infer_mean, posterior_mean, args) return elif args.plot_mode == "multiple": plot_fn(vae, plot_data, grid_z, iter_, args) vae.train() iter_ += 1 if aggressive_flag and (iter_ % len(train_data_batch)) == 0: vae.eval() cur_mi = calc_mi(vae, val_data_batch) vae.train() if cur_mi - pre_mi < 0: aggressive_flag = False print("STOP BURNING") pre_mi = cur_mi # return print('kl weight %.4f' % kl_weight) print('epoch: %d, VAL' % epoch) with torch.no_grad(): plot_fn(vae, plot_data, grid_z, iter_, args) vae.eval() with torch.no_grad(): loss, nll, kl, ppl = test(vae, val_data_batch, "VAL", args) if loss < best_loss: print('update best loss') best_loss = loss best_nll = nll best_kl = kl best_ppl = ppl torch.save(vae.state_dict(), args.save_path) if loss > opt_dict["best_loss"]: opt_dict["not_improved"] += 1 if opt_dict["not_improved"] >= decay_epoch: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay vae.load_state_dict(torch.load(args.save_path)) print('new lr: %f' % opt_dict["lr"]) decay_cnt += 1 if args.optim == 'sgd': enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=opt_dict["lr"]) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=opt_dict["lr"]) else: enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=opt_dict["lr"], betas=(0.5, 0.999)) dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=opt_dict["lr"], betas=(0.5, 0.999)) else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break if epoch % args.test_nepoch == 0: with torch.no_grad(): loss, nll, kl, ppl = test(vae, test_data_batch, "TEST", args) vae.train() print('best_loss: %.4f, kl: %.4f, nll: %.4f, ppl: %.4f' \ % (best_loss, best_kl, best_nll, best_ppl)) sys.stdout.flush() # compute importance weighted estimate of log p(x) vae.load_state_dict(torch.load(args.save_path)) vae.eval() test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) with torch.no_grad(): calc_iwnll(vae, test_data_batch, args)
def main(args): global logging logging = create_exp_dir(args.exp_dir, scripts_to_save=["text_cyc_anneal.py"]) if args.cuda: logging('using cuda') logging('model saving path: %s' % args.save_path) logging(str(args)) opt_dict = {"not_improved": 0, "lr": args.lr, "best_loss": 1e4} train_data = MonoTextData(args.train_data) vocab = train_data.vocab vocab_size = len(vocab) val_data = MonoTextData(args.val_data, vocab=vocab) test_data = MonoTextData(args.test_data, vocab=vocab) logging('Train data: %d samples' % len(train_data)) logging('finish reading datasets, vocab size is %d' % len(vocab)) logging('dropped sentences: %d' % train_data.dropped) sys.stdout.flush() model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) device = torch.device("cuda" if args.cuda else "cpu") args.device = device lm = LSTM_LM(args, vocab, model_init, emb_init).to(device) if args.load_path: loaded_state_dict = torch.load(args.load_path) lm.load_state_dict(loaded_state_dict) logging("%s loaded" % args.load_path) if args.opt == "sgd": optimizer = optim.SGD(lm.parameters(), lr=args.lr, momentum=args.momentum) opt_dict['lr'] = args.lr elif args.opt == "adam": optimizer = optim.Adam(lm.parameters(), lr=args.lr) opt_dict['lr'] = args.lr else: raise ValueError("optimizer not supported") iter_ = decay_cnt = 0 best_loss = 1e4 best_nll = best_ppl = 0 lm.train() start = time.time() train_data_batch = train_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) for epoch in range(args.epochs): report_loss = 0 report_num_words = report_num_sents = 0 for i in np.random.permutation(len(train_data_batch)): batch_data = train_data_batch[i] batch_size, sent_len = batch_data.size() # not predict start symbol report_num_words += (sent_len - 1) * batch_size report_num_sents += batch_size optimizer.zero_grad() loss = lm.reconstruct_error(batch_data) report_loss += loss.sum().item() loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(lm.parameters(), args.clip_grad) optimizer.step() if iter_ % args.log_niter == 0: train_loss = report_loss / report_num_sents logging('epoch: %d, iter: %d, avg_loss: %.4f, time elapsed %.2fs' % (epoch, iter_, train_loss, time.time() - start)) sys.stdout.flush() iter_ += 1 if epoch % args.test_nepoch == 0: #logging('epoch: %d, testing' % epoch) lm.eval() with torch.no_grad(): nll, ppl = test(lm, test_data_batch, args) logging('test | epoch: %d, nll: %.4f, ppl: %.4f' % (epoch, nll, ppl)) lm.train() lm.eval() with torch.no_grad(): nll, ppl = test(lm, val_data_batch, args) logging('valid | epoch: %d, nll: %.4f, ppl: %.4f' % (epoch, nll, ppl)) if nll < best_loss: logging('update best loss') best_loss = nll best_nll = nll best_ppl = ppl torch.save(lm.state_dict(), args.save_path) if nll > opt_dict["best_loss"]: opt_dict["not_improved"] += 1 if opt_dict["not_improved"] >= args.decay_epoch: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * args.lr_decay lm.load_state_dict(torch.load(args.save_path)) logging('new lr: %f' % opt_dict["lr"]) decay_cnt += 1 if args.opt == "sgd": optimizer = optim.SGD(lm.parameters(), lr=opt_dict["lr"], momentum=args.momentum) elif args.opt == "adam": optimizer = optim.Adam(lm.parameters(), lr=opt_dict["lr"]) else: raise ValueError("optimizer not supported") else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = nll if decay_cnt == max_decay: break lm.train() logging('valid | best_loss: %.4f, nll: %.4f, ppl: %.4f' \ % (best_loss, best_nll, best_ppl)) # reload best lm model lm.load_state_dict(torch.load(args.save_path)) with torch.no_grad(): nll, ppl = test(lm, test_data_batch, args) logging('test | nll: %.4f, ppl: %.4f' % (nll, ppl))
def main(args): global logging debug = (args.reconstruct_from != "" or args.eval == True) # don't make exp dir for reconstruction logging = create_exp_dir(args.exp_dir, scripts_to_save=None, debug=debug) if args.cuda: logging('using cuda') logging(str(args)) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} train_data = MonoTextData(args.train_data, label=args.label) vocab = train_data.vocab vocab_size = len(vocab) val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) logging('Train data: %d samples' % len(train_data)) logging('finish reading datasets, vocab size is %d' % len(vocab)) logging('dropped sentences: %d' % train_data.dropped) #sys.stdout.flush() log_niter = (len(train_data) // args.batch_size) // 10 model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) #device = torch.device("cuda" if args.cuda else "cpu") device = "cuda" if args.cuda else "cpu" args.device = device if args.enc_type == 'lstm': encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh else: raise ValueError("the specified encoder type is not supported") decoder = LSTMDecoder(args, vocab, model_init, emb_init) vae = VAE(encoder, decoder, args).to(device) if args.load_path: loaded_state_dict = torch.load(args.load_path) #curr_state_dict = vae.state_dict() #curr_state_dict.update(loaded_state_dict) vae.load_state_dict(loaded_state_dict) logging("%s loaded" % args.load_path) if args.reset_dec: vae.decoder.reset_parameters(model_init, emb_init) if args.eval: logging('begin evaluation') vae.load_state_dict(torch.load(args.load_path)) vae.eval() with torch.no_grad(): test_data_batch = test_data.create_data_batch( batch_size=args.batch_size, device=device, batch_first=True) test(vae, test_data_batch, "TEST", args) au, au_var = calc_au(vae, test_data_batch) logging("%d active units" % au) # print(au_var) test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) nll, ppl = calc_iwnll(vae, test_data_batch, args) logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl)) return if args.reconstruct_from != "": print("begin decoding") sys.stdout.flush() vae.load_state_dict(torch.load(args.reconstruct_from)) vae.eval() with torch.no_grad(): test_data_batch = test_data.create_data_batch( batch_size=args.batch_size, device=device, batch_first=True) # test(vae, test_data_batch, "TEST", args) reconstruct(vae, test_data_batch, vocab, args.decoding_strategy, args.reconstruct_to) return if args.opt == "sgd": enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=args.lr, momentum=args.momentum) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=args.lr, momentum=args.momentum) opt_dict['lr'] = args.lr elif args.opt == "adam": enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001) dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001) opt_dict['lr'] = 0.001 else: raise ValueError("optimizer not supported") iter_ = decay_cnt = 0 best_loss = 1e4 best_kl = best_nll = best_ppl = 0 pre_mi = 0 vae.train() start = time.time() train_data_batch = train_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) # At any point you can hit Ctrl + C to break out of training early. try: for epoch in range(args.epochs): report_kl_loss = report_rec_loss = report_loss = 0 report_num_words = report_num_sents = 0 for i in np.random.permutation(len(train_data_batch)): batch_data = train_data_batch[i] batch_size, sent_len = batch_data.size() # not predict start symbol report_num_words += (sent_len - 1) * batch_size report_num_sents += batch_size kl_weight = args.beta enc_optimizer.zero_grad() dec_optimizer.zero_grad() if args.iw_train_nsamples < 0: loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) else: loss, loss_rc, loss_kl = vae.loss_iw( batch_data, kl_weight, nsamples=args.iw_train_nsamples, ns=ns) loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) loss_rc = loss_rc.sum() loss_kl = loss_kl.sum() enc_optimizer.step() dec_optimizer.step() report_rec_loss += loss_rc.item() report_kl_loss += loss_kl.item() report_loss += loss.item() * batch_size if iter_ % log_niter == 0: #train_loss = (report_rec_loss + report_kl_loss) / report_num_sents train_loss = report_loss / report_num_sents logging('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \ 'time elapsed %.2fs, kl_weight %.4f' % (epoch, iter_, train_loss, report_kl_loss / report_num_sents, report_rec_loss / report_num_sents, time.time() - start, kl_weight)) #sys.stdout.flush() report_rec_loss = report_kl_loss = report_loss = 0 report_num_words = report_num_sents = 0 iter_ += 1 logging('kl weight %.4f' % kl_weight) vae.eval() with torch.no_grad(): loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args) au, au_var = calc_au(vae, val_data_batch) logging("%d active units" % au) # print(au_var) if args.save_ckpt > 0 and epoch <= args.save_ckpt: logging('save checkpoint') torch.save( vae.state_dict(), os.path.join(args.exp_dir, f'model_ckpt_{epoch}.pt')) if loss < best_loss: logging('update best loss') best_loss = loss best_nll = nll best_kl = kl best_ppl = ppl torch.save(vae.state_dict(), args.save_path) if loss > opt_dict["best_loss"]: opt_dict["not_improved"] += 1 if opt_dict[ "not_improved"] >= decay_epoch and epoch >= args.load_best_epoch: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay vae.load_state_dict(torch.load(args.save_path)) logging('new lr: %f' % opt_dict["lr"]) decay_cnt += 1 enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break if epoch % args.test_nepoch == 0: with torch.no_grad(): loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) if args.save_latent > 0 and epoch <= args.save_latent: visualize_latent(args, epoch, vae, "cuda", test_data) vae.train() except KeyboardInterrupt: logging('-' * 100) logging('Exiting from training early') # compute importance weighted estimate of log p(x) vae.load_state_dict(torch.load(args.save_path)) vae.eval() with torch.no_grad(): loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) au, au_var = calc_au(vae, test_data_batch) logging("%d active units" % au) # print(au_var) test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) with torch.no_grad(): nll, ppl = calc_iwnll(vae, test_data_batch, args) logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl))
def main(args): global logging logging = create_exp_dir(args.exp_dir, scripts_to_save=[]) if args.cuda: logging('using cuda') logging(str(args)) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} vocab = {} with open(args.vocab_file) as fvocab: for i, line in enumerate(fvocab): vocab[line.strip()] = i vocab = VocabEntry(vocab) train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab) vocab_size = len(vocab) val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) logging('Train data: %d samples' % len(train_data)) logging('finish reading datasets, vocab size is %d' % len(vocab)) logging('dropped sentences: %d' % train_data.dropped) #sys.stdout.flush() log_niter = max(1, (len(train_data) // (args.batch_size * args.update_every)) // 10) model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) #device = torch.device("cuda" if args.cuda else "cpu") device = "cuda" if args.cuda else "cpu" args.device = device if args.fb == 3: encoder = DeltaGaussianLSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh elif args.enc_type == 'lstm': encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh else: raise ValueError("the specified encoder type is not supported") decoder = LSTMDecoder(args, vocab, model_init, emb_init) vae = VAE(encoder, decoder, args).to(device) if args.load_path: loaded_state_dict = torch.load(args.load_path) #curr_state_dict = vae.state_dict() #curr_state_dict.update(loaded_state_dict) vae.load_state_dict(loaded_state_dict) logging("%s loaded" % args.load_path) # if args.eval: # logging('begin evaluation') # vae.load_state_dict(torch.load(args.load_path)) # vae.eval() # with torch.no_grad(): # test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, # device=device, # batch_first=True) # test(vae, test_data_batch, test_labels_batch, "TEST", args) # au, au_var = calc_au(vae, test_data_batch) # logging("%d active units" % au) # # print(au_var) # test_data_batch = test_data.create_data_batch(batch_size=1, # device=device, # batch_first=True) # calc_iwnll(vae, test_data_batch, args) # return if args.discriminator == "linear": discriminator = LinearDiscriminator(args, vae.encoder).to(device) elif args.discriminator == "mlp": discriminator = MLPDiscriminator(args, vae.encoder).to(device) if args.opt == "sgd": optimizer = optim.SGD(discriminator.parameters(), lr=args.lr, momentum=args.momentum) opt_dict['lr'] = args.lr elif args.opt == "adam": optimizer = optim.Adam(discriminator.parameters(), lr=0.001) opt_dict['lr'] = 0.001 else: raise ValueError("optimizer not supported") iter_ = decay_cnt = 0 best_loss = 1e4 # best_kl = best_nll = best_ppl = 0 # pre_mi = 0 discriminator.train() start = time.time() # kl_weight = args.kl_start # if args.warm_up > 0: # anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size)) # else: # anneal_rate = 0 # dim_target_kl = args.target_kl / float(args.nz) train_data_batch, train_labels_batch = train_data.create_data_batch_labels( batch_size=args.batch_size, device=device, batch_first=True) val_data_batch, val_labels_batch = val_data.create_data_batch_labels( batch_size=128, device=device, batch_first=True) test_data_batch, test_labels_batch = test_data.create_data_batch_labels( batch_size=128, device=device, batch_first=True) acc_cnt = 1 acc_loss = 0. for epoch in range(args.epochs): report_loss = 0 report_correct = report_num_words = report_num_sents = 0 acc_batch_size = 0 optimizer.zero_grad() for i in np.random.permutation(len(train_data_batch)): batch_data = train_data_batch[i] if batch_data.size(0) < 2: continue batch_labels = train_labels_batch[i] batch_labels = [int(x) for x in batch_labels] batch_labels = torch.tensor(batch_labels, dtype=torch.long, requires_grad=False, device=device) batch_size, sent_len = batch_data.size() # not predict start symbol report_num_words += (sent_len - 1) * batch_size report_num_sents += batch_size acc_batch_size += batch_size # (batch_size) loss, correct = discriminator.get_performance( batch_data, batch_labels) acc_loss = acc_loss + loss.sum() if acc_cnt % args.update_every == 0: acc_loss = acc_loss / acc_batch_size acc_loss.backward() torch.nn.utils.clip_grad_norm_(discriminator.parameters(), clip_grad) optimizer.step() optimizer.zero_grad() acc_cnt = 0 acc_loss = 0 acc_batch_size = 0 acc_cnt += 1 report_loss += loss.sum().item() report_correct += correct if iter_ % log_niter == 0: #train_loss = (report_rec_loss + report_kl_loss) / report_num_sents train_loss = report_loss / report_num_sents logging('epoch: %d, iter: %d, avg_loss: %.4f, acc %.4f,' \ 'time %.2fs' % (epoch, iter_, train_loss, report_correct / report_num_sents, time.time() - start)) #sys.stdout.flush() iter_ += 1 logging('lr {}'.format(opt_dict["lr"])) print(report_num_sents) discriminator.eval() with torch.no_grad(): loss, acc = test(discriminator, val_data_batch, val_labels_batch, "VAL", args) # print(au_var) if loss < best_loss: logging('update best loss') best_loss = loss best_acc = acc print(args.save_path) torch.save(discriminator.state_dict(), args.save_path) if loss > opt_dict["best_loss"]: opt_dict["not_improved"] += 1 if opt_dict[ "not_improved"] >= decay_epoch and epoch >= args.load_best_epoch: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay discriminator.load_state_dict(torch.load(args.save_path)) logging('new lr: %f' % opt_dict["lr"]) decay_cnt += 1 if args.opt == "sgd": optimizer = optim.SGD(discriminator.parameters(), lr=opt_dict["lr"], momentum=args.momentum) opt_dict['lr'] = opt_dict["lr"] elif args.opt == "adam": optimizer = optim.Adam(discriminator.parameters(), lr=opt_dict["lr"]) opt_dict['lr'] = opt_dict["lr"] else: raise ValueError("optimizer not supported") else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break if epoch % args.test_nepoch == 0: with torch.no_grad(): loss, acc = test(discriminator, test_data_batch, test_labels_batch, "TEST", args) discriminator.train() # compute importance weighted estimate of log p(x) discriminator.load_state_dict(torch.load(args.save_path)) discriminator.eval() with torch.no_grad(): loss, acc = test(discriminator, test_data_batch, test_labels_batch, "TEST", args)
def main(args, args_model): global logging logging = get_logger_existing_dir(os.path.dirname(args.load_path), 'log_classifier.txt') if args.cuda: logging('using cuda') logging(str(args)) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} vocab = {} if getattr(args, 'vocab_file', None) is not None: with open(args.vocab_file) as fvocab: for i, line in enumerate(fvocab): vocab[line.strip()] = i vocab = VocabEntry(vocab) filename_glob = args.train_data + '.seed_*.n_' + str( args.num_label_per_class) train_sets = glob.glob(filename_glob) print("Train sets:", train_sets) main_train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab) vocab = main_train_data.vocab vocab_size = len(vocab) logging('finish reading datasets, vocab size is %d' % len(vocab)) #sys.stdout.flush() model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) #device = torch.device("cuda" if args.cuda else "cpu") device = "cuda" if args.cuda else "cpu" args_model.device = device if args_model.enc_type == 'lstm': args_model.pooling = getattr(args_model, 'pooling', None) encoder = GaussianLSTMEncoder( args_model, vocab_size, model_init, emb_init, pooling=args_model.pooling, ) elif args_model.enc_type in ['max_avg_pool', 'max_pool', 'avg_pool']: args_model.skip_first_word = getattr(args_model, 'skip_first_word', None) encoder = GaussianPoolEncoder( args_model, vocab_size, model_init, emb_init, enc_type=args_model.enc_type, skip_first_word=args_model.skip_first_word) #args.enc_nh = args.dec_nh else: raise ValueError("the specified encoder type is not supported") args_model.encode_length = getattr(args_model, 'encode_length', None) if args_model.dec_type == 'lstm': decoder = LSTMDecoder(args_model, vocab, model_init, emb_init, args_model.encode_length) elif args_model.dec_type == 'unigram': decoder = UnigramDecoder(args_model, vocab, model_init, emb_init) vae = VAE(encoder, decoder, args_model, args_model.encode_length).to(device) if args.load_path: print("load args!") print(vae) loaded_state_dict = torch.load(args.load_path) vae.load_state_dict(loaded_state_dict) logging("%s loaded" % args.load_path) vae.eval() def preprocess(data_fn): codes, labels = read_dataset(data_fn, vocab, device, vae, args.classify_using_samples) if args.classify_using_samples: is_gaussian_enc = codes.shape[1] == (vae.encoder.nz * 2) codes = augment_dataset(codes, 1, is_gaussian_enc, vae) # use only 1 sample for test codes = codes.cpu().numpy() labels = labels.cpu().numpy() return codes, labels test_codes, test_labels = preprocess(args.test_data) test_f1_scores = [] average_f1 = 'macro' f1_scorer = make_scorer(f1_score, average=average_f1, labels=np.unique(test_labels), greater_is_better=True) # log loss: negative log likelihood. We should minimize that, so greater_is_better=False log_loss_scorer = make_scorer(log_loss, needs_proba=True, greater_is_better=False) warnings.filterwarnings('ignore') results = { 'n_samples_per_class': args.num_label_per_class, } n_repeats = args.n_repeats n_splits = min(args.num_label_per_class, 5) for i, fn in enumerate(train_sets): codes, labels = preprocess(fn) if args.resample > 1: # going to augment the training set by sampling # then create a new cross validation function to get the correct indices cross_val = augment_cross_val(labels, args.resample, n_splits, n_repeats) labels = np.repeat(labels, args.resample) else: cross_val = RepeatedStratifiedKFold(n_splits=n_splits, n_repeats=n_repeats) scaler = StandardScaler() codes = scaler.fit_transform(codes) scaled_test_codes = scaler.transform(test_codes) gridsearch = GridSearchCV( LogisticRegression(solver='sag', multi_class='auto'), { "penalty": ['l2'], "C": [0.01, 0.1, 1, 10, 100], }, cv=cross_val, scoring={ "f1": f1_scorer, "log": log_loss_scorer, }, refit=False, ) clf = gridsearch clf.fit(codes, labels) crossval_f1, test_f1 = refit_and_eval( 'f1', clf, clf.cv_results_, codes, labels, scaled_test_codes, test_labels, f1_scorer, ) crossval_log, test_log_loss = refit_and_eval( 'log', clf, clf.cv_results_, codes, labels, scaled_test_codes, test_labels, log_loss_scorer, ) results[i] = { "F1": { 'crossval': crossval_f1, 'test': test_f1 }, "log": { 'crossval': crossval_log, 'test': test_log_loss }, } print(results[i]) if args.classify_using_samples: n_per_class = str(args.num_label_per_class) resample = 1 if args.resample == -1 else args.resample output_fn = os.path.join( args.exp_dir, 'results_sample_' + str(resample) + '_' + n_per_class + '.json') else: output_fn = os.path.join(args.exp_dir, 'results_' + n_per_class + '.json') with open(output_fn, 'w') as f: json.dump(results, f)
def _load_train_data(self): train_data = MonoTextData(self.params.train_data, label=False) return train_data
def main(args, args_model): global logging eval_mode = (args.reconstruct_from != "" or args.eval or args.eval_iw_elbo or args.eval_valid_elbo or args.export_avg_loss_per_ts or args.study_pooling ) # don't make exp dir for reconstruction logging = create_exp_dir(args.exp_dir, scripts_to_save=None, debug=eval_mode) if args.cuda: logging('using cuda') logging(str(args)) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} vocab = {} if getattr(args, 'vocab_file', None): with open(args.vocab_file, 'r', encoding='utf-8') as fvocab: for i, line in enumerate(fvocab): vocab[line.strip()] = i vocab = VocabEntry(vocab) train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab) vocab = train_data.vocab vocab_size = len(vocab) val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) logging('Train data: %d samples' % len(train_data)) logging('finish reading datasets, vocab size is %d' % len(vocab)) logging('dropped sentences: %d' % train_data.dropped) #sys.stdout.flush() log_niter = max((len(train_data) // args.batch_size) // 10, 1) device = torch.device("cuda" if args.cuda else "cpu") vae = create_model(vocab, args, args_model, logging, eval_mode) if args.eval: logging('begin evaluation') vae.eval() with torch.no_grad(): test_data_batch = val_data.create_data_batch(batch_size=1, device=device, batch_first=True) nll, ppl = calc_iwnll(vae, test_data_batch, args, ns=250) logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl)) return if args.eval_iw_elbo: logging('begin evaluation') vae.load_state_dict(torch.load(args.load_path)) vae.eval() with torch.no_grad(): test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) nll, ppl = calc_iw_elbo(vae, test_data_batch, args) logging('iw ELBo: %.4f, iw PPL*: %.4f' % (nll, ppl)) return if args.eval_valid_elbo: logging('begin evaluation on validation set') vae.load_state_dict(torch.load(args.load_path)) vae.eval() val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) with torch.no_grad(): loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args) logging('nll: %.4f, iw ppl: %.4f' % (nll, ppl)) return if args.study_pooling: vae.load_state_dict(torch.load(args.load_path)) vae.eval() with torch.no_grad(): data_batch = train_data.create_data_batch( batch_size=args.batch_size, device=device, batch_first=True) model_dir = os.path.dirname(args.load_path) archive_npy = os.path.join(model_dir, 'pooling.npy') random.shuffle(data_batch) #logs = study_pooling(vae, data_batch, "TRAIN", args, min_doc_size=16) logs = study_pooling(vae, data_batch, args, min_doc_size=4) logs['exp_dir'] = model_dir np.save(archive_npy, logs) return if args.export_avg_loss_per_ts: print("MODEL") print(vae) export_avg_loss_per_ts( vae, train_data, device, args.batch_size, args.load_path, args.export_avg_loss_per_ts, ) return if args.reconstruct_from != "": print("begin decoding") vae.load_state_dict(torch.load(args.reconstruct_from)) vae.eval() with torch.no_grad(): if args.reconstruct_add_labels_to_source: test_data_batch, test_labels_batch = test_data.create_data_batch_labels( batch_size=args.reconstruct_batch_size, device=device, batch_first=True, deterministic=True) c = list(zip(test_data_batch, test_labels_batch)) #random.shuffle(c) test_data_batch, test_labels_batch = zip(*c) else: test_data_batch = test_data.create_data_batch( batch_size=args.reconstruct_batch_size, device=device, batch_first=True) test_labels_batch = None #random.shuffle(test_data_batch) # test(vae, test_data_batch, "TEST", args) reconstruct(vae, test_data_batch, vocab, args.decoding_strategy, args.reconstruct_to, test_labels_batch, args.reconstruct_max_examples, args.force_absolute_length, args.no_unk) return if args.freeze_encoder_exc: assert args.enc_type == 'lstm' enc_params = vae.encoder.linear.parameters() else: enc_params = vae.encoder.parameters() dec_params = vae.decoder.parameters() if args.opt == 'sgd': optimizer_fn = optim.SGD elif args.opt == 'adam': optimizer_fn = optim.Adam else: raise ValueError("optimizer not supported") def optimizer_fn_(params): return optimizer_fn(params, lr=args.lr, momentum=args.momentum) enc_optimizer = optimizer_fn_(enc_params) dec_optimizer = optimizer_fn_(dec_params) iter_ = decay_cnt = 0 best_loss = 1e4 best_kl = best_nll = best_ppl = 0 vae.train() start = time.time() kl_weight = args.kl_start if args.warm_up > 0: anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size)) else: anneal_rate = 0 dim_target_kl = args.target_kl / float(args.nz) train_data_batch = train_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) # At any point you can hit Ctrl + C to break out of training early. try: for epoch in range(args.epochs): report_kl_loss = report_rec_loss = report_loss = 0 report_num_words = report_num_sents = 0 for i in np.random.permutation(len(train_data_batch)): batch_data = train_data_batch[i] batch_size, sent_len = batch_data.size() # not predict start symbol report_num_words += (sent_len - 1) * batch_size report_num_sents += batch_size kl_weight = min(1.0, kl_weight + anneal_rate) enc_optimizer.zero_grad() dec_optimizer.zero_grad() if args.fb == 0: loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) elif args.fb == 1: loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples, sum_over_len=False) kl_mask = (loss_kl > args.target_kl).float() loss_rc = loss_rc.sum(-1) loss = loss_rc + kl_mask * kl_weight * loss_kl elif args.fb == 2: mu, logvar = vae.encoder(batch_data) z = vae.encoder.reparameterize(mu, logvar, args.nsamples) loss_kl = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1) kl_mask = (loss_kl > dim_target_kl).float() fake_loss_kl = (kl_mask * loss_kl).sum(dim=1) loss_rc = vae.decoder.reconstruct_error(batch_data, z).mean(dim=1) loss = loss_rc + kl_weight * fake_loss_kl loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) loss_rc = loss_rc.sum() loss_kl = loss_kl.sum() if not args.freeze_encoder: enc_optimizer.step() dec_optimizer.step() report_rec_loss += loss_rc.item() report_kl_loss += loss_kl.item() report_loss += loss_rc.item() + loss_kl.item() if iter_ % log_niter == 0: train_loss = report_loss / report_num_sents logging('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \ 'time %.2fs, kl_weight %.4f' % (epoch, iter_, train_loss, report_kl_loss / report_num_sents, report_rec_loss / report_num_sents, time.time() - start, kl_weight)) report_rec_loss = report_kl_loss = report_loss = 0 report_num_words = report_num_sents = 0 iter_ += 1 logging('kl weight %.4f' % kl_weight) logging('lr {}'.format(opt_dict["lr"])) vae.eval() with torch.no_grad(): loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args) au, au_var = calc_au(vae, val_data_batch) logging("%d active units" % au) if args.save_ckpt > 0 and epoch <= args.save_ckpt: logging('save checkpoint') torch.save( vae.state_dict(), os.path.join(args.exp_dir, f'model_ckpt_{epoch}.pt')) if loss < best_loss: logging('update best loss') best_loss = loss best_nll = nll best_kl = kl best_ppl = ppl torch.save(vae.state_dict(), args.save_path) if loss > opt_dict["best_loss"]: opt_dict["not_improved"] += 1 if opt_dict[ "not_improved"] >= decay_epoch and epoch >= args.load_best_epoch: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay vae.load_state_dict(torch.load(args.save_path)) logging('new lr: %f' % opt_dict["lr"]) decay_cnt += 1 enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break if args.save_latent > 0 and epoch <= args.save_latent: visualize_latent(args, epoch, vae, "cuda", test_data) vae.train() except KeyboardInterrupt: logging('-' * 100) logging('Exiting from training early') # compute importance weighted estimate of log p(x) vae.load_state_dict(torch.load(args.save_path)) vae.eval() with torch.no_grad(): loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) au, au_var = calc_au(vae, test_data_batch)
def main(args, args_model): global logging log_fn = 'log_classifier_full_data_' + args.discriminator + '.txt' logging = get_logger_existing_dir(os.path.dirname(args.load_path), log_fn) if args.cuda: logging('using cuda') logging(str(args)) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} vocab = {} if getattr(args, 'vocab_file', None): with open(args.vocab_file) as fvocab: for i, line in enumerate(fvocab): vocab[line.strip()] = i vocab = VocabEntry(vocab) train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab) vocab = train_data.vocab vocab_size = len(vocab) val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) logging('Train data: %d samples' % len(train_data)) logging('finish reading datasets, vocab size is %d' % len(vocab)) logging('dropped sentences: %d' % train_data.dropped) #sys.stdout.flush() log_niter = max(1, (len(train_data) // (args.batch_size * args.update_every)) // 10) model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) #device = torch.device("cuda" if args.cuda else "cpu") device = "cuda" if args.cuda else "cpu" vae = create_model(vocab, args, args_model, logging, eval_mode=True) vae.eval() args_model.ncluster = train_data.n_unique_labels print("Number of targets:", args_model.ncluster) if args.discriminator == "linear": discriminator = LinearDiscriminator(args_model, vae.encoder).to(device) elif args.discriminator == "mlp": discriminator = MLPDiscriminator(args_model, vae.encoder).to(device) print("Discriminator:") print(discriminator) if args.opt == "sgd": optimizer = optim.SGD(discriminator.parameters(), lr=args.lr, momentum=args.momentum) opt_dict['lr'] = args.lr elif args.opt == "adam": optimizer = optim.Adam(discriminator.parameters(), lr=args.lr) opt_dict['lr'] = args.lr else: raise ValueError("optimizer not supported") iter_ = decay_cnt = 0 best_loss = 1e4 best_kl = best_nll = best_ppl = 0 pre_mi = 0 discriminator.train() start = time.time() kl_weight = args.kl_start if args.warm_up > 0: anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size)) else: anneal_rate = 0 dim_target_kl = args.target_kl / float(args.nz) train_data_batch, train_labels_batch = train_data.create_data_batch_labels( batch_size=args.batch_size, device=device, batch_first=True) val_data_batch, val_labels_batch = val_data.create_data_batch_labels( batch_size=128, device=device, batch_first=True) test_data_batch, test_labels_batch = test_data.create_data_batch_labels( batch_size=128, device=device, batch_first=True) acc_cnt = 1 acc_loss = 0. for epoch in range(args.epochs): report_loss = 0 report_num_words = report_num_sents = 0 acc_batch_size = 0 optimizer.zero_grad() for i in np.random.permutation(len(train_data_batch)): batch_data = train_data_batch[i] batch_labels = train_labels_batch[i] if args.one_indexed_labels: batch_labels = [int(x) - 1 for x in batch_labels] else: batch_labels = [int(x) for x in batch_labels] batch_labels = torch.tensor(batch_labels, dtype=torch.long, requires_grad=False, device=device) batch_size, sent_len = batch_data.size() # not predict start symbol report_num_words += (sent_len - 1) * batch_size report_num_sents += batch_size acc_batch_size += batch_size # (batch_size) loss, _ = discriminator.get_performance(batch_data, batch_labels) acc_loss = acc_loss + loss.sum() if acc_cnt % args.update_every == 0: acc_loss = acc_loss / acc_batch_size acc_loss.backward() torch.nn.utils.clip_grad_norm_(discriminator.parameters(), clip_grad) optimizer.step() optimizer.zero_grad() acc_cnt = 0 acc_loss = 0 acc_batch_size = 0 acc_cnt += 1 report_loss += loss.sum().item() if iter_ % log_niter == 0: #train_loss = (report_rec_loss + report_kl_loss) / report_num_sents train_loss = report_loss / report_num_sents logging('epoch: %d, iter: %d, avg_loss: %.4f, ' \ 'time %.2fs' % (epoch, iter_, train_loss, time.time() - start)) #sys.stdout.flush() iter_ += 1 logging('lr {}'.format(opt_dict["lr"])) discriminator.eval() with torch.no_grad(): loss, macro_f1 = test(discriminator, val_data_batch, val_labels_batch, "VAL", args) # print(au_var) if loss < best_loss: logging('update best loss') best_loss = loss best_acc = macro_f1 #torch.save(discriminator.state_dict(), args.save_path) best_discriminator = discriminator.state_dict() if loss > opt_dict["best_loss"]: opt_dict["not_improved"] += 1 if opt_dict[ "not_improved"] >= decay_epoch and epoch >= args.load_best_epoch: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay #discriminator.load_state_dict(torch.load(args.save_path)) discriminator.load_state_dict(best_discriminator) logging('new lr: %f' % opt_dict["lr"]) decay_cnt += 1 if args.opt == "sgd": optimizer = optim.SGD(discriminator.parameters(), lr=opt_dict["lr"], momentum=args.momentum) opt_dict['lr'] = opt_dict["lr"] elif args.opt == "adam": optimizer = optim.Adam(discriminator.parameters(), lr=opt_dict["lr"]) opt_dict['lr'] = opt_dict["lr"] else: raise ValueError("optimizer not supported") else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break discriminator.train() # compute importance weighted estimate of log p(x) discriminator.load_state_dict(best_discriminator) discriminator.eval() with torch.no_grad(): loss, acc = test(discriminator, test_data_batch, test_labels_batch, "TEST", args)