示例#1
0
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        """Initialize a TorchModelV2.

        Here is an example implementation for a subclass
        ``MyModelClass(TorchModelV2, nn.Module)``::

            def __init__(self, *args, **kwargs):
                TorchModelV2.__init__(self, *args, **kwargs)
                nn.Module.__init__(self)
                self._hidden_layers = nn.Sequential(...)
                self._logits = ...
                self._value_branch = ...
        """

        if not isinstance(self, nn.Module):
            raise ValueError(
                "Subclasses of TorchModelV2 must also inherit from "
                "nn.Module, e.g., MyModel(TorchModel, nn.Module)")

        ModelV2.__init__(self,
                         obs_space,
                         action_space,
                         num_outputs,
                         model_config,
                         name,
                         framework="torch")
示例#2
0
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        """Initialize a TFModelV2.

        Here is an example implementation for a subclass
        ``MyModelClass(TFModelV2)``::

            def __init__(self, *args, **kwargs):
                super(MyModelClass, self).__init__(*args, **kwargs)
                input_layer = tf.keras.layers.Input(...)
                hidden_layer = tf.keras.layers.Dense(...)(input_layer)
                output_layer = tf.keras.layers.Dense(...)(hidden_layer)
                value_layer = tf.keras.layers.Dense(...)(hidden_layer)
                self.base_model = tf.keras.Model(
                    input_layer, [output_layer, value_layer])
                self.register_variables(self.base_model.variables)
        """

        ModelV2.__init__(self,
                         obs_space,
                         action_space,
                         num_outputs,
                         model_config,
                         name,
                         framework="tf")
        self.var_list = []
        if tf.executing_eagerly():
            self.graph = None
        else:
            self.graph = tf.get_default_graph()
示例#3
0
 def __init__(self, obs_space, action_space, num_outputs, model_config,
              name):
     ModelV2.__init__(self,
                      obs_space,
                      action_space,
                      num_outputs,
                      model_config,
                      name,
                      framework="tf")
示例#4
0
    def __init__(self, obs_space: gym.spaces.Space,
                 action_space: gym.spaces.Space, num_outputs: int,
                 model_config: ModelConfigDict, name: str):
        """Initializes a JAXModelV2 instance."""

        ModelV2.__init__(self,
                         obs_space,
                         action_space,
                         num_outputs,
                         model_config,
                         name,
                         framework="jax")
示例#5
0
    def __init__(
        self,
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        num_outputs: int,
        model_config: ModelConfigDict,
        name: str,
    ):
        """Initialize a TorchModelV2.

        Here is an example implementation for a subclass
        ``MyModelClass(TorchModelV2, nn.Module)``::

            def __init__(self, *args, **kwargs):
                TorchModelV2.__init__(self, *args, **kwargs)
                nn.Module.__init__(self)
                self._hidden_layers = nn.Sequential(...)
                self._logits = ...
                self._value_branch = ...
        """

        if not isinstance(self, nn.Module):
            raise ValueError(
                "Subclasses of TorchModelV2 must also inherit from "
                "nn.Module, e.g., MyModel(TorchModelV2, nn.Module)")

        ModelV2.__init__(
            self,
            obs_space,
            action_space,
            num_outputs,
            model_config,
            name,
            framework="torch",
        )

        # Dict to store per multi-gpu tower stats into.
        # In PyTorch multi-GPU, we use a single TorchPolicy and copy
        # it's Model(s) n times (1 copy for each GPU). When computing the loss
        # on each tower, we cannot store the stats (e.g. `entropy`) inside the
        # policy object as this would lead to race conditions between the
        # different towers all accessing the same property at the same time.
        self.tower_stats = {}
示例#6
0
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        """Initialize a TorchModelV2.

        Here is an example implementation for a subclass
        ``MyModelClass(TorchModelV2, nn.Module)``::

            def __init__(self, *args, **kwargs):
                TorchModelV2.__init__(self, *args, **kwargs)
                nn.Module.__init__(self)
                self._hidden_layers = nn.Sequential(...)
                self._logits = ...
                self._value_branch = ...
        """

        ModelV2.__init__(self,
                         obs_space,
                         action_space,
                         num_outputs,
                         model_config,
                         name,
                         framework="torch")
        nn.Module.__init__(self)