示例#1
0
    def test_extra_env_methods_work(self):
        cartpole_env = gym.make('CartPole-v1')
        env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)
        env = alf_wrappers.TimeLimit(env, 2)

        self.assertEqual(None, env.get_info())
        env.reset()
        action = np.array(0, dtype=np.int64)
        env.step(action)
        self.assertEqual({}, env.get_info())
示例#2
0
    def test_limit_duration_stops_after_duration(self):
        cartpole_env = gym.make('CartPole-v1')
        env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)
        env = alf_wrappers.TimeLimit(env, 2)

        env.reset()
        action = np.array(0, dtype=np.int64)
        env.step(action)
        time_step = env.step(action)

        self.assertTrue(time_step.is_last())
        self.assertNotEqual(None, time_step.discount)
        self.assertNotEqual(0.0, time_step.discount)
示例#3
0
    def test_limit_duration_wrapped_env_forwards_calls(self):
        cartpole_env = gym.spec('CartPole-v1').make()
        env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)
        env = alf_wrappers.TimeLimit(env, 10)

        action_spec = env.action_spec()
        self.assertEqual((), action_spec.shape)
        self.assertEqual(0, action_spec.minimum)
        self.assertEqual(1, action_spec.maximum)

        observation_spec = env.observation_spec()
        self.assertEqual((4, ), observation_spec.shape)
        high = np.array([
            4.8,
            np.finfo(np.float32).max, 2 / 15.0 * math.pi,
            np.finfo(np.float32).max
        ])
        np.testing.assert_array_almost_equal(-high, observation_spec.minimum)
        np.testing.assert_array_almost_equal(high, observation_spec.maximum)
示例#4
0
    def test_automatic_reset(self):
        cartpole_env = gym.make('CartPole-v1')
        env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)
        env = alf_wrappers.TimeLimit(env, 2)

        # Episode 1
        action = np.array(0, dtype=np.int64)
        first_time_step = env.step(action)
        self.assertTrue(first_time_step.is_first())
        mid_time_step = env.step(action)
        self.assertTrue(mid_time_step.is_mid())
        last_time_step = env.step(action)
        self.assertTrue(last_time_step.is_last())

        # Episode 2
        first_time_step = env.step(action)
        self.assertTrue(first_time_step.is_first())
        mid_time_step = env.step(action)
        self.assertTrue(mid_time_step.is_mid())
        last_time_step = env.step(action)
        self.assertTrue(last_time_step.is_last())
示例#5
0
    def test_duration_applied_after_episode_terminates_early(self):
        cartpole_env = gym.make('CartPole-v1')
        env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)
        env = alf_wrappers.TimeLimit(env, 10000)

        # Episode 1 stepped until termination occurs.
        action = np.array(1, dtype=np.int64)
        time_step = env.step(action)
        while not time_step.is_last():
            time_step = env.step(action)

        self.assertTrue(time_step.is_last())
        env._duration = 2

        # Episode 2 short duration hits step limit.
        action = np.array(0, dtype=np.int64)
        first_time_step = env.step(action)
        self.assertTrue(first_time_step.is_first())
        mid_time_step = env.step(action)
        self.assertTrue(mid_time_step.is_mid())
        last_time_step = env.step(action)
        self.assertTrue(last_time_step.is_last())