Example #1
0
    def load(self):
        timer = Timer()
        print("Loading model %d" % self.model_id, file=sys.stderr)
        timer.start()
        self.out_file = io.StringIO()
        try:
            self.translator = make_translator(self.opt,
                                              report_score=False,
                                              out_file=self.out_file)
        except RuntimeError as e:
            raise ServerModelError("Runtime Error: %s" % str(e))

        timer.tick("model_loading")
        if self.tokenizer_opt is not None:
            print("Loading tokenizer", file=sys.stderr)
            mandatory = ["type", "model"]
            for m in mandatory:
                if m not in self.tokenizer_opt:
                    raise ValueError(
                        "Missing mandatory tokenizer option '%s'" % m)
            if self.tokenizer_opt['type'] == 'sentencepiece':
                import sentencepiece as spm
                sp = spm.SentencePieceProcessor()
                model_path = os.path.join(self.model_root,
                                          self.tokenizer_opt['model'])
                sp.Load(model_path)
                self.tokenizer = sp
            else:
                raise ValueError("Invalid value for tokenizer type")

        self.load_time = timer.tick()
        self.reset_unload_timer()
Example #2
0
def main(opt):
    translator = make_translator(opt, report_score=True)
    if opt.src != "":
        translator.translate(opt.src_dir, opt.src, opt.tgt, opt.batch_size,
                             opt.attn_debug)
    else:
        translator.translate(opt.src_dir, sys.stdin, opt.tgt, opt.batch_size,
                             opt.attn_debug)
Example #3
0
def main(opt):
    translator = make_translator(opt, report_score=True)
    translator.translate(opt.src_dir,
                         opt.src,
                         opt.tgt,
                         opt.batch_size,
                         opt.attn_debug,
                         aux_vec_path=opt.aux_vec_path,
                         retrieved_path=opt.retrieved_path)
Example #4
0
def main(opt):
    translator = make_translator(opt, report_score=True, logger=logger)
    translator.translate(opt.src_dir, opt.src, opt.tgt,
                         opt.batch_size, opt.attn_debug)
    print(translator.output)
    print(opt.output)
    with open(opt.output, "r") as f:
        for line in f:
            print(line)
Example #5
0
def main(opt):

    translator = make_translator(opt, report_score=True)

    start = timeit.default_timer()
    _, attns_info, oov_info, copy_info, context_attns_info = translator.translate(
        opt.src_dir, opt.src, opt.tgt, opt.batch_size, opt.attn_debug)
    end = timeit.default_timer()
    print("Translation takes {}s".format(end - start))

    # currently attns_info,oov_info only contain first index data of batch
    if len(context_attns_info) == 0:
        return attns_info, oov_info, copy_info
    else:
        return attns_info, oov_info, copy_info, context_attns_info
Example #6
0
def main(opt):
    translator = make_translator(opt, report_score=True)
    translator.translate(opt.src_dir,
                         opt.src,
                         opt.tgt,
                         opt.phrase_table,
                         opt.global_phrase_table,
                         opt.batch_size,
                         opt.attn_debug,
                         opt.side_src,
                         opt.side_tgt,
                         opt.oracle,
                         opt.lower,
                         psi=opt.psi,
                         theta=opt.theta,
                         k=opt.k)
Example #7
0
def sub_main(queue, opt):
    
    translator = make_translator(opt, report_score=True)
    
    start = timeit.default_timer()
    # ocntext attns info랑 raw attns info를 혼용중 나중에 꼭 수정해야 함
    _, attns_info, oov_info, copy_info, context_attns_info = translator.translate(opt.src_dir, opt.src, opt.tgt,
                         opt.batch_size, opt.attn_debug, raw_attn=True)
    end = timeit.default_timer()
    print("Translation takes {}s".format(end-start))
    
    # currently attns_info,oov_info only contain first index data of batch
    if len(context_attns_info) == 0:
        queue.put((attns_info, oov_info, copy_info))
        return attns_info, oov_info, copy_info
    else:
        queue.put((attns_info, oov_info, copy_info, context_attns_info))
        return attns_info, oov_info, copy_info, context_attns_info
