def test_flatten_roundtripping(space): some_samples = [space.sample() for _ in range(10)] flattened_samples = [utils.flatten(space, sample) for sample in some_samples] roundtripped_samples = [ utils.unflatten(space, sample) for sample in flattened_samples ] for i, (original, roundtripped) in enumerate( zip(some_samples, roundtripped_samples) ): assert compare_nested( original, roundtripped ), f"Expected sample #{i} {original} to equal {roundtripped}"
def step(self, actions): # Make sure actions is a list if isinstance(actions, np.ndarray): actions = actions.tolist() elif not isinstance(actions, list): actions = [actions] req = api_v1.StepRequest(instanceId=self.instanceId, actions=actions) res = self.client.Step(req) reward = res.reward done = res.isDone obs = gym_utils.unflatten(self.observation_space, res.observation) info = {} # @TODO, convert Protobuf map<string, string> to Dict return [obs, reward, done, info]
def test_dtypes(original_space, expected_flattened_dtype): flattened_space = utils.flatten_space(original_space) original_sample = original_space.sample() flattened_sample = utils.flatten(original_space, original_sample) unflattened_sample = utils.unflatten(original_space, flattened_sample) assert flattened_space.contains( flattened_sample ), "Expected flattened_space to contain flattened_sample" assert flattened_space.dtype == expected_flattened_dtype, "Expected flattened_space's dtype to equal " \ "{}".format(expected_flattened_dtype) assert flattened_sample.dtype == flattened_space.dtype, "Expected flattened_space's dtype to equal " \ "flattened_sample's dtype " compare_sample_types(original_space, original_sample, unflattened_sample)
def test_unflatten(space, flattened_sample, expected_sample): sample = utils.unflatten(space, flattened_sample) assert compare_nested(sample, expected_sample)