def test_constraints(train_params: str, beam_size: int, batch_size: int):
    """Task: copy short sequences of digits"""

    with tmp_digits_dataset(prefix="test_constraints",
                            train_line_count=_TRAIN_LINE_COUNT,
                            train_max_length=_LINE_MAX_LENGTH,
                            dev_line_count=_DEV_LINE_COUNT,
                            dev_max_length=_LINE_MAX_LENGTH,
                            test_line_count=_TEST_LINE_COUNT,
                            test_line_count_empty=_TEST_LINE_COUNT_EMPTY,
                            test_max_length=_TEST_MAX_LENGTH,
                            sort_target=False) as data:

        translate_params = " --batch-size {} --beam-size {}".format(
            batch_size, beam_size)

        # Ignore return values (perplexity and BLEU) for integration test
        run_train_translate(train_params=train_params,
                            translate_params=translate_params,
                            translate_params_equiv=None,
                            train_source_path=data['source'],
                            train_target_path=data['target'],
                            dev_source_path=data['validation_source'],
                            dev_target_path=data['validation_target'],
                            test_source_path=data['test_source'],
                            test_target_path=data['test_target'],
                            max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS,
                            work_dir=data['work_dir'],
                            use_prepared_data=False,
                            restrict_lexicon=False,
                            use_target_constraints=True)
示例#2
0
def test_seq_copy(train_params: str,
                  translate_params: str,
                  restrict_lexicon: bool,
                  use_prepared_data: bool,
                  use_source_factors: bool,
                  use_constrained_decoding: bool):
    """Task: copy short sequences of digits"""

    with tmp_digits_dataset(prefix="test_seq_copy",
                            train_line_count=_TRAIN_LINE_COUNT,
                            train_max_length=_LINE_MAX_LENGTH,
                            dev_line_count=_DEV_LINE_COUNT,
                            dev_max_length=_LINE_MAX_LENGTH,
                            test_line_count=_TEST_LINE_COUNT,
                            test_line_count_empty=_TEST_LINE_COUNT_EMPTY,
                            test_max_length=_TEST_MAX_LENGTH,
                            sort_target=False,
                            with_source_factors=use_source_factors,
                            with_target_constraints=use_constrained_decoding) as data:

        # Only one of these is supported at a time in the tests
        assert not (use_source_factors and use_constrained_decoding)

        # When using source factors
        train_source_factor_paths, dev_source_factor_paths, test_source_factor_paths = None, None, None
        if use_source_factors:
            train_source_factor_paths = [data['source']]
            dev_source_factor_paths = [data['validation_source']]
            test_source_factor_paths = [data['test_source']]

        if use_constrained_decoding:
            translate_params += " --json-input"

        # Test model configuration, including the output equivalence of batch and no-batch decoding
        translate_params_batch = translate_params + " --batch-size 2"

        # Ignore return values (perplexity and BLEU) for integration test
        run_train_translate(train_params=train_params,
                            translate_params=translate_params,
                            translate_params_equiv=translate_params_batch,
                            train_source_path=data['source'],
                            train_target_path=data['target'],
                            dev_source_path=data['validation_source'],
                            dev_target_path=data['validation_target'],
                            test_source_path=data['test_source'],
                            test_target_path=data['test_target'],
                            train_source_factor_paths=train_source_factor_paths,
                            dev_source_factor_paths=dev_source_factor_paths,
                            test_source_factor_paths=test_source_factor_paths,
                            max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS,
                            restrict_lexicon=restrict_lexicon,
                            work_dir=data['work_dir'],
                            use_prepared_data=use_prepared_data)
