Ejemplo n.º 1
0
def test_seq_copy(train_params: str, translate_params: str,
                  use_prepared_data: bool, use_source_factors: bool):
    """
    Task: copy short sequences of digits
    """

    with tmp_digits_dataset(prefix="test_seq_copy",
                            train_line_count=_TRAIN_LINE_COUNT,
                            train_line_count_empty=_TRAIN_LINE_COUNT_EMPTY,
                            train_max_length=_LINE_MAX_LENGTH,
                            dev_line_count=_DEV_LINE_COUNT,
                            dev_max_length=_LINE_MAX_LENGTH,
                            test_line_count=_TEST_LINE_COUNT,
                            test_line_count_empty=_TEST_LINE_COUNT_EMPTY,
                            test_max_length=_TEST_MAX_LENGTH,
                            sort_target=False,
                            with_source_factors=use_source_factors) as data:

        # TODO: Here we temporarily switch off comparing translation and scoring scores, which
        # sometimes produces inconsistent results for --batch-size > 1 (see issue #639 on github).
        check_train_translate(train_params=train_params,
                              translate_params=translate_params,
                              data=data,
                              use_prepared_data=use_prepared_data,
                              max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS,
                              compare_output=False)
Ejemplo n.º 2
0
def test_other_clis(train_params: str, translate_params: str):
    """
    Task: test CLIs and core features other than train & translate.
    """
    with tmp_digits_dataset(prefix="test_other_clis",
                            train_line_count=_TRAIN_LINE_COUNT,
                            train_line_count_empty=_TRAIN_LINE_COUNT_EMPTY,
                            train_max_length=_LINE_MAX_LENGTH,
                            dev_line_count=_DEV_LINE_COUNT,
                            dev_max_length=_LINE_MAX_LENGTH,
                            test_line_count=_TEST_LINE_COUNT,
                            test_line_count_empty=_TEST_LINE_COUNT_EMPTY,
                            test_max_length=_TEST_MAX_LENGTH) as data:
        # train a minimal default model
        data = run_train_translate(train_params=train_params,
                                   translate_params=translate_params,
                                   data=data,
                                   max_seq_len=_LINE_MAX_LENGTH +
                                   C.SPACE_FOR_XOS)

        _test_checkpoint_decoder(data['dev_source'], data['dev_target'],
                                 data['model'])
        _test_parameter_averaging(data['model'])
        _test_extract_parameters_cli(data['model'])
        _test_evaluate_cli(data['test_outputs'], data['test_target'])
Ejemplo n.º 3
0
def test_seq_sort(name, train_params, translate_params, use_prepared_data,
                  use_source_factor, perplexity_thresh, bleu_thresh):
    """Task: sort short sequences of digits"""
    with tmp_digits_dataset("test_seq_sort.", _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH, _DEV_LINE_COUNT, _LINE_MAX_LENGTH,
                            _TEST_LINE_COUNT, _TEST_LINE_COUNT_EMPTY, _TEST_MAX_LENGTH,
                            sort_target=True, seed_train=_SEED_TRAIN_DATA, seed_dev=_SEED_DEV_DATA,
                            with_source_factors=use_source_factor) as data:
        # Test model configuration
        perplexity, bleu, bleu_restrict, chrf = run_train_translate(train_params=train_params,
                                                                    translate_params=translate_params,
                                                                    translate_params_equiv=None,
                                                                    train_source_path=data['source'],
                                                                    train_target_path=data['target'],
                                                                    dev_source_path=data['validation_source'],
                                                                    dev_target_path=data['validation_target'],
                                                                    test_source_path=data['test_source'],
                                                                    test_target_path=data['test_target'],
                                                                    train_source_factor_paths=data.get(
                                                                        'train_source_factors'),
                                                                    dev_source_factor_paths=data.get(
                                                                        'dev_source_factors'),
                                                                    test_source_factor_paths=data.get(
                                                                        'test_source_factors'),
                                                                    use_prepared_data=use_prepared_data,
                                                                    max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS,
                                                                    restrict_lexicon=True,
                                                                    work_dir=data['work_dir'],
                                                                    seed=seed)
        logger.info("test: %s", name)
        logger.info("perplexity=%f, bleu=%f, bleu_restrict=%f chrf=%f", perplexity, bleu, bleu_restrict, chrf)
        assert perplexity <= perplexity_thresh
        assert bleu >= bleu_thresh
        assert bleu_restrict >= bleu_thresh
