Example #1
0
File: test.py Project: snnclsr/nmt
def main():

    arg_parser = argparse.ArgumentParser(
        description="Neural Machine Translation Testing")
    arg_parser.add_argument("--model_file", required=True, help="Model File")
    arg_parser.add_argument("--valid_data",
                            required=True,
                            nargs="+",
                            help="Validation_data")

    args = arg_parser.parse_args()
    args = vars(args)
    print(args)
    model = Seq2Seq.load(args["model_file"])
    print(model)
    model.device = "cpu"

    tr_dev_dataset_fn, en_dev_dataset_fn = args["valid_data"]
    tr_valid_data = read_text(tr_dev_dataset_fn)
    en_valid_data = read_text(en_dev_dataset_fn)

    valid_data = list(zip(tr_valid_data, en_valid_data))

    src_valid, tgt_valid = add_start_end_tokens(valid_data)

    hypotheses = beam_search(model,
                             src_valid,
                             beam_size=3,
                             max_decoding_time_step=70)
    top_hypotheses = [hyps[0] for hyps in hypotheses]
    bleu_score = compute_corpus_level_bleu_score(tgt_valid, top_hypotheses)
    print('Corpus BLEU: {}'.format(bleu_score * 100))
Example #2
0
def translate():

    if request.method == "GET":
        return render_template("index.html")
    elif request.method == "POST":
        args = request.form

    print(args)
    text_input = args["textarea"]

    print("Input: ", text_input)
    tokenized_sent = tokenizer.tokenize(text_input)
    print("Tokenized input: ", tokenized_sent)

    with open(VOCAB_FILE, "rb") as f:
        vocabs = pickle.load(f)

    model = Seq2Seq.load(MODEL_PATH)
    model.device = "cpu"

    hypothesis = beam_search(model, [tokenized_sent],
                             beam_size=3,
                             max_decoding_time_step=70)[0]
    print("Hypothesis")
    print(hypothesis)

    for i in range(3):
        new_target = [['<sos>'] + hypothesis[i].value + ['<eos>']]
        a_ts = generate_attention_map(model, vocabs, [tokenized_sent],
                                      new_target)
        save_attention(tokenized_sent,
                       hypothesis[i].value, [
                           a[0].detach().cpu().numpy()
                           for a in a_ts[:len(hypothesis[i].value)]
                       ],
                       save_path="static/list_{}.png".format(i))

    result_hypothesis = []
    for idx, hyp in enumerate(hypothesis):
        result_hypothesis.append((idx, " ".join(hyp.value)))

    return render_template("index.html",
                           hypothesis=result_hypothesis,
                           sentence=text_input)