Example #1
0
 def test_bounds_wrong_type_gives_error(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
     tensor_spec.shape[:] = [3]
     tensor_spec.name = 'foo'
     tensor_spec.min.floats.array[:] = [1.9]
     with self.assertRaisesRegex(ValueError, 'uint32'):
         dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
Example #2
0
 def test_bounds_on_string_gives_error(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.STRING
     tensor_spec.shape[:] = [2]
     tensor_spec.name = 'named'
     tensor_spec.min.floats.array[:] = [1.9]
     tensor_spec.max.floats.array[:] = [10.0]
     with self.assertRaisesRegex(ValueError, 'string'):
         dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
Example #3
0
  def __init__(self, connection, specs, requested_observations=None):
    """Initializes the environment with the provided dm_env_rpc connection.

    Args:
      connection: An instance of Connection already connected to a dm_env_rpc
        server and after a successful JoinWorldRequest has been sent.
      specs: A dm_env_rpc ActionObservationSpecs message for the environment.
      requested_observations: The observation names to be requested from the
        environment when step is called. If None is specified then all
        observations will be requested.
    """
    self._dm_env_rpc_specs = specs
    self._action_specs = spec_manager.SpecManager(specs.actions)
    self._observation_specs = spec_manager.SpecManager(specs.observations)
    self._connection = connection
    self._last_state = dm_env_rpc_pb2.EnvironmentStateType.TERMINATED

    if requested_observations is None:
      requested_observations = self._observation_specs.names()
    requested_observations = set(requested_observations)

    self._is_reward_requested = DEFAULT_REWARD_KEY in requested_observations
    self._is_discount_requested = DEFAULT_DISCOUNT_KEY in requested_observations

    self._default_reward_spec = None
    self._default_discount_spec = None
    if DEFAULT_REWARD_KEY in self._observation_specs.names():
      self._default_reward_spec = dm_env_utils.tensor_spec_to_dm_env_spec(
          self._observation_specs.name_to_spec(DEFAULT_REWARD_KEY))
      requested_observations.add(DEFAULT_REWARD_KEY)
    if DEFAULT_DISCOUNT_KEY in self._observation_specs.names():
      self._default_discount_spec = (
          dm_env_utils.tensor_spec_to_dm_env_spec(
              self._observation_specs.name_to_spec(DEFAULT_DISCOUNT_KEY)))
      requested_observations.add(DEFAULT_DISCOUNT_KEY)

    unsupported_observations = requested_observations.difference(
        self._observation_specs.names())
    if unsupported_observations:
      raise ValueError('Unsupported observations requested: {}'.format(
          unsupported_observations))
    self._requested_observation_uids = [
        self._observation_specs.name_to_uid(name)
        for name in requested_observations
    ]

    # Not strictly necessary but it makes the unit tests deterministic.
    self._requested_observation_uids.sort()
Example #4
0
 def test_no_bounds_gives_arrayspec(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
     tensor_spec.shape[:] = [3]
     tensor_spec.name = 'foo'
     actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
     self.assertEqual(specs.Array(shape=[3], dtype=np.uint32), actual)
     self.assertEqual('foo', actual.name)
Example #5
0
 def test_string_give_string_array(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.STRING
     tensor_spec.shape[:] = [1, 2, 3]
     tensor_spec.name = 'string_spec'
     actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
     self.assertEqual(specs.StringArray(shape=[1, 2, 3]), actual)
     self.assertEqual('string_spec', actual.name)
Example #6
0
    def spec(self) -> Optional[dm_env_specs.Array]:
        """Returns a dm_env spec if the property has a valid dtype.

    Returns:
      Either a dm_env spec or, if the dtype is invalid, None.
    """
        if self._property_spec_proto.spec.dtype != (
                dm_env_rpc_pb2.DataType.INVALID_DATA_TYPE):
            return dm_env_utils.tensor_spec_to_dm_env_spec(
                self._property_spec_proto.spec)
        else:
            return None
Example #7
0
    def test_spec_generate_and_validate_scalars(self):
        dm_env_rpc_specs = []
        for name, dtype in dm_env_rpc_pb2.DataType.items():
            if dtype != dm_env_rpc_pb2.DataType.INVALID_DATA_TYPE:
                dm_env_rpc_specs.append(
                    dm_env_rpc_pb2.TensorSpec(name=name, shape=(),
                                              dtype=dtype))

        for dm_env_rpc_spec in dm_env_rpc_specs:
            spec = dm_env_utils.tensor_spec_to_dm_env_spec(dm_env_rpc_spec)
            value = spec.generate_value()
            spec.validate(value)
Example #8
0
 def observation_spec(self):
   """Implements dm_env.Environment.observation_spec."""
   specs = {}
   for uid in self._requested_observation_uids:
     name = self._observation_specs.uid_to_name(uid)
     specs[name] = dm_env_utils.tensor_spec_to_dm_env_spec(
         self._observation_specs.uid_to_spec(uid))
   if not self._is_reward_requested:
     specs.pop(DEFAULT_REWARD_KEY, None)
   if not self._is_discount_requested:
     specs.pop(DEFAULT_DISCOUNT_KEY, None)
   return specs
Example #9
0
 def test_only_max_bounds(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
     tensor_spec.shape[:] = [3]
     tensor_spec.name = 'foo'
     tensor_spec.max.uint32s.array[:] = [10]
     actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
     expected = specs.BoundedArray(shape=[3],
                                   dtype=np.uint32,
                                   minimum=0,
                                   maximum=10)
     self.assertEqual(expected, actual)
     self.assertEqual('foo', actual.name)
Example #10
0
 def test_scalar_with_0_min_and_no_max_bounds_gives_bounded_array(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
     tensor_spec.name = 'foo'
     tensor_spec.min.uint32s.array[:] = [0]
     actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
     expected = specs.BoundedArray(shape=(),
                                   dtype=np.uint32,
                                   minimum=0,
                                   maximum=2**32 - 1,
                                   name='foo')
     self.assertEqual(expected, actual)
     self.assertEqual('foo', actual.name)
Example #11
0
    def test_scalar_with_0_n_bounds_gives_discrete_array(self):
        tensor_spec = dm_env_rpc_pb2.TensorSpec()
        tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
        tensor_spec.name = 'foo'

        max_value = 9
        tensor_spec.min.uint32s.array[:] = [0]
        tensor_spec.max.uint32s.array[:] = [max_value]
        actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
        expected = specs.DiscreteArray(num_values=max_value + 1,
                                       dtype=np.uint32,
                                       name='foo')
        self.assertEqual(expected, actual)
        self.assertEqual(0, actual.minimum)
        self.assertEqual(max_value, actual.maximum)
        self.assertEqual('foo', actual.name)
Example #12
0
    def test_bounds_oneof_not_set_gives_dtype_bounds(self):
        tensor_spec = dm_env_rpc_pb2.TensorSpec()
        tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
        tensor_spec.shape[:] = [3]
        tensor_spec.name = 'foo'

        # Just to force the message to get created.
        tensor_spec.min.floats.array[:] = [3.0]
        tensor_spec.min.ClearField('floats')

        actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
        expected = specs.BoundedArray(shape=[3],
                                      dtype=np.uint32,
                                      minimum=0,
                                      maximum=2**32 - 1)
        self.assertEqual(expected, actual)
        self.assertEqual('foo', actual.name)
Example #13
0
    def __init__(
        self,
        connection: dm_env_rpc_connection.Connection,
        specs: dm_env_rpc_pb2.ActionObservationSpecs,
        requested_observations: Optional[Sequence[str]] = None,
        nested_tensors: bool = True,
        extensions: Optional[Mapping[str,
                                     Any]] = immutabledict.immutabledict()):
        """Initializes the environment with the provided dm_env_rpc connection.

    Args:
      connection: An instance of Connection already connected to a dm_env_rpc
        server and after a successful JoinWorldRequest has been sent.
      specs: A dm_env_rpc ActionObservationSpecs message for the environment.
      requested_observations: List of observation names to be requested from the
        environment when step is called. If None is specified then all
        observations will be requested.
      nested_tensors: Boolean to determine whether to flatten/unflatten tensors.
      extensions: Optional mapping of extension instances to DmEnvAdaptor
        attributes. Raises ValueError if attribute already exists.
    """
        self._dm_env_rpc_specs = specs
        self._action_specs = spec_manager.SpecManager(specs.actions)
        self._observation_specs = spec_manager.SpecManager(specs.observations)
        self._connection = connection
        self._last_state = dm_env_rpc_pb2.EnvironmentStateType.TERMINATED
        self._nested_tensors = nested_tensors

        if requested_observations is None:
            requested_observations = self._observation_specs.names()
            self._is_reward_requested = False
            self._is_discount_requested = False
        else:
            self._is_reward_requested = DEFAULT_REWARD_KEY in requested_observations
            self._is_discount_requested = (DEFAULT_DISCOUNT_KEY
                                           in requested_observations)

        requested_observations = set(requested_observations)

        self._default_reward_spec = None
        self._default_discount_spec = None
        if DEFAULT_REWARD_KEY in self._observation_specs.names():
            self._default_reward_spec = dm_env_utils.tensor_spec_to_dm_env_spec(
                self._observation_specs.name_to_spec(DEFAULT_REWARD_KEY))
            requested_observations.add(DEFAULT_REWARD_KEY)
        if DEFAULT_DISCOUNT_KEY in self._observation_specs.names():
            self._default_discount_spec = (
                dm_env_utils.tensor_spec_to_dm_env_spec(
                    self._observation_specs.name_to_spec(DEFAULT_DISCOUNT_KEY))
            )
            requested_observations.add(DEFAULT_DISCOUNT_KEY)

        unsupported_observations = requested_observations.difference(
            self._observation_specs.names())
        if unsupported_observations:
            raise ValueError('Unsupported observations requested: {}'.format(
                unsupported_observations))
        self._requested_observation_uids = [
            self._observation_specs.name_to_uid(name)
            for name in requested_observations
        ]

        # Not strictly necessary but it makes the unit tests deterministic.
        self._requested_observation_uids.sort()

        if extensions is not None:
            for extension_name, extension in extensions.items():
                if hasattr(self, extension_name):
                    raise ValueError(
                        f'DmEnvAdaptor already has attribute "{extension_name}"!'
                    )
                setattr(self, extension_name, extension)