Ejemplo n.º 1
0
def test_flatten_roundtripping(space):
    some_samples = [space.sample() for _ in range(10)]
    flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
    roundtripped_samples = [
        utils.unflatten(space, sample) for sample in flattened_samples
    ]
    for i, (original, roundtripped) in enumerate(
        zip(some_samples, roundtripped_samples)
    ):
        assert compare_nested(
            original, roundtripped
        ), f"Expected sample #{i} {original} to equal {roundtripped}"
Ejemplo n.º 2
0
    def step(self, actions):
        # Make sure actions is a list
        if isinstance(actions, np.ndarray):
            actions = actions.tolist()
        elif not isinstance(actions, list):
            actions = [actions]

        req = api_v1.StepRequest(instanceId=self.instanceId, actions=actions)
        res = self.client.Step(req)

        reward = res.reward
        done = res.isDone
        obs = gym_utils.unflatten(self.observation_space, res.observation)
        info = {}  # @TODO, convert Protobuf map<string, string> to Dict

        return [obs, reward, done, info]
Ejemplo n.º 3
0
def test_dtypes(original_space, expected_flattened_dtype):
    flattened_space = utils.flatten_space(original_space)

    original_sample = original_space.sample()
    flattened_sample = utils.flatten(original_space, original_sample)
    unflattened_sample = utils.unflatten(original_space, flattened_sample)

    assert flattened_space.contains(
        flattened_sample
    ), "Expected flattened_space to contain flattened_sample"
    assert flattened_space.dtype == expected_flattened_dtype, "Expected flattened_space's dtype to equal " \
                                                              "{}".format(expected_flattened_dtype)

    assert flattened_sample.dtype == flattened_space.dtype, "Expected flattened_space's dtype to equal " \
                                                            "flattened_sample's dtype "

    compare_sample_types(original_space, original_sample, unflattened_sample)
Ejemplo n.º 4
0
def test_unflatten(space, flattened_sample, expected_sample):
    sample = utils.unflatten(space, flattened_sample)
    assert compare_nested(sample, expected_sample)