Exemplo n.º 1
0
        def split_top_k(split_queries):
            split_scores = jnp.einsum('qd,rvd->qrv', split_queries, table)

            # Find highest scoring vector for each row.
            top_id_by_row = jnp.argmax(split_scores, axis=-1)
            top_score_by_row = jnp.max(split_scores, axis=-1)

            # Take k highest scores among all rows.
            top_row_idx = jnp.argsort(top_score_by_row,
                                      axis=-1)[:, :-self.k_top - 1:-1]

            # Sub-select best indices for k best rows.
            ids_by_topk_row = jut.matmul_slice(top_id_by_row, top_row_idx)

            # Gather highest scoring vectors for k best rows.
            split_topk_values = table[top_row_idx, ids_by_topk_row]

            # Convert row indices to indices into flattened table.
            top_table_id_by_row = top_id_by_row + jnp.arange(
                0, table_size, scores_per_row)
            # Get best ids into flattened table.
            split_topk_ids = jut.matmul_slice(top_table_id_by_row, top_row_idx)

            split_topk_scores = jut.matmul_slice(top_score_by_row, top_row_idx)

            return split_topk_values, split_topk_scores, split_topk_ids
Exemplo n.º 2
0
        def split_top_k(split_queries: Array) -> Tuple[Array, Array, Array]:
            # Find most similar clusters
            prototype_scores = jnp.einsum('qd,pd->qp', split_queries,
                                          prototypes)
            top_indices = jax.lax.top_k(prototype_scores, self.n_search)[1]
            # Perform approximate top-k similarity search over most similar clusters.
            selected_data = table[top_indices]
            split_scores = jnp.einsum('qd,qcrvd->qcrv', split_queries,
                                      selected_data)

            # Find highest scoring vector for each row.
            top_id_by_row = jnp.argmax(split_scores, axis=-1)
            top_score_by_row = jnp.max(split_scores, axis=-1)

            top_id_by_row = top_id_by_row.reshape(
                queries_per_split, self.n_search * rows_per_cluster)
            top_score_by_row = top_score_by_row.reshape(
                queries_per_split, self.n_search * rows_per_cluster)

            # Take k highest scores among all rows.
            top_row_idx = jnp.argsort(top_score_by_row,
                                      axis=-1)[:, :-self.k_top - 1:-1]

            # Sub-select best indices for k best rows.
            ids_by_topk_row = jut.matmul_slice(top_id_by_row, top_row_idx)

            # Gather highest scoring vectors for k best rows.
            query_index = jnp.arange(queries_per_split).reshape(-1, 1).tile(
                [1, self.k_top])
            top_cluster_idx, top_cluster_row_idx = jnp.divmod(
                top_row_idx, rows_per_cluster)
            split_topk_values = selected_data[query_index, top_cluster_idx,
                                              top_cluster_row_idx,
                                              ids_by_topk_row]

            row_offset = jnp.mod(
                jnp.arange(0, self.n_search * values_per_cluster,
                           values_per_row), values_per_cluster)
            cluster_offset = jnp.arange(0, table_size, values_per_cluster)

            # Convert row indices to indices into flattened table.
            top_table_id_by_row = top_id_by_row + row_offset.reshape(
                1, -1) + cluster_offset[top_indices].repeat(rows_per_cluster,
                                                            axis=-1)
            # Get best ids into flattened table.
            split_topk_ids = jut.matmul_slice(top_table_id_by_row, top_row_idx)

            split_topk_scores = jut.matmul_slice(top_score_by_row, top_row_idx)

            return split_topk_values, split_topk_scores, split_topk_ids
Exemplo n.º 3
0
    def __call__(
        self,
        encoded_input: Array,
        mlm_target_positions: Array,
        shared_embedding: Array,
    ) -> Array:
        """Perform masked language modeling scoring.

    Args:
      encoded_input: [bsz, n_tokens, hidden_size].
      mlm_target_positions: [bsz, max_mlm_targets] positions of mlm targets in
        passage.
      shared_embedding: [vocab_size, hidden_size] word embedding array, shared
        with initial embedding.

    Returns:
      Array of masked language modeling logits.
    """

        target_encodings = jut.matmul_slice(encoded_input,
                                            mlm_target_positions)
        target_encodings = self.dense(target_encodings)
        target_encodings = nn.gelu(target_encodings)
        target_encodings = self.layer_norm(target_encodings)

        mlm_logits = self.embedding_dense.apply(
            {'params': {
                'kernel': shared_embedding.T
            }}, target_encodings)
        mlm_logits = mlm_logits + self.bias

        return mlm_logits
