Exemplo n.º 1
0
def compute_cross_entropy_loss_with_positives_and_negatives_masks(
    scores: Array,
    positives: Array,
    negatives: Array,
    weights: Optional[Array] = None,
) -> Tuple[float, Dict[str, float], Tuple[Array, Array]]:
  """Compute (weighted) cross-entropy loss and accuracy-related metrics.

  The function computes cross entropy loss when there are potentially multiple
  positive classes per sample, multiple negative classes and others are neutral.
  In this case, loss per sample is average of cross entropy losses computed
  by considering each positive class and all negative classes.
  Neutral classes are ignored.

  Arguments `positives` and `negatives` are boolean matrices that specify
  which class is considered positive or negative per every sample.
  `positive[i, j]` is True <=> class j is considered positive for the sample i
  `negative[i, j]` is True <=> class j is considered negative for the sample i

  The loss is computed in 3 stages:

  (1) For every sample i and positive class j we compute cross-entropy loss
  using j as a positive class and all negative classes for i as negatives.

  (2) For every sample i the total loss is average of losses per each of its
  positive classes.

  (3) Total loss is a sum of losses per each sample. The loss only includes
  samples, which have at least one positive and one negative classes. Users
  can limit this even further by providing a custom `weights`.

  Args:
   scores: [batch_size, num_classes] scores or logits.
   positives: [batch_size, num_classes] 0-1 mask for which classes are positive.
   negatives: [batch_size, num_classes] 0-1 mask for which classes are negative.
   weights: [batch_size] 0-1 masks indicating whether the loss should be
     computed for the corresponding item in the batch.

  Returns:
    A tuple of scalar loss, a dictionary with metrics, per sample information
    (a tuple of average positive probability per sample and weight per sample).
  """
  at_least_one_positive_and_negative = jnp.logical_and(
      positives.sum(-1) > 0,
      negatives.sum(-1) > 0)
  if weights is None:
    weights = at_least_one_positive_and_negative
  else:
    weights = jnp.logical_and(weights, at_least_one_positive_and_negative)

  scores = scores.astype(jnp.float32)
  positives = positives.astype(jnp.float32)
  negatives = negatives.astype(jnp.float32)
  weights = weights.astype(jnp.float32)

  # For simplicity, we ignore the first batch dimension in the equations below
  # and assume that the loss is computed for a single sample.
  # Let p_1, ..., p_N be scores of positive classes
  # and n_1, ..., n_M be scores of negative classes.
  # In this case the loss is
  # sum_{i=1..N} -log softmax([p_i, n_1, ..., n_M])_1.
  # It's too computationally expensive to compute it naively.
  # We implement the loss in the following way

  # (1) compute S, the negatives part of softmax denominator. In other words,
  # exp(S) = sum_{j=1..M} exp(n_j)
  negative_scores = scores * negatives - _BIG_NUMBER * (1.0 - negatives)

  negative_scores_log_sum_exp = jax.nn.logsumexp(
      negative_scores, axis=-1, keepdims=True)

  # (2) now the loss per positive class i is just
  # -log (exp(p_i) / (exp(p_i) + exp(S)) = -log(1 / (1 + exp(-(p_i - S))))
  # = -log sigmoid(p_i - S)
  scores_minus_negatives = scores - negative_scores_log_sum_exp
  positives_weight = (positives.sum(axis=-1) + _SMALL_NUMBER)
  per_positive_loss = -jax.nn.log_sigmoid(scores_minus_negatives)

  # (3) compute average loss over all positive classes
  loss_per_sample = (per_positive_loss * positives).sum(axis=-1)
  loss_per_sample /= positives_weight
  loss_per_sample *= weights

  # (4) compute sum of losses over all positive samples
  loss = loss_per_sample.sum()

  # Now we need to compute the average accuracy.
  # First, compute the max score of negative classes per sample.
  # A positive class needs to have a higher score in order to get predicted.
  max_negative_scores = negative_scores.max(axis=-1, keepdims=True)

  # Second, a prediction for pair of a sample and its positive class
  # is correct if the score of the positive class is larger than
  # scores of all corresponding negative classes. In other words, the score
  # of the positive class needs to be larger than `max_negative_scores`.
  correct_prediction = (scores > max_negative_scores).astype(jnp.float32)

  # Take average over all positive classes per sample
  correct_prediction = (correct_prediction * positives).sum(axis=-1)
  correct_prediction /= positives_weight

  # Mask out samples with 0 weight
  correct_prediction = correct_prediction * weights

  metrics = {
      'loss': loss,
      'acc': correct_prediction.sum(),
      'denominator': weights.sum(),
  }
  return loss, metrics, (correct_prediction, weights)
Exemplo n.º 2
0
 def has_updated(self, state: MultiStepsState) -> Array:
     return jnp.logical_and(state.mini_step == 0, state.gradient_step > 0)
Exemplo n.º 3
0
 def cond_function(carry):
     k, _, r, qnorm_scaled = carry
     _, rnorm = _safe_normalize(r)
     return jnp.logical_and(k < (max_iterations - 1), rnorm < qnorm_scaled)
Exemplo n.º 4
0
 def cond_fun(value):
     _, k, _, residual_norm = value
     return jnp.logical_and(k < maxiter, residual_norm > atol)
Exemplo n.º 5
0
 def _do_cholesky(args):
   matrix, j, coefs, err = args
   unconverged = err > (eps * jnp.linalg.norm(matrix))
   iterating = j < maxiter
   return jnp.logical_and(unconverged, iterating)[0]
Exemplo n.º 6
0
 def cond_f(args):
     _, _, j, error = args
     still_counting = j < maxiter
     unconverged = error > thresh
     return jnp.logical_and(still_counting, unconverged)[0]