Ejemplo n.º 4
0
def test_seq_copy(name, train_params, translate_params, use_prepared_data,
                  perplexity_thresh, bleu_thresh):
    """Task: copy short sequences of digits"""
    with tmp_digits_dataset("test_seq_copy.",
                            _TRAIN_LINE_COUNT,
                            _LINE_MAX_LENGTH,
                            _DEV_LINE_COUNT,
                            _LINE_MAX_LENGTH,
                            _TEST_LINE_COUNT,
                            _TEST_LINE_COUNT_EMPTY,
                            _TEST_MAX_LENGTH,
                            seed_train=_SEED_TRAIN_DATA,
                            seed_dev=_SEED_DEV_DATA) as data:
        # Test model configuration
        perplexity, bleu, bleu_restrict, chrf = run_train_translate(
            train_params,
            translate_params,
            None,  # no second set of parameters
            data['source'],
            data['target'],
            data['validation_source'],
            data['validation_target'],
            data['test_source'],
            data['test_target'],
            use_prepared_data=use_prepared_data,
            max_seq_len=_LINE_MAX_LENGTH + 1,
            restrict_lexicon=True,
            work_dir=data['work_dir'],
            seed=seed)
        logger.info("test: %s", name)
        logger.info("perplexity=%f, bleu=%f, bleu_restrict=%f chrf=%f",
                    perplexity, bleu, bleu_restrict, chrf)
        assert perplexity <= perplexity_thresh
        assert bleu >= bleu_thresh
        assert bleu_restrict >= bleu_thresh
Ejemplo n.º 5
0
def test_seq_sort(name, train_params, translate_params, use_prepared_data,
                  use_source_factor, perplexity_thresh, bleu_thresh):
    """Task: sort short sequences of digits"""
    with tmp_digits_dataset("test_seq_sort.", _TRAIN_LINE_COUNT, _TRAIN_LINE_COUNT_EMPTY, _LINE_MAX_LENGTH, _DEV_LINE_COUNT, _LINE_MAX_LENGTH,
                            _TEST_LINE_COUNT, _TEST_LINE_COUNT_EMPTY, _TEST_MAX_LENGTH,
                            sort_target=True, seed_train=_SEED_TRAIN_DATA, seed_dev=_SEED_DEV_DATA,
                            with_source_factors=use_source_factor) as data:
        data = check_train_translate(train_params=train_params,
                                     translate_params=translate_params,
                                     data=data,
                                     use_prepared_data=use_prepared_data,
                                     max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS,
                                     compare_output=True,
                                     seed=seed)

        # get best validation perplexity
        metrics = sockeye.utils.read_metrics_file(os.path.join(data['model'], C.METRICS_NAME))
        perplexity = min(m[C.PERPLEXITY + '-val'] for m in metrics)

        # compute metrics
        bleu = sockeye.evaluate.raw_corpus_bleu(hypotheses=data['test_outputs'], references=data['test_targets'])
        chrf = sockeye.evaluate.raw_corpus_chrf(hypotheses=data['test_outputs'], references=data['test_targets'])
        bleu_restrict = sockeye.evaluate.raw_corpus_bleu(hypotheses=data['test_outputs_restricted'],
                                                         references=data['test_targets'])

        logger.info("test: %s", name)
        logger.info("perplexity=%f, bleu=%f, bleu_restrict=%f chrf=%f", perplexity, bleu, bleu_restrict, chrf)
        assert perplexity <= perplexity_thresh
        assert bleu >= bleu_thresh
        assert bleu_restrict >= bleu_thresh
