Exemple #1
0
    def call(self,
             logits,
             annotation_begins,
             annotation_ends,
             annotation_labels,
             block_ids,
             num_replicas=None,
             eps=0):
        """Calls the layer.

    Args:
      logits: <float32>[batch_size, main_seq_len, 2] Logits per position.
      annotation_begins: <int32>[batch_size, main_seq_len] Positions of
        beginnings of answer spans.
      annotation_ends: <int32>[batch_size, main_seq_len] Positions of endings of
        answer spans.
      annotation_labels: <int32>[batch_size, main_seq_len] Positions of labels
        of answer spans. Label is 0 when the span is a placeholder one (included
        only for padding purposes) and should be ignored.
      block_ids: <int32>[batch_size] Block IDs of every sample in the batch.
      num_replicas: Number of replicas to gather summaries from. If None
        (default) then cross-replicas summaries are not used.
      eps: <float> Small constant for numerical stability.

    Returns:
        total_loss: <float>
    """
        seq_length = tf.shape(logits)[1]

        # (1) Aggregate block_ids across global batch. Compute cross block mask.
        all_block_ids = block_ids
        if num_replicas:
            all_block_ids = tpu_utils.cross_replica_concat(
                tensor=all_block_ids,
                num_replicas=num_replicas,
                name='block_ids_concat')

        # [batch_size, global_batch_size]
        cross_blocks_eq_mask = tf.cast(
            tf.equal(tf.expand_dims(block_ids, 1),
                     tf.expand_dims(all_block_ids, 0)), tf.float32)

        # (2) Apply softmax over all positions in the (global) batch
        # across the blocks with the same `block_id`.

        # [batch_size, seq_len, 2]
        probs = cross_batch_softmax(logits, cross_blocks_eq_mask, num_replicas)

        # (3) Prepare one-hot labels based on annotation begins and ends

        # [batch_size, seq_len, 1]
        annotation_begins_one_hot = _one_hot_multi(
            annotation_begins,
            annotation_labels > 0,
            seq_length,
        )
        # [batch_size, seq_len, 1]
        annotation_ends_one_hot = _one_hot_multi(
            annotation_ends,
            annotation_labels > 0,
            seq_length,
        )
        # [batch_size, seq_len, 2]
        one_hot_labels = tf.concat(
            [annotation_begins_one_hot, annotation_ends_one_hot], 2)

        # (4) Compute the probability of the current begin / end positions across
        # the blocks with the same `block_id`.

        # [batch_size, 2]
        correct_probs = tf.reduce_sum(probs * one_hot_labels, axis=1)
        if num_replicas:
            # [global_batch_size, 2]
            correct_probs = tpu_utils.cross_replica_concat(
                tensor=correct_probs,
                num_replicas=num_replicas,
                name='correct_probs_concat')

        # [batch_size, 2]
        correct_probs = tf.matmul(cross_blocks_eq_mask, correct_probs)

        # (5) Compute log probability. We allow cases when there are no correct
        # labels not only for the current sample, but for the whole document
        # across the whole batch. In that case the probability of the correct label
        # would be 0 and the loss would be infinite. Therefore, we just do not
        # compute loss on these documents.

        # [batch_size, 1]
        num_annotations_per_sample = tf.reduce_sum(annotation_labels,
                                                   1,
                                                   keepdims=True)
        if num_replicas:
            # [global_batch_size, 1]
            num_annotations_per_sample = tpu_utils.cross_replica_concat(
                tensor=num_annotations_per_sample,
                num_replicas=num_replicas,
                name='num_annotations_per_sample_concat')

        # [batch_size, 1]
        num_annotations_per_doc = tf.matmul(
            cross_blocks_eq_mask,
            tf.cast(num_annotations_per_sample, tf.float32))
        # [batch_size, 2]
        doc_with_annotations_mask = tf.stop_gradient(
            tf.cast(tf.tile(num_annotations_per_doc > 0, [1, 2]), tf.float32))
        doc_without_annotations_mask = tf.stop_gradient(
            1 - doc_with_annotations_mask)
        log_correct_probs = tf.log(
            correct_probs + eps +
            doc_without_annotations_mask) * doc_with_annotations_mask

        # (6) Divide by the number of blocks per block_id
        # If there are K blocks with the same block_id, then on step (4) we'll
        # compute loss for this document K times. So we need to divide it back by K.

        # [batch_size, 2]
        log_correct_probs /= tf.reduce_sum(cross_blocks_eq_mask,
                                           1,
                                           keepdims=True)

        # (7) Sum over blocks and begin/end predictions

        loss = -tf.reduce_sum(log_correct_probs)
        return loss
    def call(self,
             yesno_logits,
             yesno_labels,
             supporting_fact_logits,
             supporting_fact_labels,
             block_ids,
             num_replicas=None,
             eps=0):
        """Calls the layer.

    Args:
      yesno_logits: <float32>[batch_size, 3] Logits per position.
      supporting_fact_logits: <float32>[batch_size] Logits per position fro
        supporting facts classification.
      block_ids: <int32>[batch_size] Block IDs of every sample in the batch.
      num_replicas: Number of replicas to gather summaries from. If None
        (default) then cross-replicas summaries are not used.
      eps: <float> Small constant for numerical stability.

    Returns:
        total_loss: <float>
    """
        batch_size = tf.shape(supporting_fact_logits)[0]
        supporting_fact_logits = tf.expand_dims(supporting_fact_logits, 1)
        supporting_fact_labels = tf.expand_dims(supporting_fact_labels, 1)
        example_mask = tf.cast(tf.expand_dims(tf.not_equal(block_ids, 0), 1),
                               tf.float32)

        # (1) Aggregate block_ids across global batch. Compute cross block mask.
        all_block_ids = block_ids
        if num_replicas:
            all_block_ids = tpu_utils.cross_replica_concat(
                tensor=all_block_ids,
                num_replicas=num_replicas,
                name='block_ids_concat')

        # [batch_size, global_batch_size]
        cross_blocks_eq_mask = tf.cast(
            tf.equal(tf.expand_dims(block_ids, 1),
                     tf.expand_dims(all_block_ids, 0)), tf.float32)

        # (2) Apply softmax over all positions in the (global) batch
        # across the blocks with the same `block_id`.

        # [batch_size, 3, 1]
        yes_no_span_probs = losses.cross_batch_softmax(
            tf.expand_dims(yesno_logits, 2), cross_blocks_eq_mask,
            num_replicas)
        yes_no_span_probs = tf.squeeze(yes_no_span_probs, 2)

        # [batch_size, 1]
        supporting_facts_probs = losses.cross_batch_softmax(
            tf.expand_dims(supporting_fact_logits, 2), cross_blocks_eq_mask,
            num_replicas)
        supporting_facts_probs = tf.squeeze(supporting_facts_probs, 2)

        # (3) Prepare one-hot labels based on annotation begins and ends

        supporting_fact_labels = tf.cast(supporting_fact_labels, tf.float32)

        # [batch_size, 3]
        yes_no_span_one_hot = tf.one_hot(yesno_labels,
                                         depth=3,
                                         dtype=tf.float32)
        yes_no_span_one_hot = yes_no_span_one_hot * supporting_fact_labels

        # (4) Compute the probability of the current begin / end positions across
        # the blocks with the same `block_id`.

        def mean_loss(all_losses):
            return tf.reduce_sum(all_losses * example_mask) / (
                tf.reduce_sum(example_mask) + eps)

        supporting_facts_loss = -mean_loss(
            tf.log(supporting_facts_probs * supporting_fact_labels + eps))

        yes_no_span_loss = -mean_loss(
            tf.log(yes_no_span_probs * yes_no_span_one_hot + eps))

        return yes_no_span_loss, supporting_facts_loss
