Esempio n. 1
0
 def __init__(self, action_space: Space, *, framework: str,
              policy_config: dict, model: ModelV2, num_workers: int,
              worker_index: int):
     """
     Args:
         action_space (Space): The action space in which to explore.
         framework (str): One of "tf" or "torch".
         policy_config (dict): The Policy's config dict.
         model (ModelV2): The Policy's model.
         num_workers (int): The overall number of workers used.
         worker_index (int): The index of the worker using this class.
     """
     self.action_space = action_space
     self.policy_config = policy_config
     self.model = model
     self.num_workers = num_workers
     self.worker_index = worker_index
     self.framework = check_framework(framework)
     # The device on which the Model has been placed.
     # This Exploration will be on the same device.
     self.device = None
     if isinstance(self.model, nn.Module):
         params = list(self.model.parameters())
         if params:
             self.device = params[0].device
Esempio n. 2
0
 def __init__(self,
              action_space: Space,
              num_workers: int,
              worker_index: int,
              framework: str = "tf"):
     """
     Args:
         action_space (Space): The action space in which to explore.
         num_workers (int): The overall number of workers used.
         worker_index (int): The index of the worker using this class.
         framework (str): One of "tf" or "torch".
     """
     self.action_space = action_space
     self.num_workers = num_workers
     self.worker_index = worker_index
     self.framework = check_framework(framework)
Esempio n. 3
0
 def __init__(self,
              action_space=None,
              num_workers=None,
              worker_index=None,
              framework="tf"):
     """
     Args:
         action_space (Optional[gym.spaces.Space]): The action space in
             which to explore.
         num_workers (Optional[int]): The overall number of workers used.
         worker_index (Optional[int]): The index of the Worker using this
             Exploration.
         framework (str): One of "tf" or "torch".
     """
     self.action_space = action_space
     self.num_workers = num_workers
     self.worker_index = worker_index
     self.framework = check_framework(framework)
Esempio n. 4
0
 def __init__(self, action_space: Space, *, framework: str,
              num_workers: int, worker_index: int, policy_config: dict,
              model: ModelV2):
     """
     Args:
         action_space (Space): The action space in which to explore.
         framework (str): One of "tf" or "torch".
         num_workers (int): The overall number of workers used.
         worker_index (int): The index of the worker using this class.
         policy_config (dict): The Policy's config dict.
         model (ModelV2): The Policy's model.
     """
     self.action_space = action_space
     self.policy_config = policy_config
     self.model = model
     self.num_workers = num_workers
     self.worker_index = worker_index
     self.framework = check_framework(framework)
Esempio n. 5
0
    def _get_v2_model_class(obs_space, model_config, framework="tf"):
        # Make sure, framework is ok.
        framework = check_framework(framework)

        if framework == "torch":
            from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
                                                      FCNet)
            from ray.rllib.models.torch.visionnet import (VisionNetwork as
                                                          VisionNet)
        else:
            from ray.rllib.models.tf.fcnet import \
                FullyConnectedNetwork as FCNet
            from ray.rllib.models.tf.visionnet import \
                VisionNetwork as VisionNet

        # Discrete/1D obs-spaces.
        if isinstance(obs_space, gym.spaces.Discrete) or \
                len(obs_space.shape) <= 2:
            return FCNet
        # Default Conv2D net.
        else:
            return VisionNet
