def test_custom_space_vector_env(): env = VectorEnv(4, CustomSpace(), CustomSpace()) assert isinstance(env.single_observation_space, CustomSpace) assert isinstance(env.observation_space, Tuple) assert isinstance(env.single_action_space, CustomSpace) assert isinstance(env.action_space, Tuple)
Dict({ "position": Dict({ "x": MultiDiscrete([29, 29, 29, 29]), "y": MultiDiscrete([31, 31, 31, 31]), }), "velocity": Tuple(( MultiDiscrete([37, 37, 37, 37]), Box(low=0, high=255, shape=(4, ), dtype=np.uint8), )), }), ] expected_custom_batch_spaces_4 = [ Tuple((CustomSpace(), CustomSpace(), CustomSpace(), CustomSpace())), Tuple(( Tuple((CustomSpace(), CustomSpace(), CustomSpace(), CustomSpace())), Box(low=0, high=255, shape=(4, ), dtype=np.uint8), )), ] @pytest.mark.parametrize( "space,expected_batch_space_4", list(zip(spaces, expected_batch_spaces_4)), ids=[space.__class__.__name__ for space in spaces], ) def test_batch_space(space, expected_batch_space_4): batch_space_4 = batch_space(space, n=4) assert batch_space_4 == expected_batch_space_4