Exemple #3
0
def cross_batch_softmax(logits, cross_blocks_eq_mask, num_replicas=None):
    """Computes softmax across the whole (global) batch.

  The computations are independent with respect to the 3rd, innermost dimension.
  In case of the span prediction, the size of this dimension is K=2, which
  corresponds to beginings and ends of annotations.

  Args:
    logits: <float32>[batch_size, seq_len, K] Tensor of logits.
    cross_blocks_eq_mask: <float32>[batch_size, global_batch_size] The mask
      which indicates which samples in the batch have the same block IDs.
    num_replicas: Optional[int]. If provided the function performs computations
      over the global (multi-devices) batch. Should be equal to the number of
      devices.

  Returns:
      probs: <float32>[batch_size, seq_len, K]
  """
    # (1) Apply max-trick to improve softmax numerical stability.

    # [batch_size, K]
    max_logits_per_sample = tf.math.reduce_max(logits, axis=1)
    if num_replicas:
        # [global_batch_size, K]
        max_logits_per_sample = tpu_utils.cross_replica_concat(
            tensor=max_logits_per_sample,
            num_replicas=num_replicas,
            name='max_logits_per_sample_concat')
    # [1, global_batch_size, K]
    max_logits_per_sample = tf.expand_dims(max_logits_per_sample, 0)

    # [batch_size, global_batch_size, 1]
    one_minus_one_mask = 2 * tf.expand_dims(cross_blocks_eq_mask, 2) - 1
    # [batch_size, global_batch_size, K]
    masked_max_logits_per_sample = tf.minimum(max_logits_per_sample,
                                              one_minus_one_mask * np.inf)
    # [batch_size, K]
    max_logits_per_sample = tf.reduce_max(masked_max_logits_per_sample, axis=1)

    # [batch_size, seq_len, K]
    logits -= tf.expand_dims(max_logits_per_sample, 1)

    # (2) Take exponent
    unnormalized_probs = tf.exp(logits)

    # (3) Compute softmax's denominator (normalization constant)

    # [batch_size, K]
    softmax_denominator_per_sample = tf.math.reduce_sum(unnormalized_probs,
                                                        axis=1)
    if num_replicas:
        # [global_batch_size, K]
        softmax_denominator_per_sample = tpu_utils.cross_replica_concat(
            tensor=softmax_denominator_per_sample,
            num_replicas=num_replicas,
            name='softmax_denominator_per_sample_concat')

    # [batch_size, K]
    softmax_denominator_per_sample = tf.matmul(cross_blocks_eq_mask,
                                               softmax_denominator_per_sample)

    # (4) Compute probabilities

    # [batch_size, seq_len, K]
    probs = unnormalized_probs / tf.expand_dims(softmax_denominator_per_sample,
                                                1)
    return probs