Example #8
0
def main(opt):
    translator = make_translator(opt, report_score=True)
    translator.translate(opt.src_dir, opt.src, opt.conversation, opt.tgt, opt.score,
                         opt.batch_size, opt.attn_debug)
Example #9
0
def main(opt):
    translator = make_translator(opt, report_score=True)
    translator.translate(opt.src_dir, opt.src, opt.tgt, opt.batch_size)
    def __init__(self, model_filename, cmdline_args):
        parser = argparse.ArgumentParser(
            description='translate.py',
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        opts.add_md_help_argument(parser)
        opts.translate_opts(parser)
        opt = parser.parse_args(['-model', model_filename, '-src', ''] +
                                (cmdline_args or []))

        translator = make_translator(opt)
        model = translator.model
        fields = translator.fields
        tgt_vocab = fields["tgt"].vocab

        def encode_from_src(src):
            enc_states, memory_bank = model.encoder(src)
            return dict(enc_states=enc_states,
                        memory_bank=memory_bank,
                        src=src)

        @lru_cache(maxsize=32)
        def encode_text(in_text):
            text_preproc = fields['src'].preprocess(in_text)
            src, src_len = fields['src'].process([text_preproc],
                                                 device=-1,
                                                 train=False)
            src = src.unsqueeze(2)  # not sure why
            return encode_from_src(src)

        @lru_cache(maxsize=32)
        def encode_img(image_idx):
            if isinstance(image_idx, str):
                image_idx = int(image_idx)
            src = Variable(torch.IntTensor([image_idx]), volatile=True)
            return encode_from_src(src)

        def encode(inp):
            if model.encoder.__class__.__name__ == 'VecsEncoder':
                return encode_img(inp)
            else:
                return encode_text(inp)

        @lru_cache(maxsize=128)
        def get_decoder_state(in_text, tokens_so_far):
            encoder_out = encode(in_text)
            enc_states = encoder_out['enc_states']
            memory_bank = encoder_out['memory_bank']
            src = encoder_out['src']

            if len(tokens_so_far) == 0:
                return None, translator.model.decoder.init_decoder_state(
                    src, memory_bank, enc_states)

            prev_out, prev_state = get_decoder_state(in_text,
                                                     tokens_so_far[:-1])

            tgt_in = Variable(torch.LongTensor(
                [tgt_vocab.stoi[tokens_so_far[-1]]]),
                              volatile=True)  # [tgt_len]
            tgt_in = tgt_in.unsqueeze(1)  # [tgt_len x batch=1]
            tgt_in = tgt_in.unsqueeze(1)  # [tgt_len x batch=1 x nfeats=1]

            # Prepare to call the decoder. Unfortunately the decoder mutates the state passed in!
            memory_bank = copy.deepcopy(memory_bank)
            assert isinstance(prev_state.hidden, tuple)
            prev_state.hidden = tuple(v.detach() for v in prev_state.hidden)
            prev_state = copy.deepcopy(prev_state)

            assert memory_bank.size()[1] == 1

            dec_out, dec_states, attn = translator.model.decoder(
                tgt_in, memory_bank, prev_state)

            assert dec_out.shape[0] == 1
            return dec_out[0], dec_states

        def generate_completions(in_text, tokens_so_far):
            tokens_so_far = [onmt.io.BOS_WORD] + tokens_so_far
            tokens_so_far = tuple(tokens_so_far)  # Make it hashable
            dec_out, dec_states = get_decoder_state(in_text, tokens_so_far)
            logits = model.generator.forward(dec_out).data
            vocab = tgt_vocab.itos

            assert logits.shape[0] == 1
            logits = logits[0]
            return logits, vocab

        def eval_logprobs(in_text, tokens, *, use_eos):
            encoder_out = encode(in_text)
            enc_states = encoder_out['enc_states']
            memory_bank = encoder_out['memory_bank']
            src = encoder_out['src']

            tokens = [onmt.io.BOS_WORD] + tokens
            if use_eos:
                tokens = tokens + [onmt.io.EOS_WORD]

            decoder_state = model.decoder.init_decoder_state(
                src, memory_bank=memory_bank, encoder_final=enc_states)
            tgt = Variable(
                torch.LongTensor([tgt_vocab.stoi[tok] for tok in tokens
                                  ]).unsqueeze(1).unsqueeze(1))
            dec_out, dec_states, attn = model.decoder(tgt[:-1], memory_bank,
                                                      decoder_state)
            logits = model.generator(dec_out)
            return F.nll_loss(logits.squeeze(1),
                              tgt[1:].squeeze(1).squeeze(1),
                              reduce=False,
                              size_average=False).data.numpy()

        self.model = model
        self.fields = fields
        self.translator = translator
        self.encode = encode
        self.get_decoder_state = get_decoder_state
        self.generate_completions = generate_completions
        self.eval_logprobs = eval_logprobs
Example #11
0
def main(opt):
    translator = make_translator(opt, report_score=True, logger=logger)
    translator.translate(opt.src_dir, opt.src, opt.tgt,
                         opt.batch_size, opt.attn_debug)
Example #12
0
def train_model(model, fields, optim, data_type, model_opt):
    translate_parser = argparse.ArgumentParser(
        description='translate',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    onmt.opts.add_md_help_argument(translate_parser)
    onmt.opts.translate_opts(translate_parser)
    opt_translate = translate_parser.parse_args(args=[])
    opt_translate.replace_unk = False
    opt_translate.verbose = True

    train_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
    valid_loss = make_loss_compute(model,
                                   fields["tgt"].vocab,
                                   opt,
                                   train=False)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches
    norm_method = opt.normalization
    grad_accum_count = opt.accum_count

    trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size,
                           shard_size, data_type, norm_method,
                           grad_accum_count)

    logger.info('')
    logger.info('Start training...')
    logger.info(' * number of epochs: %d, starting from Epoch %d' %
                (opt.epochs + 1 - opt.start_epoch, opt.start_epoch))
    logger.info(' * batch size: %d' % opt.batch_size)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        logger.info('')

        # 1. Train for one epoch on the training set.
        train_iter = make_dataset_iter(lazily_load_dataset("train"), fields,
                                       opt)
        train_stats = trainer.train(train_iter, epoch, report_func)
        logger.info('Train perplexity: %g' % train_stats.ppl())
        logger.info('Train accuracy: %g' % train_stats.accuracy())

        # 2. Validate on the validation set.
        valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
                                       fields,
                                       opt,
                                       is_train=False)
        valid_stats = trainer.validate(valid_iter)
        logger.info('Validation perplexity: %g' % valid_stats.ppl())
        logger.info('Validation accuracy: %g' % valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)
        if opt.tensorboard:
            train_stats.log_tensorboard("train", writer, optim.lr, epoch)
            train_stats.log_tensorboard("valid", writer, optim.lr, epoch)

        # 4. Update the learning rate
        decay = trainer.epoch_step(valid_stats.ppl(), epoch)
        if decay:
            logger.info("Decaying learning rate to %g" % trainer.optim.lr)

        # 5. Drop a checkpoint if needed.
        if epoch % 10 == 0:  #epoch >= opt.start_checkpoint_at:
            trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)

            opt_translate.src = 'cache/valid_src_{:s}.txt'.format(
                opt.file_templ)
            opt_translate.tgt = 'cache/valid_eval_refs_{:s}.txt'.format(
                opt.file_templ)
            opt_translate.output = 'result/{:s}/valid_res_{:s}.txt'.format(
                opt.dataset, opt.file_templ)
            opt_translate.model = '%s_acc_%.2f_ppl_%.2f_e%d.pt' % (
                opt.save_model, valid_stats.accuracy(), valid_stats.ppl(),
                epoch)

            check_save_result_path(opt_translate.output)

            translator = make_translator(opt_translate,
                                         report_score=False,
                                         logger=logger)
            translator.calc_sacre_bleu = False
            translator.translate(opt_translate.src_dir, opt_translate.src,
                                 opt_translate.tgt, opt_translate.batch_size,
                                 opt_translate.attn_debug)
