def forward_pass(batch): x = batch['x'] # [time_steps, batch_size, ...]. x = jnp.transpose(x) # [time_steps, batch_size, embed_dim]. embedding_layer = hk.Embed(full_vocab_size, embed_size) embeddings = embedding_layer(x) lstm_layers = [] for _ in range(lstm_num_layers): lstm_layers.extend([ hk.LSTM(hidden_size=lstm_hidden_size), jnp.tanh, # Projection changes dimension from lstm_hidden_size to embed_size. hk.Linear(embed_size) ]) rnn_core = hk.DeepRNN(lstm_layers) initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1]) # [time_steps, batch_size, hidden_size]. output, _ = hk.static_unroll(rnn_core, embeddings, initial_state) if share_input_output_embeddings: output = jnp.dot(output, jnp.transpose(embedding_layer.embeddings)) output = hk.Bias(bias_dims=[-1])(output) else: output = hk.Linear(full_vocab_size)(output) # [batch_size, time_steps, full_vocab_size]. output = jnp.transpose(output, axes=(1, 0, 2)) return output
def make_network() -> hk.RNNCore: """Defines the network architecture.""" model = hk.DeepRNN([ lambda x: jax.nn.one_hot(x, num_classes=dataset.NUM_CHARS), hk.LSTM(FLAGS.hidden_size), jax.nn.relu, hk.LSTM(FLAGS.hidden_size), hk.nets.MLP([FLAGS.hidden_size, dataset.NUM_CHARS]), ]) return model
def forward_pass(batch): x = batch['x'] # [time_steps, batch_size, ...]. x = jnp.transpose(x) # [time_steps, batch_size, embed_dim]. embedding_layer = hk.Embed(full_vocab_size, embed_size) embeddings = embedding_layer(x) lstm_layers = [] for _ in range(lstm_num_layers): lstm_layers.extend([hk.LSTM(hidden_size=lstm_hidden_size), jnp.tanh]) rnn_core = hk.DeepRNN(lstm_layers) initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1]) # [time_steps, batch_size, hidden_size]. output, _ = hk.static_unroll(rnn_core, embeddings, initial_state) output = hk.Linear(full_vocab_size)(output) # [batch_size, time_steps, full_vocab_size]. output = jnp.transpose(output, axes=(1, 0, 2)) return output
def make_flexible_recurrent_net(core_type: str, net_type: str, output_dims: int, num_units: Union[Sequence[int], int], num_layers: Optional[int], activation: Activation, activate_final: bool = False, name: Optional[str] = None, **unused_kwargs): """Commonly used for creating a flexible recurrences.""" if net_type != "mlp": raise ValueError("We do not support convolutional recurrent nets atm.") if unused_kwargs: logging.warning("Unused kwargs of `make_flexible_recurrent_net`: %s", str(unused_kwargs)) if isinstance(num_units, (list, tuple)): num_units = list(num_units) + [output_dims] num_layers = len(num_units) else: assert num_layers is not None num_units = [num_units] * (num_layers - 1) + [output_dims] name = name or f"{core_type.upper()}" activation = utils.get_activation(activation) core_list = [] for i, n in enumerate(num_units): if core_type.lower() == "vanilla": core_list.append(hk.VanillaRNN(hidden_size=n, name=f"{name}_{i}")) elif core_type.lower() == "lstm": core_list.append(hk.LSTM(hidden_size=n, name=f"{name}_{i}")) elif core_type.lower() == "gru": core_list.append(hk.GRU(hidden_size=n, name=f"{name}_{i}")) else: raise ValueError(f"Unrecognized core_type={core_type}.") if i != num_layers - 1: core_list.append(activation) if activate_final: core_list.append(activation) return hk.DeepRNN(core_list, name="RNN")
def initial_state(batch_size: int): network = hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)]) return network.initial_state(batch_size)
def network(inputs: jnp.ndarray, state: hk.LSTMState): return hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])(inputs, state)
def initial_state(batch_size: Optional[int] = None): network = hk.DeepRNN([hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)]) return network.initial_state(batch_size)
def network(inputs: jnp.ndarray, state: hk.LSTMState): return hk.DeepRNN([hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)])(inputs, state)
def initial_state(batch_size: Optional[int] = None): network = hk.DeepRNN( [lambda x: jnp.reshape(x, [-1]), hk.LSTM(output_size)]) return network.initial_state(batch_size)
def network(inputs: jnp.ndarray, state: hk.LSTMState): return hk.DeepRNN( [lambda x: jnp.reshape(x, [-1]), hk.LSTM(output_size)])(inputs, state)