def register_custom_components(): global_env_registry().register_env( env_name_prefix='gym_', make_env_func=make_gym_env_func, add_extra_params_func=add_extra_params_func, override_default_params_func=override_default_params_func, )
def register_custom_components(): global_env_registry().register_env( env_name_prefix='my_custom_env_', make_env_func=make_custom_env_func, add_extra_params_func=add_extra_params_func, override_default_params_func=override_default_params_func, ) register_custom_encoder('custom_env_encoder', CustomEncoder)
def add_env_args(env, parser): p = parser p.add_argument('--env_frameskip', default=None, type=int, help='Number of frames for action repeat (frame skipping). Default (None) means use default environment value') p.add_argument('--env_framestack', default=4, type=int, help='Frame stacking (only used in Atari?)') p.add_argument('--pixel_format', default='CHW', type=str, help='PyTorch expects CHW by default, Ray & TensorFlow expect HWC') add_extra_params_func = global_env_registry().resolve_env_name(env).add_extra_params_func if add_extra_params_func is not None: add_extra_params_func(env, p)
def create_env(full_env_name, cfg=None, env_config=None): """ Factory function that creates environment instances. Matches full_env_name with env family prefixes registered in the REGISTRY and calls make_env_func() for the first match. :param full_env_name: complete name of the environment, starting with the prefix of registered environment family, e.g. atari_breakout, or doom_battle. Passed to make_env_func() for further processing by the specific env family factory (see doom_utils.py or dmlab_env.py) :param cfg: namespace with full system configuration, output of argparser (or AttrDict when loaded from JSON) :param env_config: AttrDict with additional system information: env_config = AttrDict(worker_index=self.worker_idx, vector_index=vector_idx, env_id=env_id) :return: environment instance """ env_registry = global_env_registry() env_registry_entry = env_registry.resolve_env_name(full_env_name) env = env_registry_entry.make_env_func(full_env_name, cfg=cfg, env_config=env_config) return env
def env_override_defaults(env, parser): override_default_params_func = global_env_registry().resolve_env_name(env).override_default_params_func if override_default_params_func is not None: override_default_params_func(env, parser)