def test_invalid_flattened_dict_raises_error(self): input_dict = collections.OrderedDict(( ('foo.bar', True), ('foo', 'invalid_value_for_sub_key'), )) with self.assertRaisesRegex(ValueError, 'Duplicate key'): dm_env_flatten_utils.unflatten_dict(input_dict, '.')
def test_sub_tree_has_value_raises_error(self): input_dict = collections.OrderedDict(( ('branch', 'should_not_have_value'), ('branch.leaf', True), )) with self.assertRaisesRegex(ValueError, "Sub-tree 'branch' has already been assigned"): dm_env_flatten_utils.unflatten_dict(input_dict, '.')
def action_spec(self): """Implements dm_env.Environment.action_spec.""" action_spec = dm_env_utils.dm_env_spec(self._action_specs) if self._nested_tensors: return dm_env_flatten_utils.unflatten_dict(action_spec, DEFAULT_KEY_SEPARATOR) else: return action_spec
def test_flatten_unflatten(self): input_output = { 'foo': { 'bar': 1, 'baz': False }, 'fiz': object(), } self.assertSameElements( input_output, dm_env_flatten_utils.unflatten_dict( dm_env_flatten_utils.flatten_dict(input_output, '.'), '.'))
def observation_spec(self): """Implements dm_env.Environment.observation_spec.""" specs = {} for uid in self._requested_observation_uids: name = self._observation_specs.uid_to_name(uid) specs[name] = dm_env_utils.tensor_spec_to_dm_env_spec( self._observation_specs.uid_to_spec(uid)) if not self._is_reward_requested: specs.pop(DEFAULT_REWARD_KEY, None) if not self._is_discount_requested: specs.pop(DEFAULT_DISCOUNT_KEY, None) if self._nested_tensors: return dm_env_flatten_utils.unflatten_dict(specs, DEFAULT_KEY_SEPARATOR) else: return specs
def step(self, actions): """Implements dm_env.Environment.step.""" actions = dm_env_flatten_utils.flatten_dict( actions, DEFAULT_KEY_SEPARATOR) if self._nested_tensors else actions step_response = self._connection.send( dm_env_rpc_pb2.StepRequest( requested_observations=self._requested_observation_uids, actions=self._action_specs.pack(actions))) observations = self._observation_specs.unpack( step_response.observations) if (step_response.state == dm_env_rpc_pb2.EnvironmentStateType.RUNNING and self._last_state == dm_env_rpc_pb2.EnvironmentStateType.RUNNING): step_type = dm_env.StepType.MID elif step_response.state == dm_env_rpc_pb2.EnvironmentStateType.RUNNING: step_type = dm_env.StepType.FIRST elif self._last_state == dm_env_rpc_pb2.EnvironmentStateType.RUNNING: step_type = dm_env.StepType.LAST else: raise RuntimeError('Environment transitioned from {} to {}'.format( self._last_state, step_response.state)) self._last_state = step_response.state reward = self.reward(state=step_response.state, step_type=step_type, observations=observations) discount = self.discount(state=step_response.state, step_type=step_type, observations=observations) if not self._is_reward_requested: observations.pop(DEFAULT_REWARD_KEY, None) if not self._is_discount_requested: observations.pop(DEFAULT_DISCOUNT_KEY, None) observations = dm_env_flatten_utils.unflatten_dict( observations, DEFAULT_KEY_SEPARATOR) if self._nested_tensors else observations return dm_env.TimeStep(step_type, reward, discount, observations)
def test_unflatten(self): input_dict = { 'foo.bar.baz': True, 'fiz.buz': 1, 'foo.baz': 'val', 'buz': {} } expected = { 'foo': { 'bar': { 'baz': True }, 'baz': 'val' }, 'fiz': { 'buz': 1 }, 'buz': {}, } self.assertSameElements( expected, dm_env_flatten_utils.unflatten_dict(input_dict, '.'))
def test_unflatten_different_separator(self): input_dict = {'foo::bar.baz': True, 'foo.bar::baz': 1} expected = {'foo': {'bar.baz': True}, 'foo.bar': {'baz': 1}} self.assertSameElements( expected, dm_env_flatten_utils.unflatten_dict(input_dict, '::'))