コード例 #1
0
def _observation_spec():
    """Returns the observation spec."""
    return {
        1:
        dm_env_rpc_pb2.TensorSpec(name=_OBSERVATION_BOARD,
                                  shape=[_NUM_ROWS, _NUM_COLUMNS],
                                  dtype=dm_env_rpc_pb2.FLOAT),
        2:
        dm_env_rpc_pb2.TensorSpec(name=_OBSERVATION_REWARD,
                                  dtype=dm_env_rpc_pb2.FLOAT)
    }
コード例 #2
0
 def setUp(self):
     super(SpecManagerTests, self).setUp()
     specs = {
         54:
         dm_env_rpc_pb2.TensorSpec(name='fuzz',
                                   shape=[2],
                                   dtype=dm_env_rpc_pb2.DataType.FLOAT),
         55:
         dm_env_rpc_pb2.TensorSpec(name='foo',
                                   shape=[3],
                                   dtype=dm_env_rpc_pb2.DataType.INT32),
     }
     self._spec_manager = spec_manager.SpecManager(specs)
コード例 #3
0
 def test_duplicate_names_raise_error(self):
     specs = {
         54:
         dm_env_rpc_pb2.TensorSpec(name='fuzz',
                                   shape=[3],
                                   dtype=dm_env_rpc_pb2.DataType.FLOAT),
         55:
         dm_env_rpc_pb2.TensorSpec(name='fuzz',
                                   shape=[2],
                                   dtype=dm_env_rpc_pb2.DataType.FLOAT),
     }
     with self.assertRaisesRegex(ValueError, 'duplicate name'):
         spec_manager.SpecManager(specs)
コード例 #4
0
ファイル: dm_env_rpc_test.py プロジェクト: rsfb/dm_env_rpc
 def test_setting_spec(self):
   tensor_spec = dm_env_rpc_pb2.TensorSpec()
   tensor_spec.name = 'Foo'
   tensor_spec.min.floats.array[:] = [0.0]
   tensor_spec.max.floats.array[:] = [0.0]
   tensor_spec.shape[:] = [2, 2]
   tensor_spec.dtype = dm_env_rpc_pb2.DataType.FLOAT
コード例 #5
0
  def test_set_multiple_bounds(self):
    tensor_spec = dm_env_rpc_pb2.TensorSpec(
        name='test', shape=(2,), dtype=dm_env_rpc_pb2.DataType.INT32)
    tensor_spec_utils.set_bounds(tensor_spec, minimum=[1, 2], maximum=[3, 4])

    self.assertEqual([1, 2], tensor_spec.min.int32s.array)
    self.assertEqual([3, 4], tensor_spec.max.int32s.array)
コード例 #6
0
 def test_set_scalar_bounds_int8(self):
   tensor_spec = dm_env_rpc_pb2.TensorSpec(dtype=dm_env_rpc_pb2.DataType.INT8)
   minimum = 1
   maximum = 2
   tensor_spec_utils.set_bounds(tensor_spec, minimum=minimum, maximum=maximum)
   self.assertEqual(np.int8(minimum).tobytes(), tensor_spec.min.int8s.array)
   self.assertEqual(np.int8(maximum).tobytes(), tensor_spec.max.int8s.array)