Exemple #4
0
    def call(self,
             hidden_states,
             block_ids,
             block_pos,
             annotation_begins,
             annotation_ends,
             annotation_labels,
             main_seq_length,
             num_replicas_concat,
             cross_block_attention_mode,
             training,
             token_ids=None):
        """Calls the layer.

    Args:
      hidden_states: <int32>[batch_size, main_seq_len, hidden size]. Final
        hidden states of the input after the first pass of the model.
      block_ids: <int32>[batch_size] Block IDs of every sample in the batch.
      block_pos: <int32>[batch_size] Optional Tensor of absolute position ids of
        blocks in the original document.
      annotation_begins: <int32>[batch_size, max_num_annotations] Begin index of
        annotations.
      annotation_ends: <int32>[batch_size, max_num_annotations] End index of
        annotations (inclusive)
      annotation_labels: <int32>[batch_size, max_num_annotations] Label for
        annotations.
      main_seq_length: Length of the input text
      num_replicas_concat: Number of replicas to gather summaries from. If None
        (default) then cross-replicas summaries are not used.
      cross_block_attention_mode: The policy on how summaries between different
        blocks are allowed to interact with each other.
      training: bool. true for training model, false for eval model. Controls
        whether dropout will be applied.
      token_ids: <int32>[batch_size, main_seq_len] Tokens.

    Returns:
        summary_output: SummaryExtractionOutput object
    """
        # [batch_size, local_num_summaries, hidden_size]
        first_token_tensor = self._extract_summary(
            hidden_states=hidden_states,
            annotation_begins=annotation_begins,
            annotation_ends=annotation_ends,
            annotation_labels=annotation_labels,
            token_ids=token_ids)
        batch_size = tf.shape(first_token_tensor)[0]
        local_num_summaries = tf.shape(first_token_tensor)[1]
        original_block_ids = block_ids
        original_block_pos = block_pos
        block_ids = tf.tile(tf.expand_dims(block_ids, 1),
                            [1, local_num_summaries])
        block_pos = tf.tile(tf.expand_dims(block_pos, 1),
                            [1, local_num_summaries])

        first_token_tensor = tf.reshape(
            first_token_tensor,
            [batch_size * local_num_summaries, self.hidden_size])
        block_ids = tf.reshape(block_ids, [batch_size * local_num_summaries])
        block_pos = tf.reshape(block_pos, [batch_size * local_num_summaries])

        all_first_token_tensor = first_token_tensor
        all_block_ids = block_ids
        all_block_pos = block_pos
        if num_replicas_concat:
            # Concatenate all the required tensors across tpu cores.
            all_first_token_tensor = tpu_utils.cross_replica_concat(
                tensor=all_first_token_tensor,
                num_replicas=num_replicas_concat,
                name="cls_token_concat")
            all_block_ids = tpu_utils.cross_replica_concat(
                tensor=all_block_ids,
                num_replicas=num_replicas_concat,
                name="block_ids_concat")
            all_block_ids = tf.stop_gradient(all_block_ids)

            all_block_pos = tpu_utils.cross_replica_concat(
                tensor=all_block_pos,
                num_replicas=num_replicas_concat,
                name="block_pos_concat")
            all_block_pos = tf.stop_gradient(all_block_pos)

        first_token_tensor.set_shape([None, self.hidden_size])
        all_first_token_tensor.set_shape([None, self.hidden_size])

        if self.mode == "cls":
            labels = block_ids
            all_labels = all_block_ids
        elif self.mode == "text_block":
            token_block_offset = self.text_block_extract_every_x - 1
            token_block_len = self.text_block_extract_every_x
            labels = tf.cast(
                tf.logical_and(
                    tf.not_equal(token_ids[:, ::token_block_len], 0),
                    tf.not_equal(
                        token_ids[:, token_block_offset::token_block_len], 0),
                ), tf.int32)
            labels = tf.reshape(labels, [batch_size * local_num_summaries])
            all_labels = labels
            if num_replicas_concat:
                all_labels = tpu_utils.cross_replica_concat(
                    tensor=all_labels,
                    num_replicas=num_replicas_concat,
                    name="labels_concat")
                all_labels = tf.stop_gradient(all_labels)
        else:
            assert self.mode == "entity"
            labels = tf.reshape(annotation_labels,
                                [batch_size * local_num_summaries])
            all_labels = labels
            if num_replicas_concat:
                all_labels = tpu_utils.cross_replica_concat(
                    tensor=all_labels,
                    num_replicas=num_replicas_concat,
                    name="labels_concat")
                all_labels = tf.stop_gradient(all_labels)

        # TODO(urikz): Consider using this
        # Filter out all padding summaries -- the convention is that
        # padding summaries will have label 0.
        # non_padding_summary_mask = tf.not_equal(all_labels, 0)
        # all_first_token_tensor = tf.boolean_mask(all_first_token_tensor,
        #                                          non_padding_summary_mask)
        # all_block_pos = tf.boolean_mask(all_block_pos, non_padding_summary_mask)
        # all_block_ids = tf.boolean_mask(all_block_ids, non_padding_summary_mask)
        # all_labels = tf.boolean_mask(all_labels, non_padding_summary_mask)

        if self.postprocessing_type == "none":
            all_cls_summary = all_first_token_tensor
        elif self.postprocessing_type == "linear":
            all_cls_summary = self.postprocessing(all_first_token_tensor)
        elif self.postprocessing_type in ["pos", "transformer"]:
            # We treat sequence of summaries as just a single sentence.
            # [1, global_num_summaries, hidden_dim]
            all_cls_summary = tf.expand_dims(all_first_token_tensor, 0)

            # Add positional embeddings based on positions of blocks in their
            # original documents.
            all_cls_summary += self.position_embedding(all_block_pos)
            all_cls_summary = self.embedding_norm(all_cls_summary)
            # Note, we don't apply dropout here
            # all_cls_summary = self.embedding_dropout(
            #     all_cls_summary, training=training)
            if self.postprocessing_type == "transformer":
                # Create cross block attention map
                # according to the `cross_block_attention_mode`.
                # [global_num_summaries, global_num_summaries]
                block_att_mask = get_cross_block_att(
                    all_block_ids, all_block_pos, all_block_ids, all_block_pos,
                    cross_block_attention_mode)
                # [1, global_num_summaries, global_num_summaries]
                block_att_mask = tf.expand_dims(block_att_mask, 0)

                all_cls_summary = self.postprocessing(
                    main_input=all_cls_summary,
                    side_input=None,
                    att_mask=block_att_mask,
                    training=training)
            all_cls_summary = tf.squeeze(all_cls_summary, 0)
        else:
            raise ValueError("Unknown `postprocessing_type`: '{}'".format(
                self.postprocessing_type))

        # [batch_size, global_num_summaries]
        token_to_global_summary_att_map = get_cross_block_att(
            original_block_ids, original_block_pos, all_block_ids,
            all_block_pos, cross_block_attention_mode)
        # [batch_size, main_seq_length, global_num_summaries]
        token_to_global_summary_att_map = tf.tile(
            tf.expand_dims(token_to_global_summary_att_map, 1),
            [1, main_seq_length, 1])

        # Do not attend over pad entity summaries
        # [1, 1, global_num_summaries]
        is_not_pad_summary = tf.expand_dims(
            tf.expand_dims(tf.cast(tf.not_equal(all_labels, 0), tf.int32), 0),
            0)
        token_to_global_summary_att_map *= is_not_pad_summary

        if self.use_sparse_memory_attention:
            if self.mode == "entity":
                # 2. Only allow entity mentions to attend summaries
                # [batch_size, max_num_annotations, 1]
                annotation_mask = tf.expand_dims(
                    tf.cast(tf.not_equal(annotation_labels, 0), tf.int32), -1)
                # [batch_size, max_num_annotations, main_seq_length]
                mask_begin = tf.sequence_mask(annotation_begins,
                                              main_seq_length,
                                              dtype=tf.int32)
                mask_end = tf.sequence_mask(annotation_ends + 1,
                                            main_seq_length,
                                            dtype=tf.int32)

                def make_mask(x):
                    x = x * annotation_mask
                    x = tf.reduce_sum(x, 1)
                    x = tf.minimum(x, 1)
                    return x

                # [batch_size, main_seq_length, 1]
                is_token_belongs_to_entity = tf.expand_dims(
                    make_mask(mask_end - mask_begin), -1)
                token_to_global_summary_att_map *= is_token_belongs_to_entity
            elif self.mode == "cls":
                # [batch_size, main_seq_length]
                only_cls_mask = tf.concat([
                    tf.cast(tf.fill(dims=[batch_size, 1], value=1),
                            dtype=tf.int32),
                    tf.cast(tf.fill(dims=[batch_size, main_seq_length - 1],
                                    value=0),
                            dtype=tf.int32)
                ],
                                          axis=1)
                # [batch_size, main_seq_length, 1]
                only_cls_mask = tf.expand_dims(only_cls_mask, -1)
                # [batch_size, main_seq_length, global_num_summaries]
                token_to_global_summary_att_map *= only_cls_mask
            elif self.mode == "text_block":
                # [main_seq_length]
                text_block_mask = tf.range(main_seq_length,
                                           delta=1,
                                           dtype=tf.int32)
                # [main_seq_length]
                text_block_mask = tf.math.floormod(
                    text_block_mask, self.text_block_extract_every_x)
                # [main_seq_length]
                text_block_mask = tf.cast(tf.equal(text_block_mask, 0),
                                          tf.int32)
                # [batch_size, main_seq_length]
                text_block_mask = tf.tile(tf.expand_dims(text_block_mask, 0),
                                          [batch_size, 1])
                # [batch_size, main_seq_length, 1]
                text_block_mask = tf.expand_dims(text_block_mask, -1)
                # [batch_size, main_seq_length, global_num_summaries]
                token_to_global_summary_att_map *= text_block_mask
            else:
                raise ValueError("Unknown summary mode: %s" % self.mode)

        return SummaryExtractionOutput(
            local_summary=Summary(states=first_token_tensor,
                                  processed_states=None,
                                  block_ids=block_ids,
                                  block_pos=block_pos,
                                  labels=labels),
            global_summary=Summary(states=all_first_token_tensor,
                                   processed_states=all_cls_summary,
                                   block_ids=all_block_ids,
                                   block_pos=all_block_pos,
                                   labels=all_labels),
            token_to_global_summary_att_map=token_to_global_summary_att_map)