示例#1
0
def parse_sacrebleu_uri(uri: str) -> Tuple[str]:
    """
    Parses the test set and language pair from a URI of the form

        sacrebleu:wmt19:de-en
        sacrebleu:wmt19/google/ar:de-en
    """
    try:
        _, testset, langpair = uri.split(":")
    except ValueError:
        logger.error('sacrebleu:* flags must take the form "sacrebleu:testset:langpair"')
        sys.exit(1)

    testsets = sorted(DATASETS, reverse=True)
    if testset not in testsets:
        logger.error(f"Test set '{testset}' was not found. Available sacrebleu test sets are:")
        for key in testsets:
            logger.error(f"  {key:20s}: {DATASETS[key].get('description', '')}")
        sys.exit(1)

    lang_pairs = get_langpairs_for_testset(testset)

    if langpair not in lang_pairs:
        logger.error(f"Language pair '{langpair}' not available for testset '{testset}'.\n"
                     f" Language pairs available for {testset}: {', '.join(lang_pairs)}")
        sys.exit(1)

    return testset, langpair
示例#2
0
def test_api_get_langpairs_for_testset():
    """
    Loop over the datasets directly, and ensure the API function
    returns each language pair in each test set.
    """
    for testset in sacrebleu.DATASETS.keys():
        available = sacrebleu.get_langpairs_for_testset(testset)
        assert type(available) is list
        for langpair in sacrebleu.DATASETS[testset].keys():
            # skip non-language keys
            if "-" not in langpair:
                assert langpair not in available
            else:
                assert langpair in available
            assert "slashdot_" + langpair not in available
示例#3
0
########################################
# Translation tasks
########################################

# 6 total
gpt3_translation_benchmarks = {
    "wmt14": ['en-fr', 'fr-en'],  # French
    "wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'],  # German, Romanian
}


# 28 total
selected_translation_benchmarks = {
    **gpt3_translation_benchmarks,
    "wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
    "iwslt17": ['en-ar', 'ar-en']  # Arabic
}

# 319 total
all_translation_benchmarks = {
    ts: sacrebleu.get_langpairs_for_testset(ts)
    for ts in sacrebleu.get_available_testsets()
}


########################################
# All tasks
########################################