def test_flatten_unflatten(self):
   input_output = {
       'foo': {
           'bar': 1,
           'baz': False
       },
       'fiz': object(),
   }
   self.assertSameElements(
       input_output,
       dm_env_flatten_utils.unflatten_dict(
           dm_env_flatten_utils.flatten_dict(input_output, '.'), '.'))
 def test_empty_dict_values_flatten(self):
   input_dict = {
       'foo': {},
       'bar': {
           'baz': {}
       },
   }
   expected = {
       'foo': {},
       'bar.baz': {},
   }
   self.assertSameElements(expected,
                           dm_env_flatten_utils.flatten_dict(input_dict, '.'))
 def test_flatten(self):
   input_dict = {
       'foo': {
           'bar': 1,
           'baz': False
       },
       'fiz': object(),
   }
   expected = {
       'foo.bar': 1,
       'foo.baz': False,
       'fiz': object(),
   }
   self.assertSameElements(expected,
                           dm_env_flatten_utils.flatten_dict(input_dict, '.'))
Exemple #4
0
    def step(self, actions):
        """Implements dm_env.Environment.step."""
        actions = dm_env_flatten_utils.flatten_dict(
            actions,
            DEFAULT_KEY_SEPARATOR) if self._nested_tensors else actions
        step_response = self._connection.send(
            dm_env_rpc_pb2.StepRequest(
                requested_observations=self._requested_observation_uids,
                actions=self._action_specs.pack(actions)))

        observations = self._observation_specs.unpack(
            step_response.observations)

        if (step_response.state == dm_env_rpc_pb2.EnvironmentStateType.RUNNING
                and self._last_state
                == dm_env_rpc_pb2.EnvironmentStateType.RUNNING):
            step_type = dm_env.StepType.MID
        elif step_response.state == dm_env_rpc_pb2.EnvironmentStateType.RUNNING:
            step_type = dm_env.StepType.FIRST
        elif self._last_state == dm_env_rpc_pb2.EnvironmentStateType.RUNNING:
            step_type = dm_env.StepType.LAST
        else:
            raise RuntimeError('Environment transitioned from {} to {}'.format(
                self._last_state, step_response.state))

        self._last_state = step_response.state

        reward = self.reward(state=step_response.state,
                             step_type=step_type,
                             observations=observations)
        discount = self.discount(state=step_response.state,
                                 step_type=step_type,
                                 observations=observations)
        if not self._is_reward_requested:
            observations.pop(DEFAULT_REWARD_KEY, None)
        if not self._is_discount_requested:
            observations.pop(DEFAULT_DISCOUNT_KEY, None)
        observations = dm_env_flatten_utils.unflatten_dict(
            observations,
            DEFAULT_KEY_SEPARATOR) if self._nested_tensors else observations
        return dm_env.TimeStep(step_type, reward, discount, observations)
 def test_flatten_with_key_containing_separator_raises_error(self):
   with self.assertRaisesRegex(ValueError, 'foo.bar'):
     dm_env_flatten_utils.flatten_dict({'foo.bar': True}, '.')