def test_discrete_space_spec(self):
     discrete_space = Discrete(100)
     spec = gym_spaces_utils.gym_space_spec(discrete_space)
     self.assertIsInstance(spec, tf.FixedLenFeature)
     self.assertEqual(spec.dtype, tf.int64)
     self.assertListEqual(list(spec.shape), [1])
 def test_box_space_spec(self):
     box_space = Box(low=0, high=10, shape=[5, 6], dtype=np.float32)
     spec = gym_spaces_utils.gym_space_spec(box_space)
     self.assertIsInstance(spec, tf.FixedLenFeature)
     self.assertEqual(spec.dtype, tf.float32)
     self.assertListEqual(list(spec.shape), [5, 6])
예제 #3
0
 def action_spec(self):
   """The spec for reading an observation stored in a tf.Example."""
   return gym_spaces_utils.gym_space_spec(self.action_space)