Esempio n. 6
0
    def get_model_v2(obs_space,
                     action_space,
                     num_outputs,
                     model_config,
                     framework="tf",
                     name="default_model",
                     model_interface=None,
                     default_model=None,
                     **model_kwargs):
        """Returns a suitable model compatible with given spaces and output.

        Args:
            obs_space (Space): Observation space of the target gym env. This
                may have an `original_space` attribute that specifies how to
                unflatten the tensor into a ragged tensor.
            action_space (Space): Action space of the target gym env.
            num_outputs (int): The size of the output vector of the model.
            framework (str): One of "tf", "tfe", or "torch".
            name (str): Name (scope) for the model.
            model_interface (cls): Interface required for the model
            default_model (cls): Override the default class for the model. This
                only has an effect when not using a custom model
            model_kwargs (dict): args to pass to the ModelV2 constructor

        Returns:
            model (ModelV2): Model to use for the policy.
        """

        # Make sure, framework is ok.
        framework = check_framework(framework)

        if model_config.get("custom_model"):

            if "custom_options" in model_config and \
                    model_config["custom_options"] != DEPRECATED_VALUE:
                deprecation_warning("model.custom_options",
                                    "model.custom_model_config",
                                    error=False)
                model_config["custom_model_config"] = \
                    model_config.pop("custom_options")

            if isinstance(model_config["custom_model"], type):
                model_cls = model_config["custom_model"]
            else:
                model_cls = _global_registry.get(RLLIB_MODEL,
                                                 model_config["custom_model"])
            # TODO(sven): Hard-deprecate Model(V1).
            if issubclass(model_cls, ModelV2):
                logger.info("Wrapping {} as {}".format(model_cls,
                                                       model_interface))
                model_cls = ModelCatalog._wrap_if_needed(
                    model_cls, model_interface)

                if framework in ["tf", "tfe"]:
                    # Track and warn if vars were created but not registered.
                    created = set()

                    def track_var_creation(next_creator, **kw):
                        v = next_creator(**kw)
                        created.add(v)
                        return v

                    with tf.variable_creator_scope(track_var_creation):
                        # Try calling with kwargs first (custom ModelV2 should
                        # accept these as kwargs, not get them from
                        # config["custom_model_config"] anymore).
                        try:
                            instance = model_cls(obs_space, action_space,
                                                 num_outputs, model_config,
                                                 name, **model_kwargs)
                        except TypeError as e:
                            # Keyword error: Try old way w/o kwargs.
                            if "__init__() got an unexpected " in e.args[0]:
                                logger.warning(
                                    "Custom ModelV2 should accept all custom "
                                    "options as **kwargs, instead of expecting"
                                    " them in config['custom_model_config']!")
                                instance = model_cls(obs_space, action_space,
                                                     num_outputs, model_config,
                                                     name)
                            # Other error -> re-raise.
                            else:
                                raise e
                    registered = set(instance.variables())
                    not_registered = set()
                    for var in created:
                        if var not in registered:
                            not_registered.add(var)
                    if not_registered:
                        raise ValueError(
                            "It looks like variables {} were created as part "
                            "of {} but does not appear in model.variables() "
                            "({}). Did you forget to call "
                            "model.register_variables() on the variables in "
                            "question?".format(not_registered, instance,
                                               registered))
                else:
                    # PyTorch automatically tracks nn.Modules inside the parent
                    # nn.Module's constructor.
                    # TODO(sven): Do this for TF as well.
                    instance = model_cls(obs_space, action_space, num_outputs,
                                         model_config, name, **model_kwargs)
                return instance
            # TODO(sven): Hard-deprecate Model(V1). This check will be
            #   superflous then.
            elif tf.executing_eagerly():
                raise ValueError(
                    "Eager execution requires a TFModelV2 model to be "
                    "used, however you specified a custom model {}".format(
                        model_cls))

        if framework in ["tf", "tfe"]:
            v2_class = None
            # try to get a default v2 model
            if not model_config.get("custom_model"):
                v2_class = default_model or ModelCatalog._get_v2_model_class(
                    obs_space, model_config, framework=framework)
            # fallback to a default v1 model
            if v2_class is None:
                if tf.executing_eagerly():
                    raise ValueError(
                        "Eager execution requires a TFModelV2 model to be "
                        "used, however there is no default V2 model for this "
                        "observation space: {}, use_lstm={}".format(
                            obs_space, model_config.get("use_lstm")))
                v2_class = make_v1_wrapper(ModelCatalog.get_model)
            # wrap in the requested interface
            wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
            return wrapper(obs_space, action_space, num_outputs, model_config,
                           name, **model_kwargs)
        elif framework == "torch":
            v2_class = \
                default_model or ModelCatalog._get_v2_model_class(
                    obs_space, model_config, framework=framework)
            # Wrap in the requested interface.
            wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
            return wrapper(obs_space, action_space, num_outputs, model_config,
                           name, **model_kwargs)
        else:
            raise NotImplementedError(
                "Framework must be 'tf' or 'torch': {}".format(framework))
