def compute_loss_and_prob_from_probs_with_duplicates( probs: Array, classes: Array, targets: Array, weights: Array, ) -> Tuple[float, float, float]: """Compute weighted loss and avg correct probability given probs and targets. Args: probs: [batch, length, num_items] float array. classes: [batch, length, num_items] class for each item. targets: [batch, length] categorical target int array. weights: [batch, length]. Returns: Tuple of scalar loss, avg correct probability and normalizing factor. """ probs = probs.astype(jnp.float32) weights = weights.astype(jnp.float32) correct_mask = (classes == jnp.expand_dims(targets, axis=-1)) correct_mask = correct_mask.astype(jnp.float32) correct_probs = (correct_mask * probs).sum(axis=-1) avg_probs = correct_probs * weights loss = -jnp.log(correct_probs + _SMALL_NUMBER) loss = loss * weights return loss.sum(), avg_probs.sum(), weights.sum()
def compute_weighted_cross_entropy( scores: Array, targets: Array, weights: Array, inputs_are_prob: Optional[bool] = False, ) -> Tuple[float, float]: """Compute weighted cross entropy and entropy for log probs and targets. Args: scores: [batch, length, num_classes] float array. targets: [batch, length] categorical target integer array. weights: [batch, length]. inputs_are_prob: true if inputs are probabilities rather than logits. Returns: Tuple of scalar loss and batch denominator. """ scores = scores.astype(jnp.float32) targets = targets.astype(jnp.float32) weights = weights.astype(jnp.float32) if scores.ndim != targets.ndim + 1: raise ValueError('Incorrect shapes. Got shape %s scores and %s targets' % (str(scores.shape), str(targets.shape))) vocab_size = scores.shape[-1] soft_targets = jax.nn.one_hot(targets, vocab_size) if inputs_are_prob: loss = -jnp.sum(soft_targets * jnp.log(scores + _SMALL_NUMBER), axis=-1) else: loss = -jnp.sum(soft_targets * jax.nn.log_softmax(scores), axis=-1) loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor
def get_mrr(labels: Array, logits: Array) -> Array: """Mean reciprocal rank in https://www.aclweb.org/anthology/P18-1009.pdf.""" labels_exists = labels.sum(axis=-1) > 0 labels_exists = labels_exists.astype(jnp.float32) order = jnp.argsort(-logits, axis=-1) ranks = jnp.argsort(order, axis=-1) mrr_per_sample = 1.0 / (ranks + 1) mrr_per_sample = (labels * mrr_per_sample).sum(-1) / (labels.sum(axis=-1) + 1e-5) return { 'value': jnp.dot(mrr_per_sample, labels_exists), 'denominator': labels_exists.sum(), }
def compute_assignments( centroids: Array, observations: Array, n_splits: int, ) -> Tuple[Array, Array]: """Assigns observations to cluster centroids. Computes l2 distance between each pair of observation and centroids, and assigns observation to closest cluster. Because the array of pairwise distances is very large, the cluster assignment is performed chunk by chunk. Args: centroids: [n_clusters, dim] cluster centroids. observations: [n_observations, dim] data points. n_splits: split observations into this many chunks. Returns: assignments: [n_observations] closest cluster for each observation. min_dist: [n_observations] distance to closest cluster by observation. """ reshaped_observations = observations.reshape(n_splits, -1, centroids.shape[-1]) def compute_split_assignments(split_points): dist = l2_distance(split_points, centroids) split_assignments = jnp.argmin(dist, axis=-1) split_min_dist = jnp.min(dist, axis=-1) return split_assignments, split_min_dist assignments, min_dist = jax.lax.map(compute_split_assignments, reshaped_observations) assignments = assignments.reshape(-1) min_dist = min_dist.reshape(-1) return assignments, min_dist
def get_batch_and_retrievals_entity_overlap( mention_target_batch_positions: Array, mention_target_ids: Array, mention_target_weights: Array, memory_text_entities: Array, batch_size: int, ): """Compute the overlap between entities in the batch and in retrievals. Args: mention_target_batch_positions: [n_target_mentions] position of a mention in its batch. mention_target_ids: [n_target_mentions] IDs of mentions. mention_target_weights: [n_target_mentions] per-mention weight for computing loss and metrics. memory_text_entities: [n_retrievals, n_memory_text_entities] IDs of mentions in the passage where memory is coming from. Note, entities in the same retrieval are assumed to be unique. batch_size: batch size. Returns: Array of shape [batch_size, n_retrievals] with the number of overlapping unique entity IDs in the batch and in the retrieval results. """ n_target_mentions = mention_target_batch_positions.shape[0] n_retrievals = memory_text_entities.shape[0] n_memory_text_entities = memory_text_entities.shape[1] # Step 1: de-duplicate entities in the batch. # [n_target_mentions] mention_target_ids = mention_target_ids * mention_target_weights # [n_target_mentions] batch_ids = mention_utils.mask_duplicate_ids( mention_target_batch_positions, mention_target_ids) # Step 2: compare all entities in the batch against all entities in the # retrieval result. memory_text_entities = memory_text_entities.reshape( [n_retrievals * n_memory_text_entities]) # [n_target_mentions, n_mentions * k_top * n_memory_text_entities] mention_id_in_retrieved_passages = mention_utils.all_compare_without_pad( batch_ids, memory_text_entities) mention_id_in_retrieved_passages = mention_id_in_retrieved_passages.astype( jnp.int32) # Step 3: sum up the comparison results by retrieval ID # [n_target_mentions, n_retrievals] mention_id_in_retrieved_passages = mention_id_in_retrieved_passages.reshape( [n_target_mentions, n_retrievals, n_memory_text_entities]).sum(-1) # Step 4: sum up the comparison results by batch position # [batch_size, n_retrievals] num_common_ids_between_samples = mention_utils.sum_by_batch_position( mention_target_batch_positions, mention_id_in_retrieved_passages, batch_size) return num_common_ids_between_samples
def __call__( self, encoded_input: Array, retrieval_values: Array, retrieval_scores: Array, mention_batch_positions: Array, mention_start_positions: Array, mention_end_positions: Array, mention_mask: Array, deterministic: bool, ) -> Array: # Generate mention values from input representation mention_start_encodings = jut.matmul_2d_index_select( encoded_input, (mention_batch_positions, mention_start_positions)) mention_end_encodings = jut.matmul_2d_index_select( encoded_input, (mention_batch_positions, mention_end_positions)) passage_mention_values = self.value_projection( jnp.concatenate((mention_start_encodings, mention_end_encodings), axis=-1)) k_retrieval = retrieval_scores.shape[-1] passage_mention_values = jnp.expand_dims(passage_mention_values, axis=1) passage_mention_values = jnp.tile(passage_mention_values, (1, k_retrieval, 1)) # Generate concatenated values of shape [mentions, k, 2 * retrieval_dim] concat_values = jnp.concatenate( (passage_mention_values, retrieval_values), axis=-1) # MLP over concatenation mention value and individual retrieved value concat_values = nn.gelu(self.concat_mlp(concat_values)) concat_values = self.concat_dense(concat_values) concat_values = self.concat_dropout(concat_values, deterministic) # Additional MLP layers for concat_mlp_layer in self.additional_concat_mlp_layers: concat_values = concat_mlp_layer(concat_values, deterministic) pooled_values = jnp.einsum('qk,qkd->qd', retrieval_scores, concat_values) # MLP layers applied to pooled retrieval values for pooled_mlp_layer in self.pooled_mlp_layers: pooled_values = pooled_mlp_layer(pooled_values, deterministic) pooled_values = pooled_values * mention_mask.reshape(-1, 1) encoded_input = jut.matmul_2d_index_add( encoded_input, (mention_batch_positions, mention_start_positions), pooled_values) encoded_input = self.layer_norm(encoded_input) return encoded_input
def compute_weighted_accuracy(scores: Array, targets: Array, weights: Array) -> Tuple[float, float]: """Compute weighted accuracy for log probs and targets. Args: scores: [batch, length, num_classes] float array. targets: [batch, length] categorical targets int array. weights: [batch, length]. Returns: Tuple of scalar loss and batch normalizing factor. """ if scores.ndim != targets.ndim + 1: raise ValueError('Incorrect shapes. Got shape %s scores and %s targets' % (str(scores.shape), str(targets.shape))) acc = jnp.equal(jnp.argmax(scores, axis=-1), targets) acc = acc * weights normalizing_factor = weights.sum() return acc.sum(), normalizing_factor
def __call__( self, encoded_input: Array, retrieval_values: Array, retrieval_scores: Array, mention_batch_positions: Array, mention_start_positions: Array, mention_end_positions: Array, mention_mask: Array, deterministic: bool, ) -> Array: weighted_values = jnp.einsum('qk,qkd->qd', retrieval_scores, retrieval_values) projected_values = self.retrieval_projector(weighted_values) projected_values = projected_values * mention_mask.reshape(-1, 1) encoded_input = jut.matmul_2d_index_add( encoded_input, (mention_batch_positions, mention_start_positions), projected_values) encoded_input = self.layer_norm(encoded_input) return encoded_input
def same_entity_set_retrieval_loss( mention_target_batch_positions: Array, mention_target_ids: Array, mention_target_weights: Array, mention_batch_positions: Array, mention_mask: Array, memory_text_entities: Array, memory_attention_weights: Array, memory_mask: Array, batch_size: int, same_entity_set_target_threshold: int, ): """Computes same-entity-set-retrieval loss. We want to maximize attention scores received by memories which passages have large enough entity overlap (at least `same_entity_set_target_threshold` unique entities in common) with query mention's passage. User specifies which entity IDs exist in the current batch via the arguments: `mention_target_batch_positions`, `mention_target_ids` and `mention_target_weights`. Note that these correspond to linked mentions. While we retrieve memories for all mentions, we don't know entity IDs for non-linked mentions and they are not needed for computing entity overlap. Formally, let E(j) be the set of entity IDs in the j-th sample in the batch. k-th retrieval for i-th mention is "correct" if and only if IntersectionSize(E(mention_batch_positions[i]), memory_text_entities[i, j]) is at least `same_entity_set_target_threshold`. The loss is first computed for every mention separately and is negative log of sum of memory_attention_weights[i, j] where j is a "correct" retrieval for the i-th mention. The final loss is average of losses for those mentions which have at least one "correct" retrieval. Args: mention_target_batch_positions: [n_target_mentions] position of a linked (target) mention in its batch. mention_target_ids: [n_target_mentions] IDs of mentions. mention_target_weights: [n_target_mentions] per-mention weight for linked (target) mentions indicating whether a mention is padding or not. mention_batch_positions: [n_mentions] position of a mention in its batch. mention_mask: [n_mentions] whether a mention is padding or not. memory_text_entities: [n_mentions, k_top, n_memory_text_entities] IDs of mentions in the passage where memory is coming from. memory_attention_weights: [n_mentions, k_top] attention weights for the retrieval results. memory_mask: [n_mentions, k_top] which retrievals to use and which are to ignore in the loss computations. Typical usage is to ignore "disallowed" or same-passage retrievals. batch_size: batch size. same_entity_set_target_threshold: how many common entities needs to be between memory's passage and mention's passage in order to treat retrieval result as positive. If it's equal to 2 this loss becomes `same-mtb-retrieval` loss. Returns: Tuple of scalar loss, avg correct probability and normalizing factor. """ n_mentions = memory_text_entities.shape[0] k_top = memory_text_entities.shape[1] # [batch_size, n_mentions * k_top] num_common_ids_between_samples = get_batch_and_retrievals_entity_overlap( mention_target_batch_positions=mention_target_batch_positions, mention_target_ids=mention_target_ids, mention_target_weights=mention_target_weights, memory_text_entities=memory_text_entities.reshape( [n_mentions * k_top, -1]), batch_size=batch_size, ) # [batch_size, n_mentions, k_top] num_common_ids_between_samples = num_common_ids_between_samples.reshape( [batch_size, n_mentions, k_top]) # [batch_size, n_mentions] position2item = (jnp.expand_dims(jnp.arange(batch_size), 1) == mention_batch_positions) position2item = position2item.astype(jnp.int32) # [n_mentions, k_top] num_common_ids_between_mentions = jnp.einsum( 'bm,bmk->mk', position2item, num_common_ids_between_samples) # compute which retrievals have enough common elements with the # passage in the batch (at least `same_entity_set_target_threshold`) and # therefore, should be marked as "correct" retrievals. # [n_mentions, k_top] enough_common_elements = (num_common_ids_between_mentions >= same_entity_set_target_threshold) enough_common_elements = enough_common_elements.astype(jnp.int32) correct_retrievals_mask = enough_common_elements * memory_mask incorrect_retrievals_mask = (1 - enough_common_elements) * memory_mask # [n_mentions] correct_retrieval_exists = correct_retrievals_mask.sum(-1) > 0 incorrect_retrieval_exists = incorrect_retrievals_mask.sum(-1) > 0 loss_mask = jnp.logical_and(correct_retrieval_exists, incorrect_retrieval_exists) loss_mask = loss_mask.astype(mention_mask.dtype) * mention_mask loss_mask = loss_mask.astype(jnp.float32) # [n_mentions, k_top] correct_retrievals_mask = correct_retrievals_mask.astype(jnp.float32) # compute loss and metrics # [n_mentions, k_top] correct_probs = jnp.einsum('mk,mk->m', correct_retrievals_mask, memory_attention_weights) avg_probs = correct_probs * loss_mask loss = -jnp.log(correct_probs + _SMALL_NUMBER) loss = loss * loss_mask return loss.sum(), avg_probs.sum(), loss_mask.sum(),
def entity_linking_loss(mention_encodings: Array, entity_embeddings: Array, mention_target_ids: Array, mention_target_weights: Array, mode: str) -> Array: """Compute entity linking loss. Args: mention_encodings: [n_mentions, hidden_size] mention encodings to be used for computing the loss. entity_embeddings: [n_entities, hidden_size] entity embeddings table. mention_target_ids: [n_mentions] IDs of mentions. mention_target_weights: [n_mentions] per-mention weight for computing loss and metrics. mode: how to compute the scores -- using dot product ('dot'), dot product divided by the sqrt root of the hidden dim ('dot_sqrt') or cosine similarity ('cos'). Returns: Loss, a dictionary with metrics values, per sample infomation (a tuple of accuracy per mention and weight per mention). """ scores = jnp.einsum('qd,ed->qe', mention_encodings, entity_embeddings) scores = scores.astype(jnp.float32) mention_encodings_norm = jnp.linalg.norm(mention_encodings, axis=-1) entity_embeddings_norm = jnp.linalg.norm(entity_embeddings, axis=-1) # The cosine similarity is computed as dot product divided by norms of # both vectors. cos_scores = scores cos_scores /= (_SMALL_NUMBER + jnp.expand_dims(mention_encodings_norm, 1)) cos_scores /= (_SMALL_NUMBER + jnp.expand_dims(entity_embeddings_norm, 0)) if mode == 'dot': pass elif mode == 'dot_sqrt': hidden_dim = mention_encodings.shape[1] scores /= jnp.sqrt(hidden_dim) elif mode == 'cos': scores = cos_scores else: raise ValueError('Unknown entity linking mode: ' + mode) mention_target_weights = mention_target_weights.astype(jnp.float32) loss, _ = metric_utils.compute_weighted_cross_entropy( scores, mention_target_ids, mention_target_weights, inputs_are_prob=False) acc_per_mention = jnp.equal(jnp.argmax(scores, axis=-1), mention_target_ids) acc_per_mention = acc_per_mention * mention_target_weights n_mentions = mention_target_ids.shape[0] cos_per_mention = cos_scores[jnp.arange(n_mentions), mention_target_ids] cos_per_mention = cos_per_mention * mention_target_weights metrics = { 'loss': loss, 'acc': acc_per_mention.sum(), 'cos_sim': cos_per_mention.sum(), 'denominator': mention_target_weights.sum() } return loss, metrics, (acc_per_mention, mention_target_weights)
def __call__( self, queries: Array, table: Array, prototypes: Array, ) -> Tuple[Array, Array, Array]: """Perform approximate top-k similarity search over vector table. Args: queries: [n_queries, vector_dim]. table: [n_clusters, rows, values per row, vector_dim] vector table. prototypes: [n_clusters, vector_dim] representative vectors for clusters. Returns: Top-k vectors, scores and ids. """ n_queries = queries.shape[0] queries_per_split = n_queries // self.splits rows_per_cluster = table.shape[1] values_per_row = table.shape[2] values_per_cluster = rows_per_cluster * values_per_row table_size = values_per_row * table.shape[0] vector_dim = queries.shape[1] assert table.shape[-1] == vector_dim # Split queries to reduce size of selected clusters and save memory. queries = queries.reshape(self.splits, queries_per_split, vector_dim) def split_top_k(split_queries: Array) -> Tuple[Array, Array, Array]: # Find most similar clusters prototype_scores = jnp.einsum('qd,pd->qp', split_queries, prototypes) top_indices = jax.lax.top_k(prototype_scores, self.n_search)[1] # Perform approximate top-k similarity search over most similar clusters. selected_data = table[top_indices] split_scores = jnp.einsum('qd,qcrvd->qcrv', split_queries, selected_data) # Find highest scoring vector for each row. top_id_by_row = jnp.argmax(split_scores, axis=-1) top_score_by_row = jnp.max(split_scores, axis=-1) top_id_by_row = top_id_by_row.reshape( queries_per_split, self.n_search * rows_per_cluster) top_score_by_row = top_score_by_row.reshape( queries_per_split, self.n_search * rows_per_cluster) # Take k highest scores among all rows. top_row_idx = jnp.argsort(top_score_by_row, axis=-1)[:, :-self.k_top - 1:-1] # Sub-select best indices for k best rows. ids_by_topk_row = jut.matmul_slice(top_id_by_row, top_row_idx) # Gather highest scoring vectors for k best rows. query_index = jnp.arange(queries_per_split).reshape(-1, 1).tile( [1, self.k_top]) top_cluster_idx, top_cluster_row_idx = jnp.divmod( top_row_idx, rows_per_cluster) split_topk_values = selected_data[query_index, top_cluster_idx, top_cluster_row_idx, ids_by_topk_row] row_offset = jnp.mod( jnp.arange(0, self.n_search * values_per_cluster, values_per_row), values_per_cluster) cluster_offset = jnp.arange(0, table_size, values_per_cluster) # Convert row indices to indices into flattened table. top_table_id_by_row = top_id_by_row + row_offset.reshape( 1, -1) + cluster_offset[top_indices].repeat(rows_per_cluster, axis=-1) # Get best ids into flattened table. split_topk_ids = jut.matmul_slice(top_table_id_by_row, top_row_idx) split_topk_scores = jut.matmul_slice(top_score_by_row, top_row_idx) return split_topk_values, split_topk_scores, split_topk_ids # Perform similarity over each chunk of queries sequentially # (not in parallel), so that only one score tensor is in memory at a time. topk_values, topk_scores, topk_ids = jax.lax.map(split_top_k, queries) topk_values = topk_values.reshape(n_queries, self.k_top, -1) topk_scores = topk_scores.reshape(n_queries, self.k_top) topk_ids = topk_ids.reshape(n_queries, self.k_top) return topk_values, topk_scores, topk_ids
def compute_cross_entropy_loss_with_positives_and_negatives_masks( scores: Array, positives: Array, negatives: Array, weights: Optional[Array] = None, ) -> Tuple[float, Dict[str, float], Tuple[Array, Array]]: """Compute (weighted) cross-entropy loss and accuracy-related metrics. The function computes cross entropy loss when there are potentially multiple positive classes per sample, multiple negative classes and others are neutral. In this case, loss per sample is average of cross entropy losses computed by considering each positive class and all negative classes. Neutral classes are ignored. Arguments `positives` and `negatives` are boolean matrices that specify which class is considered positive or negative per every sample. `positive[i, j]` is True <=> class j is considered positive for the sample i `negative[i, j]` is True <=> class j is considered negative for the sample i The loss is computed in 3 stages: (1) For every sample i and positive class j we compute cross-entropy loss using j as a positive class and all negative classes for i as negatives. (2) For every sample i the total loss is average of losses per each of its positive classes. (3) Total loss is a sum of losses per each sample. The loss only includes samples, which have at least one positive and one negative classes. Users can limit this even further by providing a custom `weights`. Args: scores: [batch_size, num_classes] scores or logits. positives: [batch_size, num_classes] 0-1 mask for which classes are positive. negatives: [batch_size, num_classes] 0-1 mask for which classes are negative. weights: [batch_size] 0-1 masks indicating whether the loss should be computed for the corresponding item in the batch. Returns: A tuple of scalar loss, a dictionary with metrics, per sample information (a tuple of average positive probability per sample and weight per sample). """ at_least_one_positive_and_negative = jnp.logical_and( positives.sum(-1) > 0, negatives.sum(-1) > 0) if weights is None: weights = at_least_one_positive_and_negative else: weights = jnp.logical_and(weights, at_least_one_positive_and_negative) scores = scores.astype(jnp.float32) positives = positives.astype(jnp.float32) negatives = negatives.astype(jnp.float32) weights = weights.astype(jnp.float32) # For simplicity, we ignore the first batch dimension in the equations below # and assume that the loss is computed for a single sample. # Let p_1, ..., p_N be scores of positive classes # and n_1, ..., n_M be scores of negative classes. # In this case the loss is # sum_{i=1..N} -log softmax([p_i, n_1, ..., n_M])_1. # It's too computationally expensive to compute it naively. # We implement the loss in the following way # (1) compute S, the negatives part of softmax denominator. In other words, # exp(S) = sum_{j=1..M} exp(n_j) negative_scores = scores * negatives - _BIG_NUMBER * (1.0 - negatives) negative_scores_log_sum_exp = jax.nn.logsumexp( negative_scores, axis=-1, keepdims=True) # (2) now the loss per positive class i is just # -log (exp(p_i) / (exp(p_i) + exp(S)) = -log(1 / (1 + exp(-(p_i - S)))) # = -log sigmoid(p_i - S) scores_minus_negatives = scores - negative_scores_log_sum_exp positives_weight = (positives.sum(axis=-1) + _SMALL_NUMBER) per_positive_loss = -jax.nn.log_sigmoid(scores_minus_negatives) # (3) compute average loss over all positive classes loss_per_sample = (per_positive_loss * positives).sum(axis=-1) loss_per_sample /= positives_weight loss_per_sample *= weights # (4) compute sum of losses over all positive samples loss = loss_per_sample.sum() # Now we need to compute the average accuracy. # First, compute the max score of negative classes per sample. # A positive class needs to have a higher score in order to get predicted. max_negative_scores = negative_scores.max(axis=-1, keepdims=True) # Second, a prediction for pair of a sample and its positive class # is correct if the score of the positive class is larger than # scores of all corresponding negative classes. In other words, the score # of the positive class needs to be larger than `max_negative_scores`. correct_prediction = (scores > max_negative_scores).astype(jnp.float32) # Take average over all positive classes per sample correct_prediction = (correct_prediction * positives).sum(axis=-1) correct_prediction /= positives_weight # Mask out samples with 0 weight correct_prediction = correct_prediction * weights metrics = { 'loss': loss, 'acc': correct_prediction.sum(), 'denominator': weights.sum(), } return loss, metrics, (correct_prediction, weights)
def __call__( self, queries: Array, table: Array, ) -> Tuple[Array, Array, Array]: """Perform approximate top-k similarity search over vector table. Args: queries: [n_queries, vector_dim]. table: [rows, values per row, vector_dim] vector table. The number of rows in the table governs the recall vs speed of the topk similarity search. Search is performed by taking max over each row, and then top-k between rows. Distributing the same values over more rows leads to higher recall but slower search. Returns: Top-k vectors, scores and ids. """ n_queries = queries.shape[0] queries_per_split = n_queries // self.splits scores_per_row = table.shape[1] table_size = scores_per_row * table.shape[0] vector_dim = queries.shape[1] assert table.shape[-1] == vector_dim # Split queries to reduce size of intermediate score tensor and save memory. queries = queries.reshape(self.splits, queries_per_split, vector_dim) def split_top_k(split_queries): split_scores = jnp.einsum('qd,rvd->qrv', split_queries, table) # Find highest scoring vector for each row. top_id_by_row = jnp.argmax(split_scores, axis=-1) top_score_by_row = jnp.max(split_scores, axis=-1) # Take k highest scores among all rows. top_row_idx = jnp.argsort(top_score_by_row, axis=-1)[:, :-self.k_top - 1:-1] # Sub-select best indices for k best rows. ids_by_topk_row = jut.matmul_slice(top_id_by_row, top_row_idx) # Gather highest scoring vectors for k best rows. split_topk_values = table[top_row_idx, ids_by_topk_row] # Convert row indices to indices into flattened table. top_table_id_by_row = top_id_by_row + jnp.arange( 0, table_size, scores_per_row) # Get best ids into flattened table. split_topk_ids = jut.matmul_slice(top_table_id_by_row, top_row_idx) split_topk_scores = jut.matmul_slice(top_score_by_row, top_row_idx) return split_topk_values, split_topk_scores, split_topk_ids # Perform similarity over each chunk of queries sequentially # (not in parallel), so that only one score tensor is in memory at a time. topk_values, topk_scores, topk_ids = jax.lax.map(split_top_k, queries) topk_values = topk_values.reshape(n_queries, self.k_top, -1) topk_scores = topk_scores.reshape(n_queries, self.k_top) topk_ids = topk_ids.reshape(n_queries, self.k_top) return topk_values, topk_scores, topk_ids
def __call__( self, encoding: Array, mention_batch_positions: Array, mention_start_positions: Array, mention_end_positions: Array, mention_mask: Array, memory_keys: Array, memory_values: Array, memory_mask: Array, memory_entity_ids: Array, deterministic: bool, ) -> Tuple[Array, Dict[str, Array], Dict[str, Array]]: """Perform attention update over memory table. Args: encoding: [batch_size, n_tokens, hidden_size] input representation. mention_batch_positions: [n_mentions] mention sample position in batch. mention_start_positions: [n_mentions] mention start position in input. mention_end_positions: [n_mentions] mention end position in input. mention_mask: [n_mentions] attention mask to prevent updates from padding. memory_keys: [memory_size, memory_key_dim] mention memory keys. memory_values: [memory_size, memory_value_dim] mention memory values. memory_mask: [memory_size] mask for valid mentions in memory. memory_entity_ids: [memory_size] mention memory entity ids. deterministic: don't apply dropout if true. Returns: Updated input, loss and logging helper dicts. """ loss_helpers, logging_helpers = {}, {} # We generate mention representations to use as queries for similarity # search by concatenating start and end tokens for each mention and # projecting the concatenation with a dense layer. mention_start_encodings = jut.matmul_2d_index_select( encoding, (mention_batch_positions, mention_start_positions)) mention_end_encodings = jut.matmul_2d_index_select( encoding, (mention_batch_positions, mention_end_positions)) queries = self.query_projector( jnp.concatenate((mention_start_encodings, mention_end_encodings), axis=-1)) n_queries = queries.shape[0] # For attention over entire memory table, we do not want to duplicate the # entire memory table for each query. Instead, we perform an # attention-weighted sum to produce a single value. We then feed this value # to the update layer as a set of retrieved values of size 1, with score 1. if self.k_top is None: loss_helpers['top_entity_ids'] = jnp.tile(memory_entity_ids, (n_queries, 1)) scores = jnp.einsum('qd,md->qm', queries, memory_keys) scores = scores - (1 - memory_mask) * _LARGE_NUMBER true_attention_weights = nn.softmax(scores, axis=-1) loss_helpers['memory_attention_weights'] = true_attention_weights top_values = jnp.einsum('qm,md->qd', true_attention_weights, memory_values) # Expand value as though it were a set of retrieved values for each query. # Shape (n_queries, 1, memory_value_dim) top_values = jnp.expand_dims(top_values, axis=1) # Generate pseudo-score (n_queries, 1). attention_weights = jnp.ones_like(top_values, shape=(n_queries, 1)) else: # Reshape memory keys for use in approximate top-k similarity layer. memory_keys = memory_keys.reshape(self.rows, -1, self.memory_key_dim) # We generate a version of the queries with stop gradient to use as input # to the topk similarity layer. We actually do want gradient to flow to # the queries, but backward differentiation over the topk layer yields # inefficient HLO ops. Instead we use queries with gradient to recompute # attention scores later. queries_sg = jax.lax.stop_gradient(queries) # Perform top-k similarity search over queries, yielding # top_values: (queries, k_top, memory_dim) # top_ids: (queries, k_top) top_keys, _, top_ids = self.topk_similarity(queries_sg, memory_keys) top_ids = top_ids.reshape(n_queries, self.k_top) top_values = memory_values[top_ids] loss_helpers['top_entity_ids'] = memory_entity_ids[top_ids] # We re-compute top scores using the queries with gradient (wg) to make # sure the query projector and the rest of the model receives gradient. top_scores_wg = jnp.einsum('qd,qkd->qk', queries, top_keys) top_mask = memory_mask[top_ids] top_scores_wg = top_scores_wg - (1 - top_mask) * _LARGE_NUMBER # We perform dot product attention using retrieved memory vectors as key, # dense projection of retrieved vectors as value and value and mention # representations as query. attention_weights = nn.softmax(top_scores_wg, axis=-1) loss_helpers['memory_attention_weights'] = attention_weights encoding = self.update_layer( encoded_input=encoding, retrieval_values=top_values, retrieval_scores=attention_weights, mention_batch_positions=mention_batch_positions, mention_start_positions=mention_start_positions, mention_end_positions=mention_end_positions, mention_mask=mention_mask, deterministic=deterministic, ) return encoding, loss_helpers, logging_helpers