예제 #1
0
파일: registration.py 프로젝트: araffin/gym
def _check_version_exists(ns: Optional[str], name: str,
                          version: Optional[int]):
    """Check if an env version exists in a namespace. If it doesn't, print a helpful error message.
    This is a complete test whether an environment identifier is valid, and will provide the best available hints.

    Args:
        ns: The environment namespace
        name: The environment space
        version: The environment version

    Raises:
        DeprecatedEnv: The environment doesn't exist but a default version does
        VersionNotFound: The ``version`` used doesn't exist
        DeprecatedEnv: Environment version is deprecated
    """
    if get_env_id(ns, name, version) in registry:
        return

    _check_name_exists(ns, name)
    if version is None:
        return

    message = f"Environment version `v{version}` for environment `{get_env_id(ns, name, None)}` doesn't exist."

    env_specs = [
        spec_ for spec_ in registry.values()
        if spec_.namespace == ns and spec_.name == name
    ]
    env_specs = sorted(env_specs, key=lambda spec_: int(spec_.version or -1))

    default_spec = [spec_ for spec_ in env_specs if spec_.version is None]

    if default_spec:
        message += f" It provides the default version {default_spec[0].id}`."
        if len(env_specs) == 1:
            raise error.DeprecatedEnv(message)

    # Process possible versioned environments

    versioned_specs = [
        spec_ for spec_ in env_specs if spec_.version is not None
    ]

    latest_spec = max(versioned_specs,
                      key=lambda spec: spec.version,
                      default=None)  # type: ignore
    if latest_spec is not None and version > latest_spec.version:
        version_list_msg = ", ".join(f"`v{spec_.version}`"
                                     for spec_ in env_specs)
        message += f" It provides versioned environments: [ {version_list_msg} ]."

        raise error.VersionNotFound(message)

    if latest_spec is not None and version < latest_spec.version:
        raise error.DeprecatedEnv(
            f"Environment version v{version} for `{get_env_id(ns, name, None)}` is deprecated. "
            f"Please use `{latest_spec.id}` instead.")
    def spec(self, id):
        match = env_id_re.search(id)
        if not match:
            raise error.Error(
                'Attempted to look up malformed environment ID: {}. (Currently all IDs must be of the form {}.)'
                .format(id.encode('utf-8'), env_id_re.pattern))

        try:
            return self.env_specs[id]
        except KeyError:
            # Parse the env name and check to see if it matches the non-version
            # part of a valid env (could also check the exact number here)
            env_name = match.group(1)
            matching_envs = [
                valid_env_name
                for valid_env_name, valid_env_spec in self.env_specs.items()
                if env_name == valid_env_spec._env_name
            ]
            if matching_envs:
                raise error.DeprecatedEnv(
                    'Env {} not found (valid versions include {})'.format(
                        id, matching_envs))
            else:
                raise error.UnregisteredEnv(
                    'No registered env with id: {}'.format(id))
예제 #3
0
    def spec(self, path):
        if ':' in path:
            mod_name, _sep, id = path.partition(':')
            try:
                importlib.import_module(mod_name)
            # catch ImportError for python2.7 compatibility
            except ImportError:
                raise error.Error('A module ({}) was specified for the environment but was not found, make sure the package is installed with `pip install` before calling `gym.make()`'.format(mod_name))
        else:
            id = path

        match = env_id_re.search(id)
        if not match:
            raise error.Error('Attempted to look up malformed environment ID: {}. (Currently all IDs must be of the form {}.)'.format(id.encode('utf-8'), env_id_re.pattern))

        try:
            return self.env_specs[id]
        except KeyError:
            # Parse the env name and check to see if it matches the non-version
            # part of a valid env (could also check the exact number here)
            env_name = match.group(1)
            matching_envs = [valid_env_name for valid_env_name, valid_env_spec in self.env_specs.items()
                             if env_name == valid_env_spec._env_name]
            if matching_envs:
                raise error.DeprecatedEnv('Env {} not found (valid versions include {})'.format(id, matching_envs))
            else:
                raise error.UnregisteredEnv('No registered env with id: {}'.format(id))
