Exemple #1
0
    def __call__(
        self,
        encoded_input: Array,
        retrieval_values: Array,
        retrieval_scores: Array,
        mention_batch_positions: Array,
        mention_start_positions: Array,
        mention_end_positions: Array,
        mention_mask: Array,
        deterministic: bool,
    ) -> Array:

        # Generate mention values from input representation
        mention_start_encodings = jut.matmul_2d_index_select(
            encoded_input, (mention_batch_positions, mention_start_positions))
        mention_end_encodings = jut.matmul_2d_index_select(
            encoded_input, (mention_batch_positions, mention_end_positions))

        passage_mention_values = self.value_projection(
            jnp.concatenate((mention_start_encodings, mention_end_encodings),
                            axis=-1))
        k_retrieval = retrieval_scores.shape[-1]
        passage_mention_values = jnp.expand_dims(passage_mention_values,
                                                 axis=1)
        passage_mention_values = jnp.tile(passage_mention_values,
                                          (1, k_retrieval, 1))

        # Generate concatenated values of shape [mentions, k, 2 * retrieval_dim]
        concat_values = jnp.concatenate(
            (passage_mention_values, retrieval_values), axis=-1)

        # MLP over concatenation mention value and individual retrieved value
        concat_values = nn.gelu(self.concat_mlp(concat_values))
        concat_values = self.concat_dense(concat_values)
        concat_values = self.concat_dropout(concat_values, deterministic)

        # Additional MLP layers
        for concat_mlp_layer in self.additional_concat_mlp_layers:
            concat_values = concat_mlp_layer(concat_values, deterministic)

        pooled_values = jnp.einsum('qk,qkd->qd', retrieval_scores,
                                   concat_values)

        # MLP layers applied to pooled retrieval values
        for pooled_mlp_layer in self.pooled_mlp_layers:
            pooled_values = pooled_mlp_layer(pooled_values, deterministic)
        pooled_values = pooled_values * mention_mask.reshape(-1, 1)

        encoded_input = jut.matmul_2d_index_add(
            encoded_input, (mention_batch_positions, mention_start_positions),
            pooled_values)

        encoded_input = self.layer_norm(encoded_input)

        return encoded_input
    def __call__(
        self,
        encoding: Array,
        mention_batch_positions: Array,
        mention_start_positions: Array,
        mention_end_positions: Array,
        mention_mask: Array,
        mention_entity_ids: Array,
    ) -> Array:
        """.

    Args:
      encoding: [batch_size, n_tokens, hidden_size].
      mention_batch_positions: [n_mentions].
      mention_start_positions: [n_mentions].
      mention_end_positions: [n_mentions].
      mention_mask: [n_mentions].
      mention_entity_ids: [n_mentions].

    Returns:
      Array of entity linking attention scores, shape [n_mentions, hidden_size].
    """
        mention_start_encodings = jut.matmul_2d_index_select(
            encoding, (mention_batch_positions, mention_start_positions))
        mention_end_encodings = jut.matmul_2d_index_select(
            encoding, (mention_batch_positions, mention_end_positions))
        projection_input = jnp.concatenate(
            (mention_start_encodings, mention_end_encodings), axis=-1)
        n_mentions = projection_input.shape[0]
        local_memory_keys = self.key_projector(projection_input)
        local_memory_values = self.value_projector(projection_input)

        memory_keys = jax.lax.all_gather(local_memory_keys, 'batch')
        memory_values = jax.lax.all_gather(local_memory_values, 'batch')
        memory_mask = jax.lax.all_gather(mention_mask, 'batch')
        memory_entity_ids = jax.lax.all_gather(mention_entity_ids, 'batch')
        n_devices = memory_keys.shape[0]

        memory_keys = memory_keys.reshape(n_devices * n_mentions,
                                          self.memory_key_dim)
        memory_values = memory_values.reshape(n_devices * n_mentions,
                                              self.memory_value_dim)
        memory_mask = memory_mask.reshape(n_devices * n_mentions)
        memory_entity_ids = memory_entity_ids.reshape(n_devices * n_mentions)

        return_dict = {
            'memory_keys': memory_keys,
            'memory_values': memory_values,
            'memory_mask': memory_mask,
            'memory_entity_ids': memory_entity_ids,
            'local_memory_keys': local_memory_keys,
            'local_memory_values': local_memory_values,
        }

        return return_dict