def test_seq_sort(name, train_params, translate_params, use_prepared_data,
                  use_source_factor, perplexity_thresh, bleu_thresh):
    """Task: sort short sequences of digits"""
    with tmp_digits_dataset("test_seq_sort.", _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH, _DEV_LINE_COUNT, _LINE_MAX_LENGTH,
                            _TEST_LINE_COUNT, _TEST_LINE_COUNT_EMPTY, _TEST_MAX_LENGTH,
                            sort_target=True, seed_train=_SEED_TRAIN_DATA, seed_dev=_SEED_DEV_DATA,
                            with_source_factors=use_source_factor) as data:
        # Test model configuration
        perplexity, bleu, bleu_restrict, chrf = run_train_translate(train_params=train_params,
                                                                    translate_params=translate_params,
                                                                    translate_params_equiv=None,
                                                                    train_source_path=data['source'],
                                                                    train_target_path=data['target'],
                                                                    dev_source_path=data['validation_source'],
                                                                    dev_target_path=data['validation_target'],
                                                                    test_source_path=data['test_source'],
                                                                    test_target_path=data['test_target'],
                                                                    train_source_factor_paths=data.get(
                                                                        'train_source_factors'),
                                                                    dev_source_factor_paths=data.get(
                                                                        'dev_source_factors'),
                                                                    test_source_factor_paths=data.get(
                                                                        'test_source_factors'),
                                                                    use_prepared_data=use_prepared_data,
                                                                    max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS,
                                                                    restrict_lexicon=True,
                                                                    work_dir=data['work_dir'],
                                                                    seed=seed)
        logger.info("test: %s", name)
        logger.info("perplexity=%f, bleu=%f, bleu_restrict=%f chrf=%f", perplexity, bleu, bleu_restrict, chrf)
        assert perplexity <= perplexity_thresh
        assert bleu >= bleu_thresh
        assert bleu_restrict >= bleu_thresh
def test_constraints(train_params: str, beam_size: int, batch_size: int):
    """Task: copy short sequences of digits"""

    with tmp_digits_dataset(prefix="test_constraints",
                            train_line_count=_TRAIN_LINE_COUNT,
                            train_max_length=_LINE_MAX_LENGTH,
                            dev_line_count=_DEV_LINE_COUNT,
                            dev_max_length=_LINE_MAX_LENGTH,
                            test_line_count=_TEST_LINE_COUNT,
                            test_line_count_empty=_TEST_LINE_COUNT_EMPTY,
                            test_max_length=_TEST_MAX_LENGTH,
                            sort_target=False) as data:

        translate_params = " --batch-size {} --beam-size {}".format(
            batch_size, beam_size)

        # Ignore return values (perplexity and BLEU) for integration test
        run_train_translate(train_params=train_params,
                            translate_params=translate_params,
                            translate_params_equiv=None,
                            train_source_path=data['source'],
                            train_target_path=data['target'],
                            dev_source_path=data['validation_source'],
                            dev_target_path=data['validation_target'],
                            test_source_path=data['test_source'],
                            test_target_path=data['test_target'],
                            max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS,
                            work_dir=data['work_dir'],
                            use_prepared_data=False,
                            restrict_lexicon=False,
                            use_target_constraints=True)
Ejemplo n.º 8
0
def test_seq_sort(name, train_params, translate_params, perplexity_thresh,
                  bleu_thresh):
    """Task: sort short sequences of digits"""
    with tmp_digits_dataset("test_seq_sort.",
                            _TRAIN_LINE_COUNT,
                            _LINE_MAX_LENGTH,
                            _DEV_LINE_COUNT,
                            _LINE_MAX_LENGTH,
                            sort_target=True,
                            seed_train=_SEED_TRAIN,
                            seed_dev=_SEED_DEV) as data:
        # Test model configuration
        perplexity, bleu, bleu_restrict = run_train_translate(
            train_params,
            translate_params,
            None,  # no second set of parameters
            data['source'],
            data['target'],
            data['validation_source'],
            data['validation_target'],
            max_seq_len=_LINE_MAX_LENGTH + 1,
            restrict_lexicon=True,
            work_dir=data['work_dir'])
        logger.info("test: %s", name)
        logger.info("perplexity=%f, bleu=%f, bleu_restrict=%f", perplexity,
                    bleu, bleu_restrict)
        assert perplexity <= perplexity_thresh
        assert bleu >= bleu_thresh
        assert bleu_restrict >= bleu_thresh
