def _dropout_graph(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: node_key, edge_key = hk.next_rng_keys(2) nodes = hk.dropout(node_key, self._dropout_rate, graph.nodes) edges = graph.edges if not self._disable_edge_updates: edges = hk.dropout(edge_key, self._dropout_rate, edges) return graph._replace(nodes=nodes, edges=edges)
def __call__(self, h: jnp.ndarray, mask: Optional[jnp.ndarray], is_training: bool) -> jnp.ndarray: """Connects the transformer. Args: h: Inputs, [B, T, H]. mask: Padding mask, [B, T]. is_training: Whether we're training or not. Returns: Array of shape [B, T, H]. """ init_scale = 2. / np.sqrt(self._num_layers) dropout_rate = self._dropout_rate if is_training else 0. if mask is not None: mask = mask[:, None, None, :] # Note: names chosen to approximately match those used in the GPT-2 code; # see https://github.com/openai/gpt-2/blob/master/src/model.py. for i in range(self._num_layers): h_norm = layer_norm(h, name=f'h{i}_ln_1') h_attn = CausalSelfAttention(self._num_heads, init_scale, name=f'h{i}_attn')(h_norm, mask) h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn) h = h + h_attn h_norm = layer_norm(h, name=f'h{i}_ln_2') h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm) h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense) h = h + h_dense h = layer_norm(h, name='ln_f') return h
def __call__(self, h: jnp.ndarray, mask: Optional[jnp.ndarray], is_training: bool) -> jnp.ndarray: """Connects the transformer. Args: h: Inputs, [B, T, H]. mask: Padding mask, [B, T]. is_training: Whether we're training or not. Returns: Array of shape [B, T, H]. """ init_scale = 2. / self._num_layers dropout_rate = self._dropout_rate if is_training else 0. if mask is not None: mask = mask[:, None, None, :] for i in range(self._num_layers): h_norm = layer_norm(h, name=f'h{i}_ln_1') h_attn = SelfAttention(num_heads=self._num_heads, key_size=64, w_init_scale=init_scale, name=f'h{i}_attn')(h_norm, mask=mask) h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn) h = h + h_attn h_norm = layer_norm(h, name=f'h{i}_ln_2') h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm) h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense) h = h + h_dense h = layer_norm(h, name='ln_f') return h
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: hiddens = x.shape[-1] x = conv1d(x, num_units=self._dense_dim, init_scale=self._init_scale) x = jax.nn.relu(x) x = hk.dropout(hk.next_rng_key(), self._dropout_prob, x) x = conv1d(x, num_units=hiddens, init_scale=self._init_scale) return hk.dropout(hk.next_rng_key(), self._dropout_prob, x)
def __call__(self, x, lengths): x = self.embed(x) x = jax.nn.relu(self.bn1(self.conv1(x), is_training=self.is_training)) x = hk.dropout(hk.next_rng_key(), self.dropout_rate, x) if self.is_training else x x = jax.nn.relu(self.bn2(self.conv2(x), is_training=self.is_training)) x = hk.dropout(hk.next_rng_key(), self.dropout_rate, x) if self.is_training else x x = jax.nn.relu(self.bn3(self.conv3(x), is_training=self.is_training)) x = hk.dropout(hk.next_rng_key(), self.dropout_rate, x) if self.is_training else x B, L, D = x.shape mask = jnp.arange(0, L)[None, :] >= (lengths[:, None] - 1) h0c0_fwd = self.lstm_fwd.initial_state(B) new_hx_fwd, new_hxcx_fwd = hk.dynamic_unroll(self.lstm_fwd, x, h0c0_fwd, time_major=False) x_bwd, mask_bwd = jax.tree_map(lambda x: jnp.flip(x, axis=1), (x, mask)) h0c0_bwd = self.lstm_bwd.initial_state(B) new_hx_bwd, new_hxcx_bwd = hk.dynamic_unroll(self.lstm_bwd, (x_bwd, mask_bwd), h0c0_bwd, time_major=False) x = jnp.concatenate((new_hx_fwd, jnp.flip(new_hx_bwd, axis=1)), axis=-1) return x
def __call__(self, h: jnp.ndarray, mask: Optional[jnp.ndarray], is_training: bool) -> jnp.ndarray: """Connects the transformer. Args: h: Inputs, [B, T, H]. mask: Padding mask, [B, T]. is_training: Whether we're training or not. Returns: Array of shape [B, T, H]. """ init_scale = 2. / np.sqrt(self._num_layers) dropout_rate = self._dropout_rate if is_training else 0. if mask is not None: mask = mask[:, None, None, :] h = layer_norm(h) for _ in range(self._num_layers): h_attn = CausalSelfAttention(self._num_heads, init_scale)(h, mask) h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn) h = layer_norm(h + h_attn) h_dense = DenseBlock(init_scale)(h) h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense) h = layer_norm(h + h_dense) return h
def forward(self, x, is_training): # Block 1 x = jax.nn.relu(self.conv1_1(x)) x = self.bn1_1(x, is_training) x = jax.nn.relu(self.conv1_2(x)) x = self.bn1_2(x, is_training) x = hk.max_pool(x, 2, 2, "SAME") if is_training: x = hk.dropout(hk.next_rng_key(), 0.2, x) # Block 2 x = jax.nn.relu(self.conv2_1(x)) x = self.bn2_1(x, is_training) x = jax.nn.relu(self.conv2_2(x)) x = self.bn2_2(x, is_training) x = hk.max_pool(x, 2, 2, "SAME") if is_training: x = hk.dropout(hk.next_rng_key(), 0.3, x) # Block 3 x = jax.nn.relu(self.conv3_1(x)) x = self.bn3_1(x, is_training) x = jax.nn.relu(self.conv3_2(x)) x = self.bn3_2(x, is_training) x = hk.max_pool(x, 2, 2, "SAME") if is_training: x = hk.dropout(hk.next_rng_key(), 0.4, x) # Linear part x = hk.Flatten()(x) x = jax.nn.relu(self.lin1(x)) x = self.bn4(x, is_training) if is_training: x = hk.dropout(hk.next_rng_key(), 0.5, x) x = self.lin2(x) return x # logits
def __call__(self, node_feats: jnp.ndarray, adj: jnp.ndarray, graph_idx: jnp.array, is_training: bool) -> jnp.ndarray: """Predict logits or values Parameters ---------- node_feats : ndarray of shape (N, in_feats) Batch input node features. N is the total number of nodes in the batch adj : ndarray of shape (2, E) Batch adjacency list. E is the total number of edges in the batch graph_idx : ndarray of shape (N,) This idx indicate a graph number for node_feats in the batch. When the two nodes shows the same graph idx, these belong to the same graph. is_training : bool Whether the model is training or not. Returns ------- out : ndarray of shape (batch_size, n_out) Predicator output. """ predicator_dropout = self.predicator_dropout if is_training is True else 0.0 node_feats = self.gcn(node_feats, adj, is_training) # pooling graph_feat = self.pooling(node_feats, graph_idx) if predicator_dropout != 0.0: graph_feat = hk.dropout(hk.next_rng_key(), predicator_dropout, graph_feat) graph_feat = self.fc(graph_feat) graph_feat = self.activation(graph_feat) out = self.out(graph_feat) return out
def __call__(self, x: jnp.ndarray, is_training: bool) -> jnp.ndarray: dropout_rate = self._dropout_rate if is_training else 0. h = hk.Linear(self.output_size, w_init=hk.initializers.Constant(1), b_init=hk.initializers.Constant(1))(x) return hk.dropout(hk.next_rng_key(), dropout_rate, h)
def __call__( self, inputs: jnp.ndarray, dropout_rate: Optional[float] = None, rng=None, ) -> jnp.ndarray: """Connects the module to some inputs. Args: inputs: A Tensor of shape `[batch_size, input_size]`. dropout_rate: Optional dropout rate. rng: Optional RNG key. Require when using dropout. Returns: output: The output of the model of size `[batch_size, output_size]`. """ if dropout_rate is not None and rng is None: raise ValueError("When using dropout an rng key must be passed.") elif dropout_rate is None and rng is not None: raise ValueError("RNG should only be passed when using dropout.") rng = hk.PRNGSequence(rng) if rng is not None else None num_layers = len(self.layers) out = inputs for i, layer in enumerate(self.layers): out = layer(out) if i < (num_layers - 1) or self.activate_final: # Only perform dropout if we are activating the output. if dropout_rate is not None: out = hk.dropout(next(rng), dropout_rate, out) out = self.activation(out) return out
def __call__(self, x, rng, is_training=True, **kwargs): # This function assumes that the input is batched! batch_size, H, W, C = x.shape if rng.ndim > 1: # In case we did the split in ResNet or CNN assert rng.ndim == 2 assert rng.shape[0] == len(self.channel_sizes) rngs = rng else: rngs = random.split(rng, len(self.channel_sizes)) for i, (rng, out_channel, kernel_shape) in enumerate(zip(rngs, self.channel_sizes, self.kernel_shapes)): if i == len(self.channel_sizes) - 1 and self.gate == True: ab = Conv(2*out_channel, kernel_shape, name=f"conv_{i}", **self.conv_kwargs)(x, is_training=is_training) a, b = jnp.split(ab, 2, axis=-1) x = a*jax.nn.sigmoid(b) else: x = Conv(out_channel, kernel_shape, name=f"conv_{i}", **self.conv_kwargs)(x, is_training=is_training) if self.norm is not None: x = self.norm(f"norm_{i}")(x, is_training=is_training) if i < len(self.channel_sizes) - 1: x = self.nonlinearity(x) if self.dropout_rate is not None: rate = self.dropout_rate if is_training else 0.0 x = hk.dropout(rng, rate, x) return x
def __call__(self, inputs, is_training): dropout_rate = self._dropout_rate if is_training else 0.0 h = hk.dropout(hk.next_rng_key(), dropout_rate, inputs) h = hk.Linear(self._vocab_size, with_bias=False)(h) return hk.BatchNorm(create_scale=False, create_offset=False, decay_rate=0.9)(h, is_training)
def __call__(self, x, is_training=True, return_metrics=False): """Return the output of the final layer without any [log-]softmax.""" # Stem outputs = {} out = self.initial_conv(x) out = hk.max_pool(out, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding='SAME') if return_metrics: outputs.update(base.signal_metrics(out, 0)) # Blocks for i, block in enumerate(self.blocks): out, res_avg_var = block(out, is_training=is_training) if return_metrics: outputs.update(base.signal_metrics(out, i + 1)) outputs[f'res_avg_var_{i}'] = res_avg_var # Final-conv->activation, pool, dropout, classify pool = jnp.mean(self.activation(out), [1, 2]) outputs['pool'] = pool # Optionally apply dropout if self.drop_rate > 0.0 and is_training: pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool) outputs['logits'] = self.fc(pool) return outputs
def __call__(self, node_feats: jnp.ndarray, adj: jnp.ndarray, is_training: bool) -> jnp.ndarray: """Predict logits or values Parameters ---------- node_feats : ndarray of shape (batch_size, N, in_feats) Batch input node features. N is the total number of nodes in the batch of graphs. adj : ndarray of shape (batch_size, N, N) Batch adjacency matrix. is_training : bool Whether the model is training or not. Returns ------- out : ndarray of shape (batch_size, n_out) Predicator output. """ predicator_dropout = self.predicator_dropout if is_training is True else 0.0 node_feats = self.gcn(node_feats, adj, is_training) # pooling graph_feat = self.pooling(node_feats) if predicator_dropout != 0.0: graph_feat = hk.dropout(hk.next_rng_key(), predicator_dropout, graph_feat) graph_feat = self.fc(graph_feat) graph_feat = self.activation(graph_feat) out = self.out(graph_feat) return out
def __call__(self, node_feats: jnp.ndarray, adj: jnp.ndarray, is_training: bool) -> jnp.ndarray: """Update node features. Parameters ---------- node_feats : ndarray of shape (batch_size, N, in_feats) Batch input node features. N is the total number of nodes in the batch of graphs. adj : ndarray of shape (batch_size, N, N) Batch adjacency matrix. is_training : bool Whether the model is training or not. Returns ------- new_node_feats : ndarray of shape (batch_size, N, out_feats) Batch new node features. """ dropout = self.dropout if is_training is True else 0.0 # for batch data new_node_feats = jax.vmap(self._update_nodes)(node_feats, adj) if self.bias: new_node_feats += self.b new_node_feats = self.activation(new_node_feats) if dropout != 0.0: new_node_feats = hk.dropout(hk.next_rng_key(), dropout, new_node_feats) if self.batch_norm: new_node_feats = hk.BatchNorm(True, True, 0.9)(new_node_feats, is_training) return new_node_feats
def __call__(self, q: jnp.ndarray, k: jnp.ndarray) -> jnp.ndarray: """Computes the relative position embedding. Args: q: The query. k: The key. Returns: Relative position embedding. """ # Use key instead of query to obtain the length. batch_size, key_length, num_heads, head_dim = list(k.shape) # Content based addressing and global content bias content_score = jnp.einsum('bthd,bThd->bhtT', q + self._r_w_bias, k) # Relative position encoding positional_encodings = self._sinusoidal_pos_emb(key_length, batch_size) positional_encodings = hk.dropout(hk.next_rng_key(), self._dropout_rate, positional_encodings) rel_pos_emb = hk.Conv1D( output_channels=self._dim, kernel_shape=1, with_bias=False, w_init=init.RandomNormal( stddev=self._init_scale))(positional_encodings) rel_pos_emb = jnp.reshape( rel_pos_emb, [batch_size, key_length, num_heads, head_dim]) # Content dependent positional bias and global positional bias rel_pos_score = jnp.einsum('bthd,bThd->bhtT', q + self._r_r_bias, rel_pos_emb) rel_pos_score = relative_shift(rel_pos_score) assert content_score.shape == rel_pos_score.shape return content_score + rel_pos_score
def wrapped(*args): out = fn(*args) if is_training: mask = hk.dropout(hk.next_rng_key(), dropout_rate, jnp.ones([out.shape[0], 1])) out = out * mask return out
def fn(x): if dropout: x = hk.dropout(hk.next_rng_key(), 0.5, x) if batchnorm: x = hk.BatchNorm(create_offset=True, create_scale=True, decay_rate=0.001)( x, is_training=True ) return x
def __call__(self, x, *, is_training): dropout_prob = self._dropout_prob if is_training else 0.0 output_channels = x.shape[-1] x = conv_1d(output_channels=self._widening_factor * output_channels, init_scale=self._init_scale)(x) x = jax.nn.gelu(x) x = conv_1d(output_channels=output_channels, init_scale=self._init_scale)(x) return hk.dropout(hk.next_rng_key(), dropout_prob, x)
def __call__(self, X: jnp.ndarray, dropout: float, train: bool) -> Tuple[jnp.ndarray, jnp.ndarray]: X = l2_normalize(X) if train: X = hk.dropout(hk.next_rng_key(), dropout, X) h = self.mlp(X) mu = h[:, :self.latent_dim] log_var = h[:, self.latent_dim:] return mu, log_var
def maybe_dropedge(x): """Dropout on edge messages.""" if not is_training: return x return x * hk.dropout( hk.next_rng_key(), dropedge_rate, jnp.ones([x.shape[0], 1]), )
def postnet(self, mel: ndarray) -> ndarray: x = mel for conv, bn in zip(self.postnet_convs, self.postnet_bns): x = conv(x) if bn is not None: x = bn(x, is_training=self.is_training) x = jnp.tanh(x) x = hk.dropout(hk.next_rng_key(), 0.5, x) if self.is_training else x return x
def mlp_function(X: jnp.ndarray, training: bool) -> Any: layers: List[Any] = [] for d_o in config.intermediate_dims: if training: layers.append( lambda x: hk.dropout(hk.next_rng_key(), config.dropout, x)) layers.append(hk.Linear(d_o)) layers.append(config.activation) layers.append(hk.Linear(dim_out)) return hk.Sequential(layers)(X)
def __call__(self, x: jnp.ndarray, mask: Optional[jnp.ndarray] = None, should_reset: Optional[jnp.ndarray] = None, cache_steps: int = 0, extra: Optional[jnp.ndarray] = None, extra_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray: """Computes the outputs of the self attention block. Args: x: query input [batch, x_timesteps, in_dim]. mask: attention mask [batch, 1, 1, x_timesteps]. should_reset: reset marker [batch, timesteps]. cache_steps: number of timesteps in the cache. extra: if provided should be extra key-value input [batch, extra_timesteps, in_dim']. extra_mask: if provided should be the mask for extra key-value input, [batch, extra_timesteps]. Returns: output: block output [batch, x_timesteps, in_dim]. """ if self._causal: timesteps = x.shape[1] batch_size = x.shape[0] t = jnp.arange(timesteps, dtype=jnp.int32) causal_mask = (t[:, None] >= t[None, :])[None, None, :, :] causal_mask = causal_mask.astype(x.dtype) if mask is None: mask = jnp.broadcast_to(causal_mask, (batch_size, 1, timesteps, timesteps)) else: mask *= causal_mask x = Attention(self._r_w_bias, self._r_r_bias, num_heads=self._num_heads, init_scale=self._init_scale, relative_pos_clamp_len=self._relative_pos_clamp_len, dropout_prob=self._dropout_attn_prob)( x, mask=mask, should_reset=should_reset, cache_steps=cache_steps, extra=extra, extra_mask=extra_mask) else: x = Attention(self._r_w_bias, self._r_r_bias, num_heads=self._num_heads, init_scale=self._init_scale, dropout_prob=self._dropout_attn_prob)( x, mask=mask, extra=extra, extra_mask=extra_mask) return hk.dropout(hk.next_rng_key(), self._dropout_prob, x)
def __call__(self, inputs, is_training): dropout_rate = self._dropout_rate if is_training else 0.0 h = jax.nn.softplus(hk.Linear(self._hidden)(inputs)) h = jax.nn.softplus(hk.Linear(self._hidden)(h)) h = hk.dropout(hk.next_rng_key(), dropout_rate, h) h = hk.Linear(self._num_topics)(h) # NB: here we set `create_scale=False` and `create_offset=False` to reduce # the number of learning parameters log_concentration = hk.BatchNorm(create_scale=False, create_offset=False, decay_rate=0.9)(h, is_training) return jnp.exp(log_concentration)
def __call__(self, inputs, rng_key, stochastic, is_training, test_local_stats): out = shortcut = inputs if self.use_projection: shortcut = self.proj_conv(shortcut, rng_key, stochastic) shortcut = self.proj_batchnorm(shortcut, is_training, test_local_stats) # DROPOUT if self.dropout and is_training: shortcut = hk.dropout(rng_key, self.dropout_rate, shortcut) for i, (conv_i, bn_i) in enumerate(self.layers): out = conv_i(out, rng_key, stochastic) out = bn_i(out, is_training, test_local_stats) if i < len(self.layers ) - 1: # Don't apply relu or dropout on last layer out = jax.nn.relu(out) # DROPOUT if self.dropout and is_training: out = hk.dropout(rng_key, self.dropout_rate, out) return jax.nn.relu(out + shortcut)
def __call__(self, h: jnp.ndarray, input_embs: jnp.ndarray, mask: Optional[jnp.ndarray], is_training: bool) -> jnp.ndarray: """Connects the transformer. Args: input_embs: Inputs, [B, T, H]. h: Hidden, [B, T, H]. h: Hidden, [B, T, H]. mask: Padding mask, [B, T]. is_training: Whether we're training or not. Returns: Array of shape [B, T, H]. """ init_scale = 2. / np.sqrt(self._num_layers) dropout_rate = self._dropout_rate if is_training else 0. if mask is not None: mask = mask[:, None, None, :] for i in range(self._num_layers): # input injections h = h + input_embs # regular transformer block h_norm = layer_norm(h, name=f'h{i}_ln_1') h_attn = CausalSelfAttention(self._num_heads, init_scale, name=f'h{i}_attn')(h_norm, mask) h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn) h = h + h_attn h_norm = layer_norm(h, name=f'h{i}_ln_2') h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm) h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense) h = h + h_dense h = layer_norm(h, name='ln_f') return h
def __call__(self, inputs_q, inputs_kv, *, attention_mask=None, is_training): dropout_prob = self._dropout_prob if is_training else 0.0 dropout_attn_prob = self._dropout_attn_prob if is_training else 0.0 output_channels = inputs_q.shape[-1] if self._shape_for_attn == 'q': qk_channels = inputs_q.shape[-1] elif self._shape_for_attn == 'kv': qk_channels = inputs_kv.shape[-1] else: raise ValueError(f'Unknown value {self._shape_for_attn} for ' 'shape_for_attention.') v_channels = None if self._qk_channels is not None: qk_channels = self._qk_channels if self._v_channels is not None: v_channels = self._v_channels attention = Attention(num_heads=self._num_heads, init_scale=self._att_init_scale, dropout_prob=dropout_attn_prob, qk_channels=qk_channels, v_channels=v_channels, output_channels=output_channels)( layer_norm(inputs_q), layer_norm(inputs_kv), attention_mask=attention_mask) attention = hk.dropout(hk.next_rng_key(), dropout_prob, attention) # Optionally include a residual to the query. # Consider omitting the residual if the semantics of query and output # are different, e.g. if queries are positions and outputs are pixels. if self._use_query_residual: x = inputs_q + attention else: x = attention x += MLP(widening_factor=self._widening_factor, dropout_prob=dropout_prob, init_scale=self._dense_init_scale)(layer_norm(x), is_training=is_training) return x
def attend(q, k, v, dropout_prob=0.0, attention_mask=None): """Computes multi-head attention using a query, key and value. Args: q: Query with shape [batch, q_indices, num_heads, head_dim]. k: Key with shape [batch, kv_indices, num_heads, head_dim]. v: Value with shape [batch, kv_indices, num_heads, head_dim]. dropout_prob: dropout probability on the attention weights. attention_mask: Array of shape [batch, q_indices, kv_indices] indicating which attentions are valid Returns: Output of the attention with shape [batch, q_indices, hiddens] """ batch, q_indices, num_heads, q_head_dim = q.shape _, _, _, v_head_dim = v.shape hiddens = num_heads * v_head_dim attention = jnp.einsum('bthd,bThd->bhtT', q, k) scale = 1. / math.sqrt(q_head_dim) attention *= scale if attention_mask is not None: # Use large_k instead of np.NINF because np.NINF breaks for causal-masked # left-padded sampling. large_k = jnp.array(1e4 if attention.dtype == jnp.float16 else 1e30, dtype=attention.dtype) attention = jnp.where(attention_mask[:, None, :, :], attention, -large_k) normalized = jax.nn.softmax(attention) if dropout_prob > 0: normalized = hk.dropout(hk.next_rng_key(), dropout_prob, normalized) summed = jnp.einsum('bhtT,bThd->bthd', normalized, v) summed = jnp.reshape(summed, [batch, q_indices, hiddens]) if attention_mask is not None: # If all attended tokens are masked, or for masked tokens # some rows of logits gets completely masked, in which case the softmax # gives a uniform row and we obtain non-zero outputs where it should be # zero. We force zeros. wipe_attn = jnp.all(attention_mask == 0, axis=2, keepdims=True) # shape (B, T, 1) summed = jnp.where(wipe_attn, jnp.zeros_like(summed), summed) return summed
def __call__(self, x, rng, is_training=True, update_params=True, **kwargs): # This function assumes that the input is batched! batch_size, in_dim = x.shape rngs = random.split(rng, len(self.layer_sizes)) for i, (rng, out_dim) in enumerate(zip(rngs, self.layer_sizes)): if self.zero_init and i == len(self.layer_sizes) - 1: w, b = data_dependent_param_init( x, out_dim, name_suffix=f"{i}", w_init=hk.initializers.RandomNormal(stddev=0.01), b_init=jnp.zeros, is_training=is_training, update_params=update_params, parameter_norm=None) else: w, b = data_dependent_param_init( x, out_dim, name_suffix=f"{i}", w_init=self.w_init, b_init=self.b_init, is_training=is_training, update_params=update_params, parameter_norm=self.parameter_norm) z = jnp.dot(x, w.T) + b if i < len(self.layer_sizes) - 1: z = self.nonlinearity(z) # Residual connection if self.skip_connection and x.shape[-1] == z.shape[-1]: x += z else: x = z if i < len(self.layer_sizes) - 1: if self.dropout_rate is not None: rate = self.dropout_rate if is_training else 0.0 x = hk.dropout(rng, rate, x) return x