Exemplo n.º 1
0
def test_common_failures_reset():
    """
    Test that common failure cases of the `reset_method` are caught
    """
    env = IdentityEnvBox()
    # Return an observation that does not match the observation_space
    check_reset_assert_error(env, np.ones((3, )))
    # The observation is not a numpy array
    check_reset_assert_error(env, 1)

    # Return not only the observation
    check_reset_assert_error(env, (env.observation_space.sample(), False))

    env = SimpleMultiObsEnv()
    obs = env.reset()

    def wrong_reset(self):
        return {"img": obs["img"], "vec": obs["img"]}

    env.reset = types.MethodType(wrong_reset, env)
    with pytest.raises(AssertionError) as excinfo:
        check_env(env)

    # Check that the key is explicitly mentioned
    assert "vec" in str(excinfo.value)
Exemplo n.º 2
0
def test_common_failures_step():
    """
    Test that common failure cases of the `step` method are caught
    """
    env = IdentityEnvBox()

    # Wrong shape for the observation
    check_step_assert_error(env, (np.ones((4, )), 1.0, False, {}))
    # Obs is not a numpy array
    check_step_assert_error(env, (1, 1.0, False, {}))

    # Return a wrong reward
    check_step_assert_error(
        env, (env.observation_space.sample(), np.ones(1), False, {}))

    # Info dict is not returned
    check_step_assert_error(env, (env.observation_space.sample(), 0.0, False))

    # Done is not a boolean
    check_step_assert_error(env,
                            (env.observation_space.sample(), 0.0, 3.0, {}))
    check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, {}))

    env = SimpleMultiObsEnv()
    obs = env.reset()

    def wrong_step(self, action):
        return {"img": obs["vec"], "vec": obs["vec"]}, 0.0, False, {}

    env.step = types.MethodType(wrong_step, env)
    with pytest.raises(AssertionError) as excinfo:
        check_env(env)

    # Check that the key is explicitly mentioned
    assert "img" in str(excinfo.value)
Exemplo n.º 3
0
def test_dict_vec_framestack(model_class, channel_last):
    """
    Additional tests to check observation space support
    for Dictionary spaces and VecEnvWrapper using MultiInputPolicy.
    """
    use_discrete_actions = model_class not in [TQC]
    channels_order = {"vec": None, "img": "last" if channel_last else "first"}
    env = DummyVecEnv([
        lambda: SimpleMultiObsEnv(random_start=True,
                                  discrete_actions=use_discrete_actions,
                                  channel_last=channel_last)
    ])

    env = VecFrameStack(env, n_stack=3, channels_order=channels_order)

    kwargs = {}
    n_steps = 256

    if model_class in {}:
        kwargs = dict(
            n_steps=128,
            policy_kwargs=dict(
                net_arch=[32],
                features_extractor_kwargs=dict(cnn_output_dim=32),
            ),
        )
    else:
        # Avoid memory error when using replay buffer
        # Reduce the size of the features and make learning faster
        kwargs = dict(
            buffer_size=250,
            policy_kwargs=dict(
                net_arch=[32],
                features_extractor_kwargs=dict(cnn_output_dim=32),
                n_quantiles=20,
            ),
            train_freq=8,
            gradient_steps=1,
        )
        if model_class == QRDQN:
            kwargs["learning_starts"] = 0

    model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs)

    model.learn(total_timesteps=n_steps)

    evaluate_policy(model, env, n_eval_episodes=5, warn=False)