Пример #1
0
def test_mx_pt_eq_training_data():
    pytest.importorskip("mxnet")
    from sockeye import data_io

    train_line_count = 100
    train_line_count_empty = 0
    train_max_length = 30
    dev_line_count = 20
    dev_max_length = 30
    test_line_count = 20
    test_line_count_empty = 0
    test_max_length = 30
    batch_size = 5
    with tmp_digits_dataset("tmp_corpus", train_line_count,
                            train_line_count_empty,
                            train_max_length - C.SPACE_FOR_XOS, dev_line_count,
                            dev_max_length - C.SPACE_FOR_XOS, test_line_count,
                            test_line_count_empty,
                            test_max_length - C.SPACE_FOR_XOS) as data:

        vcb = vocab.build_from_paths(
            [data['train_source'], data['train_target']])

        train_iters = {}
        val_iters = {}

        # For each implementation
        for key, data_io_module in (('mx', data_io), ('pt', data_io_pt)):
            # Create iterators with no data permutation (preserve order for
            # batch equality checks)
            train_iter, val_iter, _, _ = data_io_module.get_training_data_iters(
                sources=[data['train_source']],
                targets=[data['train_target']],
                validation_sources=[data['dev_source']],
                validation_targets=[data['dev_target']],
                source_vocabs=[vcb],
                target_vocabs=[vcb],
                source_vocab_paths=[None],
                target_vocab_paths=[None],
                shared_vocab=True,
                batch_size=batch_size,
                batch_type=C.BATCH_TYPE_SENTENCE,
                max_seq_len_source=train_max_length,
                max_seq_len_target=train_max_length,
                bucketing=True,
                bucket_width=10,
                permute=False)
            train_iters[key] = train_iter
            val_iters[key] = val_iter

        # Check equality of all MXNet/PyTorch batches
        for iters in (train_iters, val_iters):
            for mx_batch, pt_batch in zip(iters['mx'], iters['pt']):
                _assert_mx_pt_batches_equal(mx_batch, pt_batch)
Пример #2
0
def test_get_training_data_iters():
    train_line_count = 100
    train_line_count_empty = 0
    train_max_length = 30
    dev_line_count = 20
    dev_max_length = 30
    expected_mean = 1.0
    expected_std = 0.0
    test_line_count = 20
    test_line_count_empty = 0
    test_max_length = 30
    batch_size = 5
    with tmp_digits_dataset("tmp_corpus", train_line_count,
                            train_line_count_empty,
                            train_max_length - C.SPACE_FOR_XOS, dev_line_count,
                            dev_max_length - C.SPACE_FOR_XOS, test_line_count,
                            test_line_count_empty,
                            test_max_length - C.SPACE_FOR_XOS) as data:
        # tmp common vocab
        vcb = vocab.build_from_paths(
            [data['train_source'], data['train_target']])

        train_iter, val_iter, config_data, data_info = data_io.get_training_data_iters(
            sources=[data['train_source']],
            target=data['train_target'],
            validation_sources=[data['dev_source']],
            validation_target=data['dev_target'],
            source_vocabs=[vcb],
            target_vocab=vcb,
            source_vocab_paths=[None],
            target_vocab_path=None,
            shared_vocab=True,
            batch_size=batch_size,
            batch_by_words=False,
            batch_num_devices=1,
            max_seq_len_source=train_max_length,
            max_seq_len_target=train_max_length,
            bucketing=True,
            bucket_width=10)
        assert isinstance(train_iter, data_io.ParallelSampleIter)
        assert isinstance(val_iter, data_io.ParallelSampleIter)
        assert isinstance(config_data, data_io.DataConfig)
        assert data_info.sources == [data['train_source']]
        assert data_info.target == data['train_target']
        assert data_info.source_vocabs == [None]
        assert data_info.target_vocab is None
        assert config_data.data_statistics.max_observed_len_source == train_max_length
        assert config_data.data_statistics.max_observed_len_target == train_max_length
        assert np.isclose(config_data.data_statistics.length_ratio_mean,
                          expected_mean)
        assert np.isclose(config_data.data_statistics.length_ratio_std,
                          expected_std)

        assert train_iter.batch_size == batch_size
        assert val_iter.batch_size == batch_size
        assert train_iter.default_bucket_key == (train_max_length,
                                                 train_max_length)
        assert val_iter.default_bucket_key == (dev_max_length, dev_max_length)
        assert train_iter.dtype == 'float32'

        # test some batches
        bos_id = vcb[C.BOS_SYMBOL]
        eos_id = vcb[C.EOS_SYMBOL]
        expected_first_target_symbols = np.full((batch_size, ),
                                                bos_id,
                                                dtype='float32')
        for epoch in range(2):
            while train_iter.iter_next():
                batch = train_iter.next()
                assert isinstance(batch, data_io.Batch)
                source = batch.source.asnumpy()
                target = batch.target.asnumpy()
                label = batch.labels[C.TARGET_LABEL_NAME].asnumpy()
                length_ratio_label = batch.labels[
                    C.LENRATIO_LABEL_NAME].asnumpy()
                assert source.shape[0] == target.shape[0] == label.shape[
                    0] == batch_size
                # target first symbol should be BOS
                # each source sequence contains one EOS symbol
                assert np.sum(source == eos_id) == batch_size
                assert np.array_equal(target[:, 0],
                                      expected_first_target_symbols)
                # label first symbol should be 2nd target symbol
                assert np.array_equal(label[:, 0], target[:, 1])
                # each label sequence contains one EOS symbol
                assert np.sum(label == eos_id) == batch_size
            train_iter.reset()
