def test_parallel_data_set_permute(): batch_size = 5 buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0) bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5)).fill_up( bucket_batch_sizes, 'replicate') permutations, inverse_permutations = data_io.get_permutations(dataset.get_bucket_counts()) assert len(permutations) == len(inverse_permutations) == len(dataset) dataset_restored = dataset.permute(permutations).permute(inverse_permutations) assert len(dataset) == len(dataset_restored) for buck_idx in range(len(dataset)): num_samples = dataset.source[buck_idx].shape[0] if num_samples: assert (dataset.source[buck_idx] == dataset_restored.source[buck_idx]).asnumpy().all() assert (dataset.target[buck_idx] == dataset_restored.target[buck_idx]).asnumpy().all() assert (dataset.label[buck_idx] == dataset_restored.label[buck_idx]).asnumpy().all() else: assert not dataset_restored.source[buck_idx] assert not dataset_restored.target[buck_idx] assert not dataset_restored.label[buck_idx]
def test_parallel_data_set_permute(): batch_size = 5 buckets = data_io.define_parallel_buckets(100, 100, 10, 1, 1.0) bucket_batch_sizes = data_io.define_bucket_batch_sizes( buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data( buckets, min_count=0, max_count=5)).fill_up(bucket_batch_sizes) permutations, inverse_permutations = data_io.get_permutations( dataset.get_bucket_counts()) assert len(permutations) == len(inverse_permutations) == len(dataset) dataset_restored = dataset.permute(permutations).permute( inverse_permutations) assert len(dataset) == len(dataset_restored) for buck_idx in range(len(dataset)): num_samples = dataset.source[buck_idx].shape[0] if num_samples: assert (dataset.source[buck_idx] == dataset_restored.source[buck_idx]).asnumpy().all() assert (dataset.target[buck_idx] == dataset_restored.target[buck_idx]).asnumpy().all() else: assert not dataset_restored.source[buck_idx] assert not dataset_restored.target[buck_idx]
def test_parallel_data_set_permute(): pytest.importorskip('mxnet') from sockeye import data_io batch_size = 5 buckets = data_io.define_parallel_buckets(100, 100, 10, True, 1.0) bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets, batch_size, batch_type=C.BATCH_TYPE_SENTENCE, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5)).fill_up( bucket_batch_sizes) permutations, inverse_permutations = data_io.get_permutations(dataset.get_bucket_counts()) assert len(permutations) == len(inverse_permutations) == len(dataset) dataset_restored = dataset.permute(permutations).permute(inverse_permutations) assert len(dataset) == len(dataset_restored) for buck_idx in range(len(dataset)): num_samples = dataset.source[buck_idx].shape[0] if num_samples: assert (dataset.source[buck_idx] == dataset_restored.source[buck_idx]).all() assert (dataset.target[buck_idx] == dataset_restored.target[buck_idx]).all() else: assert not dataset_restored.source[buck_idx] assert not dataset_restored.target[buck_idx]
def test_get_permutations(): data = [list(range(3)), list(range(1)), list(range(7)), []] bucket_counts = [len(d) for d in data] permutation, inverse_permutation = data_io.get_permutations(bucket_counts) assert len(permutation) == len(inverse_permutation) == len( bucket_counts) == len(data) for d, p, pi in zip(data, permutation, inverse_permutation): p_set = set(p.tolist()) pi_set = set(pi.tolist()) assert len(p_set) == len(p) assert len(pi_set) == len(pi) assert p_set - pi_set == set() if d: d = torch.tensor(d) assert (d[p][pi] == d).all() else: assert len(p_set) == 1
def test_get_permutations(): data = [list(range(3)), list(range(1)), list(range(7)), []] bucket_counts = [len(d) for d in data] permutation, inverse_permutation = data_io.get_permutations(bucket_counts) assert len(permutation) == len(inverse_permutation) == len(bucket_counts) == len(data) for d, p, pi in zip(data, permutation, inverse_permutation): p = p.asnumpy().astype(np.int) pi = pi.asnumpy().astype(np.int) p_set = set(p) pi_set = set(pi) assert len(p_set) == len(p) assert len(pi_set) == len(pi) assert p_set - pi_set == set() if d: d = np.array(d) assert (d[p][pi] == d).all() else: assert len(p_set) == 1
def test_get_permutations(): data = [list(range(3)), list(range(1)), list(range(7)), []] bucket_counts = [len(d) for d in data] permutation, inverse_permutation = data_io.get_permutations(bucket_counts) assert len(permutation) == len(inverse_permutation) == len(bucket_counts) == len(data) for d, p, pi in zip(data, permutation, inverse_permutation): p = p.asnumpy().astype(np.int) pi = pi.asnumpy().astype(np.int) p_set = set(p) pi_set = set(pi) assert len(p_set) == len(p) assert len(pi_set) == len(pi) assert p_set - pi_set == set() if d: d = np.array(d) assert (d[p][pi] == d).all() else: assert len(p_set) == 1
def test_get_permutations(): pytest.importorskip('mxnet') from sockeye import data_io from mxnet import np data = [list(range(3)), list(range(1)), list(range(7)), []] bucket_counts = [len(d) for d in data] permutation, inverse_permutation = data_io.get_permutations(bucket_counts) assert len(permutation) == len(inverse_permutation) == len(bucket_counts) == len(data) for d, p, pi in zip(data, permutation, inverse_permutation): p = p.astype(np.int32) pi = pi.astype(np.int32) p_set = set(p.tolist()) pi_set = set(pi.tolist()) assert len(p_set) == len(p) assert len(pi_set) == len(pi) assert p_set - pi_set == set() if d: d = np.array(d) assert (d[p][pi] == d).all() else: assert len(p_set) == 1