Beispiel #1
0
def compute_loss_and_prob_from_probs_with_duplicates(
    probs: Array,
    classes: Array,
    targets: Array,
    weights: Array,
) -> Tuple[float, float, float]:
  """Compute weighted loss and avg correct probability given probs and targets.

  Args:
   probs: [batch, length, num_items] float array.
   classes: [batch, length, num_items] class for each item.
   targets: [batch, length] categorical target int array.
   weights:  [batch, length].

  Returns:
    Tuple of scalar loss, avg correct probability and normalizing factor.
  """
  probs = probs.astype(jnp.float32)
  weights = weights.astype(jnp.float32)

  correct_mask = (classes == jnp.expand_dims(targets, axis=-1))
  correct_mask = correct_mask.astype(jnp.float32)

  correct_probs = (correct_mask * probs).sum(axis=-1)
  avg_probs = correct_probs * weights
  loss = -jnp.log(correct_probs + _SMALL_NUMBER)
  loss = loss * weights

  return loss.sum(), avg_probs.sum(), weights.sum()
Beispiel #2
0
def compute_weighted_cross_entropy(
    scores: Array,
    targets: Array,
    weights: Array,
    inputs_are_prob: Optional[bool] = False,
) -> Tuple[float, float]:
  """Compute weighted cross entropy and entropy for log probs and targets.

  Args:
   scores: [batch, length, num_classes] float array.
   targets: [batch, length] categorical target integer array.
   weights: [batch, length].
   inputs_are_prob: true if inputs are probabilities rather than logits.

  Returns:
    Tuple of scalar loss and batch denominator.
  """
  scores = scores.astype(jnp.float32)
  targets = targets.astype(jnp.float32)
  weights = weights.astype(jnp.float32)
  if scores.ndim != targets.ndim + 1:
    raise ValueError('Incorrect shapes. Got shape %s scores and %s targets' %
                     (str(scores.shape), str(targets.shape)))
  vocab_size = scores.shape[-1]
  soft_targets = jax.nn.one_hot(targets, vocab_size)

  if inputs_are_prob:
    loss = -jnp.sum(soft_targets * jnp.log(scores + _SMALL_NUMBER), axis=-1)
  else:
    loss = -jnp.sum(soft_targets * jax.nn.log_softmax(scores), axis=-1)

  loss = loss * weights
  normalizing_factor = weights.sum()

  return loss.sum(), normalizing_factor
def get_mrr(labels: Array, logits: Array) -> Array:
    """Mean reciprocal rank in https://www.aclweb.org/anthology/P18-1009.pdf."""
    labels_exists = labels.sum(axis=-1) > 0
    labels_exists = labels_exists.astype(jnp.float32)
    order = jnp.argsort(-logits, axis=-1)
    ranks = jnp.argsort(order, axis=-1)
    mrr_per_sample = 1.0 / (ranks + 1)
    mrr_per_sample = (labels * mrr_per_sample).sum(-1) / (labels.sum(axis=-1) +
                                                          1e-5)
    return {
        'value': jnp.dot(mrr_per_sample, labels_exists),
        'denominator': labels_exists.sum(),
    }
Beispiel #4
0
def compute_assignments(
    centroids: Array,
    observations: Array,
    n_splits: int,
) -> Tuple[Array, Array]:
  """Assigns observations to cluster centroids.

  Computes l2 distance between each pair of observation and centroids, and
  assigns observation to closest cluster. Because the array of pairwise
  distances is very large, the cluster assignment is performed chunk by chunk.

  Args:
    centroids: [n_clusters, dim] cluster centroids.
    observations: [n_observations, dim] data points.
    n_splits: split observations into this many chunks.

  Returns:
    assignments: [n_observations] closest cluster for each observation.
    min_dist: [n_observations] distance to closest cluster by observation.
  """
  reshaped_observations = observations.reshape(n_splits, -1,
                                               centroids.shape[-1])

  def compute_split_assignments(split_points):
    dist = l2_distance(split_points, centroids)
    split_assignments = jnp.argmin(dist, axis=-1)
    split_min_dist = jnp.min(dist, axis=-1)
    return split_assignments, split_min_dist

  assignments, min_dist = jax.lax.map(compute_split_assignments,
                                      reshaped_observations)
  assignments = assignments.reshape(-1)
  min_dist = min_dist.reshape(-1)

  return assignments, min_dist
