def get_vae(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, beta=1): ArgsObj = namedtuple("Args", ["latent_size", "device", "fb_mode", "beta"]) args = ArgsObj(latent_size=LATENT_SIZE_LARGE, device=get_device(), fb_mode=0, beta=beta) checkpoint_full_dir = os.path.join(OUTPUT_DIR, "checkpoint-full-31250") if not torch.cuda.is_available(): checkpoint = torch.load(os.path.join(checkpoint_full_dir, "training.bin"), map_location="cpu") else: checkpoint = torch.load( os.path.join(checkpoint_full_dir, "training.bin")) model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args) model_vae.load_state_dict(checkpoint["model_state_dict"]) # logger.info("Pre-trained Optimus is successfully loaded") model_vae.to(args.device) return model_vae
def main(): # Load MNIST image dataset mnist_train_data = datasets.MNIST( '/home/ajays/Downloads/',download=True,transform=transforms.ToTensor() ) mnist_test_data = datasets.MNIST('/home/ajays/Downloads/',train=False,download=True) train_loader = torch.utils.data.DataLoader( mnist_train_data, batch_size = batch_size, shuffle=True ) # Instantiation vae = VAE(n_inputs=32) # ********************* # IMAGE VAE TRAINING # ********************* # plot before training # o_before, mu, logvar = vae(mnist_train_data[0][0].reshape((1,1,28,28))) # plt.imshow(o_before.detach().numpy().reshape((28,28))) # plt.show() # train vae.load_state_dict(torch.load(LOAD_PATH)) #vae = train_image_vae(vae, train_loader) # After training # o_after, mu, logvar = vae(example[0].reshape((1,1,28,28))) o_after = vae.decode(torch.randn((128))) plt.imshow(o_after.detach().numpy().reshape((28,28))) plt.show()
def get_vae_recons(loader, hidden_size=256): model = VAE(3, hidden_size, hidden_size).to(DEVICE) ckpt = torch.load("./models/imagenet_hs_128_256_vae.pt") model.load_state_dict(ckpt) args = type('', (), {})() args.device = DEVICE gen_img, _ = next(iter(loader)) # grid = make_grid(gen_img.cpu(), nrow=8) # torchvision.utils.save_image(grid, "hs_{}_recons.png".format(hidden_size)) #exit() reconstruction = vae.generate_samples(gen_img, model, args) grid = make_grid(reconstruction.cpu(), nrow=8) return grid
def main(args): class uniform_initializer(object): def __init__(self, stdv): self.stdv = stdv def __call__(self, tensor): nn.init.uniform_(tensor, -self.stdv, self.stdv) class xavier_normal_initializer(object): def __call__(self, tensor): nn.init.xavier_normal_(tensor) if args.cuda: print('using cuda') print(args) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} train_data = MonoTextData(args.train_data, label=args.label) vocab = train_data.vocab vocab_size = len(vocab) val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) print('Train data: %d samples' % len(train_data)) print('finish reading datasets, vocab size is %d' % len(vocab)) print('dropped sentences: %d' % train_data.dropped) sys.stdout.flush() log_niter = (len(train_data) // args.batch_size) // 10 model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) if args.enc_type == 'lstm': encoder = LSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh else: raise ValueError("the specified encoder type is not supported") decoder = LSTMDecoder(args, vocab, model_init, emb_init) device = torch.device("cuda" if args.cuda else "cpu") args.device = device vae = VAE(encoder, decoder, args).to(device) if args.eval: print('begin evaluation') vae.load_state_dict(torch.load(args.load_path)) vae.eval() with torch.no_grad(): test_data_batch = test_data.create_data_batch( batch_size=args.batch_size, device=device, batch_first=True) test(vae, test_data_batch, "TEST", args) au, au_var = calc_au(vae, test_data_batch) print("%d active units" % au) # print(au_var) test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) calc_iwnll(vae, test_data_batch, args) return enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=1.0, momentum=args.momentum) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=1.0, momentum=args.momentum) opt_dict['lr'] = 1.0 iter_ = decay_cnt = 0 best_loss = 1e4 best_kl = best_nll = best_ppl = 0 pre_mi = 0 aggressive_flag = True if args.aggressive else False vae.train() start = time.time() kl_weight = args.kl_start anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size)) train_data_batch = train_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) for epoch in range(args.epochs): report_kl_loss = report_rec_loss = 0 report_num_words = report_num_sents = 0 for i in np.random.permutation(len(train_data_batch)): batch_data = train_data_batch[i] batch_size, sent_len = batch_data.size() # not predict start symbol report_num_words += (sent_len - 1) * batch_size report_num_sents += batch_size # kl_weight = 1.0 kl_weight = min(1.0, kl_weight + anneal_rate) sub_iter = 1 batch_data_enc = batch_data burn_num_words = 0 burn_pre_loss = 1e4 burn_cur_loss = 0 while aggressive_flag and sub_iter < 100: enc_optimizer.zero_grad() dec_optimizer.zero_grad() burn_batch_size, burn_sents_len = batch_data_enc.size() burn_num_words += (burn_sents_len - 1) * burn_batch_size loss, loss_rc, loss_kl = vae.loss(batch_data_enc, kl_weight, nsamples=args.nsamples) burn_cur_loss += loss.sum().item() loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) enc_optimizer.step() id_ = np.random.random_integers(0, len(train_data_batch) - 1) batch_data_enc = train_data_batch[id_] if sub_iter % 15 == 0: burn_cur_loss = burn_cur_loss / burn_num_words if burn_pre_loss - burn_cur_loss < 0: break burn_pre_loss = burn_cur_loss burn_cur_loss = burn_num_words = 0 sub_iter += 1 # if sub_iter >= 30: # break # print(sub_iter) enc_optimizer.zero_grad() dec_optimizer.zero_grad() loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) loss_rc = loss_rc.sum() loss_kl = loss_kl.sum() if not aggressive_flag: enc_optimizer.step() dec_optimizer.step() report_rec_loss += loss_rc.item() report_kl_loss += loss_kl.item() if iter_ % log_niter == 0: train_loss = (report_rec_loss + report_kl_loss) / report_num_sents if aggressive_flag or epoch == 0: vae.eval() with torch.no_grad(): mi = calc_mi(vae, val_data_batch) au, _ = calc_au(vae, val_data_batch) vae.train() print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \ 'au %d, time elapsed %.2fs' % (epoch, iter_, train_loss, report_kl_loss / report_num_sents, mi, report_rec_loss / report_num_sents, au, time.time() - start)) else: print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \ 'time elapsed %.2fs' % (epoch, iter_, train_loss, report_kl_loss / report_num_sents, report_rec_loss / report_num_sents, time.time() - start)) sys.stdout.flush() report_rec_loss = report_kl_loss = 0 report_num_words = report_num_sents = 0 iter_ += 1 if aggressive_flag and (iter_ % len(train_data_batch)) == 0: vae.eval() cur_mi = calc_mi(vae, val_data_batch) vae.train() print("pre mi:%.4f. cur mi:%.4f" % (pre_mi, cur_mi)) if cur_mi - pre_mi < 0: aggressive_flag = False print("STOP BURNING") pre_mi = cur_mi print('kl weight %.4f' % kl_weight) vae.eval() with torch.no_grad(): loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args) au, au_var = calc_au(vae, val_data_batch) print("%d active units" % au) # print(au_var) if loss < best_loss: print('update best loss') best_loss = loss best_nll = nll best_kl = kl best_ppl = ppl torch.save(vae.state_dict(), args.save_path) if loss > opt_dict["best_loss"]: opt_dict["not_improved"] += 1 if opt_dict["not_improved"] >= decay_epoch and epoch >= 15: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay vae.load_state_dict(torch.load(args.save_path)) print('new lr: %f' % opt_dict["lr"]) decay_cnt += 1 enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break if epoch % args.test_nepoch == 0: with torch.no_grad(): loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) vae.train() # compute importance weighted estimate of log p(x) vae.load_state_dict(torch.load(args.save_path)) vae.eval() with torch.no_grad(): loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) au, au_var = calc_au(vae, test_data_batch) print("%d active units" % au) # print(au_var) test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) with torch.no_grad(): calc_iwnll(vae, test_data_batch, args)
class VAESampler: def __init__(self, decode_from, params, cuda=False): self.decode_from = decode_from self.params = params params.enc_nh = params.dec_nh # not sure why this is necessary... self.train_data = MonoTextData(params.train_data, label=False) self.vocab = self.train_data.vocab self.vocab_size = len(self.vocab) # do I need these? model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) params.device = self.device = torch.device("cuda" if cuda else "cpu") self.encoder = LSTMEncoder(params, self.vocab_size, model_init, emb_init) self.decoder = LSTMDecoder(params, self.vocab, model_init, emb_init) self.vae = VAE(self.encoder, self.decoder, params).to(params.device) # assuming models were trained on a gpu... if cuda: self.vae.load_state_dict(torch.load(self.decode_from)) else: self.vae.load_state_dict( torch.load(self.decode_from, map_location='cpu')) def to_s(self, decoded): return [' '.join(item) for item in decoded] def beam(self, z, K=5): decoded_batch = self.vae.decoder.beam_search_decode(z, K) return self.to_s(decoded_batch) def sample(self, z, temperature=1.0): decoded_batch = self.vae.decoder.sample_decode(z, temperature) return self.to_s(decoded_batch) def greedy(self, z): decoded_batch = self.vae.decoder.greedy_decode(z) return self.to_s(decoded_batch) def str2ids(self, s): "encode string s as list of word ids" raise NotImplemented def encode(self, t): """ Returns (z, mu, log_var) from encoder given list of strings. z is a sample from gaussian specified with (mu, log_var) """ str_ids = [] for s in t: ids = self.str2ids(s) str_ids.append(ids) tensor = self.train_data._to_tensor(str_ids, True, self.device)[0] z, (mu, log_var) = self.vae.encoder.sample(tensor, 1) return z, mu, log_var def z(self, t): "return sampled latent zs for list of strings t" z, mu, logvar = self.encode(t) return z.squeeze(1) def mu(self, t): "return mean of latent gaussian for list of strings t" z, mu, logvar = self.encode(t) return mu.squeeze(1)
def main(args): if args.save_path == '': make_savepath(args) seed(args) if args.cuda: print('using cuda') print(args) device = torch.device("cuda" if args.cuda else "cpu") args.device = device opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} all_data = torch.load(args.data_file) x_train, x_val, x_test = all_data x_train = x_train.to(device) x_val = x_val.to(device) x_test = x_test.to(device) y_size = 1 y_train = x_train.new_zeros(x_train.size(0), y_size) y_val = x_train.new_zeros(x_val.size(0), y_size) y_test = x_train.new_zeros(x_test.size(0), y_size) print(torch.__version__) train_data = torch.utils.data.TensorDataset(x_train, y_train) val_data = torch.utils.data.TensorDataset(x_val, y_val) test_data = torch.utils.data.TensorDataset(x_test, y_test) train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True) val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=True) print('Train data: %d batches' % len(train_loader)) print('Val data: %d batches' % len(val_loader)) print('Test data: %d batches' % len(test_loader)) sys.stdout.flush() log_niter = len(train_loader) // 5 encoder = ResNetEncoderV2(args) decoder = PixelCNNDecoderV2(args) vae = VAE(encoder, decoder, args).to(device) if args.sample_from != '': save_dir = "samples/%s" % args.dataset if not os.path.exists(save_dir): os.makedirs(save_dir) vae.load_state_dict(torch.load(args.sample_from)) vae.eval() with torch.no_grad(): sample_z = vae.sample_from_prior(400).to(device) sample_x, sample_probs = vae.decode(sample_z, False) image_file = 'sample_binary_from_%s.png' % ( args.sample_from.split('/')[-1][:-3]) save_image(sample_x.data.cpu(), os.path.join(save_dir, image_file), nrow=20) image_file = 'sample_cont_from_%s.png' % ( args.sample_from.split('/')[-1][:-3]) save_image(sample_probs.data.cpu(), os.path.join(save_dir, image_file), nrow=20) return if args.eval: print('begin evaluation') test_loader = torch.utils.data.DataLoader(test_data, batch_size=50, shuffle=True) vae.load_state_dict(torch.load(args.load_path)) vae.eval() with torch.no_grad(): test(vae, test_loader, "TEST", args) au, au_var = calc_au(vae, test_loader) print("%d active units" % au) # print(au_var) calc_iwnll(vae, test_loader, args) return enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001) dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001) opt_dict['lr'] = 0.001 iter_ = 0 best_loss = 1e4 best_kl = best_nll = best_ppl = 0 decay_cnt = pre_mi = best_mi = mi_not_improved = 0 aggressive_flag = True if args.aggressive else False vae.train() start = time.time() kl_weight = args.kl_start anneal_rate = (1.0 - args.kl_start) / (args.warm_up * len(train_loader)) for epoch in range(args.epochs): report_kl_loss = report_rec_loss = 0 report_num_examples = 0 for datum in train_loader: batch_data, _ = datum batch_data = torch.bernoulli(batch_data) batch_size = batch_data.size(0) report_num_examples += batch_size # kl_weight = 1.0 kl_weight = min(1.0, kl_weight + anneal_rate) sub_iter = 1 batch_data_enc = batch_data burn_num_examples = 0 burn_pre_loss = 1e4 burn_cur_loss = 0 while aggressive_flag and sub_iter < 100: enc_optimizer.zero_grad() dec_optimizer.zero_grad() burn_num_examples += batch_data_enc.size(0) loss, loss_rc, loss_kl = vae.loss(batch_data_enc, kl_weight, nsamples=args.nsamples) burn_cur_loss += loss.sum().item() loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) enc_optimizer.step() id_ = np.random.choice(x_train.size(0), args.batch_size, replace=False) batch_data_enc = torch.bernoulli(x_train[id_]) if sub_iter % 10 == 0: burn_cur_loss = burn_cur_loss / burn_num_examples if burn_pre_loss - burn_cur_loss < 0: break burn_pre_loss = burn_cur_loss burn_cur_loss = burn_num_examples = 0 sub_iter += 1 # print(sub_iter) enc_optimizer.zero_grad() dec_optimizer.zero_grad() loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) loss_rc = loss_rc.sum() loss_kl = loss_kl.sum() if not aggressive_flag: enc_optimizer.step() dec_optimizer.step() report_rec_loss += loss_rc.item() report_kl_loss += loss_kl.item() if iter_ % log_niter == 0: train_loss = (report_rec_loss + report_kl_loss) / report_num_examples if aggressive_flag or epoch == 0: vae.eval() with torch.no_grad(): mi = calc_mi(vae, val_loader) au, _ = calc_au(vae, val_loader) vae.train() print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \ 'au %d, time elapsed %.2fs' % (epoch, iter_, train_loss, report_kl_loss / report_num_examples, mi, report_rec_loss / report_num_examples, au, time.time() - start)) else: print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \ 'time elapsed %.2fs' % (epoch, iter_, train_loss, report_kl_loss / report_num_examples, report_rec_loss / report_num_examples, time.time() - start)) sys.stdout.flush() report_rec_loss = report_kl_loss = 0 report_num_examples = 0 iter_ += 1 if aggressive_flag and (iter_ % len(train_loader)) == 0: vae.eval() cur_mi = calc_mi(vae, val_loader) vae.train() if cur_mi - best_mi < 0: mi_not_improved += 1 if mi_not_improved == 5: aggressive_flag = False print("STOP BURNING") else: best_mi = cur_mi pre_mi = cur_mi print('kl weight %.4f' % kl_weight) print('epoch: %d, VAL' % epoch) vae.eval() with torch.no_grad(): loss, nll, kl = test(vae, val_loader, "VAL", args) au, au_var = calc_au(vae, val_loader) print("%d active units" % au) # print(au_var) if loss < best_loss: print('update best loss') best_loss = loss best_nll = nll best_kl = kl torch.save(vae.state_dict(), args.save_path) if loss > best_loss: opt_dict["not_improved"] += 1 if opt_dict["not_improved"] >= decay_epoch: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay vae.load_state_dict(torch.load(args.save_path)) decay_cnt += 1 print('new lr: %f' % opt_dict["lr"]) enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=opt_dict["lr"]) dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=opt_dict["lr"]) else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break if epoch % args.test_nepoch == 0: with torch.no_grad(): loss, nll, kl = test(vae, test_loader, "TEST", args) vae.train() # compute importance weighted estimate of log p(x) vae.load_state_dict(torch.load(args.save_path)) vae.eval() with torch.no_grad(): loss, nll, kl = test(vae, test_loader, "TEST", args) au, au_var = calc_au(vae, test_loader) print("%d active units" % au) # print(au_var) test_loader = torch.utils.data.DataLoader(test_data, batch_size=50, shuffle=True) with torch.no_grad(): calc_iwnll(vae, test_loader, args)
def main(args): train_data = MonoTextData(args.train_data, label=args.label) vocab = train_data.vocab vocab_size = len(vocab) vocab_path = os.path.join("/".join(args.train_data.split("/")[:-1]), "vocab.txt") with open(vocab_path, "w") as fout: for i in range(vocab_size): fout.write("{}\n".format(vocab.id2word(i))) #return val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) print('Train data: %d samples' % len(train_data)) print('finish reading datasets, vocab size is %d' % len(vocab)) print('dropped sentences: %d' % train_data.dropped) sys.stdout.flush() log_niter = (len(train_data)//args.batch_size)//10 model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) #device = torch.device("cuda" if args.cuda else "cpu") device = "cuda" if args.cuda else "cpu" args.device = device if args.enc_type == 'lstm': encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh else: raise ValueError("the specified encoder type is not supported") decoder = LSTMDecoder(args, vocab, model_init, emb_init) vae = VAE(encoder, decoder, args).to(device) print('begin evaluation') vae.load_state_dict(torch.load(args.load_path)) vae.eval() with torch.no_grad(): test_data_batch, test_batch_labels = test_data.create_data_batch_labels(batch_size=args.batch_size, device=device, batch_first=True) # test(vae, test_data_batch, "TEST", args) # au, au_var = calc_au(vae, test_data_batch) # print("%d active units" % au) train_data_batch, train_batch_labels = train_data.create_data_batch_labels(batch_size=args.batch_size, device=device, batch_first=True) val_data_batch, val_batch_labels = val_data.create_data_batch_labels(batch_size=args.batch_size, device=device, batch_first=True) print("getting vectors for training") print(args.save_dir) save_latents(args, vae, train_data_batch, train_batch_labels, "train") print("getting vectors for validating") save_latents(args, vae, val_data_batch, val_batch_labels, "val") print("getting vectors for testing") save_latents(args, vae, test_data_batch, test_batch_labels, "test")
def main(args): class uniform_initializer(object): def __init__(self, stdv): self.stdv = stdv def __call__(self, tensor): nn.init.uniform_(tensor, -self.stdv, self.stdv) class xavier_normal_initializer(object): def __call__(self, tensor): nn.init.xavier_normal_(tensor) if args.cuda: print('using cuda') print(args) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} train_data = MonoTextData(args.train_data) vocab = train_data.vocab vocab_size = len(vocab) val_data = MonoTextData(args.val_data, vocab=vocab) test_data = MonoTextData(args.test_data, vocab=vocab) print('Train data: %d samples' % len(train_data)) print('finish reading datasets, vocab size is %d' % len(vocab)) print('dropped sentences: %d' % train_data.dropped) sys.stdout.flush() log_niter = (len(train_data) // args.batch_size) // 10 model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) device = torch.device("cuda" if args.cuda else "cpu") args.device = device encoder = LSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh decoder = LSTMDecoder(args, vocab, model_init, emb_init) vae = VAE(encoder, decoder, args).to(device) if args.optim == 'sgd': enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=1.0) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=1.0) opt_dict['lr'] = 1.0 else: enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001, betas=(0.9, 0.999)) dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001, betas=(0.9, 0.999)) opt_dict['lr'] = 0.001 iter_ = decay_cnt = 0 best_loss = 1e4 best_kl = best_nll = best_ppl = 0 pre_mi = -1 aggressive_flag = True if args.aggressive else False vae.train() start = time.time() kl_weight = args.kl_start anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size)) plot_data = train_data.data_sample(nsample=args.num_plot, device=device, batch_first=True) if args.plot_mode == 'multiple': grid_z = generate_grid(args.zmin, args.zmax, args.dz, device, ndim=1) plot_fn = plot_multiple elif args.plot_mode == 'single': grid_z = generate_grid(args.zmin, args.zmax, args.dz, device, ndim=1) plot_fn = plot_single posterior_mean = [] infer_mean = [] posterior_mean.append( vae.calc_model_posterior_mean(plot_data[0], grid_z)) infer_mean.append(vae.calc_infer_mean(plot_data[0])) train_data_batch = train_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) for epoch in range(args.epochs): report_kl_loss = report_rec_loss = 0 report_num_words = report_num_sents = 0 for i in np.random.permutation(len(train_data_batch)): if args.plot_mode == "single": batch_data, _ = plot_data else: batch_data = train_data_batch[i] batch_size, sent_len = batch_data.size() # not predict start symbol report_num_words += (sent_len - 1) * batch_size report_num_sents += batch_size # kl_weight = 1.0 kl_weight = min(1.0, kl_weight + anneal_rate) sub_iter = 1 batch_data_enc = batch_data burn_num_words = 0 burn_pre_loss = 1e4 burn_cur_loss = 0 while aggressive_flag and sub_iter < 100: enc_optimizer.zero_grad() dec_optimizer.zero_grad() burn_batch_size, burn_sents_len = batch_data_enc.size() burn_num_words += (burn_sents_len - 1) * burn_batch_size loss, loss_rc, loss_kl = vae.loss(batch_data_enc, kl_weight, nsamples=args.nsamples) burn_cur_loss += loss.sum().item() loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) enc_optimizer.step() if args.plot_mode == "single": batch_data_enc, _ = plot_data else: id_ = np.random.random_integers(0, len(train_data_batch) - 1) batch_data_enc = train_data_batch[id_] if sub_iter % 15 == 0: burn_cur_loss = burn_cur_loss / burn_num_words if burn_pre_loss - burn_cur_loss < 0: break burn_pre_loss = burn_cur_loss burn_cur_loss = burn_num_words = 0 sub_iter += 1 if args.plot_mode == 'single' and epoch == 0 and aggressive_flag: vae.eval() with torch.no_grad(): posterior_mean.append(posterior_mean[-1]) infer_mean.append(vae.calc_infer_mean(plot_data[0])) vae.train() enc_optimizer.zero_grad() dec_optimizer.zero_grad() loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) loss_rc = loss_rc.sum() loss_kl = loss_kl.sum() if not aggressive_flag: enc_optimizer.step() dec_optimizer.step() if args.plot_mode == 'single' and epoch == 0: vae.eval() with torch.no_grad(): posterior_mean.append( vae.calc_model_posterior_mean(plot_data[0], grid_z)) if aggressive_flag: infer_mean.append(infer_mean[-1]) else: infer_mean.append(vae.calc_infer_mean(plot_data[0])) vae.train() report_rec_loss += loss_rc.item() report_kl_loss += loss_kl.item() if iter_ % log_niter == 0: train_loss = (report_rec_loss + report_kl_loss) / report_num_sents if aggressive_flag or epoch == 0: vae.eval() mi = calc_mi(vae, val_data_batch) vae.train() print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \ 'time elapsed %.2fs' % (epoch, iter_, train_loss, report_kl_loss / report_num_sents, mi, report_rec_loss / report_num_sents, time.time() - start)) else: print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \ 'time elapsed %.2fs' % (epoch, iter_, train_loss, report_kl_loss / report_num_sents, report_rec_loss / report_num_sents, time.time() - start)) sys.stdout.flush() report_rec_loss = report_kl_loss = 0 report_num_words = report_num_sents = 0 if iter_ % args.plot_niter == 0 and epoch == 0: vae.eval() with torch.no_grad(): if args.plot_mode == 'single' and iter_ != 0: plot_fn(infer_mean, posterior_mean, args) return elif args.plot_mode == "multiple": plot_fn(vae, plot_data, grid_z, iter_, args) vae.train() iter_ += 1 if aggressive_flag and (iter_ % len(train_data_batch)) == 0: vae.eval() cur_mi = calc_mi(vae, val_data_batch) vae.train() if cur_mi - pre_mi < 0: aggressive_flag = False print("STOP BURNING") pre_mi = cur_mi # return print('kl weight %.4f' % kl_weight) print('epoch: %d, VAL' % epoch) with torch.no_grad(): plot_fn(vae, plot_data, grid_z, iter_, args) vae.eval() with torch.no_grad(): loss, nll, kl, ppl = test(vae, val_data_batch, "VAL", args) if loss < best_loss: print('update best loss') best_loss = loss best_nll = nll best_kl = kl best_ppl = ppl torch.save(vae.state_dict(), args.save_path) if loss > opt_dict["best_loss"]: opt_dict["not_improved"] += 1 if opt_dict["not_improved"] >= decay_epoch: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay vae.load_state_dict(torch.load(args.save_path)) print('new lr: %f' % opt_dict["lr"]) decay_cnt += 1 if args.optim == 'sgd': enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=opt_dict["lr"]) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=opt_dict["lr"]) else: enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=opt_dict["lr"], betas=(0.5, 0.999)) dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=opt_dict["lr"], betas=(0.5, 0.999)) else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break if epoch % args.test_nepoch == 0: with torch.no_grad(): loss, nll, kl, ppl = test(vae, test_data_batch, "TEST", args) vae.train() print('best_loss: %.4f, kl: %.4f, nll: %.4f, ppl: %.4f' \ % (best_loss, best_kl, best_nll, best_ppl)) sys.stdout.flush() # compute importance weighted estimate of log p(x) vae.load_state_dict(torch.load(args.save_path)) vae.eval() test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) with torch.no_grad(): calc_iwnll(vae, test_data_batch, args)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file).") parser.add_argument( "--eval_data_file", default=None, type=str, help= "An input evaluation data file to evaluate the perplexity on (a text file)." ) parser.add_argument("--checkpoint_dir", default=None, type=str, required=True, help="The directory where checkpoints are saved.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument("--dataset", default='Snli', type=str, help="The dataset.") ## Variational auto-encoder parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.") parser.add_argument("--total_sents", default=10, type=int, help="Total sentences to test recontruction.") parser.add_argument("--num_interpolation_steps", default=10, type=int, help="Total sentences to test recontruction.") parser.add_argument("--play_mode", default="interpolation", type=str, help="interpolation or reconstruction.") ## Encoder options parser.add_argument( "--encoder_model_type", default="bert", type=str, help="The encoder model architecture to be fine-tuned.") parser.add_argument( "--encoder_model_name_or_path", default="bert-base-cased", type=str, help="The encoder model checkpoint for weights initialization.") parser.add_argument( "--encoder_config_name", default="", type=str, help= "Optional pretrained config name or path if not the same as model_name_or_path" ) parser.add_argument( "--encoder_tokenizer_name", default="", type=str, help= "Optional pretrained tokenizer name or path if not the same as model_name_or_path" ) ## Decoder options parser.add_argument( "--decoder_model_type", default="gpt2", type=str, help="The decoder model architecture to be fine-tuned.") parser.add_argument( "--decoder_model_name_or_path", default="bert-base-cased", type=str, help="The decoder model checkpoint for weights initialization.") parser.add_argument( "--decoder_config_name", default="", type=str, help= "Optional pretrained config name or path if not the same as model_name_or_path" ) parser.add_argument( "--decoder_tokenizer_name", default="", type=str, help= "Optional pretrained tokenizer name or path if not the same as model_name_or_path" ) parser.add_argument("--per_gpu_train_batch_size", default=1, type=int, help="Batch size per GPU/CPU for training.") parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int, help="Batch size per GPU/CPU for evaluation.") parser.add_argument('--gloabl_step_eval', type=int, default=661, help="Evaluate the results at the given global step") parser.add_argument( "--max_seq_length", default=512, type=int, help= "Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length" ) # Interact with users parser.add_argument("--interact_with_user_input", action='store_true', help="Use user input to interact_with.") parser.add_argument("--sent_source", type=str, default="") parser.add_argument("--sent_target", type=str, default="") parser.add_argument("--sent_input", type=str, default="") parser.add_argument("--degree_to_target", type=float, default="1.0") ## Variational auto-encoder parser.add_argument("--nz", default=32, type=int, help="Latent space dimension.") parser.add_argument("--prompt", type=str, default="") parser.add_argument("--padding_text", type=str, default="") parser.add_argument("--length", type=int, default=20) parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_k", type=int, default=0) parser.add_argument("--top_p", type=float, default=1.0) parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( "--block_size", default=-1, type=int, help="Optional input sequence length after tokenization." "The training dataset will be truncated in block of this size for training." "Default to the model max input length for single sentence inputs (take into account special tokens)." ) parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--use_philly", action='store_true', help="Use Philly for computing.") args = parser.parse_args() args.device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.n_gpu = torch.cuda.device_count() set_seed(args) args.encoder_model_type = args.encoder_model_type.lower() args.decoder_model_type = args.decoder_model_type.lower() global_step = args.gloabl_step_eval output_encoder_dir = os.path.join( args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step)) output_decoder_dir = os.path.join( args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step)) checkpoints = [[output_encoder_dir, output_decoder_dir]] logger.info("Evaluate the following checkpoints: %s", checkpoints) # Load a trained Encoder model and vocabulary that you have fine-tuned encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[ args.encoder_model_type] model_encoder = encoder_model_class.from_pretrained( output_encoder_dir, latent_size=args.latent_size) tokenizer_encoder = encoder_tokenizer_class.from_pretrained( args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case) model_encoder.to(args.device) if args.block_size <= 0: args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence) # Load a trained Decoder model and vocabulary that you have fine-tuned decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[ args.decoder_model_type] model_decoder = decoder_model_class.from_pretrained( output_decoder_dir, latent_size=args.latent_size) tokenizer_decoder = decoder_tokenizer_class.from_pretrained( args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case) model_decoder.to(args.device) if args.block_size <= 0: args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence) # Load full model output_full_dir = os.path.join(args.checkpoint_dir, 'checkpoint-full-{}'.format(global_step)) checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'), map_location=torch.device('cpu')) # Chunyuan: Add Padding token to GPT2 special_tokens_dict = { 'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>' } num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict) print('We have added', num_added_toks, 'tokens to GPT2') model_decoder.resize_token_embeddings( len(tokenizer_decoder) ) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. assert tokenizer_decoder.pad_token == '<PAD>' # Evaluation model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args) model_vae.load_state_dict(checkpoint['model_state_dict']) logger.info("Pre-trained Optimus is successfully loaded") model_vae.to(args.device) if args.interact_with_user_input: if args.play_mode == 'interpolation': if len(args.sent_source) > 0 and len(args.sent_source) > 0: result = interpolate(model_vae, tokenizer_encoder, tokenizer_decoder, args) else: print('Please check: specify the source and target sentences!') if args.play_mode == 'analogy': if len(args.sent_source) > 0 and len(args.sent_source) > 0 and len( args.sent_input) > 0: result = analogy(model_vae, tokenizer_encoder, tokenizer_decoder, args) else: print( 'Please check: specify the source, target and input analogy sentences!' ) else: result = evaluate_latent_space(args, model_vae, tokenizer_encoder, tokenizer_decoder, prefix=global_step)
def main(args): global logging debug = (args.reconstruct_from != "" or args.eval == True) # don't make exp dir for reconstruction logging = create_exp_dir(args.exp_dir, scripts_to_save=None, debug=debug) if args.cuda: logging('using cuda') logging(str(args)) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} train_data = MonoTextData(args.train_data, label=args.label) vocab = train_data.vocab vocab_size = len(vocab) val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) logging('Train data: %d samples' % len(train_data)) logging('finish reading datasets, vocab size is %d' % len(vocab)) logging('dropped sentences: %d' % train_data.dropped) #sys.stdout.flush() log_niter = (len(train_data) // args.batch_size) // 10 model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) #device = torch.device("cuda" if args.cuda else "cpu") device = "cuda" if args.cuda else "cpu" args.device = device if args.enc_type == 'lstm': encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh else: raise ValueError("the specified encoder type is not supported") decoder = LSTMDecoder(args, vocab, model_init, emb_init) vae = VAE(encoder, decoder, args).to(device) if args.load_path: loaded_state_dict = torch.load(args.load_path) #curr_state_dict = vae.state_dict() #curr_state_dict.update(loaded_state_dict) vae.load_state_dict(loaded_state_dict) logging("%s loaded" % args.load_path) if args.reset_dec: vae.decoder.reset_parameters(model_init, emb_init) if args.eval: logging('begin evaluation') vae.load_state_dict(torch.load(args.load_path)) vae.eval() with torch.no_grad(): test_data_batch = test_data.create_data_batch( batch_size=args.batch_size, device=device, batch_first=True) test(vae, test_data_batch, "TEST", args) au, au_var = calc_au(vae, test_data_batch) logging("%d active units" % au) # print(au_var) test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) nll, ppl = calc_iwnll(vae, test_data_batch, args) logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl)) return if args.reconstruct_from != "": print("begin decoding") sys.stdout.flush() vae.load_state_dict(torch.load(args.reconstruct_from)) vae.eval() with torch.no_grad(): test_data_batch = test_data.create_data_batch( batch_size=args.batch_size, device=device, batch_first=True) # test(vae, test_data_batch, "TEST", args) reconstruct(vae, test_data_batch, vocab, args.decoding_strategy, args.reconstruct_to) return if args.opt == "sgd": enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=args.lr, momentum=args.momentum) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=args.lr, momentum=args.momentum) opt_dict['lr'] = args.lr elif args.opt == "adam": enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001) dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001) opt_dict['lr'] = 0.001 else: raise ValueError("optimizer not supported") iter_ = decay_cnt = 0 best_loss = 1e4 best_kl = best_nll = best_ppl = 0 pre_mi = 0 vae.train() start = time.time() train_data_batch = train_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) # At any point you can hit Ctrl + C to break out of training early. try: for epoch in range(args.epochs): report_kl_loss = report_rec_loss = report_loss = 0 report_num_words = report_num_sents = 0 for i in np.random.permutation(len(train_data_batch)): batch_data = train_data_batch[i] batch_size, sent_len = batch_data.size() # not predict start symbol report_num_words += (sent_len - 1) * batch_size report_num_sents += batch_size kl_weight = args.beta enc_optimizer.zero_grad() dec_optimizer.zero_grad() if args.iw_train_nsamples < 0: loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) else: loss, loss_rc, loss_kl = vae.loss_iw( batch_data, kl_weight, nsamples=args.iw_train_nsamples, ns=ns) loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) loss_rc = loss_rc.sum() loss_kl = loss_kl.sum() enc_optimizer.step() dec_optimizer.step() report_rec_loss += loss_rc.item() report_kl_loss += loss_kl.item() report_loss += loss.item() * batch_size if iter_ % log_niter == 0: #train_loss = (report_rec_loss + report_kl_loss) / report_num_sents train_loss = report_loss / report_num_sents logging('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \ 'time elapsed %.2fs, kl_weight %.4f' % (epoch, iter_, train_loss, report_kl_loss / report_num_sents, report_rec_loss / report_num_sents, time.time() - start, kl_weight)) #sys.stdout.flush() report_rec_loss = report_kl_loss = report_loss = 0 report_num_words = report_num_sents = 0 iter_ += 1 logging('kl weight %.4f' % kl_weight) vae.eval() with torch.no_grad(): loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args) au, au_var = calc_au(vae, val_data_batch) logging("%d active units" % au) # print(au_var) if args.save_ckpt > 0 and epoch <= args.save_ckpt: logging('save checkpoint') torch.save( vae.state_dict(), os.path.join(args.exp_dir, f'model_ckpt_{epoch}.pt')) if loss < best_loss: logging('update best loss') best_loss = loss best_nll = nll best_kl = kl best_ppl = ppl torch.save(vae.state_dict(), args.save_path) if loss > opt_dict["best_loss"]: opt_dict["not_improved"] += 1 if opt_dict[ "not_improved"] >= decay_epoch and epoch >= args.load_best_epoch: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay vae.load_state_dict(torch.load(args.save_path)) logging('new lr: %f' % opt_dict["lr"]) decay_cnt += 1 enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break if epoch % args.test_nepoch == 0: with torch.no_grad(): loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) if args.save_latent > 0 and epoch <= args.save_latent: visualize_latent(args, epoch, vae, "cuda", test_data) vae.train() except KeyboardInterrupt: logging('-' * 100) logging('Exiting from training early') # compute importance weighted estimate of log p(x) vae.load_state_dict(torch.load(args.save_path)) vae.eval() with torch.no_grad(): loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) au, au_var = calc_au(vae, test_data_batch) logging("%d active units" % au) # print(au_var) test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) with torch.no_grad(): nll, ppl = calc_iwnll(vae, test_data_batch, args) logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl))
def main(args): global logging logging = create_exp_dir(args.exp_dir, scripts_to_save=[]) if args.cuda: logging('using cuda') logging(str(args)) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} vocab = {} with open(args.vocab_file) as fvocab: for i, line in enumerate(fvocab): vocab[line.strip()] = i vocab = VocabEntry(vocab) train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab) vocab_size = len(vocab) val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) logging('Train data: %d samples' % len(train_data)) logging('finish reading datasets, vocab size is %d' % len(vocab)) logging('dropped sentences: %d' % train_data.dropped) #sys.stdout.flush() log_niter = max(1, (len(train_data) // (args.batch_size * args.update_every)) // 10) model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) #device = torch.device("cuda" if args.cuda else "cpu") device = "cuda" if args.cuda else "cpu" args.device = device if args.fb == 3: encoder = DeltaGaussianLSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh elif args.enc_type == 'lstm': encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh else: raise ValueError("the specified encoder type is not supported") decoder = LSTMDecoder(args, vocab, model_init, emb_init) vae = VAE(encoder, decoder, args).to(device) if args.load_path: loaded_state_dict = torch.load(args.load_path) #curr_state_dict = vae.state_dict() #curr_state_dict.update(loaded_state_dict) vae.load_state_dict(loaded_state_dict) logging("%s loaded" % args.load_path) # if args.eval: # logging('begin evaluation') # vae.load_state_dict(torch.load(args.load_path)) # vae.eval() # with torch.no_grad(): # test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, # device=device, # batch_first=True) # test(vae, test_data_batch, test_labels_batch, "TEST", args) # au, au_var = calc_au(vae, test_data_batch) # logging("%d active units" % au) # # print(au_var) # test_data_batch = test_data.create_data_batch(batch_size=1, # device=device, # batch_first=True) # calc_iwnll(vae, test_data_batch, args) # return if args.discriminator == "linear": discriminator = LinearDiscriminator(args, vae.encoder).to(device) elif args.discriminator == "mlp": discriminator = MLPDiscriminator(args, vae.encoder).to(device) if args.opt == "sgd": optimizer = optim.SGD(discriminator.parameters(), lr=args.lr, momentum=args.momentum) opt_dict['lr'] = args.lr elif args.opt == "adam": optimizer = optim.Adam(discriminator.parameters(), lr=0.001) opt_dict['lr'] = 0.001 else: raise ValueError("optimizer not supported") iter_ = decay_cnt = 0 best_loss = 1e4 # best_kl = best_nll = best_ppl = 0 # pre_mi = 0 discriminator.train() start = time.time() # kl_weight = args.kl_start # if args.warm_up > 0: # anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size)) # else: # anneal_rate = 0 # dim_target_kl = args.target_kl / float(args.nz) train_data_batch, train_labels_batch = train_data.create_data_batch_labels( batch_size=args.batch_size, device=device, batch_first=True) val_data_batch, val_labels_batch = val_data.create_data_batch_labels( batch_size=128, device=device, batch_first=True) test_data_batch, test_labels_batch = test_data.create_data_batch_labels( batch_size=128, device=device, batch_first=True) acc_cnt = 1 acc_loss = 0. for epoch in range(args.epochs): report_loss = 0 report_correct = report_num_words = report_num_sents = 0 acc_batch_size = 0 optimizer.zero_grad() for i in np.random.permutation(len(train_data_batch)): batch_data = train_data_batch[i] if batch_data.size(0) < 2: continue batch_labels = train_labels_batch[i] batch_labels = [int(x) for x in batch_labels] batch_labels = torch.tensor(batch_labels, dtype=torch.long, requires_grad=False, device=device) batch_size, sent_len = batch_data.size() # not predict start symbol report_num_words += (sent_len - 1) * batch_size report_num_sents += batch_size acc_batch_size += batch_size # (batch_size) loss, correct = discriminator.get_performance( batch_data, batch_labels) acc_loss = acc_loss + loss.sum() if acc_cnt % args.update_every == 0: acc_loss = acc_loss / acc_batch_size acc_loss.backward() torch.nn.utils.clip_grad_norm_(discriminator.parameters(), clip_grad) optimizer.step() optimizer.zero_grad() acc_cnt = 0 acc_loss = 0 acc_batch_size = 0 acc_cnt += 1 report_loss += loss.sum().item() report_correct += correct if iter_ % log_niter == 0: #train_loss = (report_rec_loss + report_kl_loss) / report_num_sents train_loss = report_loss / report_num_sents logging('epoch: %d, iter: %d, avg_loss: %.4f, acc %.4f,' \ 'time %.2fs' % (epoch, iter_, train_loss, report_correct / report_num_sents, time.time() - start)) #sys.stdout.flush() iter_ += 1 logging('lr {}'.format(opt_dict["lr"])) print(report_num_sents) discriminator.eval() with torch.no_grad(): loss, acc = test(discriminator, val_data_batch, val_labels_batch, "VAL", args) # print(au_var) if loss < best_loss: logging('update best loss') best_loss = loss best_acc = acc print(args.save_path) torch.save(discriminator.state_dict(), args.save_path) if loss > opt_dict["best_loss"]: opt_dict["not_improved"] += 1 if opt_dict[ "not_improved"] >= decay_epoch and epoch >= args.load_best_epoch: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay discriminator.load_state_dict(torch.load(args.save_path)) logging('new lr: %f' % opt_dict["lr"]) decay_cnt += 1 if args.opt == "sgd": optimizer = optim.SGD(discriminator.parameters(), lr=opt_dict["lr"], momentum=args.momentum) opt_dict['lr'] = opt_dict["lr"] elif args.opt == "adam": optimizer = optim.Adam(discriminator.parameters(), lr=opt_dict["lr"]) opt_dict['lr'] = opt_dict["lr"] else: raise ValueError("optimizer not supported") else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break if epoch % args.test_nepoch == 0: with torch.no_grad(): loss, acc = test(discriminator, test_data_batch, test_labels_batch, "TEST", args) discriminator.train() # compute importance weighted estimate of log p(x) discriminator.load_state_dict(torch.load(args.save_path)) discriminator.eval() with torch.no_grad(): loss, acc = test(discriminator, test_data_batch, test_labels_batch, "TEST", args)
def main(args, args_model): global logging logging = get_logger_existing_dir(os.path.dirname(args.load_path), 'log_classifier.txt') if args.cuda: logging('using cuda') logging(str(args)) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} vocab = {} if getattr(args, 'vocab_file', None) is not None: with open(args.vocab_file) as fvocab: for i, line in enumerate(fvocab): vocab[line.strip()] = i vocab = VocabEntry(vocab) filename_glob = args.train_data + '.seed_*.n_' + str( args.num_label_per_class) train_sets = glob.glob(filename_glob) print("Train sets:", train_sets) main_train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab) vocab = main_train_data.vocab vocab_size = len(vocab) logging('finish reading datasets, vocab size is %d' % len(vocab)) #sys.stdout.flush() model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) #device = torch.device("cuda" if args.cuda else "cpu") device = "cuda" if args.cuda else "cpu" args_model.device = device if args_model.enc_type == 'lstm': args_model.pooling = getattr(args_model, 'pooling', None) encoder = GaussianLSTMEncoder( args_model, vocab_size, model_init, emb_init, pooling=args_model.pooling, ) elif args_model.enc_type in ['max_avg_pool', 'max_pool', 'avg_pool']: args_model.skip_first_word = getattr(args_model, 'skip_first_word', None) encoder = GaussianPoolEncoder( args_model, vocab_size, model_init, emb_init, enc_type=args_model.enc_type, skip_first_word=args_model.skip_first_word) #args.enc_nh = args.dec_nh else: raise ValueError("the specified encoder type is not supported") args_model.encode_length = getattr(args_model, 'encode_length', None) if args_model.dec_type == 'lstm': decoder = LSTMDecoder(args_model, vocab, model_init, emb_init, args_model.encode_length) elif args_model.dec_type == 'unigram': decoder = UnigramDecoder(args_model, vocab, model_init, emb_init) vae = VAE(encoder, decoder, args_model, args_model.encode_length).to(device) if args.load_path: print("load args!") print(vae) loaded_state_dict = torch.load(args.load_path) vae.load_state_dict(loaded_state_dict) logging("%s loaded" % args.load_path) vae.eval() def preprocess(data_fn): codes, labels = read_dataset(data_fn, vocab, device, vae, args.classify_using_samples) if args.classify_using_samples: is_gaussian_enc = codes.shape[1] == (vae.encoder.nz * 2) codes = augment_dataset(codes, 1, is_gaussian_enc, vae) # use only 1 sample for test codes = codes.cpu().numpy() labels = labels.cpu().numpy() return codes, labels test_codes, test_labels = preprocess(args.test_data) test_f1_scores = [] average_f1 = 'macro' f1_scorer = make_scorer(f1_score, average=average_f1, labels=np.unique(test_labels), greater_is_better=True) # log loss: negative log likelihood. We should minimize that, so greater_is_better=False log_loss_scorer = make_scorer(log_loss, needs_proba=True, greater_is_better=False) warnings.filterwarnings('ignore') results = { 'n_samples_per_class': args.num_label_per_class, } n_repeats = args.n_repeats n_splits = min(args.num_label_per_class, 5) for i, fn in enumerate(train_sets): codes, labels = preprocess(fn) if args.resample > 1: # going to augment the training set by sampling # then create a new cross validation function to get the correct indices cross_val = augment_cross_val(labels, args.resample, n_splits, n_repeats) labels = np.repeat(labels, args.resample) else: cross_val = RepeatedStratifiedKFold(n_splits=n_splits, n_repeats=n_repeats) scaler = StandardScaler() codes = scaler.fit_transform(codes) scaled_test_codes = scaler.transform(test_codes) gridsearch = GridSearchCV( LogisticRegression(solver='sag', multi_class='auto'), { "penalty": ['l2'], "C": [0.01, 0.1, 1, 10, 100], }, cv=cross_val, scoring={ "f1": f1_scorer, "log": log_loss_scorer, }, refit=False, ) clf = gridsearch clf.fit(codes, labels) crossval_f1, test_f1 = refit_and_eval( 'f1', clf, clf.cv_results_, codes, labels, scaled_test_codes, test_labels, f1_scorer, ) crossval_log, test_log_loss = refit_and_eval( 'log', clf, clf.cv_results_, codes, labels, scaled_test_codes, test_labels, log_loss_scorer, ) results[i] = { "F1": { 'crossval': crossval_f1, 'test': test_f1 }, "log": { 'crossval': crossval_log, 'test': test_log_loss }, } print(results[i]) if args.classify_using_samples: n_per_class = str(args.num_label_per_class) resample = 1 if args.resample == -1 else args.resample output_fn = os.path.join( args.exp_dir, 'results_sample_' + str(resample) + '_' + n_per_class + '.json') else: output_fn = os.path.join(args.exp_dir, 'results_' + n_per_class + '.json') with open(output_fn, 'w') as f: json.dump(results, f)