def test_extra_env_methods_work(self): cartpole_env = gym.make('CartPole-v1') env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) env = alf_wrappers.TimeLimit(env, 2) self.assertEqual(None, env.get_info()) env.reset() action = np.array(0, dtype=np.int64) env.step(action) self.assertEqual({}, env.get_info())
def test_limit_duration_stops_after_duration(self): cartpole_env = gym.make('CartPole-v1') env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) env = alf_wrappers.TimeLimit(env, 2) env.reset() action = np.array(0, dtype=np.int64) env.step(action) time_step = env.step(action) self.assertTrue(time_step.is_last()) self.assertNotEqual(None, time_step.discount) self.assertNotEqual(0.0, time_step.discount)
def test_limit_duration_wrapped_env_forwards_calls(self): cartpole_env = gym.spec('CartPole-v1').make() env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) env = alf_wrappers.TimeLimit(env, 10) action_spec = env.action_spec() self.assertEqual((), action_spec.shape) self.assertEqual(0, action_spec.minimum) self.assertEqual(1, action_spec.maximum) observation_spec = env.observation_spec() self.assertEqual((4, ), observation_spec.shape) high = np.array([ 4.8, np.finfo(np.float32).max, 2 / 15.0 * math.pi, np.finfo(np.float32).max ]) np.testing.assert_array_almost_equal(-high, observation_spec.minimum) np.testing.assert_array_almost_equal(high, observation_spec.maximum)
def test_automatic_reset(self): cartpole_env = gym.make('CartPole-v1') env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) env = alf_wrappers.TimeLimit(env, 2) # Episode 1 action = np.array(0, dtype=np.int64) first_time_step = env.step(action) self.assertTrue(first_time_step.is_first()) mid_time_step = env.step(action) self.assertTrue(mid_time_step.is_mid()) last_time_step = env.step(action) self.assertTrue(last_time_step.is_last()) # Episode 2 first_time_step = env.step(action) self.assertTrue(first_time_step.is_first()) mid_time_step = env.step(action) self.assertTrue(mid_time_step.is_mid()) last_time_step = env.step(action) self.assertTrue(last_time_step.is_last())
def test_duration_applied_after_episode_terminates_early(self): cartpole_env = gym.make('CartPole-v1') env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) env = alf_wrappers.TimeLimit(env, 10000) # Episode 1 stepped until termination occurs. action = np.array(1, dtype=np.int64) time_step = env.step(action) while not time_step.is_last(): time_step = env.step(action) self.assertTrue(time_step.is_last()) env._duration = 2 # Episode 2 short duration hits step limit. action = np.array(0, dtype=np.int64) first_time_step = env.step(action) self.assertTrue(first_time_step.is_first()) mid_time_step = env.step(action) self.assertTrue(mid_time_step.is_mid()) last_time_step = env.step(action) self.assertTrue(last_time_step.is_last())