Beispiel #5
0
def get_batch_and_retrievals_entity_overlap(
    mention_target_batch_positions: Array,
    mention_target_ids: Array,
    mention_target_weights: Array,
    memory_text_entities: Array,
    batch_size: int,
):
    """Compute the overlap between entities in the batch and in retrievals.

  Args:
    mention_target_batch_positions: [n_target_mentions] position of a mention in
      its batch.
    mention_target_ids: [n_target_mentions] IDs of mentions.
    mention_target_weights: [n_target_mentions] per-mention weight for computing
      loss and metrics.
    memory_text_entities: [n_retrievals, n_memory_text_entities] IDs of mentions
      in the passage where memory is coming from. Note, entities in the same
      retrieval are assumed to be unique.
    batch_size: batch size.

  Returns:
    Array of shape [batch_size, n_retrievals] with the number of
      overlapping unique entity IDs in the batch and in the retrieval results.
  """
    n_target_mentions = mention_target_batch_positions.shape[0]
    n_retrievals = memory_text_entities.shape[0]
    n_memory_text_entities = memory_text_entities.shape[1]

    # Step 1: de-duplicate entities in the batch.
    # [n_target_mentions]
    mention_target_ids = mention_target_ids * mention_target_weights
    # [n_target_mentions]
    batch_ids = mention_utils.mask_duplicate_ids(
        mention_target_batch_positions, mention_target_ids)

    # Step 2: compare all entities in the batch against all entities in the
    # retrieval result.
    memory_text_entities = memory_text_entities.reshape(
        [n_retrievals * n_memory_text_entities])
    # [n_target_mentions, n_mentions * k_top * n_memory_text_entities]
    mention_id_in_retrieved_passages = mention_utils.all_compare_without_pad(
        batch_ids, memory_text_entities)
    mention_id_in_retrieved_passages = mention_id_in_retrieved_passages.astype(
        jnp.int32)

    # Step 3: sum up the comparison results by retrieval ID
    # [n_target_mentions, n_retrievals]
    mention_id_in_retrieved_passages = mention_id_in_retrieved_passages.reshape(
        [n_target_mentions, n_retrievals, n_memory_text_entities]).sum(-1)

    # Step 4: sum up the comparison results by batch position
    # [batch_size, n_retrievals]
    num_common_ids_between_samples = mention_utils.sum_by_batch_position(
        mention_target_batch_positions, mention_id_in_retrieved_passages,
        batch_size)

    return num_common_ids_between_samples
Beispiel #6
0
    def __call__(
        self,
        encoded_input: Array,
        retrieval_values: Array,
        retrieval_scores: Array,
        mention_batch_positions: Array,
        mention_start_positions: Array,
        mention_end_positions: Array,
        mention_mask: Array,
        deterministic: bool,
    ) -> Array:

        # Generate mention values from input representation
        mention_start_encodings = jut.matmul_2d_index_select(
            encoded_input, (mention_batch_positions, mention_start_positions))
        mention_end_encodings = jut.matmul_2d_index_select(
            encoded_input, (mention_batch_positions, mention_end_positions))

        passage_mention_values = self.value_projection(
            jnp.concatenate((mention_start_encodings, mention_end_encodings),
                            axis=-1))
        k_retrieval = retrieval_scores.shape[-1]
        passage_mention_values = jnp.expand_dims(passage_mention_values,
                                                 axis=1)
        passage_mention_values = jnp.tile(passage_mention_values,
                                          (1, k_retrieval, 1))

        # Generate concatenated values of shape [mentions, k, 2 * retrieval_dim]
        concat_values = jnp.concatenate(
            (passage_mention_values, retrieval_values), axis=-1)

        # MLP over concatenation mention value and individual retrieved value
        concat_values = nn.gelu(self.concat_mlp(concat_values))
        concat_values = self.concat_dense(concat_values)
        concat_values = self.concat_dropout(concat_values, deterministic)

        # Additional MLP layers
        for concat_mlp_layer in self.additional_concat_mlp_layers:
            concat_values = concat_mlp_layer(concat_values, deterministic)

        pooled_values = jnp.einsum('qk,qkd->qd', retrieval_scores,
                                   concat_values)

        # MLP layers applied to pooled retrieval values
        for pooled_mlp_layer in self.pooled_mlp_layers:
            pooled_values = pooled_mlp_layer(pooled_values, deterministic)
        pooled_values = pooled_values * mention_mask.reshape(-1, 1)

        encoded_input = jut.matmul_2d_index_add(
            encoded_input, (mention_batch_positions, mention_start_positions),
            pooled_values)

        encoded_input = self.layer_norm(encoded_input)

        return encoded_input
