def main(args): # Set the random seed manually for reproducibility. random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) else: print("Note that our pre-trained models require CUDA to evaluate.") ########################################################################### # Load the models ########################################################################### ae_args, gan_args, idx2word, autoencoder, gan_gen, gan_disc \ = load_models(args.ae_args, args.gan_args, args.vocab_file, args.ae_model, args.g_model, args.d_model) ########################################################################### # Generation code ########################################################################### # Generate sentences corpus = Corpus(args.data_path, args.dict_file, vocab_size=len(idx2word)) source, _ = next(BatchGen(corpus.get_chunks(size=2), args.ngenerations)) prev_sent = [ decode_idx(corpus.dictionary, sent) for sent in source.tolist() ] source = Variable(source, volatile=True) sentences = generate(autoencoder, gan_gen, inp=source, vocab=idx2word, sample=args.sample, maxlen=args.maxlen) if not args.noprint: print("\nSentence generations:\n") for prev, sent in zip(prev_sent, sentences): print(prev) print(" ", sent) print("") with open(args.outf, "w") as f: f.write("Sentence generations:\n\n") for prev, sent in zip(prev_sent, sentences): f.write(prev + '\n') f.write("-> " + sent + '\n\n')
def main(): state_dict = torch.load(args.ae_model) with open(args.ae_args) as f: ae_args = json.load(f) corpus = Corpus(args.data_file, args.dict_file, vocab_size=ae_args['vocab_size']) autoencoder = Seq2Seq(emsize=ae_args['emsize'], nhidden=ae_args['nhidden'], ntokens=ae_args['ntokens'], nlayers=ae_args['nlayers'], noise_radius=ae_args['noise_radius'], hidden_init=ae_args['hidden_init'], dropout=ae_args['dropout'], gpu=args.cuda) autoencoder.load_state_dict(state_dict) for param in autoencoder.parameters(): param.requires_grad = False # save arguments with open(os.path.join(out_dir, 'args.json'), 'w') as f: json.dump(vars(args), f) log.info('[Data and AE model loaded.]') gan_gen = MLP_G(ninput=args.nhidden, noutput=args.nhidden, layers=args.arch_g) gan_disc = MLP_D(ninput=2 * args.nhidden, noutput=1, layers=args.arch_d) optimizer_gan_g = optim.Adam(gan_gen.parameters(), lr=args.lr_gan_g, betas=(args.beta1, 0.999)) optimizer_gan_d = optim.Adam(gan_disc.parameters(), lr=args.lr_gan_d, betas=(args.beta1, 0.999)) criterion_ce = nn.CrossEntropyLoss() if args.cuda: autoencoder = autoencoder.cuda() gan_gen = gan_gen.cuda() gan_disc = gan_disc.cuda() criterion_ce = criterion_ce.cuda() one = to_gpu(args.cuda, torch.FloatTensor([1])) mone = one * -1 train_pairs = BatchGen(corpus.get_chunks(size=2), args.batch_size) def train_gan_g(batch): gan_gen.train() gan_gen.zero_grad() source, _ = batch source = to_gpu(args.cuda, Variable(source)) source_hidden = autoencoder(source, noise=False, encode_only=True) fake_hidden = gan_gen(source_hidden) errG = gan_disc(source_hidden, fake_hidden) # loss / backprop errG.backward(one) optimizer_gan_g.step() return errG def train_gan_d(batch): # clamp parameters to a cube for p in gan_disc.parameters(): p.data.clamp_(-args.gan_clamp, args.gan_clamp) gan_disc.train() gan_disc.zero_grad() # positive samples ---------------------------- # generate real codes source, target = batch source = to_gpu(args.cuda, Variable(source)) target = to_gpu(args.cuda, Variable(target)) # batch_size x nhidden source_hidden = autoencoder(source, noise=False, encode_only=True) target_hidden = autoencoder(target, noise=False, encode_only=True) # loss / backprop errD_real = gan_disc(source_hidden, target_hidden) errD_real.backward(one) # negative samples ---------------------------- # loss / backprop fake_hidden = gan_gen(source_hidden) errD_fake = gan_disc(source_hidden.detach(), fake_hidden.detach()) errD_fake.backward(mone) optimizer_gan_d.step() errD = -(errD_real - errD_fake) return errD, errD_real, errD_fake niter = 0 start_time = datetime.now() for t in range(args.updates): niter += 1 # train discriminator/critic for i in range(args.niters_gan_d): # feed a seen sample within this epoch; good for early training errD, errD_real, errD_fake = \ train_gan_d(next(train_pairs)) # train generator for i in range(args.niters_gan_g): errG = train_gan_g(next(train_pairs)) if niter % args.log_interval == 0: eta = str((datetime.now() - start_time) / (t + 1) * (args.updates - t - 1)).split('.')[0] log.info('[{}/{}] Loss_D: {:.6f} (real: {:.6f} ' 'fake: {:.6f}) Loss_G: {:.6f} ETA: {}'.format( niter, args.updates, errD.data.cpu()[0], errD_real.data.cpu()[0], errD_fake.data.cpu()[0], errG.data.cpu()[0], eta)) if niter % args.save_interval == 0: save_model(gan_gen, out_dir, 'gan_gen_model_{}.pt'.format(t)) save_model(gan_disc, out_dir, 'gan_disc_model_{}.pt'.format(t))