def test_cannot_set_dissimilar_shape_on_max_bounds(self): tensor_spec = dm_env_rpc_pb2.TensorSpec( name='test', shape=[2, 2], dtype=dm_env_rpc_pb2.DataType.INT32) with self.assertRaisesRegex(ValueError, 'incompatible'): tensor_spec_utils.set_bounds(tensor_spec, minimum=0, maximum=[1, 2, 3])
def test_cannot_set_any_min_bounds_to_exceed_maximum(self): tensor_spec = dm_env_rpc_pb2.TensorSpec( name='test', shape=[2], dtype=dm_env_rpc_pb2.DataType.INT32) with self.assertRaisesRegex(ValueError, 'larger than max'): tensor_spec_utils.set_bounds(tensor_spec, minimum=[0, 4], maximum=[1, 1])
def test_can_set_broadcastable_max_bounds(self): tensor_spec = dm_env_rpc_pb2.TensorSpec( name='test', shape=[2], dtype=dm_env_rpc_pb2.DataType.INT32) tensor_spec_utils.set_bounds(tensor_spec, minimum=[2, 3], maximum=4) self.assertEqual([2, 3], tensor_spec.min.int32s.array) self.assertEqual([4], tensor_spec.max.int32s.array)
def test_cannot_set_multiple_max_bounds_on_variable_shape(self): tensor_spec = dm_env_rpc_pb2.TensorSpec( name='test', shape=[2, -1], dtype=dm_env_rpc_pb2.DataType.INT32) with self.assertRaisesRegex(ValueError, 'incompatible'): tensor_spec_utils.set_bounds(tensor_spec, minimum=1, maximum=[2, 3])
def test_set_only_max(self): tensor_spec = dm_env_rpc_pb2.TensorSpec( name='test', dtype=dm_env_rpc_pb2.DataType.INT32) tensor_spec_utils.set_bounds(tensor_spec, minimum=None, maximum=1) self.assertIsNone(tensor_spec.min.WhichOneof('payload')) self.assertEqual([1], tensor_spec.max.int32s.array)
def test_cannot_set_nonnumeric_bounds(self): tensor_spec = dm_env_rpc_pb2.TensorSpec( name='test', dtype=dm_env_rpc_pb2.DataType.STRING) with self.assertRaisesRegex(ValueError, 'non-numeric'): tensor_spec_utils.set_bounds(tensor_spec, minimum=None, maximum=None)
def test_set_scalar_bounds(self): tensor_spec = dm_env_rpc_pb2.TensorSpec( name='test', dtype=dm_env_rpc_pb2.DataType.INT32) tensor_spec_utils.set_bounds(tensor_spec, minimum=1, maximum=2) self.assertEqual([1], tensor_spec.min.int32s.array) self.assertEqual([2], tensor_spec.max.int32s.array)
def test_set_multiple_bounds(self): tensor_spec = dm_env_rpc_pb2.TensorSpec( name='test', shape=(2,), dtype=dm_env_rpc_pb2.DataType.INT32) tensor_spec_utils.set_bounds(tensor_spec, minimum=[1, 2], maximum=[3, 4]) self.assertEqual([1, 2], tensor_spec.min.int32s.array) self.assertEqual([3, 4], tensor_spec.max.int32s.array)
def test_set_scalar_bounds_int8(self): tensor_spec = dm_env_rpc_pb2.TensorSpec(dtype=dm_env_rpc_pb2.DataType.INT8) minimum = 1 maximum = 2 tensor_spec_utils.set_bounds(tensor_spec, minimum=minimum, maximum=maximum) self.assertEqual(np.int8(minimum).tobytes(), tensor_spec.min.int8s.array) self.assertEqual(np.int8(maximum).tobytes(), tensor_spec.max.int8s.array)
def _action_spec(): """Returns the action spec.""" paddle_action_spec = dm_env_rpc_pb2.TensorSpec(dtype=dm_env_rpc_pb2.INT8, name=_ACTION_PADDLE) tensor_spec_utils.set_bounds(paddle_action_spec, minimum=np.min(_VALID_ACTIONS), maximum=np.max(_VALID_ACTIONS)) return {1: paddle_action_spec}
def test_unset_min_and_max(self): tensor_spec = dm_env_rpc_pb2.TensorSpec( name='test', dtype=dm_env_rpc_pb2.DataType.INT32) tensor_spec_utils.set_bounds(tensor_spec, minimum=1, maximum=2) tensor_spec_utils.set_bounds(tensor_spec, minimum=None, maximum=None) self.assertIsNone(tensor_spec.min.WhichOneof('payload')) self.assertIsNone(tensor_spec.max.WhichOneof('payload'))
def test_new_bounds_must_be_safely_castable_to_dtype(self, minimum, maximum): name = 'test' dtype = dm_env_rpc_pb2.DataType.INT8 tensor_spec = dm_env_rpc_pb2.TensorSpec(name=name, dtype=dtype) with self.assertRaisesWithLiteralMatch( ValueError, tensor_spec_utils._BOUNDS_CANNOT_BE_SAFELY_CAST_TO_DTYPE.format( name=name, minimum=minimum, maximum=maximum, dtype=dm_env_rpc_pb2.DataType.Name(dtype))): tensor_spec_utils.set_bounds( tensor_spec, minimum=minimum, maximum=maximum)