Ejemplo n.º 1
0
def test_env(spec):
    # Capture warnings
    with pytest.warns(None) as warnings:
        env = spec.make()

    # Test if env adheres to Gym API
    check_env(env, skip_render_check=True)

    # Check that dtype is explicitly declared for gym.Box spaces
    for warning_msg in warnings:
        assert "autodetected dtype" not in str(warning_msg.message)

    ob_space = env.observation_space
    act_space = env.action_space
    ob = env.reset()
    assert ob_space.contains(ob), f"Reset observation: {ob!r} not in space"
    if isinstance(ob_space, Box):
        # Only checking dtypes for Box spaces to avoid iterating through tuple entries
        assert (
            ob.dtype == ob_space.dtype
        ), f"Reset observation dtype: {ob.dtype}, expected: {ob_space.dtype}"

    a = act_space.sample()
    observation, reward, done, _info = env.step(a)
    assert ob_space.contains(
        observation), f"Step observation: {observation!r} not in space"
    assert np.isscalar(reward), f"{reward} is not a scalar for {env}"
    assert isinstance(done, bool), f"Expected {done} to be a boolean"
    if isinstance(ob_space, Box):
        assert (
            observation.dtype == ob_space.dtype
        ), f"Step observation dtype: {ob.dtype}, expected: {ob_space.dtype}"

    env.close()
Ejemplo n.º 2
0
def test_gym_wrapper():
    print("Testing Gym wrapper compatibility with gym API")
    for env_name in vizdoom_envs:
        for frame_skip in [1, 4]:
            env = gym.make(env_name,
                           frame_skip=frame_skip,
                           max_buttons_pressed=0)

            # Test if env adheres to Gym API
            check_env(env, skip_render_check=True)

            ob_space = env.observation_space
            act_space = env.action_space
            ob = env.reset()
            assert ob_space.contains(
                ob), f"Reset observation: {ob!r} not in space"

            a = act_space.sample()
            observation, reward, done, _info = env.step(a)
            assert ob_space.contains(
                observation), f"Step observation: {observation!r} not in space"
            assert np.isscalar(reward), f"{reward} is not a scalar for {env}"
            assert isinstance(done, bool), f"Expected {done} to be a boolean"

            env.close()
Ejemplo n.º 3
0
def test_env(spec):
    # Capture warnings
    with pytest.warns(None) as warnings:
        env = spec.make()

    # Test if env adheres to Gym API
    check_env(env, warn=True, skip_render_check=True)

    # Check that dtype is explicitly declared for gym.Box spaces
    for warning_msg in warnings:
        assert "autodetected dtype" not in str(warning_msg.message)

    ob_space = env.observation_space
    act_space = env.action_space
    ob = env.reset()
    assert ob_space.contains(
        ob), "Reset observation: {!r} not in space".format(ob)
    a = act_space.sample()
    observation, reward, done, _info = env.step(a)
    assert ob_space.contains(
        observation), "Step observation: {!r} not in space".format(observation)
    assert np.isscalar(reward), "{} is not a scalar for {}".format(reward, env)
    assert isinstance(done, bool), "Expected {} to be a boolean".format(done)

    for mode in env.metadata.get("render.modes", []):
        env.render(mode=mode)

    # Make sure we can render the environment after close.
    for mode in env.metadata.get("render.modes", []):
        env.render(mode=mode)

    env.close()
Ejemplo n.º 4
0
def test_api():
    env = SumoEnvironment(single_agent=True,
                          num_seconds=100000,
                          net_file='nets/single-intersection/single-intersection.net.xml',
                          route_file='nets/single-intersection/single-intersection.rou.xml')
    env.reset()
    check_env(env)
    env.close()
Ejemplo n.º 5
0
def test_check_env_dict_action():
    # Environment.step() only returns 3 values: obs, reward, done. Not info!
    test_env = ActionDictTestEnv()

    with pytest.raises(AssertionError) as errorinfo:
        check_env(env=test_env, warn=True)
        assert (
            str(errorinfo.value) ==
            "The `step()` method must return four values: obs, reward, done, info"
        )
Ejemplo n.º 6
0
def make(
    id: str | EnvSpec,
    max_episode_steps: Optional[int] = None,
    autoreset: bool = False,
    disable_env_checker: bool = False,
    **kwargs,
) -> Env:
    """
    Create an environment according to the given ID.

    Warnings:
        In v0.24, `gym.utils.env_checker.env_checker` is run for every initialised environment.
        This calls the :meth:`Env.reset`, :meth:`Env.step` and :meth:`Env.render` functions to valid
        if they follow the gym API. To disable this feature, set parameter `disable_env_checker=True`.

    Args:
        id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
        max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
        autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
        disable_env_checker: If to disable the environment checker
        kwargs: Additional arguments to pass to the environment constructor.
    Returns:
        An instance of the environment.
    """
    if isinstance(id, EnvSpec):
        spec_ = id
    else:
        module, id = (None, id) if ":" not in id else id.split(":")
        if module is not None:
            try:
                importlib.import_module(module)
            except ModuleNotFoundError as e:
                raise ModuleNotFoundError(
                    f"{e}. Environment registration via importing a module failed. "
                    f"Check whether '{module}' contains env registration and can be imported."
                )
        spec_ = registry.get(id)

        ns, name, version = parse_env_id(id)
        latest_version = find_highest_version(ns, name)
        if (version is not None and latest_version is not None
                and latest_version > version):
            logger.warn(
                f"The environment {id} is out of date. You should consider "
                f"upgrading to version `v{latest_version}`.")
        if version is None and latest_version is not None:
            version = latest_version
            new_env_id = get_env_id(ns, name, version)
            spec_ = registry.get(new_env_id)
            logger.warn(
                f"Using the latest versioned environment `{new_env_id}` "
                f"instead of the unversioned environment `{id}`.")

        if spec_ is None:
            _check_version_exists(ns, name, version)
            raise error.Error(f"No registered env with id: {id}")

    _kwargs = spec_.kwargs.copy()
    _kwargs.update(kwargs)

    # TODO: add a minimal env checker on initialization
    if spec_.entry_point is None:
        raise error.Error(
            f"{spec_.id} registered but entry_point is not specified")
    elif callable(spec_.entry_point):
        cls = spec_.entry_point
    else:
        # Assume it's a string
        cls = load(spec_.entry_point)

    env = cls(**_kwargs)

    spec_ = copy.deepcopy(spec_)
    spec_.kwargs = _kwargs

    env.unwrapped.spec = spec_

    if spec_.order_enforce:
        env = OrderEnforcing(env)

    if max_episode_steps is not None:
        env = TimeLimit(env, max_episode_steps)
    elif spec_.max_episode_steps is not None:
        env = TimeLimit(env, spec_.max_episode_steps)

    if autoreset:
        env = AutoResetWrapper(env)

    if not disable_env_checker:
        try:
            check_env(env)
        except Exception as e:
            logger.warn(
                f"Env check failed with the following message: {e}\nYou can set `disable_env_checker=True` to disable this check."
            )

    return env