Ejemplo n.º 1
0
 def __init__(self, num_actions: int):
     super().__init__(name='impala_atari_network')
     self._embed = embedding.OAREmbedding(
         DeepAtariTorso(use_layer_norm=True), num_actions)
     self._core = hk.GRU(256)
     self._head = policy_value.PolicyValueHead(num_actions)
     self._num_actions = num_actions
Ejemplo n.º 2
0
    def network(inputs: List[jnp.ndarray], state) -> ModelOutput:
        observation = hk.Flatten()(inputs[0]).reshape((1, -1))
        previous_reward = inputs[1].reshape((1, 1))
        previous_action = inputs[2].reshape((1, -1))

        torso = hk.nets.MLP(encoding_hidden_size)
        gru = hk.GRU(rnn_hidden_size)
        policy_head = hk.Linear(action_spec.num_values)
        value_head = hk.Linear(1)

        input_embedding = jnp.concatenate(
            [observation, previous_reward, previous_action], -1)
        input_embedding = torso(input_embedding)
        embedding, state = gru(input_embedding, state)
        logits = policy_head(embedding)
        value = value_head(embedding)

        return (logits, jnp.squeeze(value, axis=-1), embedding, embedding,
                embedding), state
Ejemplo n.º 3
0
def make_flexible_recurrent_net(core_type: str,
                                net_type: str,
                                output_dims: int,
                                num_units: Union[Sequence[int], int],
                                num_layers: Optional[int],
                                activation: Activation,
                                activate_final: bool = False,
                                name: Optional[str] = None,
                                **unused_kwargs):
    """Commonly used for creating a flexible recurrences."""
    if net_type != "mlp":
        raise ValueError("We do not support convolutional recurrent nets atm.")
    if unused_kwargs:
        logging.warning("Unused kwargs of `make_flexible_recurrent_net`: %s",
                        str(unused_kwargs))

    if isinstance(num_units, (list, tuple)):
        num_units = list(num_units) + [output_dims]
        num_layers = len(num_units)
    else:
        assert num_layers is not None
        num_units = [num_units] * (num_layers - 1) + [output_dims]
    name = name or f"{core_type.upper()}"

    activation = utils.get_activation(activation)
    core_list = []
    for i, n in enumerate(num_units):
        if core_type.lower() == "vanilla":
            core_list.append(hk.VanillaRNN(hidden_size=n, name=f"{name}_{i}"))
        elif core_type.lower() == "lstm":
            core_list.append(hk.LSTM(hidden_size=n, name=f"{name}_{i}"))
        elif core_type.lower() == "gru":
            core_list.append(hk.GRU(hidden_size=n, name=f"{name}_{i}"))
        else:
            raise ValueError(f"Unrecognized core_type={core_type}.")
        if i != num_layers - 1:
            core_list.append(activation)
    if activate_final:
        core_list.append(activation)

    return hk.DeepRNN(core_list, name="RNN")
Ejemplo n.º 4
0
  def __call__(self, inputs, state):
    batch_size = inputs.shape[0]
    resets = np.broadcast_to(True, (batch_size,))
    return self.wrapped((inputs, resets), state)


# RNN cores. For shape, use the shape of a single example.
RNN_CORES = (
    ModuleDescriptor(
        name="ResetCore",
        create=lambda: ResetCoreAdapter(hk.ResetCore(DummyCore())),
        shape=(BATCH_SIZE, 128)),
    ModuleDescriptor(
        name="GRU",
        create=lambda: hk.GRU(1),
        shape=(BATCH_SIZE, 128)),
    ModuleDescriptor(
        name="IdentityCore",
        create=lambda: hk.IdentityCore(),
        shape=(BATCH_SIZE, 128)),
    ModuleDescriptor(
        name="LSTM",
        create=lambda: hk.LSTM(1),
        shape=(BATCH_SIZE, 128)),
    ModuleDescriptor(
        name="Conv1DLSTM",
        create=lambda: hk.Conv1DLSTM([2], 3, 3),
        shape=(BATCH_SIZE, 2, 2)),
    ModuleDescriptor(
        name="Conv2DLSTM",