def train_vae(model, train_iter, valid_iter, tgtvocab, optim): #train_iter = make_train_data_iter(train_data, opt) #valid_iter = make_valid_data_iter(valid_data, opt) train_loss = Loss.VAELoss(model.generator, tgtvocab) valid_loss = Loss.VAELoss(model.generator, tgtvocab) if use_gpu(opt): train_loss = train_loss.cuda() valid_loss = valid_loss.cuda() trunc_size = opt.truncated_decoder # Badly named... default=0 shard_size = opt.max_generator_batches #default=32 trainer = Trainer.VaeTrainer(model, train_iter, valid_iter, train_loss, valid_loss, optim) for epoch in range(opt.start_epoch, opt.epochs + 1): print('') # 1. Train for one epoch on the training set. train_stats = trainer.train(epoch, report_func) print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) # 2. Validate on the validation set. valid_stats = trainer.validate() print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy()) # 3. Log to remote server. if opt.exp_host: train_stats.log("train", experiment, optim.lr) valid_stats.log("valid", experiment, optim.lr) # 4. Update the learning rate trainer.epoch_step(valid_stats.ppl(), epoch) # 5. Drop a checkpoint if needed. if epoch >= opt.start_checkpoint_at: trainer.drop_checkpoint(opt, epoch, valid_stats) train_loss.VAE_weightaneal(epoch) valid_loss.VAE_weightaneal(epoch) model.encoder.Varianceanneal()
def main(): rebuild_vocab = False if rebuild_vocab: trainfile = '/D/home/lili/mnt/DATA/convaws/convdata/conv-test_v.json' train = pd.read_json(trainfile) print('Read training data from: {}'.format(trainfile)) valfile = '/D/home/lili/mnt/DATA/convaws/convdata/conv-val_v.json' val = pd.read_json(valfile) print('Read validation data from: {}'.format(valfile)) train_srs = train.context.values.tolist() train_tgt = train.replies.values.tolist() val_srs = val.context.values.tolist() val_tgt = val.replies.values.tolist() src_vocab, _ = hierdata.buildvocab(train_srs + val_srs) tgt_vocab, tgtwords = hierdata.buildvocab(train_tgt + val_tgt) else: print('load vocab from pt file') dicts = torch.load('test_vocabs.pt') #tgt = pd.read_json('./tgt.json') #src = pd.read_json('./src.json') src_vocab = dicts['src_word2id'] tgt_vocab = dicts['tgt_word2id'] tgtwords = dicts['tgt_id2word'] print('source vocab size: {}'.format(len(src_vocab))) print('source vocab test, bill: {} , {}'.format( src_vocab['<pad>'], src_vocab['bill'])) print('target vocab size: {}'.format(len(tgt_vocab))) print('target vocab test, bill: {}, {}'.format(tgt_vocab['<pad>'], tgt_vocab['bill'])) print('target vocat testing:') print('word: <pad> get :{}'.format(tgtwords[tgt_vocab['<pad>']])) print('word: bill get :{}'.format(tgtwords[tgt_vocab['bill']])) print('word: service get :{}'.format(tgtwords[tgt_vocab['service']])) parser = argparse.ArgumentParser(description='train.py') # opts.py opts.add_md_help_argument(parser) opts.model_opts(parser) opts.train_opts(parser) opt = parser.parse_args() dummy_opt = parser.parse_known_args([])[0] opt.cuda = opt.gpuid[0] > -1 if opt.cuda: torch.cuda.set_device(opt.gpuid[0]) checkpoint = opt.model print('Building model...') model = ModelHVAE.make_base_model( opt, src_vocab, tgt_vocab, opt.cuda, checkpoint ) ### Done #### How to integrate the two embedding layers... print(model) tally_parameters(model) ### Done testfile = '/D/home/lili/mnt/DATA/convaws/convdata/conv-val_v.json' test = pd.read_json(testfile) print('Test training data from: {}'.format(testfile)) test_srs = test.context.values.tolist() test_tgt = test.replies.values.tolist() test_batch_size = 16 test_iter = data_util.gen_minibatch(test_srs, test_tgt, test_batch_size, src_vocab, tgt_vocab) tgtvocab = tgt_vocab optim = Optim.Optim('adam', 1e-3, 5) train_loss = Loss.VAELoss(model.generator, tgtvocab) valid_loss = Loss.VAELoss(model.generator, tgtvocab) trainer = Trainer.VaeTrainer(model, test_iter, test_iter, train_loss, valid_loss, optim) valid_stats = trainer.validate() print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy())