Пример #1
0
def test_seq_copy(train_params: str, translate_params: str,
                  use_prepared_data: bool, use_source_factors: bool):
    """
    Task: copy short sequences of digits
    """

    with tmp_digits_dataset(prefix="test_seq_copy",
                            train_line_count=_TRAIN_LINE_COUNT,
                            train_line_count_empty=_TRAIN_LINE_COUNT_EMPTY,
                            train_max_length=_LINE_MAX_LENGTH,
                            dev_line_count=_DEV_LINE_COUNT,
                            dev_max_length=_LINE_MAX_LENGTH,
                            test_line_count=_TEST_LINE_COUNT,
                            test_line_count_empty=_TEST_LINE_COUNT_EMPTY,
                            test_max_length=_TEST_MAX_LENGTH,
                            sort_target=False,
                            with_source_factors=use_source_factors) as data:

        # TODO: Here we temporarily switch off comparing translation and scoring scores, which
        # sometimes produces inconsistent results for --batch-size > 1 (see issue #639 on github).
        check_train_translate(train_params=train_params,
                              translate_params=translate_params,
                              data=data,
                              use_prepared_data=use_prepared_data,
                              max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS,
                              compare_output=False)
Пример #2
0
def test_seq_sort(name, train_params, translate_params, use_prepared_data,
                  n_source_factors, n_target_factors, perplexity_thresh, bleu_thresh):
    """Task: sort short sequences of digits"""
    with tmp_digits_dataset("test_seq_sort.",
                            _TRAIN_LINE_COUNT, _TRAIN_LINE_COUNT_EMPTY, _LINE_MAX_LENGTH,
                            _DEV_LINE_COUNT, _LINE_MAX_LENGTH,
                            _TEST_LINE_COUNT, _TEST_LINE_COUNT_EMPTY, _TEST_MAX_LENGTH,
                            sort_target=True, seed_train=_SEED_TRAIN_DATA, seed_dev=_SEED_DEV_DATA,
                            with_n_source_factors=n_source_factors,
                            with_n_target_factors=n_target_factors) as data:
        data = check_train_translate(train_params=train_params,
                                     translate_params=translate_params,
                                     data=data,
                                     use_prepared_data=use_prepared_data,
                                     max_seq_len=_LINE_MAX_LENGTH,
                                     compare_output=True,
                                     seed=seed)

        # get best validation perplexity
        metrics = sockeye.utils.read_metrics_file(os.path.join(data['model'], C.METRICS_NAME))
        perplexity = min(m[C.PERPLEXITY + '-val'] for m in metrics)

        # compute metrics
        hypotheses = [json['translation'] for json in data['test_outputs']]
        hypotheses_restricted = [json['translation'] for json in data['test_outputs_restricted']]
        bleu = sockeye.evaluate.raw_corpus_bleu(hypotheses=hypotheses, references=data['test_targets'])
        chrf = sockeye.evaluate.raw_corpus_chrf(hypotheses=hypotheses, references=data['test_targets'])
        bleu_restrict = sockeye.evaluate.raw_corpus_bleu(hypotheses=hypotheses_restricted,
                                                         references=data['test_targets'])

        logger.info("test: %s", name)
        logger.info("perplexity=%f, bleu=%f, bleu_restrict=%f chrf=%f", perplexity, bleu, bleu_restrict, chrf)
        assert perplexity <= perplexity_thresh
        assert bleu >= bleu_thresh
        assert bleu_restrict >= bleu_thresh
Пример #3
0
def test_seq_copy(name, train_params, translate_params, use_prepared_data,
                  perplexity_thresh, bleu_thresh):
    """Task: copy short sequences of digits"""
    with tmp_digits_dataset(prefix="test_seq_copy",
                            train_line_count=_TRAIN_LINE_COUNT,
                            train_line_count_empty=_TRAIN_LINE_COUNT_EMPTY,
                            train_max_length=_LINE_MAX_LENGTH,
                            dev_line_count=_DEV_LINE_COUNT,
                            dev_max_length=_LINE_MAX_LENGTH,
                            test_line_count=_TEST_LINE_COUNT,
                            test_line_count_empty=_TEST_LINE_COUNT_EMPTY,
                            test_max_length=_TEST_MAX_LENGTH,
                            sort_target=False,
                            with_n_source_factors=0) as data:
        data = check_train_translate(train_params=train_params,
                                     translate_params=translate_params,
                                     data=data,
                                     use_prepared_data=use_prepared_data,
                                     max_seq_len=_LINE_MAX_LENGTH,
                                     compare_output=True,
                                     seed=seed)

        # get best validation perplexity
        metrics = sockeye.utils.read_metrics_file(
            os.path.join(data['model'], C.METRICS_NAME))
        perplexity = min(m[C.PERPLEXITY + '-val'] for m in metrics)

        # compute metrics
        hypotheses = [json['translation'] for json in data['test_outputs']]
        bleu = sockeye.evaluate.raw_corpus_bleu(
            hypotheses=hypotheses, references=data['test_targets'])
        chrf = sockeye.evaluate.raw_corpus_chrf(
            hypotheses=hypotheses, references=data['test_targets'])
        if 'test_outputs_restricted' in data:
            hypotheses_restricted = [
                json['translation'] for json in data['test_outputs_restricted']
            ]
            bleu_restrict = sockeye.evaluate.raw_corpus_bleu(
                hypotheses=hypotheses_restricted,
                references=data['test_targets'])
        else:
            bleu_restrict = None

        logger.info("================")
        logger.info("test results: %s", name)
        logger.info("perplexity=%f, bleu=%f, bleu_restrict=%f chrf=%f",
                    perplexity, bleu, bleu_restrict, chrf)
        logger.info("================\n")

        assert perplexity <= perplexity_thresh
        assert bleu >= bleu_thresh
        if bleu_restrict is not None:
            assert bleu_restrict >= bleu_thresh
Пример #4
0
def test_seq_copy(train_params: str, translate_params: str,
                  use_prepared_data: bool, use_source_factors: bool):
    """
    Task: copy short sequences of digits
    """

    with tmp_digits_dataset(prefix="test_seq_copy",
                            train_line_count=_TRAIN_LINE_COUNT,
                            train_max_length=_LINE_MAX_LENGTH,
                            dev_line_count=_DEV_LINE_COUNT,
                            dev_max_length=_LINE_MAX_LENGTH,
                            test_line_count=_TEST_LINE_COUNT,
                            test_line_count_empty=_TEST_LINE_COUNT_EMPTY,
                            test_max_length=_TEST_MAX_LENGTH,
                            sort_target=False,
                            with_source_factors=use_source_factors) as data:
        check_train_translate(train_params=train_params,
                              translate_params=translate_params,
                              data=data,
                              use_prepared_data=use_prepared_data,
                              max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS)