Esempio n. 1
0
def make_network() -> hk.RNNCore:
    """Defines the network architecture."""
    model = hk.DeepRNN([
        lambda x: jax.nn.one_hot(x, num_classes=dataset.NUM_CHARS),
        hk.LSTM(FLAGS.hidden_size),
        jax.nn.relu,
        hk.LSTM(FLAGS.hidden_size),
        hk.nets.MLP([FLAGS.hidden_size, dataset.NUM_CHARS]),
    ])
    return model
Esempio n. 2
0
 def __init__(self, vocab_size, lstm_dim, dropout_rate, is_training=True):
   super().__init__()
   self.is_training = is_training
   self.embed = hk.Embed(vocab_size, lstm_dim)
   self.conv1 = hk.Conv1D(lstm_dim, 3, padding='SAME')
   self.conv2 = hk.Conv1D(lstm_dim, 3, padding='SAME')
   self.conv3 = hk.Conv1D(lstm_dim, 3, padding='SAME')
   self.bn1 = hk.BatchNorm(True, True, 0.9)
   self.bn2 = hk.BatchNorm(True, True, 0.9)
   self.bn3 = hk.BatchNorm(True, True, 0.9)
   self.lstm_fwd = hk.LSTM(lstm_dim)
   self.lstm_bwd = hk.ResetCore(hk.LSTM(lstm_dim))
   self.dropout_rate = dropout_rate
Esempio n. 3
0
  def __init__(self, is_training=True):
    super().__init__()
    self.is_training = is_training
    self.encoder = TokenEncoder(FLAGS.vocab_size, FLAGS.acoustic_encoder_dim, 0.5, is_training)
    self.decoder = hk.deep_rnn_with_skip_connections([
        hk.LSTM(FLAGS.acoustic_decoder_dim),
        hk.LSTM(FLAGS.acoustic_decoder_dim)
    ])
    self.projection = hk.Linear(FLAGS.mel_dim)

    # prenet
    self.prenet_fc1 = hk.Linear(256, with_bias=True)
    self.prenet_fc2 = hk.Linear(256, with_bias=True)
    # posnet
    self.postnet_convs = [hk.Conv1D(FLAGS.postnet_dim, 5) for _ in range(4)] + [hk.Conv1D(FLAGS.mel_dim, 5)]
    self.postnet_bns = [hk.BatchNorm(True, True, 0.9) for _ in range(4)] + [None]
Esempio n. 4
0
    def forward_pass(batch):
        x = batch['x']
        # [time_steps, batch_size, ...].
        x = jnp.transpose(x)
        # [time_steps, batch_size, embed_dim].
        embedding_layer = hk.Embed(full_vocab_size, embed_size)
        embeddings = embedding_layer(x)

        lstm_layers = []
        for _ in range(lstm_num_layers):
            lstm_layers.extend([
                hk.LSTM(hidden_size=lstm_hidden_size),
                jnp.tanh,
                # Projection changes dimension from lstm_hidden_size to embed_size.
                hk.Linear(embed_size)
            ])
        rnn_core = hk.DeepRNN(lstm_layers)
        initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1])
        # [time_steps, batch_size, hidden_size].
        output, _ = hk.static_unroll(rnn_core, embeddings, initial_state)

        if share_input_output_embeddings:
            output = jnp.dot(output, jnp.transpose(embedding_layer.embeddings))
            output = hk.Bias(bias_dims=[-1])(output)
        else:
            output = hk.Linear(full_vocab_size)(output)
        # [batch_size, time_steps, full_vocab_size].
        output = jnp.transpose(output, axes=(1, 0, 2))
        return output
Esempio n. 5
0
 def __init__(self, num_actions: int):
     super().__init__(name='r2d2_atari_network')
     self._embed = embedding.OAREmbedding(DeepAtariTorso(), num_actions)
     self._core = hk.LSTM(512)
     self._duelling_head = duelling.DuellingMLP(num_actions,
                                                hidden_sizes=[512])
     self._num_actions = num_actions
Esempio n. 6
0
        def unroll_fn(observations, state):
            lstm = hk.LSTM(hidden_size)
            embedding, state = hk.dynamic_unroll(lstm, observations, state)
            logits = hk.Linear(num_actions)(embedding)
            values = jnp.squeeze(hk.Linear(1)(embedding), axis=-1)

            return (logits, values), state
