def __init__(self, obs_space, action_space, num_outputs, model_config, name): TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) if action_space != Tuple([Discrete(2), Discrete(2)]): raise ValueError( "This model only supports the [2, 2] action space") # Output of the model (normally 'logits', but for an autoregressive # dist this is more like a context/feature layer encoding the obs) self.context_layer = SlimFC( in_size=obs_space.shape[0], out_size=num_outputs, initializer=normc_init_torch(1.0), activation_fn=nn.Tanh, ) # V(s) self.value_branch = SlimFC( in_size=num_outputs, out_size=1, initializer=normc_init_torch(0.01), activation_fn=None, ) # P(a1 | obs) self.a1_logits = SlimFC(in_size=num_outputs, out_size=2, activation_fn=None, initializer=normc_init_torch(0.01)) class _ActionModel(nn.Module): def __init__(self): nn.Module.__init__(self) self.a2_hidden = SlimFC(in_size=1, out_size=16, activation_fn=nn.Tanh, initializer=normc_init_torch(1.0)) self.a2_logits = SlimFC(in_size=16, out_size=2, activation_fn=None, initializer=normc_init_torch(0.01)) def forward(self_, ctx_input, a1_input): a1_logits = self.a1_logits(ctx_input) a2_logits = self_.a2_logits(self_.a2_hidden(a1_input)) return a1_logits, a2_logits # P(a2 | a1) # --note: typically you'd want to implement P(a2 | a1, obs) as follows: # a2_context = tf.keras.layers.Concatenate(axis=1)( # [ctx_input, a1_input]) self.action_module = _ActionModel() self._context = None
def __init__(self): nn.Module.__init__(self) self.a2_hidden = SlimFC(in_size=1, out_size=16, activation_fn=nn.Tanh, initializer=normc_init_torch(1.0)) self.a2_logits = SlimFC(in_size=16, out_size=2, activation_fn=None, initializer=normc_init_torch(0.01))