Exemple #3
0
 def test_matmul_2d_index_select(self, dim1, dim2, dim3, n_index):
     shape = [dim1, dim2]
     if dim3 is not None:
         shape.append(dim3)
     array = np.random.randint(_MAX_INT_VALUE, size=shape)
     indices_1 = np.random.randint(dim1, size=(n_index))
     indices_2 = np.random.randint(dim2, size=(n_index))
     actual = jut.matmul_2d_index_select(array, (indices_1, indices_2))
     self.assertTrue(jnp.array_equal(actual, array[indices_1, indices_2]))
Exemple #4
0
    def forward(
        self,
        batch: Dict[str, Array],
        deterministic: bool,
    ) -> Tuple[Array, Dict[str, Array], Dict[str, Array]]:

        loss_helpers = {}
        logging_helpers = {}

        embedded_input = self.embedder({
            'token_ids': batch['text_ids'],
            'position_ids': batch['position_ids'],
            'segment_ids': batch['segment_ids']
        })

        embedded_input = self.embeddings_layer_norm(embedded_input)
        embedded_input = self.embeddings_dropout(embedded_input,
                                                 deterministic=deterministic)

        loss_helpers['word_embeddings'] = self.embedder.variables['params'][
            'embedders_token_ids']['embedding']

        attention_mask = batch['text_mask']
        encoding = self.encoder(encoding=embedded_input,
                                attention_mask=attention_mask,
                                deterministic=deterministic)

        if 'mention_target_batch_positions' in batch:
            mention_start_encodings = jut.matmul_2d_index_select(
                encoding, (batch['mention_target_batch_positions'],
                           batch['mention_target_start_positions']))
            mention_end_encodings = jut.matmul_2d_index_select(
                encoding, (batch['mention_target_batch_positions'],
                           batch['mention_target_end_positions']))
            loss_helpers['target_mention_encodings'] = self.mention_projector(
                jnp.concatenate(
                    (mention_start_encodings, mention_end_encodings), axis=-1))

        return encoding, loss_helpers, logging_helpers
    def __call__(
        self,
        encoded_input: Array,
        mention_batch_positions: Array,
        mention_start_positions: Array,
        mention_end_positions: Array,
        mention_mask: Array,
        memory_keys: Array,
        memory_identifiers: Array,
        memory_entity_ids: Array,
        deterministic: bool,
        memory_values: Optional[Array] = None,
        text_identifiers: Optional[Array] = None,
        memory_text_entities: Optional[Array] = None,
        same_passage_memory_policy: str = 'disallow',
    ) -> Tuple[Array, Dict[str, Array], Dict[str, Array]]:
        """Perform attention update over memory table.

    Args:
      encoded_input: [batch_size, n_tokens, hidden_size] input representation.
      mention_batch_positions: [n_mentions] mention sample position in batch.
      mention_start_positions: [n_mentions] mention start position in input.
      mention_end_positions: [n_mentions] mention end position in input.
      mention_mask: [n_mentions] attention mask to prevent updates from padding.
      memory_keys: [rows, values per row, key_dim] mention memory keys. The
        number of rows in the memory table governs the recall vs speed of the
        topk similarity search. Search is performed by taking max over each row,
        and then top-k between rows. Distributing the same values over more rows
        leads to higher recall but slower search.
      memory_identifiers: [memory_size] identifier for memory vectors.
      memory_entity_ids: [memory_size] entity ids for mentions in memory table
      deterministic: don't apply dropout if true.
      memory_values: [values, memory_dim] if separate keys and values.
      text_identifiers: [n_mentions] search will not retrieve memory vectors
        with the same identifier as passage mention.
      memory_text_entities: [n_mentions, n_memory_text_entities] entity ids for
        passages where memories are coming from.
      same_passage_memory_policy: how to treat mentions from the same passage.
        Possible options: `allow`, `disallow` and `only`.

    Returns:
      Updated input, loss and logging helper dicts.
    """
        _assert_array_is_integer_or_none(mention_batch_positions,
                                         'mention_batch_positions')
        _assert_array_is_integer_or_none(mention_start_positions,
                                         'mention_start_positions')
        _assert_array_is_integer_or_none(mention_end_positions,
                                         'mention_end_positions')
        _assert_array_is_integer_or_none(memory_entity_ids,
                                         'memory_entity_ids')
        _assert_array_is_integer_or_none(memory_identifiers,
                                         'memory_identifiers')
        _assert_array_is_integer_or_none(memory_text_entities,
                                         'memory_text_entities')
        _assert_array_is_integer_or_none(text_identifiers, 'text_identifiers')

        loss_helpers, logging_helpers = {}, {}

        # We generate mention representations to use as queries for similarity
        # search by concatenating start and end tokens for each mention and
        # projecting the concatenation with a dense layer.
        mention_start_encodings = jut.matmul_2d_index_select(
            encoded_input, (mention_batch_positions, mention_start_positions))
        mention_end_encodings = jut.matmul_2d_index_select(
            encoded_input, (mention_batch_positions, mention_end_positions))

        queries = self.query_projector(
            jnp.concatenate((mention_start_encodings, mention_end_encodings),
                            axis=-1))

        loss_helpers['memory_attention_mention_encodings'] = queries

        retrieval_result = self.memory_retrieval_layer(
            queries=queries,
            memory_keys=memory_keys,
            memory_identifiers=memory_identifiers,
            memory_entity_ids=memory_entity_ids,
            memory_values=memory_values,
            text_identifiers=text_identifiers,
            memory_text_entities=memory_text_entities,
            same_passage_memory_policy=same_passage_memory_policy,
        )

        # Most of the information from retrieval_result goes to `loss_helpers`
        # except `n_disallowed`. In future, we might join these two into a single
        # dictionary.
        loss_helpers.update(retrieval_result)
        if 'n_disallowed' in retrieval_result:
            logging_helpers['n_disallowed'] = retrieval_result['n_disallowed']

        encoded_input = self.update_layer(
            encoded_input=encoded_input,
            retrieval_values=retrieval_result['top_values'],
            retrieval_scores=retrieval_result['memory_attention_weights'],
            mention_batch_positions=mention_batch_positions,
            mention_start_positions=mention_start_positions,
            mention_end_positions=mention_end_positions,
            mention_mask=mention_mask,
            deterministic=deterministic,
        )

        return encoded_input, loss_helpers, logging_helpers
    def __call__(
        self,
        encoded_input: Array,
        mention_batch_positions: Array,
        mention_start_positions: Array,
        mention_end_positions: Array,
        mention_mask: Array,
        entity_embeddings: Array,
    ) -> Dict[str, Array]:
        """Perform attention update over entity embedding table.

    Args:
      encoded_input: [batch_size, n_tokens, hidden_size].
      mention_batch_positions: [n_mentions].
      mention_start_positions: [n_mentions].
      mention_end_positions: [n_mentions].
      mention_mask: [n_mentions] attention mask to prevent updates from padding
        mentions.
      entity_embeddings: entity embedding table.

    Returns:
      Updated input, mention encodings and entity attention scores.
    """

        mention_start_encodings = jut.matmul_2d_index_select(
            encoded_input, (mention_batch_positions, mention_start_positions))
        mention_end_encodings = jut.matmul_2d_index_select(
            encoded_input, (mention_batch_positions, mention_end_positions))
        mention_encodings = self.mention_query_projector(
            jnp.concatenate((mention_start_encodings, mention_end_encodings),
                            axis=-1))

        scores = jnp.einsum('qd,ed->qe', mention_encodings, entity_embeddings)
        attention_weights = nn.softmax(scores, axis=-1)

        retrieved_values = jnp.einsum('qe,ed->qd', attention_weights,
                                      entity_embeddings)
        retrieved_values = self.entity_projector(retrieved_values)
        retrieved_values = retrieved_values * jnp.expand_dims(mention_mask, -1)

        encoded_input = jut.matmul_2d_index_add(
            encoded_input, (mention_batch_positions, mention_start_positions),
            retrieved_values)
        encoded_input = self.layer_norm(encoded_input)

        # The cosine similarity is computed as dot product divided by norms of
        # both vectors.
        mention_encodings_norm = jnp.linalg.norm(mention_encodings, axis=-1)
        entity_embeddings_norm = jnp.linalg.norm(entity_embeddings, axis=-1)
        cos_scores = scores
        cos_scores = cos_scores / (_SMALL_NUMBER +
                                   jnp.expand_dims(mention_encodings_norm, 1))
        cos_scores = cos_scores / (_SMALL_NUMBER +
                                   jnp.expand_dims(entity_embeddings_norm, 0))

        return {
            'encoded_output': encoded_input,
            'mention_encodings': mention_encodings,
            'cosine_similarity': cos_scores,
            'attention_weights': attention_weights,
        }
  def forward(
      self, batch: Dict[str, Array],
      deterministic: bool) -> Tuple[Array, Dict[str, Array], Dict[str, Array]]:
    loss_helpers = {}
    logging_helpers = {}

    embedded_input = self.embedder({
        'token_ids': batch['text_ids'],
        'position_ids': batch['position_ids'],
        'segment_ids': batch['segment_ids']
    })

    embedded_input = self.embeddings_layer_norm(embedded_input)
    embedded_input = self.embeddings_dropout(
        embedded_input, deterministic=deterministic)

    loss_helpers['word_embeddings'] = self.embedder.variables['params'][
        'embedders_token_ids']['embedding']

    attention_mask = batch['text_mask']

    def apply_encoder_block(encoder_block, encoding):
      return encoder_block(
          encoding=encoding,
          attention_mask=attention_mask,
          deterministic=deterministic,
      )

    # First read
    initial_encoding_first = apply_encoder_block(self.initial_encoder,
                                                 embedded_input)
    if self.num_intermediate_layers is not None:
      encoding_first = apply_encoder_block(self.intermediate_encoder,
                                           initial_encoding_first)
    else:
      encoding_first = initial_encoding_first
    encoding_first = apply_encoder_block(self.final_encoder, encoding_first)

    loss_helpers['final_encoding_first'] = encoding_first

    mention_batch_positions = batch['mention_batch_positions']
    mention_target_batch_positions = batch['mention_target_batch_positions']

    # Generate memory table
    if self.extract_unlinked_mentions:
      local_memory_entity_ids = jnp.zeros(
          dtype=jnp.int32, shape=batch['mention_mask'].shape[0])
      local_memory_entity_ids = local_memory_entity_ids.at[
          batch['mention_target_indices']].set(batch['mention_target_ids'])
      extracted_memory_dict = self.memory_extraction_layer(
          encoding=encoding_first,
          mention_batch_positions=mention_batch_positions,
          mention_start_positions=batch['mention_start_positions'],
          mention_end_positions=batch['mention_end_positions'],
          mention_mask=batch['mention_mask'],
          mention_entity_ids=local_memory_entity_ids,
      )
    else:
      extracted_memory_dict = self.memory_extraction_layer(
          encoding=encoding_first,
          mention_batch_positions=mention_target_batch_positions,
          mention_start_positions=batch['mention_target_start_positions'],
          mention_end_positions=batch['mention_target_end_positions'],
          mention_mask=batch['mention_target_weights'],
          mention_entity_ids=batch['mention_target_ids'],
      )
    memory_keys = extracted_memory_dict['memory_keys']
    memory_values = extracted_memory_dict['memory_values']
    memory_mask = extracted_memory_dict['memory_mask']
    memory_entity_ids = extracted_memory_dict['memory_entity_ids']
    local_memory_keys = extracted_memory_dict['local_memory_keys']
    local_memory_values = extracted_memory_dict['local_memory_values']

    loss_helpers['memory_keys'] = local_memory_keys
    loss_helpers['memory_values'] = local_memory_values

    if self.same_passage_retrieval_policy == 'only':
      memory_keys = local_memory_keys
      memory_values = local_memory_values
      memory_mask = mention_utils.all_compare(mention_batch_positions,
                                              mention_target_batch_positions)
      memory_entity_ids = batch['mention_target_ids']
    elif self.same_passage_retrieval_policy == 'disallow':
      # Transform local batch positions to global batch positions
      bsz = encoding_first.shape[0]
      (global_mention_batch_positions,
       _) = mention_utils.get_globally_consistent_batch_positions(
           mention_batch_positions, bsz)

      (_, all_global_mention_target_batch_positions
      ) = mention_utils.get_globally_consistent_batch_positions(
          mention_target_batch_positions, bsz)

      # Set mask to true only for mentions from different samples
      memory_mask = mention_utils.all_compare(
          global_mention_batch_positions,
          all_global_mention_target_batch_positions)
    elif self.same_passage_retrieval_policy != 'allow':
      raise ValueError('Unknown value for same_passage_retrieval_policy: %s' %
                       self.same_passage_retrieval_policy)

    mention_mask = batch['mention_mask']
    if self.no_retrieval_for_masked_mentions:
      mention_mask = mention_mask * (1 - batch['mention_is_masked'])

    def apply_memory_layer(attention_layer, encoding, prefix=''):
      encoding, mem_loss_helpers, mem_logging_helpers = attention_layer(
          encoding=encoding,
          mention_batch_positions=mention_batch_positions,
          mention_start_positions=batch['mention_start_positions'],
          mention_end_positions=batch['mention_end_positions'],
          mention_mask=mention_mask,
          memory_keys=memory_keys,
          memory_values=memory_values,
          memory_mask=memory_mask,
          memory_entity_ids=memory_entity_ids,
          deterministic=deterministic,
      )
      loss_helpers.update(
          {prefix + key: value for key, value in mem_loss_helpers.items()})
      logging_helpers.update(
          {prefix + key: value for key, value in mem_logging_helpers.items()})
      return encoding

    # If second read shares initial encoder, just reuse initial representation
    if self.shared_initial_encoder:
      encoding_second = initial_encoding_first
    else:
      encoding_second = apply_encoder_block(self.second_initial_encoder,
                                            embedded_input)

    if self.num_intermediate_layers is not None:
      if not self.no_retrieval:
        encoding_second = apply_memory_layer(
            attention_layer=self.intermediate_memory_attention_layer,
            encoding=encoding_second)
      encoding_second = apply_encoder_block(self.second_intermediate_encoder,
                                            encoding_second)
      if not self.no_retrieval:
        encoding_second = apply_memory_layer(
            attention_layer=self.final_memory_attention_layer,
            encoding=encoding_second,
            prefix='second_')
    else:
      if not self.no_retrieval:
        encoding_second = apply_memory_layer(
            attention_layer=self.memory_attention_layer,
            encoding=encoding_second)
    encoding_second = apply_encoder_block(self.second_final_encoder,
                                          encoding_second)

    if 'mention_target_batch_positions' in batch:
      mention_start_final_encodings = jut.matmul_2d_index_select(
          encoding_second, (batch['mention_target_batch_positions'],
                            batch['mention_target_start_positions']))
      mention_end_final_encodings = jut.matmul_2d_index_select(
          encoding_second, (batch['mention_target_batch_positions'],
                            batch['mention_target_end_positions']))
      loss_helpers['target_mention_encodings'] = self.mention_projector(
          jnp.concatenate(
              (mention_start_final_encodings, mention_end_final_encodings),
              axis=-1))

    return encoding_second, loss_helpers, logging_helpers
    def forward(
        self,
        batch: Dict[str, Array],
        deterministic: bool,
    ) -> Tuple[Array, Dict[str, Array], Dict[str, Array]]:
        loss_helpers = {}
        logging_helpers = {}

        embedded_input = self.embedder({
            'token_ids': batch['text_ids'],
            'position_ids': batch['position_ids'],
            'segment_ids': batch['segment_ids']
        })

        embedded_input = self.embeddings_layer_norm(embedded_input)
        embedded_input = self.embeddings_dropout(embedded_input,
                                                 deterministic=deterministic)

        loss_helpers['word_embeddings'] = self.embedder.variables['params'][
            'embedders_token_ids']['embedding']

        attention_mask = batch['text_mask']
        encoding = self.initial_encoder(encoding=embedded_input,
                                        attention_mask=attention_mask,
                                        deterministic=deterministic)

        memory_values = jnp.asarray(
            self.memory_values.value,
            dtype=self.dtype) if self.separate_memory_values else None
        memory_keys = jnp.asarray(self.memory_keys.value, dtype=self.dtype)
        memory_entity_ids = self.memory_entity_ids.value
        memory_identifiers = self.memory_identifiers.value

        loss_helpers['memory_values'] = memory_values
        loss_helpers['memory_keys'] = memory_keys
        loss_helpers['memory_entity_ids'] = memory_entity_ids
        loss_helpers['memory_identifiers'] = memory_identifiers

        def apply_memory_attention(memory_layer, encoding, prefix=''):
            encoding, mem_loss_helpers, mem_logging_helpers = memory_layer(
                encoded_input=encoding,
                mention_batch_positions=batch['mention_batch_positions'],
                mention_start_positions=batch['mention_start_positions'],
                mention_end_positions=batch['mention_end_positions'],
                mention_mask=batch['mention_mask'],
                memory_keys=memory_keys,
                memory_identifiers=memory_identifiers,
                memory_entity_ids=memory_entity_ids,
                deterministic=deterministic,
                memory_values=memory_values,
                text_identifiers=batch.get('text_identifiers', None),
                memory_text_entities=(self.memory_text_entities.value
                                      if self.memory_text_entities is not None
                                      else None),
                same_passage_memory_policy=self.same_passage_memory_policy,
            )
            loss_helpers.update({
                prefix + key: value
                for key, value in mem_loss_helpers.items()
            })
            logging_helpers.update({
                prefix + key: value
                for key, value in mem_logging_helpers.items()
            })
            return encoding

        if self.num_intermediate_layers is None:
            encoding = apply_memory_attention(self.memory_attention_layer,
                                              encoding)
        else:
            encoding = apply_memory_attention(
                self.intermediate_memory_attention_layer, encoding)
            encoding = self.intermediate_encoder(encoding=encoding,
                                                 attention_mask=attention_mask,
                                                 deterministic=deterministic)
            encoding = apply_memory_attention(
                self.final_memory_attention_layer, encoding, 'second_')
        encoding = self.final_encoder(encoding=encoding,
                                      attention_mask=attention_mask,
                                      deterministic=deterministic)

        if 'mention_target_batch_positions' in batch:
            mention_start_final_encodings = jut.matmul_2d_index_select(
                encoding, (batch['mention_target_batch_positions'],
                           batch['mention_target_start_positions']))
            mention_end_final_encodings = jut.matmul_2d_index_select(
                encoding, (batch['mention_target_batch_positions'],
                           batch['mention_target_end_positions']))

            loss_helpers[
                'intermediate_target_mention_encodings'] = jut.matmul_slice(
                    loss_helpers['memory_attention_mention_encodings'],
                    batch['mention_target_indices'])
            if self.num_intermediate_layers is not None:
                loss_helpers[
                    'second_intermediate_target_mention_encodings'] = jut.matmul_slice(
                        loss_helpers[
                            'second_memory_attention_mention_encodings'],
                        batch['mention_target_indices'])

            loss_helpers['target_mention_encodings'] = self.mention_projector(
                jnp.concatenate((mention_start_final_encodings,
                                 mention_end_final_encodings),
                                axis=-1))

            # Final retrieval layer is only applied over target mentions.
            if self.apply_final_retrieval:
                queries = self.final_query_projector(
                    loss_helpers['target_mention_encodings'])

                retrieval_result = self.final_memory_retrieval_layer(
                    queries=queries,
                    memory_keys=memory_keys,
                    memory_identifiers=memory_identifiers,
                    memory_entity_ids=memory_entity_ids,
                    memory_values=memory_values,
                    text_identifiers=None,
                    memory_text_entities=None,
                    same_passage_memory_policy='disallow',
                )

                loss_helpers.update(
                    {'final_' + k: v
                     for k, v in retrieval_result.items()})

        return encoding, loss_helpers, logging_helpers
