예제 #1
0
def main(args):
    # Set the random seed manually for reproducibility.
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
    else:
        print("Note that our pre-trained models require CUDA to evaluate.")

    ###########################################################################
    # Load the models
    ###########################################################################

    ae_args, gan_args, idx2word, autoencoder, gan_gen, gan_disc \
        = load_models(args.ae_args, args.gan_args, args.vocab_file,
                      args.ae_model, args.g_model, args.d_model)

    ###########################################################################
    # Generation code
    ###########################################################################

    # Generate sentences
    corpus = Corpus(args.data_path, args.dict_file, vocab_size=len(idx2word))

    source, _ = next(BatchGen(corpus.get_chunks(size=2), args.ngenerations))
    prev_sent = [
        decode_idx(corpus.dictionary, sent) for sent in source.tolist()
    ]
    source = Variable(source, volatile=True)
    sentences = generate(autoencoder,
                         gan_gen,
                         inp=source,
                         vocab=idx2word,
                         sample=args.sample,
                         maxlen=args.maxlen)

    if not args.noprint:
        print("\nSentence generations:\n")
        for prev, sent in zip(prev_sent, sentences):
            print(prev)
            print("    ", sent)
            print("")
    with open(args.outf, "w") as f:
        f.write("Sentence generations:\n\n")
        for prev, sent in zip(prev_sent, sentences):
            f.write(prev + '\n')
            f.write("-> " + sent + '\n\n')
예제 #2
0
파일: train.py 프로젝트: JasonK93/Deechat
def main():
    state_dict = torch.load(args.ae_model)
    with open(args.ae_args) as f:
        ae_args = json.load(f)

    corpus = Corpus(args.data_file,
                    args.dict_file,
                    vocab_size=ae_args['vocab_size'])
    autoencoder = Seq2Seq(emsize=ae_args['emsize'],
                          nhidden=ae_args['nhidden'],
                          ntokens=ae_args['ntokens'],
                          nlayers=ae_args['nlayers'],
                          noise_radius=ae_args['noise_radius'],
                          hidden_init=ae_args['hidden_init'],
                          dropout=ae_args['dropout'],
                          gpu=args.cuda)
    autoencoder.load_state_dict(state_dict)
    for param in autoencoder.parameters():
        param.requires_grad = False
    # save arguments
    with open(os.path.join(out_dir, 'args.json'), 'w') as f:
        json.dump(vars(args), f)
    log.info('[Data and AE model loaded.]')

    gan_gen = MLP_G(ninput=args.nhidden,
                    noutput=args.nhidden,
                    layers=args.arch_g)
    gan_disc = MLP_D(ninput=2 * args.nhidden, noutput=1, layers=args.arch_d)
    optimizer_gan_g = optim.Adam(gan_gen.parameters(),
                                 lr=args.lr_gan_g,
                                 betas=(args.beta1, 0.999))
    optimizer_gan_d = optim.Adam(gan_disc.parameters(),
                                 lr=args.lr_gan_d,
                                 betas=(args.beta1, 0.999))
    criterion_ce = nn.CrossEntropyLoss()

    if args.cuda:
        autoencoder = autoencoder.cuda()
        gan_gen = gan_gen.cuda()
        gan_disc = gan_disc.cuda()
        criterion_ce = criterion_ce.cuda()

    one = to_gpu(args.cuda, torch.FloatTensor([1]))
    mone = one * -1
    train_pairs = BatchGen(corpus.get_chunks(size=2), args.batch_size)

    def train_gan_g(batch):
        gan_gen.train()
        gan_gen.zero_grad()

        source, _ = batch
        source = to_gpu(args.cuda, Variable(source))
        source_hidden = autoencoder(source, noise=False, encode_only=True)

        fake_hidden = gan_gen(source_hidden)
        errG = gan_disc(source_hidden, fake_hidden)

        # loss / backprop
        errG.backward(one)
        optimizer_gan_g.step()

        return errG

    def train_gan_d(batch):
        # clamp parameters to a cube
        for p in gan_disc.parameters():
            p.data.clamp_(-args.gan_clamp, args.gan_clamp)

        gan_disc.train()
        gan_disc.zero_grad()

        # positive samples ----------------------------
        # generate real codes
        source, target = batch
        source = to_gpu(args.cuda, Variable(source))
        target = to_gpu(args.cuda, Variable(target))

        # batch_size x nhidden
        source_hidden = autoencoder(source, noise=False, encode_only=True)
        target_hidden = autoencoder(target, noise=False, encode_only=True)

        # loss / backprop
        errD_real = gan_disc(source_hidden, target_hidden)
        errD_real.backward(one)

        # negative samples ----------------------------

        # loss / backprop
        fake_hidden = gan_gen(source_hidden)
        errD_fake = gan_disc(source_hidden.detach(), fake_hidden.detach())
        errD_fake.backward(mone)

        optimizer_gan_d.step()
        errD = -(errD_real - errD_fake)

        return errD, errD_real, errD_fake

    niter = 0
    start_time = datetime.now()

    for t in range(args.updates):
        niter += 1

        # train discriminator/critic
        for i in range(args.niters_gan_d):
            # feed a seen sample within this epoch; good for early training
            errD, errD_real, errD_fake = \
                train_gan_d(next(train_pairs))

        # train generator
        for i in range(args.niters_gan_g):
            errG = train_gan_g(next(train_pairs))

        if niter % args.log_interval == 0:
            eta = str((datetime.now() - start_time) / (t + 1) *
                      (args.updates - t - 1)).split('.')[0]
            log.info('[{}/{}] Loss_D: {:.6f} (real: {:.6f} '
                     'fake: {:.6f}) Loss_G: {:.6f} ETA: {}'.format(
                         niter, args.updates,
                         errD.data.cpu()[0],
                         errD_real.data.cpu()[0],
                         errD_fake.data.cpu()[0],
                         errG.data.cpu()[0], eta))
        if niter % args.save_interval == 0:
            save_model(gan_gen, out_dir, 'gan_gen_model_{}.pt'.format(t))
            save_model(gan_disc, out_dir, 'gan_disc_model_{}.pt'.format(t))