Esempio n. 1
0
def test_union(ray_start_regular_shared):
    it1 = from_items(["a", "b", "c"], 1)
    it2 = from_items(["x", "y", "z"], 1)
    it = it1.union(it2)
    assert (repr(it) == "ParallelIterator[ParallelUnion[ParallelIterator["
            "from_items[str, 3, shards=1]], ParallelIterator["
            "from_items[str, 3, shards=1]]]]")
    assert list(it.gather_sync()) == ["a", "x", "b", "y", "c", "z"]
Esempio n. 2
0
def test_local_shuffle(ray_start_regular_shared):
    # confirm that no data disappears, and they all stay within the same shard
    it = from_range(8, num_shards=2).local_shuffle(shuffle_buffer_size=2)
    assert repr(it) == ("ParallelIterator[from_range[8, shards=2]" +
                        ".local_shuffle(shuffle_buffer_size=2, seed=None)]")
    shard_0 = it.get_shard(0)
    shard_1 = it.get_shard(1)
    assert set(shard_0) == {0, 1, 2, 3}
    assert set(shard_1) == {4, 5, 6, 7}

    # check that shuffling results in different orders
    it1 = from_range(100, num_shards=10).local_shuffle(shuffle_buffer_size=5)
    it2 = from_range(100, num_shards=10).local_shuffle(shuffle_buffer_size=5)
    assert list(it1.gather_sync()) != list(it2.gather_sync())

    # buffer size of 1 should not result in any shuffling
    it3 = from_range(10, num_shards=1).local_shuffle(shuffle_buffer_size=1)
    assert list(it3.gather_sync()) == list(range(10))

    # statistical test
    it4 = from_items(
        [0, 1] * 10000, num_shards=1).local_shuffle(shuffle_buffer_size=100)
    result = "".join(it4.gather_sync().for_each(str))
    freq_counter = Counter(zip(result[:-1], result[1:]))
    assert len(freq_counter) == 4
    for key, value in freq_counter.items():
        assert value / len(freq_counter) > 0.2
Esempio n. 3
0
def test_serialization(ray_start_regular_shared):
    it = (from_items([1, 2, 3, 4]).gather_sync().for_each(lambda x: x)
          .filter(lambda x: True).batch(2).flatten())
    assert (repr(it) == "LocalIterator[ParallelIterator["
            "from_items[int, 4, shards=2]].gather_sync()."
            "for_each().filter().batch(2).flatten()]")

    @ray.remote
    def get(it):
        return list(it)

    assert ray.get(get.remote(it)) == [1, 2, 3, 4]
Esempio n. 4
0
def test_flatten(ray_start_regular_shared):
    it = from_items([[1, 2], [3, 4]], 1).flatten()
    assert repr(
        it) == "ParallelIterator[from_items[list, 2, shards=1].flatten()]"
    assert list(it.gather_sync()) == [1, 2, 3, 4]
Esempio n. 5
0
def test_from_items_repeat(ray_start_regular_shared):
    it = from_items([1, 2, 3, 4], repeat=True)
    assert repr(
        it) == "ParallelIterator[from_items[int, 4, shards=2, repeat=True]]"
    assert it.take(8) == [1, 2, 3, 4, 1, 2, 3, 4]
Esempio n. 6
0
def test_union_local(ray_start_regular_shared):
    it1 = from_items(["a", "b", "c"], 1).gather_async()
    it2 = from_range(5, 2).for_each(str).gather_async()
    it = it1.union(it2)
    assert sorted(it) == ["0", "1", "2", "3", "4", "a", "b", "c"]
Esempio n. 7
0
def test_from_items(ray_start_regular_shared):
    it = from_items([1, 2, 3, 4])
    assert repr(it) == "ParallelIterator[from_items[int, 4, shards=2]]"
    assert list(it.gather_sync()) == [1, 2, 3, 4]
    assert next(it.gather_sync()) == 1