예제 #1
0
파일: atari.py 프로젝트: dzorlu/acme
 def __init__(self, num_actions: int):
     super().__init__(name='deep_impala_atari_network')
     self._embed = embedding.OAREmbedding(torso=vision.ResNetTorso(),
                                          num_actions=num_actions)
     self._core = snt.LSTM(256)
     self._head = snt.Sequential([
         snt.Linear(256),
         tf.nn.relu,
         policy_value.PolicyValueHead(num_actions),
     ])
     self._num_actions = num_actions
예제 #2
0
파일: atari.py 프로젝트: dzorlu/acme
 def __init__(self, num_actions: int):
     super().__init__(name='r2d2_atari_network')
     self._embed = embedding.OAREmbedding(torso=AtariTorso(),
                                          num_actions=num_actions)
     self._core = snt.LSTM(256)
     self._head = duelling.DuellingMLP(num_actions, hidden_sizes=[256])
예제 #3
0
파일: atari.py 프로젝트: deepmind/acme
 def __init__(self, num_actions: int, core: Optional[base.RNNCore] = None):
     super().__init__(name='r2d2_atari_network')
     self._embed = embedding.OAREmbedding(torso=AtariTorso(),
                                          num_actions=num_actions)
     self._core = core if core is not None else recurrence.LSTM(512)
     self._head = duelling.DuellingMLP(num_actions, hidden_sizes=[512])