Esempio n. 7
0
def network_definition(graph: jraph.GraphsTuple) -> jraph.ArrayTree:
    """`InteractionNetwork` with an LSTM in the edge update."""

    # LSTM that will keep a memory of the inputs to the edge model.
    edge_fn_lstm = hk.LSTM(hidden_size=HIDDEN_SIZE)

    # MLPs used in the edge and the node model. Note that in this instance
    # the output size matches the input size so the same model can be run
    # iteratively multiple times. In a real model, this would usually be achieved
    # by first using an encoder in the input data into a common `EMBEDDING_SIZE`.
    edge_fn_mlp = hk.nets.MLP([HIDDEN_SIZE, EMBEDDING_SIZE])
    node_fn_mlp = hk.nets.MLP([HIDDEN_SIZE, EMBEDDING_SIZE])

    # Initialize the edge features to contain both the input edge embedding
    # and initial LSTM state. Note for the nodes we only have an embedding since
    # in this example nodes do not use a `node_fn_lstm`, but for analogy, we
    # still put it in a `StatefulField`.
    graph = graph._replace(
        edges=StatefulField(embedding=graph.edges,
                            state=edge_fn_lstm.initial_state(
                                graph.edges.shape[0])),
        nodes=StatefulField(embedding=graph.nodes, state=None),
    )

    def update_edge_fn(edges, sender_nodes, receiver_nodes):
        # We will run an LSTM memory on the inputs first, and then
        # process the output of the LSTM with an MLP.
        edge_inputs = jnp.concatenate([
            edges.embedding, sender_nodes.embedding, receiver_nodes.embedding
        ],
                                      axis=-1)
        lstm_output, updated_state = edge_fn_lstm(edge_inputs, edges.state)
        updated_edges = StatefulField(
            embedding=edge_fn_mlp(lstm_output),
            state=updated_state,
        )
        return updated_edges

    def update_node_fn(nodes, received_edges):
        # Note `received_edges.state` will also contain the aggregated state for
        # all received edges, which we may choose to use in the node update.
        node_inputs = jnp.concatenate(
            [nodes.embedding, received_edges.embedding], axis=-1)
        updated_nodes = StatefulField(embedding=node_fn_mlp(node_inputs),
                                      state=None)
        return updated_nodes

    recurrent_graph_network = jraph.InteractionNetwork(
        update_edge_fn=update_edge_fn, update_node_fn=update_node_fn)

    # Apply the model recurrently for 10 message passing steps.
    # If instead we intended to use the LSTM to process a sequence of features
    # for each node/edge, here we would select the corresponding inputs from the
    # sequence along the sequence axis of the nodes/edges features to build the
    # correct input graph for each step of the iteration.
    num_message_passing_steps = 10
    for _ in range(num_message_passing_steps):
        graph = recurrent_graph_network(graph)

    return graph
Esempio n. 8
0
 def __init__(self, num_actions: int):
     super().__init__(name='my_network')
     self._torso = hk.Sequential([
         lambda x: jnp.reshape(x, [np.prod(x.shape)]),
         hk.nets.MLP([50, 50]),
     ])
     self._core = hk.LSTM(20)
     self._head = networks.PolicyValueHead(num_actions)
Esempio n. 9
0
  def network(inputs: jnp.ndarray,
              state) -> Tuple[Tuple[Logits, Value], LSTMState]:
    flat_inputs = hk.Flatten()(inputs)
    torso = hk.nets.MLP([hidden_size, hidden_size])
    lstm = hk.LSTM(hidden_size)
    policy_head = hk.Linear(action_spec.num_values)
    value_head = hk.Linear(1)

    embedding = torso(flat_inputs)
    embedding, state = lstm(embedding, state)
    logits = policy_head(embedding)
    value = value_head(embedding)
    return (logits, jnp.squeeze(value, axis=-1)), state
Esempio n. 10
0
  def forward_pass(batch):
    x = batch['x']
    # [time_steps, batch_size, ...].
    x = jnp.transpose(x)
    # [time_steps, batch_size, embed_dim].
    embedding_layer = hk.Embed(full_vocab_size, embed_size)
    embeddings = embedding_layer(x)

    lstm_layers = []
    for _ in range(lstm_num_layers):
      lstm_layers.extend([hk.LSTM(hidden_size=lstm_hidden_size), jnp.tanh])
    rnn_core = hk.DeepRNN(lstm_layers)
    initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1])
    # [time_steps, batch_size, hidden_size].
    output, _ = hk.static_unroll(rnn_core, embeddings, initial_state)

    output = hk.Linear(full_vocab_size)(output)
    # [batch_size, time_steps, full_vocab_size].
    output = jnp.transpose(output, axes=(1, 0, 2))
    return output