Exemplo n.º 4
0
    def test_slice_values_int(self, bsz, seq_len, index_len, dim):
        # no batch dim
        array = np.random.randint(_MAX_INT_VALUE, size=(seq_len, dim))
        indices = np.random.randint(seq_len, size=(index_len))
        matmul_slice = jut.matmul_slice(array, indices)
        vmap_slice = array[indices]
        self.assertTrue(jnp.allclose(matmul_slice, vmap_slice))

        # 2d array
        array = np.random.randint(_MAX_INT_VALUE, size=(bsz, seq_len))
        indices = np.random.randint(seq_len, size=(bsz, index_len))
        matmul_slice = jut.matmul_slice(array, indices)
        vmap_slice = jut.vmap_slice(array, indices)
        self.assertTrue(jnp.allclose(matmul_slice, vmap_slice))

        # 3d array
        array = np.random.randint(_MAX_INT_VALUE, size=(bsz, seq_len, dim))
        indices = np.random.randint(seq_len, size=(bsz, index_len))
        matmul_slice = jut.matmul_slice(array, indices)
        vmap_slice = jut.vmap_slice(array, indices)
        self.assertTrue(jnp.allclose(matmul_slice, vmap_slice))
Exemplo n.º 5
0
    def __call__(self, batch: Dict[str, Array], deterministic: bool):
        _, loss_helpers, logging_helpers = self.encoder.forward(
            batch, deterministic)
        mention_encodings = loss_helpers[self.mention_encodings_feature]

        subject_mention_encodings = jut.matmul_slice(
            mention_encodings, batch['mention_subject_indices'])

        object_mention_encodings = jut.matmul_slice(
            mention_encodings, batch['mention_object_indices'])

        relation_encodings = jnp.concatenate(
            [subject_mention_encodings, object_mention_encodings], -1)

        for mlp_layer in self.classification_mlp_layers:
            relation_encodings = mlp_layer(relation_encodings, deterministic)

        classifier_logits = self.linear_classifier(relation_encodings)
        loss_helpers['classifier_logits'] = classifier_logits

        return loss_helpers, logging_helpers
Exemplo n.º 6
0
            def process_el_im_loss(loss, weight, prefix=''):
                memory_attention_weights = loss_helpers[
                    prefix + 'memory_attention_weights']
                memory_entity_ids = loss_helpers[prefix + 'top_entity_ids']

                target_mentions_memory_attention_weights = jut.matmul_slice(
                    memory_attention_weights, batch['mention_target_indices'])

                intermediate_entity_ids = jut.matmul_slice(
                    memory_entity_ids, batch['mention_target_indices'])

                el_loss_intermediate, same_entity_avg_prob, el_im_denom = metric_utils.compute_loss_and_prob_from_probs_with_duplicates(
                    target_mentions_memory_attention_weights,
                    intermediate_entity_ids, mention_target_ids,
                    batch['mention_target_weights'])

                if weight > 0:
                    loss += weight * el_loss_intermediate / el_im_denom
                metrics[prefix + 'el_intermediate'] = {
                    'loss': el_loss_intermediate,
                    'same_entity_avg_prob': same_entity_avg_prob,
                    'denominator': el_im_denom,
                }
                return loss