Beispiel #7
0
def compute_weighted_accuracy(scores: Array, targets: Array,
                              weights: Array) -> Tuple[float, float]:
  """Compute weighted accuracy for log probs and targets.

  Args:
   scores: [batch, length, num_classes] float array.
   targets: [batch, length] categorical targets int array.
   weights: [batch, length].

  Returns:
    Tuple of scalar loss and batch normalizing factor.
  """
  if scores.ndim != targets.ndim + 1:
    raise ValueError('Incorrect shapes. Got shape %s scores and %s targets' %
                     (str(scores.shape), str(targets.shape)))
  acc = jnp.equal(jnp.argmax(scores, axis=-1), targets)
  acc = acc * weights
  normalizing_factor = weights.sum()

  return acc.sum(), normalizing_factor
Beispiel #8
0
    def __call__(
        self,
        encoded_input: Array,
        retrieval_values: Array,
        retrieval_scores: Array,
        mention_batch_positions: Array,
        mention_start_positions: Array,
        mention_end_positions: Array,
        mention_mask: Array,
        deterministic: bool,
    ) -> Array:

        weighted_values = jnp.einsum('qk,qkd->qd', retrieval_scores,
                                     retrieval_values)
        projected_values = self.retrieval_projector(weighted_values)
        projected_values = projected_values * mention_mask.reshape(-1, 1)
        encoded_input = jut.matmul_2d_index_add(
            encoded_input, (mention_batch_positions, mention_start_positions),
            projected_values)

        encoded_input = self.layer_norm(encoded_input)

        return encoded_input
