Пример #1
0
def test_write_to_shared_memory(space):
    def assert_nested_equal(lhs, rhs):
        assert isinstance(rhs, list)
        if isinstance(lhs, (list, tuple)):
            for i in range(len(lhs)):
                assert_nested_equal(lhs[i], [rhs_[i] for rhs_ in rhs])

        elif isinstance(lhs, (dict, OrderedDict)):
            for key in lhs.keys():
                assert_nested_equal(lhs[key], [rhs_[key] for rhs_ in rhs])

        elif isinstance(lhs, SynchronizedArray):
            assert np.all(np.array(lhs[:]) == np.stack(rhs, axis=0).flatten())

        else:
            raise TypeError(f"Got unknown type `{type(lhs)}`.")

    def write(i, shared_memory, sample):
        write_to_shared_memory(space, i, sample, shared_memory)

    shared_memory_n8 = create_shared_memory(space, n=8)
    samples = [space.sample() for _ in range(8)]

    processes = [
        Process(target=write, args=(i, shared_memory_n8, samples[i]))
        for i in range(8)
    ]

    for process in processes:
        process.start()
    for process in processes:
        process.join()

    assert_nested_equal(shared_memory_n8, samples)
Пример #2
0
def test_create_shared_memory(space, expected_type, n, ctx):
    def assert_nested_type(lhs, rhs, n):
        assert type(lhs) == type(rhs)
        if isinstance(lhs, (list, tuple)):
            assert len(lhs) == len(rhs)
            for lhs_, rhs_ in zip(lhs, rhs):
                assert_nested_type(lhs_, rhs_, n)

        elif isinstance(lhs, (dict, OrderedDict)):
            assert set(lhs.keys()) ^ set(rhs.keys()) == set()
            for key in lhs.keys():
                assert_nested_type(lhs[key], rhs[key], n)

        elif isinstance(lhs, SynchronizedArray):
            # Assert the length of the array
            assert len(lhs[:]) == n * len(rhs[:])
            # Assert the data type
            assert type(lhs[0]) == type(rhs[0])  # noqa: E721

        else:
            raise TypeError(f"Got unknown type `{type(lhs)}`.")

    ctx = mp if (ctx is None) else mp.get_context(ctx)
    shared_memory = create_shared_memory(space, n=n, ctx=ctx)
    assert_nested_type(shared_memory, expected_type, n=n)
Пример #3
0
def test_read_from_shared_memory(space):
    def assert_nested_equal(lhs, rhs, space, n):
        assert isinstance(rhs, list)
        if isinstance(space, Tuple):
            assert isinstance(lhs, tuple)
            for i in range(len(lhs)):
                assert_nested_equal(lhs[i], [rhs_[i] for rhs_ in rhs],
                                    space.spaces[i], n)

        elif isinstance(space, Dict):
            assert isinstance(lhs, OrderedDict)
            for key in lhs.keys():
                assert_nested_equal(lhs[key], [rhs_[key] for rhs_ in rhs],
                                    space.spaces[key], n)

        elif isinstance(space, _BaseGymSpaces):
            assert isinstance(lhs, np.ndarray)
            assert lhs.shape == ((n, ) + space.shape)
            assert lhs.dtype == space.dtype
            assert np.all(lhs == np.stack(rhs, axis=0))

        else:
            raise TypeError(f"Got unknown type `{type(space)}`")

    def write(i, shared_memory, sample):
        write_to_shared_memory(space, i, sample, shared_memory)

    shared_memory_n8 = create_shared_memory(space, n=8)
    memory_view_n8 = read_from_shared_memory(space, shared_memory_n8, n=8)
    samples = [space.sample() for _ in range(8)]

    processes = [
        Process(target=write, args=(i, shared_memory_n8, samples[i]))
        for i in range(8)
    ]

    for process in processes:
        process.start()
    for process in processes:
        process.join()

    assert_nested_equal(memory_view_n8, samples, space, n=8)
def test_create_shared_memory(space, expected_type, n):
    def assert_nested_type(lhs, rhs, n):
        assert type(lhs) == type(rhs)
        if isinstance(lhs, (list, tuple)):
            assert len(lhs) == len(rhs)
            for lhs_, rhs_ in zip(lhs, rhs):
                assert_nested_type(lhs_, rhs_, n)

        elif isinstance(lhs, (dict, OrderedDict)):
            assert set(lhs.keys()) ^ set(rhs.keys()) == set()
            for key in lhs.keys():
                assert_nested_type(lhs[key], rhs[key], n)

        elif isinstance(lhs, SynchronizedArray):
            # Assert the length of the array
            assert len(lhs[:]) == n * len(rhs[:])
            # Assert the data type
            assert type(lhs[0]) == type(rhs[0])

        else:
            raise TypeError('Got unknown type `{0}`.'.format(type(lhs)))

    shared_memory = create_shared_memory(space, n=n)
    assert_nested_type(shared_memory, expected_type, n=n)
Пример #5
0
def test_create_shared_memory_custom_space(n, ctx, space):
    ctx = mp if (ctx is None) else mp.get_context(ctx)
    with pytest.raises(CustomSpaceError):
        shared_memory = create_shared_memory(space, n=n, ctx=ctx)