Esempio n. 11
0
def make_flexible_recurrent_net(core_type: str,
                                net_type: str,
                                output_dims: int,
                                num_units: Union[Sequence[int], int],
                                num_layers: Optional[int],
                                activation: Activation,
                                activate_final: bool = False,
                                name: Optional[str] = None,
                                **unused_kwargs):
    """Commonly used for creating a flexible recurrences."""
    if net_type != "mlp":
        raise ValueError("We do not support convolutional recurrent nets atm.")
    if unused_kwargs:
        logging.warning("Unused kwargs of `make_flexible_recurrent_net`: %s",
                        str(unused_kwargs))

    if isinstance(num_units, (list, tuple)):
        num_units = list(num_units) + [output_dims]
        num_layers = len(num_units)
    else:
        assert num_layers is not None
        num_units = [num_units] * (num_layers - 1) + [output_dims]
    name = name or f"{core_type.upper()}"

    activation = utils.get_activation(activation)
    core_list = []
    for i, n in enumerate(num_units):
        if core_type.lower() == "vanilla":
            core_list.append(hk.VanillaRNN(hidden_size=n, name=f"{name}_{i}"))
        elif core_type.lower() == "lstm":
            core_list.append(hk.LSTM(hidden_size=n, name=f"{name}_{i}"))
        elif core_type.lower() == "gru":
            core_list.append(hk.GRU(hidden_size=n, name=f"{name}_{i}"))
        else:
            raise ValueError(f"Unrecognized core_type={core_type}.")
        if i != num_layers - 1:
            core_list.append(activation)
    if activate_final:
        core_list.append(activation)

    return hk.DeepRNN(core_list, name="RNN")
Esempio n. 12
0
 def forward(batch, is_training):
   x, _ = batch
   batch_size = x.shape[0]
   x = hk.Embed(vocab_size=max_features, embed_dim=embedding_size)(x)
   x = hk.Conv1D(output_channels=num_filters, kernel_shape=kernel_size,
                 padding="VALID")(x)
   if use_swish:
       x = jax.nn.swish(x)
   else:
       x = jax.nn.relu(x)
   if use_maxpool:
       x = hk.MaxPool(
           window_shape=pool_size, strides=pool_size, padding='VALID',
           channel_axis=2)(x)
   x = jnp.moveaxis(x, 1, 0)[:, :] #[T, B, F]
   lstm_layer = hk.LSTM(hidden_size=cell_size)
   init_state = lstm_layer.initial_state(batch_size)
   x, state = hk.static_unroll(lstm_layer, x, init_state)
   x = x[-1]
   logits = hk.Linear(num_classes)(x)
   return logits
Esempio n. 13
0
    def __init__(self,
                 input_dim,
                 hidden_dim,
                 latent_dim,
                 max_num_segments,
                 temp_b=1.,
                 temp_z=1.,
                 latent_dist='gaussian',
                 name='compile'):
        super().__init__(name=name)

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.max_num_segments = max_num_segments
        self.temp_b = temp_b
        self.temp_z = temp_z
        self.latent_dist = latent_dist

        self.embed = hk.Embed(input_dim, hidden_dim)
        self.lstm_cell = hk.LSTM(hidden_dim)

        # LSTM output heads.
        self.head_z_1 = hk.Linear(hidden_dim)  # Latents (z).

        if latent_dist == 'gaussian':
            self.head_z_2 = hk.Linear(latent_dim * 2)
        elif latent_dist == 'concrete':
            self.head_z_2 = hk.Linear(latent_dim)
        else:
            raise ValueError('Invalid argument for `latent_dist`.')

        self.head_b_1 = hk.Linear(hidden_dim)  # Boundaries (b).
        self.head_b_2 = hk.Linear(1)

        # Decoder MLP.
        self.decode_1 = hk.Linear(hidden_dim)
        self.decode_2 = hk.Linear(input_dim)
Esempio n. 14
0
 def initial_state(batch_size: int):
   network = hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])
   return network.initial_state(batch_size)
