Beispiel #1
0
def _get_parser():
    parser = ArgumentParser(description='train.py')
    # Construct config
    opts.config_opts(parser)
    opts.model_opts(parser)
    opts.train_opts(parser)
    return parser
Beispiel #2
0
def main():
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    # engine = DBEngine(opt.db_file)

    with codecs.open(opt.source_file, "r", "utf-8") as corpus_file:
        sql_list = [json.loads(line)['sql'] for line in corpus_file]

    js_list = table.IO.read_anno_json(opt.anno)

    prev_best = (None, None)
    for fn_model in glob.glob(opt.model_path):
        print(fn_model)
        print(opt.anno)
        opt.model = fn_model

        translator = table.Translator(opt, dummy_opt.__dict__)
        data = table.IO.TableDataset(js_list, translator.fields, None, False)
        test_data = table.IO.OrderedIterator(dataset=data,
                                             device=opt.gpu,
                                             batch_size=opt.batch_size,
                                             train=False,
                                             sort=True,
                                             sort_within_batch=False)

        # inference
        r_list = []
        for batch in test_data:
            r_list += translator.translate(batch)
        r_list.sort(key=lambda x: x.idx)
Beispiel #3
0
def main():
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    engine = DBEngine(opt.db_file)

    with codecs.open(opt.source_file, "r", "utf-8") as corpus_file:
        sql_list = [json.loads(line)['sql'] for line in corpus_file]

    js_list = table.IO.read_anno_json(opt.anno)

    prev_best = (None, None)
    for fn_model in glob.glob(opt.model_path):

        opt.model = fn_model

        translator = Translator(opt, dummy_opt.__dict__)
        data = table.IO.TableDataset(js_list, translator.fields, None, False)
        test_data = table.IO.OrderedIterator(dataset=data,
                                             device=opt.gpu,
                                             batch_size=opt.batch_size,
                                             train=False,
                                             sort=True,
                                             sort_within_batch=False)

        # inference
        if opt.beam_search:
            print('Using execution guidance for inference.')
        r_list = []

        for batch in test_data:
            r_list += translator.translate(batch, js_list, sql_list)

        r_list.sort(key=lambda x: x.idx)

        assert len(r_list) == len(
            js_list), 'len(r_list) != len(js_list): {} != {}'.format(
                len(r_list), len(js_list))

        # evaluation
        for pred, gold, sql_gold in zip(r_list, js_list, sql_list):
            pred.eval(gold, sql_gold, engine)
        print('Results:')
        for metric_name in ('all', 'exe'):
            c_correct = sum((x.correct[metric_name] for x in r_list))
            print('{}: {} / {} = {:.2%}'.format(metric_name, c_correct,
                                                len(r_list),
                                                c_correct / len(r_list)))
            if metric_name == 'all' and (prev_best[0] is None
                                         or c_correct > prev_best[1]):
                prev_best = (fn_model, c_correct)

    if (opt.split == 'dev') and (prev_best[0] is not None):
        with codecs.open(os.path.join(opt.data_path, 'dev_best.txt'),
                         'w',
                         encoding='utf-8') as f_out:
            f_out.write('{}\n'.format(prev_best[0]))
Beispiel #4
0
def main():
    parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(parser)
    opts.train_opts(parser)
    opts.data_opts(parser)
    opts.score_opts(parser)
    add_md_help_argument(parser)
    options = parser.parse_args()
Beispiel #5
0
Datei: svr.py Projekt: marcwww/LL
def load_opt():
    parser = argparse. \
        ArgumentParser(description='main.py',
                       formatter_class=argparse.
                       ArgumentDefaultsHelpFormatter)

    opts.model_opts(parser)
    opts.train_opts(parser)
    opt = parser.parse_args()
    return opt
