Ejemplo n.º 1
0
    def get_torch_model(input_shape, num_outputs, options=None):
        """Returns a PyTorch suitable model. This is currently only supported
        in A3C.

        Args:
            input_shape (tuple): The input shape to the model.
            num_outputs (int): The size of the output vector of the model.
            options (dict): Optional args to pass to the model constructor.

        Returns:
            model (models.Model): Neural network model.
        """
        from ray.rllib.models.pytorch.fcnet import (FullyConnectedNetwork as
                                                    PyTorchFCNet)
        from ray.rllib.models.pytorch.visionnet import (VisionNetwork as
                                                        PyTorchVisionNet)

        options = options or MODEL_DEFAULTS
        if options.get("custom_model"):
            model = options["custom_model"]
            logger.info("Using custom torch model {}".format(model))
            return _global_registry.get(RLLIB_MODEL,
                                        model)(input_shape, num_outputs,
                                               options)

        # TODO(alok): fix to handle Discrete(n) state spaces
        obs_rank = len(input_shape) - 1

        if obs_rank > 1:
            return PyTorchVisionNet(input_shape, num_outputs, options)

        # TODO(alok): overhaul PyTorchFCNet so it can just
        # take input shape directly
        return PyTorchFCNet(input_shape[0], num_outputs, options)
Ejemplo n.º 2
0
    def get_torch_model(registry, input_shape, num_outputs, options=dict()):
        """Returns a PyTorch suitable model. This is currently only supported
        in A3C.

        Args:
            registry (obj): Registry of named objects (ray.tune.registry).
            input_shape (tuple): The input shape to the model.
            num_outputs (int): The size of the output vector of the model.
            options (dict): Optional args to pass to the model constructor.

        Returns:
            model (Model): Neural network model.
        """
        from ray.rllib.models.pytorch.fcnet import (FullyConnectedNetwork as
                                                    PyTorchFCNet)
        from ray.rllib.models.pytorch.visionnet import (VisionNetwork as
                                                        PyTorchVisionNet)

        if "custom_model" in options:
            model = options["custom_model"]
            print("Using custom torch model {}".format(model))
            return registry.get(RLLIB_MODEL, model)(input_shape, num_outputs,
                                                    options)

        obs_rank = len(input_shape) - 1

        if obs_rank > 1:
            return PyTorchVisionNet(input_shape, num_outputs, options)

        return PyTorchFCNet(input_shape[0], num_outputs, options)
Ejemplo n.º 3
0
    def get_torch_model(obs_space,
                        num_outputs,
                        options=None,
                        default_model_cls=None):
        """Returns a custom model for PyTorch algorithms.

        Args:
            obs_space (Space): The input observation space.
            num_outputs (int): The size of the output vector of the model.
            options (dict): Optional args to pass to the model constructor.
            default_model_cls (cls): Optional class to use if no custom model.

        Returns:
            model (models.Model): Neural network model.
        """
        from ray.rllib.models.pytorch.fcnet import (FullyConnectedNetwork as
                                                    PyTorchFCNet)
        from ray.rllib.models.pytorch.visionnet import (VisionNetwork as
                                                        PyTorchVisionNet)

        options = options or MODEL_DEFAULTS

        if options.get("custom_model"):
            model = options["custom_model"]
            logger.debug("Using custom torch model {}".format(model))
            return _global_registry.get(RLLIB_MODEL,
                                        model)(obs_space, num_outputs, options)

        if options.get("use_lstm"):
            raise NotImplementedError(
                "LSTM auto-wrapping not implemented for torch")

        if default_model_cls:
            return default_model_cls(obs_space, num_outputs, options)

        if isinstance(obs_space, gym.spaces.Discrete):
            obs_rank = 1
        else:
            obs_rank = len(obs_space.shape)

        if obs_rank > 1:
            return PyTorchVisionNet(obs_space, num_outputs, options)

        return PyTorchFCNet(obs_space, num_outputs, options)
Ejemplo n.º 4
0
    def get_torch_model(input_shape, num_outputs, options=dict()):
        """Returns a PyTorch suitable model.

        Args:
            input_shape (tup): The input shape to the model.
            num_outputs (int): The size of the output vector of the model.
            options (dict): Optional args to pass to the model constructor.

        Returns:
            model (Model): Neural network model.
        """
        from ray.rllib.models.pytorch.fcnet import (FullyConnectedNetwork as
                                                    PyTorchFCNet)
        from ray.rllib.models.pytorch.visionnet import (VisionNetwork as
                                                        PyTorchVisionNet)

        obs_rank = len(input_shape) - 1

        if obs_rank > 1:
            return PyTorchVisionNet(input_shape, num_outputs, options)

        return PyTorchFCNet(input_shape[0], num_outputs, options)