Пример #1
0
    def test_automatic_reset_after_create(self):
        cartpole_env = gym.spec('CartPole-v1').make()
        env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)

        action = np.array(0, dtype=np.int64)
        first_time_step = env.step(action)
        self.assertTrue(first_time_step.is_first())
Пример #2
0
 def test_default_batch_properties(self):
     cartpole_env = gym.spec('CartPole-v1').make()
     env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)
     self.assertFalse(env.batched)
     self.assertEqual(env.batch_size, 1)
     wrap_env = alf_wrappers.AlfEnvironmentBaseWrapper(env)
     self.assertEqual(wrap_env.batched, env.batched)
     self.assertEqual(wrap_env.batch_size, env.batch_size)
Пример #3
0
 def test_get_info(self):
     cartpole_env = gym.spec('CartPole-v1').make()
     env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)
     self.assertEqual(None, env.get_info())
     env.reset()
     self.assertEqual(None, env.get_info())
     action = np.array(0, dtype=np.int64)
     env.step(action)
     self.assertEqual({}, env.get_info())
Пример #4
0
    def test_wrapped_cartpole_reset(self):
        cartpole_env = gym.spec('CartPole-v1').make()
        env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)

        first_time_step = env.reset()
        self.assertTrue(first_time_step.is_first())
        self.assertEqual(0.0, first_time_step.reward)
        self.assertEqual(1.0, first_time_step.discount)
        self.assertEqual((4, ), first_time_step.observation.shape)
        self.assertEqual("float32", str(first_time_step.observation.dtype))
Пример #5
0
    def test_extra_env_methods_work(self):
        cartpole_env = gym.make('CartPole-v1')
        env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)
        env = alf_wrappers.TimeLimit(env, 2)

        self.assertEqual(None, env.get_info())
        env.reset()
        action = np.array(0, dtype=np.int64)
        env.step(action)
        self.assertEqual({}, env.get_info())
Пример #6
0
    def test_wrapped_cartpole_transition(self):
        cartpole_env = gym.spec('CartPole-v1').make()
        env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)
        env.reset()
        action = np.array(0, dtype=np.int64)
        transition_time_step = env.step(action)

        self.assertTrue(transition_time_step.is_mid())
        self.assertNotEqual(None, transition_time_step.reward)
        self.assertEqual(1.0, transition_time_step.discount)
        self.assertEqual((4, ), transition_time_step.observation.shape)
Пример #7
0
 def test_method_propagation(self):
     cartpole_env = gym.spec('CartPole-v1').make()
     for method_name in ('render', 'seed', 'close'):
         setattr(cartpole_env, method_name, mock.MagicMock())
     env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)
     env.render()
     self.assertEqual(1, cartpole_env.render.call_count)
     env.seed(0)
     self.assertEqual(1, cartpole_env.seed.call_count)
     cartpole_env.seed.assert_called_with(0)
     env.close()
     self.assertEqual(1, cartpole_env.close.call_count)
Пример #8
0
    def test_automatic_reset_after_done_not_using_reset_directly(self):
        cartpole_env = gym.spec('CartPole-v1').make()
        env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)
        action = np.array(1, dtype=np.int64)
        time_step = env.step(action)

        while not time_step.is_last():
            time_step = env.step(action)

        self.assertTrue(time_step.is_last())
        action = np.array(0, dtype=np.int64)
        first_time_step = env.step(action)
        self.assertTrue(first_time_step.is_first())
Пример #9
0
    def test_wrapped_cartpole_final(self):
        cartpole_env = gym.spec('CartPole-v1').make()
        env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)
        time_step = env.reset()

        action = np.array(1, dtype=np.int64)
        while not time_step.is_last():
            time_step = env.step(action)

        self.assertTrue(time_step.is_last())
        self.assertNotEqual(None, time_step.reward)
        self.assertEqual(0.0, time_step.discount)
        self.assertEqual((4, ), time_step.observation.shape)
Пример #10
0
    def test_limit_duration_stops_after_duration(self):
        cartpole_env = gym.make('CartPole-v1')
        env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)
        env = alf_wrappers.TimeLimit(env, 2)

        env.reset()
        action = np.array(0, dtype=np.int64)
        env.step(action)
        time_step = env.step(action)

        self.assertTrue(time_step.is_last())
        self.assertNotEqual(None, time_step.discount)
        self.assertNotEqual(0.0, time_step.discount)
Пример #11
0
    def test_limit_duration_wrapped_env_forwards_calls(self):
        cartpole_env = gym.spec('CartPole-v1').make()
        env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)
        env = alf_wrappers.TimeLimit(env, 10)

        action_spec = env.action_spec()
        self.assertEqual((), action_spec.shape)
        self.assertEqual(0, action_spec.minimum)
        self.assertEqual(1, action_spec.maximum)

        observation_spec = env.observation_spec()
        self.assertEqual((4, ), observation_spec.shape)
        high = np.array([
            4.8,
            np.finfo(np.float32).max, 2 / 15.0 * math.pi,
            np.finfo(np.float32).max
        ])
        np.testing.assert_array_almost_equal(-high, observation_spec.minimum)
        np.testing.assert_array_almost_equal(high, observation_spec.maximum)