Пример #3
0
def test_get_training_image_text_data_iters():
    # Test images
    source_list = ['1', '2', '3', '4', '100']
    prefix = "tmp_corpus"
    use_feature_loader = False
    preload_features = False
    train_max_length = 30
    dev_max_length = 30
    expected_mean = 1.0
    expected_std = 1.0
    test_max_length = 30
    batch_size = 5
    if use_feature_loader:
        source_image_size = _FEATURE_SHAPE
    else:
        source_image_size = _CNN_INPUT_IMAGE_SHAPE
    with tmp_img_captioning_dataset(source_list, prefix, train_max_length,
                                    dev_max_length, test_max_length,
                                    use_feature_loader) as data:
        # tmp common vocab
        vcb = vocab.build_from_paths([data['target'], data['target']])

        train_iter, val_iter, config_data, data_info = data_io.get_training_image_text_data_iters(
            source_root=data['work_dir'],
            source=data['source'],
            target=data['target'],
            validation_source_root=data['work_dir'],
            validation_source=data['validation_source'],
            validation_target=data['validation_target'],
            vocab_target=vcb,
            vocab_target_path=None,
            batch_size=batch_size,
            batch_by_words=False,
            batch_num_devices=1,
            source_image_size=source_image_size,
            fill_up="replicate",
            max_seq_len_target=train_max_length,
            bucketing=False,
            bucket_width=10,
            use_feature_loader=use_feature_loader,
            preload_features=preload_features)
        assert isinstance(train_iter, data_io.ParallelSampleIter)
        assert isinstance(val_iter, data_io.ParallelSampleIter)
        assert isinstance(config_data, data_io.DataConfig)
        assert isinstance(data_info.sources[0], data_io.FileListReader)
        assert data_info.target == data['target']
        assert data_info.source_vocabs is None
        assert data_info.target_vocab is None
        assert config_data.data_statistics.max_observed_len_source == 0
        assert config_data.data_statistics.max_observed_len_target == train_max_length - 1
        assert np.isclose(config_data.data_statistics.length_ratio_mean,
                          expected_mean)
        assert np.isclose(config_data.data_statistics.length_ratio_std,
                          expected_std)

        assert train_iter.batch_size == batch_size
        assert val_iter.batch_size == batch_size
        assert train_iter.default_bucket_key == (0, train_max_length)
        assert val_iter.default_bucket_key == (0, dev_max_length)
        assert train_iter.dtype == 'float32'

        # test some batches
        bos_id = vcb[C.BOS_SYMBOL]
        expected_first_target_symbols = np.full((batch_size, ),
                                                bos_id,
                                                dtype='float32')
        for epoch in range(2):
            while train_iter.iter_next():
                batch = train_iter.next()
                assert len(batch.data) == 2
                assert len(batch.label) == 1
                assert batch.bucket_key in train_iter.buckets
                source = batch.data[0].asnumpy()
                target = batch.data[1].asnumpy()
                label = batch.label[0].asnumpy()
                assert source.shape[0] == target.shape[0] == label.shape[
                    0] == batch_size
                # target first symbol should be BOS
                assert np.array_equal(target[:, 0],
                                      expected_first_target_symbols)
                # label first symbol should be 2nd target symbol
                assert np.array_equal(label[:, 0], target[:, 1])
                # each label sequence contains one EOS symbol
                assert np.sum(label == vcb[C.EOS_SYMBOL]) == batch_size
            train_iter.reset()
