def make_network() -> hk.RNNCore: """Defines the network architecture.""" model = hk.DeepRNN([ lambda x: jax.nn.one_hot(x, num_classes=dataset.NUM_CHARS), hk.LSTM(FLAGS.hidden_size), jax.nn.relu, hk.LSTM(FLAGS.hidden_size), hk.nets.MLP([FLAGS.hidden_size, dataset.NUM_CHARS]), ]) return model
def __init__(self, vocab_size, lstm_dim, dropout_rate, is_training=True): super().__init__() self.is_training = is_training self.embed = hk.Embed(vocab_size, lstm_dim) self.conv1 = hk.Conv1D(lstm_dim, 3, padding='SAME') self.conv2 = hk.Conv1D(lstm_dim, 3, padding='SAME') self.conv3 = hk.Conv1D(lstm_dim, 3, padding='SAME') self.bn1 = hk.BatchNorm(True, True, 0.9) self.bn2 = hk.BatchNorm(True, True, 0.9) self.bn3 = hk.BatchNorm(True, True, 0.9) self.lstm_fwd = hk.LSTM(lstm_dim) self.lstm_bwd = hk.ResetCore(hk.LSTM(lstm_dim)) self.dropout_rate = dropout_rate
def __init__(self, is_training=True): super().__init__() self.is_training = is_training self.encoder = TokenEncoder(FLAGS.vocab_size, FLAGS.acoustic_encoder_dim, 0.5, is_training) self.decoder = hk.deep_rnn_with_skip_connections([ hk.LSTM(FLAGS.acoustic_decoder_dim), hk.LSTM(FLAGS.acoustic_decoder_dim) ]) self.projection = hk.Linear(FLAGS.mel_dim) # prenet self.prenet_fc1 = hk.Linear(256, with_bias=True) self.prenet_fc2 = hk.Linear(256, with_bias=True) # posnet self.postnet_convs = [hk.Conv1D(FLAGS.postnet_dim, 5) for _ in range(4)] + [hk.Conv1D(FLAGS.mel_dim, 5)] self.postnet_bns = [hk.BatchNorm(True, True, 0.9) for _ in range(4)] + [None]
def forward_pass(batch): x = batch['x'] # [time_steps, batch_size, ...]. x = jnp.transpose(x) # [time_steps, batch_size, embed_dim]. embedding_layer = hk.Embed(full_vocab_size, embed_size) embeddings = embedding_layer(x) lstm_layers = [] for _ in range(lstm_num_layers): lstm_layers.extend([ hk.LSTM(hidden_size=lstm_hidden_size), jnp.tanh, # Projection changes dimension from lstm_hidden_size to embed_size. hk.Linear(embed_size) ]) rnn_core = hk.DeepRNN(lstm_layers) initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1]) # [time_steps, batch_size, hidden_size]. output, _ = hk.static_unroll(rnn_core, embeddings, initial_state) if share_input_output_embeddings: output = jnp.dot(output, jnp.transpose(embedding_layer.embeddings)) output = hk.Bias(bias_dims=[-1])(output) else: output = hk.Linear(full_vocab_size)(output) # [batch_size, time_steps, full_vocab_size]. output = jnp.transpose(output, axes=(1, 0, 2)) return output
def __init__(self, num_actions: int): super().__init__(name='r2d2_atari_network') self._embed = embedding.OAREmbedding(DeepAtariTorso(), num_actions) self._core = hk.LSTM(512) self._duelling_head = duelling.DuellingMLP(num_actions, hidden_sizes=[512]) self._num_actions = num_actions
def unroll_fn(observations, state): lstm = hk.LSTM(hidden_size) embedding, state = hk.dynamic_unroll(lstm, observations, state) logits = hk.Linear(num_actions)(embedding) values = jnp.squeeze(hk.Linear(1)(embedding), axis=-1) return (logits, values), state
def network_definition(graph: jraph.GraphsTuple) -> jraph.ArrayTree: """`InteractionNetwork` with an LSTM in the edge update.""" # LSTM that will keep a memory of the inputs to the edge model. edge_fn_lstm = hk.LSTM(hidden_size=HIDDEN_SIZE) # MLPs used in the edge and the node model. Note that in this instance # the output size matches the input size so the same model can be run # iteratively multiple times. In a real model, this would usually be achieved # by first using an encoder in the input data into a common `EMBEDDING_SIZE`. edge_fn_mlp = hk.nets.MLP([HIDDEN_SIZE, EMBEDDING_SIZE]) node_fn_mlp = hk.nets.MLP([HIDDEN_SIZE, EMBEDDING_SIZE]) # Initialize the edge features to contain both the input edge embedding # and initial LSTM state. Note for the nodes we only have an embedding since # in this example nodes do not use a `node_fn_lstm`, but for analogy, we # still put it in a `StatefulField`. graph = graph._replace( edges=StatefulField(embedding=graph.edges, state=edge_fn_lstm.initial_state( graph.edges.shape[0])), nodes=StatefulField(embedding=graph.nodes, state=None), ) def update_edge_fn(edges, sender_nodes, receiver_nodes): # We will run an LSTM memory on the inputs first, and then # process the output of the LSTM with an MLP. edge_inputs = jnp.concatenate([ edges.embedding, sender_nodes.embedding, receiver_nodes.embedding ], axis=-1) lstm_output, updated_state = edge_fn_lstm(edge_inputs, edges.state) updated_edges = StatefulField( embedding=edge_fn_mlp(lstm_output), state=updated_state, ) return updated_edges def update_node_fn(nodes, received_edges): # Note `received_edges.state` will also contain the aggregated state for # all received edges, which we may choose to use in the node update. node_inputs = jnp.concatenate( [nodes.embedding, received_edges.embedding], axis=-1) updated_nodes = StatefulField(embedding=node_fn_mlp(node_inputs), state=None) return updated_nodes recurrent_graph_network = jraph.InteractionNetwork( update_edge_fn=update_edge_fn, update_node_fn=update_node_fn) # Apply the model recurrently for 10 message passing steps. # If instead we intended to use the LSTM to process a sequence of features # for each node/edge, here we would select the corresponding inputs from the # sequence along the sequence axis of the nodes/edges features to build the # correct input graph for each step of the iteration. num_message_passing_steps = 10 for _ in range(num_message_passing_steps): graph = recurrent_graph_network(graph) return graph
def __init__(self, num_actions: int): super().__init__(name='my_network') self._torso = hk.Sequential([ lambda x: jnp.reshape(x, [np.prod(x.shape)]), hk.nets.MLP([50, 50]), ]) self._core = hk.LSTM(20) self._head = networks.PolicyValueHead(num_actions)
def network(inputs: jnp.ndarray, state) -> Tuple[Tuple[Logits, Value], LSTMState]: flat_inputs = hk.Flatten()(inputs) torso = hk.nets.MLP([hidden_size, hidden_size]) lstm = hk.LSTM(hidden_size) policy_head = hk.Linear(action_spec.num_values) value_head = hk.Linear(1) embedding = torso(flat_inputs) embedding, state = lstm(embedding, state) logits = policy_head(embedding) value = value_head(embedding) return (logits, jnp.squeeze(value, axis=-1)), state
def forward_pass(batch): x = batch['x'] # [time_steps, batch_size, ...]. x = jnp.transpose(x) # [time_steps, batch_size, embed_dim]. embedding_layer = hk.Embed(full_vocab_size, embed_size) embeddings = embedding_layer(x) lstm_layers = [] for _ in range(lstm_num_layers): lstm_layers.extend([hk.LSTM(hidden_size=lstm_hidden_size), jnp.tanh]) rnn_core = hk.DeepRNN(lstm_layers) initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1]) # [time_steps, batch_size, hidden_size]. output, _ = hk.static_unroll(rnn_core, embeddings, initial_state) output = hk.Linear(full_vocab_size)(output) # [batch_size, time_steps, full_vocab_size]. output = jnp.transpose(output, axes=(1, 0, 2)) return output
def make_flexible_recurrent_net(core_type: str, net_type: str, output_dims: int, num_units: Union[Sequence[int], int], num_layers: Optional[int], activation: Activation, activate_final: bool = False, name: Optional[str] = None, **unused_kwargs): """Commonly used for creating a flexible recurrences.""" if net_type != "mlp": raise ValueError("We do not support convolutional recurrent nets atm.") if unused_kwargs: logging.warning("Unused kwargs of `make_flexible_recurrent_net`: %s", str(unused_kwargs)) if isinstance(num_units, (list, tuple)): num_units = list(num_units) + [output_dims] num_layers = len(num_units) else: assert num_layers is not None num_units = [num_units] * (num_layers - 1) + [output_dims] name = name or f"{core_type.upper()}" activation = utils.get_activation(activation) core_list = [] for i, n in enumerate(num_units): if core_type.lower() == "vanilla": core_list.append(hk.VanillaRNN(hidden_size=n, name=f"{name}_{i}")) elif core_type.lower() == "lstm": core_list.append(hk.LSTM(hidden_size=n, name=f"{name}_{i}")) elif core_type.lower() == "gru": core_list.append(hk.GRU(hidden_size=n, name=f"{name}_{i}")) else: raise ValueError(f"Unrecognized core_type={core_type}.") if i != num_layers - 1: core_list.append(activation) if activate_final: core_list.append(activation) return hk.DeepRNN(core_list, name="RNN")
def forward(batch, is_training): x, _ = batch batch_size = x.shape[0] x = hk.Embed(vocab_size=max_features, embed_dim=embedding_size)(x) x = hk.Conv1D(output_channels=num_filters, kernel_shape=kernel_size, padding="VALID")(x) if use_swish: x = jax.nn.swish(x) else: x = jax.nn.relu(x) if use_maxpool: x = hk.MaxPool( window_shape=pool_size, strides=pool_size, padding='VALID', channel_axis=2)(x) x = jnp.moveaxis(x, 1, 0)[:, :] #[T, B, F] lstm_layer = hk.LSTM(hidden_size=cell_size) init_state = lstm_layer.initial_state(batch_size) x, state = hk.static_unroll(lstm_layer, x, init_state) x = x[-1] logits = hk.Linear(num_classes)(x) return logits
def __init__(self, input_dim, hidden_dim, latent_dim, max_num_segments, temp_b=1., temp_z=1., latent_dist='gaussian', name='compile'): super().__init__(name=name) self.input_dim = input_dim self.hidden_dim = hidden_dim self.latent_dim = latent_dim self.max_num_segments = max_num_segments self.temp_b = temp_b self.temp_z = temp_z self.latent_dist = latent_dist self.embed = hk.Embed(input_dim, hidden_dim) self.lstm_cell = hk.LSTM(hidden_dim) # LSTM output heads. self.head_z_1 = hk.Linear(hidden_dim) # Latents (z). if latent_dist == 'gaussian': self.head_z_2 = hk.Linear(latent_dim * 2) elif latent_dist == 'concrete': self.head_z_2 = hk.Linear(latent_dim) else: raise ValueError('Invalid argument for `latent_dist`.') self.head_b_1 = hk.Linear(hidden_dim) # Boundaries (b). self.head_b_2 = hk.Linear(1) # Decoder MLP. self.decode_1 = hk.Linear(hidden_dim) self.decode_2 = hk.Linear(input_dim)
def initial_state(batch_size: int): network = hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)]) return network.initial_state(batch_size)
def network(inputs: jnp.ndarray, state: hk.LSTMState): return hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])(inputs, state)
def initial_state_fn(): return hk.LSTM(hidden_size).initial_state(None)
def __init__(self, num_actions: int): super().__init__(name='impala_atari_network') self._embed = embedding.OAREmbedding(DeepAtariTorso(), num_actions) self._core = hk.LSTM(256) self._head = policy_value.PolicyValueHead(num_actions) self._num_actions = num_actions
def initial_state(batch_size: Optional[int] = None): network = hk.DeepRNN([hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)]) return network.initial_state(batch_size)
hk.Flatten(), hk.Linear(32), jax.nn.relu, hk.Linear(10), ])(features) def lstm_model(x, vocab_size=10_000, seq_len=256, args=None, **_): embed_init = hk.initializers.TruncatedNormal(stddev=0.02) token_embedding_map = hk.Embed(vocab_size + 4, 100, w_init=embed_init) o2 = token_embedding_map(x) o2 = jnp.reshape(o2, (o2.shape[1], o2.shape[0], o2.shape[2])) # LSTM Part of Network core = hk.LSTM(100) if args and args.dynamic_unroll: outs, state = hk.dynamic_unroll(core, o2, core.initial_state(x.shape[0])) else: outs, state = hk.static_unroll(core, o2, core.initial_state(x.shape[0])) outs = outs.reshape(outs.shape[1], outs.shape[0], outs.shape[2]) # Avg Pool -> Linear red_dim_outs = hk.avg_pool(outs, seq_len, seq_len, "SAME").squeeze() final_layer = hk.Linear(2) ret = final_layer(red_dim_outs) return ret
def __init__(self, num_actions, use_resnet, use_lstm, name=None): super(AtariNet, self).__init__(name=name) self._num_actions = num_actions self._use_resnet = use_resnet self._use_lstm = use_lstm self._core = hk.ResetCore(hk.LSTM(256))
def initial_state(batch_size: Optional[int] = None): network = hk.DeepRNN( [lambda x: jnp.reshape(x, [-1]), hk.LSTM(output_size)]) return network.initial_state(batch_size)
def network(inputs: jnp.ndarray, state: hk.LSTMState): return hk.DeepRNN([hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)])(inputs, state)
RNN_CORES = ( ModuleDescriptor( name="ResetCore", create=lambda: ResetCoreAdapter(hk.ResetCore(DummyCore())), shape=(BATCH_SIZE, 128)), ModuleDescriptor( name="GRU", create=lambda: hk.GRU(1), shape=(BATCH_SIZE, 128)), ModuleDescriptor( name="IdentityCore", create=lambda: hk.IdentityCore(), shape=(BATCH_SIZE, 128)), ModuleDescriptor( name="LSTM", create=lambda: hk.LSTM(1), shape=(BATCH_SIZE, 128)), ModuleDescriptor( name="Conv1DLSTM", create=lambda: hk.Conv1DLSTM([2], 3, 3), shape=(BATCH_SIZE, 2, 2)), ModuleDescriptor( name="Conv2DLSTM", create=lambda: hk.Conv2DLSTM([2, 2], 3, 3), shape=(BATCH_SIZE, 2, 2, 2)), ModuleDescriptor( name="Conv3DLSTM", create=lambda: hk.Conv3DLSTM([2, 2, 2], 3, 3), shape=(BATCH_SIZE, 2, 2, 2, 2)), ModuleDescriptor( name="VanillaRNN",
def network(inputs: jnp.ndarray, state: hk.LSTMState): return hk.DeepRNN( [lambda x: jnp.reshape(x, [-1]), hk.LSTM(output_size)])(inputs, state)