Esempio n. 1
0
    def get_model_v2(obs_space,
                     action_space,
                     num_outputs,
                     model_config,
                     framework="tf",
                     name="default_model",
                     model_interface=None,
                     default_model=None,
                     **model_kwargs):
        """Returns a suitable model compatible with given spaces and output.

        Args:
            obs_space (Space): Observation space of the target gym env. This
                may have an `original_space` attribute that specifies how to
                unflatten the tensor into a ragged tensor.
            action_space (Space): Action space of the target gym env.
            num_outputs (int): The size of the output vector of the model.
            framework (str): One of "tf" or "torch".
            name (str): Name (scope) for the model.
            model_interface (cls): Interface required for the model
            default_model (cls): Override the default class for the model. This
                only has an effect when not using a custom model
            model_kwargs (dict): args to pass to the ModelV2 constructor

        Returns:
            model (ModelV2): Model to use for the policy.
        """

        if model_config.get("custom_model"):
            model_cls = _global_registry.get(RLLIB_MODEL,
                                             model_config["custom_model"])
            if issubclass(model_cls, ModelV2):
                if framework == "tf":
                    logger.info("Wrapping {} as {}".format(
                        model_cls, model_interface))
                    model_cls = ModelCatalog._wrap_if_needed(
                        model_cls, model_interface)
                    created = set()

                    # Track and warn if vars were created but not registered
                    def track_var_creation(next_creator, **kw):
                        v = next_creator(**kw)
                        created.add(v)
                        return v

                    with tf.variable_creator_scope(track_var_creation):
                        instance = model_cls(obs_space, action_space,
                                             num_outputs, model_config, name,
                                             **model_kwargs)
                    registered = set(instance.variables())
                    not_registered = set()
                    for var in created:
                        if var not in registered:
                            not_registered.add(var)
                    if not_registered:
                        raise ValueError(
                            "It looks like variables {} were created as part "
                            "of {} but does not appear in model.variables() "
                            "({}). Did you forget to call "
                            "model.register_variables() on the variables in "
                            "question?".format(not_registered, instance,
                                               registered))
                else:
                    # no variable tracking
                    instance = model_cls(obs_space, action_space, num_outputs,
                                         model_config, name, **model_kwargs)
                return instance
            elif tf.executing_eagerly():
                raise ValueError(
                    "Eager execution requires a TFModelV2 model to be "
                    "used, however you specified a custom model {}".format(
                        model_cls))

        if framework == "tf":
            v2_class = None
            # try to get a default v2 model
            if not model_config.get("custom_model"):
                v2_class = default_model or ModelCatalog._get_v2_model_class(
                    obs_space, model_config, framework=framework)
            # fallback to a default v1 model
            if v2_class is None:
                if tf.executing_eagerly():
                    raise ValueError(
                        "Eager execution requires a TFModelV2 model to be "
                        "used, however there is no default V2 model for this "
                        "observation space: {}, use_lstm={}".format(
                            obs_space, model_config.get("use_lstm")))
                v2_class = make_v1_wrapper(ModelCatalog.get_model)
            # wrap in the requested interface
            wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
            return wrapper(obs_space, action_space, num_outputs, model_config,
                           name, **model_kwargs)
        elif framework == "torch":
            v2_class = \
                default_model or ModelCatalog._get_v2_model_class(
                    obs_space, model_config, framework=framework)
            # Wrap in the requested interface.
            wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
            return wrapper(obs_space, action_space, num_outputs, model_config,
                           name, **model_kwargs)
        else:
            raise NotImplementedError(
                "Framework must be 'tf' or 'torch': {}".format(framework))
