def compute_cross_entropy_loss_with_positives_and_negatives_masks( scores: Array, positives: Array, negatives: Array, weights: Optional[Array] = None, ) -> Tuple[float, Dict[str, float], Tuple[Array, Array]]: """Compute (weighted) cross-entropy loss and accuracy-related metrics. The function computes cross entropy loss when there are potentially multiple positive classes per sample, multiple negative classes and others are neutral. In this case, loss per sample is average of cross entropy losses computed by considering each positive class and all negative classes. Neutral classes are ignored. Arguments `positives` and `negatives` are boolean matrices that specify which class is considered positive or negative per every sample. `positive[i, j]` is True <=> class j is considered positive for the sample i `negative[i, j]` is True <=> class j is considered negative for the sample i The loss is computed in 3 stages: (1) For every sample i and positive class j we compute cross-entropy loss using j as a positive class and all negative classes for i as negatives. (2) For every sample i the total loss is average of losses per each of its positive classes. (3) Total loss is a sum of losses per each sample. The loss only includes samples, which have at least one positive and one negative classes. Users can limit this even further by providing a custom `weights`. Args: scores: [batch_size, num_classes] scores or logits. positives: [batch_size, num_classes] 0-1 mask for which classes are positive. negatives: [batch_size, num_classes] 0-1 mask for which classes are negative. weights: [batch_size] 0-1 masks indicating whether the loss should be computed for the corresponding item in the batch. Returns: A tuple of scalar loss, a dictionary with metrics, per sample information (a tuple of average positive probability per sample and weight per sample). """ at_least_one_positive_and_negative = jnp.logical_and( positives.sum(-1) > 0, negatives.sum(-1) > 0) if weights is None: weights = at_least_one_positive_and_negative else: weights = jnp.logical_and(weights, at_least_one_positive_and_negative) scores = scores.astype(jnp.float32) positives = positives.astype(jnp.float32) negatives = negatives.astype(jnp.float32) weights = weights.astype(jnp.float32) # For simplicity, we ignore the first batch dimension in the equations below # and assume that the loss is computed for a single sample. # Let p_1, ..., p_N be scores of positive classes # and n_1, ..., n_M be scores of negative classes. # In this case the loss is # sum_{i=1..N} -log softmax([p_i, n_1, ..., n_M])_1. # It's too computationally expensive to compute it naively. # We implement the loss in the following way # (1) compute S, the negatives part of softmax denominator. In other words, # exp(S) = sum_{j=1..M} exp(n_j) negative_scores = scores * negatives - _BIG_NUMBER * (1.0 - negatives) negative_scores_log_sum_exp = jax.nn.logsumexp( negative_scores, axis=-1, keepdims=True) # (2) now the loss per positive class i is just # -log (exp(p_i) / (exp(p_i) + exp(S)) = -log(1 / (1 + exp(-(p_i - S)))) # = -log sigmoid(p_i - S) scores_minus_negatives = scores - negative_scores_log_sum_exp positives_weight = (positives.sum(axis=-1) + _SMALL_NUMBER) per_positive_loss = -jax.nn.log_sigmoid(scores_minus_negatives) # (3) compute average loss over all positive classes loss_per_sample = (per_positive_loss * positives).sum(axis=-1) loss_per_sample /= positives_weight loss_per_sample *= weights # (4) compute sum of losses over all positive samples loss = loss_per_sample.sum() # Now we need to compute the average accuracy. # First, compute the max score of negative classes per sample. # A positive class needs to have a higher score in order to get predicted. max_negative_scores = negative_scores.max(axis=-1, keepdims=True) # Second, a prediction for pair of a sample and its positive class # is correct if the score of the positive class is larger than # scores of all corresponding negative classes. In other words, the score # of the positive class needs to be larger than `max_negative_scores`. correct_prediction = (scores > max_negative_scores).astype(jnp.float32) # Take average over all positive classes per sample correct_prediction = (correct_prediction * positives).sum(axis=-1) correct_prediction /= positives_weight # Mask out samples with 0 weight correct_prediction = correct_prediction * weights metrics = { 'loss': loss, 'acc': correct_prediction.sum(), 'denominator': weights.sum(), } return loss, metrics, (correct_prediction, weights)
def has_updated(self, state: MultiStepsState) -> Array: return jnp.logical_and(state.mini_step == 0, state.gradient_step > 0)
def cond_function(carry): k, _, r, qnorm_scaled = carry _, rnorm = _safe_normalize(r) return jnp.logical_and(k < (max_iterations - 1), rnorm < qnorm_scaled)
def cond_fun(value): _, k, _, residual_norm = value return jnp.logical_and(k < maxiter, residual_norm > atol)
def _do_cholesky(args): matrix, j, coefs, err = args unconverged = err > (eps * jnp.linalg.norm(matrix)) iterating = j < maxiter return jnp.logical_and(unconverged, iterating)[0]
def cond_f(args): _, _, j, error = args still_counting = j < maxiter unconverged = error > thresh return jnp.logical_and(still_counting, unconverged)[0]
def multi_head_dot_product_attention(scope: Scope, inputs_q, inputs_kv, num_heads, dtype=jnp.float32, qkv_features=None, out_features=None, attention_axis=None, causal_mask=False, padding_mask=None, key_padding_mask=None, segmentation=None, key_segmentation=None, cache=False, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None, kernel_init=default_kernel_init, bias_init=initializers.zeros, bias=True, attention_fn=dot_product_attention): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. This can be used for encoder-decoder attention by specifying both `inputs_q` and `inputs_kv` orfor self-attention by only specifying `inputs_q` and setting `inputs_kv` to None. Args: inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`. inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or None for self-attention, inn which case key/values will be derived from inputs_q. num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. dtype: the dtype of the computation (default: float32) qkv_features: dimension of the key, query, and value. out_features: dimension of the last projection attention_axis: axes over which the attention is applied ( 'None' means attention over all axes, but batch, heads, and features). causal_mask: boolean specifying whether to apply a causal mask on the attention weights. If True, the output at timestep `t` will not depend on inputs at timesteps strictly greater than `t`. padding_mask: boolean specifying query tokens that are pad token. key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. cache: an instance of `flax.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the kernel of the Dense layers. bias_init: initializer for the bias of the Dense layers. bias: bool: whether pointwise QKVO dense transforms use bias. attention_fn: dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` Returns: output of shape `[bs, dim1, dim2, ..., dimN, features]`. """ assert causal_mask or not cache, ( 'Caching is only support for causal attention.') if inputs_kv is None: inputs_kv = inputs_q if attention_axis is None: attention_axis = tuple(range(1, inputs_q.ndim - 1)) features = out_features or inputs_q.shape[-1] qkv_features = qkv_features or inputs_q.shape[-1] assert qkv_features % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // num_heads dense = functools.partial(dense_general, axis=-1, dtype=dtype, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision) # project inputs_q to multi-headed q/k/v # dimensions are then [bs, dims..., n_heads, n_features_per_head] query = scope.child(dense, 'query')(inputs_q) key = scope.child(dense, 'key')(inputs_kv) value = scope.child(dense, 'value')(inputs_kv) if cache: if not scope.has_variable('cache', 'entry'): ndim, tail_shape = (key.ndim, key.shape[-2:]) def init_fn(shape, dtype=jnp.float32): full_shape = shape + tail_shape if len(full_shape) != ndim: raise ValueError( 'Shape should be a tuple with the shape of the batch' 'and attention dims.') return CacheEntry(key=jnp.zeros(full_shape, dtype), value=jnp.zeros(full_shape, dtype), i=jnp.zeros((), jnp.uint32)) cache_entry = init_fn else: cache_entry = scope.get_variable('cache', 'entry') if not isinstance(cache_entry, CacheEntry): raise ValueError('Cache is not initialized.') expected_shape = list(cache_entry.key.shape[:-2]) for attn_dim in attention_axis: expected_shape[attn_dim] = 1 expected_shape = tuple(expected_shape) + inputs_q.shape[-1:] if expected_shape != inputs_q.shape: raise ValueError('Invalid shape provided, ' 'expected shape %s instead got %s.' % (expected_shape, inputs_q.shape)) cshape = cache_entry.key.shape indices = [0] * len(cshape) i = cache_entry.i attn_size = onp.prod(onp.take(cshape, attention_axis)) for attn_dim in attention_axis: attn_size //= cshape[attn_dim] indices[attn_dim] = i // attn_size i = i % attn_size key = lax.dynamic_update_slice(cache_entry.key, key, indices) value = lax.dynamic_update_slice(cache_entry.value, value, indices) one = jnp.array(1, jnp.uint32) cache_entry = cache_entry.replace(i=cache_entry.i + one, key=key, value=value) # TODO(levskaya): verify this is still needed in translation decoding. key_padding_mask = jnp.broadcast_to( (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2]) key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None] scope.put_variable('cache', 'entry', cache_entry) # create attention masks mask_components = [] if causal_mask: if cache and isinstance(cache_entry, CacheEntry): bias_pre_shape = (1, ) * (key.ndim - 1) attn_shape = tuple(onp.take(key.shape, attention_axis)) attn_size = onp.prod(attn_shape) ii = jnp.arange(attn_size, dtype=jnp.uint32) mask = ii < cache_entry.i mask_components.append(mask.reshape(bias_pre_shape + attn_shape)) else: mask_components.append(_make_causal_mask(key, attention_axis)) if padding_mask is not None: if key_padding_mask is None: key_padding_mask = padding_mask padding_mask = make_padding_mask(padding_mask_query=padding_mask, padding_mask_key=key_padding_mask, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis) mask_components.append(padding_mask) if segmentation is not None: if key_segmentation is None: key_segmentation = segmentation segmentation_mask = make_padding_mask( padding_mask_query=segmentation, padding_mask_key=key_segmentation, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis, segmentation_mask=True) mask_components.append(segmentation_mask) if mask_components: attention_mask = mask_components[0] for component in mask_components[1:]: attention_mask = jnp.logical_and(attention_mask, component) # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype), jnp.full(attention_mask.shape, -1e10).astype(dtype)) else: attention_bias = None # apply attention x = scope.child(attention_fn)(query, key, value, dtype=dtype, axis=attention_axis, bias=attention_bias, precision=precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic) # back to the original inputs dimensions out = scope.child(dense_general, name='out')(x, features=features, axis=(-2, -1), kernel_init=kernel_init, bias_init=bias_init, bias=bias, dtype=dtype, precision=precision) return out
def eval_step(state, inputs, outputs, programs, bos_token, eos_token, config, lp_config): """Evaluate on batch of program tasks.""" params = state.optimizer.target lp_params = state.lp_optimizer.target weights = jnp.where(programs > 0, 1, 0).astype(jnp.float32) # Embedding mask for autoencoding. emb_mask = jnp.ones((1, FLAGS.latent_vocab_size), jnp.float32).at[:, [0, bos_token, eos_token]].set(0) ae_logits, vq = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, inputs, outputs, programs, emb_mask, mutable=False) # Postprocess latent indices. latent_indices = add_eos_token(vq['latent_indices'], eos_token) latent_weights = jnp.where(latent_indices > 0, 1, 0).astype(jnp.float32) encoded_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32) # Additionally mask out eos token in latents. latents_mask = jnp.where( jnp.logical_and(latent_indices > 0, latent_indices != eos_token), 1, 0).astype(jnp.float32) latent_logits = models.ProgramTransformer(lp_config).apply( {'params': lp_params}, inputs, outputs, latent_indices) encoded = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, inputs, outputs, mutable=False, method=models.LatentProgramTransformer.encode) latents = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, latent_logits, mutable=False, method=models.LatentProgramTransformer.quantize) logits = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, programs, latents, encoded, latents_mask, encoded_mask, mutable=False, method=models.LatentProgramTransformer.decode) metrics = compute_metrics(logits, programs, weights) metrics.update(compute_metrics(ae_logits, programs, weights, prefix='ae_')) latent_metrics = compute_metrics(latent_logits, latent_indices, latent_weights, prefix='latent_') return metrics, latent_metrics
def predict_step(state, inputs, outputs, cache, lp_cache, beam_size, bos_token, eos_token, max_decode_len, config, lp_config): """Predict translation with fast decoding beam search on a batch.""" params = state.optimizer.target lp_params = state.lp_optimizer.target # Split beam over latent sequences and programs. per_latent_beam_size = beam_size // FLAGS.latent_beam_size beam_size = FLAGS.latent_beam_size * per_latent_beam_size flat_lp_encoded = decode.flat_batch_beam_expand( models.ProgramTransformer(lp_config).apply( {'params': lp_params}, inputs, outputs, method=models.ProgramTransformer.encode), FLAGS.latent_beam_size) encoded_padding_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32) flat_encoded_padding_mask = decode.flat_batch_beam_expand( encoded_padding_mask, FLAGS.latent_beam_size) def tokens_ids_to_latent_logits(flat_ids, flat_lp_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.ProgramTransformer(lp_config).apply( { 'params': lp_params, 'cache': flat_lp_cache }, flat_ids, flat_lp_encoded, flat_encoded_padding_mask, mutable=['cache'], method=models.ProgramTransformer.decode) new_flat_lp_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_lp_cache # Step 1: Beam-search over latent tokens. latent_beam_seqs, _ = decode.beam_search( inputs, lp_cache, tokens_ids_to_latent_logits, beam_size=FLAGS.latent_beam_size, alpha=0.6, bos_token=bos_token, eos_token=eos_token, max_decode_len=np.ceil(max_decode_len / 2**FLAGS.c).astype(np.int32)) flat_latent_seqs = decode.flat_batch_beam_expand( decode.flatten_beam_dim(latent_beam_seqs), per_latent_beam_size) # Quantize the predicted latent codes. flat_latents = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, flat_latent_seqs, mutable=False, method=models.LatentProgramTransformer.quantize) flat_encoded = decode.flat_batch_beam_expand( models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, inputs, outputs, mutable=False, method=models.LatentProgramTransformer.encode), beam_size) # Padding masks. flat_latents_mask = jnp.where( jnp.logical_and(flat_latent_seqs > 0, flat_latent_seqs != eos_token), 1, 0).astype(jnp.float32) flat_encoded_padding_mask = decode.flat_batch_beam_expand( encoded_padding_mask, beam_size) def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state, 'cache': flat_cache }, flat_ids, flat_latents, flat_encoded, flat_latents_mask, flat_encoded_padding_mask, mutable=['cache'], method=models.LatentProgramTransformer.decode) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache # Step 2: Beam-search over program tokens. per_latent_inputs = decode.flat_batch_beam_expand(inputs, FLAGS.latent_beam_size) per_latent_cache = jax.tree_map( lambda x: decode.flat_batch_beam_expand(x, FLAGS.latent_beam_size), cache) beam_seqs, _ = decode.beam_search(per_latent_inputs, per_latent_cache, tokens_ids_to_logits, beam_size=per_latent_beam_size, alpha=0.6, bos_token=bos_token, eos_token=eos_token, max_decode_len=max_decode_len) # Collapse both beam dimensions into one. beam_seqs = beam_seqs.reshape((inputs.shape[0], beam_size) + beam_seqs.shape[2:]) latent_beam_seqs = jnp.repeat(latent_beam_seqs, per_latent_beam_size, axis=1) # Beam search returns [n_batch, n_beam, n_length] with beam dimension # sorted in increasing order of log-probability. return beam_seqs, latent_beam_seqs
def antiu(x, xi): dark = onp.array(g(x, xi)) ind = onp.where(np.logical_and(dark > lb(x), dark < ub(x)))[0] dark[ind] = np.nan return dark
def train_step(state, inputs, outputs, programs, pretrain, bos_token, eos_token, learning_rate_fn, config, lp_config, train_rng=None): """Train on batch of program tasks.""" # We handle PRNG splitting inside the top pmap, rather # than handling it outside in the training loop - doing the # latter can add some stalls to the devices. train_rng, new_train_rng = jax.random.split(train_rng) weights = jnp.where(programs > 0, 1, 0).astype(jnp.float32) # Embedding mask for autoencoding. emb_mask = jnp.ones((1, FLAGS.latent_vocab_size), jnp.float32).at[:, [0, bos_token, eos_token]].set(0) def ae_loss_fn(params): """Loss function used for training autoencoder.""" (logits, vq), new_variables = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, inputs, outputs, programs, emb_mask, pretrain=pretrain, mutable=['vqvae'], rngs={'dropout': train_rng}) loss, weight_sum = compute_weighted_cross_entropy( logits, programs, weights) # Add EOS token for latent predictor loss. vq_weight_sum = jnp.sum( jnp.where(vq['latent_indices'] > 0, 1, 0).astype(jnp.float32)) latent_indices = add_eos_token(vq['latent_indices'], eos_token) mean_loss = loss / weight_sum + vq['loss'] / vq_weight_sum return mean_loss, (new_variables['vqvae'], logits, latent_indices) step = state.step optimizer = state.optimizer lp_optimizer = state.lp_optimizer lr = learning_rate_fn(step) grad_fn = jax.value_and_grad(ae_loss_fn, has_aux=True) (_, (new_model_state, ae_logits, latent_indices)), ae_grad = grad_fn(optimizer.target) ae_grad = jax.lax.pmean(ae_grad, 'batch') latent_weights = jnp.where(latent_indices > 0, 1, 0).astype(jnp.float32) encoded_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32) # Additionally mask out eos token in latents. latents_mask = jnp.where( jnp.logical_and(latent_indices > 0, latent_indices != eos_token), 1, 0).astype(jnp.float32) def loss_fn(params, lp_params): """Loss function used for training.""" latent_logits = models.ProgramTransformer(lp_config).apply( {'params': lp_params}, inputs, outputs, latent_indices, rngs={'dropout': train_rng}) latent_loss, latent_weight_sum = compute_weighted_cross_entropy( latent_logits, latent_indices, latent_weights) # End-to-end prediction. encoded = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, inputs, outputs, mutable=False, rngs={'dropout': train_rng}, method=models.LatentProgramTransformer.encode) latents = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, latent_logits, mutable=False, rngs={'dropout': train_rng}, method=models.LatentProgramTransformer.quantize) logits = models.LatentProgramTransformer(config).apply( { 'params': params, 'vqvae': state.model_state }, programs, latents, encoded, latents_mask, encoded_mask, mutable=False, rngs={'dropout': train_rng}, method=models.LatentProgramTransformer.decode) loss, weight_sum = compute_weighted_cross_entropy( logits, programs, weights) mean_loss = latent_loss / latent_weight_sum if not pretrain: mean_loss += loss / weight_sum return mean_loss, (logits, latent_logits) grad_fn = jax.value_and_grad(loss_fn, argnums=[0, 1], has_aux=True) (_, (logits, latent_logits)), grads = grad_fn(optimizer.target, lp_optimizer.target) grads = jax.lax.pmean(grads, 'batch') new_optimizer = optimizer.apply_gradient(jax.tree_multimap( jnp.add, grads[0], ae_grad), learning_rate=lr) new_lp_optimizer = lp_optimizer.apply_gradient(grads[1], learning_rate=lr) metrics = compute_metrics(logits, programs, weights) metrics['learning_rate'] = lr metrics.update(compute_metrics(ae_logits, programs, weights, prefix='ae_')) latent_metrics = compute_metrics(latent_logits, latent_indices, latent_weights, prefix='latent_') new_state = state.replace(step=step + 1, optimizer=new_optimizer, model_state=jax.lax.pmean( new_model_state, 'batch'), lp_optimizer=new_lp_optimizer) return new_state, metrics, latent_metrics, new_train_rng
def __call__(self, inputs_q, inputs_kv, *, padding_mask, key_padding_mask, segmentation=None, key_segmentation=None): """Applies multi-head dot product attention on the input data. If weight_prec is not None, scales and quantizes weights to signed int with weight_prec bits. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. This can be used for encoder-decoder attention by specifying both `inputs_q` and `inputs_kv` or for self-attention by only specifying `inputs_q` and setting `inputs_kv` to None. Args: inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`. inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or None for self-attention, inn which case key/values will be derived from inputs_q. padding_mask: boolean tensor specifying query tokens that are pad token. key_padding_mask: boolean tensor specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. Returns: output of shape `[bs, dim1, dim2, ..., dimN, features]`. """ batch_size, query_sequence_length, channel_size = inputs_q.shape hparams = self.hparams if inputs_kv is None: inputs_kv = inputs_q key_sequence_length = inputs_q.shape[1] else: key_sequence_length = inputs_kv.shape[1] shape_utils.assert_shapes_equal( inputs_kv.shape, (batch_size, key_sequence_length, channel_size)) jax_precision = jax.lax.Precision.DEFAULT if padding_mask is not None: shape_utils.assert_shapes_equal( padding_mask.shape, (batch_size, query_sequence_length, 1)) if key_padding_mask is None: key_padding_mask = padding_mask else: shape_utils.assert_shapes_equal( key_padding_mask.shape, (batch_size, key_sequence_length, 1)) attention_axis = self.attention_axis if attention_axis is None: attention_axis = tuple(range(1, inputs_q.ndim - 1)) qkv_features = self.qkv_features qkv_features = qkv_features or inputs_q.shape[-1] num_heads = self.num_heads assert qkv_features % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // num_heads paxis_name = self.paxis_name train = self.train kernel_init = self.kernel_init bias_init = self.bias_init use_bias = self.use_bias dtype = self.dtype def multi_batch_dense_aqt(inputs, *, name, padding_mask): batch_size, sequence_length, channel_size = inputs.shape inputs = inputs.reshape(batch_size * sequence_length, channel_size) if padding_mask is not None: padding_mask = padding_mask.reshape( batch_size * sequence_length, 1) out = flax_layers.DenseAqt(name=name, features=num_heads * head_dim, paxis_name=paxis_name, train=train, quant_context=self.quant_context, hparams=hparams.dense_kqv, kernel_init=kernel_init, bias_init=bias_init, use_bias=use_bias, dtype=dtype)(inputs, padding_mask=padding_mask) return out.reshape(batch_size, sequence_length, num_heads, head_dim) # project inputs_q to multi-headed q/k/v # dimensions are then [bs, sequence_length, n_heads, n_features_per_head] query = multi_batch_dense_aqt(inputs_q, name='query', padding_mask=padding_mask) key = multi_batch_dense_aqt(inputs_kv, name='key', padding_mask=key_padding_mask) value = multi_batch_dense_aqt(inputs_kv, name='value', padding_mask=key_padding_mask) is_cache_initialized = False if self.decode: is_cache_initialized = self.has_variable('cache', 'cached_key') cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype) cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape, value.dtype) cache_index = self.variable('cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32)) if is_cache_initialized: expected_shape = list(cached_key.value.shape[:-2]) for attn_dim in attention_axis: expected_shape[attn_dim] = 1 expected_shape = tuple(expected_shape) + inputs_q.shape[-1:] if expected_shape != inputs_q.shape: raise ValueError('Invalid shape provided, ' 'expected shape %s instead got %s.' % (expected_shape, inputs_q.shape)) cshape = cached_key.value.shape indices = [0] * len(cshape) i = cache_index.value attn_size = onp.prod(onp.take(cshape, attention_axis)) *batch_dims, max_length, num_heads, depth_per_head = ( # pylint: disable=unused-variable cached_key.value.shape) indices = (0, ) * len(batch_dims) + (i, 0, 0) key = lax.dynamic_update_slice(cached_key.value, key, indices) value = lax.dynamic_update_slice(cached_value.value, value, indices) one = jnp.array(1, jnp.int32) cache_index.value = cache_index.value + one cached_key.value = key cached_value.value = value # TODO(levskaya): verify this is still needed in translation decoding. key_padding_mask = jnp.broadcast_to( (jnp.arange(max_length) < cache_index.value), cshape[:2]) key_padding_mask = key_padding_mask.astype( jnp.float32)[Ellipsis, None] # create attention masks mask_components = [] if self.causal_mask: if self.decode and is_cache_initialized: bias_pre_shape = (1, ) * (key.ndim - 1) attn_shape = tuple(onp.take(key.shape, attention_axis)) attn_size = onp.prod(attn_shape) ii = jnp.arange(attn_size, dtype=jnp.int32) mask = ii < cache_index.value mask_components.append( mask.reshape(bias_pre_shape + attn_shape)) else: mask_components.append(_make_causal_mask(key, attention_axis)) if padding_mask is not None: if key_padding_mask is None: key_padding_mask = padding_mask attn_padding_mask = make_padding_mask( padding_mask_query=padding_mask, padding_mask_key=key_padding_mask, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis) mask_components.append(attn_padding_mask) if segmentation is not None: if key_segmentation is None: key_segmentation = segmentation segmentation_mask = make_padding_mask( padding_mask_query=segmentation, padding_mask_key=key_segmentation, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis, segmentation_mask=True) mask_components.append(segmentation_mask) attention_mask = None if mask_components: attention_mask = mask_components[0] for component in mask_components[1:]: attention_mask = jnp.logical_and(attention_mask, component) attention_mask = attention_mask.astype(jnp.bool_) # attention mask in the form of attention bias attention_bias = jnp.where( attention_mask, jnp.full(attention_mask.shape, 0.).astype(dtype), jnp.full(attention_mask.shape, -1e10).astype(dtype)) else: attention_bias = None # Add an extra dimension to the mask corresponding to the head # dimension. eg, if inputs_q has shape [batch_size, sequence_length, # n_features], then padding_mask will have a shape # [batch_size, sequence_length, 1] and query will have shape # [batch_size, sequence_length, n_heads, n_features_per_head]. # We create query_padding_mask with shape [batch_size, sequence_length, # 1, 1] to be broadcast-compatible with 'query'. if padding_mask is not None: padding_mask = padding_mask[Ellipsis, None] shape_utils.assert_shapes_equal( padding_mask.shape, (batch_size, query_sequence_length, 1, 1)) if key_padding_mask is not None: key_padding_mask = key_padding_mask[Ellipsis, None] # During prediction, the key padding mask is only going to be # broadcast-compatible with the key. shape_utils.assert_shapes_compatible( key_padding_mask.shape, (batch_size, key_sequence_length, 1, 1)) # apply attention attention_fn = self.attention_fn dropout_rate = self.dropout_rate broadcast_dropout = self.broadcast_dropout deterministic = self.deterministic if not deterministic and self.dropout_rate > 0.0: dropout_rng = self.make_rng('dropout') else: dropout_rng = None x = attention_fn( # pylint: disable=redundant-keyword-arg query=query, key=key, value=value, hparams=hparams.attn_acts, paxis_name=paxis_name, train=train, quant_context=self.quant_context, dtype=dtype, axis=attention_axis, bias=attention_bias, precision=jax_precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic, query_padding_mask=padding_mask, key_padding_mask=key_padding_mask, attn_mask=attention_mask) shape_utils.assert_shapes_equal( x.shape, (batch_size, query_sequence_length, num_heads, head_dim)) x = x.reshape(batch_size * query_sequence_length, num_heads * head_dim) if padding_mask is not None: padding_mask = padding_mask.reshape( batch_size * query_sequence_length, 1) # back to the original inputs dimensions out = flax_layers.DenseAqt(features=channel_size, hparams=hparams.dense_out, quant_context=self.quant_context, paxis_name=paxis_name, train=train, kernel_init=kernel_init, bias_init=bias_init, use_bias=use_bias, dtype=dtype, name='dense_out')(x, padding_mask=padding_mask) shape_utils.assert_shapes_equal( out.shape, (batch_size * query_sequence_length, channel_size)) out = out.reshape(batch_size, query_sequence_length, channel_size) return out
def cond(v): t, rmse, _ = v return jnp.logical_and(t < max_iter, rmse > tolerance)
def is_transitional(state): """Checks whether individuals are in a state that can develop.""" return np.logical_and(EXPOSED <= state, state <= INFECTED_3)
def logical_and(self: TensorType, other: TensorOrScalar) -> TensorType: assert_bool(self) assert_bool(other) return type(self)(np.logical_and(self.raw, unwrap1(other)))
def cond(vals): sigma, phi, phiprime, unconverged, j = vals return np.logical_and(np.any(unconverged), j < maxiter)
def cond_fun(state): _, _, _, is_unconverged, is_not_max_iteration = state return jnp.logical_and(is_unconverged, is_not_max_iteration)
def piter(f, x, trust_radius, args): fg = jax.value_and_grad(f) g = jax.grad(f) #h = jax.hessian(f) h = jax.jacfwd(g) #h = jax.jacrev(g) #compute function value and gradient val, grad = fg(x, *args) gradmag = np.linalg.norm(grad, axis=-1) gradcol = np.expand_dims(grad, axis=-1) #compute hessian and eigen-decomposition hess = h(x, *args) e, u = np.linalg.eigh(hess) ut = u.T #convert gradient to eigen-basis a = np.matmul(ut, gradcol) a = np.squeeze(a, axis=-1) lam = e e0 = lam[..., 0] #TODO deal with null gradient components and repeated eigenvectors lambar = lam abarsq = np.square(a) def phif(s): pmagsq = np.sum(abarsq / np.square(lambar + s), axis=-1) pmag = np.sqrt(pmagsq) phipartial = np.reciprocal(pmag) singular = np.any(np.equal(-s, lambar), axis=-1) phipartial = np.where(singular, 0., phipartial) phi = phipartial - np.reciprocal(trust_radius) return phi def phiphiprime(s): phi = phif(s) pmagsq = np.sum(abarsq / np.square(lambar + s), axis=-1) phiprime = np.power(pmagsq, -1.5) * np.sum( abarsq / np.power(lambar + s, 3), axis=-1) return (phi, phiprime) #check if unconstrained solution is valid sigma0 = np.maximum(-e0, 0.) phisigma0 = phif(sigma0) usesolu = np.logical_and(e0 > 0., phisigma0 >= 0.) sigma = np.max(np.abs(a) / trust_radius - lam, axis=-1) sigma = np.maximum(sigma, 0.) sigma = np.where(usesolu, 0., sigma) phi, phiprime = phiphiprime(sigma) #TODO, add handling of additional cases here (singular and "hard" cases) #iteratively solve for sigma, enforcing unconstrained solution sigma=0 where appropriate unconverged = np.ones(shape=sigma.shape, dtype=np.bool_) j = 0 maxiter = 200 #This can't work with vmap+jit because of the dynamic condition, so we use the jax while_loop below #j = 0 #while np.logical_and(np.any(unconverged), j<maxiter): #sigma = sigma - phi/phiprime #sigma = np.where(usesolu, 0., sigma) #phiout, phiprimeout = phiphiprime(sigma) #unconverged = np.logical_and( (phiout > phi) , (phiout < 0.) ) #phi,phiprime = (phiout, phiprimeout) #j = j +1 def cond(vals): sigma, phi, phiprime, unconverged, j = vals return np.logical_and(np.any(unconverged), j < maxiter) def body(vals): sigma, phi, phiprime, unconverged, j = vals sigma = sigma - phi / phiprime sigma = np.where(usesolu, 0., sigma) phiout, phiprimeout = phiphiprime(sigma) unconverged = np.logical_and((phiout > phi), (phiout < 0.)) phi, phiprime = (phiout, phiprimeout) j = j + 1 return (sigma, phi, phiprime, unconverged, j) sigma = jax.lax.while_loop(cond, body, (sigma, phi, phiprime, unconverged, j))[0] #compute solution from eigenvalues and eigenvectors coeffs = -a / (lam + sigma) coeffscol = np.expand_dims(coeffs, axis=-1) p = np.matmul(u, coeffscol) p = np.squeeze(p, axis=-1) #compute predicted reduction in loss function from eigenvalues and eigenvectors predicted_reduction = -np.sum(a * coeffs + 0.5 * lam * np.square(coeffs), axis=-1) #compute actual reduction in loss x_new = x + p val_new = f(x_new, *args) #actual_reduction = -(val_new - val) actual_reduction = val - val_new #update trust radius and output parameters, following Nocedal and Wright 2nd ed. Algorithm 4.1 eta = 0.15 trust_radius_max = 1e3 rho = actual_reduction / np.where(np.equal(actual_reduction, 0.), 1., predicted_reduction) rho = np.where(np.isnan(rho), 0., rho) at_boundary = np.logical_not(usesolu) trust_radius_out = np.where( rho < 0.25, 0.25 * trust_radius, np.where(np.logical_and(rho > 0.75, at_boundary), np.minimum(2. * trust_radius, trust_radius_max), trust_radius)) x_out = np.where(rho > eta, x_new, x) #compute estimated distance to minimum for unconstrained solution (only valid if e0>0) coeffs0 = -a / lam edm = -np.sum(a * coeffs0 + 0.5 * lam * np.square(coeffs0), axis=-1) return x_out, trust_radius_out, val, gradmag, edm, e0
def _unconverged(lk, j, maxiter, err, tol_delta, tol_lk): changing = err > tol_delta far_from_end = jnp.abs(1 - lk) > tol_lk unconverged = jnp.logical_or(changing, far_from_end) iterating = j < maxiter return jnp.logical_and(iterating, unconverged)[0]
def _rational_quadratic_spline_inv(y: Array, x_pos: Array, y_pos: Array, knot_slopes: Array) -> Tuple[Array, Array]: """Applies the inverse of a rational-quadratic spline to a scalar. Args: y: a scalar (0-dimensional array). The scalar `y` can be any real number; it will be transformed by the spline if it's in the closed interval `[y_pos[0], y_pos[-1]]`, and it will be transformed linearly if it's outside that interval. x_pos: array of shape [num_bins + 1], the bin boundaries on the x axis. y_pos: array of shape [num_bins + 1], the bin boundaries on the y axis. knot_slopes: array of shape [num_bins + 1], the slopes at the knot points. Returns: A tuple of two scalars: the output of the inverse transformation and the log of the absolute first derivative of the inverse at `y`. """ # Search to find the right bin. NOTE: The bins are sorted, so we could use # binary search, but this is more GPU/TPU friendly. # The following implementation avoids indexing for faster TPU computation. below_range = y <= y_pos[0] above_range = y >= y_pos[-1] correct_bin = jnp.logical_and(y >= y_pos[:-1], y < y_pos[1:]) any_bin_in_range = jnp.any(correct_bin) first_bin = jnp.concatenate( [jnp.array([1]), jnp.zeros(len(correct_bin) - 1)]).astype(bool) # If y does not fall into any bin, we use the first spline in the following # computations to avoid numerical issues. correct_bin = jnp.where(any_bin_in_range, correct_bin, first_bin) # Dot product of each parameter with the correct bin mask. params = jnp.stack([x_pos, y_pos, knot_slopes], axis=1) params_bin_left = jnp.sum(correct_bin[:, None] * params[:-1], axis=0) params_bin_right = jnp.sum(correct_bin[:, None] * params[1:], axis=0) # These are the parameters for the corresponding bin. x_pos_bin = (params_bin_left[0], params_bin_right[0]) y_pos_bin = (params_bin_left[1], params_bin_right[1]) knot_slopes_bin = (params_bin_left[2], params_bin_right[2]) bin_width = x_pos_bin[1] - x_pos_bin[0] bin_height = y_pos_bin[1] - y_pos_bin[0] bin_slope = bin_height / bin_width w = (y - y_pos_bin[0]) / bin_height w = jnp.clip(w, 0., 1.) # Ensure w is in [0, 1]. # Compute quadratic coefficients: az^2 + bz + c = 0 slopes_term = knot_slopes_bin[1] + knot_slopes_bin[0] - 2. * bin_slope c = -bin_slope * w b = knot_slopes_bin[0] - slopes_term * w a = bin_slope - b # Solve quadratic to obtain z and then x. z = -2. * c / (b + jnp.sqrt(b**2 - 4. * a * c)) z = jnp.clip(z, 0., 1.) # Ensure z is in [0, 1]. x = bin_width * z + x_pos_bin[0] # Compute log det Jacobian. sq_z = z * z z1mz = z - sq_z # z(1-z) sq_1mz = (1. - z)**2 denominator = bin_slope + slopes_term * z1mz logdet = -2. * jnp.log(bin_slope) - jnp.log( knot_slopes_bin[1] * sq_z + 2. * bin_slope * z1mz + knot_slopes_bin[0] * sq_1mz) + 2. * jnp.log(denominator) # If y is outside the spline range, we default to a linear transformation. x = jnp.where(below_range, (y - y_pos[0]) / knot_slopes[0] + x_pos[0], x) x = jnp.where(above_range, (y - y_pos[-1]) / knot_slopes[-1] + x_pos[-1], x) logdet = jnp.where(below_range, -jnp.log(knot_slopes[0]), logdet) logdet = jnp.where(above_range, -jnp.log(knot_slopes[-1]), logdet) return x, logdet
def loop_cond(carry): _, _, breakdown, k = carry return jnp.logical_and(k < restart, jnp.logical_not(breakdown))
def _rational_quadratic_spline_fwd(x: Array, x_pos: Array, y_pos: Array, knot_slopes: Array) -> Tuple[Array, Array]: """Applies a rational-quadratic spline to a scalar. Args: x: a scalar (0-dimensional array). The scalar `x` can be any real number; it will be transformed by the spline if it's in the closed interval `[x_pos[0], x_pos[-1]]`, and it will be transformed linearly if it's outside that interval. x_pos: array of shape [num_bins + 1], the bin boundaries on the x axis. y_pos: array of shape [num_bins + 1], the bin boundaries on the y axis. knot_slopes: array of shape [num_bins + 1], the slopes at the knot points. Returns: A tuple of two scalars: the output of the transformation and the log of the absolute first derivative at `x`. """ # Search to find the right bin. NOTE: The bins are sorted, so we could use # binary search, but this is more GPU/TPU friendly. # The following implementation avoids indexing for faster TPU computation. below_range = x <= x_pos[0] above_range = x >= x_pos[-1] correct_bin = jnp.logical_and(x >= x_pos[:-1], x < x_pos[1:]) any_bin_in_range = jnp.any(correct_bin) first_bin = jnp.concatenate( [jnp.array([1]), jnp.zeros(len(correct_bin) - 1)]).astype(bool) # If y does not fall into any bin, we use the first spline in the following # computations to avoid numerical issues. correct_bin = jnp.where(any_bin_in_range, correct_bin, first_bin) # Dot product of each parameter with the correct bin mask. params = jnp.stack([x_pos, y_pos, knot_slopes], axis=1) params_bin_left = jnp.sum(correct_bin[:, None] * params[:-1], axis=0) params_bin_right = jnp.sum(correct_bin[:, None] * params[1:], axis=0) x_pos_bin = (params_bin_left[0], params_bin_right[0]) y_pos_bin = (params_bin_left[1], params_bin_right[1]) knot_slopes_bin = (params_bin_left[2], params_bin_right[2]) bin_width = x_pos_bin[1] - x_pos_bin[0] bin_height = y_pos_bin[1] - y_pos_bin[0] bin_slope = bin_height / bin_width z = (x - x_pos_bin[0]) / bin_width # `z` should be in range [0, 1] to avoid NaNs later. This can happen because # of small floating point issues or when x is outside of the range of bins. # To avoid all problems, we restrict z in [0, 1]. z = jnp.clip(z, 0., 1.) sq_z = z * z z1mz = z - sq_z # z(1-z) sq_1mz = (1. - z)**2 slopes_term = knot_slopes_bin[1] + knot_slopes_bin[0] - 2. * bin_slope numerator = bin_height * (bin_slope * sq_z + knot_slopes_bin[0] * z1mz) denominator = bin_slope + slopes_term * z1mz y = y_pos_bin[0] + numerator / denominator # Compute log det Jacobian. # The logdet is a sum of 3 logs. It is easy to see that the inputs of the # first two logs are guaranteed to be positive because we ensured that z is in # [0, 1]. This is also true of the log(denominator) because: # denominator # == bin_slope + (knot_slopes_bin[1] + knot_slopes_bin[0] - 2 * bin_slope) * # z*(1-z) # >= bin_slope - 2 * bin_slope * z * (1-z) # >= bin_slope - 2 * bin_slope * (1/4) # == bin_slope / 2 logdet = 2. * jnp.log(bin_slope) + jnp.log( knot_slopes_bin[1] * sq_z + 2. * bin_slope * z1mz + knot_slopes_bin[0] * sq_1mz) - 2. * jnp.log(denominator) # If x is outside the spline range, we default to a linear transformation. y = jnp.where(below_range, (x - x_pos[0]) * knot_slopes[0] + y_pos[0], y) y = jnp.where(above_range, (x - x_pos[-1]) * knot_slopes[-1] + y_pos[-1], y) logdet = jnp.where(below_range, jnp.log(knot_slopes[0]), logdet) logdet = jnp.where(above_range, jnp.log(knot_slopes[-1]), logdet) return y, logdet
def preprocess_masked(inputs, random_tokens, mask_token, pad_token, mask_rate, mask_token_proportion, random_token_proportion, mode, rng): """Preprocess inputs for masked language modeling. Args: inputs: [batch x length] input tokens. random_tokens: Set of tokens usable for replacing mask_token: Int ID to mask blanks with. pad_token: Int ID for PAD token. Positions left unchanged. mask_rate: Proportion of tokens to mask out. mask_token_proportion: Replace this proportion of chosen positions with MASK. random_token_proportion: Replace this proportion of chosen positions with randomly sampled tokens mode: Mode key. rng: Jax RNG. Returns: Tuple of [batch x length] inputs, targets, per position weights. targets will have random positions masked out with either a MASK token, or a randomly chosen token from the vocabulary. """ total = random_token_proportion + mask_token_proportion if total < 0 or total > 1: raise ValueError('Sum of random proportion and mask proportion must be' ' in [0, 1] range.') targets = inputs if mode == Mode.predict: weights = jnp.full_like(targets, 1) masked_inputs = inputs # Pass through else: if rng is None: if mode is not Mode.eval: raise ValueError('Must provide RNG unless in eval mode.') # TODO(b/157055145): How to keep same eval set across runs? # Make each sequences mask invariant to other members # of the batch. Right now there is batch size dependence. rng = jrandom.PRNGKey(jnp.sum(inputs)) # Get positions to leave untouched is_pad = inputs == pad_token # Positions to mask rng, subrng = jax.random.split(rng) should_mask = jrandom.bernoulli(subrng, p=mask_rate, shape=inputs.shape) should_mask = jnp.where(is_pad, 0, should_mask) # Don't mask out padding. # Generate full array of random tokens. rng, subrng = jax.random.split(rng) random_ids = jax.random.randint(subrng, inputs.shape, minval=0, maxval=len(random_tokens)) fullrandom = random_tokens[random_ids] # Full array of MASK tokens fullmask = jnp.full_like(inputs, mask_token) # Build up masked array by selecting from inputs/fullmask/fullrandom. rand = jax.random.uniform(rng, shape=inputs.shape) masked_inputs = inputs # Remaining probability mass stays original values after MASK and RANDOM. # MASK tokens. masked_inputs = jnp.where(rand < mask_token_proportion, fullmask, masked_inputs) # Random tokens. masked_inputs = jnp.where( jnp.logical_and( rand >= mask_token_proportion, rand < mask_token_proportion + random_token_proportion), fullrandom, masked_inputs) # Only replace positions where `should_mask` masked_inputs = jnp.where(should_mask, masked_inputs, inputs) weights = should_mask return masked_inputs, targets, weights
def _safe_div(x1, x2): return jnp.where(jnp.logical_and(x1 == 0, x2 == 0), x1, x1 / x2)
def logical_and(x1, x2): if isinstance(x1, JaxArray): x1 = x1.value if isinstance(x2, JaxArray): x2 = x2.value return JaxArray(jnp.logical_and(x1, x2))
def _iter_condition(state): i, unused_v, unused_s, unused_s_v, run_step = state return jnp.logical_and(i < num_iters, run_step)
def qnorm_cond(carry): k, not_done, _, _ = carry return jnp.logical_and(not_done, k < (max_iterations - 1))
def _iter_condition(state): (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, run_step) = state return jnp.logical_and( i < iter_count, jnp.logical_or(error > error_tolerance, run_step))
def loop_cond(carry): k, err, _, _, _, _ = carry return jnp.logical_and(k < restart, err > ptol)
def _get_attention_result(query, key, value, dtype, precision, dropout_rng, dropout_rate, broadcast_dropout, deterministic, mask=None, padding_mask=None, key_padding_mask=None, segmentation=None, key_segmentation=None, apply_causal_mask=False): """Helper function returning `[batch_size, seq_len, heads, features]` output.""" # assumes query/key/value has shape `[batch_size, seq_len, heads, features]`. mask_components = [] if mask is None else [mask] seq_len = query.shape[1] if apply_causal_mask: causal_mask = jnp.array( np.reshape(np.tri(seq_len, k=0), [1, 1, seq_len, seq_len])).astype(jnp.bool_) mask_components.append(causal_mask) if padding_mask is not None: if key_padding_mask is None: key_padding_mask = padding_mask padding_mask = nn.attention.make_padding_mask( padding_mask_query=padding_mask, padding_mask_key=key_padding_mask, query_shape=query.shape, key_shape=key.shape, attention_axis=(1, )) mask_components.append(padding_mask) if segmentation is not None: if key_segmentation is None: key_segmentation = segmentation segmentation_mask = nn.attention.make_padding_mask( padding_mask_query=segmentation, padding_mask_key=key_segmentation, query_shape=query.shape, key_shape=key.shape, attention_axis=(1, ), segmentation_mask=True) mask_components.append(segmentation_mask) if mask_components: attention_mask = mask_components[0] for component in mask_components[1:]: attention_mask = jnp.logical_and(attention_mask, component) # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype), jnp.full(attention_mask.shape, -1e10).astype(dtype)) else: attention_bias = None return nn.attention.dot_product_attention( query, key, value, dtype=dtype, axis=1, bias=attention_bias, precision=precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic)