def split_top_k(split_queries): split_scores = jnp.einsum('qd,rvd->qrv', split_queries, table) # Find highest scoring vector for each row. top_id_by_row = jnp.argmax(split_scores, axis=-1) top_score_by_row = jnp.max(split_scores, axis=-1) # Take k highest scores among all rows. top_row_idx = jnp.argsort(top_score_by_row, axis=-1)[:, :-self.k_top - 1:-1] # Sub-select best indices for k best rows. ids_by_topk_row = jut.matmul_slice(top_id_by_row, top_row_idx) # Gather highest scoring vectors for k best rows. split_topk_values = table[top_row_idx, ids_by_topk_row] # Convert row indices to indices into flattened table. top_table_id_by_row = top_id_by_row + jnp.arange( 0, table_size, scores_per_row) # Get best ids into flattened table. split_topk_ids = jut.matmul_slice(top_table_id_by_row, top_row_idx) split_topk_scores = jut.matmul_slice(top_score_by_row, top_row_idx) return split_topk_values, split_topk_scores, split_topk_ids
def split_top_k(split_queries: Array) -> Tuple[Array, Array, Array]: # Find most similar clusters prototype_scores = jnp.einsum('qd,pd->qp', split_queries, prototypes) top_indices = jax.lax.top_k(prototype_scores, self.n_search)[1] # Perform approximate top-k similarity search over most similar clusters. selected_data = table[top_indices] split_scores = jnp.einsum('qd,qcrvd->qcrv', split_queries, selected_data) # Find highest scoring vector for each row. top_id_by_row = jnp.argmax(split_scores, axis=-1) top_score_by_row = jnp.max(split_scores, axis=-1) top_id_by_row = top_id_by_row.reshape( queries_per_split, self.n_search * rows_per_cluster) top_score_by_row = top_score_by_row.reshape( queries_per_split, self.n_search * rows_per_cluster) # Take k highest scores among all rows. top_row_idx = jnp.argsort(top_score_by_row, axis=-1)[:, :-self.k_top - 1:-1] # Sub-select best indices for k best rows. ids_by_topk_row = jut.matmul_slice(top_id_by_row, top_row_idx) # Gather highest scoring vectors for k best rows. query_index = jnp.arange(queries_per_split).reshape(-1, 1).tile( [1, self.k_top]) top_cluster_idx, top_cluster_row_idx = jnp.divmod( top_row_idx, rows_per_cluster) split_topk_values = selected_data[query_index, top_cluster_idx, top_cluster_row_idx, ids_by_topk_row] row_offset = jnp.mod( jnp.arange(0, self.n_search * values_per_cluster, values_per_row), values_per_cluster) cluster_offset = jnp.arange(0, table_size, values_per_cluster) # Convert row indices to indices into flattened table. top_table_id_by_row = top_id_by_row + row_offset.reshape( 1, -1) + cluster_offset[top_indices].repeat(rows_per_cluster, axis=-1) # Get best ids into flattened table. split_topk_ids = jut.matmul_slice(top_table_id_by_row, top_row_idx) split_topk_scores = jut.matmul_slice(top_score_by_row, top_row_idx) return split_topk_values, split_topk_scores, split_topk_ids
def __call__( self, encoded_input: Array, mlm_target_positions: Array, shared_embedding: Array, ) -> Array: """Perform masked language modeling scoring. Args: encoded_input: [bsz, n_tokens, hidden_size]. mlm_target_positions: [bsz, max_mlm_targets] positions of mlm targets in passage. shared_embedding: [vocab_size, hidden_size] word embedding array, shared with initial embedding. Returns: Array of masked language modeling logits. """ target_encodings = jut.matmul_slice(encoded_input, mlm_target_positions) target_encodings = self.dense(target_encodings) target_encodings = nn.gelu(target_encodings) target_encodings = self.layer_norm(target_encodings) mlm_logits = self.embedding_dense.apply( {'params': { 'kernel': shared_embedding.T }}, target_encodings) mlm_logits = mlm_logits + self.bias return mlm_logits
def test_slice_values_int(self, bsz, seq_len, index_len, dim): # no batch dim array = np.random.randint(_MAX_INT_VALUE, size=(seq_len, dim)) indices = np.random.randint(seq_len, size=(index_len)) matmul_slice = jut.matmul_slice(array, indices) vmap_slice = array[indices] self.assertTrue(jnp.allclose(matmul_slice, vmap_slice)) # 2d array array = np.random.randint(_MAX_INT_VALUE, size=(bsz, seq_len)) indices = np.random.randint(seq_len, size=(bsz, index_len)) matmul_slice = jut.matmul_slice(array, indices) vmap_slice = jut.vmap_slice(array, indices) self.assertTrue(jnp.allclose(matmul_slice, vmap_slice)) # 3d array array = np.random.randint(_MAX_INT_VALUE, size=(bsz, seq_len, dim)) indices = np.random.randint(seq_len, size=(bsz, index_len)) matmul_slice = jut.matmul_slice(array, indices) vmap_slice = jut.vmap_slice(array, indices) self.assertTrue(jnp.allclose(matmul_slice, vmap_slice))
def __call__(self, batch: Dict[str, Array], deterministic: bool): _, loss_helpers, logging_helpers = self.encoder.forward( batch, deterministic) mention_encodings = loss_helpers[self.mention_encodings_feature] subject_mention_encodings = jut.matmul_slice( mention_encodings, batch['mention_subject_indices']) object_mention_encodings = jut.matmul_slice( mention_encodings, batch['mention_object_indices']) relation_encodings = jnp.concatenate( [subject_mention_encodings, object_mention_encodings], -1) for mlp_layer in self.classification_mlp_layers: relation_encodings = mlp_layer(relation_encodings, deterministic) classifier_logits = self.linear_classifier(relation_encodings) loss_helpers['classifier_logits'] = classifier_logits return loss_helpers, logging_helpers
def process_el_im_loss(loss, weight, prefix=''): memory_attention_weights = loss_helpers[ prefix + 'memory_attention_weights'] memory_entity_ids = loss_helpers[prefix + 'top_entity_ids'] target_mentions_memory_attention_weights = jut.matmul_slice( memory_attention_weights, batch['mention_target_indices']) intermediate_entity_ids = jut.matmul_slice( memory_entity_ids, batch['mention_target_indices']) el_loss_intermediate, same_entity_avg_prob, el_im_denom = metric_utils.compute_loss_and_prob_from_probs_with_duplicates( target_mentions_memory_attention_weights, intermediate_entity_ids, mention_target_ids, batch['mention_target_weights']) if weight > 0: loss += weight * el_loss_intermediate / el_im_denom metrics[prefix + 'el_intermediate'] = { 'loss': el_loss_intermediate, 'same_entity_avg_prob': same_entity_avg_prob, 'denominator': el_im_denom, } return loss
def forward( self, batch: Dict[str, Array], deterministic: bool, ) -> Tuple[Array, Dict[str, Array], Dict[str, Array]]: loss_helpers = {} logging_helpers = {} embedded_input = self.embedder({ 'token_ids': batch['text_ids'], 'position_ids': batch['position_ids'], 'segment_ids': batch['segment_ids'] }) embedded_input = self.embeddings_layer_norm(embedded_input) embedded_input = self.embeddings_dropout(embedded_input, deterministic=deterministic) loss_helpers['word_embeddings'] = self.embedder.variables['params'][ 'embedders_token_ids']['embedding'] attention_mask = batch['text_mask'] encoding = self.initial_encoder(encoding=embedded_input, attention_mask=attention_mask, deterministic=deterministic) memory_values = jnp.asarray( self.memory_values.value, dtype=self.dtype) if self.separate_memory_values else None memory_keys = jnp.asarray(self.memory_keys.value, dtype=self.dtype) memory_entity_ids = self.memory_entity_ids.value memory_identifiers = self.memory_identifiers.value loss_helpers['memory_values'] = memory_values loss_helpers['memory_keys'] = memory_keys loss_helpers['memory_entity_ids'] = memory_entity_ids loss_helpers['memory_identifiers'] = memory_identifiers def apply_memory_attention(memory_layer, encoding, prefix=''): encoding, mem_loss_helpers, mem_logging_helpers = memory_layer( encoded_input=encoding, mention_batch_positions=batch['mention_batch_positions'], mention_start_positions=batch['mention_start_positions'], mention_end_positions=batch['mention_end_positions'], mention_mask=batch['mention_mask'], memory_keys=memory_keys, memory_identifiers=memory_identifiers, memory_entity_ids=memory_entity_ids, deterministic=deterministic, memory_values=memory_values, text_identifiers=batch.get('text_identifiers', None), memory_text_entities=(self.memory_text_entities.value if self.memory_text_entities is not None else None), same_passage_memory_policy=self.same_passage_memory_policy, ) loss_helpers.update({ prefix + key: value for key, value in mem_loss_helpers.items() }) logging_helpers.update({ prefix + key: value for key, value in mem_logging_helpers.items() }) return encoding if self.num_intermediate_layers is None: encoding = apply_memory_attention(self.memory_attention_layer, encoding) else: encoding = apply_memory_attention( self.intermediate_memory_attention_layer, encoding) encoding = self.intermediate_encoder(encoding=encoding, attention_mask=attention_mask, deterministic=deterministic) encoding = apply_memory_attention( self.final_memory_attention_layer, encoding, 'second_') encoding = self.final_encoder(encoding=encoding, attention_mask=attention_mask, deterministic=deterministic) if 'mention_target_batch_positions' in batch: mention_start_final_encodings = jut.matmul_2d_index_select( encoding, (batch['mention_target_batch_positions'], batch['mention_target_start_positions'])) mention_end_final_encodings = jut.matmul_2d_index_select( encoding, (batch['mention_target_batch_positions'], batch['mention_target_end_positions'])) loss_helpers[ 'intermediate_target_mention_encodings'] = jut.matmul_slice( loss_helpers['memory_attention_mention_encodings'], batch['mention_target_indices']) if self.num_intermediate_layers is not None: loss_helpers[ 'second_intermediate_target_mention_encodings'] = jut.matmul_slice( loss_helpers[ 'second_memory_attention_mention_encodings'], batch['mention_target_indices']) loss_helpers['target_mention_encodings'] = self.mention_projector( jnp.concatenate((mention_start_final_encodings, mention_end_final_encodings), axis=-1)) # Final retrieval layer is only applied over target mentions. if self.apply_final_retrieval: queries = self.final_query_projector( loss_helpers['target_mention_encodings']) retrieval_result = self.final_memory_retrieval_layer( queries=queries, memory_keys=memory_keys, memory_identifiers=memory_identifiers, memory_entity_ids=memory_entity_ids, memory_values=memory_values, text_identifiers=None, memory_text_entities=None, same_passage_memory_policy='disallow', ) loss_helpers.update( {'final_' + k: v for k, v in retrieval_result.items()}) return encoding, loss_helpers, logging_helpers
def forward(self, batch: Dict[str, Array], deterministic: bool): loss_helpers = {} logging_helpers = {} embedded_input = self.embedder({ 'token_ids': batch['text_ids'], 'position_ids': batch['position_ids'], 'segment_ids': batch['segment_ids'] }) embedded_input = self.embeddings_layer_norm(embedded_input) embedded_input = self.embeddings_dropout(embedded_input, deterministic) loss_helpers['word_embeddings'] = self.embedder.variables['params'][ 'embedders_token_ids']['embedding'] attention_mask = batch['text_mask'] encoding = self.initial_encoder(encoding=embedded_input, attention_mask=attention_mask, deterministic=deterministic) if not self.no_retrieval: encoding = self.retrieval_update_layer( encoded_input=encoding, retrieval_values=jnp.expand_dims( # [max_retrieval_indices, retrieval_dim] batch['retrieval_mention_values'], -2), retrieval_scores=jnp.expand_dims( # [max_retrieval_indices] batch['retrieval_mention_scores'], -1), mention_batch_positions=batch[ 'retrieval_mention_batch_positions'], mention_start_positions=batch[ 'retrieval_mention_start_positions'], mention_end_positions=batch['retrieval_mention_end_positions'], mention_mask=batch['retrieval_mention_mask'], deterministic=deterministic) encoding = self.final_encoder(encoding=encoding, attention_mask=attention_mask, deterministic=deterministic) mention_target_batch_positions = jut.matmul_slice( batch['mention_batch_positions'], batch['mention_target_indices']) mention_target_start_positions = jut.matmul_slice( batch['mention_start_positions'], batch['mention_target_indices']) mention_target_end_positions = jut.matmul_slice( batch['mention_end_positions'], batch['mention_target_indices']) mention_start_final_encodings = jut.matmul_2d_index_select( encoding, (mention_target_batch_positions, mention_target_start_positions)) mention_end_final_encodings = jut.matmul_2d_index_select( encoding, (mention_target_batch_positions, mention_target_end_positions)) loss_helpers['target_mention_encodings'] = self.mention_projector( jnp.concatenate( (mention_start_final_encodings, mention_end_final_encodings), axis=-1)) return encoding, loss_helpers, logging_helpers
def loss_fn( model_config: ml_collections.FrozenConfigDict, model_params: Dict[str, Any], model_vars: Dict[str, Any], # pylint: disable=unused-argument batch: Dict[str, Any], deterministic: bool, dropout_rng: Optional[Dict[str, Array]] = None, ) -> Tuple[float, MetricGroups, Dict[str, Any]]: """Task-specific loss function. See BaseTask.""" batch_size = batch['text_ids'].shape[0] loss_helpers, logging_helpers = cls.build_model(model_config).apply( # pylint: disable=unused-variable {'params': model_params}, batch, deterministic=deterministic, rngs=dropout_rng) mention_target_is_masked = batch['mention_target_is_masked'] mention_target_is_not_masked = 1 - batch['mention_target_is_masked'] mention_target_ids = batch['mention_target_ids'] mention_target_ids = mention_target_ids * batch['mention_target_weights'] mlm_logits = loss_helpers['mlm_logits'] mlm_loss, mlm_denom = metric_utils.compute_weighted_cross_entropy( mlm_logits, batch['mlm_target_ids'], batch['mlm_target_weights']) mlm_correct_mask = jnp.equal( jnp.argmax(mlm_logits, axis=-1), batch['mlm_target_ids']) * batch['mlm_target_weights'] mlm_acc = mlm_correct_mask.sum() mlm_mention_acc = (mlm_correct_mask * batch['mlm_target_is_mention']).sum() mlm_mention_denom = (batch['mlm_target_weights'] * batch['mlm_target_is_mention']).sum() mlm_non_mention_acc = (mlm_correct_mask * (1 - batch['mlm_target_is_mention'])).sum() mlm_non_mention_denom = (batch['mlm_target_weights'] * (1 - batch['mlm_target_is_mention'])).sum() metrics = { 'mlm': { 'loss': mlm_loss, 'acc': mlm_acc, 'denominator': mlm_denom, }, 'mlm_mention': { 'acc': mlm_mention_acc, 'denominator': mlm_mention_denom, }, 'mlm_non_mention': { 'acc': mlm_non_mention_acc, 'denominator': mlm_non_mention_denom, }, } if 'intermediate_mention_encodings' in loss_helpers: intermediate_target_mention_encodings = jut.matmul_slice( loss_helpers['intermediate_mention_encodings'], batch['mention_target_indices']) else: intermediate_target_mention_encodings = loss_helpers[ 'im_target_mention_encodings'] if model_config.encoder_config.get('no_entity_attention', False): (el_im_loss, el_im_metrics, (el_im_acc_per_mention, el_im_weight_per_mention)) = mention_losses.entity_linking_loss( intermediate_target_mention_encodings, loss_helpers['entity_embeddings'], mention_target_ids, batch['mention_target_weights'], el_score_mode) el_im_denom = el_im_metrics['denominator'] metrics['el_intermediate'] = el_im_metrics metrics['el_intermediate_masked'] = { 'acc': jnp.dot(el_im_acc_per_mention, el_im_weight_per_mention * mention_target_is_masked), 'denominator': jnp.dot(el_im_weight_per_mention, mention_target_is_not_masked), } metrics['el_intermediate_non_masked'] = { 'acc': jnp.dot(el_im_acc_per_mention, el_im_weight_per_mention * mention_target_is_masked), 'denominator': jnp.dot(el_im_weight_per_mention, mention_target_is_not_masked), } else: intermediate_entity_attention = loss_helpers[ 'intermediate_entity_attention'] # Construct targets and ids for intermediate entity linking loss intermediate_target_ids = jnp.zeros_like(batch['mention_mask']) intermediate_target_ids = intermediate_target_ids.at[ batch['mention_target_indices']].add( mention_target_ids * batch['mention_target_weights']) intermediate_target_weights = jnp.zeros_like( batch['mention_mask'], dtype=intermediate_entity_attention.dtype) intermediate_target_weights = intermediate_target_weights.at[ batch['mention_target_indices']].add( batch['mention_target_weights']) mention_is_masked = jnp.zeros_like(batch['mention_mask']) mention_is_masked = mention_is_masked.at[ batch['mention_target_indices']].add( mention_target_is_masked * batch['mention_target_weights']) el_im_loss, el_im_denom = metric_utils.compute_weighted_cross_entropy( intermediate_entity_attention, intermediate_target_ids, intermediate_target_weights, inputs_are_prob=True) el_im_correct_mask = jnp.equal( jnp.argmax(intermediate_entity_attention, axis=-1), intermediate_target_ids) * intermediate_target_weights el_im_acc = el_im_correct_mask.sum() el_im_acc, _ = metric_utils.compute_weighted_accuracy( intermediate_entity_attention, intermediate_target_ids, intermediate_target_weights) intermediate_entity_cos_sim = loss_helpers[ 'intermediate_entity_cos_sim'][batch['mention_target_indices'], mention_target_ids] metrics['el_intermediate'] = { 'loss': el_im_loss, 'acc': el_im_acc, 'cos_sim': jnp.dot(intermediate_entity_cos_sim, batch['mention_target_weights']), 'denominator': el_im_denom, } metrics['el_intermediate_masked'] = { 'acc': jnp.dot(el_im_correct_mask, mention_is_masked), 'denominator': jnp.dot(batch['mention_target_weights'], batch['mention_target_is_masked']), } metrics['el_intermediate_non_masked'] = { 'acc': jnp.dot(el_im_correct_mask, (1 - mention_is_masked)), 'denominator': jnp.dot(batch['mention_target_weights'], (1 - batch['mention_target_is_masked'])), } im_final_mention_encodings_cos_sim = jut.cosine_similarity( intermediate_target_mention_encodings, loss_helpers['target_mention_encodings']) metrics['im_final_mention_encodings'] = { 'cos_sim': jnp.dot(im_final_mention_encodings_cos_sim, batch['mention_target_weights']), 'denominator': batch['mention_target_weights'].sum(), } (el_final_loss, el_final_metrics, (el_final_acc_per_mention, el_final_weight_per_mention)) = mention_losses.entity_linking_loss( loss_helpers['target_mention_encodings'], loss_helpers['entity_embeddings'], mention_target_ids, batch['mention_target_weights'], el_score_mode) el_final_denom = el_final_metrics['denominator'] metrics['el_final'] = el_final_metrics metrics['el_final_masked'] = { 'acc': jnp.dot(el_final_acc_per_mention, el_final_weight_per_mention * mention_target_is_masked), 'denominator': jnp.dot(el_final_weight_per_mention, mention_target_is_masked), } metrics['el_final_non_masked'] = { 'acc': jnp.dot( el_final_acc_per_mention, el_final_weight_per_mention * mention_target_is_not_masked), 'denominator': jnp.dot(el_final_weight_per_mention, mention_target_is_not_masked), } loss = mlm_weight * mlm_loss / mlm_denom loss += el_im_weight * el_im_loss / el_im_denom loss += el_final_weight * el_final_loss / el_final_denom if mtb_im_weight > 0: (mtb_im_loss, mtb_im_metrics) = mention_losses.mtb_loss( intermediate_target_mention_encodings, batch['mention_target_batch_positions'], mention_target_ids, batch_size, mtb_score_mode, mention_target_is_masked, 'im_') mtb_im_denom = mtb_im_metrics['im_mtb']['denominator'] loss += mtb_im_weight * mtb_im_loss / mtb_im_denom metrics.update(mtb_im_metrics) if mtb_final_weight > 0: (mtb_final_loss, mtb_final_metrics) = mention_losses.mtb_loss( loss_helpers['target_mention_encodings'], batch['mention_target_batch_positions'], mention_target_ids, batch_size, mtb_score_mode, mention_target_is_masked, 'final_') mtb_final_denom = mtb_final_metrics['final_mtb']['denominator'] loss += mtb_final_weight * mtb_final_loss / mtb_final_denom metrics.update(mtb_final_metrics) metrics['agg'] = { 'loss': loss, 'denominator': 1.0, } return loss, metrics, {}
def __call__( self, queries: Array, memory_keys: Array, memory_identifiers: Array, memory_entity_ids: Array, memory_values: Optional[Array] = None, text_identifiers: Optional[Array] = None, memory_text_entities: Optional[Array] = None, same_passage_memory_policy: str = 'disallow', ) -> Dict[str, Array]: """Perform attention update over memory table. Args: queries: [n_mentions, hidden_size] query vectors. memory_keys: [rows, values per row, key_dim] mention memory keys. The number of rows in the memory table governs the recall vs speed of the topk similarity search. Search is performed by taking max over each row, and then top-k between rows. Distributing the same values over more rows leads to higher recall but slower search. memory_identifiers: [memory_size] identifier for memory vectors. memory_entity_ids: [memory_size] entity ids for mentions in memory table memory_values: [values, memory_dim] if separate keys and values. text_identifiers: [n_mentions] search will not retrieve memory vectors with the same identifier as passage mention. memory_text_entities: [n_mentions, n_memory_text_entities] entity ids for passages where memories are coming from. same_passage_memory_policy: how to treat mentions from the same passage. Possible options: `allow`, `disallow` and `only`. Returns: Dictionary with retrieval results, including values, entity IDs, attention weights and etc. """ _assert_array_is_integer_or_none(memory_entity_ids, 'memory_entity_ids') _assert_array_is_integer_or_none(memory_identifiers, 'memory_identifiers') _assert_array_is_integer_or_none(memory_text_entities, 'memory_text_entities') _assert_array_is_integer_or_none(text_identifiers, 'text_identifiers') retrieval_result = {} memory_size = memory_keys.shape[0] * memory_keys.shape[1] memory_key_dim = memory_keys.shape[2] n_queries = queries.shape[0] # We generate a version of the queries with stop gradient to use as input to # the topk similarity layer. We actually do want gradient to flow to the # queries, but backward differentiation over the topk layer yields # inefficient HLO ops. Instead we use queries with gradient to recompute # attention scores later. queries_sg = jax.lax.stop_gradient(queries) # Gather queries from all devices. Each device contains a shard of the # mention memory. Ultimately we want to perform search over the entire # mention memory, so we gather mentions from all devices, apply similarity # search over the local shard, then distribute the results back. gathered_queries = jax.lax.all_gather(queries_sg, 'batch') if text_identifiers is not None: gathered_identifiers = jax.lax.all_gather(text_identifiers, 'batch') n_devices = gathered_queries.shape[0] gathered_queries = gathered_queries.reshape(n_devices * n_queries, memory_key_dim) # Perform top-k similarity search over queries, yielding # top_values: (n_devices * queries_per_device, k_top_device, memory_key_dim) # top_ids: (n_devices * queries_per_device, k_top_device) top_keys, top_scores, top_ids = self.topk_similarity( gathered_queries, memory_keys) if memory_values is not None: top_values = memory_values[top_ids] else: top_values = top_keys memory_dim = top_values.shape[-1] # Also return entity ids top_entity_ids = memory_entity_ids[top_ids] top_values = top_values.reshape(n_devices, n_queries, self.k_top_device, memory_dim) top_entity_ids = top_entity_ids.reshape(n_devices, n_queries, self.k_top_device) global_top_ids = top_ids.reshape(n_devices, n_queries, self.k_top_device) # Now that we have searched the local shard using queries from all devices, # we need to distribute the search results back to all devices. Applying # pswapaxes followed by swapaxes makes us go from # (devices, queries per device, local shard retrievals) to # (local queries, devices, memory retrievals per device). (top_values, top_entity_ids, global_top_ids) = jax.lax.pswapaxes( (top_values, top_entity_ids, global_top_ids), axis_name='batch', axis=0) top_values = jnp.swapaxes(top_values, 0, 1) top_entity_ids = jnp.swapaxes(top_entity_ids, 0, 1) # (local queries, devices, memory retrievals per device). global_top_ids = jnp.swapaxes(global_top_ids, 0, 1) # IDs are device specific. Therefore, we need to convert them to `global` # memory IDs. Note that every devices operates on a memory of the same size. # Therefore, IDs on the device 0 don't need to be changed, we need to add # `memory_size` to IDs from the device 1, 2 * `memory_size` to IDs from the # device 2, etc. global_top_ids = global_top_ids + jnp.arange(n_devices).reshape( 1, -1, 1) * memory_size # Reshape results to (local_queries, global retrievals). k_top = n_devices * self.k_top_device top_values = top_values.reshape(n_queries, k_top, memory_dim) top_entity_ids = top_entity_ids.reshape(n_queries, k_top) global_top_ids = global_top_ids.reshape(n_queries, k_top) # At this point, we have selected `k_top = n_devices * self.k_top_device` # memories for every query. The selection process is approximate since # we retrieve `self.k_top_device` memories from every device and then # just concatenate the results. # Due to computational constraints we may wish to limit the number # of memories per query, so we subselect even further and keep only # `self.k_top_post_selection` retrieved memories for every query. if self.k_top_post_selection is not None: top_scores = top_scores.reshape(n_devices, n_queries, self.k_top_device) top_scores = jax.lax.pswapaxes(top_scores, axis_name='batch', axis=0) top_scores = jnp.swapaxes(top_scores, 0, 1) top_scores = top_scores.reshape(n_queries, k_top) # Take k highest scores among all rows. # pylint:disable=invalid-unary-operand-type top_post_selection_index = jnp.argsort( top_scores, axis=-1)[:, :-self.k_top_post_selection - 1:-1] # pylint:enable=invalid-unary-operand-type top_values = jut.matmul_slice(top_values, top_post_selection_index) top_entity_ids = jut.matmul_slice(top_entity_ids, top_post_selection_index) global_top_ids = jut.matmul_slice(global_top_ids, top_post_selection_index) # If we use separate memory values, distribute keys back also. if memory_values is not None: top_keys = top_keys.reshape(n_devices, n_queries, self.k_top_device, memory_key_dim) top_keys = jax.lax.pswapaxes(top_keys, axis_name='batch', axis=0) top_keys = jnp.swapaxes(top_keys, 0, 1) top_keys = top_keys.reshape(n_queries, k_top, memory_key_dim) if self.k_top_post_selection is not None: top_keys = jut.matmul_slice(top_keys, top_post_selection_index) else: top_keys = top_values retrieval_result['top_entity_ids'] = top_entity_ids retrieval_result['top_memory_ids'] = global_top_ids retrieval_result['top_values'] = top_values # We re-compute top scores using the queries with gradient (wg) to make sure # the mention encoder and the rest of the model receives gradient top_scores_wg = jnp.einsum('qd,qkd->qk', queries, top_keys) retrieval_result[ 'memory_attention_scores_with_disallowed'] = top_scores_wg # We want to disallow some mentions from being retrieved (i.e. from same # passage during pre-training). Here we mask retrieved mentions which have # the same identifier as the query. if text_identifiers is not None: top_ids = top_ids.reshape(n_devices, n_queries, self.k_top_device) gathered_identifiers = gathered_identifiers.reshape( n_devices, n_queries, 1) identifier_mask = ( memory_identifiers[top_ids] == gathered_identifiers) # We manually cast `identifier_mask` into int32. Otherwise, `pswapaxes` # which is known to have undefined behaviour on CPU, "corrupts" a vector # making it effectively int32, while keeping boolean dtype. This in turn # leads to a compilation error for the einsum operation in the # `matmul_slice` (types mismatch). identifier_mask = identifier_mask.astype(dtype=jnp.int32) identifier_mask = jax.lax.pswapaxes(identifier_mask, axis_name='batch', axis=0) identifier_mask = jnp.swapaxes(identifier_mask, 0, 1) identifier_mask = identifier_mask.reshape(n_queries, k_top) if self.k_top_post_selection is not None: identifier_mask = jut.matmul_slice(identifier_mask, top_post_selection_index) retrieval_result[ 'memory_attention_disallowed_mask'] = identifier_mask.astype( jnp.bool_) identifier_mask = identifier_mask.astype(top_scores_wg.dtype) # Depending on `same_passage_memory_policy` we treat memories from the # same passage as query mentions differently. if same_passage_memory_policy == 'disallow': top_scores_wg = top_scores_wg - identifier_mask * default_values.LARGE_NUMBER elif same_passage_memory_policy == 'only': top_scores_wg = top_scores_wg - ( 1.0 - identifier_mask) * default_values.LARGE_NUMBER elif same_passage_memory_policy == 'allow': pass else: raise ValueError( 'Unknown value for `same_passage_memory_policy: %s' % same_passage_memory_policy) n_disallowed = identifier_mask.sum() retrieval_result['n_disallowed'] = n_disallowed if memory_text_entities is not None: top_ids = top_ids.reshape(n_devices, n_queries, self.k_top_device) # shape [n_devices, n_queries, k_top_device, n_text_entities_per_passage] top_text_entities = memory_text_entities[top_ids] top_text_entities = jax.lax.pswapaxes(top_text_entities, axis_name='batch', axis=0) # shape [n_queries, n_devices, k_top_device, n_text_entities_per_passage] top_text_entities = jnp.swapaxes(top_text_entities, 0, 1) # shape [n_queries, n_devices * k_top_device, n_text_entities_per_passage] top_text_entities = top_text_entities.reshape(n_queries, k_top, -1) if self.k_top_post_selection is not None: top_text_entities = jut.matmul_slice(top_text_entities, top_post_selection_index) retrieval_result['memory_top_text_entities'] = top_text_entities # We perform dot product attention using retrieved memory vectors as key, # dense projection of retrieved vectors as value and value and mention # representations as query. attention_weights = nn.softmax(top_scores_wg, axis=-1) retrieval_result['memory_attention_weights'] = attention_weights return retrieval_result