Ejemplo n.º 9
0
def test_seq_copy(train_params: str,
                  translate_params: str,
                  restrict_lexicon: bool,
                  use_prepared_data: bool,
                  use_source_factors: bool,
                  use_constrained_decoding: bool):
    """Task: copy short sequences of digits"""

    with tmp_digits_dataset(prefix="test_seq_copy",
                            train_line_count=_TRAIN_LINE_COUNT,
                            train_max_length=_LINE_MAX_LENGTH,
                            dev_line_count=_DEV_LINE_COUNT,
                            dev_max_length=_LINE_MAX_LENGTH,
                            test_line_count=_TEST_LINE_COUNT,
                            test_line_count_empty=_TEST_LINE_COUNT_EMPTY,
                            test_max_length=_TEST_MAX_LENGTH,
                            sort_target=False,
                            with_source_factors=use_source_factors,
                            with_target_constraints=use_constrained_decoding) as data:

        # Only one of these is supported at a time in the tests
        assert not (use_source_factors and use_constrained_decoding)

        # When using source factors
        train_source_factor_paths, dev_source_factor_paths, test_source_factor_paths = None, None, None
        if use_source_factors:
            train_source_factor_paths = [data['source']]
            dev_source_factor_paths = [data['validation_source']]
            test_source_factor_paths = [data['test_source']]

        if use_constrained_decoding:
            translate_params += " --json-input"

        # Test model configuration, including the output equivalence of batch and no-batch decoding
        translate_params_batch = translate_params + " --batch-size 2"

        # Ignore return values (perplexity and BLEU) for integration test
        run_train_translate(train_params=train_params,
                            translate_params=translate_params,
                            translate_params_equiv=translate_params_batch,
                            train_source_path=data['source'],
                            train_target_path=data['target'],
                            dev_source_path=data['validation_source'],
                            dev_target_path=data['validation_target'],
                            test_source_path=data['test_source'],
                            test_target_path=data['test_target'],
                            train_source_factor_paths=train_source_factor_paths,
                            dev_source_factor_paths=dev_source_factor_paths,
                            test_source_factor_paths=test_source_factor_paths,
                            max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS,
                            restrict_lexicon=restrict_lexicon,
                            work_dir=data['work_dir'],
                            use_prepared_data=use_prepared_data)
Ejemplo n.º 10
0
def test_seq_copy(train_params: str, translate_params: str,
                  restrict_lexicon: bool, use_prepared_data: bool,
                  use_source_factors: bool):
    """Task: copy short sequences of digits"""

    with tmp_digits_dataset(prefix="test_seq_copy",
                            train_line_count=_TRAIN_LINE_COUNT,
                            train_max_length=_LINE_MAX_LENGTH,
                            dev_line_count=_DEV_LINE_COUNT,
                            dev_max_length=_LINE_MAX_LENGTH,
                            test_line_count=_TEST_LINE_COUNT,
                            test_line_count_empty=_TEST_LINE_COUNT_EMPTY,
                            test_max_length=_TEST_MAX_LENGTH,
                            sort_target=False,
                            with_source_factors=use_source_factors) as data:

        # When using source factors
        train_source_factor_paths, dev_source_factor_paths, test_source_factor_paths = None, None, None
        if use_source_factors:
            train_source_factor_paths = [data['source']]
            dev_source_factor_paths = [data['validation_source']]
            test_source_factor_paths = [data['test_source']]

        # Test model configuration, including the output equivalence of batch and no-batch decoding
        if "--nbest-size" not in translate_params.split():
            translate_params_batch = translate_params + " --batch-size 2"
        else:
            # nbest produces json output, which doesn't work with the splitting
            # of translations and scores in run_train_translate, which in turn
            # makes the comparison with the batch decoding fail.
            # TODO: Refactor the run_train_translate function!
            translate_params_batch = None

        # Ignore return values (perplexity and BLEU) for integration test
        run_train_translate(
            train_params=train_params,
            translate_params=translate_params,
            translate_params_equiv=translate_params_batch,
            train_source_path=data['source'],
            train_target_path=data['target'],
            dev_source_path=data['validation_source'],
            dev_target_path=data['validation_target'],
            test_source_path=data['test_source'],
            test_target_path=data['test_target'],
            train_source_factor_paths=train_source_factor_paths,
            dev_source_factor_paths=dev_source_factor_paths,
            test_source_factor_paths=test_source_factor_paths,
            max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS,
            restrict_lexicon=restrict_lexicon,
            work_dir=data['work_dir'],
            use_prepared_data=use_prepared_data,
            use_target_constraints=False)
