Exemple #1
0
def test_parallel_data_set_permute():
    batch_size = 5
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets,
                                                           batch_size,
                                                           batch_by_words=False,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))
    dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5)).fill_up(
        bucket_batch_sizes, 'replicate')

    permutations, inverse_permutations = data_io.get_permutations(dataset.get_bucket_counts())

    assert len(permutations) == len(inverse_permutations) == len(dataset)
    dataset_restored = dataset.permute(permutations).permute(inverse_permutations)
    assert len(dataset) == len(dataset_restored)
    for buck_idx in range(len(dataset)):
        num_samples = dataset.source[buck_idx].shape[0]
        if num_samples:
            assert (dataset.source[buck_idx] == dataset_restored.source[buck_idx]).asnumpy().all()
            assert (dataset.target[buck_idx] == dataset_restored.target[buck_idx]).asnumpy().all()
            assert (dataset.label[buck_idx] == dataset_restored.label[buck_idx]).asnumpy().all()
        else:
            assert not dataset_restored.source[buck_idx]
            assert not dataset_restored.target[buck_idx]
            assert not dataset_restored.label[buck_idx]
Exemple #2
0
def test_parallel_data_set_permute():
    batch_size = 5
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1, 1.0)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(
        buckets,
        batch_size,
        batch_by_words=False,
        batch_num_devices=1,
        data_target_average_len=[None] * len(buckets))
    dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(
        buckets, min_count=0, max_count=5)).fill_up(bucket_batch_sizes)

    permutations, inverse_permutations = data_io.get_permutations(
        dataset.get_bucket_counts())

    assert len(permutations) == len(inverse_permutations) == len(dataset)
    dataset_restored = dataset.permute(permutations).permute(
        inverse_permutations)
    assert len(dataset) == len(dataset_restored)
    for buck_idx in range(len(dataset)):
        num_samples = dataset.source[buck_idx].shape[0]
        if num_samples:
            assert (dataset.source[buck_idx] ==
                    dataset_restored.source[buck_idx]).asnumpy().all()
            assert (dataset.target[buck_idx] ==
                    dataset_restored.target[buck_idx]).asnumpy().all()
        else:
            assert not dataset_restored.source[buck_idx]
            assert not dataset_restored.target[buck_idx]
Exemple #3
0
def test_parallel_data_set_permute():
    pytest.importorskip('mxnet')
    from sockeye import data_io
    batch_size = 5
    buckets = data_io.define_parallel_buckets(100, 100, 10, True, 1.0)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets,
                                                           batch_size,
                                                           batch_type=C.BATCH_TYPE_SENTENCE,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))
    dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5)).fill_up(
        bucket_batch_sizes)

    permutations, inverse_permutations = data_io.get_permutations(dataset.get_bucket_counts())

    assert len(permutations) == len(inverse_permutations) == len(dataset)
    dataset_restored = dataset.permute(permutations).permute(inverse_permutations)
    assert len(dataset) == len(dataset_restored)
    for buck_idx in range(len(dataset)):
        num_samples = dataset.source[buck_idx].shape[0]
        if num_samples:
            assert (dataset.source[buck_idx] == dataset_restored.source[buck_idx]).all()
            assert (dataset.target[buck_idx] == dataset_restored.target[buck_idx]).all()
        else:
            assert not dataset_restored.source[buck_idx]
            assert not dataset_restored.target[buck_idx]
Exemple #4
0
def test_get_permutations():
    data = [list(range(3)), list(range(1)), list(range(7)), []]
    bucket_counts = [len(d) for d in data]

    permutation, inverse_permutation = data_io.get_permutations(bucket_counts)
    assert len(permutation) == len(inverse_permutation) == len(
        bucket_counts) == len(data)

    for d, p, pi in zip(data, permutation, inverse_permutation):
        p_set = set(p.tolist())
        pi_set = set(pi.tolist())
        assert len(p_set) == len(p)
        assert len(pi_set) == len(pi)
        assert p_set - pi_set == set()
        if d:
            d = torch.tensor(d)
            assert (d[p][pi] == d).all()
        else:
            assert len(p_set) == 1
Exemple #5
0
def test_get_permutations():
    data = [list(range(3)), list(range(1)), list(range(7)), []]
    bucket_counts = [len(d) for d in data]

    permutation, inverse_permutation = data_io.get_permutations(bucket_counts)
    assert len(permutation) == len(inverse_permutation) == len(bucket_counts) == len(data)

    for d, p, pi in zip(data, permutation, inverse_permutation):
        p = p.asnumpy().astype(np.int)
        pi = pi.asnumpy().astype(np.int)
        p_set = set(p)
        pi_set = set(pi)
        assert len(p_set) == len(p)
        assert len(pi_set) == len(pi)
        assert p_set - pi_set == set()
        if d:
            d = np.array(d)
            assert (d[p][pi] == d).all()
        else:
            assert len(p_set) == 1
Exemple #6
0
def test_get_permutations():
    data = [list(range(3)), list(range(1)), list(range(7)), []]
    bucket_counts = [len(d) for d in data]

    permutation, inverse_permutation = data_io.get_permutations(bucket_counts)
    assert len(permutation) == len(inverse_permutation) == len(bucket_counts) == len(data)

    for d, p, pi in zip(data, permutation, inverse_permutation):
        p = p.asnumpy().astype(np.int)
        pi = pi.asnumpy().astype(np.int)
        p_set = set(p)
        pi_set = set(pi)
        assert len(p_set) == len(p)
        assert len(pi_set) == len(pi)
        assert p_set - pi_set == set()
        if d:
            d = np.array(d)
            assert (d[p][pi] == d).all()
        else:
            assert len(p_set) == 1
Exemple #7
0
def test_get_permutations():
    pytest.importorskip('mxnet')
    from sockeye import data_io
    from mxnet import np
    data = [list(range(3)), list(range(1)), list(range(7)), []]
    bucket_counts = [len(d) for d in data]

    permutation, inverse_permutation = data_io.get_permutations(bucket_counts)
    assert len(permutation) == len(inverse_permutation) == len(bucket_counts) == len(data)

    for d, p, pi in zip(data, permutation, inverse_permutation):
        p = p.astype(np.int32)
        pi = pi.astype(np.int32)
        p_set = set(p.tolist())
        pi_set = set(pi.tolist())
        assert len(p_set) == len(p)
        assert len(pi_set) == len(pi)
        assert p_set - pi_set == set()
        if d:
            d = np.array(d)
            assert (d[p][pi] == d).all()
        else:
            assert len(p_set) == 1