def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        if action_space != Tuple([Discrete(2), Discrete(2)]):
            raise ValueError(
                "This model only supports the [2, 2] action space")

        # Output of the model (normally 'logits', but for an autoregressive
        # dist this is more like a context/feature layer encoding the obs)
        self.context_layer = SlimFC(
            in_size=obs_space.shape[0],
            out_size=num_outputs,
            initializer=normc_init_torch(1.0),
            activation_fn=nn.Tanh,
        )

        # V(s)
        self.value_branch = SlimFC(
            in_size=num_outputs,
            out_size=1,
            initializer=normc_init_torch(0.01),
            activation_fn=None,
        )

        # P(a1 | obs)
        self.a1_logits = SlimFC(in_size=num_outputs,
                                out_size=2,
                                activation_fn=None,
                                initializer=normc_init_torch(0.01))

        class _ActionModel(nn.Module):
            def __init__(self):
                nn.Module.__init__(self)
                self.a2_hidden = SlimFC(in_size=1,
                                        out_size=16,
                                        activation_fn=nn.Tanh,
                                        initializer=normc_init_torch(1.0))
                self.a2_logits = SlimFC(in_size=16,
                                        out_size=2,
                                        activation_fn=None,
                                        initializer=normc_init_torch(0.01))

            def forward(self_, ctx_input, a1_input):
                a1_logits = self.a1_logits(ctx_input)
                a2_logits = self_.a2_logits(self_.a2_hidden(a1_input))
                return a1_logits, a2_logits

        # P(a2 | a1)
        # --note: typically you'd want to implement P(a2 | a1, obs) as follows:
        # a2_context = tf.keras.layers.Concatenate(axis=1)(
        #     [ctx_input, a1_input])
        self.action_module = _ActionModel()

        self._context = None
 def __init__(self):
     nn.Module.__init__(self)
     self.a2_hidden = SlimFC(in_size=1,
                             out_size=16,
                             activation_fn=nn.Tanh,
                             initializer=normc_init_torch(1.0))
     self.a2_logits = SlimFC(in_size=16,
                             out_size=2,
                             activation_fn=None,
                             initializer=normc_init_torch(0.01))