Ejemplo n.º 1
0
def register_custom_components():
    global_env_registry().register_env(
        env_name_prefix='gym_',
        make_env_func=make_gym_env_func,
        add_extra_params_func=add_extra_params_func,
        override_default_params_func=override_default_params_func,
    )
def register_custom_components():
    global_env_registry().register_env(
        env_name_prefix='my_custom_env_',
        make_env_func=make_custom_env_func,
        add_extra_params_func=add_extra_params_func,
        override_default_params_func=override_default_params_func,
    )

    register_custom_encoder('custom_env_encoder', CustomEncoder)
Ejemplo n.º 3
0
def add_env_args(env, parser):
    p = parser

    p.add_argument('--env_frameskip', default=None, type=int, help='Number of frames for action repeat (frame skipping). Default (None) means use default environment value')
    p.add_argument('--env_framestack', default=4, type=int, help='Frame stacking (only used in Atari?)')
    p.add_argument('--pixel_format', default='CHW', type=str, help='PyTorch expects CHW by default, Ray & TensorFlow expect HWC')

    add_extra_params_func = global_env_registry().resolve_env_name(env).add_extra_params_func
    if add_extra_params_func is not None:
        add_extra_params_func(env, p)
Ejemplo n.º 4
0
def create_env(full_env_name, cfg=None, env_config=None):
    """
    Factory function that creates environment instances.
    Matches full_env_name with env family prefixes registered in the REGISTRY and calls make_env_func()
    for the first match.

    :param full_env_name: complete name of the environment, starting with the prefix of registered environment family,
    e.g. atari_breakout, or doom_battle. Passed to make_env_func() for further processing by the specific env family
    factory (see doom_utils.py or dmlab_env.py)
    :param cfg: namespace with full system configuration, output of argparser (or AttrDict when loaded from JSON)
    :param env_config: AttrDict with additional system information:
    env_config = AttrDict(worker_index=self.worker_idx, vector_index=vector_idx, env_id=env_id)

    :return: environment instance
    """

    env_registry = global_env_registry()
    env_registry_entry = env_registry.resolve_env_name(full_env_name)
    env = env_registry_entry.make_env_func(full_env_name,
                                           cfg=cfg,
                                           env_config=env_config)
    return env
Ejemplo n.º 5
0
def env_override_defaults(env, parser):
    override_default_params_func = global_env_registry().resolve_env_name(env).override_default_params_func
    if override_default_params_func is not None:
        override_default_params_func(env, parser)