Exemplo n.º 7
0
def multi_head_dot_product_attention(scope: Scope,
                                     inputs_q,
                                     inputs_kv,
                                     num_heads,
                                     dtype=jnp.float32,
                                     qkv_features=None,
                                     out_features=None,
                                     attention_axis=None,
                                     causal_mask=False,
                                     padding_mask=None,
                                     key_padding_mask=None,
                                     segmentation=None,
                                     key_segmentation=None,
                                     cache=False,
                                     broadcast_dropout=True,
                                     dropout_rng=None,
                                     dropout_rate=0.,
                                     deterministic=False,
                                     precision=None,
                                     kernel_init=default_kernel_init,
                                     bias_init=initializers.zeros,
                                     bias=True,
                                     attention_fn=dot_product_attention):
    """Applies multi-head dot product attention on the input data.

  Projects the inputs into multi-headed query, key, and value vectors,
  applies dot-product attention and project the results to an output vector.

  This can be used for encoder-decoder attention by specifying both `inputs_q`
  and `inputs_kv` orfor self-attention by only specifying `inputs_q` and
  setting `inputs_kv` to None.

  Args:
    inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
    inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]`
      or None for self-attention, inn which case key/values will be derived
      from inputs_q.
    num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
      should be divisible by the number of heads.
    dtype: the dtype of the computation (default: float32)
    qkv_features: dimension of the key, query, and value.
    out_features: dimension of the last projection
    attention_axis: axes over which the attention is applied ( 'None' means
      attention over all axes, but batch, heads, and features).
    causal_mask: boolean specifying whether to apply a causal mask on the
      attention weights. If True, the output at timestep `t` will not depend
      on inputs at timesteps strictly greater than `t`.
    padding_mask: boolean specifying query tokens that are pad token.
    key_padding_mask: boolean specifying key-value tokens that are pad token.
    segmentation: segment indices for packed inputs_q data.
    key_segmentation: segment indices for packed inputs_kv data.
    cache: an instance of `flax.nn.attention.Cache` used for efficient
      autoregressive decoding.
    broadcast_dropout: bool: use a broadcasted dropout along batch dims.
    dropout_rng: JAX PRNGKey: to be used for dropout
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.
    kernel_init: initializer for the kernel of the Dense layers.
    bias_init: initializer for the bias of the Dense layers.
    bias: bool: whether pointwise QKVO dense transforms use bias.
    attention_fn: dot_product_attention or compatible function. Accepts
    query, key, value, and returns output of shape
    `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]``

  Returns:
    output of shape `[bs, dim1, dim2, ..., dimN, features]`.
  """

    assert causal_mask or not cache, (
        'Caching is only support for causal attention.')

    if inputs_kv is None:
        inputs_kv = inputs_q

    if attention_axis is None:
        attention_axis = tuple(range(1, inputs_q.ndim - 1))

    features = out_features or inputs_q.shape[-1]
    qkv_features = qkv_features or inputs_q.shape[-1]

    assert qkv_features % num_heads == 0, (
        'Memory dimension must be divisible by number of heads.')
    head_dim = qkv_features // num_heads

    dense = functools.partial(dense_general,
                              axis=-1,
                              dtype=dtype,
                              features=(num_heads, head_dim),
                              kernel_init=kernel_init,
                              bias_init=bias_init,
                              bias=bias,
                              precision=precision)
    # project inputs_q to multi-headed q/k/v
    # dimensions are then [bs, dims..., n_heads, n_features_per_head]
    query = scope.child(dense, 'query')(inputs_q)
    key = scope.child(dense, 'key')(inputs_kv)
    value = scope.child(dense, 'value')(inputs_kv)

    if cache:
        if not scope.has_variable('cache', 'entry'):
            ndim, tail_shape = (key.ndim, key.shape[-2:])

            def init_fn(shape, dtype=jnp.float32):
                full_shape = shape + tail_shape
                if len(full_shape) != ndim:
                    raise ValueError(
                        'Shape should be a tuple with the shape of the batch'
                        'and attention dims.')
                return CacheEntry(key=jnp.zeros(full_shape, dtype),
                                  value=jnp.zeros(full_shape, dtype),
                                  i=jnp.zeros((), jnp.uint32))

            cache_entry = init_fn
        else:
            cache_entry = scope.get_variable('cache', 'entry')
            if not isinstance(cache_entry, CacheEntry):
                raise ValueError('Cache is not initialized.')

            expected_shape = list(cache_entry.key.shape[:-2])
            for attn_dim in attention_axis:
                expected_shape[attn_dim] = 1
            expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
            if expected_shape != inputs_q.shape:
                raise ValueError('Invalid shape provided, '
                                 'expected shape %s instead got %s.' %
                                 (expected_shape, inputs_q.shape))

            cshape = cache_entry.key.shape
            indices = [0] * len(cshape)
            i = cache_entry.i
            attn_size = onp.prod(onp.take(cshape, attention_axis))
            for attn_dim in attention_axis:
                attn_size //= cshape[attn_dim]
                indices[attn_dim] = i // attn_size
                i = i % attn_size

            key = lax.dynamic_update_slice(cache_entry.key, key, indices)
            value = lax.dynamic_update_slice(cache_entry.value, value, indices)
            one = jnp.array(1, jnp.uint32)
            cache_entry = cache_entry.replace(i=cache_entry.i + one,
                                              key=key,
                                              value=value)

            # TODO(levskaya): verify this is still needed in translation decoding.
            key_padding_mask = jnp.broadcast_to(
                (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2])
            key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None]
        scope.put_variable('cache', 'entry', cache_entry)

    # create attention masks
    mask_components = []

    if causal_mask:
        if cache and isinstance(cache_entry, CacheEntry):
            bias_pre_shape = (1, ) * (key.ndim - 1)
            attn_shape = tuple(onp.take(key.shape, attention_axis))
            attn_size = onp.prod(attn_shape)
            ii = jnp.arange(attn_size, dtype=jnp.uint32)
            mask = ii < cache_entry.i
            mask_components.append(mask.reshape(bias_pre_shape + attn_shape))
        else:
            mask_components.append(_make_causal_mask(key, attention_axis))

    if padding_mask is not None:
        if key_padding_mask is None:
            key_padding_mask = padding_mask
        padding_mask = make_padding_mask(padding_mask_query=padding_mask,
                                         padding_mask_key=key_padding_mask,
                                         query_shape=query.shape,
                                         key_shape=key.shape,
                                         attention_axis=attention_axis)
        mask_components.append(padding_mask)

    if segmentation is not None:
        if key_segmentation is None:
            key_segmentation = segmentation
        segmentation_mask = make_padding_mask(
            padding_mask_query=segmentation,
            padding_mask_key=key_segmentation,
            query_shape=query.shape,
            key_shape=key.shape,
            attention_axis=attention_axis,
            segmentation_mask=True)
        mask_components.append(segmentation_mask)

    if mask_components:
        attention_mask = mask_components[0]
        for component in mask_components[1:]:
            attention_mask = jnp.logical_and(attention_mask, component)

        # attention mask in the form of attention bias
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.).astype(dtype),
            jnp.full(attention_mask.shape, -1e10).astype(dtype))
    else:
        attention_bias = None

    # apply attention
    x = scope.child(attention_fn)(query,
                                  key,
                                  value,
                                  dtype=dtype,
                                  axis=attention_axis,
                                  bias=attention_bias,
                                  precision=precision,
                                  dropout_rng=dropout_rng,
                                  dropout_rate=dropout_rate,
                                  broadcast_dropout=broadcast_dropout,
                                  deterministic=deterministic)

    # back to the original inputs dimensions
    out = scope.child(dense_general, name='out')(x,
                                                 features=features,
                                                 axis=(-2, -1),
                                                 kernel_init=kernel_init,
                                                 bias_init=bias_init,
                                                 bias=bias,
                                                 dtype=dtype,
                                                 precision=precision)

    return out
