def test_automatic_reset_after_create(self): cartpole_env = gym.spec('CartPole-v1').make() env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) action = np.array(0, dtype=np.int64) first_time_step = env.step(action) self.assertTrue(first_time_step.is_first())
def test_default_batch_properties(self): cartpole_env = gym.spec('CartPole-v1').make() env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) self.assertFalse(env.batched) self.assertEqual(env.batch_size, 1) wrap_env = alf_wrappers.AlfEnvironmentBaseWrapper(env) self.assertEqual(wrap_env.batched, env.batched) self.assertEqual(wrap_env.batch_size, env.batch_size)
def test_get_info(self): cartpole_env = gym.spec('CartPole-v1').make() env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) self.assertEqual(None, env.get_info()) env.reset() self.assertEqual(None, env.get_info()) action = np.array(0, dtype=np.int64) env.step(action) self.assertEqual({}, env.get_info())
def test_wrapped_cartpole_reset(self): cartpole_env = gym.spec('CartPole-v1').make() env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) first_time_step = env.reset() self.assertTrue(first_time_step.is_first()) self.assertEqual(0.0, first_time_step.reward) self.assertEqual(1.0, first_time_step.discount) self.assertEqual((4, ), first_time_step.observation.shape) self.assertEqual("float32", str(first_time_step.observation.dtype))
def test_extra_env_methods_work(self): cartpole_env = gym.make('CartPole-v1') env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) env = alf_wrappers.TimeLimit(env, 2) self.assertEqual(None, env.get_info()) env.reset() action = np.array(0, dtype=np.int64) env.step(action) self.assertEqual({}, env.get_info())
def test_wrapped_cartpole_transition(self): cartpole_env = gym.spec('CartPole-v1').make() env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) env.reset() action = np.array(0, dtype=np.int64) transition_time_step = env.step(action) self.assertTrue(transition_time_step.is_mid()) self.assertNotEqual(None, transition_time_step.reward) self.assertEqual(1.0, transition_time_step.discount) self.assertEqual((4, ), transition_time_step.observation.shape)
def test_method_propagation(self): cartpole_env = gym.spec('CartPole-v1').make() for method_name in ('render', 'seed', 'close'): setattr(cartpole_env, method_name, mock.MagicMock()) env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) env.render() self.assertEqual(1, cartpole_env.render.call_count) env.seed(0) self.assertEqual(1, cartpole_env.seed.call_count) cartpole_env.seed.assert_called_with(0) env.close() self.assertEqual(1, cartpole_env.close.call_count)
def test_automatic_reset_after_done_not_using_reset_directly(self): cartpole_env = gym.spec('CartPole-v1').make() env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) action = np.array(1, dtype=np.int64) time_step = env.step(action) while not time_step.is_last(): time_step = env.step(action) self.assertTrue(time_step.is_last()) action = np.array(0, dtype=np.int64) first_time_step = env.step(action) self.assertTrue(first_time_step.is_first())
def test_wrapped_cartpole_final(self): cartpole_env = gym.spec('CartPole-v1').make() env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) time_step = env.reset() action = np.array(1, dtype=np.int64) while not time_step.is_last(): time_step = env.step(action) self.assertTrue(time_step.is_last()) self.assertNotEqual(None, time_step.reward) self.assertEqual(0.0, time_step.discount) self.assertEqual((4, ), time_step.observation.shape)
def test_limit_duration_stops_after_duration(self): cartpole_env = gym.make('CartPole-v1') env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) env = alf_wrappers.TimeLimit(env, 2) env.reset() action = np.array(0, dtype=np.int64) env.step(action) time_step = env.step(action) self.assertTrue(time_step.is_last()) self.assertNotEqual(None, time_step.discount) self.assertNotEqual(0.0, time_step.discount)
def test_limit_duration_wrapped_env_forwards_calls(self): cartpole_env = gym.spec('CartPole-v1').make() env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) env = alf_wrappers.TimeLimit(env, 10) action_spec = env.action_spec() self.assertEqual((), action_spec.shape) self.assertEqual(0, action_spec.minimum) self.assertEqual(1, action_spec.maximum) observation_spec = env.observation_spec() self.assertEqual((4, ), observation_spec.shape) high = np.array([ 4.8, np.finfo(np.float32).max, 2 / 15.0 * math.pi, np.finfo(np.float32).max ]) np.testing.assert_array_almost_equal(-high, observation_spec.minimum) np.testing.assert_array_almost_equal(high, observation_spec.maximum)
def test_wrapped_cartpole_specs(self): # Note we use spec.make on gym envs to avoid getting a TimeLimit wrapper on # the environment. cartpole_env = gym.spec('CartPole-v1').make() env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) action_spec = env.action_spec() self.assertEqual((), action_spec.shape) self.assertEqual(0, action_spec.minimum) self.assertEqual(1, action_spec.maximum) observation_spec = env.observation_spec() self.assertEqual((4, ), observation_spec.shape) self.assertEqual(torch.float32, observation_spec.dtype) high = np.array([ 4.8, np.finfo(np.float32).max, 2 / 15.0 * math.pi, np.finfo(np.float32).max ]) np.testing.assert_array_almost_equal(-high, observation_spec.minimum) np.testing.assert_array_almost_equal(high, observation_spec.maximum)
def test_automatic_reset(self): cartpole_env = gym.make('CartPole-v1') env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) env = alf_wrappers.TimeLimit(env, 2) # Episode 1 action = np.array(0, dtype=np.int64) first_time_step = env.step(action) self.assertTrue(first_time_step.is_first()) mid_time_step = env.step(action) self.assertTrue(mid_time_step.is_mid()) last_time_step = env.step(action) self.assertTrue(last_time_step.is_last()) # Episode 2 first_time_step = env.step(action) self.assertTrue(first_time_step.is_first()) mid_time_step = env.step(action) self.assertTrue(mid_time_step.is_mid()) last_time_step = env.step(action) self.assertTrue(last_time_step.is_last())
def test_duration_applied_after_episode_terminates_early(self): cartpole_env = gym.make('CartPole-v1') env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) env = alf_wrappers.TimeLimit(env, 10000) # Episode 1 stepped until termination occurs. action = np.array(1, dtype=np.int64) time_step = env.step(action) while not time_step.is_last(): time_step = env.step(action) self.assertTrue(time_step.is_last()) env._duration = 2 # Episode 2 short duration hits step limit. action = np.array(0, dtype=np.int64) first_time_step = env.step(action) self.assertTrue(first_time_step.is_first()) mid_time_step = env.step(action) self.assertTrue(mid_time_step.is_mid()) last_time_step = env.step(action) self.assertTrue(last_time_step.is_last())
def wrap_env(gym_env, env_id=None, discount=1.0, max_episode_steps=0, gym_env_wrappers=(), time_limit_wrapper=alf_wrappers.TimeLimit, normalize_action=True, clip_action=True, alf_env_wrappers=(), image_channel_first=True, auto_reset=True): """Wraps given gym environment with AlfGymWrapper. Note that by default a TimeLimit wrapper is used to limit episode lengths to the default benchmarks defined by the registered environments. Also note that all gym wrappers assume images are 'channel_last' by default, while PyTorch only supports 'channel_first' image inputs. To enable this transpose, 'image_channel_first' is set as True by default. ``gym_wrappers.ImageChannelFirst`` is applied after all gym_env_wrappers and before the AlfGymWrapper. Args: gym_env (gym.Env): An instance of OpenAI gym environment. env_id (int): (optional) ID of the environment. discount (float): Discount to use for the environment. max_episode_steps (int): Used to create a TimeLimitWrapper. No limit is applied if set to 0. Usually set to `gym_spec.max_episode_steps` as done in `load. gym_env_wrappers (Iterable): Iterable with references to gym_wrappers, classes to use directly on the gym environment. time_limit_wrapper (AlfEnvironmentBaseWrapper): Wrapper that accepts (env, max_episode_steps) params to enforce a TimeLimit. Usuaully this should be left as the default, alf_wrappers.TimeLimit. normalize_action (bool): if True, will scale continuous actions to ``[-1, 1]`` to be better used by algorithms that compute entropies. clip_action (bool): If True, will clip continuous action to its bound specified by ``action_spec``. If ``normalize_action`` is also ``True``, this clipping happens after the normalization (i.e., clips to ``[-1, 1]``). alf_env_wrappers (Iterable): Iterable with references to alf_wrappers classes to use on the ALF environment. image_channel_first (bool): whether transpose image channels to first dimension. PyTorch only supports channgel_first image inputs. auto_reset (bool): If True (default), reset the environment automatically after a terminal state is reached. Returns: An AlfEnvironment instance. """ for wrapper in gym_env_wrappers: gym_env = wrapper(gym_env) # To apply channel_first transpose on gym (py) env if image_channel_first: gym_env = gym_wrappers.ImageChannelFirst(gym_env) if normalize_action: # normalize continuous actions to [-1, 1] gym_env = gym_wrappers.NormalizedAction(gym_env) if clip_action: # clip continuous actions according to gym_env.action_space gym_env = gym_wrappers.ContinuousActionClip(gym_env) env = alf_gym_wrapper.AlfGymWrapper( gym_env=gym_env, env_id=env_id, discount=discount, auto_reset=auto_reset, ) if max_episode_steps > 0: env = time_limit_wrapper(env, max_episode_steps) for wrapper in alf_env_wrappers: env = wrapper(env) return env
def test_obs_dtype(self): cartpole_env = gym.spec('CartPole-v1').make() env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) time_step = env.reset() self.assertEqual(torch_dtype_to_str(env.observation_spec().dtype), str(time_step.observation.dtype))