def __call__(self, h: jnp.ndarray, mask: Optional[jnp.ndarray], is_training: bool) -> jnp.ndarray: """Connects the transformer. Args: h: Inputs, [B, T, H]. mask: Padding mask, [B, T]. is_training: Whether we're training or not. Returns: Array of shape [B, T, H]. """ init_scale = 2. / self._num_layers dropout_rate = self._dropout_rate if is_training else 0. if mask is not None: mask = mask[:, None, None, :] for i in range(self._num_layers): h_norm = layer_norm(h, name=f'h{i}_ln_1') h_attn = SelfAttention(num_heads=self._num_heads, key_size=64, w_init_scale=init_scale, name=f'h{i}_attn')(h_norm, mask=mask) h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn) h = h + h_attn h_norm = layer_norm(h, name=f'h{i}_ln_2') h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm) h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense) h = h + h_dense h = layer_norm(h, name='ln_f') return h
def __call__(self, x, lengths): x = self.embed(x) x = jax.nn.relu(self.bn1(self.conv1(x), is_training=self.is_training)) x = hk.dropout(hk.next_rng_key(), self.dropout_rate, x) if self.is_training else x x = jax.nn.relu(self.bn2(self.conv2(x), is_training=self.is_training)) x = hk.dropout(hk.next_rng_key(), self.dropout_rate, x) if self.is_training else x x = jax.nn.relu(self.bn3(self.conv3(x), is_training=self.is_training)) x = hk.dropout(hk.next_rng_key(), self.dropout_rate, x) if self.is_training else x B, L, D = x.shape mask = jnp.arange(0, L)[None, :] >= (lengths[:, None] - 1) h0c0_fwd = self.lstm_fwd.initial_state(B) new_hx_fwd, new_hxcx_fwd = hk.dynamic_unroll(self.lstm_fwd, x, h0c0_fwd, time_major=False) x_bwd, mask_bwd = jax.tree_map(lambda x: jnp.flip(x, axis=1), (x, mask)) h0c0_bwd = self.lstm_bwd.initial_state(B) new_hx_bwd, new_hxcx_bwd = hk.dynamic_unroll(self.lstm_bwd, (x_bwd, mask_bwd), h0c0_bwd, time_major=False) x = jnp.concatenate((new_hx_fwd, jnp.flip(new_hx_bwd, axis=1)), axis=-1) return x
def __call__(self, h: jnp.ndarray, mask: Optional[jnp.ndarray], is_training: bool) -> jnp.ndarray: """Connects the transformer. Args: h: Inputs, [B, T, H]. mask: Padding mask, [B, T]. is_training: Whether we're training or not. Returns: Array of shape [B, T, H]. """ init_scale = 2. / np.sqrt(self._num_layers) dropout_rate = self._dropout_rate if is_training else 0. if mask is not None: mask = mask[:, None, None, :] h = layer_norm(h) for _ in range(self._num_layers): h_attn = CausalSelfAttention(self._num_heads, init_scale)(h, mask) h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn) h = layer_norm(h + h_attn) h_dense = DenseBlock(init_scale)(h) h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense) h = layer_norm(h + h_dense) return h
def forward(self, x, is_training): # Block 1 x = jax.nn.relu(self.conv1_1(x)) x = self.bn1_1(x, is_training) x = jax.nn.relu(self.conv1_2(x)) x = self.bn1_2(x, is_training) x = hk.max_pool(x, 2, 2, "SAME") if is_training: x = hk.dropout(hk.next_rng_key(), 0.2, x) # Block 2 x = jax.nn.relu(self.conv2_1(x)) x = self.bn2_1(x, is_training) x = jax.nn.relu(self.conv2_2(x)) x = self.bn2_2(x, is_training) x = hk.max_pool(x, 2, 2, "SAME") if is_training: x = hk.dropout(hk.next_rng_key(), 0.3, x) # Block 3 x = jax.nn.relu(self.conv3_1(x)) x = self.bn3_1(x, is_training) x = jax.nn.relu(self.conv3_2(x)) x = self.bn3_2(x, is_training) x = hk.max_pool(x, 2, 2, "SAME") if is_training: x = hk.dropout(hk.next_rng_key(), 0.4, x) # Linear part x = hk.Flatten()(x) x = jax.nn.relu(self.lin1(x)) x = self.bn4(x, is_training) if is_training: x = hk.dropout(hk.next_rng_key(), 0.5, x) x = self.lin2(x) return x # logits
def net_fn(inputs): """Function representing a linear layer with learned noise distribution.""" num_inputs = inputs.shape[-1] mu_initializer = _dqn_default_initializer(num_inputs) mu_layer = hk.Linear(num_outputs, name='mu', with_bias=with_bias, w_init=mu_initializer, b_init=mu_initializer) sigma_initializer = hk.initializers.Constant( # weight_init_stddev / jnp.sqrt(num_inputs)) sigma_layer = hk.Linear(num_outputs, name='sigma', with_bias=True, w_init=sigma_initializer, b_init=sigma_initializer) # Broadcast noise over batch dimension. input_noise_sqrt = make_noise_sqrt(hk.next_rng_key(), [1, num_inputs]) output_noise_sqrt = make_noise_sqrt(hk.next_rng_key(), [1, num_outputs]) # Factorized Gaussian noise. mu = mu_layer(inputs) noisy_inputs = input_noise_sqrt * inputs sigma = sigma_layer(noisy_inputs) * output_noise_sqrt return mu + sigma
def func_boxspace(S, is_training): batch_norm = hk.BatchNorm(False, False, 0.99) mu = hk.Sequential(( hk.Flatten(), hk.Linear(8), jax.nn.relu, partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.), partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh, hk.Linear(onp.prod(boxspace.shape)), hk.Reshape(boxspace.shape), )) logvar = hk.Sequential(( hk.Flatten(), hk.Linear(8), jax.nn.relu, partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.), partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh, hk.Linear(onp.prod(boxspace.shape)), hk.Reshape(boxspace.shape), )) return {'mu': mu(S), 'logvar': logvar(S)}
def net_fn(inputs): """Function representing multi-head DQN Q-network.""" network = hk.Sequential([ dqn_torso(), dqn_value_head(num_heads * num_actions), ]) network_output = network(inputs) multi_head_output = jnp.reshape(network_output, (-1, num_heads, num_actions)) mask = jax.random.choice(key=hk.next_rng_key(), a=2, shape=( multi_head_output.shape[0], num_heads, ), p=binomial_probabilities) random_head_indices = jax.random.choice( key=hk.next_rng_key(), a=num_heads, shape=(multi_head_output.shape[0], )) random_head_q_value = jnp.reshape( multi_head_output[:, random_head_indices], (-1, num_actions)) # TODO: make the q values (used for eval) the output of voting or weighted mean. # Currently random head q value used as placeholder return MultiHeadQNetworkOutputs( q_values=jnp.mean(multi_head_output, axis=1), multi_head_output=multi_head_output, random_head_q_value=random_head_q_value)
def __call__(self, h: jnp.ndarray, mask: Optional[jnp.ndarray], is_training: bool) -> jnp.ndarray: """Connects the transformer. Args: h: Inputs, [B, T, H]. mask: Padding mask, [B, T]. is_training: Whether we're training or not. Returns: Array of shape [B, T, H]. """ init_scale = 2. / np.sqrt(self._num_layers) dropout_rate = self._dropout_rate if is_training else 0. if mask is not None: mask = mask[:, None, None, :] # Note: names chosen to approximately match those used in the GPT-2 code; # see https://github.com/openai/gpt-2/blob/master/src/model.py. for i in range(self._num_layers): h_norm = layer_norm(h, name=f'h{i}_ln_1') h_attn = CausalSelfAttention(self._num_heads, init_scale, name=f'h{i}_attn')(h_norm, mask) h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn) h = h + h_attn h_norm = layer_norm(h, name=f'h{i}_ln_2') h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm) h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense) h = h + h_dense h = layer_norm(h, name='ln_f') return h
def get_latents(self, encodings, probs_b, training): """Read out latents (z) form input encodings for a single segment.""" readout_mask = probs_b[:, 1:, None] # Offset readout by 1 to left. readout = (encodings[:, :-1] * readout_mask).sum(1) hidden = nn.relu(self.head_z_1(readout)) logits_z = self.head_z_2(hidden) # Gaussian latents. if self.latent_dist == 'gaussian': if training: mu, log_var = jnp.split(logits_z, 2, axis=1) sample_z = utils.gaussian_sample(hk.next_rng_key(), mu, log_var) else: sample_z = logits_z[:, :self.latent_dim] # Concrete / Gumbel softmax latents. elif self.latent_dist == 'concrete': if training: sample_z = utils.gumbel_softmax_sample(hk.next_rng_key(), logits_z, temp=self.temp_z) else: sample_z_idx = jnp.argmax(logits_z, axis=1) sample_z = utils.to_one_hot(sample_z_idx, logits_z.size(1)) else: raise ValueError('Invalid argument for `latent_dist`.') return logits_z, sample_z
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: hiddens = x.shape[-1] x = conv1d(x, num_units=self._dense_dim, init_scale=self._init_scale) x = jax.nn.relu(x) x = hk.dropout(hk.next_rng_key(), self._dropout_prob, x) x = conv1d(x, num_units=hiddens, init_scale=self._init_scale) return hk.dropout(hk.next_rng_key(), self._dropout_prob, x)
def __call__(self, x: jnp.ndarray) -> VAEOutput: x = x.astype(jnp.float32) mean, stddev = Encoder(self._hidden_size, self._latent_size)(x) z = mean + stddev * jax.random.normal(hk.next_rng_key(), mean.shape) logits = Decoder(self._hidden_size, self._output_shape)(z) p = jax.nn.sigmoid(logits) image = jax.random.bernoulli(hk.next_rng_key(), p) return VAEOutput(image, mean, stddev, logits)
def forward_fn(x: jnp.ndarray) -> jnp.ndarray: linear_1 = linear_with_dropout(3, 0.5) transformed_linear = hk.transform(linear_1) inner_params = hk.experimental.lift(transformed_linear.init)( hk.next_rng_key(), x, True) def fun(_params, _rng, h): return transformed_linear.apply(_params, _rng, h, True) z = deq(inner_params, hk.next_rng_key(), x, fun, max_iter) return hk.Linear(output_size, name='l2', with_bias=False)(z)
def forward_fn(x: jnp.ndarray) -> jnp.ndarray: linear_1 = hk.Linear(output_size, name='l1', w_init=hk.initializers.Constant(1), b_init=hk.initializers.Constant(1)) transformed_linear = hk.transform(linear_1) inner_params = hk.experimental.lift(transformed_linear.init)( hk.next_rng_key(), x) z = deq(inner_params, hk.next_rng_key(), x, transformed_linear.apply, max_iter) return z
def _elbo_fun(input_data): if _ENCODER.value is EncoderArch.color_mnist_mlp_encoder: encoder = encoders.ColorMnistMLPEncoder(_LATENT_DIM.value) if _DECODER.value is DecoderArch.color_mnist_mlp_decoder: decoder = decoders.ColorMnistMLPDecoder(_OBS_VAR.value) vae_obj = vae.VAE(encoder, decoder, _RHO.value) if _MODEL.value is Model.vae: return vae_obj.vae_elbo(input_data, hk.next_rng_key()) else: return vae_obj.avae_elbo(input_data, hk.next_rng_key())
def init_fn(rng: Optional[Union[PRNGKey]], inputs: Mapping[str, jnp.ndarray], batch_axes=(), return_initial_output=False, **kwargs ) -> Tuple[Params, State]: """ Initializes your function collecting parameters and state. """ rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR) with new_custom_context(rng=rng) as ctx: # Create the model model = create_fun() # Load the batch axes for the inputs Layer.batch_axes = batch_axes key = hk.next_rng_key() # Initialize the model outputs = model(inputs, key, **kwargs) # Unset the batch axes Layer.batch_axes = () nonlocal constants params, state, constants = ctx.collect_params(), ctx.collect_initial_state(), ctx.collect_constants() if return_initial_output: return params, state, outputs return params, state
def __call__(self, node_feats: jnp.ndarray, adj: jnp.ndarray, is_training: bool) -> jnp.ndarray: """Predict logits or values Parameters ---------- node_feats : ndarray of shape (batch_size, N, in_feats) Batch input node features. N is the total number of nodes in the batch of graphs. adj : ndarray of shape (batch_size, N, N) Batch adjacency matrix. is_training : bool Whether the model is training or not. Returns ------- out : ndarray of shape (batch_size, n_out) Predicator output. """ predicator_dropout = self.predicator_dropout if is_training is True else 0.0 node_feats = self.gcn(node_feats, adj, is_training) # pooling graph_feat = self.pooling(node_feats) if predicator_dropout != 0.0: graph_feat = hk.dropout(hk.next_rng_key(), predicator_dropout, graph_feat) graph_feat = self.fc(graph_feat) graph_feat = self.activation(graph_feat) out = self.out(graph_feat) return out
def __call__(self, q: jnp.ndarray, k: jnp.ndarray) -> jnp.ndarray: """Computes the relative position embedding. Args: q: The query. k: The key. Returns: Relative position embedding. """ # Use key instead of query to obtain the length. batch_size, key_length, num_heads, head_dim = list(k.shape) # Content based addressing and global content bias content_score = jnp.einsum('bthd,bThd->bhtT', q + self._r_w_bias, k) # Relative position encoding positional_encodings = self._sinusoidal_pos_emb(key_length, batch_size) positional_encodings = hk.dropout(hk.next_rng_key(), self._dropout_rate, positional_encodings) rel_pos_emb = hk.Conv1D( output_channels=self._dim, kernel_shape=1, with_bias=False, w_init=init.RandomNormal( stddev=self._init_scale))(positional_encodings) rel_pos_emb = jnp.reshape( rel_pos_emb, [batch_size, key_length, num_heads, head_dim]) # Content dependent positional bias and global positional bias rel_pos_score = jnp.einsum('bthd,bThd->bhtT', q + self._r_r_bias, rel_pos_emb) rel_pos_score = relative_shift(rel_pos_score) assert content_score.shape == rel_pos_score.shape return content_score + rel_pos_score
def generate_initial(self, context, length): # slice last token off the context (we use that in generate_once to generate the first new token) last = context[-1:] context = context[:-1] input_len = context.shape[0] if self.rpe is not None: attn_bias = self.rpe(input_len, input_len, self.heads_per_shard, 32) else: attn_bias = 0 x = self.embed(context) states = [] for l in self.transformer_layers: res, layer_state = l.get_init_decode_state(x, length - 1, attn_bias) x = x + res states.append(layer_state) return self.proj(x), (last.astype(jnp.uint32), states, hk.next_rng_key())
def get_boundaries(self, encodings, segment_id, lengths, training): """Get boundaries (b) for a single segment in batch.""" if segment_id == self.max_num_segments - 1: # Last boundary is always placed on last sequence element. logits_b = None # sample_b = jnp.zeros_like(encodings[:, :, 0]).scatter_( # 1, jnp.expand_dims(lengths, -1) - 1, 1) sample_b = jnp.zeros_like(encodings[:, :, 0]) sample_b = jax.ops.index_update( sample_b, jax.ops.index[jnp.arange(len(lengths)), lengths - 1], 1) else: hidden = nn.relu(self.head_b_1(encodings)) logits_b = jnp.squeeze(self.head_b_2(hidden), -1) # Mask out first position with large neg. value. neg_inf = jnp.ones((encodings.shape[0], 1)) * utils.NEG_INF # TODO(tkipf): Mask out padded positions with large neg. value. logits_b = jnp.concatenate([neg_inf, logits_b[:, 1:]], axis=1) if training: sample_b = utils.gumbel_softmax_sample(hk.next_rng_key(), logits_b, temp=self.temp_b) else: sample_b_idx = jnp.argmax(logits_b, axis=1) sample_b = nn.one_hot(sample_b_idx, logits_b.shape[1]) return logits_b, sample_b
def __call__(self, inputs, is_training): dropout_rate = self._dropout_rate if is_training else 0.0 h = hk.dropout(hk.next_rng_key(), dropout_rate, inputs) h = hk.Linear(self._vocab_size, with_bias=False)(h) return hk.BatchNorm(create_scale=False, create_offset=False, decay_rate=0.9)(h, is_training)
def __call__(self, node_feats: jnp.ndarray, adj: jnp.ndarray, graph_idx: jnp.array, is_training: bool) -> jnp.ndarray: """Predict logits or values Parameters ---------- node_feats : ndarray of shape (N, in_feats) Batch input node features. N is the total number of nodes in the batch adj : ndarray of shape (2, E) Batch adjacency list. E is the total number of edges in the batch graph_idx : ndarray of shape (N,) This idx indicate a graph number for node_feats in the batch. When the two nodes shows the same graph idx, these belong to the same graph. is_training : bool Whether the model is training or not. Returns ------- out : ndarray of shape (batch_size, n_out) Predicator output. """ predicator_dropout = self.predicator_dropout if is_training is True else 0.0 node_feats = self.gcn(node_feats, adj, is_training) # pooling graph_feat = self.pooling(node_feats, graph_idx) if predicator_dropout != 0.0: graph_feat = hk.dropout(hk.next_rng_key(), predicator_dropout, graph_feat) graph_feat = self.fc(graph_feat) graph_feat = self.activation(graph_feat) out = self.out(graph_feat) return out
def __call__(self, x: jnp.ndarray, is_training: bool) -> jnp.ndarray: dropout_rate = self._dropout_rate if is_training else 0. h = hk.Linear(self.output_size, w_init=hk.initializers.Constant(1), b_init=hk.initializers.Constant(1))(x) return hk.dropout(hk.next_rng_key(), dropout_rate, h)
def wrapped(*args): out = fn(*args) if is_training: mask = hk.dropout(hk.next_rng_key(), dropout_rate, jnp.ones([out.shape[0], 1])) out = out * mask return out
def __call__(self, x, is_training=True, return_metrics=False): """Return the output of the final layer without any [log-]softmax.""" # Stem outputs = {} out = self.initial_conv(x) out = hk.max_pool(out, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding='SAME') if return_metrics: outputs.update(base.signal_metrics(out, 0)) # Blocks for i, block in enumerate(self.blocks): out, res_avg_var = block(out, is_training=is_training) if return_metrics: outputs.update(base.signal_metrics(out, i + 1)) outputs[f'res_avg_var_{i}'] = res_avg_var # Final-conv->activation, pool, dropout, classify pool = jnp.mean(self.activation(out), [1, 2]) outputs['pool'] = pool # Optionally apply dropout if self.drop_rate > 0.0 and is_training: pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool) outputs['logits'] = self.fc(pool) return outputs
def layer( self, x: jnp.ndarray, latents: jnp.ndarray, output_channels: int, upsample: bool = False, ) -> jnp.ndarray: if upsample: conv = UpsampleConv2D( output_channels=output_channels, kernel_shape=3, upsample_factor=2, resample_kernel=self.resample_kernel, ) else: conv = ModulatedConv2D(output_channels=output_channels, kernel_shape=3, padding="SAME") y = conv(x, latents) if self.data_format == ChannelOrder.channels_first: noise_shape = (y.shape[0], 1, y.shape[2], y.shape[3]) else: noise_shape = (y.shape[0], y.shape[1], y.shape[2], 1) key = hk.next_rng_key() noise = jax.random.normal(key, shape=noise_shape, dtype=y.dtype) noise_strength = hk.get_parameter("noise_strength", (1, 1, 1, 1), dtype=y.dtype, init=jnp.zeros) y += noise_strength * noise return self.activation_function(y)
def __call__(self, node_feats: jnp.ndarray, adj: jnp.ndarray, is_training: bool) -> jnp.ndarray: """Update node features. Parameters ---------- node_feats : ndarray of shape (batch_size, N, in_feats) Batch input node features. N is the total number of nodes in the batch of graphs. adj : ndarray of shape (batch_size, N, N) Batch adjacency matrix. is_training : bool Whether the model is training or not. Returns ------- new_node_feats : ndarray of shape (batch_size, N, out_feats) Batch new node features. """ dropout = self.dropout if is_training is True else 0.0 # for batch data new_node_feats = jax.vmap(self._update_nodes)(node_feats, adj) if self.bias: new_node_feats += self.b new_node_feats = self.activation(new_node_feats) if dropout != 0.0: new_node_feats = hk.dropout(hk.next_rng_key(), dropout, new_node_feats) if self.batch_norm: new_node_feats = hk.BatchNorm(True, True, 0.9)(new_node_feats, is_training) return new_node_feats
def fn(x): if dropout: x = hk.dropout(hk.next_rng_key(), 0.5, x) if batchnorm: x = hk.BatchNorm(create_offset=True, create_scale=True, decay_rate=0.001)( x, is_training=True ) return x
def maybe_dropedge(x): """Dropout on edge messages.""" if not is_training: return x return x * hk.dropout( hk.next_rng_key(), dropedge_rate, jnp.ones([x.shape[0], 1]), )
def func_discrete_type2(S, is_training): batch_norm = hk.BatchNorm(False, False, 0.99) seq = hk.Sequential( (hk.Flatten(), hk.Linear(8), jax.nn.relu, partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.), partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh, hk.Linear(discrete.n * discrete.n), hk.Reshape((discrete.n, discrete.n)), jax.nn.softmax)) return seq(S)
def func_discrete_type1(S, A, is_training): batch_norm = hk.BatchNorm(False, False, 0.99) seq = hk.Sequential( (hk.Flatten(), hk.Linear(8), jax.nn.relu, partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.), partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh, hk.Linear(discrete.n), jax.nn.softmax)) X = jax.vmap(jnp.kron)(S, A) return seq(X)