Ejemplo n.º 11
0
def test_seq_copy(train_params: str, translate_params: str,
                  restrict_lexicon: bool, use_prepared_data: bool,
                  use_source_factors: bool, use_constrained_decoding: bool):
    """Task: copy short sequences of digits"""

    with tmp_digits_dataset(
            prefix="test_seq_copy",
            train_line_count=_TRAIN_LINE_COUNT,
            train_max_length=_LINE_MAX_LENGTH,
            dev_line_count=_DEV_LINE_COUNT,
            dev_max_length=_LINE_MAX_LENGTH,
            test_line_count=_TEST_LINE_COUNT,
            test_line_count_empty=_TEST_LINE_COUNT_EMPTY,
            test_max_length=_TEST_MAX_LENGTH,
            sort_target=False,
            with_source_factors=use_source_factors,
            with_target_constraints=use_constrained_decoding) as data:

        # Only one of these is supported at a time in the tests
        assert not (use_source_factors and use_constrained_decoding)

        # When using source factors
        train_source_factor_paths, dev_source_factor_paths, test_source_factor_paths = None, None, None
        if use_source_factors:
            train_source_factor_paths = [data['source']]
            dev_source_factor_paths = [data['validation_source']]
            test_source_factor_paths = [data['test_source']]

        if use_constrained_decoding:
            translate_params += " --json-input"

        # Test model configuration, including the output equivalence of batch and no-batch decoding
        translate_params_batch = translate_params + " --batch-size 2"

        # Ignore return values (perplexity and BLEU) for integration test
        run_train_translate(
            train_params=train_params,
            translate_params=translate_params,
            translate_params_equiv=translate_params_batch,
            train_source_path=data['source'],
            train_target_path=data['target'],
            dev_source_path=data['validation_source'],
            dev_target_path=data['validation_target'],
            test_source_path=data['test_target'],
            test_target_path=data['test_target'],
            train_source_factor_paths=train_source_factor_paths,
            dev_source_factor_paths=dev_source_factor_paths,
            test_source_factor_paths=test_source_factor_paths,
            max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS,
            restrict_lexicon=restrict_lexicon,
            work_dir=data['work_dir'],
            use_prepared_data=use_prepared_data)
Ejemplo n.º 12
0
def test_seq_copy(train_params: str, translate_params: str, restrict_lexicon: bool):
    """Task: copy short sequences of digits"""
    with tmp_digits_dataset("test_seq_copy", _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH, _DEV_LINE_COUNT,
                            _LINE_MAX_LENGTH) as data:
        # Test model configuration, including the output equivalence of batch and no-batch decoding
        translate_params_batch = translate_params + " --batch-size 2"
        # Ignore return values (perplexity and BLEU) for integration test
        run_train_translate(train_params,
                            translate_params,
                            translate_params_batch,
                            data['source'],
                            data['target'],
                            data['validation_source'],
                            data['validation_target'],
                            max_seq_len=_LINE_MAX_LENGTH + 1,
                            restrict_lexicon=restrict_lexicon,
                            work_dir=data['work_dir'])
Ejemplo n.º 13
0
def test_constraints(train_params: str, translate_params: str):
    with tmp_digits_dataset(prefix="test_constraints",
                            train_line_count=_TRAIN_LINE_COUNT,
                            train_max_length=_LINE_MAX_LENGTH,
                            dev_line_count=_DEV_LINE_COUNT,
                            dev_max_length=_LINE_MAX_LENGTH,
                            test_line_count=_TEST_LINE_COUNT,
                            test_line_count_empty=_TEST_LINE_COUNT_EMPTY,
                            test_max_length=_TEST_MAX_LENGTH,
                            sort_target=False) as data:
        # train a minimal default model
        data = run_train_translate(train_params=train_params, translate_params=translate_params, data=data,
                                   max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS)

        # 'constraint' = positive constraints (must appear), 'avoid' = negative constraints (must not appear)
        for constraint_type in ["constraints", "avoid"]:
            _test_constrained_type(constraint_type=constraint_type, data=data, translate_params=translate_params)