Exemplo n.º 8
0
def eval_step(state, inputs, outputs, programs, bos_token, eos_token, config,
              lp_config):
    """Evaluate on batch of program tasks."""
    params = state.optimizer.target
    lp_params = state.lp_optimizer.target

    weights = jnp.where(programs > 0, 1, 0).astype(jnp.float32)
    # Embedding mask for autoencoding.
    emb_mask = jnp.ones((1, FLAGS.latent_vocab_size),
                        jnp.float32).at[:, [0, bos_token, eos_token]].set(0)

    ae_logits, vq = models.LatentProgramTransformer(config).apply(
        {
            'params': params,
            'vqvae': state.model_state
        },
        inputs,
        outputs,
        programs,
        emb_mask,
        mutable=False)

    # Postprocess latent indices.
    latent_indices = add_eos_token(vq['latent_indices'], eos_token)
    latent_weights = jnp.where(latent_indices > 0, 1, 0).astype(jnp.float32)

    encoded_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32)
    # Additionally mask out eos token in latents.
    latents_mask = jnp.where(
        jnp.logical_and(latent_indices > 0, latent_indices != eos_token), 1,
        0).astype(jnp.float32)

    latent_logits = models.ProgramTransformer(lp_config).apply(
        {'params': lp_params}, inputs, outputs, latent_indices)

    encoded = models.LatentProgramTransformer(config).apply(
        {
            'params': params,
            'vqvae': state.model_state
        },
        inputs,
        outputs,
        mutable=False,
        method=models.LatentProgramTransformer.encode)
    latents = models.LatentProgramTransformer(config).apply(
        {
            'params': params,
            'vqvae': state.model_state
        },
        latent_logits,
        mutable=False,
        method=models.LatentProgramTransformer.quantize)
    logits = models.LatentProgramTransformer(config).apply(
        {
            'params': params,
            'vqvae': state.model_state
        },
        programs,
        latents,
        encoded,
        latents_mask,
        encoded_mask,
        mutable=False,
        method=models.LatentProgramTransformer.decode)

    metrics = compute_metrics(logits, programs, weights)
    metrics.update(compute_metrics(ae_logits, programs, weights, prefix='ae_'))
    latent_metrics = compute_metrics(latent_logits,
                                     latent_indices,
                                     latent_weights,
                                     prefix='latent_')
    return metrics, latent_metrics
Exemplo n.º 9
0
def predict_step(state, inputs, outputs, cache, lp_cache, beam_size, bos_token,
                 eos_token, max_decode_len, config, lp_config):
    """Predict translation with fast decoding beam search on a batch."""
    params = state.optimizer.target
    lp_params = state.lp_optimizer.target

    # Split beam over latent sequences and programs.
    per_latent_beam_size = beam_size // FLAGS.latent_beam_size
    beam_size = FLAGS.latent_beam_size * per_latent_beam_size

    flat_lp_encoded = decode.flat_batch_beam_expand(
        models.ProgramTransformer(lp_config).apply(
            {'params': lp_params},
            inputs,
            outputs,
            method=models.ProgramTransformer.encode), FLAGS.latent_beam_size)

    encoded_padding_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32)
    flat_encoded_padding_mask = decode.flat_batch_beam_expand(
        encoded_padding_mask, FLAGS.latent_beam_size)

    def tokens_ids_to_latent_logits(flat_ids, flat_lp_cache):
        """Token slice to logits from decoder model."""
        # --> [batch * beam, 1, vocab]
        flat_logits, new_vars = models.ProgramTransformer(lp_config).apply(
            {
                'params': lp_params,
                'cache': flat_lp_cache
            },
            flat_ids,
            flat_lp_encoded,
            flat_encoded_padding_mask,
            mutable=['cache'],
            method=models.ProgramTransformer.decode)
        new_flat_lp_cache = new_vars['cache']
        # Remove singleton sequence-length dimension:
        # [batch * beam, 1, vocab] --> [batch * beam, vocab]
        flat_logits = flat_logits.squeeze(axis=1)
        return flat_logits, new_flat_lp_cache

    # Step 1: Beam-search over latent tokens.
    latent_beam_seqs, _ = decode.beam_search(
        inputs,
        lp_cache,
        tokens_ids_to_latent_logits,
        beam_size=FLAGS.latent_beam_size,
        alpha=0.6,
        bos_token=bos_token,
        eos_token=eos_token,
        max_decode_len=np.ceil(max_decode_len / 2**FLAGS.c).astype(np.int32))

    flat_latent_seqs = decode.flat_batch_beam_expand(
        decode.flatten_beam_dim(latent_beam_seqs), per_latent_beam_size)
    # Quantize the predicted latent codes.
    flat_latents = models.LatentProgramTransformer(config).apply(
        {
            'params': params,
            'vqvae': state.model_state
        },
        flat_latent_seqs,
        mutable=False,
        method=models.LatentProgramTransformer.quantize)

    flat_encoded = decode.flat_batch_beam_expand(
        models.LatentProgramTransformer(config).apply(
            {
                'params': params,
                'vqvae': state.model_state
            },
            inputs,
            outputs,
            mutable=False,
            method=models.LatentProgramTransformer.encode), beam_size)

    # Padding masks.
    flat_latents_mask = jnp.where(
        jnp.logical_and(flat_latent_seqs > 0, flat_latent_seqs != eos_token),
        1, 0).astype(jnp.float32)
    flat_encoded_padding_mask = decode.flat_batch_beam_expand(
        encoded_padding_mask, beam_size)

    def tokens_ids_to_logits(flat_ids, flat_cache):
        """Token slice to logits from decoder model."""
        # --> [batch * beam, 1, vocab]
        flat_logits, new_vars = models.LatentProgramTransformer(config).apply(
            {
                'params': params,
                'vqvae': state.model_state,
                'cache': flat_cache
            },
            flat_ids,
            flat_latents,
            flat_encoded,
            flat_latents_mask,
            flat_encoded_padding_mask,
            mutable=['cache'],
            method=models.LatentProgramTransformer.decode)
        new_flat_cache = new_vars['cache']
        # Remove singleton sequence-length dimension:
        # [batch * beam, 1, vocab] --> [batch * beam, vocab]
        flat_logits = flat_logits.squeeze(axis=1)
        return flat_logits, new_flat_cache

    # Step 2: Beam-search over program tokens.
    per_latent_inputs = decode.flat_batch_beam_expand(inputs,
                                                      FLAGS.latent_beam_size)
    per_latent_cache = jax.tree_map(
        lambda x: decode.flat_batch_beam_expand(x, FLAGS.latent_beam_size),
        cache)
    beam_seqs, _ = decode.beam_search(per_latent_inputs,
                                      per_latent_cache,
                                      tokens_ids_to_logits,
                                      beam_size=per_latent_beam_size,
                                      alpha=0.6,
                                      bos_token=bos_token,
                                      eos_token=eos_token,
                                      max_decode_len=max_decode_len)
    # Collapse both beam dimensions into one.
    beam_seqs = beam_seqs.reshape((inputs.shape[0], beam_size) +
                                  beam_seqs.shape[2:])
    latent_beam_seqs = jnp.repeat(latent_beam_seqs,
                                  per_latent_beam_size,
                                  axis=1)

    # Beam search returns [n_batch, n_beam, n_length] with beam dimension
    # sorted in increasing order of log-probability.
    return beam_seqs, latent_beam_seqs
Exemplo n.º 10
0
def antiu(x, xi):
    dark = onp.array(g(x, xi))
    ind = onp.where(np.logical_and(dark > lb(x), dark < ub(x)))[0]
    dark[ind] = np.nan
    return dark