Esempio n. 7
0
    def get_action_dist(action_space,
                        config,
                        dist_type=None,
                        framework="tf",
                        **kwargs):
        """Returns a distribution class and size for the given action space.

        Args:
            action_space (Space): Action space of the target gym env.
            config (Optional[dict]): Optional model config.
            dist_type (Optional[str]): Identifier of the action distribution.
            framework (str): One of "tf", "tfe", or "torch".
            kwargs (dict): Optional kwargs to pass on to the Distribution's
                constructor.

        Returns:
            dist_class (ActionDistribution): Python class of the distribution.
            dist_dim (int): The size of the input vector to the distribution.
        """

        # Make sure, framework is ok.
        framework = check_framework(framework)

        dist = None
        config = config or MODEL_DEFAULTS
        # Custom distribution given.
        if config.get("custom_action_dist"):
            action_dist_name = config["custom_action_dist"]
            logger.debug(
                "Using custom action distribution {}".format(action_dist_name))
            dist = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name)
        # Dist_type is given directly as a class.
        elif type(dist_type) is type and \
                issubclass(dist_type, ActionDistribution) and \
                dist_type not in (
                MultiActionDistribution, TorchMultiActionDistribution):
            dist = dist_type
        # Box space -> DiagGaussian OR Deterministic.
        elif isinstance(action_space, gym.spaces.Box):
            if len(action_space.shape) > 1:
                raise UnsupportedSpaceException(
                    "Action space has multiple dimensions "
                    "{}. ".format(action_space.shape) +
                    "Consider reshaping this into a single dimension, "
                    "using a custom action distribution, "
                    "using a Tuple action space, or the multi-agent API.")
            # TODO(sven): Check for bounds and return SquashedNormal, etc..
            if dist_type is None:
                dist = TorchDiagGaussian if framework == "torch" \
                    else DiagGaussian
            elif dist_type == "deterministic":
                dist = TorchDeterministic if framework == "torch" \
                    else Deterministic
        # Discrete Space -> Categorical.
        elif isinstance(action_space, gym.spaces.Discrete):
            dist = TorchCategorical if framework == "torch" else Categorical
        # Tuple/Dict Spaces -> MultiAction.
        elif dist_type in (MultiActionDistribution,
                           TorchMultiActionDistribution) or \
                isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)):
            flat_action_space = flatten_space(action_space)
            child_dists_and_in_lens = tree.map_structure(
                lambda s: ModelCatalog.get_action_dist(
                    s, config, framework=framework), flat_action_space)
            child_dists = [e[0] for e in child_dists_and_in_lens]
            input_lens = [int(e[1]) for e in child_dists_and_in_lens]
            return partial((TorchMultiActionDistribution if framework
                            == "torch" else MultiActionDistribution),
                           action_space=action_space,
                           child_distributions=child_dists,
                           input_lens=input_lens), int(sum(input_lens))
        # Simplex -> Dirichlet.
        elif isinstance(action_space, Simplex):
            if framework == "torch":
                # TODO(sven): implement
                raise NotImplementedError(
                    "Simplex action spaces not supported for torch.")
            dist = Dirichlet
        # MultiDiscrete -> MultiCategorical.
        elif isinstance(action_space, gym.spaces.MultiDiscrete):
            dist = TorchMultiCategorical if framework == "torch" else \
                MultiCategorical
            return partial(dist, input_lens=action_space.nvec), \
                int(sum(action_space.nvec))
        # Unknown type -> Error.
        else:
            raise NotImplementedError("Unsupported args: {} {}".format(
                action_space, dist_type))

        return dist, dist.required_model_output_shape(action_space, config)
