Пример #1
0
    def forward_pass(batch):
        x = batch['x']
        # [time_steps, batch_size, ...].
        x = jnp.transpose(x)
        # [time_steps, batch_size, embed_dim].
        embedding_layer = hk.Embed(full_vocab_size, embed_size)
        embeddings = embedding_layer(x)

        lstm_layers = []
        for _ in range(lstm_num_layers):
            lstm_layers.extend([
                hk.LSTM(hidden_size=lstm_hidden_size),
                jnp.tanh,
                # Projection changes dimension from lstm_hidden_size to embed_size.
                hk.Linear(embed_size)
            ])
        rnn_core = hk.DeepRNN(lstm_layers)
        initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1])
        # [time_steps, batch_size, hidden_size].
        output, _ = hk.static_unroll(rnn_core, embeddings, initial_state)

        if share_input_output_embeddings:
            output = jnp.dot(output, jnp.transpose(embedding_layer.embeddings))
            output = hk.Bias(bias_dims=[-1])(output)
        else:
            output = hk.Linear(full_vocab_size)(output)
        # [batch_size, time_steps, full_vocab_size].
        output = jnp.transpose(output, axes=(1, 0, 2))
        return output
Пример #2
0
def make_network() -> hk.RNNCore:
    """Defines the network architecture."""
    model = hk.DeepRNN([
        lambda x: jax.nn.one_hot(x, num_classes=dataset.NUM_CHARS),
        hk.LSTM(FLAGS.hidden_size),
        jax.nn.relu,
        hk.LSTM(FLAGS.hidden_size),
        hk.nets.MLP([FLAGS.hidden_size, dataset.NUM_CHARS]),
    ])
    return model
Пример #3
0
  def forward_pass(batch):
    x = batch['x']
    # [time_steps, batch_size, ...].
    x = jnp.transpose(x)
    # [time_steps, batch_size, embed_dim].
    embedding_layer = hk.Embed(full_vocab_size, embed_size)
    embeddings = embedding_layer(x)

    lstm_layers = []
    for _ in range(lstm_num_layers):
      lstm_layers.extend([hk.LSTM(hidden_size=lstm_hidden_size), jnp.tanh])
    rnn_core = hk.DeepRNN(lstm_layers)
    initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1])
    # [time_steps, batch_size, hidden_size].
    output, _ = hk.static_unroll(rnn_core, embeddings, initial_state)

    output = hk.Linear(full_vocab_size)(output)
    # [batch_size, time_steps, full_vocab_size].
    output = jnp.transpose(output, axes=(1, 0, 2))
    return output
Пример #4
0
def make_flexible_recurrent_net(core_type: str,
                                net_type: str,
                                output_dims: int,
                                num_units: Union[Sequence[int], int],
                                num_layers: Optional[int],
                                activation: Activation,
                                activate_final: bool = False,
                                name: Optional[str] = None,
                                **unused_kwargs):
    """Commonly used for creating a flexible recurrences."""
    if net_type != "mlp":
        raise ValueError("We do not support convolutional recurrent nets atm.")
    if unused_kwargs:
        logging.warning("Unused kwargs of `make_flexible_recurrent_net`: %s",
                        str(unused_kwargs))

    if isinstance(num_units, (list, tuple)):
        num_units = list(num_units) + [output_dims]
        num_layers = len(num_units)
    else:
        assert num_layers is not None
        num_units = [num_units] * (num_layers - 1) + [output_dims]
    name = name or f"{core_type.upper()}"

    activation = utils.get_activation(activation)
    core_list = []
    for i, n in enumerate(num_units):
        if core_type.lower() == "vanilla":
            core_list.append(hk.VanillaRNN(hidden_size=n, name=f"{name}_{i}"))
        elif core_type.lower() == "lstm":
            core_list.append(hk.LSTM(hidden_size=n, name=f"{name}_{i}"))
        elif core_type.lower() == "gru":
            core_list.append(hk.GRU(hidden_size=n, name=f"{name}_{i}"))
        else:
            raise ValueError(f"Unrecognized core_type={core_type}.")
        if i != num_layers - 1:
            core_list.append(activation)
    if activate_final:
        core_list.append(activation)

    return hk.DeepRNN(core_list, name="RNN")
Пример #5
0
 def initial_state(batch_size: int):
   network = hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])
   return network.initial_state(batch_size)
Пример #6
0
 def network(inputs: jnp.ndarray, state: hk.LSTMState):
   return hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])(inputs, state)
Пример #7
0
 def initial_state(batch_size: Optional[int] = None):
   network = hk.DeepRNN([hk.Reshape([-1], preserve_dims=1),
                         hk.LSTM(output_size)])
   return network.initial_state(batch_size)
Пример #8
0
 def network(inputs: jnp.ndarray, state: hk.LSTMState):
   return hk.DeepRNN([hk.Reshape([-1], preserve_dims=1),
                      hk.LSTM(output_size)])(inputs, state)
Пример #9
0
 def initial_state(batch_size: Optional[int] = None):
     network = hk.DeepRNN(
         [lambda x: jnp.reshape(x, [-1]),
          hk.LSTM(output_size)])
     return network.initial_state(batch_size)
Пример #10
0
 def network(inputs: jnp.ndarray, state: hk.LSTMState):
     return hk.DeepRNN(
         [lambda x: jnp.reshape(x, [-1]),
          hk.LSTM(output_size)])(inputs, state)