Beispiel #6
0
def main(anno_file_name, col_headers, raw_args=None, verbose=True):
    parser = argparse.ArgumentParser(description='evaluate.py')
    opts.translate_opts(parser)
    opt = parser.parse_args(raw_args)
    torch.cuda.set_device(opt.gpu)
    opt.db_file = os.path.join(opt.data_path, '{}.db'.format(opt.split))
    opt.pre_word_vecs = os.path.join(opt.data_path, 'embedding')
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]
    opt.anno = anno_file_name

    engine = DBEngine(opt.db_file)

    js_list = table.IO.read_anno_json(opt.anno)

    prev_best = (None, None)
    sql_query = []
    for fn_model in glob.glob(opt.model_path):

        opt.model = fn_model

        translator = Translator(opt, dummy_opt.__dict__)
        data = table.IO.TableDataset(js_list, translator.fields, None, False)
        test_data = table.IO.OrderedIterator(dataset=data,
                                             device=opt.gpu,
                                             batch_size=opt.batch_size,
                                             train=False,
                                             sort=True,
                                             sort_within_batch=False)

        # inference
        r_list = []
        for batch in test_data:
            r_list += translator.translate(batch)
        r_list.sort(key=lambda x: x.idx)
        pred = r_list[-1]
        sql_pred = {
            'agg': pred.agg,
            'sel': pred.sel,
            'conds': pred.recover_cond_to_gloss(js_list[-1])
        }
        if verbose:
            print('\n sql_pred: ', sql_pred, '\n')
            print('\n col_headers: ', col_headers, '\n')
        sql_query = Query(sql_pred['sel'], sql_pred['agg'], sql_pred['conds'])
        try:
            ans_pred = engine.execute_query(js_list[-1]['table_id'],
                                            Query.from_dict(sql_pred),
                                            lower=True,
                                            verbose=verbose)
        except Exception as e:
            ans_pred = None
    return sql_query.get_complete_query(col_headers), ans_pred
def main():
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    js_list = table.IO.read_anno_json(opt.anno, opt)

    metric_name_list = ['tgt']
    prev_best = (None, None)
    for fn_model in glob.glob(opt.model_path):
        opt.model = fn_model
        print(fn_model)
        print(opt.anno)

        translator = table.Translator(opt, dummy_opt.__dict__)
        data = table.IO.TableDataset(js_list, translator.fields, 0, None,
                                     False)
        test_data = table.IO.OrderedIterator(dataset=data,
                                             device=opt.gpu,
                                             batch_size=opt.batch_size,
                                             train=False,
                                             sort=True,
                                             sort_within_batch=False)

        # inference
        r_list = []
        for batch in test_data:
            r = translator.translate(batch)
            r_list += r
        r_list.sort(key=lambda x: x.idx)
        assert len(r_list) == len(
            js_list), 'len(r_list) != len(js_list): {} != {}'.format(
                len(r_list), len(js_list))

        # evaluation
        for pred, gold in zip(r_list, js_list):
            pred.eval(gold)
        print('Results:')
        for metric_name in metric_name_list:
            c_correct = sum((x.correct[metric_name] for x in r_list))
            acc = c_correct / len(r_list)
            print('{}: {} / {} = {:.2%}'.format(metric_name, c_correct,
                                                len(r_list), acc))
            if metric_name == 'tgt' and (prev_best[0] is None
                                         or acc > prev_best[1]):
                prev_best = (fn_model, acc)

    if (opt.split == 'dev') and (prev_best[0] is not None):
        with codecs.open(os.path.join(opt.root_dir, opt.dataset,
                                      'dev_best.txt'),
                         'w',
                         encoding='utf-8') as f_out:
            f_out.write('{}\n'.format(prev_best[0]))
Beispiel #8
0
def main():
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    js_list = table.IO.read_anno_json(opt.anno, opt)
    # metric_name_list = ['tgt']
    prev_best = (None, None)
    # print(opt.model_path)
    for fn_model in glob.glob(opt.model_path):
        opt.model = fn_model
        print(fn_model)
        with torch.no_grad():
            translator = table.Translator(opt, dummy_opt.__dict__)
            data = table.IO.TableDataset(js_list, translator.fields, 0, None,
                                         False)
            test_data = table.IO.OrderedIterator(dataset=data,
                                                 device=opt.gpu,
                                                 batch_size=opt.batch_size,
                                                 train=False,
                                                 sort=True,
                                                 sort_within_batch=False)
            # inference
            r_list = []
            for batch in test_data:
                r = translator.translate(batch)
                r_list += r

        r_list.sort(key=lambda x: x.idx)
        assert len(r_list) == len(
            js_list), 'len(r_list) != len(js_list): {} != {}'.format(
                len(r_list), len(js_list))

        metric, _ = com_metric(js_list, r_list)
    if opt.split == 'test':
        ref_dic, pre_dict = effect_len(js_list, r_list)
        for i in range(len(ref_dic)):
            js_list = ref_dic[i]
            r_list = pre_dict[i]
            print("the effect of length {}".format(i))
            metric, _ = com_metric(js_list, r_list)

        if prev_best[0] is None or float(metric['Bleu_1']) > prev_best[1]:
            prev_best = (fn_model, metric['Bleu_1'])

    if (opt.split == 'dev') and (prev_best[0] is not None):
        with codecs.open(os.path.join(opt.root_dir, opt.dataset,
                                      'dev_best.txt'),
                         'w',
                         encoding='utf-8') as f_out:
            f_out.write('{}\n'.format(prev_best[0]))
