Esempio n. 1
0
def train_vae(model, train_iter, valid_iter, tgtvocab, optim):

    #train_iter = make_train_data_iter(train_data, opt)
    #valid_iter = make_valid_data_iter(valid_data, opt)

    train_loss = Loss.VAELoss(model.generator, tgtvocab)
    valid_loss = Loss.VAELoss(model.generator, tgtvocab)

    if use_gpu(opt):
        train_loss = train_loss.cuda()
        valid_loss = valid_loss.cuda()

    trunc_size = opt.truncated_decoder  # Badly named... default=0
    shard_size = opt.max_generator_batches  #default=32

    trainer = Trainer.VaeTrainer(model, train_iter, valid_iter, train_loss,
                                 valid_loss, optim)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        # 1. Train for one epoch on the training set.
        train_stats = trainer.train(epoch, report_func)

        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())

        # 2. Validate on the validation set.
        valid_stats = trainer.validate()
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)

        # 4. Update the learning rate
        trainer.epoch_step(valid_stats.ppl(), epoch)

        # 5. Drop a checkpoint if needed.

        if epoch >= opt.start_checkpoint_at:
            trainer.drop_checkpoint(opt, epoch, valid_stats)

        train_loss.VAE_weightaneal(epoch)
        valid_loss.VAE_weightaneal(epoch)
        model.encoder.Varianceanneal()
Esempio n. 2
0
def main():
    rebuild_vocab = False
    if rebuild_vocab:
        trainfile = '/D/home/lili/mnt/DATA/convaws/convdata/conv-test_v.json'
        train = pd.read_json(trainfile)
        print('Read training data from: {}'.format(trainfile))

        valfile = '/D/home/lili/mnt/DATA/convaws/convdata/conv-val_v.json'
        val = pd.read_json(valfile)
        print('Read validation data from: {}'.format(valfile))
        train_srs = train.context.values.tolist()
        train_tgt = train.replies.values.tolist()
        val_srs = val.context.values.tolist()
        val_tgt = val.replies.values.tolist()
        src_vocab, _ = hierdata.buildvocab(train_srs + val_srs)
        tgt_vocab, tgtwords = hierdata.buildvocab(train_tgt + val_tgt)

    else:
        print('load vocab from pt file')
        dicts = torch.load('test_vocabs.pt')
        #tgt = pd.read_json('./tgt.json')
        #src = pd.read_json('./src.json')
        src_vocab = dicts['src_word2id']
        tgt_vocab = dicts['tgt_word2id']
        tgtwords = dicts['tgt_id2word']
        print('source vocab size: {}'.format(len(src_vocab)))
        print('source vocab test, bill: {} , {}'.format(
            src_vocab['<pad>'], src_vocab['bill']))
        print('target vocab size: {}'.format(len(tgt_vocab)))
        print('target vocab test, bill: {}, {}'.format(tgt_vocab['<pad>'],
                                                       tgt_vocab['bill']))
        print('target vocat testing:')
        print('word: <pad> get :{}'.format(tgtwords[tgt_vocab['<pad>']]))
        print('word: bill get :{}'.format(tgtwords[tgt_vocab['bill']]))
        print('word: service get :{}'.format(tgtwords[tgt_vocab['service']]))

    parser = argparse.ArgumentParser(description='train.py')

    # opts.py
    opts.add_md_help_argument(parser)
    opts.model_opts(parser)
    opts.train_opts(parser)
    opt = parser.parse_args()

    dummy_opt = parser.parse_known_args([])[0]

    opt.cuda = opt.gpuid[0] > -1
    if opt.cuda:
        torch.cuda.set_device(opt.gpuid[0])

    checkpoint = opt.model
    print('Building model...')
    model = ModelHVAE.make_base_model(
        opt, src_vocab, tgt_vocab, opt.cuda, checkpoint
    )  ### Done  #### How to integrate the two embedding layers...
    print(model)
    tally_parameters(model)  ### Done

    testfile = '/D/home/lili/mnt/DATA/convaws/convdata/conv-val_v.json'
    test = pd.read_json(testfile)
    print('Test training data from: {}'.format(testfile))

    test_srs = test.context.values.tolist()
    test_tgt = test.replies.values.tolist()

    test_batch_size = 16
    test_iter = data_util.gen_minibatch(test_srs, test_tgt, test_batch_size,
                                        src_vocab, tgt_vocab)

    tgtvocab = tgt_vocab

    optim = Optim.Optim('adam', 1e-3, 5)
    train_loss = Loss.VAELoss(model.generator, tgtvocab)
    valid_loss = Loss.VAELoss(model.generator, tgtvocab)
    trainer = Trainer.VaeTrainer(model, test_iter, test_iter, train_loss,
                                 valid_loss, optim)
    valid_stats = trainer.validate()
    print('Validation perplexity: %g' % valid_stats.ppl())
    print('Validation accuracy: %g' % valid_stats.accuracy())