def test_batch_across_shards(ray_start_regular_shared): it = from_iterators([[0, 1], [2, 3]]) it = it.batch_across_shards() assert ( repr(it) == "LocalIterator[ParallelIterator[from_iterators[shards=2]]" ".batch_across_shards()]") assert sorted(it) == [[0, 2], [1, 3]]
def test_union_async(ray_start_regular_shared): def gen_fast(): for i in range(10): time.sleep(0.05) print("PRODUCE FAST", i) yield i def gen_slow(): for i in range(10): time.sleep(0.3) print("PRODUCE SLOW", i) yield i it1 = from_iterators([gen_fast]).for_each(lambda x: ("fast", x)) it2 = from_iterators([gen_slow]).for_each(lambda x: ("slow", x)) it = it1.union(it2) results = list(it.gather_async()) assert all(x[0] == "slow" for x in results[-3:]), results
def test_union_local_async(ray_start_regular_shared): def gen_fast(): for i in range(10): time.sleep(0.05) print("PRODUCE FAST", i) yield i def gen_slow(): for i in range(10): time.sleep(0.3) print("PRODUCE SLOW", i) yield i it1 = from_iterators([gen_fast]).for_each(lambda x: ("fast", x)) it2 = from_iterators([gen_slow]).for_each(lambda x: ("slow", x)) it = it1.gather_async().union(it2.gather_async()) assert (repr(it) == "LocalIterator[LocalUnion[LocalIterator[" "ParallelIterator[from_iterators[shards=1].for_each()]" ".gather_async()], LocalIterator[ParallelIterator[" "from_iterators[shards=1].for_each()].gather_async()]]]") results = list(it) assert all(x[0] == "slow" for x in results[-3:]), results
def test_remote(ray_start_regular_shared): it = from_iterators([[0, 1], [3, 4], [5, 6, 7]]) assert it.num_shards() == 3 @ray.remote def get_shard(it, i): return list(it.get_shard(i)) assert ray.get(get_shard.remote(it, 0)) == [0, 1] assert ray.get(get_shard.remote(it, 1)) == [3, 4] assert ray.get(get_shard.remote(it, 2)) == [5, 6, 7] @ray.remote def check_remote(it): assert ray.get(get_shard.remote(it, 0)) == [0, 1] assert ray.get(get_shard.remote(it, 1)) == [3, 4] assert ray.get(get_shard.remote(it, 2)) == [5, 6, 7] ray.get(check_remote.remote(it))
def test_from_iterators(ray_start_regular_shared): it = from_iterators([range(2), range(2)]) assert repr(it) == "ParallelIterator[from_iterators[shards=2]]" assert list(it.gather_sync()) == [0, 0, 1, 1]