Ejemplo n.º 1
0
    def get_torch_model(obs_space,
                        num_outputs,
                        options=None,
                        default_model_cls=None):
        """Returns a custom model for PyTorch algorithms.

        Args:
            obs_space (Space): The input observation space.
            num_outputs (int): The size of the output vector of the model.
            options (dict): Optional args to pass to the model constructor.
            default_model_cls (cls): Optional class to use if no custom model.

        Returns:
            model (models.Model): Neural network model.
        """
        from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
                                                  PyTorchFCNet)
        from ray.rllib.models.torch.visionnet import (VisionNetwork as
                                                      PyTorchVisionNet)

        options = options or MODEL_DEFAULTS

        if options.get("custom_model"):
            model = options["custom_model"]
            logger.debug("Using custom torch model {}".format(model))
            return _global_registry.get(RLLIB_MODEL,
                                        model)(obs_space, num_outputs, options)

        if options.get("use_lstm"):
            raise NotImplementedError(
                "LSTM auto-wrapping not implemented for torch")

        if default_model_cls:
            return default_model_cls(obs_space, num_outputs, options)

        if isinstance(obs_space, gym.spaces.Discrete):
            obs_rank = 1
        else:
            obs_rank = len(obs_space.shape)

        if obs_rank > 1:
            return PyTorchVisionNet(obs_space, num_outputs, options)

        return PyTorchFCNet(obs_space, num_outputs, options)
Ejemplo n.º 2
0
    def _get_default_torch_model_v2(obs_space, action_space, num_outputs,
                                    model_config, name):
        from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
                                                  PyTorchFCNet)
        from ray.rllib.models.torch.visionnet import (VisionNetwork as
                                                      PyTorchVisionNet)

        model_config = model_config or MODEL_DEFAULTS

        if model_config.get("use_lstm"):
            raise NotImplementedError(
                "LSTM auto-wrapping not implemented for torch")

        if isinstance(obs_space, gym.spaces.Discrete):
            obs_rank = 1
        else:
            obs_rank = len(obs_space.shape)

        if obs_rank > 2:
            return PyTorchVisionNet(obs_space, action_space, num_outputs,
                                    model_config, name)

        return PyTorchFCNet(obs_space, action_space, num_outputs, model_config,
                            name)