Exemple #9
0
    def forward(self, batch: Dict[str, Array], deterministic: bool):
        loss_helpers = {}
        logging_helpers = {}

        embedded_input = self.embedder({
            'token_ids': batch['text_ids'],
            'position_ids': batch['position_ids'],
            'segment_ids': batch['segment_ids']
        })

        embedded_input = self.embeddings_layer_norm(embedded_input)
        embedded_input = self.embeddings_dropout(embedded_input, deterministic)

        loss_helpers['word_embeddings'] = self.embedder.variables['params'][
            'embedders_token_ids']['embedding']

        attention_mask = batch['text_mask']
        encoding = self.initial_encoder(encoding=embedded_input,
                                        attention_mask=attention_mask,
                                        deterministic=deterministic)

        if not self.no_retrieval:
            encoding = self.retrieval_update_layer(
                encoded_input=encoding,
                retrieval_values=jnp.expand_dims(
                    # [max_retrieval_indices, retrieval_dim]
                    batch['retrieval_mention_values'],
                    -2),
                retrieval_scores=jnp.expand_dims(
                    # [max_retrieval_indices]
                    batch['retrieval_mention_scores'],
                    -1),
                mention_batch_positions=batch[
                    'retrieval_mention_batch_positions'],
                mention_start_positions=batch[
                    'retrieval_mention_start_positions'],
                mention_end_positions=batch['retrieval_mention_end_positions'],
                mention_mask=batch['retrieval_mention_mask'],
                deterministic=deterministic)

        encoding = self.final_encoder(encoding=encoding,
                                      attention_mask=attention_mask,
                                      deterministic=deterministic)

        mention_target_batch_positions = jut.matmul_slice(
            batch['mention_batch_positions'], batch['mention_target_indices'])
        mention_target_start_positions = jut.matmul_slice(
            batch['mention_start_positions'], batch['mention_target_indices'])
        mention_target_end_positions = jut.matmul_slice(
            batch['mention_end_positions'], batch['mention_target_indices'])

        mention_start_final_encodings = jut.matmul_2d_index_select(
            encoding,
            (mention_target_batch_positions, mention_target_start_positions))
        mention_end_final_encodings = jut.matmul_2d_index_select(
            encoding,
            (mention_target_batch_positions, mention_target_end_positions))
        loss_helpers['target_mention_encodings'] = self.mention_projector(
            jnp.concatenate(
                (mention_start_final_encodings, mention_end_final_encodings),
                axis=-1))

        return encoding, loss_helpers, logging_helpers