Пример #4
0
def test_get_training_data_iters():
    train_line_count = 100
    train_max_length = 30
    dev_line_count = 20
    dev_max_length = 30
    expected_mean = 1.0
    expected_std = 0.0
    test_line_count = 20
    test_line_count_empty = 0
    test_max_length = 30
    batch_size = 5
    with tmp_digits_dataset("tmp_corpus",
                            train_line_count, train_max_length - C.SPACE_FOR_XOS,
                            dev_line_count, dev_max_length - C.SPACE_FOR_XOS,
                            test_line_count, test_line_count_empty,
                            test_max_length - C.SPACE_FOR_XOS) as data:
        # tmp common vocab
        vcb = vocab.build_from_paths([data['source'], data['target']])

        train_iter, val_iter, config_data, data_info = data_io.get_training_data_iters(
            sources=[data['source']],
            target=data['target'],
            validation_sources=[
                data['validation_source']],
            validation_target=data[
                'validation_target'],
            source_vocabs=[vcb],
            target_vocab=vcb,
            source_vocab_paths=[None],
            target_vocab_path=None,
            shared_vocab=True,
            batch_size=batch_size,
            batch_by_words=False,
            batch_num_devices=1,
            fill_up="replicate",
            max_seq_len_source=train_max_length,
            max_seq_len_target=train_max_length,
            bucketing=True,
            bucket_width=10)
        assert isinstance(train_iter, data_io.ParallelSampleIter)
        assert isinstance(val_iter, data_io.ParallelSampleIter)
        assert isinstance(config_data, data_io.DataConfig)
        assert data_info.sources == [data['source']]
        assert data_info.target == data['target']
        assert data_info.source_vocabs == [None]
        assert data_info.target_vocab is None
        assert config_data.data_statistics.max_observed_len_source == train_max_length
        assert config_data.data_statistics.max_observed_len_target == train_max_length
        assert np.isclose(config_data.data_statistics.length_ratio_mean, expected_mean)
        assert np.isclose(config_data.data_statistics.length_ratio_std, expected_std)

        assert train_iter.batch_size == batch_size
        assert val_iter.batch_size == batch_size
        assert train_iter.default_bucket_key == (train_max_length, train_max_length)
        assert val_iter.default_bucket_key == (dev_max_length, dev_max_length)
        assert train_iter.dtype == 'float32'

        # test some batches
        bos_id = vcb[C.BOS_SYMBOL]
        eos_id = vcb[C.EOS_SYMBOL]
        expected_first_target_symbols = np.full((batch_size,), bos_id, dtype='float32')
        for epoch in range(2):
            while train_iter.iter_next():
                batch = train_iter.next()
                assert len(batch.data) == 2
                assert len(batch.label) == 1
                assert batch.bucket_key in train_iter.buckets
                source = batch.data[0].asnumpy()
                target = batch.data[1].asnumpy()
                label = batch.label[0].asnumpy()
                assert source.shape[0] == target.shape[0] == label.shape[0] == batch_size
                # target first symbol should be BOS
                # each source sequence contains one EOS symbol
                assert np.sum(source == eos_id) == batch_size
                assert np.array_equal(target[:, 0], expected_first_target_symbols)
                # label first symbol should be 2nd target symbol
                assert np.array_equal(label[:, 0], target[:, 1])
                # each label sequence contains one EOS symbol
                assert np.sum(label == eos_id) == batch_size
            train_iter.reset()
