def test_cannot_set_dissimilar_shape_on_max_bounds(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec(
         name='test', shape=[2, 2], dtype=dm_env_rpc_pb2.DataType.INT32)
     with self.assertRaisesRegex(ValueError, 'incompatible'):
         tensor_spec_utils.set_bounds(tensor_spec,
                                      minimum=0,
                                      maximum=[1, 2, 3])
 def test_cannot_set_any_min_bounds_to_exceed_maximum(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec(
         name='test', shape=[2], dtype=dm_env_rpc_pb2.DataType.INT32)
     with self.assertRaisesRegex(ValueError, 'larger than max'):
         tensor_spec_utils.set_bounds(tensor_spec,
                                      minimum=[0, 4],
                                      maximum=[1, 1])
    def test_can_set_broadcastable_max_bounds(self):
        tensor_spec = dm_env_rpc_pb2.TensorSpec(
            name='test', shape=[2], dtype=dm_env_rpc_pb2.DataType.INT32)
        tensor_spec_utils.set_bounds(tensor_spec, minimum=[2, 3], maximum=4)

        self.assertEqual([2, 3], tensor_spec.min.int32s.array)
        self.assertEqual([4], tensor_spec.max.int32s.array)
 def test_cannot_set_multiple_max_bounds_on_variable_shape(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec(
         name='test', shape=[2, -1], dtype=dm_env_rpc_pb2.DataType.INT32)
     with self.assertRaisesRegex(ValueError, 'incompatible'):
         tensor_spec_utils.set_bounds(tensor_spec,
                                      minimum=1,
                                      maximum=[2, 3])
    def test_set_only_max(self):
        tensor_spec = dm_env_rpc_pb2.TensorSpec(
            name='test', dtype=dm_env_rpc_pb2.DataType.INT32)
        tensor_spec_utils.set_bounds(tensor_spec, minimum=None, maximum=1)

        self.assertIsNone(tensor_spec.min.WhichOneof('payload'))
        self.assertEqual([1], tensor_spec.max.int32s.array)
 def test_cannot_set_nonnumeric_bounds(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec(
         name='test', dtype=dm_env_rpc_pb2.DataType.STRING)
     with self.assertRaisesRegex(ValueError, 'non-numeric'):
         tensor_spec_utils.set_bounds(tensor_spec,
                                      minimum=None,
                                      maximum=None)
    def test_set_scalar_bounds(self):
        tensor_spec = dm_env_rpc_pb2.TensorSpec(
            name='test', dtype=dm_env_rpc_pb2.DataType.INT32)
        tensor_spec_utils.set_bounds(tensor_spec, minimum=1, maximum=2)

        self.assertEqual([1], tensor_spec.min.int32s.array)
        self.assertEqual([2], tensor_spec.max.int32s.array)
Пример #8
0
  def test_set_multiple_bounds(self):
    tensor_spec = dm_env_rpc_pb2.TensorSpec(
        name='test', shape=(2,), dtype=dm_env_rpc_pb2.DataType.INT32)
    tensor_spec_utils.set_bounds(tensor_spec, minimum=[1, 2], maximum=[3, 4])

    self.assertEqual([1, 2], tensor_spec.min.int32s.array)
    self.assertEqual([3, 4], tensor_spec.max.int32s.array)
Пример #9
0
 def test_set_scalar_bounds_int8(self):
   tensor_spec = dm_env_rpc_pb2.TensorSpec(dtype=dm_env_rpc_pb2.DataType.INT8)
   minimum = 1
   maximum = 2
   tensor_spec_utils.set_bounds(tensor_spec, minimum=minimum, maximum=maximum)
   self.assertEqual(np.int8(minimum).tobytes(), tensor_spec.min.int8s.array)
   self.assertEqual(np.int8(maximum).tobytes(), tensor_spec.max.int8s.array)
Пример #10
0
def _action_spec():
    """Returns the action spec."""
    paddle_action_spec = dm_env_rpc_pb2.TensorSpec(dtype=dm_env_rpc_pb2.INT8,
                                                   name=_ACTION_PADDLE)
    tensor_spec_utils.set_bounds(paddle_action_spec,
                                 minimum=np.min(_VALID_ACTIONS),
                                 maximum=np.max(_VALID_ACTIONS))
    return {1: paddle_action_spec}
Пример #11
0
    def test_unset_min_and_max(self):
        tensor_spec = dm_env_rpc_pb2.TensorSpec(
            name='test', dtype=dm_env_rpc_pb2.DataType.INT32)
        tensor_spec_utils.set_bounds(tensor_spec, minimum=1, maximum=2)
        tensor_spec_utils.set_bounds(tensor_spec, minimum=None, maximum=None)

        self.assertIsNone(tensor_spec.min.WhichOneof('payload'))
        self.assertIsNone(tensor_spec.max.WhichOneof('payload'))
Пример #12
0
 def test_new_bounds_must_be_safely_castable_to_dtype(self, minimum, maximum):
   name = 'test'
   dtype = dm_env_rpc_pb2.DataType.INT8
   tensor_spec = dm_env_rpc_pb2.TensorSpec(name=name, dtype=dtype)
   with self.assertRaisesWithLiteralMatch(
       ValueError,
       tensor_spec_utils._BOUNDS_CANNOT_BE_SAFELY_CAST_TO_DTYPE.format(
           name=name,
           minimum=minimum,
           maximum=maximum,
           dtype=dm_env_rpc_pb2.DataType.Name(dtype))):
     tensor_spec_utils.set_bounds(
         tensor_spec, minimum=minimum, maximum=maximum)