Beispiel #9
0
def same_entity_set_retrieval_loss(
    mention_target_batch_positions: Array,
    mention_target_ids: Array,
    mention_target_weights: Array,
    mention_batch_positions: Array,
    mention_mask: Array,
    memory_text_entities: Array,
    memory_attention_weights: Array,
    memory_mask: Array,
    batch_size: int,
    same_entity_set_target_threshold: int,
):
    """Computes same-entity-set-retrieval loss.

  We want to maximize attention scores received by memories which passages have
  large enough entity overlap (at least `same_entity_set_target_threshold`
  unique entities in common) with query mention's passage.

  User specifies which entity IDs exist in the current batch via the arguments:
  `mention_target_batch_positions`, `mention_target_ids` and
  `mention_target_weights`. Note that these correspond to linked mentions. While
  we retrieve memories for all mentions, we don't know entity IDs for non-linked
  mentions and they are not needed for computing entity overlap.

  Formally, let E(j) be the set of entity IDs in the j-th sample in the batch.
  k-th retrieval for i-th mention is "correct" if and only if
  IntersectionSize(E(mention_batch_positions[i]), memory_text_entities[i, j]) is
  at least `same_entity_set_target_threshold`.

  The loss is first computed for every mention separately and is negative log of
  sum of memory_attention_weights[i, j] where j is a "correct" retrieval for the
  i-th mention. The final loss is average of losses for those mentions which
  have at least one "correct" retrieval.

  Args:
    mention_target_batch_positions: [n_target_mentions] position of a linked
      (target) mention in its batch.
    mention_target_ids: [n_target_mentions] IDs of mentions.
    mention_target_weights: [n_target_mentions] per-mention weight for linked
      (target) mentions indicating whether a mention is padding or not.
    mention_batch_positions: [n_mentions] position of a mention in its batch.
    mention_mask: [n_mentions] whether a mention is padding or not.
    memory_text_entities: [n_mentions, k_top, n_memory_text_entities] IDs of
      mentions in the passage where memory is coming from.
    memory_attention_weights: [n_mentions, k_top] attention weights for the
      retrieval results.
    memory_mask: [n_mentions, k_top] which retrievals to use and which are to
      ignore in the loss computations. Typical usage is to ignore "disallowed"
      or same-passage retrievals.
    batch_size: batch size.
    same_entity_set_target_threshold: how many common entities needs to be
      between memory's passage and mention's passage in order to treat retrieval
      result as positive. If it's equal to 2 this loss becomes
      `same-mtb-retrieval` loss.

  Returns:
    Tuple of scalar loss, avg correct probability and normalizing factor.
  """
    n_mentions = memory_text_entities.shape[0]
    k_top = memory_text_entities.shape[1]

    # [batch_size, n_mentions * k_top]
    num_common_ids_between_samples = get_batch_and_retrievals_entity_overlap(
        mention_target_batch_positions=mention_target_batch_positions,
        mention_target_ids=mention_target_ids,
        mention_target_weights=mention_target_weights,
        memory_text_entities=memory_text_entities.reshape(
            [n_mentions * k_top, -1]),
        batch_size=batch_size,
    )

    # [batch_size, n_mentions, k_top]
    num_common_ids_between_samples = num_common_ids_between_samples.reshape(
        [batch_size, n_mentions, k_top])

    # [batch_size, n_mentions]
    position2item = (jnp.expand_dims(jnp.arange(batch_size),
                                     1) == mention_batch_positions)
    position2item = position2item.astype(jnp.int32)

    # [n_mentions,  k_top]
    num_common_ids_between_mentions = jnp.einsum(
        'bm,bmk->mk', position2item, num_common_ids_between_samples)

    # compute which retrievals have enough common elements with the
    # passage in the batch (at least `same_entity_set_target_threshold`) and
    # therefore, should be marked as "correct" retrievals.
    # [n_mentions, k_top]
    enough_common_elements = (num_common_ids_between_mentions >=
                              same_entity_set_target_threshold)
    enough_common_elements = enough_common_elements.astype(jnp.int32)
    correct_retrievals_mask = enough_common_elements * memory_mask
    incorrect_retrievals_mask = (1 - enough_common_elements) * memory_mask

    # [n_mentions]
    correct_retrieval_exists = correct_retrievals_mask.sum(-1) > 0
    incorrect_retrieval_exists = incorrect_retrievals_mask.sum(-1) > 0
    loss_mask = jnp.logical_and(correct_retrieval_exists,
                                incorrect_retrieval_exists)
    loss_mask = loss_mask.astype(mention_mask.dtype) * mention_mask
    loss_mask = loss_mask.astype(jnp.float32)

    # [n_mentions, k_top]
    correct_retrievals_mask = correct_retrievals_mask.astype(jnp.float32)

    # compute loss and metrics
    # [n_mentions, k_top]
    correct_probs = jnp.einsum('mk,mk->m', correct_retrievals_mask,
                               memory_attention_weights)
    avg_probs = correct_probs * loss_mask
    loss = -jnp.log(correct_probs + _SMALL_NUMBER)
    loss = loss * loss_mask

    return loss.sum(), avg_probs.sum(), loss_mask.sum(),
Beispiel #10
0
def entity_linking_loss(mention_encodings: Array, entity_embeddings: Array,
                        mention_target_ids: Array,
                        mention_target_weights: Array, mode: str) -> Array:
    """Compute entity linking loss.

  Args:
    mention_encodings: [n_mentions, hidden_size] mention encodings to be used
      for computing the loss.
    entity_embeddings: [n_entities, hidden_size] entity embeddings table.
    mention_target_ids: [n_mentions] IDs of mentions.
    mention_target_weights: [n_mentions] per-mention weight for computing loss
      and metrics.
    mode: how to compute the scores -- using dot product ('dot'), dot product
      divided by the sqrt root of the hidden dim ('dot_sqrt') or cosine
      similarity ('cos').

  Returns:
    Loss, a dictionary with metrics values, per sample infomation
    (a tuple of accuracy per mention and weight per mention).
  """
    scores = jnp.einsum('qd,ed->qe', mention_encodings, entity_embeddings)
    scores = scores.astype(jnp.float32)

    mention_encodings_norm = jnp.linalg.norm(mention_encodings, axis=-1)
    entity_embeddings_norm = jnp.linalg.norm(entity_embeddings, axis=-1)

    # The cosine similarity is computed as dot product divided by norms of
    # both vectors.
    cos_scores = scores
    cos_scores /= (_SMALL_NUMBER + jnp.expand_dims(mention_encodings_norm, 1))
    cos_scores /= (_SMALL_NUMBER + jnp.expand_dims(entity_embeddings_norm, 0))

    if mode == 'dot':
        pass
    elif mode == 'dot_sqrt':
        hidden_dim = mention_encodings.shape[1]
        scores /= jnp.sqrt(hidden_dim)
    elif mode == 'cos':
        scores = cos_scores
    else:
        raise ValueError('Unknown entity linking mode: ' + mode)

    mention_target_weights = mention_target_weights.astype(jnp.float32)

    loss, _ = metric_utils.compute_weighted_cross_entropy(
        scores,
        mention_target_ids,
        mention_target_weights,
        inputs_are_prob=False)

    acc_per_mention = jnp.equal(jnp.argmax(scores, axis=-1),
                                mention_target_ids)

    acc_per_mention = acc_per_mention * mention_target_weights

    n_mentions = mention_target_ids.shape[0]
    cos_per_mention = cos_scores[jnp.arange(n_mentions), mention_target_ids]
    cos_per_mention = cos_per_mention * mention_target_weights

    metrics = {
        'loss': loss,
        'acc': acc_per_mention.sum(),
        'cos_sim': cos_per_mention.sum(),
        'denominator': mention_target_weights.sum()
    }
    return loss, metrics, (acc_per_mention, mention_target_weights)