Exemplo n.º 7
0
    def forward(
        self,
        batch: Dict[str, Array],
        deterministic: bool,
    ) -> Tuple[Array, Dict[str, Array], Dict[str, Array]]:
        loss_helpers = {}
        logging_helpers = {}

        embedded_input = self.embedder({
            'token_ids': batch['text_ids'],
            'position_ids': batch['position_ids'],
            'segment_ids': batch['segment_ids']
        })

        embedded_input = self.embeddings_layer_norm(embedded_input)
        embedded_input = self.embeddings_dropout(embedded_input,
                                                 deterministic=deterministic)

        loss_helpers['word_embeddings'] = self.embedder.variables['params'][
            'embedders_token_ids']['embedding']

        attention_mask = batch['text_mask']
        encoding = self.initial_encoder(encoding=embedded_input,
                                        attention_mask=attention_mask,
                                        deterministic=deterministic)

        memory_values = jnp.asarray(
            self.memory_values.value,
            dtype=self.dtype) if self.separate_memory_values else None
        memory_keys = jnp.asarray(self.memory_keys.value, dtype=self.dtype)
        memory_entity_ids = self.memory_entity_ids.value
        memory_identifiers = self.memory_identifiers.value

        loss_helpers['memory_values'] = memory_values
        loss_helpers['memory_keys'] = memory_keys
        loss_helpers['memory_entity_ids'] = memory_entity_ids
        loss_helpers['memory_identifiers'] = memory_identifiers

        def apply_memory_attention(memory_layer, encoding, prefix=''):
            encoding, mem_loss_helpers, mem_logging_helpers = memory_layer(
                encoded_input=encoding,
                mention_batch_positions=batch['mention_batch_positions'],
                mention_start_positions=batch['mention_start_positions'],
                mention_end_positions=batch['mention_end_positions'],
                mention_mask=batch['mention_mask'],
                memory_keys=memory_keys,
                memory_identifiers=memory_identifiers,
                memory_entity_ids=memory_entity_ids,
                deterministic=deterministic,
                memory_values=memory_values,
                text_identifiers=batch.get('text_identifiers', None),
                memory_text_entities=(self.memory_text_entities.value
                                      if self.memory_text_entities is not None
                                      else None),
                same_passage_memory_policy=self.same_passage_memory_policy,
            )
            loss_helpers.update({
                prefix + key: value
                for key, value in mem_loss_helpers.items()
            })
            logging_helpers.update({
                prefix + key: value
                for key, value in mem_logging_helpers.items()
            })
            return encoding

        if self.num_intermediate_layers is None:
            encoding = apply_memory_attention(self.memory_attention_layer,
                                              encoding)
        else:
            encoding = apply_memory_attention(
                self.intermediate_memory_attention_layer, encoding)
            encoding = self.intermediate_encoder(encoding=encoding,
                                                 attention_mask=attention_mask,
                                                 deterministic=deterministic)
            encoding = apply_memory_attention(
                self.final_memory_attention_layer, encoding, 'second_')
        encoding = self.final_encoder(encoding=encoding,
                                      attention_mask=attention_mask,
                                      deterministic=deterministic)

        if 'mention_target_batch_positions' in batch:
            mention_start_final_encodings = jut.matmul_2d_index_select(
                encoding, (batch['mention_target_batch_positions'],
                           batch['mention_target_start_positions']))
            mention_end_final_encodings = jut.matmul_2d_index_select(
                encoding, (batch['mention_target_batch_positions'],
                           batch['mention_target_end_positions']))

            loss_helpers[
                'intermediate_target_mention_encodings'] = jut.matmul_slice(
                    loss_helpers['memory_attention_mention_encodings'],
                    batch['mention_target_indices'])
            if self.num_intermediate_layers is not None:
                loss_helpers[
                    'second_intermediate_target_mention_encodings'] = jut.matmul_slice(
                        loss_helpers[
                            'second_memory_attention_mention_encodings'],
                        batch['mention_target_indices'])

            loss_helpers['target_mention_encodings'] = self.mention_projector(
                jnp.concatenate((mention_start_final_encodings,
                                 mention_end_final_encodings),
                                axis=-1))

            # Final retrieval layer is only applied over target mentions.
            if self.apply_final_retrieval:
                queries = self.final_query_projector(
                    loss_helpers['target_mention_encodings'])

                retrieval_result = self.final_memory_retrieval_layer(
                    queries=queries,
                    memory_keys=memory_keys,
                    memory_identifiers=memory_identifiers,
                    memory_entity_ids=memory_entity_ids,
                    memory_values=memory_values,
                    text_identifiers=None,
                    memory_text_entities=None,
                    same_passage_memory_policy='disallow',
                )

                loss_helpers.update(
                    {'final_' + k: v
                     for k, v in retrieval_result.items()})

        return encoding, loss_helpers, logging_helpers