Example #13
0
def train_model(model, fields, optim, data_type, opt_per_pred):
    translate_parser = argparse.ArgumentParser(
        description='translate',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    onmt.opts.add_md_help_argument(translate_parser)
    onmt.opts.translate_opts(translate_parser)
    opt_translate = translate_parser.parse_args(args=[])
    opt_translate.replace_unk = False
    opt_translate.verbose = False
    opt_translate.block_ngram_repeat = False
    if opt.gpuid:
        opt_translate.gpu = opt.gpuid[0]

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches
    norm_method = opt.normalization
    grad_accum_count = opt.accum_count

    trainer = {}
    for predicate in opt.parser.predicates:
        train_loss = make_loss_compute(model[predicate],
                                       fields[predicate]["tgt"].vocab,
                                       opt_per_pred[predicate])
        valid_loss = make_loss_compute(model[predicate],
                                       fields[predicate]["tgt"].vocab,
                                       opt_per_pred[predicate],
                                       train=False)
        trainer[predicate] = onmt.Trainer(model[predicate], train_loss,
                                          valid_loss, optim[predicate],
                                          trunc_size, shard_size,
                                          data_type[predicate], norm_method,
                                          grad_accum_count)

    logger.info('')
    logger.info('Start training...')
    logger.info(' * number of epochs: %d, starting from Epoch %d' %
                (opt.epochs + 1 - opt.start_epoch, opt.start_epoch))
    logger.info(' * batch size: %d' % opt.batch_size)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        logger.info('')

        train_stats = {}
        valid_stats = {}
        for predicate in opt.parser.predicates:
            logger.info('Train predicate: %s' % predicate)
            # 1. Train for one epoch on the training set.
            train_iter = make_dataset_iter(
                lazily_load_dataset("train", opt_per_pred[predicate]),
                fields[predicate], opt_per_pred[predicate])
            train_stats[predicate] = trainer[predicate].train(
                train_iter, epoch, fields[predicate], report_func)
            logger.info('Train perplexity: %g' % train_stats[predicate].ppl())
            logger.info('Train accuracy: %g' %
                        train_stats[predicate].accuracy())

            # 2. Validate on the validation set.
            valid_iter = make_dataset_iter(lazily_load_dataset(
                "valid", opt_per_pred[predicate]),
                                           fields[predicate],
                                           opt_per_pred[predicate],
                                           is_train=False)
            valid_stats[predicate] = trainer[predicate].validate(valid_iter)
            logger.info('Validation perplexity: %g' %
                        valid_stats[predicate].ppl())
            logger.info('Validation accuracy: %g' %
                        valid_stats[predicate].accuracy())

            # 3. Log to remote server.
            if opt_per_pred[predicate].exp_host:
                train_stats[predicate].log("train", experiment,
                                           optim[predicate].lr)
                valid_stats[predicate].log("valid", experiment,
                                           optim[predicate].lr)
            if opt_per_pred[predicate].tensorboard:
                train_stats[predicate].log_tensorboard("train", writer,
                                                       optim[predicate].lr,
                                                       epoch)
                train_stats[predicate].log_tensorboard("valid", writer,
                                                       optim[predicate].lr,
                                                       epoch)

            # 4. Update the learning rate
            decay = trainer[predicate].epoch_step(valid_stats[predicate].ppl(),
                                                  epoch)
            if decay:
                logger.info("Decaying learning rate to %g" %
                            trainer[predicate].optim.lr)

        # 5. Drop a checkpoint if needed.
        if epoch % 10 == 0:  #epoch >= opt.start_checkpoint_at:
            opt_translates = []
            for predicate in opt.parser.predicates:
                opt_translate.predicate = predicate
                opt_translate.batch_size = opt_per_pred[predicate].batch_size
                opt_translate.src = 'cache/valid_src_{:s}.txt'.format(
                    opt_per_pred[predicate].file_templ)
                opt_translate.tgt = 'cache/valid_eval_refs_{:s}.txt'.format(
                    opt_per_pred[predicate].file_templ)
                opt_translate.output = 'result/{:s}/valid_res_{:s}.txt'.format(
                    opt_per_pred[predicate].dataset,
                    opt_per_pred[predicate].file_templ)
                #opt_translate.model = '%s_acc_%.2f_ppl_%.2f_e%d.pt' % (
                #opt_per_pred[predicate].save_model, valid_stats[predicate].accuracy(), valid_stats[predicate].ppl(), epoch)

                check_save_result_path(opt_translate.output)

                translator = make_translator(opt_translate,
                                             report_score=False,
                                             logger=logger,
                                             fields=fields[predicate],
                                             model=trainer[predicate].model,
                                             model_opt=opt_per_pred[predicate])
                translator.output_beam = 'result/{:s}/valid_res_beam_{:s}.txt'.format(
                    opt_per_pred[predicate].dataset,
                    opt_per_pred[predicate].file_templ)
                #translator.beam_size = 5
                #translator.n_best = 5
                translator.translate(opt_translate.src_dir, opt_translate.src,
                                     opt_translate.tgt,
                                     opt_translate.batch_size,
                                     opt_translate.attn_debug)
                opt_translates.append(copy(opt_translate))
            corpusBLEU, bleu, rouge, coverage, bleu_per_predicate = evaluate(
                opt_translates)
            for predicate in opt.parser.predicates:
                trainer[predicate].drop_checkpoint(
                    opt_per_pred[predicate], epoch, corpusBLEU, bleu, rouge,
                    coverage, bleu_per_predicate[predicate], fields[predicate],
                    valid_stats[predicate])
def main():
    # load the data!
    if opt.dataset.lower() == 'e2e':
        dataparser = DatasetParser('data/e2e/trainset.csv',
                                   'data/e2e/devset.csv',
                                   'data/e2e/testset_w_refs.csv',
                                   'E2E',
                                   opt,
                                   light=True)
    elif opt.dataset.lower() == 'webnlg':
        dataparser = DatasetParser('data/webNLG_challenge_data/train',
                                   'data/webNLG_challenge_data/dev',
                                   False,
                                   'webNLG',
                                   opt,
                                   light=True)
    elif opt.dataset.lower() == 'sfhotel':
        dataparser = DatasetParser('data/sfx_data/sfxhotel/train.json',
                                   'data/sfx_data/sfxhotel/valid.json',
                                   'data/sfx_data/sfxhotel/test.json',
                                   'SFHotel',
                                   opt,
                                   light=True)

    opt.data = 'save_data/{:s}/'.format(opt.dataset)
    gen_templ = dataparser.get_onmt_file_templ(opt)

    opt.parser = dataparser
    model = {}
    fields = {}
    optim = {}
    data_type = {}

    opt_per_pred = {}

    translate_parser = argparse.ArgumentParser(
        description='translate',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    onmt.opts.add_md_help_argument(translate_parser)
    onmt.opts.translate_opts(translate_parser)
    opt_translate = translate_parser.parse_args(args=[])
    opt_translate.replace_unk = False
    opt_translate.verbose = False
    opt_translate.block_ngram_repeat = False
    if opt.gpuid:
        opt_translate.gpu = opt.gpuid[0]
    opt_translates = []
    for predicate in dataparser.predicates:
        opt_per_pred[predicate] = copy(opt)
        opt_per_pred[predicate].predicate = predicate
        opt_per_pred[predicate].file_templ = gen_templ.format(predicate)
        opt_per_pred[predicate].save_model = 'save_model/{:s}/{:s}'.format(
            opt_per_pred[predicate].dataset, predicate)

        # Get the saved model with the highest reported BLEU in dev
        dir_path = 'save_model/{:s}/'.format(opt_per_pred[predicate].dataset)
        poss_models = [
            f for f in listdir(dir_path)
            if isfile(join(dir_path, f)) and f.startswith(predicate + "_e")
            and '_bleuForPred_' in f and '_corpusBLEU_' in f
        ]
        bleu_models = [
            float(f[f.find('_bleuForPred_') + 13:f.find('_corpusBLEU_')])
            for f in poss_models
        ]
        checkpoint_file = 'save_model/{:s}/{:s}'.format(
            opt_per_pred[predicate].dataset,
            poss_models[bleu_models.index(max(bleu_models))])

        logger.info('Loading checkpoint from %s' % checkpoint_file)
        checkpoint = torch.load(checkpoint_file,
                                map_location=lambda storage, loc: storage)

        # Peek the fisrt dataset to determine the data_type.
        # (All datasets have the same data_type).
        first_dataset = next(
            lazily_load_dataset("train", opt_per_pred[predicate]))
        data_type[predicate] = first_dataset.data_type

        # Load fields generated from preprocess phase.
        fields[predicate] = load_fields(first_dataset, data_type[predicate],
                                        opt_per_pred[predicate], checkpoint)

        # Report src/tgt features.
        collect_report_features(fields[predicate])

        # Load model.
        model[predicate] = build_model(opt_per_pred[predicate],
                                       fields[predicate], checkpoint)
        model[predicate].predicate = predicate
        model[predicate].eval()
        tally_parameters(model[predicate])
        check_save_model_path()

        # Load optimizer.
        optim[predicate] = build_optim(model[predicate], checkpoint)

        opt_translate.predicate = predicate
        opt_translate.batch_size = opt_per_pred[predicate].batch_size
        opt_translate.src = 'cache/test_src_{:s}.txt'.format(
            opt_per_pred[predicate].file_templ)
        opt_translate.tgt = 'cache/test_eval_refs_{:s}.txt'.format(
            opt_per_pred[predicate].file_templ)
        opt_translate.output = 'result/{:s}/test_res_{:s}.txt'.format(
            opt_per_pred[predicate].dataset,
            opt_per_pred[predicate].file_templ)

        check_save_result_path(opt_translate.output)
        if os.path.isfile(opt_translate.src) and os.path.isfile(
                opt_translate.tgt):
            translator = make_translator(opt_translate,
                                         report_score=False,
                                         logger=logger,
                                         fields=fields[predicate],
                                         model=model[predicate],
                                         model_opt=opt_per_pred[predicate])
            translator.output_beam = 'result/{:s}/test_res_beam_{:s}.txt'.format(
                opt_per_pred[predicate].dataset,
                opt_per_pred[predicate].file_templ)
            #translator.beam_size = 5
            #translator.n_best = 5
            translator.translate(opt_translate.src_dir, opt_translate.src,
                                 opt_translate.tgt, opt_translate.batch_size,
                                 opt_translate.attn_debug)
            opt_translates.append(copy(opt_translate))
    evaluate(opt_translates)
Example #15
0
    #     txt = f.readline()
    # with open(src_file, 'w') as f:
    #     f.write(some_char+txt)
    LIST_predict = translator.translate(opt.src_dir, src_file, out_file,
                                        opt.batch_size, opt.attn_debug)
    return LIST_predict[1][0]['score'], ' '.join(LIST_predict[1][0]['area'])


# -------------------------------------
# when load at first
# -------------------------------------
parser = argparse.ArgumentParser(
    description='translate.py',
    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
onmt.opts.add_md_help_argument(parser)
onmt.opts.translate_opts(parser)
opt = parser.parse_args()

# set some default value
opt.gpu = -1
opt.model = '%s/deppon_model_acc_99.87_ppl_1.00_e9.pt' % folder  # no detail
# opt.model = '%s/deppon_model_acc_98.56_ppl_1.07_e11_detail.pt' % folder  # no detail
opt.replace_unk = True
opt.verbose = True
opt.attn_debug = False
translator = make_translator(opt, report_score=True)

# network_parse('t.txt', 'o.txt')
with open('%s/back_code/network/t_muban.txt' % (root_dir), 'r') as f:
    a = f.readline()
some_char = a[0:3]
Example #16
0
def main(NMT_config):

    ### Load RL (global) configurations ###
    config = parse_args()

    ### Load trained QA model ###
    QA_checkpoint = torch.load(config.data_dir + config.QA_best_model)
    QA_config = QA_checkpoint['config']

    QA_mod = BiDAF(QA_config)
    if QA_config.use_gpu:
        QA_mod.cuda()
    QA_mod.load_state_dict(QA_checkpoint['state_dict'])

    ### Load SQuAD dataset ###
    data_filter = get_squad_data_filter(QA_config)

    train_data = read_data(QA_config,
                           'train',
                           QA_config.load,
                           data_filter=data_filter)
    dev_data = read_data(QA_config, 'dev', True, data_filter=data_filter)

    update_config(QA_config, [train_data, dev_data])

    print("Total vocabulary for training is %s" % QA_config.word_vocab_size)

    # from all
    word2vec_dict = train_data.shared[
        'lower_word2vec'] if QA_config.lower_word else train_data.shared[
            'word2vec']
    # from filter-out set
    word2idx_dict = train_data.shared['word2idx']

    # filter-out set idx-vector
    idx2vec_dict = {
        word2idx_dict[word]: vec
        for word, vec in word2vec_dict.items() if word in word2idx_dict
    }
    print("{}/{} unique words have corresponding glove vectors.".format(
        len(idx2vec_dict), len(word2idx_dict)))

    # <null> and <unk> do not have corresponding vector so random.
    emb_mat = np.array([
        idx2vec_dict[idx]
        if idx in idx2vec_dict else np.random.multivariate_normal(
            np.zeros(QA_config.word_emb_size), np.eye(QA_config.word_emb_size))
        for idx in range(QA_config.word_vocab_size)
    ])

    config.emb_mat = emb_mat
    config.new_emb_mat = train_data.shared['new_emb_mat']

    num_steps = int(
        math.ceil(train_data.num_examples /
                  (QA_config.batch_size *
                   QA_config.num_gpus))) * QA_config.num_epochs

    # offset for question mark
    NMT_config.max_length = QA_config.ques_size_th - 1
    NMT_config.batch_size = QA_config.batch_size

    ### Construct translator ###
    translator = make_translator(NMT_config, report_score=True)

    ### Construct optimizer ###
    optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 translator.model.parameters()),
                          lr=config.lr)

    ### Start RL training ###
    count = 0
    QA_mod.eval()
    F1_eval = F1Evaluator(QA_config, QA_mod)
    #eval_model(QA_mod, train_data, dev_data, QA_config, NMT_config, config, translator)

    for i in range(config.n_episodes):
        for batches in tqdm(train_data.get_multi_batches(
                QA_config.batch_size,
                QA_config.num_gpus,
                num_steps=num_steps,
                shuffle=True,
                cluster=QA_config.cluster),
                            total=num_steps):

            #for n, p in translator.model.named_parameters():
            #    print(n)
            #    print(p)
            #print(p.requires_grad)

            start = datetime.now()
            to_input(batches[0][1].data['q'], config.RL_path + config.RL_file)

            # obtain rewrite and log_prob
            q, scores, log_prob = translator.translate(NMT_config.src_dir,
                                                       NMT_config.src,
                                                       NMT_config.tgt,
                                                       NMT_config.batch_size,
                                                       NMT_config.attn_debug)

            q, cq = ref_query(q)
            batches[0][1].data['q'] = q
            batches[0][1].data['cq'] = cq

            log_prob = torch.stack(log_prob).squeeze(-1)
            #print(log_prob)

            translator.model.zero_grad()

            QA_mod(batches)

            e = F1_eval.get_evaluation(batches, False, NMT_config, config,
                                       translator)
            reward = Variable(torch.FloatTensor(e.f1s), requires_grad=False)
            #print(reward)

            ## Initial loss
            loss = create_loss(log_prob, reward)

            loss.backward()
            optimizer.step()