def testStepHybrid(self): obs_spec = array_spec.BoundedArraySpec((2, 3), np.int32, -10, 10) action_spec = { 'discrete': array_spec.BoundedArraySpec((1,), np.int32, 1, 3), 'continuous': array_spec.ArraySpec((2,), np.float32) } mock_env = mock.Mock( wraps=random_py_environment.RandomPyEnvironment(obs_spec, action_spec)) one_hot_action_wrapper = wrappers.OneHotActionWrapper(mock_env) one_hot_action_wrapper.reset() action = { 'discrete': np.array([[0, 1, 0]]).astype(np.int32), 'continuous': np.array([0.5, 0.3]).astype(np.float32) } one_hot_action_wrapper.step(action) self.assertTrue(mock_env.step.called) expected_action = { 'discrete': np.array([2]), 'continuous': np.array([0.5, 0.3]) } np.testing.assert_array_almost_equal( expected_action['discrete'], mock_env.step.call_args[0][0]['discrete']) np.testing.assert_array_almost_equal( expected_action['continuous'], mock_env.step.call_args[0][0]['continuous'])
def testActionSpec(self): cartpole_env = gym.spec('CartPole-v1').make() env = gym_wrapper.GymWrapper(cartpole_env) one_hot_action_wrapper = wrappers.OneHotActionWrapper(env) expected_spec = array_spec.BoundedArraySpec(shape=(2, ), dtype=np.int64, minimum=0, maximum=1, name='one_hot_action_spec') self.assertEqual(one_hot_action_wrapper.action_spec(), expected_spec)
def testStep(self): obs_spec = array_spec.BoundedArraySpec((2, 3), np.int32, -10, 10) action_spec = array_spec.BoundedArraySpec((1, ), np.int32, 0, 2) mock_env = mock.Mock(wraps=random_py_environment.RandomPyEnvironment( obs_spec, action_spec)) one_hot_action_wrapper = wrappers.OneHotActionWrapper(mock_env) one_hot_action_wrapper.reset() one_hot_action_wrapper.step(np.array([0.5, 0.4, 0.1])) self.assertTrue(mock_env.step.called) np.testing.assert_array_equal(0, mock_env.step.call_args[0][0])