Esempio n. 2
0
    def get_model_v2(obs_space,
                     action_space,
                     num_outputs,
                     model_config,
                     framework="tf",
                     name="default_model",
                     model_interface=None,
                     default_model=None,
                     **model_kwargs):
        """Returns a suitable model compatible with given spaces and output.

        Args:
            obs_space (Space): Observation space of the target gym env. This
                may have an `original_space` attribute that specifies how to
                unflatten the tensor into a ragged tensor.
            action_space (Space): Action space of the target gym env.
            num_outputs (int): The size of the output vector of the model.
            framework (str): One of "tf", "tfe", or "torch".
            name (str): Name (scope) for the model.
            model_interface (cls): Interface required for the model
            default_model (cls): Override the default class for the model. This
                only has an effect when not using a custom model
            model_kwargs (dict): args to pass to the ModelV2 constructor

        Returns:
            model (ModelV2): Model to use for the policy.
        """

        if model_config.get("custom_model"):

            if "custom_options" in model_config and \
                    model_config["custom_options"] != DEPRECATED_VALUE:
                deprecation_warning(
                    "model.custom_options",
                    "model.custom_model_config",
                    error=False)
                model_config["custom_model_config"] = \
                    model_config.pop("custom_options")

            if isinstance(model_config["custom_model"], type):
                model_cls = model_config["custom_model"]
            else:
                model_cls = _global_registry.get(RLLIB_MODEL,
                                                 model_config["custom_model"])

            # TODO(sven): Hard-deprecate Model(V1).
            if issubclass(model_cls, ModelV2):
                logger.info("Wrapping {} as {}".format(model_cls,
                                                       model_interface))
                model_cls = ModelCatalog._wrap_if_needed(
                    model_cls, model_interface)

                if framework in ["tf", "tfe"]:
                    # Track and warn if vars were created but not registered.
                    created = set()

                    def track_var_creation(next_creator, **kw):
                        v = next_creator(**kw)
                        created.add(v.ref())
                        return v

                    with tf.variable_creator_scope(track_var_creation):
                        # Try calling with kwargs first (custom ModelV2 should
                        # accept these as kwargs, not get them from
                        # config["custom_model_config"] anymore).
                        try:
                            instance = model_cls(obs_space, action_space,
                                                 num_outputs, model_config,
                                                 name, **model_kwargs)
                        except TypeError as e:
                            # Keyword error: Try old way w/o kwargs.
                            if "__init__() got an unexpected " in e.args[0]:
                                logger.warning(
                                    "Custom ModelV2 should accept all custom "
                                    "options as **kwargs, instead of expecting"
                                    " them in config['custom_model_config']!")
                                instance = model_cls(obs_space, action_space,
                                                     num_outputs, model_config,
                                                     name)
                            # Other error -> re-raise.
                            else:
                                raise e
                    registered = set([v.ref() for v in instance.variables()])
                    not_registered = set()
                    for var in created:
                        if var not in registered:
                            not_registered.add(var.ref())
                    if not_registered:
                        raise ValueError(
                            "It looks like variables {} were created as part "
                            "of {} but does not appear in model.variables() "
                            "({}). Did you forget to call "
                            "model.register_variables() on the variables in "
                            "question?".format(not_registered, instance,
                                               registered))
                else:
                    # PyTorch automatically tracks nn.Modules inside the parent
                    # nn.Module's constructor.
                    # TODO(sven): Do this for TF as well.
                    instance = model_cls(obs_space, action_space, num_outputs,
                                         model_config, name, **model_kwargs)
                return instance
            # TODO(sven): Hard-deprecate Model(V1). This check will be
            #   superflous then.
            elif tf.executing_eagerly():
                raise ValueError(
                    "Eager execution requires a TFModelV2 model to be "
                    "used, however you specified a custom model {}".format(
                        model_cls))

        if framework in ["tf", "tfe", "tf2"]:
            v2_class = None
            # Try to get a default v2 model.
            if not model_config.get("custom_model"):
                v2_class = default_model or ModelCatalog._get_v2_model_class(
                    obs_space, model_config, framework=framework)

            if model_config.get("use_lstm"):
                wrapped_cls = v2_class
                forward = wrapped_cls.forward
                v2_class = ModelCatalog._wrap_if_needed(
                    wrapped_cls, LSTMWrapper)
                v2_class._wrapped_forward = forward

            # fallback to a default v1 model
            if v2_class is None:
                if tf.executing_eagerly():
                    raise ValueError(
                        "Eager execution requires a TFModelV2 model to be "
                        "used, however there is no default V2 model for this "
                        "observation space: {}, use_lstm={}".format(
                            obs_space, model_config.get("use_lstm")))
                v2_class = make_v1_wrapper(ModelCatalog.get_model)
            # Wrap in the requested interface.
            wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
            return wrapper(obs_space, action_space, num_outputs, model_config,
                           name, **model_kwargs)
        elif framework == "torch":
            v2_class = \
                default_model or ModelCatalog._get_v2_model_class(
                    obs_space, model_config, framework=framework)
            if model_config.get("use_lstm"):
                from ray.rllib.models.torch.recurrent_net import LSTMWrapper \
                    as TorchLSTMWrapper
                wrapped_cls = v2_class
                forward = wrapped_cls.forward
                v2_class = ModelCatalog._wrap_if_needed(
                    wrapped_cls, TorchLSTMWrapper)
                v2_class._wrapped_forward = forward
            # Wrap in the requested interface.
            wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
            return wrapper(obs_space, action_space, num_outputs, model_config,
                           name, **model_kwargs)
        else:
            raise NotImplementedError(
                "`framework` must be 'tf|tfe|torch', but is "
                "{}!".format(framework))
Esempio n. 3
0
    def get_model_v2(obs_space,
                     action_space,
                     num_outputs,
                     model_config,
                     framework="tf",
                     name=None,
                     model_interface=None,
                     **model_kwargs):
        """Returns a suitable model compatible with given spaces and output.

        Args:
            obs_space (Space): Observation space of the target gym env. This
                may have an `original_space` attribute that specifies how to
                unflatten the tensor into a ragged tensor.
            action_space (Space): Action space of the target gym env.
            num_outputs (int): The size of the output vector of the model.
            framework (str): Either "tf" or "torch".
            name (str): Name (scope) for the model.
            model_interface (cls): Interface required for the model
            model_kwargs (dict): args to pass to the ModelV2 constructor

        Returns:
            model (ModelV2): Model to use for the policy.
        """

        if model_config.get("custom_model"):
            model_cls = _global_registry.get(RLLIB_MODEL,
                                             model_config["custom_model"])
            if issubclass(model_cls, ModelV2):
                if model_interface and not issubclass(model_cls,
                                                      model_interface):
                    raise ValueError("The given model must subclass",
                                     model_interface)
                created = set()

                # Track and warn if variables were created but no registered
                def track_var_creation(next_creator, **kw):
                    v = next_creator(**kw)
                    created.add(v)
                    return v

                with tf.variable_creator_scope(track_var_creation):
                    instance = model_cls(obs_space, action_space, num_outputs,
                                         model_config, name, **model_kwargs)
                registered = set(instance.variables())
                not_registered = set()
                for var in created:
                    if var not in registered:
                        not_registered.add(var)
                if not_registered:
                    raise ValueError(
                        "It looks like variables {} were created as part of "
                        "{} but does not appear in model.variables() ({}). "
                        "Did you forget to call model.register_variables() "
                        "on the variables in question?".format(
                            not_registered, instance, registered))
                return instance

        if framework == "tf":
            legacy_model_cls = ModelCatalog.get_model
            wrapper = ModelCatalog._wrap_if_needed(
                make_v1_wrapper(legacy_model_cls), model_interface)
            return wrapper(obs_space, action_space, num_outputs, model_config,
                           name, **model_kwargs)

        raise NotImplementedError("TODO: support {} models".format(framework))