def get_torch_model(obs_space, num_outputs, options=None, default_model_cls=None): """Returns a custom model for PyTorch algorithms. Args: obs_space (Space): The input observation space. num_outputs (int): The size of the output vector of the model. options (dict): Optional args to pass to the model constructor. default_model_cls (cls): Optional class to use if no custom model. Returns: model (models.Model): Neural network model. """ from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as PyTorchFCNet) from ray.rllib.models.torch.visionnet import (VisionNetwork as PyTorchVisionNet) options = options or MODEL_DEFAULTS if options.get("custom_model"): model = options["custom_model"] logger.debug("Using custom torch model {}".format(model)) return _global_registry.get(RLLIB_MODEL, model)(obs_space, num_outputs, options) if options.get("use_lstm"): raise NotImplementedError( "LSTM auto-wrapping not implemented for torch") if default_model_cls: return default_model_cls(obs_space, num_outputs, options) if isinstance(obs_space, gym.spaces.Discrete): obs_rank = 1 else: obs_rank = len(obs_space.shape) if obs_rank > 1: return PyTorchVisionNet(obs_space, num_outputs, options) return PyTorchFCNet(obs_space, num_outputs, options)
def _get_default_torch_model_v2(obs_space, action_space, num_outputs, model_config, name): from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as PyTorchFCNet) from ray.rllib.models.torch.visionnet import (VisionNetwork as PyTorchVisionNet) model_config = model_config or MODEL_DEFAULTS if model_config.get("use_lstm"): raise NotImplementedError( "LSTM auto-wrapping not implemented for torch") if isinstance(obs_space, gym.spaces.Discrete): obs_rank = 1 else: obs_rank = len(obs_space.shape) if obs_rank > 2: return PyTorchVisionNet(obs_space, action_space, num_outputs, model_config, name) return PyTorchFCNet(obs_space, action_space, num_outputs, model_config, name)