Exemplo n.º 11
0
def train_step(state,
               inputs,
               outputs,
               programs,
               pretrain,
               bos_token,
               eos_token,
               learning_rate_fn,
               config,
               lp_config,
               train_rng=None):
    """Train on batch of program tasks."""
    # We handle PRNG splitting inside the top pmap, rather
    # than handling it outside in the training loop - doing the
    # latter can add some stalls to the devices.
    train_rng, new_train_rng = jax.random.split(train_rng)

    weights = jnp.where(programs > 0, 1, 0).astype(jnp.float32)

    # Embedding mask for autoencoding.
    emb_mask = jnp.ones((1, FLAGS.latent_vocab_size),
                        jnp.float32).at[:, [0, bos_token, eos_token]].set(0)

    def ae_loss_fn(params):
        """Loss function used for training autoencoder."""
        (logits,
         vq), new_variables = models.LatentProgramTransformer(config).apply(
             {
                 'params': params,
                 'vqvae': state.model_state
             },
             inputs,
             outputs,
             programs,
             emb_mask,
             pretrain=pretrain,
             mutable=['vqvae'],
             rngs={'dropout': train_rng})
        loss, weight_sum = compute_weighted_cross_entropy(
            logits, programs, weights)

        # Add EOS token for latent predictor loss.
        vq_weight_sum = jnp.sum(
            jnp.where(vq['latent_indices'] > 0, 1, 0).astype(jnp.float32))
        latent_indices = add_eos_token(vq['latent_indices'], eos_token)

        mean_loss = loss / weight_sum + vq['loss'] / vq_weight_sum
        return mean_loss, (new_variables['vqvae'], logits, latent_indices)

    step = state.step
    optimizer = state.optimizer
    lp_optimizer = state.lp_optimizer
    lr = learning_rate_fn(step)
    grad_fn = jax.value_and_grad(ae_loss_fn, has_aux=True)
    (_, (new_model_state, ae_logits,
         latent_indices)), ae_grad = grad_fn(optimizer.target)
    ae_grad = jax.lax.pmean(ae_grad, 'batch')

    latent_weights = jnp.where(latent_indices > 0, 1, 0).astype(jnp.float32)

    encoded_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32)
    # Additionally mask out eos token in latents.
    latents_mask = jnp.where(
        jnp.logical_and(latent_indices > 0, latent_indices != eos_token), 1,
        0).astype(jnp.float32)

    def loss_fn(params, lp_params):
        """Loss function used for training."""
        latent_logits = models.ProgramTransformer(lp_config).apply(
            {'params': lp_params},
            inputs,
            outputs,
            latent_indices,
            rngs={'dropout': train_rng})
        latent_loss, latent_weight_sum = compute_weighted_cross_entropy(
            latent_logits, latent_indices, latent_weights)

        # End-to-end prediction.
        encoded = models.LatentProgramTransformer(config).apply(
            {
                'params': params,
                'vqvae': state.model_state
            },
            inputs,
            outputs,
            mutable=False,
            rngs={'dropout': train_rng},
            method=models.LatentProgramTransformer.encode)
        latents = models.LatentProgramTransformer(config).apply(
            {
                'params': params,
                'vqvae': state.model_state
            },
            latent_logits,
            mutable=False,
            rngs={'dropout': train_rng},
            method=models.LatentProgramTransformer.quantize)
        logits = models.LatentProgramTransformer(config).apply(
            {
                'params': params,
                'vqvae': state.model_state
            },
            programs,
            latents,
            encoded,
            latents_mask,
            encoded_mask,
            mutable=False,
            rngs={'dropout': train_rng},
            method=models.LatentProgramTransformer.decode)
        loss, weight_sum = compute_weighted_cross_entropy(
            logits, programs, weights)

        mean_loss = latent_loss / latent_weight_sum
        if not pretrain:
            mean_loss += loss / weight_sum
        return mean_loss, (logits, latent_logits)

    grad_fn = jax.value_and_grad(loss_fn, argnums=[0, 1], has_aux=True)
    (_, (logits, latent_logits)), grads = grad_fn(optimizer.target,
                                                  lp_optimizer.target)
    grads = jax.lax.pmean(grads, 'batch')
    new_optimizer = optimizer.apply_gradient(jax.tree_multimap(
        jnp.add, grads[0], ae_grad),
                                             learning_rate=lr)
    new_lp_optimizer = lp_optimizer.apply_gradient(grads[1], learning_rate=lr)

    metrics = compute_metrics(logits, programs, weights)
    metrics['learning_rate'] = lr
    metrics.update(compute_metrics(ae_logits, programs, weights, prefix='ae_'))
    latent_metrics = compute_metrics(latent_logits,
                                     latent_indices,
                                     latent_weights,
                                     prefix='latent_')

    new_state = state.replace(step=step + 1,
                              optimizer=new_optimizer,
                              model_state=jax.lax.pmean(
                                  new_model_state, 'batch'),
                              lp_optimizer=new_lp_optimizer)
    return new_state, metrics, latent_metrics, new_train_rng