Beispiel #11
0
    def __call__(
        self,
        queries: Array,
        table: Array,
        prototypes: Array,
    ) -> Tuple[Array, Array, Array]:
        """Perform approximate top-k similarity search over vector table.

    Args:
      queries: [n_queries, vector_dim].
      table: [n_clusters, rows, values per row, vector_dim] vector table.
      prototypes: [n_clusters, vector_dim] representative vectors for clusters.

    Returns:
      Top-k vectors, scores and ids.
    """
        n_queries = queries.shape[0]
        queries_per_split = n_queries // self.splits

        rows_per_cluster = table.shape[1]
        values_per_row = table.shape[2]
        values_per_cluster = rows_per_cluster * values_per_row

        table_size = values_per_row * table.shape[0]
        vector_dim = queries.shape[1]
        assert table.shape[-1] == vector_dim

        # Split queries to reduce size of selected clusters and save memory.
        queries = queries.reshape(self.splits, queries_per_split, vector_dim)

        def split_top_k(split_queries: Array) -> Tuple[Array, Array, Array]:
            # Find most similar clusters
            prototype_scores = jnp.einsum('qd,pd->qp', split_queries,
                                          prototypes)
            top_indices = jax.lax.top_k(prototype_scores, self.n_search)[1]
            # Perform approximate top-k similarity search over most similar clusters.
            selected_data = table[top_indices]
            split_scores = jnp.einsum('qd,qcrvd->qcrv', split_queries,
                                      selected_data)

            # Find highest scoring vector for each row.
            top_id_by_row = jnp.argmax(split_scores, axis=-1)
            top_score_by_row = jnp.max(split_scores, axis=-1)

            top_id_by_row = top_id_by_row.reshape(
                queries_per_split, self.n_search * rows_per_cluster)
            top_score_by_row = top_score_by_row.reshape(
                queries_per_split, self.n_search * rows_per_cluster)

            # Take k highest scores among all rows.
            top_row_idx = jnp.argsort(top_score_by_row,
                                      axis=-1)[:, :-self.k_top - 1:-1]

            # Sub-select best indices for k best rows.
            ids_by_topk_row = jut.matmul_slice(top_id_by_row, top_row_idx)

            # Gather highest scoring vectors for k best rows.
            query_index = jnp.arange(queries_per_split).reshape(-1, 1).tile(
                [1, self.k_top])
            top_cluster_idx, top_cluster_row_idx = jnp.divmod(
                top_row_idx, rows_per_cluster)
            split_topk_values = selected_data[query_index, top_cluster_idx,
                                              top_cluster_row_idx,
                                              ids_by_topk_row]

            row_offset = jnp.mod(
                jnp.arange(0, self.n_search * values_per_cluster,
                           values_per_row), values_per_cluster)
            cluster_offset = jnp.arange(0, table_size, values_per_cluster)

            # Convert row indices to indices into flattened table.
            top_table_id_by_row = top_id_by_row + row_offset.reshape(
                1, -1) + cluster_offset[top_indices].repeat(rows_per_cluster,
                                                            axis=-1)
            # Get best ids into flattened table.
            split_topk_ids = jut.matmul_slice(top_table_id_by_row, top_row_idx)

            split_topk_scores = jut.matmul_slice(top_score_by_row, top_row_idx)

            return split_topk_values, split_topk_scores, split_topk_ids

        # Perform similarity over each chunk of queries sequentially
        # (not in parallel), so that only one score tensor is in memory at a time.
        topk_values, topk_scores, topk_ids = jax.lax.map(split_top_k, queries)

        topk_values = topk_values.reshape(n_queries, self.k_top, -1)
        topk_scores = topk_scores.reshape(n_queries, self.k_top)
        topk_ids = topk_ids.reshape(n_queries, self.k_top)

        return topk_values, topk_scores, topk_ids