コード例 #7
0
 def test_min_and_max_legacy(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.INT32
     tensor_spec.min.int32 = -1
     tensor_spec.max.int32 = 1
     bounds = tensor_spec_utils.bounds(tensor_spec)
     self.assertEqual((-1, 1), bounds)
コード例 #8
0
def _action_spec():
    """Returns the action spec."""
    return {
        1:
        dm_env_rpc_pb2.TensorSpec(dtype=dm_env_rpc_pb2.INT8,
                                  name=_ACTION_PADDLE)
    }
コード例 #9
0
 def test_min_and_max(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.INT32
     tensor_spec.min.int32s.array[:] = [-1]
     tensor_spec.max.int32s.array[:] = [1]
     bounds = tensor_spec_utils.bounds(tensor_spec)
     self.assertEqual((-1, 1), bounds)
コード例 #10
0
    def test_set_scalar_bounds(self):
        tensor_spec = dm_env_rpc_pb2.TensorSpec(
            name='test', dtype=dm_env_rpc_pb2.DataType.INT32)
        tensor_spec_utils.set_bounds(tensor_spec, minimum=1, maximum=2)

        self.assertEqual([1], tensor_spec.min.int32s.array)
        self.assertEqual([2], tensor_spec.max.int32s.array)
コード例 #11
0
 def test_invalid_min_shape(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
     tensor_spec.min.uint32s.array[:] = [1, 2]
     with self.assertRaisesRegex(
             ValueError, 'Scalar tensors must have exactly 1 element.*'):
         tensor_spec_utils.bounds(tensor_spec)
コード例 #12
0
 def test_cannot_set_any_min_bounds_to_exceed_maximum(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec(
         name='test', shape=[2], dtype=dm_env_rpc_pb2.DataType.INT32)
     with self.assertRaisesRegex(ValueError, 'larger than max'):
         tensor_spec_utils.set_bounds(tensor_spec,
                                      minimum=[0, 4],
                                      maximum=[1, 1])
コード例 #13
0
    def test_set_only_max(self):
        tensor_spec = dm_env_rpc_pb2.TensorSpec(
            name='test', dtype=dm_env_rpc_pb2.DataType.INT32)
        tensor_spec_utils.set_bounds(tensor_spec, minimum=None, maximum=1)

        self.assertIsNone(tensor_spec.min.WhichOneof('payload'))
        self.assertEqual([1], tensor_spec.max.int32s.array)
コード例 #14
0
 def test_cannot_set_multiple_max_bounds_on_variable_shape(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec(
         name='test', shape=[2, -1], dtype=dm_env_rpc_pb2.DataType.INT32)
     with self.assertRaisesRegex(ValueError, 'incompatible'):
         tensor_spec_utils.set_bounds(tensor_spec,
                                      minimum=1,
                                      maximum=[2, 3])
コード例 #15
0
 def test_cannot_set_dissimilar_shape_on_max_bounds(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec(
         name='test', shape=[2, 2], dtype=dm_env_rpc_pb2.DataType.INT32)
     with self.assertRaisesRegex(ValueError, 'incompatible'):
         tensor_spec_utils.set_bounds(tensor_spec,
                                      minimum=0,
                                      maximum=[1, 2, 3])
コード例 #16
0
 def test_cannot_set_nonnumeric_bounds(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec(
         name='test', dtype=dm_env_rpc_pb2.DataType.STRING)
     with self.assertRaisesRegex(ValueError, 'non-numeric'):
         tensor_spec_utils.set_bounds(tensor_spec,
                                      minimum=None,
                                      maximum=None)
コード例 #17
0
    def test_can_set_broadcastable_max_bounds(self):
        tensor_spec = dm_env_rpc_pb2.TensorSpec(
            name='test', shape=[2], dtype=dm_env_rpc_pb2.DataType.INT32)
        tensor_spec_utils.set_bounds(tensor_spec, minimum=[2, 3], maximum=4)

        self.assertEqual([2, 3], tensor_spec.min.int32s.array)
        self.assertEqual([4], tensor_spec.max.int32s.array)
コード例 #18
0
 def test_infinite_bounds_are_valid_for_floats(self, minimum, maximum):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.DOUBLE
     tensor_spec.min.double = minimum
     tensor_spec.max.double = maximum
     tensor_spec.name = 'foo'
     tensor_spec_utils.bounds(tensor_spec)
コード例 #19
0
 def test_nonnumeric_type_raises_error(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.STRING
     tensor_spec.max.int32s.array[:] = [1]
     tensor_spec.name = 'foo'
     with self.assertRaisesRegex(ValueError, 'foo.*non-numeric.*string'):
         tensor_spec_utils.bounds(tensor_spec)
コード例 #20
0
 def test_max_mismatches_type_raises_error(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
     tensor_spec.max.int32s.array[:] = [1]
     tensor_spec.name = 'foo'
     with self.assertRaisesRegex(ValueError, 'foo.*uint32.*max.*int32'):
         tensor_spec_utils.bounds(tensor_spec)
コード例 #21
0
 def test_no_bounds_gives_arrayspec(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
     tensor_spec.shape[:] = [3]
     tensor_spec.name = 'foo'
     actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
     self.assertEqual(specs.Array(shape=[3], dtype=np.uint32), actual)
     self.assertEqual('foo', actual.name)
コード例 #22
0
 def test_string_give_string_array(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.STRING
     tensor_spec.shape[:] = [1, 2, 3]
     tensor_spec.name = 'string_spec'
     actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
     self.assertEqual(specs.StringArray(shape=[1, 2, 3]), actual)
     self.assertEqual('string_spec', actual.name)
コード例 #23
0
 def test_invalid_max_shape(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
     tensor_spec.max.uint32s.array[:] = [1, 2]
     tensor_spec.shape[:] = (2, 2)
     with self.assertRaisesRegex(
             ValueError, 'cannot reshape array of size .* into shape.*'):
         tensor_spec_utils.bounds(tensor_spec)
コード例 #24
0
 def test_max_broadcast(self):
   tensor_spec = dm_env_rpc_pb2.TensorSpec()
   tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
   tensor_spec.max.uint32s.array[:] = [1]
   tensor_spec.shape[:] = (2, 2)
   bounds = tensor_spec_utils.bounds(tensor_spec)
   np.testing.assert_array_equal(np.full(tensor_spec.shape, 0), bounds.min)
   np.testing.assert_array_equal(np.full(tensor_spec.shape, 1), bounds.max)
コード例 #25
0
 def test_max_less_than_min_raises_error(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.INT32
     tensor_spec.max.int32s.array[:] = [-1]
     tensor_spec.min.int32s.array[:] = [1]
     tensor_spec.name = 'foo'
     with self.assertRaisesRegex(ValueError, 'foo.*min 1.*max -1'):
         tensor_spec_utils.bounds(tensor_spec)
コード例 #26
0
 def test_broadcast_var_shape(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.INT32
     tensor_spec.min.int32s.array[:] = [-1]
     tensor_spec.max.int32s.array[:] = [1]
     tensor_spec.shape[:] = (-1, )
     bounds = tensor_spec_utils.bounds(tensor_spec)
     self.assertEqual((-1, 1), bounds)
コード例 #27
0
def _action_spec():
    """Returns the action spec."""
    paddle_action_spec = dm_env_rpc_pb2.TensorSpec(dtype=dm_env_rpc_pb2.INT8,
                                                   name=_ACTION_PADDLE)
    tensor_spec_utils.set_bounds(paddle_action_spec,
                                 minimum=np.min(_VALID_ACTIONS),
                                 maximum=np.max(_VALID_ACTIONS))
    return {1: paddle_action_spec}
コード例 #28
0
    def test_unset_min_and_max(self):
        tensor_spec = dm_env_rpc_pb2.TensorSpec(
            name='test', dtype=dm_env_rpc_pb2.DataType.INT32)
        tensor_spec_utils.set_bounds(tensor_spec, minimum=1, maximum=2)
        tensor_spec_utils.set_bounds(tensor_spec, minimum=None, maximum=None)

        self.assertIsNone(tensor_spec.min.WhichOneof('payload'))
        self.assertIsNone(tensor_spec.max.WhichOneof('payload'))
コード例 #29
0
 def test_max_scalar_doesnt_broadcast(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
     tensor_spec.max.uint32s.array[:] = [1]
     tensor_spec.shape[:] = (2, 2)
     bounds = tensor_spec_utils.bounds(tensor_spec)
     self.assertEqual(0, bounds.min)
     self.assertEqual(1, bounds.max)
コード例 #30
0
 def test_bounds_wrong_type_gives_error(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
     tensor_spec.shape[:] = [3]
     tensor_spec.name = 'foo'
     tensor_spec.min.floats.array[:] = [1.9]
     with self.assertRaisesRegex(ValueError, 'uint32'):
         dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)