예제 #1
0
def set_bounds(tensor_spec: dm_env_rpc_pb2.TensorSpec, minimum, maximum):
  """Modifies `tensor_spec` to have its inclusive bounds set.

  Packs `minimum` in to `tensor_spec.min` and `maximum` in to `tensor_spec.max`.

  Args:
    tensor_spec: An instance of a dm_env_rpc TensorSpec proto.  It should
      already have its `name`, `dtype` and `shape` attributes set.
    minimum: The minimum value that elements in the described tensor can obtain.
      A scalar, iterable of scalars, or None.  If None, `min` will be cleared on
      `tensor_spec`.
    maximum: The maximum value that elements in the described tensor can obtain.
      A scalar, iterable of scalars, or None.  If None, `max` will be cleared on
      `tensor_spec`.
  """
  np_type = tensor_utils.data_type_to_np_type(tensor_spec.dtype)
  if not issubclass(np_type, np.number):
    raise ValueError(f'TensorSpec has non-numeric type "{np_type}".')

  np_type_bounds = _np_range_info(np_type)
  has_min = minimum is not None
  has_max = maximum is not None

  if has_min:
    minimum = np.asarray(minimum)
    if minimum.size != 1 and minimum.shape != tuple(tensor_spec.shape):
      raise ValueError(
          f'minimum has shape {minimum.shape}, which is incompatible with '
          f"tensor_spec {tensor_spec.name}'s shape {tensor_spec.shape}.")

  if has_max:
    maximum = np.asarray(maximum)
    if maximum.size != 1 and maximum.shape != tuple(tensor_spec.shape):
      raise ValueError(
          f'maximum has shape {maximum.shape}, which is incompatible with '
          f"tensor_spec {tensor_spec.name}'s shape {tensor_spec.shape}.")

  if ((has_min and not _can_cast(minimum, np_type)) or
      (has_max and not _can_cast(maximum, np_type))):
    raise ValueError(
        _BOUNDS_CANNOT_BE_SAFELY_CAST_TO_DTYPE.format(
            name=tensor_spec.name,
            minimum=minimum,
            maximum=maximum,
            dtype=dm_env_rpc_pb2.DataType.Name(tensor_spec.dtype)))

  if (has_min and has_max and np.any(maximum < minimum)):
    raise ValueError('TensorSpec "{}" has min {} larger than max {}.'.format(
        tensor_spec.name, minimum, maximum))

  packer = tensor_utils.get_packer(np_type)
  if has_min:
    packer.pack(tensor_spec.min, minimum)
  else:
    tensor_spec.ClearField('min')

  if has_max:
    packer.pack(tensor_spec.max, maximum)
  else:
    tensor_spec.ClearField('max')
예제 #2
0
 def test_can_unpack(self):
     packer = tensor_utils.get_packer(np.int32)
     tensor = dm_env_rpc_pb2.Tensor()
     tensor.int32s.array[:] = [1, 2, 3]
     np.testing.assert_array_equal([1, 2, 3], packer.unpack(tensor))
예제 #3
0
 def test_can_pack(self):
     packer = tensor_utils.get_packer(np.int32)
     tensor = dm_env_rpc_pb2.Tensor()
     packer.pack(tensor, np.asarray([1, 2, 3]))
     self.assertEqual([1, 2, 3], tensor.int32s.array)
예제 #4
0
 def test_cannot_get_packer_for_invalid_type(self):
     with self.assertRaisesRegex(TypeError, 'complex64'):
         tensor_utils.get_packer(np.complex64)