Beispiel #12
0
def compute_cross_entropy_loss_with_positives_and_negatives_masks(
    scores: Array,
    positives: Array,
    negatives: Array,
    weights: Optional[Array] = None,
) -> Tuple[float, Dict[str, float], Tuple[Array, Array]]:
  """Compute (weighted) cross-entropy loss and accuracy-related metrics.

  The function computes cross entropy loss when there are potentially multiple
  positive classes per sample, multiple negative classes and others are neutral.
  In this case, loss per sample is average of cross entropy losses computed
  by considering each positive class and all negative classes.
  Neutral classes are ignored.

  Arguments `positives` and `negatives` are boolean matrices that specify
  which class is considered positive or negative per every sample.
  `positive[i, j]` is True <=> class j is considered positive for the sample i
  `negative[i, j]` is True <=> class j is considered negative for the sample i

  The loss is computed in 3 stages:

  (1) For every sample i and positive class j we compute cross-entropy loss
  using j as a positive class and all negative classes for i as negatives.

  (2) For every sample i the total loss is average of losses per each of its
  positive classes.

  (3) Total loss is a sum of losses per each sample. The loss only includes
  samples, which have at least one positive and one negative classes. Users
  can limit this even further by providing a custom `weights`.

  Args:
   scores: [batch_size, num_classes] scores or logits.
   positives: [batch_size, num_classes] 0-1 mask for which classes are positive.
   negatives: [batch_size, num_classes] 0-1 mask for which classes are negative.
   weights: [batch_size] 0-1 masks indicating whether the loss should be
     computed for the corresponding item in the batch.

  Returns:
    A tuple of scalar loss, a dictionary with metrics, per sample information
    (a tuple of average positive probability per sample and weight per sample).
  """
  at_least_one_positive_and_negative = jnp.logical_and(
      positives.sum(-1) > 0,
      negatives.sum(-1) > 0)
  if weights is None:
    weights = at_least_one_positive_and_negative
  else:
    weights = jnp.logical_and(weights, at_least_one_positive_and_negative)

  scores = scores.astype(jnp.float32)
  positives = positives.astype(jnp.float32)
  negatives = negatives.astype(jnp.float32)
  weights = weights.astype(jnp.float32)

  # For simplicity, we ignore the first batch dimension in the equations below
  # and assume that the loss is computed for a single sample.
  # Let p_1, ..., p_N be scores of positive classes
  # and n_1, ..., n_M be scores of negative classes.
  # In this case the loss is
  # sum_{i=1..N} -log softmax([p_i, n_1, ..., n_M])_1.
  # It's too computationally expensive to compute it naively.
  # We implement the loss in the following way

  # (1) compute S, the negatives part of softmax denominator. In other words,
  # exp(S) = sum_{j=1..M} exp(n_j)
  negative_scores = scores * negatives - _BIG_NUMBER * (1.0 - negatives)

  negative_scores_log_sum_exp = jax.nn.logsumexp(
      negative_scores, axis=-1, keepdims=True)

  # (2) now the loss per positive class i is just
  # -log (exp(p_i) / (exp(p_i) + exp(S)) = -log(1 / (1 + exp(-(p_i - S))))
  # = -log sigmoid(p_i - S)
  scores_minus_negatives = scores - negative_scores_log_sum_exp
  positives_weight = (positives.sum(axis=-1) + _SMALL_NUMBER)
  per_positive_loss = -jax.nn.log_sigmoid(scores_minus_negatives)

  # (3) compute average loss over all positive classes
  loss_per_sample = (per_positive_loss * positives).sum(axis=-1)
  loss_per_sample /= positives_weight
  loss_per_sample *= weights

  # (4) compute sum of losses over all positive samples
  loss = loss_per_sample.sum()

  # Now we need to compute the average accuracy.
  # First, compute the max score of negative classes per sample.
  # A positive class needs to have a higher score in order to get predicted.
  max_negative_scores = negative_scores.max(axis=-1, keepdims=True)

  # Second, a prediction for pair of a sample and its positive class
  # is correct if the score of the positive class is larger than
  # scores of all corresponding negative classes. In other words, the score
  # of the positive class needs to be larger than `max_negative_scores`.
  correct_prediction = (scores > max_negative_scores).astype(jnp.float32)

  # Take average over all positive classes per sample
  correct_prediction = (correct_prediction * positives).sum(axis=-1)
  correct_prediction /= positives_weight

  # Mask out samples with 0 weight
  correct_prediction = correct_prediction * weights

  metrics = {
      'loss': loss,
      'acc': correct_prediction.sum(),
      'denominator': weights.sum(),
  }
  return loss, metrics, (correct_prediction, weights)
    def __call__(
        self,
        queries: Array,
        table: Array,
    ) -> Tuple[Array, Array, Array]:
        """Perform approximate top-k similarity search over vector table.

    Args:
      queries: [n_queries, vector_dim].
      table: [rows, values per row, vector_dim] vector table. The number of rows
        in the table governs the recall vs speed of the topk similarity search.
        Search is performed by taking max over each row, and then top-k between
        rows. Distributing the same values over more rows leads to higher recall
        but slower search.

    Returns:
      Top-k vectors, scores and ids.
    """
        n_queries = queries.shape[0]
        queries_per_split = n_queries // self.splits
        scores_per_row = table.shape[1]
        table_size = scores_per_row * table.shape[0]
        vector_dim = queries.shape[1]
        assert table.shape[-1] == vector_dim

        # Split queries to reduce size of intermediate score tensor and save memory.
        queries = queries.reshape(self.splits, queries_per_split, vector_dim)

        def split_top_k(split_queries):
            split_scores = jnp.einsum('qd,rvd->qrv', split_queries, table)

            # Find highest scoring vector for each row.
            top_id_by_row = jnp.argmax(split_scores, axis=-1)
            top_score_by_row = jnp.max(split_scores, axis=-1)

            # Take k highest scores among all rows.
            top_row_idx = jnp.argsort(top_score_by_row,
                                      axis=-1)[:, :-self.k_top - 1:-1]

            # Sub-select best indices for k best rows.
            ids_by_topk_row = jut.matmul_slice(top_id_by_row, top_row_idx)

            # Gather highest scoring vectors for k best rows.
            split_topk_values = table[top_row_idx, ids_by_topk_row]

            # Convert row indices to indices into flattened table.
            top_table_id_by_row = top_id_by_row + jnp.arange(
                0, table_size, scores_per_row)
            # Get best ids into flattened table.
            split_topk_ids = jut.matmul_slice(top_table_id_by_row, top_row_idx)

            split_topk_scores = jut.matmul_slice(top_score_by_row, top_row_idx)

            return split_topk_values, split_topk_scores, split_topk_ids

        # Perform similarity over each chunk of queries sequentially
        # (not in parallel), so that only one score tensor is in memory at a time.
        topk_values, topk_scores, topk_ids = jax.lax.map(split_top_k, queries)

        topk_values = topk_values.reshape(n_queries, self.k_top, -1)
        topk_scores = topk_scores.reshape(n_queries, self.k_top)
        topk_ids = topk_ids.reshape(n_queries, self.k_top)

        return topk_values, topk_scores, topk_ids
  def __call__(
      self,
      encoding: Array,
      mention_batch_positions: Array,
      mention_start_positions: Array,
      mention_end_positions: Array,
      mention_mask: Array,
      memory_keys: Array,
      memory_values: Array,
      memory_mask: Array,
      memory_entity_ids: Array,
      deterministic: bool,
  ) -> Tuple[Array, Dict[str, Array], Dict[str, Array]]:
    """Perform attention update over memory table.

    Args:
      encoding: [batch_size, n_tokens, hidden_size] input representation.
      mention_batch_positions: [n_mentions] mention sample position in batch.
      mention_start_positions: [n_mentions] mention start position in input.
      mention_end_positions: [n_mentions] mention end position in input.
      mention_mask: [n_mentions] attention mask to prevent updates from padding.
      memory_keys: [memory_size, memory_key_dim] mention memory keys.
      memory_values: [memory_size, memory_value_dim] mention memory values.
      memory_mask: [memory_size] mask for valid mentions in memory.
      memory_entity_ids: [memory_size] mention memory entity ids.
      deterministic: don't apply dropout if true.

    Returns:
      Updated input, loss and logging helper dicts.
    """
    loss_helpers, logging_helpers = {}, {}

    # We generate mention representations to use as queries for similarity
    # search by concatenating start and end tokens for each mention and
    # projecting the concatenation with a dense layer.
    mention_start_encodings = jut.matmul_2d_index_select(
        encoding, (mention_batch_positions, mention_start_positions))
    mention_end_encodings = jut.matmul_2d_index_select(
        encoding, (mention_batch_positions, mention_end_positions))

    queries = self.query_projector(
        jnp.concatenate((mention_start_encodings, mention_end_encodings),
                        axis=-1))

    n_queries = queries.shape[0]

    # For attention over entire memory table, we do not want to duplicate the
    # entire memory table for each query. Instead, we perform an
    # attention-weighted sum to produce a single value. We then feed this value
    # to the update layer as a set of retrieved values of size 1, with score 1.
    if self.k_top is None:
      loss_helpers['top_entity_ids'] = jnp.tile(memory_entity_ids,
                                                (n_queries, 1))
      scores = jnp.einsum('qd,md->qm', queries, memory_keys)
      scores = scores - (1 - memory_mask) * _LARGE_NUMBER
      true_attention_weights = nn.softmax(scores, axis=-1)
      loss_helpers['memory_attention_weights'] = true_attention_weights
      top_values = jnp.einsum('qm,md->qd', true_attention_weights,
                              memory_values)
      # Expand value as though it were a set of retrieved values for each query.
      # Shape (n_queries, 1, memory_value_dim)
      top_values = jnp.expand_dims(top_values, axis=1)
      # Generate pseudo-score (n_queries, 1).
      attention_weights = jnp.ones_like(top_values, shape=(n_queries, 1))
    else:
      # Reshape memory keys for use in approximate top-k similarity layer.
      memory_keys = memory_keys.reshape(self.rows, -1, self.memory_key_dim)
      # We generate a version of the queries with stop gradient to use as input
      # to the topk similarity layer. We actually do want gradient to flow to
      # the queries, but backward differentiation over the topk layer yields
      # inefficient HLO ops. Instead we use queries with gradient to recompute
      # attention scores later.
      queries_sg = jax.lax.stop_gradient(queries)

      # Perform top-k similarity search over queries, yielding
      #   top_values: (queries, k_top, memory_dim)
      #   top_ids: (queries, k_top)
      top_keys, _, top_ids = self.topk_similarity(queries_sg, memory_keys)

      top_ids = top_ids.reshape(n_queries, self.k_top)
      top_values = memory_values[top_ids]
      loss_helpers['top_entity_ids'] = memory_entity_ids[top_ids]

      # We re-compute top scores using the queries with gradient (wg) to make
      # sure the query projector and the rest of the model receives gradient.
      top_scores_wg = jnp.einsum('qd,qkd->qk', queries, top_keys)
      top_mask = memory_mask[top_ids]
      top_scores_wg = top_scores_wg - (1 - top_mask) * _LARGE_NUMBER

      # We perform dot product attention using retrieved memory vectors as key,
      # dense projection of retrieved vectors as value and value and mention
      # representations as query.
      attention_weights = nn.softmax(top_scores_wg, axis=-1)
      loss_helpers['memory_attention_weights'] = attention_weights
    encoding = self.update_layer(
        encoded_input=encoding,
        retrieval_values=top_values,
        retrieval_scores=attention_weights,
        mention_batch_positions=mention_batch_positions,
        mention_start_positions=mention_start_positions,
        mention_end_positions=mention_end_positions,
        mention_mask=mention_mask,
        deterministic=deterministic,
    )

    return encoding, loss_helpers, logging_helpers