Exemple #1
0
    def get_model(registry, inputs, num_outputs, options=dict()):
        """Returns a suitable model conforming to given input and output specs.

        Args:
            registry (obj): Registry of named objects (ray.tune.registry).
            inputs (Tensor): The input tensor to the model.
            num_outputs (int): The size of the output vector of the model.
            options (dict): Optional args to pass to the model constructor.

        Returns:
            model (Model): Neural network model.
        """

        if "custom_model" in options:
            model = options["custom_model"]
            print("Using custom model {}".format(model))
            return registry.get(RLLIB_MODEL, model)(inputs, num_outputs,
                                                    options)

        obs_rank = len(inputs.shape) - 1

        if obs_rank > 1:
            return VisionNetwork(inputs, num_outputs, options)

        return FullyConnectedNetwork(inputs, num_outputs, options)
Exemple #2
0
    def get_model(registry, inputs, num_outputs, options={}):
        """Returns a suitable model conforming to given input and output specs.

        Args:
            registry (obj): Registry of named objects (ray.tune.registry).
            inputs (Tensor): The input tensor to the model.
            num_outputs (int): The size of the output vector of the model.
            options (dict): Optional args to pass to the model constructor.

        Returns:
            model (Model): Neural network model.
        """

        if "custom_model" in options:
            model = options["custom_model"]
            print("Using custom model {}".format(model))
            return registry.get(RLLIB_MODEL, model)(inputs, num_outputs,
                                                    options)

        obs_rank = len(inputs.shape) - 1

        # num_outputs > 1 used to avoid hitting this with the value function
        if isinstance(
                options.get("custom_options", {}).get(
                    "multiagent_fcnet_hiddens", 1), list) and num_outputs > 1:
            return MultiAgentFullyConnectedNetwork(inputs, num_outputs,
                                                   options)

        if obs_rank > 1:
            return VisionNetwork(inputs, num_outputs, options)

        return FullyConnectedNetwork(inputs, num_outputs, options)
Exemple #3
0
    def _get_model(inputs, num_outputs, options, state_in, seq_lens):
        if "custom_model" in options:
            model = options["custom_model"]
            print("Using custom model {}".format(model))
            return _global_registry.get(RLLIB_MODEL, model)(inputs,
                                                            num_outputs,
                                                            options,
                                                            state_in=state_in,
                                                            seq_lens=seq_lens)

        obs_rank = len(inputs.shape) - 1

        if obs_rank > 1:
            return VisionNetwork(inputs, num_outputs, options)

        return FullyConnectedNetwork(inputs, num_outputs, options)
Exemple #4
0
    def get_model(inputs, num_outputs):
        """Returns a suitable model conforming to given input and output specs.

        Args:
            inputs (Tensor): The input tensor to the model.
            num_outputs (int): The size of the output vector of the model.

        Returns:
            model (Model): Neural network model.
        """

        obs_rank = len(inputs.get_shape()) - 1

        if obs_rank > 1:
            return VisionNetwork(inputs, num_outputs)

        return FullyConnectedNetwork(inputs, num_outputs)
Exemple #5
0
    def _get_model(input_dict, obs_space, num_outputs, options, state_in,
                   seq_lens):
        if options.get("custom_model"):
            model = options["custom_model"]
            logger.debug("Using custom model {}".format(model))
            return _global_registry.get(RLLIB_MODEL, model)(input_dict,
                                                            obs_space,
                                                            num_outputs,
                                                            options,
                                                            state_in=state_in,
                                                            seq_lens=seq_lens)

        obs_rank = len(input_dict["obs"].shape) - 1

        if obs_rank > 1:
            return VisionNetwork(input_dict, obs_space, num_outputs, options)

        return FullyConnectedNetwork(input_dict, obs_space, num_outputs,
                                     options)
    def _get_model(inputs, num_outputs, options):
        if "custom_model" in options:
            model = options["custom_model"]
            print("Using custom model {}".format(model))
            return _global_registry.get(RLLIB_MODEL,
                                        model)(inputs, num_outputs, options)

        obs_rank = len(inputs.shape) - 1

        # num_outputs > 1 used to avoid hitting this with the value function
        if isinstance(
                options.get("custom_options", {}).get(
                    "multiagent_fcnet_hiddens", 1), list) and num_outputs > 1:
            return MultiAgentFullyConnectedNetwork(inputs, num_outputs,
                                                   options)

        if obs_rank > 1:
            return VisionNetwork(inputs, num_outputs, options)

        return FullyConnectedNetwork(inputs, num_outputs, options)
Exemple #7
0
    def get_model(inputs, num_outputs, options=None):
        """Returns a suitable model conforming to given input and output specs.

        Args:
            inputs (Tensor): The input tensor to the model.
            num_outputs (int): The size of the output vector of the model.
            options (dict): Optional args to pass to the model constructor.

        Returns:
            model (Model): Neural network model.
        """

        if options is None:
            options = {}

        obs_rank = len(inputs.get_shape()) - 1

        if obs_rank > 1:
            return VisionNetwork(inputs, num_outputs, options)

        return FullyConnectedNetwork(inputs, num_outputs, options)