def test_step(self): action_spec = tensor_spec.BoundedTensorSpec((), tf.int32, 0, 4) env = _build_test_env(action_spec=action_spec) mock_env = mock.Mock(wraps=env) wrapper = tf_wrappers.OneHotActionWrapper(mock_env) wrapper.reset() wrapper.step(tf.constant([[0, 1, 0, 0, 0], [0, 0, 0, 1, 0]], tf.int32)) self.assertTrue(mock_env.step.called) self.assertAllEqual([1, 3], mock_env.step.call_args[0][0])
def test_action_spec(self): action_spec = tensor_spec.BoundedTensorSpec((1, ), tf.int32, 0, 4) env = _build_test_env(action_spec=action_spec) wrapper = tf_wrappers.OneHotActionWrapper(env) expected_action_spec = tensor_spec.BoundedTensorSpec( shape=(1, 5), dtype=tf.int32, minimum=0, maximum=1, name='one_hot_action_spec') self.assertEqual(expected_action_spec, wrapper.action_spec())
def test_raises_invalid_action_spec(self): action_spec = tensor_spec.BoundedTensorSpec((1, 1), tf.int32, 0, 4) with self.assertRaisesRegexp(ValueError, 'at most one dimension'): tf_wrappers.OneHotActionWrapper( _build_test_env(action_spec=action_spec))