Пример #12
0
    def test_wrapped_cartpole_specs(self):
        # Note we use spec.make on gym envs to avoid getting a TimeLimit wrapper on
        # the environment.
        cartpole_env = gym.spec('CartPole-v1').make()
        env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)

        action_spec = env.action_spec()
        self.assertEqual((), action_spec.shape)
        self.assertEqual(0, action_spec.minimum)
        self.assertEqual(1, action_spec.maximum)

        observation_spec = env.observation_spec()
        self.assertEqual((4, ), observation_spec.shape)
        self.assertEqual(torch.float32, observation_spec.dtype)
        high = np.array([
            4.8,
            np.finfo(np.float32).max, 2 / 15.0 * math.pi,
            np.finfo(np.float32).max
        ])
        np.testing.assert_array_almost_equal(-high, observation_spec.minimum)
        np.testing.assert_array_almost_equal(high, observation_spec.maximum)
Пример #13
0
    def test_automatic_reset(self):
        cartpole_env = gym.make('CartPole-v1')
        env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)
        env = alf_wrappers.TimeLimit(env, 2)

        # Episode 1
        action = np.array(0, dtype=np.int64)
        first_time_step = env.step(action)
        self.assertTrue(first_time_step.is_first())
        mid_time_step = env.step(action)
        self.assertTrue(mid_time_step.is_mid())
        last_time_step = env.step(action)
        self.assertTrue(last_time_step.is_last())

        # Episode 2
        first_time_step = env.step(action)
        self.assertTrue(first_time_step.is_first())
        mid_time_step = env.step(action)
        self.assertTrue(mid_time_step.is_mid())
        last_time_step = env.step(action)
        self.assertTrue(last_time_step.is_last())
Пример #14
0
    def test_duration_applied_after_episode_terminates_early(self):
        cartpole_env = gym.make('CartPole-v1')
        env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)
        env = alf_wrappers.TimeLimit(env, 10000)

        # Episode 1 stepped until termination occurs.
        action = np.array(1, dtype=np.int64)
        time_step = env.step(action)
        while not time_step.is_last():
            time_step = env.step(action)

        self.assertTrue(time_step.is_last())
        env._duration = 2

        # Episode 2 short duration hits step limit.
        action = np.array(0, dtype=np.int64)
        first_time_step = env.step(action)
        self.assertTrue(first_time_step.is_first())
        mid_time_step = env.step(action)
        self.assertTrue(mid_time_step.is_mid())
        last_time_step = env.step(action)
        self.assertTrue(last_time_step.is_last())
Пример #15
0
def wrap_env(gym_env,
             env_id=None,
             discount=1.0,
             max_episode_steps=0,
             gym_env_wrappers=(),
             time_limit_wrapper=alf_wrappers.TimeLimit,
             normalize_action=True,
             clip_action=True,
             alf_env_wrappers=(),
             image_channel_first=True,
             auto_reset=True):
    """Wraps given gym environment with AlfGymWrapper.

    Note that by default a TimeLimit wrapper is used to limit episode lengths
    to the default benchmarks defined by the registered environments.

    Also note that all gym wrappers assume images are 'channel_last' by default,
    while PyTorch only supports 'channel_first' image inputs. To enable this
    transpose, 'image_channel_first' is set as True by default. ``gym_wrappers.ImageChannelFirst``
    is applied after all gym_env_wrappers and before the AlfGymWrapper.

    Args:
        gym_env (gym.Env): An instance of OpenAI gym environment.
        env_id (int): (optional) ID of the environment.
        discount (float): Discount to use for the environment.
        max_episode_steps (int): Used to create a TimeLimitWrapper. No limit is applied
            if set to 0. Usually set to `gym_spec.max_episode_steps` as done in `load.
        gym_env_wrappers (Iterable): Iterable with references to gym_wrappers,
            classes to use directly on the gym environment.
        time_limit_wrapper (AlfEnvironmentBaseWrapper): Wrapper that accepts
            (env, max_episode_steps) params to enforce a TimeLimit. Usuaully this
            should be left as the default, alf_wrappers.TimeLimit.
        normalize_action (bool): if True, will scale continuous actions to
            ``[-1, 1]`` to be better used by algorithms that compute entropies.
        clip_action (bool): If True, will clip continuous action to its bound specified
            by ``action_spec``. If ``normalize_action`` is also ``True``, this
            clipping happens after the normalization (i.e., clips to ``[-1, 1]``).
        alf_env_wrappers (Iterable): Iterable with references to alf_wrappers
            classes to use on the ALF environment.
        image_channel_first (bool): whether transpose image channels to first dimension.
            PyTorch only supports channgel_first image inputs.
        auto_reset (bool): If True (default), reset the environment automatically after a
            terminal state is reached.

    Returns:
        An AlfEnvironment instance.
    """

    for wrapper in gym_env_wrappers:
        gym_env = wrapper(gym_env)

    # To apply channel_first transpose on gym (py) env
    if image_channel_first:
        gym_env = gym_wrappers.ImageChannelFirst(gym_env)

    if normalize_action:
        # normalize continuous actions to [-1, 1]
        gym_env = gym_wrappers.NormalizedAction(gym_env)

    if clip_action:
        # clip continuous actions according to gym_env.action_space
        gym_env = gym_wrappers.ContinuousActionClip(gym_env)

    env = alf_gym_wrapper.AlfGymWrapper(
        gym_env=gym_env,
        env_id=env_id,
        discount=discount,
        auto_reset=auto_reset,
    )

    if max_episode_steps > 0:
        env = time_limit_wrapper(env, max_episode_steps)

    for wrapper in alf_env_wrappers:
        env = wrapper(env)

    return env
Пример #16
0
 def test_obs_dtype(self):
     cartpole_env = gym.spec('CartPole-v1').make()
     env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)
     time_step = env.reset()
     self.assertEqual(torch_dtype_to_str(env.observation_spec().dtype),
                      str(time_step.observation.dtype))