示例#1
0
    def _test(metric_device):
        engine = Engine(update)
        m = Bleu(ngram=4, smooth="smooth2", average="micro")
        m.attach(engine, "bleu")

        engine.run(data=list(range(size)), max_epochs=1)

        assert "bleu" in engine.state.metrics

        ref_bleu = 0
        references = []
        candidates = []
        for _candidates, _references in data:
            references.append(_references[0])
            candidates.append(_candidates[0])
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            ref_bleu += corpus_bleu(
                references,
                candidates,
                weights=[0.25, 0.25, 0.25, 0.25],
                smoothing_function=SmoothingFunction().method2,
            )

        assert pytest.approx(engine.state.metrics["bleu"]) == ref_bleu
示例#2
0
    def _test(metric_device):
        engine = Engine(update)
        m = Bleu(ngram=4, smooth="smooth2")
        m.attach(engine, "bleu")

        engine.run(data=list(range(size)), max_epochs=1)

        assert "bleu" in engine.state.metrics

        ref_bleu = 0
        for candidates, references in data:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                ref_bleu += sentence_bleu(
                    references[0],
                    candidates[0],
                    weights=[0.25, 0.25, 0.25, 0.25],
                    smoothing_function=SmoothingFunction().method2,
                )

        assert pytest.approx(
            engine.state.metrics["bleu"]) == ref_bleu / len(data)