示例#3
0
def test_seq_copy(train_params: str, translate_params: str,
                  restrict_lexicon: bool, use_prepared_data: bool,
                  use_source_factors: bool):
    """Task: copy short sequences of digits"""

    with tmp_digits_dataset(prefix="test_seq_copy",
                            train_line_count=_TRAIN_LINE_COUNT,
                            train_max_length=_LINE_MAX_LENGTH,
                            dev_line_count=_DEV_LINE_COUNT,
                            dev_max_length=_LINE_MAX_LENGTH,
                            test_line_count=_TEST_LINE_COUNT,
                            test_line_count_empty=_TEST_LINE_COUNT_EMPTY,
                            test_max_length=_TEST_MAX_LENGTH,
                            sort_target=False,
                            with_source_factors=use_source_factors) as data:

        # When using source factors
        train_source_factor_paths, dev_source_factor_paths, test_source_factor_paths = None, None, None
        if use_source_factors:
            train_source_factor_paths = [data['source']]
            dev_source_factor_paths = [data['validation_source']]
            test_source_factor_paths = [data['test_source']]

        # Test model configuration, including the output equivalence of batch and no-batch decoding
        if "--nbest-size" not in translate_params.split():
            translate_params_batch = translate_params + " --batch-size 2"
        else:
            # nbest produces json output, which doesn't work with the splitting
            # of translations and scores in run_train_translate, which in turn
            # makes the comparison with the batch decoding fail.
            # TODO: Refactor the run_train_translate function!
            translate_params_batch = None

        # Ignore return values (perplexity and BLEU) for integration test
        run_train_translate(
            train_params=train_params,
            translate_params=translate_params,
            translate_params_equiv=translate_params_batch,
            train_source_path=data['source'],
            train_target_path=data['target'],
            dev_source_path=data['validation_source'],
            dev_target_path=data['validation_target'],
            test_source_path=data['test_source'],
            test_target_path=data['test_target'],
            train_source_factor_paths=train_source_factor_paths,
            dev_source_factor_paths=dev_source_factor_paths,
            test_source_factor_paths=test_source_factor_paths,
            max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS,
            restrict_lexicon=restrict_lexicon,
            work_dir=data['work_dir'],
            use_prepared_data=use_prepared_data,
            use_target_constraints=False)
示例#4
0
def test_seq_copy(train_params: str, translate_params: str,
                  restrict_lexicon: bool, use_prepared_data: bool,
                  use_source_factors: bool, use_constrained_decoding: bool):
    """Task: copy short sequences of digits"""

    with tmp_digits_dataset(
            prefix="test_seq_copy",
            train_line_count=_TRAIN_LINE_COUNT,
            train_max_length=_LINE_MAX_LENGTH,
            dev_line_count=_DEV_LINE_COUNT,
            dev_max_length=_LINE_MAX_LENGTH,
            test_line_count=_TEST_LINE_COUNT,
            test_line_count_empty=_TEST_LINE_COUNT_EMPTY,
            test_max_length=_TEST_MAX_LENGTH,
            sort_target=False,
            with_source_factors=use_source_factors,
            with_target_constraints=use_constrained_decoding) as data:

        # Only one of these is supported at a time in the tests
        assert not (use_source_factors and use_constrained_decoding)

        # When using source factors
        train_source_factor_paths, dev_source_factor_paths, test_source_factor_paths = None, None, None
        if use_source_factors:
            train_source_factor_paths = [data['source']]
            dev_source_factor_paths = [data['validation_source']]
            test_source_factor_paths = [data['test_source']]

        if use_constrained_decoding:
            translate_params += " --json-input"

        # Test model configuration, including the output equivalence of batch and no-batch decoding
        translate_params_batch = translate_params + " --batch-size 2"

        # Ignore return values (perplexity and BLEU) for integration test
        run_train_translate(
            train_params=train_params,
            translate_params=translate_params,
            translate_params_equiv=translate_params_batch,
            train_source_path=data['source'],
            train_target_path=data['target'],
            dev_source_path=data['validation_source'],
            dev_target_path=data['validation_target'],
            test_source_path=data['test_target'],
            test_target_path=data['test_target'],
            train_source_factor_paths=train_source_factor_paths,
            dev_source_factor_paths=dev_source_factor_paths,
            test_source_factor_paths=test_source_factor_paths,
            max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS,
            restrict_lexicon=restrict_lexicon,
            work_dir=data['work_dir'],
            use_prepared_data=use_prepared_data)
