def make_dqn(num_actions: int): return snt.Sequential([ snt.Conv2D(32, [3, 3], [2, 2]), tf.nn.relu, snt.Conv2D(32, [3, 3], [2, 2]), tf.nn.relu, snt.Conv2D(32, [3, 3], [2, 2]), tf.nn.relu, snt.Conv2D(32, [3, 3], [2, 2]), tf.nn.relu, snt.Flatten(), duelling.DuellingMLP(num_actions, hidden_sizes=[512]), ])
def __init__(self, num_actions: int): super().__init__(name='r2d2_atari_network') self._embed = embedding.OAREmbedding(torso=AtariTorso(), num_actions=num_actions) self._core = snt.LSTM(256) self._head = duelling.DuellingMLP(num_actions, hidden_sizes=[256])
def __init__(self, num_actions: int): super().__init__(name='dqn_atari_network') self._network = snt.Sequential([ AtariTorso(), duelling.DuellingMLP(num_actions, hidden_sizes=[512]), ])
def __init__(self, num_actions: int, core: Optional[base.RNNCore] = None): super().__init__(name='r2d2_atari_network') self._embed = embedding.OAREmbedding(torso=AtariTorso(), num_actions=num_actions) self._core = core if core is not None else recurrence.LSTM(512) self._head = duelling.DuellingMLP(num_actions, hidden_sizes=[512])