Exemplo n.º 12
0
    def __call__(self,
                 inputs_q,
                 inputs_kv,
                 *,
                 padding_mask,
                 key_padding_mask,
                 segmentation=None,
                 key_segmentation=None):
        """Applies multi-head dot product attention on the input data.

    If weight_prec is not None, scales and quantizes weights to signed int with
    weight_prec bits.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    This can be used for encoder-decoder attention by specifying both `inputs_q`
    and `inputs_kv` or for self-attention by only specifying `inputs_q` and
    setting `inputs_kv` to None.

    Args:
      inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
      inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or
        None for self-attention, inn which case key/values will be derived from
        inputs_q.
      padding_mask: boolean tensor specifying query tokens that are pad token.
      key_padding_mask: boolean tensor specifying key-value tokens that are pad
        token.
      segmentation: segment indices for packed inputs_q data.
      key_segmentation: segment indices for packed inputs_kv data.

    Returns:
      output of shape `[bs, dim1, dim2, ..., dimN, features]`.
    """
        batch_size, query_sequence_length, channel_size = inputs_q.shape
        hparams = self.hparams
        if inputs_kv is None:
            inputs_kv = inputs_q
            key_sequence_length = inputs_q.shape[1]
        else:
            key_sequence_length = inputs_kv.shape[1]
            shape_utils.assert_shapes_equal(
                inputs_kv.shape,
                (batch_size, key_sequence_length, channel_size))

        jax_precision = jax.lax.Precision.DEFAULT

        if padding_mask is not None:
            shape_utils.assert_shapes_equal(
                padding_mask.shape, (batch_size, query_sequence_length, 1))
        if key_padding_mask is None:
            key_padding_mask = padding_mask
        else:
            shape_utils.assert_shapes_equal(
                key_padding_mask.shape, (batch_size, key_sequence_length, 1))
        attention_axis = self.attention_axis
        if attention_axis is None:
            attention_axis = tuple(range(1, inputs_q.ndim - 1))

        qkv_features = self.qkv_features
        qkv_features = qkv_features or inputs_q.shape[-1]

        num_heads = self.num_heads
        assert qkv_features % num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // num_heads

        paxis_name = self.paxis_name
        train = self.train
        kernel_init = self.kernel_init
        bias_init = self.bias_init
        use_bias = self.use_bias
        dtype = self.dtype

        def multi_batch_dense_aqt(inputs, *, name, padding_mask):
            batch_size, sequence_length, channel_size = inputs.shape
            inputs = inputs.reshape(batch_size * sequence_length, channel_size)
            if padding_mask is not None:
                padding_mask = padding_mask.reshape(
                    batch_size * sequence_length, 1)
            out = flax_layers.DenseAqt(name=name,
                                       features=num_heads * head_dim,
                                       paxis_name=paxis_name,
                                       train=train,
                                       quant_context=self.quant_context,
                                       hparams=hparams.dense_kqv,
                                       kernel_init=kernel_init,
                                       bias_init=bias_init,
                                       use_bias=use_bias,
                                       dtype=dtype)(inputs,
                                                    padding_mask=padding_mask)
            return out.reshape(batch_size, sequence_length, num_heads,
                               head_dim)

        # project inputs_q to multi-headed q/k/v
        # dimensions are then [bs, sequence_length, n_heads, n_features_per_head]
        query = multi_batch_dense_aqt(inputs_q,
                                      name='query',
                                      padding_mask=padding_mask)
        key = multi_batch_dense_aqt(inputs_kv,
                                    name='key',
                                    padding_mask=key_padding_mask)
        value = multi_batch_dense_aqt(inputs_kv,
                                      name='value',
                                      padding_mask=key_padding_mask)
        is_cache_initialized = False
        if self.decode:
            is_cache_initialized = self.has_variable('cache', 'cached_key')
            cached_key = self.variable('cache', 'cached_key', jnp.zeros,
                                       key.shape, key.dtype)
            cached_value = self.variable('cache', 'cached_value', jnp.zeros,
                                         value.shape, value.dtype)
            cache_index = self.variable('cache', 'cache_index',
                                        lambda: jnp.array(0, dtype=jnp.int32))
            if is_cache_initialized:
                expected_shape = list(cached_key.value.shape[:-2])
                for attn_dim in attention_axis:
                    expected_shape[attn_dim] = 1
                expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
                if expected_shape != inputs_q.shape:
                    raise ValueError('Invalid shape provided, '
                                     'expected shape %s instead got %s.' %
                                     (expected_shape, inputs_q.shape))

                cshape = cached_key.value.shape
                indices = [0] * len(cshape)
                i = cache_index.value
                attn_size = onp.prod(onp.take(cshape, attention_axis))

                *batch_dims, max_length, num_heads, depth_per_head = (  # pylint: disable=unused-variable
                    cached_key.value.shape)
                indices = (0, ) * len(batch_dims) + (i, 0, 0)

                key = lax.dynamic_update_slice(cached_key.value, key, indices)
                value = lax.dynamic_update_slice(cached_value.value, value,
                                                 indices)
                one = jnp.array(1, jnp.int32)
                cache_index.value = cache_index.value + one
                cached_key.value = key
                cached_value.value = value

                # TODO(levskaya): verify this is still needed in translation decoding.
                key_padding_mask = jnp.broadcast_to(
                    (jnp.arange(max_length) < cache_index.value), cshape[:2])
                key_padding_mask = key_padding_mask.astype(
                    jnp.float32)[Ellipsis, None]

        # create attention masks
        mask_components = []
        if self.causal_mask:
            if self.decode and is_cache_initialized:
                bias_pre_shape = (1, ) * (key.ndim - 1)
                attn_shape = tuple(onp.take(key.shape, attention_axis))
                attn_size = onp.prod(attn_shape)
                ii = jnp.arange(attn_size, dtype=jnp.int32)
                mask = ii < cache_index.value
                mask_components.append(
                    mask.reshape(bias_pre_shape + attn_shape))
            else:
                mask_components.append(_make_causal_mask(key, attention_axis))
        if padding_mask is not None:
            if key_padding_mask is None:
                key_padding_mask = padding_mask
            attn_padding_mask = make_padding_mask(
                padding_mask_query=padding_mask,
                padding_mask_key=key_padding_mask,
                query_shape=query.shape,
                key_shape=key.shape,
                attention_axis=attention_axis)
            mask_components.append(attn_padding_mask)
        if segmentation is not None:
            if key_segmentation is None:
                key_segmentation = segmentation
            segmentation_mask = make_padding_mask(
                padding_mask_query=segmentation,
                padding_mask_key=key_segmentation,
                query_shape=query.shape,
                key_shape=key.shape,
                attention_axis=attention_axis,
                segmentation_mask=True)
            mask_components.append(segmentation_mask)
        attention_mask = None
        if mask_components:
            attention_mask = mask_components[0]
            for component in mask_components[1:]:
                attention_mask = jnp.logical_and(attention_mask, component)
            attention_mask = attention_mask.astype(jnp.bool_)

            # attention mask in the form of attention bias
            attention_bias = jnp.where(
                attention_mask,
                jnp.full(attention_mask.shape, 0.).astype(dtype),
                jnp.full(attention_mask.shape, -1e10).astype(dtype))
        else:
            attention_bias = None

        # Add an extra dimension to the mask corresponding to the head
        # dimension. eg, if inputs_q has shape [batch_size, sequence_length,
        # n_features], then padding_mask will have a shape
        # [batch_size, sequence_length, 1] and query will have shape
        # [batch_size, sequence_length, n_heads, n_features_per_head].
        # We create query_padding_mask with shape [batch_size, sequence_length,
        # 1, 1] to be broadcast-compatible with 'query'.
        if padding_mask is not None:
            padding_mask = padding_mask[Ellipsis, None]
            shape_utils.assert_shapes_equal(
                padding_mask.shape, (batch_size, query_sequence_length, 1, 1))
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask[Ellipsis, None]
            # During prediction, the key padding mask is only going to be
            # broadcast-compatible with the key.
            shape_utils.assert_shapes_compatible(
                key_padding_mask.shape,
                (batch_size, key_sequence_length, 1, 1))

        # apply attention
        attention_fn = self.attention_fn
        dropout_rate = self.dropout_rate
        broadcast_dropout = self.broadcast_dropout
        deterministic = self.deterministic
        if not deterministic and self.dropout_rate > 0.0:
            dropout_rng = self.make_rng('dropout')
        else:
            dropout_rng = None
        x = attention_fn(  # pylint: disable=redundant-keyword-arg
            query=query,
            key=key,
            value=value,
            hparams=hparams.attn_acts,
            paxis_name=paxis_name,
            train=train,
            quant_context=self.quant_context,
            dtype=dtype,
            axis=attention_axis,
            bias=attention_bias,
            precision=jax_precision,
            dropout_rng=dropout_rng,
            dropout_rate=dropout_rate,
            broadcast_dropout=broadcast_dropout,
            deterministic=deterministic,
            query_padding_mask=padding_mask,
            key_padding_mask=key_padding_mask,
            attn_mask=attention_mask)
        shape_utils.assert_shapes_equal(
            x.shape, (batch_size, query_sequence_length, num_heads, head_dim))
        x = x.reshape(batch_size * query_sequence_length, num_heads * head_dim)
        if padding_mask is not None:
            padding_mask = padding_mask.reshape(
                batch_size * query_sequence_length, 1)
        # back to the original inputs dimensions
        out = flax_layers.DenseAqt(features=channel_size,
                                   hparams=hparams.dense_out,
                                   quant_context=self.quant_context,
                                   paxis_name=paxis_name,
                                   train=train,
                                   kernel_init=kernel_init,
                                   bias_init=bias_init,
                                   use_bias=use_bias,
                                   dtype=dtype,
                                   name='dense_out')(x,
                                                     padding_mask=padding_mask)
        shape_utils.assert_shapes_equal(
            out.shape, (batch_size * query_sequence_length, channel_size))
        out = out.reshape(batch_size, query_sequence_length, channel_size)
        return out