示例#5
0
def test_other_clis(train_params: str, translate_params: str):
    """
    Task: test CLIs and core features other than train & translate.
    """
    with tmp_digits_dataset(prefix="test_other_clis",
                            train_line_count=_TRAIN_LINE_COUNT,
                            train_line_count_empty=_TRAIN_LINE_COUNT_EMPTY,
                            train_max_length=_LINE_MAX_LENGTH,
                            dev_line_count=_DEV_LINE_COUNT,
                            dev_max_length=_LINE_MAX_LENGTH,
                            test_line_count=_TEST_LINE_COUNT,
                            test_line_count_empty=_TEST_LINE_COUNT_EMPTY,
                            test_max_length=_TEST_MAX_LENGTH) as data:
        # train a minimal default model
        data = run_train_translate(train_params=train_params,
                                   translate_params=translate_params,
                                   data=data,
                                   max_seq_len=_LINE_MAX_LENGTH +
                                   C.SPACE_FOR_XOS)

        _test_checkpoint_decoder(data['dev_source'], data['dev_target'],
                                 data['model'])
        _test_parameter_averaging(data['model'])
        _test_extract_parameters_cli(data['model'])
        _test_evaluate_cli(data['test_outputs'], data['test_target'])
示例#6
0
def test_seq_copy(name, train_params, translate_params, use_prepared_data,
                  perplexity_thresh, bleu_thresh):
    """Task: copy short sequences of digits"""
    with tmp_digits_dataset("test_seq_copy.",
                            _TRAIN_LINE_COUNT,
                            _LINE_MAX_LENGTH,
                            _DEV_LINE_COUNT,
                            _LINE_MAX_LENGTH,
                            _TEST_LINE_COUNT,
                            _TEST_LINE_COUNT_EMPTY,
                            _TEST_MAX_LENGTH,
                            seed_train=_SEED_TRAIN_DATA,
                            seed_dev=_SEED_DEV_DATA) as data:
        # Test model configuration
        perplexity, bleu, bleu_restrict, chrf = run_train_translate(
            train_params,
            translate_params,
            None,  # no second set of parameters
            data['source'],
            data['target'],
            data['validation_source'],
            data['validation_target'],
            data['test_source'],
            data['test_target'],
            use_prepared_data=use_prepared_data,
            max_seq_len=_LINE_MAX_LENGTH + 1,
            restrict_lexicon=True,
            work_dir=data['work_dir'],
            seed=seed)
        logger.info("test: %s", name)
        logger.info("perplexity=%f, bleu=%f, bleu_restrict=%f chrf=%f",
                    perplexity, bleu, bleu_restrict, chrf)
        assert perplexity <= perplexity_thresh
        assert bleu >= bleu_thresh
        assert bleu_restrict >= bleu_thresh
示例#7
0
def test_seq_copy(train_params, translate_params, perplexity_thresh,
                  bleu_thresh):
    """Task: copy short sequences of digits"""
    with TemporaryDirectory(prefix="test_seq_copy.") as work_dir:
        # Simple digits files for train/dev data
        train_source_path = os.path.join(work_dir, "train.src")
        train_target_path = os.path.join(work_dir, "train.tgt")
        dev_source_path = os.path.join(work_dir, "dev.src")
        dev_target_path = os.path.join(work_dir, "dev.tgt")
        generate_digits_file(train_source_path, train_target_path,
                             _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH)
        generate_digits_file(dev_source_path, dev_target_path, _DEV_LINE_COUNT,
                             _LINE_MAX_LENGTH)
        # Test model configuration
        perplexity, bleu = run_train_translate(train_params,
                                               translate_params,
                                               train_source_path,
                                               train_target_path,
                                               dev_source_path,
                                               dev_target_path,
                                               max_seq_len=_LINE_MAX_LENGTH +
                                               1,
                                               work_dir=work_dir)
        assert perplexity <= perplexity_thresh
        assert bleu >= bleu_thresh
示例#8
0
def test_seq_sort(name, train_params, translate_params, use_prepared_data,
                  use_source_factor, perplexity_thresh, bleu_thresh):
    """Task: sort short sequences of digits"""
    with tmp_digits_dataset("test_seq_sort.", _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH, _DEV_LINE_COUNT, _LINE_MAX_LENGTH,
                            _TEST_LINE_COUNT, _TEST_LINE_COUNT_EMPTY, _TEST_MAX_LENGTH,
                            sort_target=True, seed_train=_SEED_TRAIN_DATA, seed_dev=_SEED_DEV_DATA,
                            with_source_factors=use_source_factor) as data:
        # Test model configuration
        perplexity, bleu, bleu_restrict, chrf = run_train_translate(train_params=train_params,
                                                                    translate_params=translate_params,
                                                                    translate_params_equiv=None,
                                                                    train_source_path=data['source'],
                                                                    train_target_path=data['target'],
                                                                    dev_source_path=data['validation_source'],
                                                                    dev_target_path=data['validation_target'],
                                                                    test_source_path=data['test_source'],
                                                                    test_target_path=data['test_target'],
                                                                    train_source_factor_paths=data.get(
                                                                        'train_source_factors'),
                                                                    dev_source_factor_paths=data.get(
                                                                        'dev_source_factors'),
                                                                    test_source_factor_paths=data.get(
                                                                        'test_source_factors'),
                                                                    use_prepared_data=use_prepared_data,
                                                                    max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS,
                                                                    restrict_lexicon=True,
                                                                    work_dir=data['work_dir'],
                                                                    seed=seed)
        logger.info("test: %s", name)
        logger.info("perplexity=%f, bleu=%f, bleu_restrict=%f chrf=%f", perplexity, bleu, bleu_restrict, chrf)
        assert perplexity <= perplexity_thresh
        assert bleu >= bleu_thresh
        assert bleu_restrict >= bleu_thresh
def test_seq_sort(name, train_params, translate_params, use_prepared_data,
                  use_source_factor, perplexity_thresh, bleu_thresh):
    """Task: sort short sequences of digits"""
    with tmp_digits_dataset("test_seq_sort.", _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH, _DEV_LINE_COUNT, _LINE_MAX_LENGTH,
                            _TEST_LINE_COUNT, _TEST_LINE_COUNT_EMPTY, _TEST_MAX_LENGTH,
                            sort_target=True, seed_train=_SEED_TRAIN_DATA, seed_dev=_SEED_DEV_DATA,
                            with_source_factors=use_source_factor) as data:
        # Test model configuration
        perplexity, bleu, bleu_restrict, chrf = run_train_translate(train_params=train_params,
                                                                    translate_params=translate_params,
                                                                    translate_params_equiv=None,
                                                                    train_source_path=data['source'],
                                                                    train_target_path=data['target'],
                                                                    dev_source_path=data['validation_source'],
                                                                    dev_target_path=data['validation_target'],
                                                                    test_source_path=data['test_source'],
                                                                    test_target_path=data['test_target'],
                                                                    train_source_factor_paths=data.get(
                                                                        'train_source_factors'),
                                                                    dev_source_factor_paths=data.get(
                                                                        'dev_source_factors'),
                                                                    test_source_factor_paths=data.get(
                                                                        'test_source_factors'),
                                                                    use_prepared_data=use_prepared_data,
                                                                    max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS,
                                                                    restrict_lexicon=True,
                                                                    work_dir=data['work_dir'],
                                                                    seed=seed)
        logger.info("test: %s", name)
        logger.info("perplexity=%f, bleu=%f, bleu_restrict=%f chrf=%f", perplexity, bleu, bleu_restrict, chrf)
        assert perplexity <= perplexity_thresh
        assert bleu >= bleu_thresh
        assert bleu_restrict >= bleu_thresh
