def test_invalid_flattened_dict_raises_error(self):
   input_dict = collections.OrderedDict((
       ('foo.bar', True),
       ('foo', 'invalid_value_for_sub_key'),
   ))
   with self.assertRaisesRegex(ValueError, 'Duplicate key'):
     dm_env_flatten_utils.unflatten_dict(input_dict, '.')
 def test_sub_tree_has_value_raises_error(self):
   input_dict = collections.OrderedDict((
       ('branch', 'should_not_have_value'),
       ('branch.leaf', True),
       ))
   with self.assertRaisesRegex(ValueError,
                               "Sub-tree 'branch' has already been assigned"):
     dm_env_flatten_utils.unflatten_dict(input_dict, '.')
Example #3
0
 def action_spec(self):
     """Implements dm_env.Environment.action_spec."""
     action_spec = dm_env_utils.dm_env_spec(self._action_specs)
     if self._nested_tensors:
         return dm_env_flatten_utils.unflatten_dict(action_spec,
                                                    DEFAULT_KEY_SEPARATOR)
     else:
         return action_spec
 def test_flatten_unflatten(self):
   input_output = {
       'foo': {
           'bar': 1,
           'baz': False
       },
       'fiz': object(),
   }
   self.assertSameElements(
       input_output,
       dm_env_flatten_utils.unflatten_dict(
           dm_env_flatten_utils.flatten_dict(input_output, '.'), '.'))
Example #5
0
    def observation_spec(self):
        """Implements dm_env.Environment.observation_spec."""
        specs = {}
        for uid in self._requested_observation_uids:
            name = self._observation_specs.uid_to_name(uid)
            specs[name] = dm_env_utils.tensor_spec_to_dm_env_spec(
                self._observation_specs.uid_to_spec(uid))
        if not self._is_reward_requested:
            specs.pop(DEFAULT_REWARD_KEY, None)
        if not self._is_discount_requested:
            specs.pop(DEFAULT_DISCOUNT_KEY, None)

        if self._nested_tensors:
            return dm_env_flatten_utils.unflatten_dict(specs,
                                                       DEFAULT_KEY_SEPARATOR)
        else:
            return specs
Example #6
0
    def step(self, actions):
        """Implements dm_env.Environment.step."""
        actions = dm_env_flatten_utils.flatten_dict(
            actions,
            DEFAULT_KEY_SEPARATOR) if self._nested_tensors else actions
        step_response = self._connection.send(
            dm_env_rpc_pb2.StepRequest(
                requested_observations=self._requested_observation_uids,
                actions=self._action_specs.pack(actions)))

        observations = self._observation_specs.unpack(
            step_response.observations)

        if (step_response.state == dm_env_rpc_pb2.EnvironmentStateType.RUNNING
                and self._last_state
                == dm_env_rpc_pb2.EnvironmentStateType.RUNNING):
            step_type = dm_env.StepType.MID
        elif step_response.state == dm_env_rpc_pb2.EnvironmentStateType.RUNNING:
            step_type = dm_env.StepType.FIRST
        elif self._last_state == dm_env_rpc_pb2.EnvironmentStateType.RUNNING:
            step_type = dm_env.StepType.LAST
        else:
            raise RuntimeError('Environment transitioned from {} to {}'.format(
                self._last_state, step_response.state))

        self._last_state = step_response.state

        reward = self.reward(state=step_response.state,
                             step_type=step_type,
                             observations=observations)
        discount = self.discount(state=step_response.state,
                                 step_type=step_type,
                                 observations=observations)
        if not self._is_reward_requested:
            observations.pop(DEFAULT_REWARD_KEY, None)
        if not self._is_discount_requested:
            observations.pop(DEFAULT_DISCOUNT_KEY, None)
        observations = dm_env_flatten_utils.unflatten_dict(
            observations,
            DEFAULT_KEY_SEPARATOR) if self._nested_tensors else observations
        return dm_env.TimeStep(step_type, reward, discount, observations)
 def test_unflatten(self):
   input_dict = {
       'foo.bar.baz': True,
       'fiz.buz': 1,
       'foo.baz': 'val',
       'buz': {}
   }
   expected = {
       'foo': {
           'bar': {
               'baz': True
           },
           'baz': 'val'
       },
       'fiz': {
           'buz': 1
       },
       'buz': {},
   }
   self.assertSameElements(
       expected, dm_env_flatten_utils.unflatten_dict(input_dict, '.'))
 def test_unflatten_different_separator(self):
   input_dict = {'foo::bar.baz': True, 'foo.bar::baz': 1}
   expected = {'foo': {'bar.baz': True}, 'foo.bar': {'baz': 1}}
   self.assertSameElements(
       expected, dm_env_flatten_utils.unflatten_dict(input_dict, '::'))