Exemplo n.º 1
0
def setup_and_load(use_cmd_line_args=True, **kwargs):
    """
    Initialize the global config using command line options, defaulting to the values in `config.py`.

    `use_cmd_line_args`: set to False to ignore command line arguments passed to the program
    `**kwargs`: override the defaults from `config.py` with these values
    """
    args = Config.initialize_args(use_cmd_line_args=use_cmd_line_args,
                                  **kwargs)

    load_for_setup_if_necessary()

    return args
Exemplo n.º 2
0
def create_env(
    num_envs,
    *,
    env_kind="procgen",
    epsilon_greedy=0.0,
    reward_scale=1.0,
    frame_stack=1,
    use_sticky_actions=0,
    coinrun_old_extra_actions=0,
    **kwargs,
):
    if env_kind == "procgen":
        env_kwargs = {k: v for k, v in kwargs.items() if v is not None}
        env_name = env_kwargs.pop("env_name")

        if env_name == "coinrun_old":
            import coinrun
            from coinrun.config import Config

            Config.initialize_args(use_cmd_line_args=False, **env_kwargs)
            global coinrun_initialized
            if not coinrun_initialized:
                coinrun.init_args_and_threads()
                coinrun_initialized = True
            venv = coinrun.make("standard", num_envs)
            if coinrun_old_extra_actions > 0:
                venv = VecExtraActions(
                    venv, extra_actions=coinrun_old_extra_actions, default_action=0
                )

        else:
            from procgen import ProcgenGym3Env
            import gym3

            env_kwargs = {
                k: v for k, v in env_kwargs.items() if k in PROCGEN_KWARG_KEYS
            }
            env = ProcgenGym3Env(num_envs, env_name=env_name, **env_kwargs)
            env = gym3.ExtractDictObWrapper(env, "rgb")
            venv = gym3.ToBaselinesVecEnv(env)

    elif env_kind == "atari":
        game_version = "v0" if use_sticky_actions == 1 else "v4"

        def make_atari_env(lower_env_id, num_env):
            env_id = ATARI_ENV_DICT[lower_env_id] + f"NoFrameskip-{game_version}"

            def make_atari_env_fn():
                env = make_atari(env_id)
                env = wrap_deepmind(env, frame_stack=False, clip_rewards=False)

                return env

            return SubprocVecEnv([make_atari_env_fn for i in range(num_env)])

        lower_env_id = kwargs["env_id"]

        venv = make_atari_env(lower_env_id, num_envs)

    else:
        raise ValueError(f"Unsupported env_kind: {env_kind}")

    if frame_stack > 1:
        venv = VecFrameStack(venv=venv, nstack=frame_stack)

    if reward_scale != 1:
        venv = VecRewardScale(venv, reward_scale)

    venv = VecMonitor(venv=venv, filename=None, keep_buf=100)

    if epsilon_greedy > 0:
        venv = EpsilonGreedy(venv, epsilon_greedy)

    venv = VecShallowCopy(venv)

    return venv