Ejemplo n.º 1
0
    def test_step(self):
        action_spec = tensor_spec.BoundedTensorSpec((), tf.int32, 0, 4)
        env = _build_test_env(action_spec=action_spec)
        mock_env = mock.Mock(wraps=env)
        wrapper = tf_wrappers.OneHotActionWrapper(mock_env)
        wrapper.reset()

        wrapper.step(tf.constant([[0, 1, 0, 0, 0], [0, 0, 0, 1, 0]], tf.int32))
        self.assertTrue(mock_env.step.called)
        self.assertAllEqual([1, 3], mock_env.step.call_args[0][0])
Ejemplo n.º 2
0
 def test_action_spec(self):
     action_spec = tensor_spec.BoundedTensorSpec((1, ), tf.int32, 0, 4)
     env = _build_test_env(action_spec=action_spec)
     wrapper = tf_wrappers.OneHotActionWrapper(env)
     expected_action_spec = tensor_spec.BoundedTensorSpec(
         shape=(1, 5),
         dtype=tf.int32,
         minimum=0,
         maximum=1,
         name='one_hot_action_spec')
     self.assertEqual(expected_action_spec, wrapper.action_spec())
Ejemplo n.º 3
0
 def test_raises_invalid_action_spec(self):
     action_spec = tensor_spec.BoundedTensorSpec((1, 1), tf.int32, 0, 4)
     with self.assertRaisesRegexp(ValueError, 'at most one dimension'):
         tf_wrappers.OneHotActionWrapper(
             _build_test_env(action_spec=action_spec))