Esempio n. 8
0
    def _setup(self, config):
        env = self._env_id
        if env:
            config["env"] = env
            # An already registered env.
            if _global_registry.contains(ENV_CREATOR, env):
                self.env_creator = _global_registry.get(ENV_CREATOR, env)
            # A class specifier.
            elif "." in env:
                self.env_creator = \
                    lambda env_config: from_config(env, env_config)
            # Try gym.
            else:
                import gym  # soft dependency
                self.env_creator = lambda env_config: gym.make(env)
        else:
            self.env_creator = lambda env_config: None

        # Merge the supplied config with the class default, but store the
        # user-provided one.
        self.raw_user_config = config
        self.config = Trainer.merge_trainer_configs(self._default_config,
                                                    config)

        # Check and resolve DL framework settings.
        if "use_pytorch" in self.config and \
                self.config["use_pytorch"] != DEPRECATED_VALUE:
            deprecation_warning("use_pytorch", "framework=torch", error=False)
            if self.config["use_pytorch"]:
                self.config["framework"] = "torch"
            self.config.pop("use_pytorch")
        if "eager" in self.config and self.config["eager"] != DEPRECATED_VALUE:
            deprecation_warning("eager", "framework=tfe", error=False)
            if self.config["eager"]:
                self.config["framework"] = "tfe"
            self.config.pop("eager")

        # Check all dependencies and resolve "auto" framework.
        self.config["framework"] = check_framework(self.config["framework"])
        # Notify about eager/tracing support.
        if tf and self.config["framework"] == "tfe":
            if not tf.executing_eagerly():
                tf.enable_eager_execution()
            logger.info("Executing eagerly, with eager_tracing={}".format(
                self.config["eager_tracing"]))
        if tf and not tf.executing_eagerly() and \
                self.config["framework"] != "torch":
            logger.info("Tip: set framework=tfe or the --eager flag to enable "
                        "TensorFlow eager execution")

        if self.config["normalize_actions"]:
            inner = self.env_creator

            def normalize(env):
                import gym  # soft dependency
                if not isinstance(env, gym.Env):
                    raise ValueError(
                        "Cannot apply NormalizeActionActionWrapper to env of "
                        "type {}, which does not subclass gym.Env.", type(env))
                return NormalizeActionWrapper(env)

            self.env_creator = lambda env_config: normalize(inner(env_config))

        Trainer._validate_config(self.config)
        if not callable(self.config["callbacks"]):
            raise ValueError(
                "`callbacks` must be a callable method that "
                "returns a subclass of DefaultCallbacks, got {}".format(
                    self.config["callbacks"]))
        self.callbacks = self.config["callbacks"]()
        log_level = self.config.get("log_level")
        if log_level in ["WARN", "ERROR"]:
            logger.info("Current log_level is {}. For more information, "
                        "set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
                        "-vv flags.".format(log_level))
        if self.config.get("log_level"):
            logging.getLogger("ray.rllib").setLevel(self.config["log_level"])

        def get_scope():
            if tf and not tf.executing_eagerly():
                return tf.Graph().as_default()
            else:
                return open(os.devnull)  # fake a no-op scope

        with get_scope():
            self._init(self.config, self.env_creator)

            # Evaluation setup.
            if self.config.get("evaluation_interval"):
                # Update env_config with evaluation settings:
                extra_config = copy.deepcopy(self.config["evaluation_config"])
                # Assert that user has not unset "in_evaluation".
                assert "in_evaluation" not in extra_config or \
                    extra_config["in_evaluation"] is True
                extra_config.update({
                    "batch_mode": "complete_episodes",
                    "rollout_fragment_length": 1,
                    "in_evaluation": True,
                })
                logger.debug(
                    "using evaluation_config: {}".format(extra_config))

                self.evaluation_workers = self._make_workers(
                    self.env_creator,
                    self._policy,
                    merge_dicts(self.config, extra_config),
                    num_workers=self.config["evaluation_num_workers"])
                self.evaluation_metrics = {}
Esempio n. 9
0
 def __init__(self, framework=None):
     self.framework = check_framework(framework)
Esempio n. 10
0
 def __init__(self, framework=None):
     # TODO(sven): replace with .tf_value() / torch_value() methods that
     # can be applied late binding, so no need to set framework during
     # construction.
     self.framework = check_framework(framework)