def __init__(self, obs_space, action_space, num_outputs, model_config, name): """Initialize a TorchModelV2. Here is an example implementation for a subclass ``MyModelClass(TorchModelV2, nn.Module)``:: def __init__(self, *args, **kwargs): TorchModelV2.__init__(self, *args, **kwargs) nn.Module.__init__(self) self._hidden_layers = nn.Sequential(...) self._logits = ... self._value_branch = ... """ if not isinstance(self, nn.Module): raise ValueError( "Subclasses of TorchModelV2 must also inherit from " "nn.Module, e.g., MyModel(TorchModel, nn.Module)") ModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name, framework="torch")
def __init__(self, obs_space, action_space, num_outputs, model_config, name): """Initialize a TFModelV2. Here is an example implementation for a subclass ``MyModelClass(TFModelV2)``:: def __init__(self, *args, **kwargs): super(MyModelClass, self).__init__(*args, **kwargs) input_layer = tf.keras.layers.Input(...) hidden_layer = tf.keras.layers.Dense(...)(input_layer) output_layer = tf.keras.layers.Dense(...)(hidden_layer) value_layer = tf.keras.layers.Dense(...)(hidden_layer) self.base_model = tf.keras.Model( input_layer, [output_layer, value_layer]) self.register_variables(self.base_model.variables) """ ModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name, framework="tf") self.var_list = [] if tf.executing_eagerly(): self.graph = None else: self.graph = tf.get_default_graph()
def __init__(self, obs_space, action_space, num_outputs, model_config, name): ModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name, framework="tf")
def __init__(self, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, num_outputs: int, model_config: ModelConfigDict, name: str): """Initializes a JAXModelV2 instance.""" ModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name, framework="jax")
def __init__( self, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, num_outputs: int, model_config: ModelConfigDict, name: str, ): """Initialize a TorchModelV2. Here is an example implementation for a subclass ``MyModelClass(TorchModelV2, nn.Module)``:: def __init__(self, *args, **kwargs): TorchModelV2.__init__(self, *args, **kwargs) nn.Module.__init__(self) self._hidden_layers = nn.Sequential(...) self._logits = ... self._value_branch = ... """ if not isinstance(self, nn.Module): raise ValueError( "Subclasses of TorchModelV2 must also inherit from " "nn.Module, e.g., MyModel(TorchModelV2, nn.Module)") ModelV2.__init__( self, obs_space, action_space, num_outputs, model_config, name, framework="torch", ) # Dict to store per multi-gpu tower stats into. # In PyTorch multi-GPU, we use a single TorchPolicy and copy # it's Model(s) n times (1 copy for each GPU). When computing the loss # on each tower, we cannot store the stats (e.g. `entropy`) inside the # policy object as this would lead to race conditions between the # different towers all accessing the same property at the same time. self.tower_stats = {}
def __init__(self, obs_space, action_space, num_outputs, model_config, name): """Initialize a TorchModelV2. Here is an example implementation for a subclass ``MyModelClass(TorchModelV2, nn.Module)``:: def __init__(self, *args, **kwargs): TorchModelV2.__init__(self, *args, **kwargs) nn.Module.__init__(self) self._hidden_layers = nn.Sequential(...) self._logits = ... self._value_branch = ... """ ModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name, framework="torch") nn.Module.__init__(self)