Exemplo n.º 13
0
 def cond(v):
     t, rmse, _ = v
     return jnp.logical_and(t < max_iter, rmse > tolerance)
Exemplo n.º 14
0
def is_transitional(state):
  """Checks whether individuals are in a state that can develop."""
  return np.logical_and(EXPOSED <= state, state <= INFECTED_3)
Exemplo n.º 15
0
Arquivo: jax.py Projeto: yibit/eagerpy
 def logical_and(self: TensorType, other: TensorOrScalar) -> TensorType:
     assert_bool(self)
     assert_bool(other)
     return type(self)(np.logical_and(self.raw, unwrap1(other)))
Exemplo n.º 16
0
 def cond(vals):
     sigma, phi, phiprime, unconverged, j = vals
     return np.logical_and(np.any(unconverged), j < maxiter)
Exemplo n.º 17
0
 def cond_fun(state):
     _, _, _, is_unconverged, is_not_max_iteration = state
     return jnp.logical_and(is_unconverged, is_not_max_iteration)
Exemplo n.º 18
0
def piter(f, x, trust_radius, args):

    fg = jax.value_and_grad(f)
    g = jax.grad(f)
    #h = jax.hessian(f)
    h = jax.jacfwd(g)
    #h = jax.jacrev(g)

    #compute function value and gradient
    val, grad = fg(x, *args)
    gradmag = np.linalg.norm(grad, axis=-1)
    gradcol = np.expand_dims(grad, axis=-1)

    #compute hessian and eigen-decomposition
    hess = h(x, *args)
    e, u = np.linalg.eigh(hess)

    ut = u.T

    #convert gradient to eigen-basis
    a = np.matmul(ut, gradcol)
    a = np.squeeze(a, axis=-1)

    lam = e
    e0 = lam[..., 0]

    #TODO deal with null gradient components and repeated eigenvectors
    lambar = lam
    abarsq = np.square(a)

    def phif(s):
        pmagsq = np.sum(abarsq / np.square(lambar + s), axis=-1)
        pmag = np.sqrt(pmagsq)
        phipartial = np.reciprocal(pmag)
        singular = np.any(np.equal(-s, lambar), axis=-1)
        phipartial = np.where(singular, 0., phipartial)
        phi = phipartial - np.reciprocal(trust_radius)
        return phi

    def phiphiprime(s):
        phi = phif(s)
        pmagsq = np.sum(abarsq / np.square(lambar + s), axis=-1)
        phiprime = np.power(pmagsq, -1.5) * np.sum(
            abarsq / np.power(lambar + s, 3), axis=-1)
        return (phi, phiprime)

    #check if unconstrained solution is valid
    sigma0 = np.maximum(-e0, 0.)
    phisigma0 = phif(sigma0)
    usesolu = np.logical_and(e0 > 0., phisigma0 >= 0.)

    sigma = np.max(np.abs(a) / trust_radius - lam, axis=-1)
    sigma = np.maximum(sigma, 0.)
    sigma = np.where(usesolu, 0., sigma)
    phi, phiprime = phiphiprime(sigma)

    #TODO, add handling of additional cases here (singular and "hard" cases)

    #iteratively solve for sigma, enforcing unconstrained solution sigma=0 where appropriate
    unconverged = np.ones(shape=sigma.shape, dtype=np.bool_)
    j = 0
    maxiter = 200

    #This can't work with vmap+jit because of the dynamic condition, so we use the jax while_loop below
    #j = 0
    #while np.logical_and(np.any(unconverged), j<maxiter):
    #sigma = sigma - phi/phiprime
    #sigma = np.where(usesolu, 0., sigma)
    #phiout, phiprimeout = phiphiprime(sigma)
    #unconverged = np.logical_and( (phiout > phi) , (phiout < 0.) )
    #phi,phiprime = (phiout, phiprimeout)
    #j = j +1

    def cond(vals):
        sigma, phi, phiprime, unconverged, j = vals
        return np.logical_and(np.any(unconverged), j < maxiter)

    def body(vals):
        sigma, phi, phiprime, unconverged, j = vals
        sigma = sigma - phi / phiprime
        sigma = np.where(usesolu, 0., sigma)
        phiout, phiprimeout = phiphiprime(sigma)
        unconverged = np.logical_and((phiout > phi), (phiout < 0.))
        phi, phiprime = (phiout, phiprimeout)
        j = j + 1
        return (sigma, phi, phiprime, unconverged, j)

    sigma = jax.lax.while_loop(cond, body,
                               (sigma, phi, phiprime, unconverged, j))[0]

    #compute solution from eigenvalues and eigenvectors
    coeffs = -a / (lam + sigma)
    coeffscol = np.expand_dims(coeffs, axis=-1)

    p = np.matmul(u, coeffscol)
    p = np.squeeze(p, axis=-1)

    #compute predicted reduction in loss function from eigenvalues and eigenvectors
    predicted_reduction = -np.sum(a * coeffs + 0.5 * lam * np.square(coeffs),
                                  axis=-1)

    #compute actual reduction in loss
    x_new = x + p
    val_new = f(x_new, *args)
    #actual_reduction = -(val_new - val)
    actual_reduction = val - val_new

    #update trust radius and output parameters, following Nocedal and Wright 2nd ed. Algorithm 4.1
    eta = 0.15
    trust_radius_max = 1e3
    rho = actual_reduction / np.where(np.equal(actual_reduction, 0.), 1.,
                                      predicted_reduction)
    rho = np.where(np.isnan(rho), 0., rho)
    at_boundary = np.logical_not(usesolu)
    trust_radius_out = np.where(
        rho < 0.25, 0.25 * trust_radius,
        np.where(np.logical_and(rho > 0.75, at_boundary),
                 np.minimum(2. * trust_radius, trust_radius_max),
                 trust_radius))

    x_out = np.where(rho > eta, x_new, x)

    #compute estimated distance to minimum for unconstrained solution (only valid if e0>0)
    coeffs0 = -a / lam
    edm = -np.sum(a * coeffs0 + 0.5 * lam * np.square(coeffs0), axis=-1)

    return x_out, trust_radius_out, val, gradmag, edm, e0
Exemplo n.º 19
0
def _unconverged(lk, j, maxiter, err, tol_delta, tol_lk):
    changing = err > tol_delta
    far_from_end = jnp.abs(1 - lk) > tol_lk
    unconverged = jnp.logical_or(changing, far_from_end)
    iterating = j < maxiter
    return jnp.logical_and(iterating, unconverged)[0]
