Beispiel #1
0
  def testStepHybrid(self):
    obs_spec = array_spec.BoundedArraySpec((2, 3), np.int32, -10, 10)
    action_spec = {
        'discrete':
            array_spec.BoundedArraySpec((1,), np.int32, 1, 3),
        'continuous':
            array_spec.ArraySpec((2,), np.float32)
    }
    mock_env = mock.Mock(
        wraps=random_py_environment.RandomPyEnvironment(obs_spec, action_spec))
    one_hot_action_wrapper = wrappers.OneHotActionWrapper(mock_env)
    one_hot_action_wrapper.reset()

    action = {
        'discrete':
            np.array([[0, 1, 0]]).astype(np.int32),
        'continuous':
            np.array([0.5, 0.3]).astype(np.float32)
    }

    one_hot_action_wrapper.step(action)
    self.assertTrue(mock_env.step.called)

    expected_action = {
        'discrete':
            np.array([2]),
        'continuous':
            np.array([0.5, 0.3])
    }
    np.testing.assert_array_almost_equal(
        expected_action['discrete'], mock_env.step.call_args[0][0]['discrete'])
    np.testing.assert_array_almost_equal(
        expected_action['continuous'],
        mock_env.step.call_args[0][0]['continuous'])
Beispiel #2
0
 def testActionSpec(self):
     cartpole_env = gym.spec('CartPole-v1').make()
     env = gym_wrapper.GymWrapper(cartpole_env)
     one_hot_action_wrapper = wrappers.OneHotActionWrapper(env)
     expected_spec = array_spec.BoundedArraySpec(shape=(2, ),
                                                 dtype=np.int64,
                                                 minimum=0,
                                                 maximum=1,
                                                 name='one_hot_action_spec')
     self.assertEqual(one_hot_action_wrapper.action_spec(), expected_spec)
Beispiel #3
0
    def testStep(self):
        obs_spec = array_spec.BoundedArraySpec((2, 3), np.int32, -10, 10)
        action_spec = array_spec.BoundedArraySpec((1, ), np.int32, 0, 2)
        mock_env = mock.Mock(wraps=random_py_environment.RandomPyEnvironment(
            obs_spec, action_spec))
        one_hot_action_wrapper = wrappers.OneHotActionWrapper(mock_env)
        one_hot_action_wrapper.reset()

        one_hot_action_wrapper.step(np.array([0.5, 0.4, 0.1]))
        self.assertTrue(mock_env.step.called)
        np.testing.assert_array_equal(0, mock_env.step.call_args[0][0])