コード例 #1
0
ファイル: test_rouge.py プロジェクト: isabella232/ignite-2
def test_wrong_inputs():

    with pytest.raises(ValueError,
                       match=r"ngram order must be greater than zero"):
        RougeN(ngram=0)

    with pytest.raises(ValueError,
                       match=r"alpha must be in interval \[0, 1\]"):
        RougeN(alpha=-1)

    with pytest.raises(ValueError,
                       match=r"alpha must be in interval \[0, 1\]"):
        RougeN(alpha=2)

    with pytest.raises(
            ValueError,
            match=r"multiref : valid values are \['best', 'average'\] "):
        RougeN(multiref="")

    with pytest.raises(
            ValueError,
            match=r"variant must be 'L' or integer greater to zero"):
        Rouge(variants=["error"])

    with pytest.raises(NotComputableError):
        RougeL().compute()

    with pytest.raises(ValueError):
        Rouge(multiref="unknown")
コード例 #2
0
ファイル: test_rouge.py プロジェクト: pytorch/ignite
def test_rouge_n_alpha(ngram, candidate, reference, expected):
    for alpha in [0, 1, 0.3, 0.5, 0.8]:
        rouge = RougeN(ngram=ngram, alpha=alpha)
        rouge.update(([candidate], [[reference]]))
        results = rouge.compute()
        assert results[f"Rouge-{ngram}-P"] == expected[0]
        assert results[f"Rouge-{ngram}-R"] == expected[1]
        try:
            F = expected[0] * expected[1] / ((1 - alpha) * expected[0] + alpha * expected[1])
        except ZeroDivisionError:
            F = 0
        assert results[f"Rouge-{ngram}-F"] == F