def set_bounds(tensor_spec: dm_env_rpc_pb2.TensorSpec, minimum, maximum): """Modifies `tensor_spec` to have its inclusive bounds set. Packs `minimum` in to `tensor_spec.min` and `maximum` in to `tensor_spec.max`. Args: tensor_spec: An instance of a dm_env_rpc TensorSpec proto. It should already have its `name`, `dtype` and `shape` attributes set. minimum: The minimum value that elements in the described tensor can obtain. A scalar, iterable of scalars, or None. If None, `min` will be cleared on `tensor_spec`. maximum: The maximum value that elements in the described tensor can obtain. A scalar, iterable of scalars, or None. If None, `max` will be cleared on `tensor_spec`. """ np_type = tensor_utils.data_type_to_np_type(tensor_spec.dtype) if not issubclass(np_type, np.number): raise ValueError(f'TensorSpec has non-numeric type "{np_type}".') np_type_bounds = _np_range_info(np_type) has_min = minimum is not None has_max = maximum is not None if has_min: minimum = np.asarray(minimum) if minimum.size != 1 and minimum.shape != tuple(tensor_spec.shape): raise ValueError( f'minimum has shape {minimum.shape}, which is incompatible with ' f"tensor_spec {tensor_spec.name}'s shape {tensor_spec.shape}.") if has_max: maximum = np.asarray(maximum) if maximum.size != 1 and maximum.shape != tuple(tensor_spec.shape): raise ValueError( f'maximum has shape {maximum.shape}, which is incompatible with ' f"tensor_spec {tensor_spec.name}'s shape {tensor_spec.shape}.") if ((has_min and not _can_cast(minimum, np_type)) or (has_max and not _can_cast(maximum, np_type))): raise ValueError( _BOUNDS_CANNOT_BE_SAFELY_CAST_TO_DTYPE.format( name=tensor_spec.name, minimum=minimum, maximum=maximum, dtype=dm_env_rpc_pb2.DataType.Name(tensor_spec.dtype))) if (has_min and has_max and np.any(maximum < minimum)): raise ValueError('TensorSpec "{}" has min {} larger than max {}.'.format( tensor_spec.name, minimum, maximum)) packer = tensor_utils.get_packer(np_type) if has_min: packer.pack(tensor_spec.min, minimum) else: tensor_spec.ClearField('min') if has_max: packer.pack(tensor_spec.max, maximum) else: tensor_spec.ClearField('max')
def test_can_unpack(self): packer = tensor_utils.get_packer(np.int32) tensor = dm_env_rpc_pb2.Tensor() tensor.int32s.array[:] = [1, 2, 3] np.testing.assert_array_equal([1, 2, 3], packer.unpack(tensor))
def test_can_pack(self): packer = tensor_utils.get_packer(np.int32) tensor = dm_env_rpc_pb2.Tensor() packer.pack(tensor, np.asarray([1, 2, 3])) self.assertEqual([1, 2, 3], tensor.int32s.array)
def test_cannot_get_packer_for_invalid_type(self): with self.assertRaisesRegex(TypeError, 'complex64'): tensor_utils.get_packer(np.complex64)