Beispiel #1
0
def test(config):
    _config_test(config)

    de2idx, idx2de = load_de_vocab()
    en2idx, idx2en = load_en_vocab()
    
    model = ConvSeq2Seq(config)
    graph_handler = GraphHandler(config)
    inferencer = Inferencer(config, model)
    sess = tf.Session()
    graph_handler.initialize(sess)

    global_step = 0
    refs = []
    hypotheses = []
    with codecs.open(os.path.join(config.eval_dir, config.model_name), "w", "utf-8") as fout:
        for i, batch in tqdm(enumerate(get_batch_for_test())):
            preds = inferencer.run(sess, batch)
            sources = batch['source']
            targets = batch['target']
            for source, target, pred in zip(sources, targets, preds):
                got = " ".join(idx2en[idx] for idx in pred).split("</S>")[0].strip()
                fout.write("- source: " + source +"\n")
                fout.write("- expected: " + target + "\n")
                fout.write("- got: " + got + "\n\n")
                fout.flush()

                ref = target.split()
                hypothesis = got.split()
                if len(ref) > 3 and len(hypothesis) > 3:
                    refs.append([ref])
                    hypotheses.append(hypothesis)

        score = corpus_bleu(refs, hypotheses)
        fout.write("Bleu Score = " + str(100*score))