Exemplo n.º 20
0
def _rational_quadratic_spline_inv(y: Array, x_pos: Array, y_pos: Array,
                                   knot_slopes: Array) -> Tuple[Array, Array]:
    """Applies the inverse of a rational-quadratic spline to a scalar.

  Args:
    y: a scalar (0-dimensional array). The scalar `y` can be any real number; it
      will be transformed by the spline if it's in the closed interval
      `[y_pos[0], y_pos[-1]]`, and it will be transformed linearly if it's
      outside that interval.
    x_pos: array of shape [num_bins + 1], the bin boundaries on the x axis.
    y_pos: array of shape [num_bins + 1], the bin boundaries on the y axis.
    knot_slopes: array of shape [num_bins + 1], the slopes at the knot points.
  Returns:
    A tuple of two scalars: the output of the inverse transformation and the log
    of the absolute first derivative of the inverse at `y`.
  """
    # Search to find the right bin. NOTE: The bins are sorted, so we could use
    # binary search, but this is more GPU/TPU friendly.
    # The following implementation avoids indexing for faster TPU computation.
    below_range = y <= y_pos[0]
    above_range = y >= y_pos[-1]
    correct_bin = jnp.logical_and(y >= y_pos[:-1], y < y_pos[1:])
    any_bin_in_range = jnp.any(correct_bin)
    first_bin = jnp.concatenate(
        [jnp.array([1]), jnp.zeros(len(correct_bin) - 1)]).astype(bool)
    # If y does not fall into any bin, we use the first spline in the following
    # computations to avoid numerical issues.
    correct_bin = jnp.where(any_bin_in_range, correct_bin, first_bin)
    # Dot product of each parameter with the correct bin mask.
    params = jnp.stack([x_pos, y_pos, knot_slopes], axis=1)
    params_bin_left = jnp.sum(correct_bin[:, None] * params[:-1], axis=0)
    params_bin_right = jnp.sum(correct_bin[:, None] * params[1:], axis=0)

    # These are the parameters for the corresponding bin.
    x_pos_bin = (params_bin_left[0], params_bin_right[0])
    y_pos_bin = (params_bin_left[1], params_bin_right[1])
    knot_slopes_bin = (params_bin_left[2], params_bin_right[2])

    bin_width = x_pos_bin[1] - x_pos_bin[0]
    bin_height = y_pos_bin[1] - y_pos_bin[0]
    bin_slope = bin_height / bin_width
    w = (y - y_pos_bin[0]) / bin_height
    w = jnp.clip(w, 0., 1.)  # Ensure w is in [0, 1].
    # Compute quadratic coefficients: az^2 + bz + c = 0
    slopes_term = knot_slopes_bin[1] + knot_slopes_bin[0] - 2. * bin_slope
    c = -bin_slope * w
    b = knot_slopes_bin[0] - slopes_term * w
    a = bin_slope - b

    # Solve quadratic to obtain z and then x.
    z = -2. * c / (b + jnp.sqrt(b**2 - 4. * a * c))
    z = jnp.clip(z, 0., 1.)  # Ensure z is in [0, 1].
    x = bin_width * z + x_pos_bin[0]

    # Compute log det Jacobian.
    sq_z = z * z
    z1mz = z - sq_z  # z(1-z)
    sq_1mz = (1. - z)**2
    denominator = bin_slope + slopes_term * z1mz
    logdet = -2. * jnp.log(bin_slope) - jnp.log(
        knot_slopes_bin[1] * sq_z + 2. * bin_slope * z1mz +
        knot_slopes_bin[0] * sq_1mz) + 2. * jnp.log(denominator)

    # If y is outside the spline range, we default to a linear transformation.
    x = jnp.where(below_range, (y - y_pos[0]) / knot_slopes[0] + x_pos[0], x)
    x = jnp.where(above_range, (y - y_pos[-1]) / knot_slopes[-1] + x_pos[-1],
                  x)
    logdet = jnp.where(below_range, -jnp.log(knot_slopes[0]), logdet)
    logdet = jnp.where(above_range, -jnp.log(knot_slopes[-1]), logdet)
    return x, logdet
Exemplo n.º 21
0
 def loop_cond(carry):
     _, _, breakdown, k = carry
     return jnp.logical_and(k < restart, jnp.logical_not(breakdown))
Exemplo n.º 22
0
def _rational_quadratic_spline_fwd(x: Array, x_pos: Array, y_pos: Array,
                                   knot_slopes: Array) -> Tuple[Array, Array]:
    """Applies a rational-quadratic spline to a scalar.

  Args:
    x: a scalar (0-dimensional array). The scalar `x` can be any real number; it
      will be transformed by the spline if it's in the closed interval
      `[x_pos[0], x_pos[-1]]`, and it will be transformed linearly if it's
      outside that interval.
    x_pos: array of shape [num_bins + 1], the bin boundaries on the x axis.
    y_pos: array of shape [num_bins + 1], the bin boundaries on the y axis.
    knot_slopes: array of shape [num_bins + 1], the slopes at the knot points.
  Returns:
    A tuple of two scalars: the output of the transformation and the log of the
    absolute first derivative at `x`.
  """
    # Search to find the right bin. NOTE: The bins are sorted, so we could use
    # binary search, but this is more GPU/TPU friendly.
    # The following implementation avoids indexing for faster TPU computation.
    below_range = x <= x_pos[0]
    above_range = x >= x_pos[-1]
    correct_bin = jnp.logical_and(x >= x_pos[:-1], x < x_pos[1:])
    any_bin_in_range = jnp.any(correct_bin)
    first_bin = jnp.concatenate(
        [jnp.array([1]), jnp.zeros(len(correct_bin) - 1)]).astype(bool)
    # If y does not fall into any bin, we use the first spline in the following
    # computations to avoid numerical issues.
    correct_bin = jnp.where(any_bin_in_range, correct_bin, first_bin)
    # Dot product of each parameter with the correct bin mask.
    params = jnp.stack([x_pos, y_pos, knot_slopes], axis=1)
    params_bin_left = jnp.sum(correct_bin[:, None] * params[:-1], axis=0)
    params_bin_right = jnp.sum(correct_bin[:, None] * params[1:], axis=0)

    x_pos_bin = (params_bin_left[0], params_bin_right[0])
    y_pos_bin = (params_bin_left[1], params_bin_right[1])
    knot_slopes_bin = (params_bin_left[2], params_bin_right[2])

    bin_width = x_pos_bin[1] - x_pos_bin[0]
    bin_height = y_pos_bin[1] - y_pos_bin[0]
    bin_slope = bin_height / bin_width

    z = (x - x_pos_bin[0]) / bin_width
    # `z` should be in range [0, 1] to avoid NaNs later. This can happen because
    # of small floating point issues or when x is outside of the range of bins.
    # To avoid all problems, we restrict z in [0, 1].
    z = jnp.clip(z, 0., 1.)
    sq_z = z * z
    z1mz = z - sq_z  # z(1-z)
    sq_1mz = (1. - z)**2
    slopes_term = knot_slopes_bin[1] + knot_slopes_bin[0] - 2. * bin_slope
    numerator = bin_height * (bin_slope * sq_z + knot_slopes_bin[0] * z1mz)
    denominator = bin_slope + slopes_term * z1mz
    y = y_pos_bin[0] + numerator / denominator

    # Compute log det Jacobian.
    # The logdet is a sum of 3 logs. It is easy to see that the inputs of the
    # first two logs are guaranteed to be positive because we ensured that z is in
    # [0, 1]. This is also true of the log(denominator) because:
    # denominator
    # == bin_slope + (knot_slopes_bin[1] + knot_slopes_bin[0] - 2 * bin_slope) *
    # z*(1-z)
    # >= bin_slope - 2 * bin_slope * z * (1-z)
    # >= bin_slope - 2 * bin_slope * (1/4)
    # == bin_slope / 2
    logdet = 2. * jnp.log(bin_slope) + jnp.log(
        knot_slopes_bin[1] * sq_z + 2. * bin_slope * z1mz +
        knot_slopes_bin[0] * sq_1mz) - 2. * jnp.log(denominator)

    # If x is outside the spline range, we default to a linear transformation.
    y = jnp.where(below_range, (x - x_pos[0]) * knot_slopes[0] + y_pos[0], y)
    y = jnp.where(above_range, (x - x_pos[-1]) * knot_slopes[-1] + y_pos[-1],
                  y)
    logdet = jnp.where(below_range, jnp.log(knot_slopes[0]), logdet)
    logdet = jnp.where(above_range, jnp.log(knot_slopes[-1]), logdet)
    return y, logdet