Ejemplo n.º 14
0
def test_seq_copy(train_params: str, translate_params: str,
                  use_prepared_data: bool, use_source_factors: bool):
    """
    Task: copy short sequences of digits
    """

    with tmp_digits_dataset(prefix="test_seq_copy",
                            train_line_count=_TRAIN_LINE_COUNT,
                            train_max_length=_LINE_MAX_LENGTH,
                            dev_line_count=_DEV_LINE_COUNT,
                            dev_max_length=_LINE_MAX_LENGTH,
                            test_line_count=_TEST_LINE_COUNT,
                            test_line_count_empty=_TEST_LINE_COUNT_EMPTY,
                            test_max_length=_TEST_MAX_LENGTH,
                            sort_target=False,
                            with_source_factors=use_source_factors) as data:
        check_train_translate(train_params=train_params,
                              translate_params=translate_params,
                              data=data,
                              use_prepared_data=use_prepared_data,
                              max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS)
Ejemplo n.º 15
0
def test_get_training_data_iters():
    train_line_count = 100
    train_max_length = 30
    dev_line_count = 20
    dev_max_length = 30
    expected_mean = 1.0
    expected_std = 0.0
    test_line_count = 20
    test_line_count_empty = 0
    test_max_length = 30
    batch_size = 5
    with tmp_digits_dataset("tmp_corpus", train_line_count,
                            train_max_length - C.SPACE_FOR_XOS, dev_line_count,
                            dev_max_length - C.SPACE_FOR_XOS, test_line_count,
                            test_line_count_empty,
                            test_max_length - C.SPACE_FOR_XOS) as data:
        # tmp common vocab
        vcb = vocab.build_from_paths([data['source'], data['target']])

        train_iter, val_iter, config_data, data_info = data_io.get_training_data_iters(
            sources=[data['source']],
            target=data['target'],
            validation_sources=[data['validation_source']],
            validation_target=data['validation_target'],
            source_vocabs=[vcb],
            target_vocab=vcb,
            source_vocab_paths=[None],
            target_vocab_path=None,
            shared_vocab=True,
            batch_size=batch_size,
            batch_by_words=False,
            batch_num_devices=1,
            fill_up="replicate",
            max_seq_len_source=train_max_length,
            max_seq_len_target=train_max_length,
            bucketing=True,
            bucket_width=10)
        assert isinstance(train_iter, data_io.ParallelSampleIter)
        assert isinstance(val_iter, data_io.ParallelSampleIter)
        assert isinstance(config_data, data_io.DataConfig)
        assert data_info.sources == [data['source']]
        assert data_info.target == data['target']
        assert data_info.source_vocabs == [None]
        assert data_info.target_vocab is None
        assert config_data.data_statistics.max_observed_len_source == train_max_length
        assert config_data.data_statistics.max_observed_len_target == train_max_length
        assert np.isclose(config_data.data_statistics.length_ratio_mean,
                          expected_mean)
        assert np.isclose(config_data.data_statistics.length_ratio_std,
                          expected_std)

        assert train_iter.batch_size == batch_size
        assert val_iter.batch_size == batch_size
        assert train_iter.default_bucket_key == (train_max_length,
                                                 train_max_length)
        assert val_iter.default_bucket_key == (dev_max_length, dev_max_length)
        assert train_iter.dtype == 'float32'

        # test some batches
        bos_id = vcb[C.BOS_SYMBOL]
        eos_id = vcb[C.EOS_SYMBOL]
        expected_first_target_symbols = np.full((batch_size, ),
                                                bos_id,
                                                dtype='float32')
        for epoch in range(2):
            while train_iter.iter_next():
                batch = train_iter.next()
                assert len(batch.data) == 2
                assert len(batch.label) == 1
                assert batch.bucket_key in train_iter.buckets
                source = batch.data[0].asnumpy()
                target = batch.data[1].asnumpy()
                label = batch.label[0].asnumpy()
                assert source.shape[0] == target.shape[0] == label.shape[
                    0] == batch_size
                # target first symbol should be BOS
                # each source sequence contains one EOS symbol
                assert np.sum(source == eos_id) == batch_size
                assert np.array_equal(target[:, 0],
                                      expected_first_target_symbols)
                # label first symbol should be 2nd target symbol
                assert np.array_equal(label[:, 0], target[:, 1])
                # each label sequence contains one EOS symbol
                assert np.sum(label == eos_id) == batch_size
            train_iter.reset()
Ejemplo n.º 16
0
def test_get_training_data_iters():
    train_line_count = 100
    train_max_length = 30
    dev_line_count = 20
    dev_max_length = 30
    expected_mean = 1.0
    expected_std = 0.0
    test_line_count = 20
    test_line_count_empty = 0
    test_max_length = 30
    batch_size = 5
    with tmp_digits_dataset("tmp_corpus",
                            train_line_count, train_max_length - C.SPACE_FOR_XOS,
                            dev_line_count, dev_max_length - C.SPACE_FOR_XOS,
                            test_line_count, test_line_count_empty,
                            test_max_length - C.SPACE_FOR_XOS) as data:
        # tmp common vocab
        vcb = vocab.build_from_paths([data['source'], data['target']])

        train_iter, val_iter, config_data, data_info = data_io.get_training_data_iters(
            sources=[data['source']],
            target=data['target'],
            validation_sources=[
                data['validation_source']],
            validation_target=data[
                'validation_target'],
            source_vocabs=[vcb],
            target_vocab=vcb,
            source_vocab_paths=[None],
            target_vocab_path=None,
            shared_vocab=True,
            batch_size=batch_size,
            batch_by_words=False,
            batch_num_devices=1,
            fill_up="replicate",
            max_seq_len_source=train_max_length,
            max_seq_len_target=train_max_length,
            bucketing=True,
            bucket_width=10)
        assert isinstance(train_iter, data_io.ParallelSampleIter)
        assert isinstance(val_iter, data_io.ParallelSampleIter)
        assert isinstance(config_data, data_io.DataConfig)
        assert data_info.sources == [data['source']]
        assert data_info.target == data['target']
        assert data_info.source_vocabs == [None]
        assert data_info.target_vocab is None
        assert config_data.data_statistics.max_observed_len_source == train_max_length
        assert config_data.data_statistics.max_observed_len_target == train_max_length
        assert np.isclose(config_data.data_statistics.length_ratio_mean, expected_mean)
        assert np.isclose(config_data.data_statistics.length_ratio_std, expected_std)

        assert train_iter.batch_size == batch_size
        assert val_iter.batch_size == batch_size
        assert train_iter.default_bucket_key == (train_max_length, train_max_length)
        assert val_iter.default_bucket_key == (dev_max_length, dev_max_length)
        assert train_iter.dtype == 'float32'

        # test some batches
        bos_id = vcb[C.BOS_SYMBOL]
        eos_id = vcb[C.EOS_SYMBOL]
        expected_first_target_symbols = np.full((batch_size,), bos_id, dtype='float32')
        for epoch in range(2):
            while train_iter.iter_next():
                batch = train_iter.next()
                assert len(batch.data) == 2
                assert len(batch.label) == 1
                assert batch.bucket_key in train_iter.buckets
                source = batch.data[0].asnumpy()
                target = batch.data[1].asnumpy()
                label = batch.label[0].asnumpy()
                assert source.shape[0] == target.shape[0] == label.shape[0] == batch_size
                # target first symbol should be BOS
                # each source sequence contains one EOS symbol
                assert np.sum(source == eos_id) == batch_size
                assert np.array_equal(target[:, 0], expected_first_target_symbols)
                # label first symbol should be 2nd target symbol
                assert np.array_equal(label[:, 0], target[:, 1])
                # each label sequence contains one EOS symbol
                assert np.sum(label == eos_id) == batch_size
            train_iter.reset()