Пример #5
0
def test_get_training_data_iters():
    train_line_count = 100
    train_max_length = 30
    dev_line_count = 20
    dev_max_length = 30
    expected_mean = 1.0
    expected_std = 0.0
    test_line_count = 20
    test_line_count_empty = 0
    test_max_length = 30
    batch_size = 5
    with tmp_digits_dataset("tmp_corpus",
                            train_line_count, train_max_length, dev_line_count, dev_max_length,
                            test_line_count, test_line_count_empty, test_max_length) as data:
        # tmp common vocab
        vcb = vocab.build_from_paths([data['source'], data['target']])

        train_iter, val_iter, config_data = data_io.get_training_data_iters(data['source'], data['target'],
                                                                            data['validation_source'],
                                                                            data['validation_target'],
                                                                            vocab_source=vcb,
                                                                            vocab_target=vcb,
                                                                            vocab_source_path=None,
                                                                            vocab_target_path=None,
                                                                            shared_vocab=True,
                                                                            batch_size=batch_size,
                                                                            batch_by_words=False,
                                                                            batch_num_devices=1,
                                                                            fill_up="replicate",
                                                                            max_seq_len_source=train_max_length,
                                                                            max_seq_len_target=train_max_length,
                                                                            bucketing=True,
                                                                            bucket_width=10)
        assert isinstance(train_iter, data_io.ParallelSampleIter)
        assert isinstance(val_iter, data_io.ParallelSampleIter)
        assert isinstance(config_data, data_io.DataConfig)
        assert config_data.source == data['source']
        assert config_data.target == data['target']
        assert config_data.vocab_source is None
        assert config_data.vocab_target is None
        assert config_data.data_statistics.max_observed_len_source == train_max_length - 1
        assert config_data.data_statistics.max_observed_len_target == train_max_length
        assert np.isclose(config_data.data_statistics.length_ratio_mean, expected_mean)
        assert np.isclose(config_data.data_statistics.length_ratio_std, expected_std)

        assert train_iter.batch_size == batch_size
        assert val_iter.batch_size == batch_size
        assert train_iter.default_bucket_key == (train_max_length, train_max_length)
        assert val_iter.default_bucket_key == (dev_max_length, dev_max_length)
        assert train_iter.dtype == 'float32'

        # test some batches
        bos_id = vcb[C.BOS_SYMBOL]
        expected_first_target_symbols = np.full((batch_size,), bos_id, dtype='float32')
        for epoch in range(2):
            while train_iter.iter_next():
                batch = train_iter.next()
                assert len(batch.data) == 2
                assert len(batch.label) == 1
                assert batch.bucket_key in train_iter.buckets
                source = batch.data[0].asnumpy()
                target = batch.data[1].asnumpy()
                label = batch.label[0].asnumpy()
                assert source.shape[0] == target.shape[0] == label.shape[0] == batch_size
                # target first symbol should be BOS
                assert np.array_equal(target[:, 0], expected_first_target_symbols)
                # label first symbol should be 2nd target symbol
                assert np.array_equal(label[:, 0], target[:, 1])
                # each label sequence contains one EOS symbol
                assert np.sum(label == vcb[C.EOS_SYMBOL]) == batch_size
            train_iter.reset()