Exemple #10
0
    def forward(
        self,
        batch: Dict[str, Array],
        deterministic: bool,
    ) -> Tuple[Array, Dict[str, Array], Dict[str, Array]]:

        loss_helpers = {}
        logging_helpers = {}

        embedded_input = self.embedder({
            'token_ids': batch['text_ids'],
            'position_ids': batch['position_ids'],
            'segment_ids': batch['segment_ids']
        })

        embedded_input = self.embeddings_layer_norm(embedded_input)
        embedded_input = self.embeddings_dropout(embedded_input,
                                                 deterministic=deterministic)

        loss_helpers['word_embeddings'] = self.embedder.variables['params'][
            'embedders_token_ids']['embedding']

        attention_mask = batch['text_mask']
        encoding = self.initial_encoder(encoding=embedded_input,
                                        attention_mask=attention_mask,
                                        deterministic=deterministic)

        if self.no_entity_attention:
            if 'mention_target_batch_positions' in batch:
                mention_start_encodings = jut.matmul_2d_index_select(
                    encoding, (batch['mention_target_batch_positions'],
                               batch['mention_target_start_positions']))
                mention_end_encodings = jut.matmul_2d_index_select(
                    encoding, (batch['mention_target_batch_positions'],
                               batch['mention_target_end_positions']))
                loss_helpers[
                    'im_target_mention_encodings'] = self.intermediate_mention_projector(
                        jnp.concatenate(
                            (mention_start_encodings, mention_end_encodings),
                            axis=-1))
                loss_helpers['entity_embeddings'] = jnp.asarray(
                    self.entity_embeddings, dtype=self.dtype)
        else:
            entity_attention_output = self.entity_attention_layer(
                encoded_input=encoding,
                mention_batch_positions=batch['mention_batch_positions'],
                mention_start_positions=batch['mention_start_positions'],
                mention_end_positions=batch['mention_end_positions'],
                mention_mask=batch['mention_mask'],
                entity_embeddings=jnp.asarray(self.entity_embeddings,
                                              dtype=self.dtype),
            )
            encoding = entity_attention_output['encoded_output']
            loss_helpers[
                'intermediate_mention_encodings'] = entity_attention_output[
                    'mention_encodings']
            loss_helpers[
                'intermediate_entity_attention'] = entity_attention_output[
                    'attention_weights']
            loss_helpers[
                'intermediate_entity_cos_sim'] = entity_attention_output[
                    'cosine_similarity']

        encoding = self.final_encoder(encoding=encoding,
                                      attention_mask=attention_mask,
                                      deterministic=deterministic)

        if 'mention_target_batch_positions' in batch:
            mention_start_encodings = jut.matmul_2d_index_select(
                encoding, (batch['mention_target_batch_positions'],
                           batch['mention_target_start_positions']))
            mention_end_encodings = jut.matmul_2d_index_select(
                encoding, (batch['mention_target_batch_positions'],
                           batch['mention_target_end_positions']))
            loss_helpers['target_mention_encodings'] = self.mention_projector(
                jnp.concatenate(
                    (mention_start_encodings, mention_end_encodings), axis=-1))
            loss_helpers['entity_embeddings'] = jnp.asarray(
                self.entity_embeddings, dtype=self.dtype)

        return encoding, loss_helpers, logging_helpers
  def __call__(
      self,
      encoding: Array,
      mention_batch_positions: Array,
      mention_start_positions: Array,
      mention_end_positions: Array,
      mention_mask: Array,
      memory_keys: Array,
      memory_values: Array,
      memory_mask: Array,
      memory_entity_ids: Array,
      deterministic: bool,
  ) -> Tuple[Array, Dict[str, Array], Dict[str, Array]]:
    """Perform attention update over memory table.

    Args:
      encoding: [batch_size, n_tokens, hidden_size] input representation.
      mention_batch_positions: [n_mentions] mention sample position in batch.
      mention_start_positions: [n_mentions] mention start position in input.
      mention_end_positions: [n_mentions] mention end position in input.
      mention_mask: [n_mentions] attention mask to prevent updates from padding.
      memory_keys: [memory_size, memory_key_dim] mention memory keys.
      memory_values: [memory_size, memory_value_dim] mention memory values.
      memory_mask: [memory_size] mask for valid mentions in memory.
      memory_entity_ids: [memory_size] mention memory entity ids.
      deterministic: don't apply dropout if true.

    Returns:
      Updated input, loss and logging helper dicts.
    """
    loss_helpers, logging_helpers = {}, {}

    # We generate mention representations to use as queries for similarity
    # search by concatenating start and end tokens for each mention and
    # projecting the concatenation with a dense layer.
    mention_start_encodings = jut.matmul_2d_index_select(
        encoding, (mention_batch_positions, mention_start_positions))
    mention_end_encodings = jut.matmul_2d_index_select(
        encoding, (mention_batch_positions, mention_end_positions))

    queries = self.query_projector(
        jnp.concatenate((mention_start_encodings, mention_end_encodings),
                        axis=-1))

    n_queries = queries.shape[0]

    # For attention over entire memory table, we do not want to duplicate the
    # entire memory table for each query. Instead, we perform an
    # attention-weighted sum to produce a single value. We then feed this value
    # to the update layer as a set of retrieved values of size 1, with score 1.
    if self.k_top is None:
      loss_helpers['top_entity_ids'] = jnp.tile(memory_entity_ids,
                                                (n_queries, 1))
      scores = jnp.einsum('qd,md->qm', queries, memory_keys)
      scores = scores - (1 - memory_mask) * _LARGE_NUMBER
      true_attention_weights = nn.softmax(scores, axis=-1)
      loss_helpers['memory_attention_weights'] = true_attention_weights
      top_values = jnp.einsum('qm,md->qd', true_attention_weights,
                              memory_values)
      # Expand value as though it were a set of retrieved values for each query.
      # Shape (n_queries, 1, memory_value_dim)
      top_values = jnp.expand_dims(top_values, axis=1)
      # Generate pseudo-score (n_queries, 1).
      attention_weights = jnp.ones_like(top_values, shape=(n_queries, 1))
    else:
      # Reshape memory keys for use in approximate top-k similarity layer.
      memory_keys = memory_keys.reshape(self.rows, -1, self.memory_key_dim)
      # We generate a version of the queries with stop gradient to use as input
      # to the topk similarity layer. We actually do want gradient to flow to
      # the queries, but backward differentiation over the topk layer yields
      # inefficient HLO ops. Instead we use queries with gradient to recompute
      # attention scores later.
      queries_sg = jax.lax.stop_gradient(queries)

      # Perform top-k similarity search over queries, yielding
      #   top_values: (queries, k_top, memory_dim)
      #   top_ids: (queries, k_top)
      top_keys, _, top_ids = self.topk_similarity(queries_sg, memory_keys)

      top_ids = top_ids.reshape(n_queries, self.k_top)
      top_values = memory_values[top_ids]
      loss_helpers['top_entity_ids'] = memory_entity_ids[top_ids]

      # We re-compute top scores using the queries with gradient (wg) to make
      # sure the query projector and the rest of the model receives gradient.
      top_scores_wg = jnp.einsum('qd,qkd->qk', queries, top_keys)
      top_mask = memory_mask[top_ids]
      top_scores_wg = top_scores_wg - (1 - top_mask) * _LARGE_NUMBER

      # We perform dot product attention using retrieved memory vectors as key,
      # dense projection of retrieved vectors as value and value and mention
      # representations as query.
      attention_weights = nn.softmax(top_scores_wg, axis=-1)
      loss_helpers['memory_attention_weights'] = attention_weights
    encoding = self.update_layer(
        encoded_input=encoding,
        retrieval_values=top_values,
        retrieval_scores=attention_weights,
        mention_batch_positions=mention_batch_positions,
        mention_start_positions=mention_start_positions,
        mention_end_positions=mention_end_positions,
        mention_mask=mention_mask,
        deterministic=deterministic,
    )

    return encoding, loss_helpers, logging_helpers