예제 #4
0
    def spec(self, path):
        if ":" in path:
            mod_name, _, id = path.partition(":")
            try:
                importlib.import_module(mod_name)
            except ModuleNotFoundError:
                raise error.Error(
                    "A module ({}) was specified for the environment but was not found, make sure the package is installed with `pip install` before calling `gym.make()`"
                    .format(mod_name))
        else:
            id = path

        match = env_id_re.search(id)
        if not match:
            raise error.Error(
                "Attempted to look up malformed environment ID: {}. (Currently all IDs must be of the form {}.)"
                .format(id.encode("utf-8"), env_id_re.pattern))

        try:
            return self.env_specs[id]
        except KeyError:
            # Parse the env name and check to see if it matches the non-version
            # part of a valid env (could also check the exact number here)
            env_name = match.group(1)
            matching_envs = [
                valid_env_name
                for valid_env_name, valid_env_spec in self.env_specs.items()
                if env_name == valid_env_spec._env_name
            ]
            algorithmic_envs = [
                "Copy",
                "RepeatCopy",
                "DuplicatedInput",
                "Reverse",
                "ReversedAdiiton",
                "ReversedAddition3",
            ]
            toytext_envs = [
                "KellyCoinflip",
                "KellyCoinflipGeneralized",
                "NChain",
                "Roulette",
                "GuessingGame",
                "HotterColder",
            ]
            if matching_envs:
                raise error.DeprecatedEnv(
                    "Env {} not found (valid versions include {})".format(
                        id, matching_envs))
            elif env_name in algorithmic_envs:
                raise error.UnregisteredEnv(
                    "Algorithmic environment {} has been moved out of Gym. Install it via `pip install gym-algorithmic` and add `import gym_algorithmic` before using it."
                    .format(id))
            elif env_name in toytext_envs:
                raise error.UnregisteredEnv(
                    "Toytext environment {} has been moved out of Gym. Install it via `pip install gym-legacy-toytext` and add `import gym_toytext` before using it."
                    .format(id))
            else:
                raise error.UnregisteredEnv(
                    "No registered env with id: {}".format(id))
예제 #5
0
파일: registration.py 프로젝트: chksi/gym
    def _assert_version_exists(self, namespace: Optional[str], name: str,
                               version: Optional[int]):
        self._assert_name_exists(namespace, name)
        if version in self.tree[namespace][name]:
            return

        # Construct the appropriate exception.
        # If the version is less than the latest version
        # then we throw an error.DeprecatedEnv exception.
        # Otherwise we throw error.VersionNotFound.
        versions = self.tree[namespace][name]
        assert len(versions) > 0

        versioned_specs = list(
            filter(lambda spec: isinstance(spec.version, int),
                   versions.values()))
        default_spec = versions[None] if None in versions else None
        assert len(versioned_specs) > 0 or default_spec is not None

        latest_spec = max(versioned_specs,
                          key=lambda spec: spec.version,
                          default=default_spec)

        if version is not None:
            message = f"Environment version `v{version}` for `"
        else:
            message = "The default version for `"

        if namespace is not None:
            message += f"{namespace}/"
        message += f"{name}` "

        # If this version doesn't exist but there exists a newer non-default
        # version we should warn the user this version is deprecated.
        if (latest_spec and latest_spec.version is not None
                and version is not None and version < latest_spec.version):
            message += "is deprecated. "
            message += f"Please use the latest version `v{latest_spec.version}`."
            raise error.DeprecatedEnv(message)
        # If this version doesn't exist and there only exists a default version
        elif latest_spec and latest_spec.version is None:
            message += "is deprecated. "
            message += f"`{latest_spec.name}` only provides the default version. "
            message += (
                f'You can initialize the environment as `gym.make("{latest_spec.id}")`.'
            )
            raise error.DeprecatedEnv(message)
        # Otherwise we've asked for a version that doesn't exist.
        else:
            message += f"could not be found. `{name}` provides "

            if default_spec:
                message += "a default version"
                if versioned_specs:
                    message += " and "
            if versioned_specs:
                message += "the versioned environments: [ "
                versioned_specs_sorted = sorted(versioned_specs,
                                                key=lambda spec: spec.version)
                message += ", ".join(
                    map(lambda spec: f"`v{spec.version}`",
                        versioned_specs_sorted))
                message += " ]"
            message += "."
            raise error.VersionNotFound(message)