Esempio n. 15
0
 def network(inputs: jnp.ndarray, state: hk.LSTMState):
   return hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])(inputs, state)
Esempio n. 16
0
 def initial_state_fn():
     return hk.LSTM(hidden_size).initial_state(None)
Esempio n. 17
0
 def __init__(self, num_actions: int):
     super().__init__(name='impala_atari_network')
     self._embed = embedding.OAREmbedding(DeepAtariTorso(), num_actions)
     self._core = hk.LSTM(256)
     self._head = policy_value.PolicyValueHead(num_actions)
     self._num_actions = num_actions
Esempio n. 18
0
 def initial_state(batch_size: Optional[int] = None):
   network = hk.DeepRNN([hk.Reshape([-1], preserve_dims=1),
                         hk.LSTM(output_size)])
   return network.initial_state(batch_size)
Esempio n. 19
0
        hk.Flatten(),
        hk.Linear(32),
        jax.nn.relu,
        hk.Linear(10),
    ])(features)


def lstm_model(x, vocab_size=10_000, seq_len=256, args=None, **_):
    embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
    token_embedding_map = hk.Embed(vocab_size + 4, 100, w_init=embed_init)
    o2 = token_embedding_map(x)

    o2 = jnp.reshape(o2, (o2.shape[1], o2.shape[0], o2.shape[2]))

    # LSTM Part of Network
    core = hk.LSTM(100)
    if args and args.dynamic_unroll:
        outs, state = hk.dynamic_unroll(core, o2,
                                        core.initial_state(x.shape[0]))
    else:
        outs, state = hk.static_unroll(core, o2,
                                       core.initial_state(x.shape[0]))
    outs = outs.reshape(outs.shape[1], outs.shape[0], outs.shape[2])

    # Avg Pool -> Linear
    red_dim_outs = hk.avg_pool(outs, seq_len, seq_len, "SAME").squeeze()
    final_layer = hk.Linear(2)
    ret = final_layer(red_dim_outs)

    return ret
Esempio n. 20
0
 def __init__(self, num_actions, use_resnet, use_lstm, name=None):
     super(AtariNet, self).__init__(name=name)
     self._num_actions = num_actions
     self._use_resnet = use_resnet
     self._use_lstm = use_lstm
     self._core = hk.ResetCore(hk.LSTM(256))
Esempio n. 21
0
 def initial_state(batch_size: Optional[int] = None):
     network = hk.DeepRNN(
         [lambda x: jnp.reshape(x, [-1]),
          hk.LSTM(output_size)])
     return network.initial_state(batch_size)
Esempio n. 22
0
 def network(inputs: jnp.ndarray, state: hk.LSTMState):
   return hk.DeepRNN([hk.Reshape([-1], preserve_dims=1),
                      hk.LSTM(output_size)])(inputs, state)
Esempio n. 23
0
RNN_CORES = (
    ModuleDescriptor(
        name="ResetCore",
        create=lambda: ResetCoreAdapter(hk.ResetCore(DummyCore())),
        shape=(BATCH_SIZE, 128)),
    ModuleDescriptor(
        name="GRU",
        create=lambda: hk.GRU(1),
        shape=(BATCH_SIZE, 128)),
    ModuleDescriptor(
        name="IdentityCore",
        create=lambda: hk.IdentityCore(),
        shape=(BATCH_SIZE, 128)),
    ModuleDescriptor(
        name="LSTM",
        create=lambda: hk.LSTM(1),
        shape=(BATCH_SIZE, 128)),
    ModuleDescriptor(
        name="Conv1DLSTM",
        create=lambda: hk.Conv1DLSTM([2], 3, 3),
        shape=(BATCH_SIZE, 2, 2)),
    ModuleDescriptor(
        name="Conv2DLSTM",
        create=lambda: hk.Conv2DLSTM([2, 2], 3, 3),
        shape=(BATCH_SIZE, 2, 2, 2)),
    ModuleDescriptor(
        name="Conv3DLSTM",
        create=lambda: hk.Conv3DLSTM([2, 2, 2], 3, 3),
        shape=(BATCH_SIZE, 2, 2, 2, 2)),
    ModuleDescriptor(
        name="VanillaRNN",
Esempio n. 24
0
 def network(inputs: jnp.ndarray, state: hk.LSTMState):
     return hk.DeepRNN(
         [lambda x: jnp.reshape(x, [-1]),
          hk.LSTM(output_size)])(inputs, state)