Exemplo n.º 23
0
def preprocess_masked(inputs, random_tokens, mask_token, pad_token, mask_rate,
                      mask_token_proportion, random_token_proportion, mode,
                      rng):
    """Preprocess inputs for masked language modeling.

  Args:
    inputs: [batch x length] input tokens.
    random_tokens: Set of tokens usable for replacing
    mask_token: Int ID to mask blanks with.
    pad_token: Int ID for PAD token. Positions left unchanged.
    mask_rate: Proportion of tokens to mask out.
    mask_token_proportion: Replace this proportion of chosen positions with
      MASK.
    random_token_proportion: Replace this proportion of chosen positions with
      randomly sampled tokens
    mode: Mode key.
    rng: Jax RNG.

  Returns:
    Tuple of [batch x length] inputs, targets, per position weights. targets
      will have random positions masked out with either a MASK token, or a
      randomly chosen token from the vocabulary.
  """
    total = random_token_proportion + mask_token_proportion
    if total < 0 or total > 1:
        raise ValueError('Sum of random proportion and mask proportion must be'
                         ' in [0, 1] range.')
    targets = inputs

    if mode == Mode.predict:
        weights = jnp.full_like(targets, 1)
        masked_inputs = inputs  # Pass through
    else:
        if rng is None:
            if mode is not Mode.eval:
                raise ValueError('Must provide RNG unless in eval mode.')
            # TODO(b/157055145): How to keep same eval set across runs?
            # Make each sequences mask invariant to other members
            # of the batch. Right now there is batch size dependence.
            rng = jrandom.PRNGKey(jnp.sum(inputs))

        # Get positions to leave untouched
        is_pad = inputs == pad_token

        # Positions to mask
        rng, subrng = jax.random.split(rng)
        should_mask = jrandom.bernoulli(subrng,
                                        p=mask_rate,
                                        shape=inputs.shape)
        should_mask = jnp.where(is_pad, 0,
                                should_mask)  # Don't mask out padding.

        # Generate full array of random tokens.
        rng, subrng = jax.random.split(rng)
        random_ids = jax.random.randint(subrng,
                                        inputs.shape,
                                        minval=0,
                                        maxval=len(random_tokens))

        fullrandom = random_tokens[random_ids]
        # Full array of MASK tokens
        fullmask = jnp.full_like(inputs, mask_token)

        # Build up masked array by selecting from inputs/fullmask/fullrandom.
        rand = jax.random.uniform(rng, shape=inputs.shape)
        masked_inputs = inputs
        # Remaining probability mass stays original values after MASK and RANDOM.
        # MASK tokens.
        masked_inputs = jnp.where(rand < mask_token_proportion, fullmask,
                                  masked_inputs)
        # Random tokens.
        masked_inputs = jnp.where(
            jnp.logical_and(
                rand >= mask_token_proportion,
                rand < mask_token_proportion + random_token_proportion),
            fullrandom, masked_inputs)

        # Only replace positions where `should_mask`
        masked_inputs = jnp.where(should_mask, masked_inputs, inputs)
        weights = should_mask

    return masked_inputs, targets, weights
Exemplo n.º 24
0
 def _safe_div(x1, x2):
     return jnp.where(jnp.logical_and(x1 == 0, x2 == 0), x1, x1 / x2)
Exemplo n.º 25
0
def logical_and(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.logical_and(x1, x2))
Exemplo n.º 26
0
 def _iter_condition(state):
     i, unused_v, unused_s, unused_s_v, run_step = state
     return jnp.logical_and(i < num_iters, run_step)
Exemplo n.º 27
0
 def qnorm_cond(carry):
     k, not_done, _, _ = carry
     return jnp.logical_and(not_done, k < (max_iterations - 1))
Exemplo n.º 28
0
 def _iter_condition(state):
     (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error,
      run_step) = state
     return jnp.logical_and(
         i < iter_count, jnp.logical_or(error > error_tolerance, run_step))
Exemplo n.º 29
0
 def loop_cond(carry):
     k, err, _, _, _, _ = carry
     return jnp.logical_and(k < restart, err > ptol)
def _get_attention_result(query,
                          key,
                          value,
                          dtype,
                          precision,
                          dropout_rng,
                          dropout_rate,
                          broadcast_dropout,
                          deterministic,
                          mask=None,
                          padding_mask=None,
                          key_padding_mask=None,
                          segmentation=None,
                          key_segmentation=None,
                          apply_causal_mask=False):
    """Helper function returning `[batch_size, seq_len, heads, features]` output."""
    # assumes query/key/value has shape `[batch_size, seq_len, heads, features]`.

    mask_components = [] if mask is None else [mask]

    seq_len = query.shape[1]

    if apply_causal_mask:
        causal_mask = jnp.array(
            np.reshape(np.tri(seq_len, k=0),
                       [1, 1, seq_len, seq_len])).astype(jnp.bool_)
        mask_components.append(causal_mask)
    if padding_mask is not None:
        if key_padding_mask is None:
            key_padding_mask = padding_mask
        padding_mask = nn.attention.make_padding_mask(
            padding_mask_query=padding_mask,
            padding_mask_key=key_padding_mask,
            query_shape=query.shape,
            key_shape=key.shape,
            attention_axis=(1, ))
        mask_components.append(padding_mask)

    if segmentation is not None:
        if key_segmentation is None:
            key_segmentation = segmentation
        segmentation_mask = nn.attention.make_padding_mask(
            padding_mask_query=segmentation,
            padding_mask_key=key_segmentation,
            query_shape=query.shape,
            key_shape=key.shape,
            attention_axis=(1, ),
            segmentation_mask=True)
        mask_components.append(segmentation_mask)

    if mask_components:
        attention_mask = mask_components[0]
        for component in mask_components[1:]:
            attention_mask = jnp.logical_and(attention_mask, component)

        # attention mask in the form of attention bias
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.).astype(dtype),
            jnp.full(attention_mask.shape, -1e10).astype(dtype))
    else:
        attention_bias = None

    return nn.attention.dot_product_attention(
        query,
        key,
        value,
        dtype=dtype,
        axis=1,
        bias=attention_bias,
        precision=precision,
        dropout_rng=dropout_rng,
        dropout_rate=dropout_rate,
        broadcast_dropout=broadcast_dropout,
        deterministic=deterministic)