def test_bounds_wrong_type_gives_error(self): tensor_spec = dm_env_rpc_pb2.TensorSpec() tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32 tensor_spec.shape[:] = [3] tensor_spec.name = 'foo' tensor_spec.min.floats.array[:] = [1.9] with self.assertRaisesRegex(ValueError, 'uint32'): dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
def test_bounds_on_string_gives_error(self): tensor_spec = dm_env_rpc_pb2.TensorSpec() tensor_spec.dtype = dm_env_rpc_pb2.DataType.STRING tensor_spec.shape[:] = [2] tensor_spec.name = 'named' tensor_spec.min.floats.array[:] = [1.9] tensor_spec.max.floats.array[:] = [10.0] with self.assertRaisesRegex(ValueError, 'string'): dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
def __init__(self, connection, specs, requested_observations=None): """Initializes the environment with the provided dm_env_rpc connection. Args: connection: An instance of Connection already connected to a dm_env_rpc server and after a successful JoinWorldRequest has been sent. specs: A dm_env_rpc ActionObservationSpecs message for the environment. requested_observations: The observation names to be requested from the environment when step is called. If None is specified then all observations will be requested. """ self._dm_env_rpc_specs = specs self._action_specs = spec_manager.SpecManager(specs.actions) self._observation_specs = spec_manager.SpecManager(specs.observations) self._connection = connection self._last_state = dm_env_rpc_pb2.EnvironmentStateType.TERMINATED if requested_observations is None: requested_observations = self._observation_specs.names() requested_observations = set(requested_observations) self._is_reward_requested = DEFAULT_REWARD_KEY in requested_observations self._is_discount_requested = DEFAULT_DISCOUNT_KEY in requested_observations self._default_reward_spec = None self._default_discount_spec = None if DEFAULT_REWARD_KEY in self._observation_specs.names(): self._default_reward_spec = dm_env_utils.tensor_spec_to_dm_env_spec( self._observation_specs.name_to_spec(DEFAULT_REWARD_KEY)) requested_observations.add(DEFAULT_REWARD_KEY) if DEFAULT_DISCOUNT_KEY in self._observation_specs.names(): self._default_discount_spec = ( dm_env_utils.tensor_spec_to_dm_env_spec( self._observation_specs.name_to_spec(DEFAULT_DISCOUNT_KEY))) requested_observations.add(DEFAULT_DISCOUNT_KEY) unsupported_observations = requested_observations.difference( self._observation_specs.names()) if unsupported_observations: raise ValueError('Unsupported observations requested: {}'.format( unsupported_observations)) self._requested_observation_uids = [ self._observation_specs.name_to_uid(name) for name in requested_observations ] # Not strictly necessary but it makes the unit tests deterministic. self._requested_observation_uids.sort()
def test_no_bounds_gives_arrayspec(self): tensor_spec = dm_env_rpc_pb2.TensorSpec() tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32 tensor_spec.shape[:] = [3] tensor_spec.name = 'foo' actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec) self.assertEqual(specs.Array(shape=[3], dtype=np.uint32), actual) self.assertEqual('foo', actual.name)
def test_string_give_string_array(self): tensor_spec = dm_env_rpc_pb2.TensorSpec() tensor_spec.dtype = dm_env_rpc_pb2.DataType.STRING tensor_spec.shape[:] = [1, 2, 3] tensor_spec.name = 'string_spec' actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec) self.assertEqual(specs.StringArray(shape=[1, 2, 3]), actual) self.assertEqual('string_spec', actual.name)
def spec(self) -> Optional[dm_env_specs.Array]: """Returns a dm_env spec if the property has a valid dtype. Returns: Either a dm_env spec or, if the dtype is invalid, None. """ if self._property_spec_proto.spec.dtype != ( dm_env_rpc_pb2.DataType.INVALID_DATA_TYPE): return dm_env_utils.tensor_spec_to_dm_env_spec( self._property_spec_proto.spec) else: return None
def test_spec_generate_and_validate_scalars(self): dm_env_rpc_specs = [] for name, dtype in dm_env_rpc_pb2.DataType.items(): if dtype != dm_env_rpc_pb2.DataType.INVALID_DATA_TYPE: dm_env_rpc_specs.append( dm_env_rpc_pb2.TensorSpec(name=name, shape=(), dtype=dtype)) for dm_env_rpc_spec in dm_env_rpc_specs: spec = dm_env_utils.tensor_spec_to_dm_env_spec(dm_env_rpc_spec) value = spec.generate_value() spec.validate(value)
def observation_spec(self): """Implements dm_env.Environment.observation_spec.""" specs = {} for uid in self._requested_observation_uids: name = self._observation_specs.uid_to_name(uid) specs[name] = dm_env_utils.tensor_spec_to_dm_env_spec( self._observation_specs.uid_to_spec(uid)) if not self._is_reward_requested: specs.pop(DEFAULT_REWARD_KEY, None) if not self._is_discount_requested: specs.pop(DEFAULT_DISCOUNT_KEY, None) return specs
def test_only_max_bounds(self): tensor_spec = dm_env_rpc_pb2.TensorSpec() tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32 tensor_spec.shape[:] = [3] tensor_spec.name = 'foo' tensor_spec.max.uint32s.array[:] = [10] actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec) expected = specs.BoundedArray(shape=[3], dtype=np.uint32, minimum=0, maximum=10) self.assertEqual(expected, actual) self.assertEqual('foo', actual.name)
def test_scalar_with_0_min_and_no_max_bounds_gives_bounded_array(self): tensor_spec = dm_env_rpc_pb2.TensorSpec() tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32 tensor_spec.name = 'foo' tensor_spec.min.uint32s.array[:] = [0] actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec) expected = specs.BoundedArray(shape=(), dtype=np.uint32, minimum=0, maximum=2**32 - 1, name='foo') self.assertEqual(expected, actual) self.assertEqual('foo', actual.name)
def test_scalar_with_0_n_bounds_gives_discrete_array(self): tensor_spec = dm_env_rpc_pb2.TensorSpec() tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32 tensor_spec.name = 'foo' max_value = 9 tensor_spec.min.uint32s.array[:] = [0] tensor_spec.max.uint32s.array[:] = [max_value] actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec) expected = specs.DiscreteArray(num_values=max_value + 1, dtype=np.uint32, name='foo') self.assertEqual(expected, actual) self.assertEqual(0, actual.minimum) self.assertEqual(max_value, actual.maximum) self.assertEqual('foo', actual.name)
def test_bounds_oneof_not_set_gives_dtype_bounds(self): tensor_spec = dm_env_rpc_pb2.TensorSpec() tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32 tensor_spec.shape[:] = [3] tensor_spec.name = 'foo' # Just to force the message to get created. tensor_spec.min.floats.array[:] = [3.0] tensor_spec.min.ClearField('floats') actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec) expected = specs.BoundedArray(shape=[3], dtype=np.uint32, minimum=0, maximum=2**32 - 1) self.assertEqual(expected, actual) self.assertEqual('foo', actual.name)
def __init__( self, connection: dm_env_rpc_connection.Connection, specs: dm_env_rpc_pb2.ActionObservationSpecs, requested_observations: Optional[Sequence[str]] = None, nested_tensors: bool = True, extensions: Optional[Mapping[str, Any]] = immutabledict.immutabledict()): """Initializes the environment with the provided dm_env_rpc connection. Args: connection: An instance of Connection already connected to a dm_env_rpc server and after a successful JoinWorldRequest has been sent. specs: A dm_env_rpc ActionObservationSpecs message for the environment. requested_observations: List of observation names to be requested from the environment when step is called. If None is specified then all observations will be requested. nested_tensors: Boolean to determine whether to flatten/unflatten tensors. extensions: Optional mapping of extension instances to DmEnvAdaptor attributes. Raises ValueError if attribute already exists. """ self._dm_env_rpc_specs = specs self._action_specs = spec_manager.SpecManager(specs.actions) self._observation_specs = spec_manager.SpecManager(specs.observations) self._connection = connection self._last_state = dm_env_rpc_pb2.EnvironmentStateType.TERMINATED self._nested_tensors = nested_tensors if requested_observations is None: requested_observations = self._observation_specs.names() self._is_reward_requested = False self._is_discount_requested = False else: self._is_reward_requested = DEFAULT_REWARD_KEY in requested_observations self._is_discount_requested = (DEFAULT_DISCOUNT_KEY in requested_observations) requested_observations = set(requested_observations) self._default_reward_spec = None self._default_discount_spec = None if DEFAULT_REWARD_KEY in self._observation_specs.names(): self._default_reward_spec = dm_env_utils.tensor_spec_to_dm_env_spec( self._observation_specs.name_to_spec(DEFAULT_REWARD_KEY)) requested_observations.add(DEFAULT_REWARD_KEY) if DEFAULT_DISCOUNT_KEY in self._observation_specs.names(): self._default_discount_spec = ( dm_env_utils.tensor_spec_to_dm_env_spec( self._observation_specs.name_to_spec(DEFAULT_DISCOUNT_KEY)) ) requested_observations.add(DEFAULT_DISCOUNT_KEY) unsupported_observations = requested_observations.difference( self._observation_specs.names()) if unsupported_observations: raise ValueError('Unsupported observations requested: {}'.format( unsupported_observations)) self._requested_observation_uids = [ self._observation_specs.name_to_uid(name) for name in requested_observations ] # Not strictly necessary but it makes the unit tests deterministic. self._requested_observation_uids.sort() if extensions is not None: for extension_name, extension in extensions.items(): if hasattr(self, extension_name): raise ValueError( f'DmEnvAdaptor already has attribute "{extension_name}"!' ) setattr(self, extension_name, extension)