示例#10
0
def test_seq_sort(name, train_params, translate_params, perplexity_thresh,
                  bleu_thresh):
    """Task: sort short sequences of digits"""
    with tmp_digits_dataset("test_seq_sort.",
                            _TRAIN_LINE_COUNT,
                            _LINE_MAX_LENGTH,
                            _DEV_LINE_COUNT,
                            _LINE_MAX_LENGTH,
                            sort_target=True,
                            seed_train=_SEED_TRAIN,
                            seed_dev=_SEED_DEV) as data:
        # Test model configuration
        perplexity, bleu, bleu_restrict = run_train_translate(
            train_params,
            translate_params,
            None,  # no second set of parameters
            data['source'],
            data['target'],
            data['validation_source'],
            data['validation_target'],
            max_seq_len=_LINE_MAX_LENGTH + 1,
            restrict_lexicon=True,
            work_dir=data['work_dir'])
        logger.info("test: %s", name)
        logger.info("perplexity=%f, bleu=%f, bleu_restrict=%f", perplexity,
                    bleu, bleu_restrict)
        assert perplexity <= perplexity_thresh
        assert bleu >= bleu_thresh
        assert bleu_restrict >= bleu_thresh
示例#11
0
def test_seq_copy(train_params: str, translate_params: str, restrict_lexicon: bool):
    """Task: copy short sequences of digits"""
    with tmp_digits_dataset("test_seq_copy", _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH, _DEV_LINE_COUNT,
                            _LINE_MAX_LENGTH) as data:
        # Test model configuration, including the output equivalence of batch and no-batch decoding
        translate_params_batch = translate_params + " --batch-size 2"
        # Ignore return values (perplexity and BLEU) for integration test
        run_train_translate(train_params,
                            translate_params,
                            translate_params_batch,
                            data['source'],
                            data['target'],
                            data['validation_source'],
                            data['validation_target'],
                            max_seq_len=_LINE_MAX_LENGTH + 1,
                            restrict_lexicon=restrict_lexicon,
                            work_dir=data['work_dir'])
示例#12
0
def test_seq_copy(train_params, translate_params):
    """Task: copy short sequences of digits"""
    with TemporaryDirectory(prefix="test_seq_copy") as work_dir:
        # Simple digits files for train/dev data
        train_source_path = os.path.join(work_dir, "train.src")
        train_target_path = os.path.join(work_dir, "train.tgt")
        dev_source_path = os.path.join(work_dir, "dev.src")
        dev_target_path = os.path.join(work_dir, "dev.tgt")
        generate_digits_file(train_source_path, train_target_path, _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH)
        generate_digits_file(dev_source_path, dev_target_path, _DEV_LINE_COUNT, _LINE_MAX_LENGTH)
        # Test model configuration
        # Ignore return values (perplexity and BLEU) for integration test
        run_train_translate(train_params,
                            translate_params,
                            train_source_path,
                            train_target_path,
                            dev_source_path,
                            dev_target_path,
                            max_seq_len=_LINE_MAX_LENGTH + 1,
                            work_dir=work_dir)
示例#13
0
def test_constraints(train_params: str, translate_params: str):
    with tmp_digits_dataset(prefix="test_constraints",
                            train_line_count=_TRAIN_LINE_COUNT,
                            train_max_length=_LINE_MAX_LENGTH,
                            dev_line_count=_DEV_LINE_COUNT,
                            dev_max_length=_LINE_MAX_LENGTH,
                            test_line_count=_TEST_LINE_COUNT,
                            test_line_count_empty=_TEST_LINE_COUNT_EMPTY,
                            test_max_length=_TEST_MAX_LENGTH,
                            sort_target=False) as data:
        # train a minimal default model
        data = run_train_translate(train_params=train_params, translate_params=translate_params, data=data,
                                   max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS)

        # 'constraint' = positive constraints (must appear), 'avoid' = negative constraints (must not appear)
        for constraint_type in ["constraints", "avoid"]:
            _test_constrained_type(constraint_type=constraint_type, data=data, translate_params=translate_params)