Пример #6
0
def test_get_training_data_iters():
    from sockeye.test_utils import tmp_digits_dataset

    train_line_count = 100
    train_line_count_empty = 0
    train_max_length = 30
    dev_line_count = 20
    dev_max_length = 30
    expected_mean = 1.0
    expected_std = 0.0
    test_line_count = 20
    test_line_count_empty = 0
    test_max_length = 30
    batch_size = 5
    num_source_factors = num_target_factors = 1
    with tmp_digits_dataset("tmp_corpus", train_line_count,
                            train_line_count_empty,
                            train_max_length - C.SPACE_FOR_XOS, dev_line_count,
                            dev_max_length - C.SPACE_FOR_XOS, test_line_count,
                            test_line_count_empty,
                            test_max_length - C.SPACE_FOR_XOS) as data:
        # tmp common vocab
        vcb = vocab.build_from_paths(
            [data['train_source'], data['train_target']])

        train_iter, val_iter, config_data, data_info = data_io.get_training_data_iters(
            sources=[data['train_source']],
            targets=[data['train_target']],
            validation_sources=[data['dev_source']],
            validation_targets=[data['dev_target']],
            source_vocabs=[vcb],
            target_vocabs=[vcb],
            source_vocab_paths=[None],
            target_vocab_paths=[None],
            shared_vocab=True,
            batch_size=batch_size,
            batch_type=C.BATCH_TYPE_SENTENCE,
            max_seq_len_source=train_max_length,
            max_seq_len_target=train_max_length,
            bucketing=True,
            bucket_width=10)
        assert isinstance(train_iter, data_io.ParallelSampleIter)
        assert isinstance(val_iter, data_io.ParallelSampleIter)
        assert isinstance(config_data, data_io.DataConfig)
        assert data_info.sources == [data['train_source']]
        assert data_info.targets == [data['train_target']]
        assert data_info.source_vocabs == [None]
        assert data_info.target_vocabs == [None]
        assert config_data.data_statistics.max_observed_len_source == train_max_length
        assert config_data.data_statistics.max_observed_len_target == train_max_length
        assert np.isclose(config_data.data_statistics.length_ratio_mean,
                          expected_mean)
        assert np.isclose(config_data.data_statistics.length_ratio_std,
                          expected_std)

        assert train_iter.batch_size == batch_size
        assert val_iter.batch_size == batch_size

        # test some batches
        bos_id = vcb[C.BOS_SYMBOL]
        eos_id = vcb[C.EOS_SYMBOL]
        expected_first_target_symbols = torch.full((batch_size, 1),
                                                   bos_id,
                                                   dtype=torch.int32)
        for epoch in range(2):
            while train_iter.iter_next():
                batch = train_iter.next()
                assert isinstance(batch, data_io.Batch)
                source = batch.source
                target = batch.target
                label = batch.labels[
                    C.
                    TARGET_LABEL_NAME]  # TODO: still 2-shape: (batch, length)
                length_ratio_label = batch.labels[C.LENRATIO_LABEL_NAME]
                assert source.shape[0] == target.shape[0] == label.shape[
                    0] == batch_size
                assert source.shape[2] == target.shape[
                    2] == num_source_factors == num_target_factors
                # target first symbol should be BOS
                # each source sequence contains one EOS symbol
                assert torch.sum(source == eos_id) == batch_size
                assert torch.equal(target[:, 0], expected_first_target_symbols)
                # label first symbol should be 2nd target symbol
                assert torch.equal(label[:, 0], target[:, 1, 0])
                # each label sequence contains one EOS symbol
                assert torch.sum(label == eos_id) == batch_size
            train_iter.reset()
Пример #7
0
def test_mx_pt_eq_prepared_data():
    pytest.importorskip("mxnet")
    from sockeye import data_io

    train_line_count = 100
    train_line_count_empty = 0
    train_max_length = 30
    dev_line_count = 20
    dev_max_length = 30
    test_line_count = 20
    test_line_count_empty = 0
    test_max_length = 30
    batch_size = 5
    batch_sentences_multiple_of = 8

    with tmp_digits_dataset("tmp_corpus", train_line_count,
                            train_line_count_empty,
                            train_max_length - C.SPACE_FOR_XOS, dev_line_count,
                            dev_max_length - C.SPACE_FOR_XOS, test_line_count,
                            test_line_count_empty,
                            test_max_length - C.SPACE_FOR_XOS) as data:

        with TemporaryDirectory() as work_dir, utils.create_pool(2) as pool:

            vcb = vocab.build_from_paths(
                [data['train_source'], data['train_target']])

            train_iters = {}
            val_iters = {}

            # For each implementation
            for key, data_io_module in (('mx', data_io), ('pt', data_io_pt)):
                output_folder = os.path.join(work_dir, key)
                os.mkdir(output_folder)

                # Create 1 shard (avoid random assignment that breaks equality)
                shards, keep_tmp_shard_files = data_io_module.create_shards(
                    source_fnames=[data['train_source']],
                    target_fnames=[data['train_target']],
                    num_shards=1,
                    output_prefix=output_folder)

                # Prepare data using multiple processes
                data_io_module.prepare_data(
                    source_fnames=[data['train_source']],
                    target_fnames=[data['train_target']],
                    source_vocabs=[vcb],
                    target_vocabs=[vcb],
                    source_vocab_paths=[None],
                    target_vocab_paths=[None],
                    shared_vocab=True,
                    max_seq_len_source=train_max_length,
                    max_seq_len_target=train_max_length,
                    bucketing=True,
                    bucket_width=10,
                    num_shards=1,
                    output_prefix=output_folder,
                    bucket_scaling=True,
                    keep_tmp_shard_files=keep_tmp_shard_files,
                    pool=pool,
                    shards=shards)

                # Create iterators
                train_iter, val_iter, _, _, _ = data_io_module.get_prepared_data_iters(
                    prepared_data_dir=output_folder,
                    validation_sources=[data['dev_source']],
                    validation_targets=[data['dev_target']],
                    shared_vocab=True,
                    batch_size=batch_size,
                    batch_type=C.BATCH_TYPE_SENTENCE,
                    batch_sentences_multiple_of=batch_sentences_multiple_of,
                    permute=False)

                train_iters[key] = train_iter
                val_iters[key] = val_iter

            # Check equality of all MXNet/PyTorch batches
            for iters in (train_iters, val_iters):
                for i, (mx_batch,
                        pt_batch) in enumerate(zip(iters['mx'], iters['pt']),
                                               1):
                    print(i)
                    _assert_mx_pt_batches_equal(mx_batch, pt_batch)