Exemplo n.º 8
0
    def forward(self, batch: Dict[str, Array], deterministic: bool):
        loss_helpers = {}
        logging_helpers = {}

        embedded_input = self.embedder({
            'token_ids': batch['text_ids'],
            'position_ids': batch['position_ids'],
            'segment_ids': batch['segment_ids']
        })

        embedded_input = self.embeddings_layer_norm(embedded_input)
        embedded_input = self.embeddings_dropout(embedded_input, deterministic)

        loss_helpers['word_embeddings'] = self.embedder.variables['params'][
            'embedders_token_ids']['embedding']

        attention_mask = batch['text_mask']
        encoding = self.initial_encoder(encoding=embedded_input,
                                        attention_mask=attention_mask,
                                        deterministic=deterministic)

        if not self.no_retrieval:
            encoding = self.retrieval_update_layer(
                encoded_input=encoding,
                retrieval_values=jnp.expand_dims(
                    # [max_retrieval_indices, retrieval_dim]
                    batch['retrieval_mention_values'],
                    -2),
                retrieval_scores=jnp.expand_dims(
                    # [max_retrieval_indices]
                    batch['retrieval_mention_scores'],
                    -1),
                mention_batch_positions=batch[
                    'retrieval_mention_batch_positions'],
                mention_start_positions=batch[
                    'retrieval_mention_start_positions'],
                mention_end_positions=batch['retrieval_mention_end_positions'],
                mention_mask=batch['retrieval_mention_mask'],
                deterministic=deterministic)

        encoding = self.final_encoder(encoding=encoding,
                                      attention_mask=attention_mask,
                                      deterministic=deterministic)

        mention_target_batch_positions = jut.matmul_slice(
            batch['mention_batch_positions'], batch['mention_target_indices'])
        mention_target_start_positions = jut.matmul_slice(
            batch['mention_start_positions'], batch['mention_target_indices'])
        mention_target_end_positions = jut.matmul_slice(
            batch['mention_end_positions'], batch['mention_target_indices'])

        mention_start_final_encodings = jut.matmul_2d_index_select(
            encoding,
            (mention_target_batch_positions, mention_target_start_positions))
        mention_end_final_encodings = jut.matmul_2d_index_select(
            encoding,
            (mention_target_batch_positions, mention_target_end_positions))
        loss_helpers['target_mention_encodings'] = self.mention_projector(
            jnp.concatenate(
                (mention_start_final_encodings, mention_end_final_encodings),
                axis=-1))

        return encoding, loss_helpers, logging_helpers
