def __init__(self, num_actions: int): super().__init__(name='impala_atari_network') self._embed = embedding.OAREmbedding( DeepAtariTorso(use_layer_norm=True), num_actions) self._core = hk.GRU(256) self._head = policy_value.PolicyValueHead(num_actions) self._num_actions = num_actions
def network(inputs: List[jnp.ndarray], state) -> ModelOutput: observation = hk.Flatten()(inputs[0]).reshape((1, -1)) previous_reward = inputs[1].reshape((1, 1)) previous_action = inputs[2].reshape((1, -1)) torso = hk.nets.MLP(encoding_hidden_size) gru = hk.GRU(rnn_hidden_size) policy_head = hk.Linear(action_spec.num_values) value_head = hk.Linear(1) input_embedding = jnp.concatenate( [observation, previous_reward, previous_action], -1) input_embedding = torso(input_embedding) embedding, state = gru(input_embedding, state) logits = policy_head(embedding) value = value_head(embedding) return (logits, jnp.squeeze(value, axis=-1), embedding, embedding, embedding), state
def make_flexible_recurrent_net(core_type: str, net_type: str, output_dims: int, num_units: Union[Sequence[int], int], num_layers: Optional[int], activation: Activation, activate_final: bool = False, name: Optional[str] = None, **unused_kwargs): """Commonly used for creating a flexible recurrences.""" if net_type != "mlp": raise ValueError("We do not support convolutional recurrent nets atm.") if unused_kwargs: logging.warning("Unused kwargs of `make_flexible_recurrent_net`: %s", str(unused_kwargs)) if isinstance(num_units, (list, tuple)): num_units = list(num_units) + [output_dims] num_layers = len(num_units) else: assert num_layers is not None num_units = [num_units] * (num_layers - 1) + [output_dims] name = name or f"{core_type.upper()}" activation = utils.get_activation(activation) core_list = [] for i, n in enumerate(num_units): if core_type.lower() == "vanilla": core_list.append(hk.VanillaRNN(hidden_size=n, name=f"{name}_{i}")) elif core_type.lower() == "lstm": core_list.append(hk.LSTM(hidden_size=n, name=f"{name}_{i}")) elif core_type.lower() == "gru": core_list.append(hk.GRU(hidden_size=n, name=f"{name}_{i}")) else: raise ValueError(f"Unrecognized core_type={core_type}.") if i != num_layers - 1: core_list.append(activation) if activate_final: core_list.append(activation) return hk.DeepRNN(core_list, name="RNN")
def __call__(self, inputs, state): batch_size = inputs.shape[0] resets = np.broadcast_to(True, (batch_size,)) return self.wrapped((inputs, resets), state) # RNN cores. For shape, use the shape of a single example. RNN_CORES = ( ModuleDescriptor( name="ResetCore", create=lambda: ResetCoreAdapter(hk.ResetCore(DummyCore())), shape=(BATCH_SIZE, 128)), ModuleDescriptor( name="GRU", create=lambda: hk.GRU(1), shape=(BATCH_SIZE, 128)), ModuleDescriptor( name="IdentityCore", create=lambda: hk.IdentityCore(), shape=(BATCH_SIZE, 128)), ModuleDescriptor( name="LSTM", create=lambda: hk.LSTM(1), shape=(BATCH_SIZE, 128)), ModuleDescriptor( name="Conv1DLSTM", create=lambda: hk.Conv1DLSTM([2], 3, 3), shape=(BATCH_SIZE, 2, 2)), ModuleDescriptor( name="Conv2DLSTM",