def call(self, logits, annotation_begins, annotation_ends, annotation_labels, block_ids, num_replicas=None, eps=0): """Calls the layer. Args: logits: <float32>[batch_size, main_seq_len, 2] Logits per position. annotation_begins: <int32>[batch_size, main_seq_len] Positions of beginnings of answer spans. annotation_ends: <int32>[batch_size, main_seq_len] Positions of endings of answer spans. annotation_labels: <int32>[batch_size, main_seq_len] Positions of labels of answer spans. Label is 0 when the span is a placeholder one (included only for padding purposes) and should be ignored. block_ids: <int32>[batch_size] Block IDs of every sample in the batch. num_replicas: Number of replicas to gather summaries from. If None (default) then cross-replicas summaries are not used. eps: <float> Small constant for numerical stability. Returns: total_loss: <float> """ seq_length = tf.shape(logits)[1] # (1) Aggregate block_ids across global batch. Compute cross block mask. all_block_ids = block_ids if num_replicas: all_block_ids = tpu_utils.cross_replica_concat( tensor=all_block_ids, num_replicas=num_replicas, name='block_ids_concat') # [batch_size, global_batch_size] cross_blocks_eq_mask = tf.cast( tf.equal(tf.expand_dims(block_ids, 1), tf.expand_dims(all_block_ids, 0)), tf.float32) # (2) Apply softmax over all positions in the (global) batch # across the blocks with the same `block_id`. # [batch_size, seq_len, 2] probs = cross_batch_softmax(logits, cross_blocks_eq_mask, num_replicas) # (3) Prepare one-hot labels based on annotation begins and ends # [batch_size, seq_len, 1] annotation_begins_one_hot = _one_hot_multi( annotation_begins, annotation_labels > 0, seq_length, ) # [batch_size, seq_len, 1] annotation_ends_one_hot = _one_hot_multi( annotation_ends, annotation_labels > 0, seq_length, ) # [batch_size, seq_len, 2] one_hot_labels = tf.concat( [annotation_begins_one_hot, annotation_ends_one_hot], 2) # (4) Compute the probability of the current begin / end positions across # the blocks with the same `block_id`. # [batch_size, 2] correct_probs = tf.reduce_sum(probs * one_hot_labels, axis=1) if num_replicas: # [global_batch_size, 2] correct_probs = tpu_utils.cross_replica_concat( tensor=correct_probs, num_replicas=num_replicas, name='correct_probs_concat') # [batch_size, 2] correct_probs = tf.matmul(cross_blocks_eq_mask, correct_probs) # (5) Compute log probability. We allow cases when there are no correct # labels not only for the current sample, but for the whole document # across the whole batch. In that case the probability of the correct label # would be 0 and the loss would be infinite. Therefore, we just do not # compute loss on these documents. # [batch_size, 1] num_annotations_per_sample = tf.reduce_sum(annotation_labels, 1, keepdims=True) if num_replicas: # [global_batch_size, 1] num_annotations_per_sample = tpu_utils.cross_replica_concat( tensor=num_annotations_per_sample, num_replicas=num_replicas, name='num_annotations_per_sample_concat') # [batch_size, 1] num_annotations_per_doc = tf.matmul( cross_blocks_eq_mask, tf.cast(num_annotations_per_sample, tf.float32)) # [batch_size, 2] doc_with_annotations_mask = tf.stop_gradient( tf.cast(tf.tile(num_annotations_per_doc > 0, [1, 2]), tf.float32)) doc_without_annotations_mask = tf.stop_gradient( 1 - doc_with_annotations_mask) log_correct_probs = tf.log( correct_probs + eps + doc_without_annotations_mask) * doc_with_annotations_mask # (6) Divide by the number of blocks per block_id # If there are K blocks with the same block_id, then on step (4) we'll # compute loss for this document K times. So we need to divide it back by K. # [batch_size, 2] log_correct_probs /= tf.reduce_sum(cross_blocks_eq_mask, 1, keepdims=True) # (7) Sum over blocks and begin/end predictions loss = -tf.reduce_sum(log_correct_probs) return loss
def call(self, yesno_logits, yesno_labels, supporting_fact_logits, supporting_fact_labels, block_ids, num_replicas=None, eps=0): """Calls the layer. Args: yesno_logits: <float32>[batch_size, 3] Logits per position. supporting_fact_logits: <float32>[batch_size] Logits per position fro supporting facts classification. block_ids: <int32>[batch_size] Block IDs of every sample in the batch. num_replicas: Number of replicas to gather summaries from. If None (default) then cross-replicas summaries are not used. eps: <float> Small constant for numerical stability. Returns: total_loss: <float> """ batch_size = tf.shape(supporting_fact_logits)[0] supporting_fact_logits = tf.expand_dims(supporting_fact_logits, 1) supporting_fact_labels = tf.expand_dims(supporting_fact_labels, 1) example_mask = tf.cast(tf.expand_dims(tf.not_equal(block_ids, 0), 1), tf.float32) # (1) Aggregate block_ids across global batch. Compute cross block mask. all_block_ids = block_ids if num_replicas: all_block_ids = tpu_utils.cross_replica_concat( tensor=all_block_ids, num_replicas=num_replicas, name='block_ids_concat') # [batch_size, global_batch_size] cross_blocks_eq_mask = tf.cast( tf.equal(tf.expand_dims(block_ids, 1), tf.expand_dims(all_block_ids, 0)), tf.float32) # (2) Apply softmax over all positions in the (global) batch # across the blocks with the same `block_id`. # [batch_size, 3, 1] yes_no_span_probs = losses.cross_batch_softmax( tf.expand_dims(yesno_logits, 2), cross_blocks_eq_mask, num_replicas) yes_no_span_probs = tf.squeeze(yes_no_span_probs, 2) # [batch_size, 1] supporting_facts_probs = losses.cross_batch_softmax( tf.expand_dims(supporting_fact_logits, 2), cross_blocks_eq_mask, num_replicas) supporting_facts_probs = tf.squeeze(supporting_facts_probs, 2) # (3) Prepare one-hot labels based on annotation begins and ends supporting_fact_labels = tf.cast(supporting_fact_labels, tf.float32) # [batch_size, 3] yes_no_span_one_hot = tf.one_hot(yesno_labels, depth=3, dtype=tf.float32) yes_no_span_one_hot = yes_no_span_one_hot * supporting_fact_labels # (4) Compute the probability of the current begin / end positions across # the blocks with the same `block_id`. def mean_loss(all_losses): return tf.reduce_sum(all_losses * example_mask) / ( tf.reduce_sum(example_mask) + eps) supporting_facts_loss = -mean_loss( tf.log(supporting_facts_probs * supporting_fact_labels + eps)) yes_no_span_loss = -mean_loss( tf.log(yes_no_span_probs * yes_no_span_one_hot + eps)) return yes_no_span_loss, supporting_facts_loss
def cross_batch_softmax(logits, cross_blocks_eq_mask, num_replicas=None): """Computes softmax across the whole (global) batch. The computations are independent with respect to the 3rd, innermost dimension. In case of the span prediction, the size of this dimension is K=2, which corresponds to beginings and ends of annotations. Args: logits: <float32>[batch_size, seq_len, K] Tensor of logits. cross_blocks_eq_mask: <float32>[batch_size, global_batch_size] The mask which indicates which samples in the batch have the same block IDs. num_replicas: Optional[int]. If provided the function performs computations over the global (multi-devices) batch. Should be equal to the number of devices. Returns: probs: <float32>[batch_size, seq_len, K] """ # (1) Apply max-trick to improve softmax numerical stability. # [batch_size, K] max_logits_per_sample = tf.math.reduce_max(logits, axis=1) if num_replicas: # [global_batch_size, K] max_logits_per_sample = tpu_utils.cross_replica_concat( tensor=max_logits_per_sample, num_replicas=num_replicas, name='max_logits_per_sample_concat') # [1, global_batch_size, K] max_logits_per_sample = tf.expand_dims(max_logits_per_sample, 0) # [batch_size, global_batch_size, 1] one_minus_one_mask = 2 * tf.expand_dims(cross_blocks_eq_mask, 2) - 1 # [batch_size, global_batch_size, K] masked_max_logits_per_sample = tf.minimum(max_logits_per_sample, one_minus_one_mask * np.inf) # [batch_size, K] max_logits_per_sample = tf.reduce_max(masked_max_logits_per_sample, axis=1) # [batch_size, seq_len, K] logits -= tf.expand_dims(max_logits_per_sample, 1) # (2) Take exponent unnormalized_probs = tf.exp(logits) # (3) Compute softmax's denominator (normalization constant) # [batch_size, K] softmax_denominator_per_sample = tf.math.reduce_sum(unnormalized_probs, axis=1) if num_replicas: # [global_batch_size, K] softmax_denominator_per_sample = tpu_utils.cross_replica_concat( tensor=softmax_denominator_per_sample, num_replicas=num_replicas, name='softmax_denominator_per_sample_concat') # [batch_size, K] softmax_denominator_per_sample = tf.matmul(cross_blocks_eq_mask, softmax_denominator_per_sample) # (4) Compute probabilities # [batch_size, seq_len, K] probs = unnormalized_probs / tf.expand_dims(softmax_denominator_per_sample, 1) return probs
def call(self, hidden_states, block_ids, block_pos, annotation_begins, annotation_ends, annotation_labels, main_seq_length, num_replicas_concat, cross_block_attention_mode, training, token_ids=None): """Calls the layer. Args: hidden_states: <int32>[batch_size, main_seq_len, hidden size]. Final hidden states of the input after the first pass of the model. block_ids: <int32>[batch_size] Block IDs of every sample in the batch. block_pos: <int32>[batch_size] Optional Tensor of absolute position ids of blocks in the original document. annotation_begins: <int32>[batch_size, max_num_annotations] Begin index of annotations. annotation_ends: <int32>[batch_size, max_num_annotations] End index of annotations (inclusive) annotation_labels: <int32>[batch_size, max_num_annotations] Label for annotations. main_seq_length: Length of the input text num_replicas_concat: Number of replicas to gather summaries from. If None (default) then cross-replicas summaries are not used. cross_block_attention_mode: The policy on how summaries between different blocks are allowed to interact with each other. training: bool. true for training model, false for eval model. Controls whether dropout will be applied. token_ids: <int32>[batch_size, main_seq_len] Tokens. Returns: summary_output: SummaryExtractionOutput object """ # [batch_size, local_num_summaries, hidden_size] first_token_tensor = self._extract_summary( hidden_states=hidden_states, annotation_begins=annotation_begins, annotation_ends=annotation_ends, annotation_labels=annotation_labels, token_ids=token_ids) batch_size = tf.shape(first_token_tensor)[0] local_num_summaries = tf.shape(first_token_tensor)[1] original_block_ids = block_ids original_block_pos = block_pos block_ids = tf.tile(tf.expand_dims(block_ids, 1), [1, local_num_summaries]) block_pos = tf.tile(tf.expand_dims(block_pos, 1), [1, local_num_summaries]) first_token_tensor = tf.reshape( first_token_tensor, [batch_size * local_num_summaries, self.hidden_size]) block_ids = tf.reshape(block_ids, [batch_size * local_num_summaries]) block_pos = tf.reshape(block_pos, [batch_size * local_num_summaries]) all_first_token_tensor = first_token_tensor all_block_ids = block_ids all_block_pos = block_pos if num_replicas_concat: # Concatenate all the required tensors across tpu cores. all_first_token_tensor = tpu_utils.cross_replica_concat( tensor=all_first_token_tensor, num_replicas=num_replicas_concat, name="cls_token_concat") all_block_ids = tpu_utils.cross_replica_concat( tensor=all_block_ids, num_replicas=num_replicas_concat, name="block_ids_concat") all_block_ids = tf.stop_gradient(all_block_ids) all_block_pos = tpu_utils.cross_replica_concat( tensor=all_block_pos, num_replicas=num_replicas_concat, name="block_pos_concat") all_block_pos = tf.stop_gradient(all_block_pos) first_token_tensor.set_shape([None, self.hidden_size]) all_first_token_tensor.set_shape([None, self.hidden_size]) if self.mode == "cls": labels = block_ids all_labels = all_block_ids elif self.mode == "text_block": token_block_offset = self.text_block_extract_every_x - 1 token_block_len = self.text_block_extract_every_x labels = tf.cast( tf.logical_and( tf.not_equal(token_ids[:, ::token_block_len], 0), tf.not_equal( token_ids[:, token_block_offset::token_block_len], 0), ), tf.int32) labels = tf.reshape(labels, [batch_size * local_num_summaries]) all_labels = labels if num_replicas_concat: all_labels = tpu_utils.cross_replica_concat( tensor=all_labels, num_replicas=num_replicas_concat, name="labels_concat") all_labels = tf.stop_gradient(all_labels) else: assert self.mode == "entity" labels = tf.reshape(annotation_labels, [batch_size * local_num_summaries]) all_labels = labels if num_replicas_concat: all_labels = tpu_utils.cross_replica_concat( tensor=all_labels, num_replicas=num_replicas_concat, name="labels_concat") all_labels = tf.stop_gradient(all_labels) # TODO(urikz): Consider using this # Filter out all padding summaries -- the convention is that # padding summaries will have label 0. # non_padding_summary_mask = tf.not_equal(all_labels, 0) # all_first_token_tensor = tf.boolean_mask(all_first_token_tensor, # non_padding_summary_mask) # all_block_pos = tf.boolean_mask(all_block_pos, non_padding_summary_mask) # all_block_ids = tf.boolean_mask(all_block_ids, non_padding_summary_mask) # all_labels = tf.boolean_mask(all_labels, non_padding_summary_mask) if self.postprocessing_type == "none": all_cls_summary = all_first_token_tensor elif self.postprocessing_type == "linear": all_cls_summary = self.postprocessing(all_first_token_tensor) elif self.postprocessing_type in ["pos", "transformer"]: # We treat sequence of summaries as just a single sentence. # [1, global_num_summaries, hidden_dim] all_cls_summary = tf.expand_dims(all_first_token_tensor, 0) # Add positional embeddings based on positions of blocks in their # original documents. all_cls_summary += self.position_embedding(all_block_pos) all_cls_summary = self.embedding_norm(all_cls_summary) # Note, we don't apply dropout here # all_cls_summary = self.embedding_dropout( # all_cls_summary, training=training) if self.postprocessing_type == "transformer": # Create cross block attention map # according to the `cross_block_attention_mode`. # [global_num_summaries, global_num_summaries] block_att_mask = get_cross_block_att( all_block_ids, all_block_pos, all_block_ids, all_block_pos, cross_block_attention_mode) # [1, global_num_summaries, global_num_summaries] block_att_mask = tf.expand_dims(block_att_mask, 0) all_cls_summary = self.postprocessing( main_input=all_cls_summary, side_input=None, att_mask=block_att_mask, training=training) all_cls_summary = tf.squeeze(all_cls_summary, 0) else: raise ValueError("Unknown `postprocessing_type`: '{}'".format( self.postprocessing_type)) # [batch_size, global_num_summaries] token_to_global_summary_att_map = get_cross_block_att( original_block_ids, original_block_pos, all_block_ids, all_block_pos, cross_block_attention_mode) # [batch_size, main_seq_length, global_num_summaries] token_to_global_summary_att_map = tf.tile( tf.expand_dims(token_to_global_summary_att_map, 1), [1, main_seq_length, 1]) # Do not attend over pad entity summaries # [1, 1, global_num_summaries] is_not_pad_summary = tf.expand_dims( tf.expand_dims(tf.cast(tf.not_equal(all_labels, 0), tf.int32), 0), 0) token_to_global_summary_att_map *= is_not_pad_summary if self.use_sparse_memory_attention: if self.mode == "entity": # 2. Only allow entity mentions to attend summaries # [batch_size, max_num_annotations, 1] annotation_mask = tf.expand_dims( tf.cast(tf.not_equal(annotation_labels, 0), tf.int32), -1) # [batch_size, max_num_annotations, main_seq_length] mask_begin = tf.sequence_mask(annotation_begins, main_seq_length, dtype=tf.int32) mask_end = tf.sequence_mask(annotation_ends + 1, main_seq_length, dtype=tf.int32) def make_mask(x): x = x * annotation_mask x = tf.reduce_sum(x, 1) x = tf.minimum(x, 1) return x # [batch_size, main_seq_length, 1] is_token_belongs_to_entity = tf.expand_dims( make_mask(mask_end - mask_begin), -1) token_to_global_summary_att_map *= is_token_belongs_to_entity elif self.mode == "cls": # [batch_size, main_seq_length] only_cls_mask = tf.concat([ tf.cast(tf.fill(dims=[batch_size, 1], value=1), dtype=tf.int32), tf.cast(tf.fill(dims=[batch_size, main_seq_length - 1], value=0), dtype=tf.int32) ], axis=1) # [batch_size, main_seq_length, 1] only_cls_mask = tf.expand_dims(only_cls_mask, -1) # [batch_size, main_seq_length, global_num_summaries] token_to_global_summary_att_map *= only_cls_mask elif self.mode == "text_block": # [main_seq_length] text_block_mask = tf.range(main_seq_length, delta=1, dtype=tf.int32) # [main_seq_length] text_block_mask = tf.math.floormod( text_block_mask, self.text_block_extract_every_x) # [main_seq_length] text_block_mask = tf.cast(tf.equal(text_block_mask, 0), tf.int32) # [batch_size, main_seq_length] text_block_mask = tf.tile(tf.expand_dims(text_block_mask, 0), [batch_size, 1]) # [batch_size, main_seq_length, 1] text_block_mask = tf.expand_dims(text_block_mask, -1) # [batch_size, main_seq_length, global_num_summaries] token_to_global_summary_att_map *= text_block_mask else: raise ValueError("Unknown summary mode: %s" % self.mode) return SummaryExtractionOutput( local_summary=Summary(states=first_token_tensor, processed_states=None, block_ids=block_ids, block_pos=block_pos, labels=labels), global_summary=Summary(states=all_first_token_tensor, processed_states=all_cls_summary, block_ids=all_block_ids, block_pos=all_block_pos, labels=all_labels), token_to_global_summary_att_map=token_to_global_summary_att_map)