Exemplo n.º 9
0
    def loss_fn(
        model_config: ml_collections.FrozenConfigDict,
        model_params: Dict[str, Any],
        model_vars: Dict[str, Any],  # pylint: disable=unused-argument
        batch: Dict[str, Any],
        deterministic: bool,
        dropout_rng: Optional[Dict[str, Array]] = None,
    ) -> Tuple[float, MetricGroups, Dict[str, Any]]:
      """Task-specific loss function. See BaseTask."""

      batch_size = batch['text_ids'].shape[0]
      loss_helpers, logging_helpers = cls.build_model(model_config).apply(  # pylint: disable=unused-variable
          {'params': model_params},
          batch,
          deterministic=deterministic,
          rngs=dropout_rng)
      mention_target_is_masked = batch['mention_target_is_masked']
      mention_target_is_not_masked = 1 - batch['mention_target_is_masked']
      mention_target_ids = batch['mention_target_ids']
      mention_target_ids = mention_target_ids * batch['mention_target_weights']

      mlm_logits = loss_helpers['mlm_logits']

      mlm_loss, mlm_denom = metric_utils.compute_weighted_cross_entropy(
          mlm_logits, batch['mlm_target_ids'], batch['mlm_target_weights'])

      mlm_correct_mask = jnp.equal(
          jnp.argmax(mlm_logits, axis=-1),
          batch['mlm_target_ids']) * batch['mlm_target_weights']
      mlm_acc = mlm_correct_mask.sum()
      mlm_mention_acc = (mlm_correct_mask *
                         batch['mlm_target_is_mention']).sum()
      mlm_mention_denom = (batch['mlm_target_weights'] *
                           batch['mlm_target_is_mention']).sum()
      mlm_non_mention_acc = (mlm_correct_mask *
                             (1 - batch['mlm_target_is_mention'])).sum()
      mlm_non_mention_denom = (batch['mlm_target_weights'] *
                               (1 - batch['mlm_target_is_mention'])).sum()

      metrics = {
          'mlm': {
              'loss': mlm_loss,
              'acc': mlm_acc,
              'denominator': mlm_denom,
          },
          'mlm_mention': {
              'acc': mlm_mention_acc,
              'denominator': mlm_mention_denom,
          },
          'mlm_non_mention': {
              'acc': mlm_non_mention_acc,
              'denominator': mlm_non_mention_denom,
          },
      }

      if 'intermediate_mention_encodings' in loss_helpers:
        intermediate_target_mention_encodings = jut.matmul_slice(
            loss_helpers['intermediate_mention_encodings'],
            batch['mention_target_indices'])
      else:
        intermediate_target_mention_encodings = loss_helpers[
            'im_target_mention_encodings']

      if model_config.encoder_config.get('no_entity_attention', False):
        (el_im_loss, el_im_metrics,
         (el_im_acc_per_mention,
          el_im_weight_per_mention)) = mention_losses.entity_linking_loss(
              intermediate_target_mention_encodings,
              loss_helpers['entity_embeddings'], mention_target_ids,
              batch['mention_target_weights'], el_score_mode)
        el_im_denom = el_im_metrics['denominator']
        metrics['el_intermediate'] = el_im_metrics
        metrics['el_intermediate_masked'] = {
            'acc':
                jnp.dot(el_im_acc_per_mention,
                        el_im_weight_per_mention * mention_target_is_masked),
            'denominator':
                jnp.dot(el_im_weight_per_mention, mention_target_is_not_masked),
        }
        metrics['el_intermediate_non_masked'] = {
            'acc':
                jnp.dot(el_im_acc_per_mention,
                        el_im_weight_per_mention * mention_target_is_masked),
            'denominator':
                jnp.dot(el_im_weight_per_mention, mention_target_is_not_masked),
        }
      else:
        intermediate_entity_attention = loss_helpers[
            'intermediate_entity_attention']

        # Construct targets and ids for intermediate entity linking loss
        intermediate_target_ids = jnp.zeros_like(batch['mention_mask'])
        intermediate_target_ids = intermediate_target_ids.at[
            batch['mention_target_indices']].add(
                mention_target_ids * batch['mention_target_weights'])

        intermediate_target_weights = jnp.zeros_like(
            batch['mention_mask'], dtype=intermediate_entity_attention.dtype)
        intermediate_target_weights = intermediate_target_weights.at[
            batch['mention_target_indices']].add(
                batch['mention_target_weights'])

        mention_is_masked = jnp.zeros_like(batch['mention_mask'])
        mention_is_masked = mention_is_masked.at[
            batch['mention_target_indices']].add(
                mention_target_is_masked * batch['mention_target_weights'])

        el_im_loss, el_im_denom = metric_utils.compute_weighted_cross_entropy(
            intermediate_entity_attention,
            intermediate_target_ids,
            intermediate_target_weights,
            inputs_are_prob=True)

        el_im_correct_mask = jnp.equal(
            jnp.argmax(intermediate_entity_attention, axis=-1),
            intermediate_target_ids) * intermediate_target_weights
        el_im_acc = el_im_correct_mask.sum()

        el_im_acc, _ = metric_utils.compute_weighted_accuracy(
            intermediate_entity_attention, intermediate_target_ids,
            intermediate_target_weights)

        intermediate_entity_cos_sim = loss_helpers[
            'intermediate_entity_cos_sim'][batch['mention_target_indices'],
                                           mention_target_ids]

        metrics['el_intermediate'] = {
            'loss':
                el_im_loss,
            'acc':
                el_im_acc,
            'cos_sim':
                jnp.dot(intermediate_entity_cos_sim,
                        batch['mention_target_weights']),
            'denominator':
                el_im_denom,
        }
        metrics['el_intermediate_masked'] = {
            'acc':
                jnp.dot(el_im_correct_mask, mention_is_masked),
            'denominator':
                jnp.dot(batch['mention_target_weights'],
                        batch['mention_target_is_masked']),
        }
        metrics['el_intermediate_non_masked'] = {
            'acc':
                jnp.dot(el_im_correct_mask, (1 - mention_is_masked)),
            'denominator':
                jnp.dot(batch['mention_target_weights'],
                        (1 - batch['mention_target_is_masked'])),
        }

        im_final_mention_encodings_cos_sim = jut.cosine_similarity(
            intermediate_target_mention_encodings,
            loss_helpers['target_mention_encodings'])
        metrics['im_final_mention_encodings'] = {
            'cos_sim':
                jnp.dot(im_final_mention_encodings_cos_sim,
                        batch['mention_target_weights']),
            'denominator':
                batch['mention_target_weights'].sum(),
        }

      (el_final_loss, el_final_metrics,
       (el_final_acc_per_mention,
        el_final_weight_per_mention)) = mention_losses.entity_linking_loss(
            loss_helpers['target_mention_encodings'],
            loss_helpers['entity_embeddings'], mention_target_ids,
            batch['mention_target_weights'], el_score_mode)
      el_final_denom = el_final_metrics['denominator']
      metrics['el_final'] = el_final_metrics
      metrics['el_final_masked'] = {
          'acc':
              jnp.dot(el_final_acc_per_mention,
                      el_final_weight_per_mention * mention_target_is_masked),
          'denominator':
              jnp.dot(el_final_weight_per_mention, mention_target_is_masked),
      }
      metrics['el_final_non_masked'] = {
          'acc':
              jnp.dot(
                  el_final_acc_per_mention,
                  el_final_weight_per_mention * mention_target_is_not_masked),
          'denominator':
              jnp.dot(el_final_weight_per_mention,
                      mention_target_is_not_masked),
      }

      loss = mlm_weight * mlm_loss / mlm_denom
      loss += el_im_weight * el_im_loss / el_im_denom
      loss += el_final_weight * el_final_loss / el_final_denom

      if mtb_im_weight > 0:
        (mtb_im_loss, mtb_im_metrics) = mention_losses.mtb_loss(
            intermediate_target_mention_encodings,
            batch['mention_target_batch_positions'], mention_target_ids,
            batch_size, mtb_score_mode, mention_target_is_masked, 'im_')
        mtb_im_denom = mtb_im_metrics['im_mtb']['denominator']
        loss += mtb_im_weight * mtb_im_loss / mtb_im_denom
        metrics.update(mtb_im_metrics)

      if mtb_final_weight > 0:
        (mtb_final_loss, mtb_final_metrics) = mention_losses.mtb_loss(
            loss_helpers['target_mention_encodings'],
            batch['mention_target_batch_positions'], mention_target_ids,
            batch_size, mtb_score_mode, mention_target_is_masked, 'final_')
        mtb_final_denom = mtb_final_metrics['final_mtb']['denominator']
        loss += mtb_final_weight * mtb_final_loss / mtb_final_denom
        metrics.update(mtb_final_metrics)

      metrics['agg'] = {
          'loss': loss,
          'denominator': 1.0,
      }
      return loss, metrics, {}