Пример #8
0
def test_get_training_image_text_data_iters():
    # Test images
    source_list = ['1', '2', '3', '4', '100']
    prefix = "tmp_corpus"
    use_feature_loader = False
    preload_features = False
    train_max_length = 30
    dev_max_length = 30
    expected_mean = 1.0
    expected_std = 1.0
    test_max_length = 30
    batch_size = 5
    if use_feature_loader:
        source_image_size = _FEATURE_SHAPE
    else:
        source_image_size = _CNN_INPUT_IMAGE_SHAPE
    with tmp_img_captioning_dataset(source_list,
                                    prefix,
                                    train_max_length,
                                    dev_max_length,
                                    test_max_length,
                                    use_feature_loader) as data:
        # tmp common vocab
        vcb = vocab.build_from_paths([data['target'], data['target']])

        train_iter, val_iter, config_data, data_info = data_io.get_training_image_text_data_iters(source_root=data['work_dir'],
                                                                                                  source=data['source'],
                                                                                                  target=data['target'],
                                                                                                  validation_source_root=data['work_dir'],
                                                                                                  validation_source=data['validation_source'],
                                                                                                  validation_target=data['validation_target'],
                                                                                                  vocab_target=vcb,
                                                                                                  vocab_target_path=None,
                                                                                                  batch_size=batch_size,
                                                                                                  batch_by_words=False,
                                                                                                  batch_num_devices=1,
                                                                                                  source_image_size=source_image_size,
                                                                                                  fill_up="replicate",
                                                                                                  max_seq_len_target=train_max_length,
                                                                                                  bucketing=False,
                                                                                                  bucket_width=10,
                                                                                                  use_feature_loader=use_feature_loader,
                                                                                                  preload_features=preload_features)
        assert isinstance(train_iter, data_io.ParallelSampleIter)
        assert isinstance(val_iter, data_io.ParallelSampleIter)
        assert isinstance(config_data, data_io.DataConfig)
        assert isinstance(data_info.sources[0], data_io.FileListReader)
        assert data_info.target == data['target']
        assert data_info.source_vocabs is None
        assert data_info.target_vocab is None
        assert config_data.data_statistics.max_observed_len_source == 0
        assert config_data.data_statistics.max_observed_len_target == train_max_length - 1
        assert np.isclose(config_data.data_statistics.length_ratio_mean, expected_mean)
        assert np.isclose(config_data.data_statistics.length_ratio_std, expected_std)

        assert train_iter.batch_size == batch_size
        assert val_iter.batch_size == batch_size
        assert train_iter.default_bucket_key == (0, train_max_length)
        assert val_iter.default_bucket_key == (0, dev_max_length)
        assert train_iter.dtype == 'float32'

        # test some batches
        bos_id = vcb[C.BOS_SYMBOL]
        expected_first_target_symbols = np.full((batch_size,), bos_id, dtype='float32')
        for epoch in range(2):
            while train_iter.iter_next():
                batch = train_iter.next()
                assert len(batch.data) == 2
                assert len(batch.label) == 1
                assert batch.bucket_key in train_iter.buckets
                source = batch.data[0].asnumpy()
                target = batch.data[1].asnumpy()
                label = batch.label[0].asnumpy()
                assert source.shape[0] == target.shape[0] == label.shape[0] == batch_size
                # target first symbol should be BOS
                assert np.array_equal(target[:, 0], expected_first_target_symbols)
                # label first symbol should be 2nd target symbol
                assert np.array_equal(label[:, 0], target[:, 1])
                # each label sequence contains one EOS symbol
                assert np.sum(label == vcb[C.EOS_SYMBOL]) == batch_size
            train_iter.reset()