def test_local_shuffle(ray_start_regular_shared): # confirm that no data disappears, and they all stay within the same shard it = from_range(8, num_shards=2).local_shuffle(shuffle_buffer_size=2) assert repr(it) == ("ParallelIterator[from_range[8, shards=2]" + ".local_shuffle(shuffle_buffer_size=2, seed=None)]") shard_0 = it.get_shard(0) shard_1 = it.get_shard(1) assert set(shard_0) == {0, 1, 2, 3} assert set(shard_1) == {4, 5, 6, 7} # check that shuffling results in different orders it1 = from_range(100, num_shards=10).local_shuffle(shuffle_buffer_size=5) it2 = from_range(100, num_shards=10).local_shuffle(shuffle_buffer_size=5) assert list(it1.gather_sync()) != list(it2.gather_sync()) # buffer size of 1 should not result in any shuffling it3 = from_range(10, num_shards=1).local_shuffle(shuffle_buffer_size=1) assert list(it3.gather_sync()) == list(range(10)) # statistical test it4 = from_items( [0, 1] * 10000, num_shards=1).local_shuffle(shuffle_buffer_size=100) result = "".join(it4.gather_sync().for_each(str)) freq_counter = Counter(zip(result[:-1], result[1:])) assert len(freq_counter) == 4 for key, value in freq_counter.items(): assert value / len(freq_counter) > 0.2
def test_gather_async(ray_start_regular_shared): it = from_range(4) it = it.gather_async() assert ( repr(it) == "LocalIterator[ParallelIterator[from_range[4, shards=2]]" ".gather_async()]") assert sorted(it) == [0, 1, 2, 3]
def test_batch(ray_start_regular_shared): it = from_range(4, 1).batch(2) assert repr(it) == "ParallelIterator[from_range[4, shards=1].batch(2)]" assert list(it.gather_sync()) == [[0, 1], [2, 3]]
def test_filter(ray_start_regular_shared): it = from_range(4).filter(lambda x: x < 3) assert repr(it) == "ParallelIterator[from_range[4, shards=2].filter()]" assert list(it.gather_sync()) == [0, 2, 1]
def test_chain(ray_start_regular_shared): it = from_range(4).for_each(lambda x: x * 2).for_each(lambda x: x * 2) assert repr( it ) == "ParallelIterator[from_range[4, shards=2].for_each().for_each()]" assert list(it.gather_sync()) == [0, 8, 4, 12]
def test_combine(ray_start_regular_shared): it = from_range(4, 1).combine(lambda x: [x, x]) assert repr(it) == "ParallelIterator[from_range[4, shards=1].combine()]" assert list(it.gather_sync()) == [0, 0, 1, 1, 2, 2, 3, 3]
def test_from_range(ray_start_regular_shared): it = from_range(4) assert repr(it) == "ParallelIterator[from_range[4, shards=2]]" assert list(it.gather_sync()) == [0, 2, 1, 3]
def test_union_local(ray_start_regular_shared): it1 = from_items(["a", "b", "c"], 1).gather_async() it2 = from_range(5, 2).for_each(str).gather_async() it = it1.union(it2) assert sorted(it) == ["0", "1", "2", "3", "4", "a", "b", "c"]