Beispiel #9
0
def parse_args():
    parser = argparse.ArgumentParser(
        description='umt.py',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    opts.add_md_help_argument(parser)
    opts.model_opts(parser)
    opts.preprocess_opts(parser)
    opts.train_opts(parser)

    opt = parser.parse_args()
    torch.manual_seed(opt.seed)

    if opt.word_vec_size != -1:
        opt.src_word_vec_size = opt.word_vec_size
        opt.tgt_word_vec_size = opt.word_vec_size

    if opt.layers != -1:
        opt.enc_layers = opt.layers
        opt.dec_layers = opt.layers

    opt.brnn = (opt.encoder_type == "brnn")

    # if opt.seed > 0:
    random.seed(opt.seed)
    torch.manual_seed(opt.seed)

    if torch.cuda.is_available() and not opt.gpuid:
        print("WARNING: You have a CUDA device, should run with -gpuid 0")

    if opt.gpuid:
        cuda.set_device(opt.gpuid[0])
        if opt.seed > 0:
            torch.cuda.manual_seed(opt.seed)

    if len(opt.gpuid) > 1:
        sys.stderr.write("Sorry, multigpu isn't supported yet, coming soon!\n")
        sys.exit(1)

    # Set up the Crayon logging server.
    if opt.exp_host != "":
        from pycrayon import CrayonClient

        cc = CrayonClient(hostname=opt.exp_host)

        experiments = cc.get_experiment_names()
        print(experiments)
        if opt.exp in experiments:
            cc.remove_experiment(opt.exp)

    return opt
Beispiel #10
0
def main():
    parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(parser)
    opts.train_opts(parser)
    opts.data_opts(parser)
    opts.score_opts(parser)
    options = parser.parse_args()

    print(options)

    argfile = options.save_model + '_arg.p'

    print('Saving arguments in ' + argfile)
    pickle.dump(options, open(argfile, "wb"))

    train(options)
Beispiel #11
0
import argparse
import glob

print torch.cuda.is_available()
print cuda.device_count()
print cuda.current_device()

parser = argparse.ArgumentParser(
    description='train.py',
    formatter_class=argparse.ArgumentDefaultsHelpFormatter)

# opts.py
opts.add_md_help_argument(parser)
opts.model_opts(parser)
opts.train_opts(parser)

opt = parser.parse_args()
if opt.word_vec_size != -1:
    opt.src_word_vec_size = opt.word_vec_size
    opt.tgt_word_vec_size = opt.word_vec_size

if opt.layers != -1:
    opt.enc_layers = opt.layers
    opt.dec_layers = opt.layers

opt.brnn = (opt.encoder_type == "brnn")
if opt.seed > 0:
    random.seed(opt.seed)
    torch.manual_seed(opt.seed)
Beispiel #12
0
import unittest
import math

import torch
from torch.autograd import Variable

import onmt
import onmt.io
import opts
from onmt.ModelConstructor import make_embeddings, \
                            make_encoder, make_decoder
from onmt.modules import ImageEncoder, AudioEncoder

parser = argparse.ArgumentParser(description='train.py')
opts.model_opts(parser)
opts.train_opts(parser)

# -data option is required, but not used in this test, so dummy.
opt = parser.parse_known_args(['-data', 'dummy'])[0]


class TestModel(unittest.TestCase):

    def __init__(self, *args, **kwargs):
        super(TestModel, self).__init__(*args, **kwargs)
        self.opt = opt

    # Helper to generate a vocabulary

    def get_vocab(self):
        src = onmt.io.get_fields("text", 0, 0)["src"]
Beispiel #13
0
def main():
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    js_list = table.IO.read_anno_json(opt.anno, opt)

    metric_name_list = ['tgt']
    prev_best = (None, None)
    for fn_model in glob.glob(opt.model_path):
        opt.model = fn_model
        print(fn_model)
        print(opt.anno)

        translator = table.Translator(opt, dummy_opt.__dict__)
        data = table.IO.TableDataset(js_list, translator.fields, 0, None,
                                     False)
        test_data = table.IO.OrderedIterator(dataset=data,
                                             device=opt.gpu,
                                             batch_size=opt.batch_size,
                                             train=False,
                                             sort=True,
                                             sort_within_batch=False)

        # inference
        r_list = []
        for batch in test_data:
            r = translator.translate(batch)
            r_list += r
        r_list.sort(key=lambda x: x.idx)
        assert len(r_list) == len(
            js_list), 'len(r_list) != len(js_list): {} != {}'.format(
                len(r_list), len(js_list))

        # evaluation
        for pred, gold in zip(r_list, js_list):
            print("pred tgt: ", pred.tgt)
            print("pred lay: ", pred.lay)
            print("gold:", gold)

            pred.eval(gold)
        print('Results:')
        for metric_name in metric_name_list:
            c_correct = sum((x.correct[metric_name] for x in r_list))
            acc = c_correct / len(r_list)
            print('{}: {} / {} = {:.2%}'.format(metric_name, c_correct,
                                                len(r_list), acc))
            if metric_name == 'tgt' and (prev_best[0] is None
                                         or acc > prev_best[1]):
                prev_best = (fn_model, acc)

        # calcualte bleu score
        pred_tgt_tokens = [pred.tgt for pred in r_list]
        gold_tgt_tokens = [gold['tgt'] for gold in js_list]
        # print('pred_tgt_tokens[0]', pred_tgt_tokens[0])
        # print('gold_tgt_tokens[0]', gold_tgt_tokens[0])
        bleu_score = table.modules.bleu_score.compute_bleu(gold_tgt_tokens,
                                                           pred_tgt_tokens,
                                                           smooth=False)
        bleu_score = bleu_score[0]

        bleu_score_nltk = corpus_bleu(
            gold_tgt_tokens,
            pred_tgt_tokens,
            smoothing_function=SmoothingFunction().method3)

        print('{}: = {:.4}'.format('tgt blue score', bleu_score))

        print('{}: = {:.4}'.format('tgt nltk blue score', bleu_score_nltk))

    if (opt.split == 'dev') and (prev_best[0] is not None):
        with codecs.open(os.path.join(opt.root_dir, opt.dataset,
                                      'dev_best.txt'),
                         'w',
                         encoding='utf-8') as f_out:
            f_out.write('{}\n'.format(prev_best[0]))
Beispiel #14
0
 def add_cmdline_args(argparser):
     # opts.py
     opts.add_md_help_argument(argparser)
     opts.model_opts(argparser)
     opts.train_opts(argparser)
     opt = argparser.parse_args()
Beispiel #15
0
def main():
    rebuild_vocab = False
    if rebuild_vocab:
        trainfile = '/D/home/lili/mnt/DATA/convaws/convdata/conv-test_v.json'
        train = pd.read_json(trainfile)
        print('Read training data from: {}'.format(trainfile))

        valfile = '/D/home/lili/mnt/DATA/convaws/convdata/conv-val_v.json'
        val = pd.read_json(valfile)
        print('Read validation data from: {}'.format(valfile))
        train_srs = train.context.values.tolist()
        train_tgt = train.replies.values.tolist()
        val_srs = val.context.values.tolist()
        val_tgt = val.replies.values.tolist()
        src_vocab, _ = hierdata.buildvocab(train_srs + val_srs)
        tgt_vocab, tgtwords = hierdata.buildvocab(train_tgt + val_tgt)

    else:
        print('load vocab from pt file')
        dicts = torch.load('test_vocabs.pt')
        #tgt = pd.read_json('./tgt.json')
        #src = pd.read_json('./src.json')
        src_vocab = dicts['src_word2id']
        tgt_vocab = dicts['tgt_word2id']
        tgtwords = dicts['tgt_id2word']
        print('source vocab size: {}'.format(len(src_vocab)))
        print('source vocab test, bill: {} , {}'.format(
            src_vocab['<pad>'], src_vocab['bill']))
        print('target vocab size: {}'.format(len(tgt_vocab)))
        print('target vocab test, bill: {}, {}'.format(tgt_vocab['<pad>'],
                                                       tgt_vocab['bill']))
        print('target vocat testing:')
        print('word: <pad> get :{}'.format(tgtwords[tgt_vocab['<pad>']]))
        print('word: bill get :{}'.format(tgtwords[tgt_vocab['bill']]))
        print('word: service get :{}'.format(tgtwords[tgt_vocab['service']]))

    parser = argparse.ArgumentParser(description='train.py')

    # opts.py
    opts.add_md_help_argument(parser)
    opts.model_opts(parser)
    opts.train_opts(parser)
    opt = parser.parse_args()

    dummy_opt = parser.parse_known_args([])[0]

    opt.cuda = opt.gpuid[0] > -1
    if opt.cuda:
        torch.cuda.set_device(opt.gpuid[0])

    checkpoint = opt.model
    print('Building model...')
    model = ModelHVAE.make_base_model(
        opt, src_vocab, tgt_vocab, opt.cuda, checkpoint
    )  ### Done  #### How to integrate the two embedding layers...
    print(model)
    tally_parameters(model)  ### Done

    testfile = '/D/home/lili/mnt/DATA/convaws/convdata/conv-val_v.json'
    test = pd.read_json(testfile)
    print('Test training data from: {}'.format(testfile))

    test_srs = test.context.values.tolist()
    test_tgt = test.replies.values.tolist()

    test_batch_size = 16
    test_iter = data_util.gen_minibatch(test_srs, test_tgt, test_batch_size,
                                        src_vocab, tgt_vocab)

    tgtvocab = tgt_vocab

    optim = Optim.Optim('adam', 1e-3, 5)
    train_loss = Loss.VAELoss(model.generator, tgtvocab)
    valid_loss = Loss.VAELoss(model.generator, tgtvocab)
    trainer = Trainer.VaeTrainer(model, test_iter, test_iter, train_loss,
                                 valid_loss, optim)
    valid_stats = trainer.validate()
    print('Validation perplexity: %g' % valid_stats.ppl())
    print('Validation accuracy: %g' % valid_stats.accuracy())
def main():
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    engine = DBEngine(opt.db_file)

    with codecs.open(opt.source_file, "r", "utf-8") as corpus_file:
        sql_list = [json.loads(line)['sql'] for line in corpus_file]

    js_list = table.IO.read_anno_json(opt.anno)

    prev_best = (None, None)
    print(opt.split, opt.model_path)

    num_models = 0

    f_out = open('Two-stream-' + opt.unseen_table + '-out-case', 'w')

    for fn_model in glob.glob(opt.model_path):
        num_models += 1
        sys.stdout.flush()
        print(fn_model)
        print(opt.anno)
        opt.model = fn_model

        translator = table.Translator(opt, dummy_opt.__dict__)
        data = table.IO.TableDataset(js_list, translator.fields, None, False)
        #torch.save(data, open( 'data.pt', 'wb'))
        test_data = table.IO.OrderedIterator(dataset=data,
                                             device=opt.gpu,
                                             batch_size=opt.batch_size,
                                             train=False,
                                             sort=True,
                                             sort_within_batch=False)

        # inference
        r_list = []
        for batch in test_data:
            r_list += translator.translate(batch)
        r_list.sort(key=lambda x: x.idx)
        assert len(r_list) == len(
            js_list), 'len(r_list) != len(js_list): {} != {}'.format(
                len(r_list), len(js_list))
        # evaluation
        error_cases = []
        for pred, gold, sql_gold in zip(r_list, js_list, sql_list):
            error_cases.append(pred.eval(opt.split, gold, sql_gold, engine))
#            error_cases.append(pred.eval(opt.split, gold, sql_gold))
        print('Results:')
        for metric_name in ('all', 'exe', 'agg', 'sel', 'where', 'col', 'span',
                            'lay', 'BIO', 'BIO_col'):
            c_correct = sum((x.correct[metric_name] for x in r_list))
            print('{}: {} / {} = {:.2%}'.format(metric_name, c_correct,
                                                len(r_list),
                                                c_correct / len(r_list)))
            if metric_name == 'all':
                all_acc = c_correct
            if metric_name == 'exe':
                exe_acc = c_correct
        if prev_best[
                0] is None or all_acc + exe_acc > prev_best[1] + prev_best[2]:
            prev_best = (fn_model, all_acc, exe_acc)

#        random.shuffle(error_cases)
        for error_case in error_cases:
            if len(error_case) == 0:
                continue
            json.dump(error_case, f_out)
            f_out.write('\n')


#            print('table_id:\t', error_case['table_id'])
#            print('question_id:\t',error_case['question_id'])
#            print('question:\t', error_case['question'])
#            print('table_head:\t', error_case['table_head'])
#            print('table_content:\t', error_case['table_content'])
#            print()

#            print(error_case['BIO'])
#            print(error_case['BIO_col'])
#            print()

#            print('gold:','agg:',error_case['gold']['agg'],'sel:',error_case['predict']['sel'])
#            for i in range(len(error_case['gold']['conds'])):
#                print(error_case['gold']['conds'][i])

#           print('predict:','agg:',error_case['predict']['agg'],'sel:',error_case['predict']['sel'])
#           for i in range(len(error_case['predict']['conds'])):
#               print(error_case['predict']['conds'][i])
#           print('\n\n')

    print(prev_best)
    if (opt.split == 'dev') and (prev_best[0] is not None) and num_models != 1:
        if opt.unseen_table == 'full':
            with codecs.open(os.path.join(opt.save_path, 'dev_best.txt'),
                             'w',
                             encoding='utf-8') as f_out:
                f_out.write('{}\n'.format(prev_best[0]))
        else:
            with codecs.open(os.path.join(
                    opt.save_path, 'dev_best_' + opt.unseen_table + '.txt'),
                             'w',
                             encoding='utf-8') as f_out:
                f_out.write('{}\n'.format(prev_best[0]))
Beispiel #17
0
def trainModel(model, trainData, validData, dataset, optim, stats, opt):
    print(model)

    # define criterion of each GPU
    criterion = NMTCriterion(dataset['dicts']['tgt'].size(), opt.gpus)
    translator = onmt.Translator(opt)
    lm = kenlm.Model(opt.lm_path)

    start_time = time.time()

    def trainEpoch(epoch):

        model.train()

        if opt.extra_shuffle and epoch > opt.curriculum:
            trainData.shuffle()

        # shuffle mini batch order
        batchOrder = torch.randperm(len(trainData))

        total_loss, total_KLD, total_KLD_obj, total_words, total_num_correct = 0, 0, 0, 0, 0
        report_loss, report_KLD, report_KLD_obj, report_tgt_words, report_src_words, report_num_correct = 0, 0, 0, 0, 0, 0
        start = time.time()
        for i in range(len(trainData)):

            total_step = epoch * len(trainData) + i
            batchIdx = batchOrder[i] if epoch > opt.curriculum else i
            batch = trainData[batchIdx][:-1] # exclude original indices

            model.zero_grad()
            outputs, mu, logvar = model(batch, total_step)
            targets = batch[1][1:]  # exclude <s> from targets
            _memoryEfficientLoss = memoryEfficientLoss(opt.max_generator_batches)
            loss, gradOutput, num_correct = _memoryEfficientLoss(
                    outputs, targets, criterion)

            KLD, KLD_obj = KLDLoss(opt.kl_min)(mu, logvar)
            if opt.k != 0:
                kl_rate = 1 / (1 + opt.k * math.exp(-total_step/opt.k))
            else:
                kl_rate = 1
            KLD_obj = kl_rate * KLD_obj

            elbo = KLD_obj + loss
            elbo.backward()

            # update the parameters
            optim.step()

            num_words = targets.data.ne(onmt.Constants.PAD).sum()
            report_loss += loss.data[0]
            report_KLD += KLD.data[0]
            report_KLD_obj += KLD_obj.data[0]
            report_num_correct += num_correct
            report_tgt_words += num_words
            report_src_words += sum(batch[0][1])
            total_loss += loss.data[0]
            total_KLD += KLD.data[0]
            total_KLD_obj += KLD_obj.data[0]
            total_num_correct += num_correct
            total_words += num_words
            stats['kl_rate'].append(kl_rate)
            if i % opt.log_interval == -1 % opt.log_interval:
                print("Epoch %2d, %5d/%5d; acc: %6.2f; ppl: %6.2f; KLD: %6.2f; KLD obj: %6.2f; kl rate: %2.6f; %3.0f src tok/s; %3.0f tgt tok/s; %6.0f s elapsed" %
                      (epoch, i+1, len(trainData),
                      report_num_correct / report_tgt_words * 100,
                      math.exp(report_loss / report_tgt_words),
                      report_KLD / report_tgt_words,
                      report_KLD_obj / report_tgt_words,
                      kl_rate,
                      report_src_words/(time.time()-start),
                      report_tgt_words/(time.time()-start),
                      time.time()-start_time))
                mu_mean = mu.mean()
                mu_std  = mu.std()
                logvar_mean = logvar.mean()
                logvar_std = logvar.std()
                print("mu mean: {:0.5f}".format(mu_mean.data[0]))
                print("mu std: {:0.5f}".format(mu_std.data[0]))
                print("logvar mean: {:0.5f}".format(logvar_mean.data[0]))
                print("logvar std: {:0.5f}".format(logvar_std.data[0]))
                report_loss = report_KLD = report_KLD_obj = report_tgt_words = report_src_words = report_num_correct = 0

                start = time.time()

        return total_loss / total_words, total_KLD / total_words, total_KLD_obj / total_words, total_num_correct / total_words

    best_valid_acc = max(stats['valid_accuracy']) if stats['valid_accuracy'] else 0
    best_valid_ppl = math.exp(min(stats['valid_loss'])) if stats['valid_loss'] else math.inf
    best_valid_lm_nll = math.exp(min(stats['valid_lm_nll'])) if stats['valid_lm_nll'] else math.inf
    best_epoch = 1 + np.argmax(stats['valid_accuracy']) if stats['valid_accuracy'] else 1
    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        #  (1) train for one epoch on the training set
        train_loss, train_KLD, train_KLD_obj, train_acc = trainEpoch(epoch)
        train_ppl = math.exp(min(train_loss, 100))

        stats['train_loss'].append(train_loss)
        stats['train_KLD'].append(train_KLD)
        stats['train_KLD_obj'].append(train_KLD_obj)
        stats['train_accuracy'].append(train_acc)

        print('Train perplexity: %g' % train_ppl)
        print('Train KL Divergence: %g' % train_KLD)
        print('Train KL divergence objective: %g' % train_KLD_obj)
        print('Train accuracy: %g' % (train_acc*100))

        #  (2) evaluate on the validation set

        plot_tsne = plotTsne(epoch, opt.save_model)
        _eval = eval(model, criterion, plot_tsne, opt.tsne_num_batches)
        valid_loss, valid_KLD, valid_acc = _eval(validData)
        valid_ppl = math.exp(min(valid_loss, 100))
        sampled_sentences = []
        for i in range(opt.validation_num_batches):
            predBatch, predScore = translator.sample(opt.batch_size)
            for pred in predBatch:
            sampled_sentences.append(" ".join(pred[0]))
        valid_lm_nll = get_nll(lm, sampled_sentences)



        stats['valid_loss'].append(valid_loss)
        stats['valid_KLD'].append(valid_KLD)
        stats['valid_accuracy'].append(valid_acc)
        stats['valid_lm_nll'].append(valid_lm_nll)
        stats['step'].append(epoch * len(trainData))

        print('Validation perplexity: %g' % valid_ppl)
        print('Validation KL Divergence: %g' % valid_KLD)
        print('Validation accuracy: %g' % (valid_acc*100))
        print('Validation kenlm nll: %g' % (valid_lm_nll))

        #  (3) plot statistics
        _plot_stats = plot_stats(opt.save_model)
        _plot_stats(stats)

        #  (4) update the learning rate
        optim.updateLearningRate(valid_loss, epoch)
        if best_valid_lm_nll > valid_lm_nll: # only store checkpoints if accuracy improved
            if epoch > opt.start_epoch:
                os.remove('%s_acc_%.2f_ppl_%.2f_lmnll_%.2f_e%d.pt'\
                % (opt.save_model, 100*best_valid_acc, best_valid_ppl, best_valid_lm_nll, best_epoch))
            best_valid_acc = valid_acc
            best_valid_lm_nll = valid_lm_nll
            best_valid_ppl = valid_ppl
            best_epoch = epoch
            model_state_dict = model.module.state_dict() if len(opt.gpus) > 1 else model.state_dict()
            #  (5) drop a checkpoint
            checkpoint = {
                'model': model_state_dict,
                'dicts': dataset['dicts'],
                'opt': opt,
                'epoch': epoch,
                'optim': optim,
                'stats': stats
            }
            torch.save(checkpoint,
                   '%s_acc_%.2f_ppl_%.2f_lmnll_%.2f_e%d.pt' % (opt.save_model, 100*valid_acc, valid_ppl, valid_lm_nll, epoch))

    return best_valid_lm_nll


def train(opt, dataset):

    if torch.cuda.is_available() and not opt.gpus:
        print("WARNING: You have a CUDA device, so you should probably run with -gpus 0")

    if opt.gpus:
        cuda.set_device(opt.gpus[0])
        opt.cuda = True
    else:
        opt.cuda = False

    ckpt_path = opt.train_from
    if ckpt_path:
        print('Loading dicts from checkpoint at %s' % ckpt_path)
        checkpoint = torch.load(ckpt_path)
        opt = checkpoint['opt']

    print("Loading data from '%s'" % opt.data)

    if ckpt_path:
        dataset['dicts'] = checkpoint['dicts']
    model_dir = os.path.dirname(opt.save_model)
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    trainData = onmt.Dataset(dataset['train']['src'],
                             dataset['train']['tgt'], opt.batch_size, opt.gpus)
    validData = onmt.Dataset(dataset['valid']['src'],
                             dataset['valid']['tgt'], opt.batch_size, opt.gpus,
                             volatile=True)

    dicts = dataset['dicts']
    print(' * vocabulary size. source = %d; target = %d' %
          (dicts['src'].size(), dicts['tgt'].size()))
    print(' * number of training sentences. %d' %
          len(dataset['train']['src']))
    print(' * maximum batch size. %d' % opt.batch_size)

    print('Building model...')
    assert dicts['src'].size() == dicts['tgt'].size()
    dict_size = dicts['src'].size()
    word_lut = nn.Embedding(dicts['src'].size(),
                            opt.word_vec_size,
                            padding_idx=onmt.Constants.PAD)
    generator = nn.Sequential(
        nn.Linear(opt.rnn_size, dicts['tgt'].size()),
        nn.LogSoftmax())
    encoder = onmt.Models.Encoder(opt, word_lut)
    decoder = onmt.Models.Decoder(opt, word_lut, generator)

    model = onmt.Models.NMTModel(encoder, decoder, opt)


    if ckpt_path:
        print('Loading model from checkpoint at %s' % ckpt_path)
        model.load_state_dict(checkpoint['model'])
        opt.start_epoch = checkpoint['epoch'] + 1

    if len(opt.gpus) >= 1:
        model.cuda()
    else:
        model.cpu()

    if len(opt.gpus) > 1:
        model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)

    if not ckpt_path:
        for p in model.parameters():
            p.data.uniform_(-opt.param_init, opt.param_init)

        encoder.load_pretrained_vectors(opt)
        decoder.load_pretrained_vectors(opt)

        optim = onmt.Optim(
            opt.optim, opt.learning_rate, opt.max_grad_norm,
            lr_decay=opt.learning_rate_decay,
            start_decay_at=opt.start_decay_at
        )
        optim.set_parameters(model.parameters())
    else:
        print('Loading optimizer from checkpoint:')
        optim = checkpoint['optim']
        optim.set_parameters(model.parameters())
        optim.optimizer.load_state_dict(checkpoint['optim'].optimizer.state_dict())


    if ckpt_path:
        stats = checkpoint['stats']
    else:
        stats = {'train_loss': [], 'train_KLD': [], 'train_KLD_obj': [],
        'train_accuracy': [], 'kl_rate': [], 'valid_loss': [], 'valid_KLD': [],
        'valid_accuracy': [], 'valid_lm_nll', 'step': []}

    nParams = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % nParams)

    best_valid_lm_nll = trainModel(model, trainData, validData, dataset, optim, stats, opt)
    return best_valid_lm_nll


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='train.py')

    opts.model_opts(parser)
    opts.train_opts(parser)
    opt = parser.parse_args()
    train(opt)