Exemplo n.º 10
0
    def __call__(
        self,
        queries: Array,
        memory_keys: Array,
        memory_identifiers: Array,
        memory_entity_ids: Array,
        memory_values: Optional[Array] = None,
        text_identifiers: Optional[Array] = None,
        memory_text_entities: Optional[Array] = None,
        same_passage_memory_policy: str = 'disallow',
    ) -> Dict[str, Array]:
        """Perform attention update over memory table.

    Args:
      queries: [n_mentions, hidden_size] query vectors.
      memory_keys: [rows, values per row, key_dim] mention memory keys. The
        number of rows in the memory table governs the recall vs speed of the
        topk similarity search. Search is performed by taking max over each row,
        and then top-k between rows. Distributing the same values over more rows
        leads to higher recall but slower search.
      memory_identifiers: [memory_size] identifier for memory vectors.
      memory_entity_ids: [memory_size] entity ids for mentions in memory table
      memory_values: [values, memory_dim] if separate keys and values.
      text_identifiers: [n_mentions] search will not retrieve memory vectors
        with the same identifier as passage mention.
      memory_text_entities: [n_mentions, n_memory_text_entities] entity ids for
        passages where memories are coming from.
      same_passage_memory_policy: how to treat mentions from the same passage.
        Possible options: `allow`, `disallow` and `only`.

    Returns:
      Dictionary with retrieval results, including values, entity IDs, attention
      weights and etc.
    """
        _assert_array_is_integer_or_none(memory_entity_ids,
                                         'memory_entity_ids')
        _assert_array_is_integer_or_none(memory_identifiers,
                                         'memory_identifiers')
        _assert_array_is_integer_or_none(memory_text_entities,
                                         'memory_text_entities')
        _assert_array_is_integer_or_none(text_identifiers, 'text_identifiers')

        retrieval_result = {}
        memory_size = memory_keys.shape[0] * memory_keys.shape[1]
        memory_key_dim = memory_keys.shape[2]
        n_queries = queries.shape[0]

        # We generate a version of the queries with stop gradient to use as input to
        # the topk similarity layer. We actually do want gradient to flow to the
        # queries, but backward differentiation over the topk layer yields
        # inefficient HLO ops. Instead we use queries with gradient to recompute
        # attention scores later.
        queries_sg = jax.lax.stop_gradient(queries)

        # Gather queries from all devices. Each device contains a shard of the
        # mention memory. Ultimately we want to perform search over the entire
        # mention memory, so we gather mentions from all devices, apply similarity
        # search over the local shard, then distribute the results back.
        gathered_queries = jax.lax.all_gather(queries_sg, 'batch')
        if text_identifiers is not None:
            gathered_identifiers = jax.lax.all_gather(text_identifiers,
                                                      'batch')

        n_devices = gathered_queries.shape[0]
        gathered_queries = gathered_queries.reshape(n_devices * n_queries,
                                                    memory_key_dim)

        # Perform top-k similarity search over queries, yielding
        # top_values: (n_devices * queries_per_device, k_top_device, memory_key_dim)
        # top_ids: (n_devices * queries_per_device, k_top_device)
        top_keys, top_scores, top_ids = self.topk_similarity(
            gathered_queries, memory_keys)

        if memory_values is not None:
            top_values = memory_values[top_ids]
        else:
            top_values = top_keys
        memory_dim = top_values.shape[-1]

        # Also return entity ids
        top_entity_ids = memory_entity_ids[top_ids]

        top_values = top_values.reshape(n_devices, n_queries,
                                        self.k_top_device, memory_dim)
        top_entity_ids = top_entity_ids.reshape(n_devices, n_queries,
                                                self.k_top_device)
        global_top_ids = top_ids.reshape(n_devices, n_queries,
                                         self.k_top_device)

        # Now that we have searched the local shard using queries from all devices,
        # we need to distribute the search results back to all devices. Applying
        # pswapaxes followed by swapaxes makes us go from
        # (devices, queries per device, local shard retrievals) to
        # (local queries, devices, memory retrievals per device).
        (top_values, top_entity_ids, global_top_ids) = jax.lax.pswapaxes(
            (top_values, top_entity_ids, global_top_ids),
            axis_name='batch',
            axis=0)

        top_values = jnp.swapaxes(top_values, 0, 1)
        top_entity_ids = jnp.swapaxes(top_entity_ids, 0, 1)

        # (local queries, devices, memory retrievals per device).
        global_top_ids = jnp.swapaxes(global_top_ids, 0, 1)
        # IDs are device specific. Therefore, we need to convert them to `global`
        # memory IDs. Note that every devices operates on a memory of the same size.
        # Therefore, IDs on the device 0 don't need to be changed, we need to add
        # `memory_size` to IDs from the device 1, 2 * `memory_size` to IDs from the
        # device 2, etc.
        global_top_ids = global_top_ids + jnp.arange(n_devices).reshape(
            1, -1, 1) * memory_size

        # Reshape results to (local_queries, global retrievals).
        k_top = n_devices * self.k_top_device
        top_values = top_values.reshape(n_queries, k_top, memory_dim)
        top_entity_ids = top_entity_ids.reshape(n_queries, k_top)
        global_top_ids = global_top_ids.reshape(n_queries, k_top)

        # At this point, we have selected `k_top = n_devices * self.k_top_device`
        # memories for every query. The selection process is approximate since
        # we retrieve `self.k_top_device` memories from every device and then
        # just concatenate the results.
        # Due to computational constraints we may wish to limit the number
        # of memories per query, so we subselect even further and keep only
        # `self.k_top_post_selection` retrieved memories for every query.
        if self.k_top_post_selection is not None:
            top_scores = top_scores.reshape(n_devices, n_queries,
                                            self.k_top_device)
            top_scores = jax.lax.pswapaxes(top_scores,
                                           axis_name='batch',
                                           axis=0)
            top_scores = jnp.swapaxes(top_scores, 0, 1)
            top_scores = top_scores.reshape(n_queries, k_top)
            # Take k highest scores among all rows.
            # pylint:disable=invalid-unary-operand-type
            top_post_selection_index = jnp.argsort(
                top_scores, axis=-1)[:, :-self.k_top_post_selection - 1:-1]
            # pylint:enable=invalid-unary-operand-type
            top_values = jut.matmul_slice(top_values, top_post_selection_index)
            top_entity_ids = jut.matmul_slice(top_entity_ids,
                                              top_post_selection_index)
            global_top_ids = jut.matmul_slice(global_top_ids,
                                              top_post_selection_index)

        # If we use separate memory values, distribute keys back also.
        if memory_values is not None:
            top_keys = top_keys.reshape(n_devices, n_queries,
                                        self.k_top_device, memory_key_dim)
            top_keys = jax.lax.pswapaxes(top_keys, axis_name='batch', axis=0)
            top_keys = jnp.swapaxes(top_keys, 0, 1)
            top_keys = top_keys.reshape(n_queries, k_top, memory_key_dim)
            if self.k_top_post_selection is not None:
                top_keys = jut.matmul_slice(top_keys, top_post_selection_index)
        else:
            top_keys = top_values

        retrieval_result['top_entity_ids'] = top_entity_ids
        retrieval_result['top_memory_ids'] = global_top_ids
        retrieval_result['top_values'] = top_values

        # We re-compute top scores using the queries with gradient (wg) to make sure
        # the mention encoder and the rest of the model receives gradient
        top_scores_wg = jnp.einsum('qd,qkd->qk', queries, top_keys)

        retrieval_result[
            'memory_attention_scores_with_disallowed'] = top_scores_wg

        # We want to disallow some mentions from being retrieved (i.e. from same
        # passage during pre-training). Here we mask retrieved mentions which have
        # the same identifier as the query.
        if text_identifiers is not None:
            top_ids = top_ids.reshape(n_devices, n_queries, self.k_top_device)
            gathered_identifiers = gathered_identifiers.reshape(
                n_devices, n_queries, 1)
            identifier_mask = (
                memory_identifiers[top_ids] == gathered_identifiers)

            # We manually cast `identifier_mask` into int32. Otherwise, `pswapaxes`
            # which is known to have undefined behaviour on CPU, "corrupts" a vector
            # making it effectively int32, while keeping boolean dtype. This in turn
            # leads to a compilation error for the einsum operation in the
            # `matmul_slice` (types mismatch).
            identifier_mask = identifier_mask.astype(dtype=jnp.int32)
            identifier_mask = jax.lax.pswapaxes(identifier_mask,
                                                axis_name='batch',
                                                axis=0)
            identifier_mask = jnp.swapaxes(identifier_mask, 0, 1)
            identifier_mask = identifier_mask.reshape(n_queries, k_top)
            if self.k_top_post_selection is not None:
                identifier_mask = jut.matmul_slice(identifier_mask,
                                                   top_post_selection_index)
            retrieval_result[
                'memory_attention_disallowed_mask'] = identifier_mask.astype(
                    jnp.bool_)
            identifier_mask = identifier_mask.astype(top_scores_wg.dtype)

            # Depending on `same_passage_memory_policy` we treat memories from the
            # same passage as query mentions differently.
            if same_passage_memory_policy == 'disallow':
                top_scores_wg = top_scores_wg - identifier_mask * default_values.LARGE_NUMBER
            elif same_passage_memory_policy == 'only':
                top_scores_wg = top_scores_wg - (
                    1.0 - identifier_mask) * default_values.LARGE_NUMBER
            elif same_passage_memory_policy == 'allow':
                pass
            else:
                raise ValueError(
                    'Unknown value for `same_passage_memory_policy: %s' %
                    same_passage_memory_policy)
            n_disallowed = identifier_mask.sum()
            retrieval_result['n_disallowed'] = n_disallowed

        if memory_text_entities is not None:
            top_ids = top_ids.reshape(n_devices, n_queries, self.k_top_device)
            # shape [n_devices, n_queries, k_top_device, n_text_entities_per_passage]
            top_text_entities = memory_text_entities[top_ids]
            top_text_entities = jax.lax.pswapaxes(top_text_entities,
                                                  axis_name='batch',
                                                  axis=0)
            # shape [n_queries, n_devices, k_top_device, n_text_entities_per_passage]
            top_text_entities = jnp.swapaxes(top_text_entities, 0, 1)
            # shape [n_queries, n_devices * k_top_device, n_text_entities_per_passage]
            top_text_entities = top_text_entities.reshape(n_queries, k_top, -1)
            if self.k_top_post_selection is not None:
                top_text_entities = jut.matmul_slice(top_text_entities,
                                                     top_post_selection_index)
            retrieval_result['memory_top_text_entities'] = top_text_entities

        # We perform dot product attention using retrieved memory vectors as key,
        # dense projection of retrieved vectors as value and value and mention
        # representations as query.
        attention_weights = nn.softmax(top_scores_wg, axis=-1)
        retrieval_result['memory_attention_weights'] = attention_weights

        return retrieval_result