def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> None: # shape (batch_size, num_spans) span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)] if span_indices_mask is not None: # It's not strictly necessary to multiply the span indices by the mask here, # but it's possible that the span representation was padded with something other # than 0 (such as -1, which would be an invalid index), so we do so anyway to # be safe. span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask if not self._use_exclusive_start_indices: start_embeddings = util.batched_index_select(sequence_tensor, span_starts) end_embeddings = util.batched_index_select(sequence_tensor, span_ends) else: # We want `exclusive` span starts, so we remove 1 from the forward span starts # as the AllenNLP ``SpanField`` is inclusive. # shape (batch_size, num_spans) exclusive_span_starts = span_starts - 1 # shape (batch_size, num_spans, 1) start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1) exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1)) # We'll check the indices here at runtime, because it's difficult to debug # if this goes wrong and it's tricky to get right. if (exclusive_span_starts < 0).any(): raise ValueError(f"Adjusted span indices must lie inside the the sequence tensor, " f"but found: exclusive_span_starts: {exclusive_span_starts}.") start_embeddings = util.batched_index_select(sequence_tensor, exclusive_span_starts) end_embeddings = util.batched_index_select(sequence_tensor, span_ends) # We're using sentinels, so we need to replace all the elements which were # outside the dimensions of the sequence_tensor with the start sentinel. float_start_sentinel_mask = start_sentinel_mask.float() start_embeddings = start_embeddings * (1 - float_start_sentinel_mask) \ + float_start_sentinel_mask * self._start_sentinel combined_tensors = util.combine_tensors(self._combination, [start_embeddings, end_embeddings]) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = util.bucket_values(span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([combined_tensors, span_width_embeddings], -1) if span_indices_mask is not None: return combined_tensors * span_indices_mask.unsqueeze(-1).float() return combined_tensors
def _compute_span_pair_embeddings(self, top_span_embeddings: torch.FloatTensor, antecedent_embeddings: torch.FloatTensor, antecedent_offsets: torch.FloatTensor): """ Computes an embedding representation of pairs of spans for the pairwise scoring function to consider. This includes both the original span representations, the element-wise similarity of the span representations, and an embedding representation of the distance between the two spans. Parameters ---------- top_span_embeddings : ``torch.FloatTensor``, required. Embedding representations of the top spans. Has shape (batch_size, num_spans_to_keep, embedding_size). antecedent_embeddings : ``torch.FloatTensor``, required. Embedding representations of the antecedent spans we are considering for each top span. Has shape (batch_size, num_spans_to_keep, max_antecedents, embedding_size). antecedent_offsets : ``torch.IntTensor``, required. The offsets between each top span and its antecedent spans in terms of spans we are considering. Has shape (1, max_antecedents). Returns ------- span_pair_embeddings : ``torch.FloatTensor`` Embedding representation of the pair of spans to consider. Has shape (batch_size, num_spans_to_keep, max_antecedents, embedding_size) """ # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) target_embeddings = top_span_embeddings.unsqueeze(2).expand_as( antecedent_embeddings) # Shape: (1, max_antecedents, embedding_size) # Shape (coarse-to-fine pruning): (batch_size, num_spans_to_keep, max_antecedents, embedding_size) antecedent_distance_embeddings = self._distance_embedding( util.bucket_values(antecedent_offsets, num_total_buckets=self._num_distance_buckets)) if not self._do_coarse_to_fine_prune: # Shape: (1, 1, max_antecedents, embedding_size) antecedent_distance_embeddings = antecedent_distance_embeddings.unsqueeze( 0) expanded_distance_embeddings_shape = ( antecedent_embeddings.size(0), antecedent_embeddings.size(1), antecedent_embeddings.size(2), antecedent_distance_embeddings.size(-1)) # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) antecedent_distance_embeddings = antecedent_distance_embeddings.expand( *expanded_distance_embeddings_shape) # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = torch.cat([ target_embeddings, antecedent_embeddings, antecedent_embeddings * target_embeddings, antecedent_distance_embeddings ], -1) return span_pair_embeddings
def _compute_span_pair_embeddings(self, top_span_embeddings: torch.FloatTensor, antecedent_embeddings: torch.FloatTensor, antecedent_offsets: torch.FloatTensor): """ Computes an embedding representation of pairs of spans for the pairwise scoring function to consider. This includes both the original span representations, the element-wise similarity of the span representations, and an embedding representation of the distance between the two spans. Parameters ---------- top_span_embeddings : ``torch.FloatTensor``, required. Embedding representations of the top spans. Has shape (batch_size, num_spans_to_keep, embedding_size). antecedent_embeddings : ``torch.FloatTensor``, required. Embedding representations of the antecedent spans we are considering for each top span. Has shape (batch_size, num_spans_to_keep, max_antecedents, embedding_size). antecedent_offsets : ``torch.IntTensor``, required. The offsets between each top span and its antecedent spans in terms of spans we are considering. Has shape (1, max_antecedents). Returns ------- span_pair_embeddings : ``torch.FloatTensor`` Embedding representation of the pair of spans to consider. Has shape (batch_size, num_spans_to_keep, max_antecedents, embedding_size) """ # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) target_embeddings = top_span_embeddings.unsqueeze(2).expand_as(antecedent_embeddings) # Shape: (1, max_antecedents, embedding_size) antecedent_distance_embeddings = self._distance_embedding( util.bucket_values(antecedent_offsets, num_total_buckets=self._num_distance_buckets)) # Shape: (1, 1, max_antecedents, embedding_size) antecedent_distance_embeddings = antecedent_distance_embeddings.unsqueeze(0) expanded_distance_embeddings_shape = (antecedent_embeddings.size(0), antecedent_embeddings.size(1), antecedent_embeddings.size(2), antecedent_distance_embeddings.size(-1)) # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) antecedent_distance_embeddings = antecedent_distance_embeddings.expand(*expanded_distance_embeddings_shape) # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = torch.cat([target_embeddings, antecedent_embeddings, antecedent_embeddings * target_embeddings, antecedent_distance_embeddings], -1) return span_pair_embeddings
def _compute_span_pair_embeddings( self, top_span_embeddings: torch.FloatTensor, antecedent_embeddings: torch.FloatTensor, antecedent_offsets: torch.FloatTensor, ): """ Computes an embedding representation of pairs of spans for the pairwise scoring function to consider. This includes both the original span representations, the element-wise similarity of the span representations, and an embedding representation of the distance between the two spans. # Parameters top_span_embeddings : `torch.FloatTensor`, required. Embedding representations of the top spans. Has shape (batch_size, num_spans_to_keep, embedding_size). antecedent_embeddings : `torch.FloatTensor`, required. Embedding representations of the antecedent spans we are considering for each top span. Has shape (batch_size, num_spans_to_keep, max_antecedents, embedding_size). antecedent_offsets : `torch.IntTensor`, required. The offsets between each top span and its antecedent spans in terms of spans we are considering. Has shape (batch_size, num_spans_to_keep, max_antecedents). # Returns span_pair_embeddings : `torch.FloatTensor` Embedding representation of the pair of spans to consider. Has shape (batch_size, num_spans_to_keep, max_antecedents, embedding_size) """ # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) target_embeddings = top_span_embeddings.unsqueeze(2).expand_as( antecedent_embeddings) # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) antecedent_distance_embeddings = self._distance_embedding( util.bucket_values(antecedent_offsets, num_total_buckets=self._num_distance_buckets)) # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = torch.cat( [ target_embeddings, antecedent_embeddings, antecedent_embeddings * target_embeddings, antecedent_distance_embeddings, ], -1, ) return span_pair_embeddings
def forward( self, span: torch.Tensor, # SHAPE: (batch_size, num_spans, span_dim) span_pairs: torch.LongTensor # SHAPE: (batch_size, num_span_pairs) ): span1 = span2 = span if self.dim_reduce_layer1 is not None: span1 = self.dim_reduce_layer1(span) if self.dim_reduce_layer2 is not None: span2 = self.dim_reduce_layer2(span) if not self.pair: return span1, span2 num_spans = span.size(1) # get span pair embedding span_pairs_p = span_pairs[:, :, 0] span_pairs_c = span_pairs[:, :, 1] # SHAPE: (batch_size * num_span_pairs) flat_span_pairs_p = util.flatten_and_batch_shift_indices( span_pairs_p, num_spans) flat_span_pairs_c = util.flatten_and_batch_shift_indices( span_pairs_c, num_spans) # SHAPE: (batch_size, num_span_pairs, span_dim) span_pair_p_emb = util.batched_index_select(span1, span_pairs_p, flat_span_pairs_p) span_pair_c_emb = util.batched_index_select(span2, span_pairs_c, flat_span_pairs_c) if self.combine == 'concat': # SHAPE: (batch_size, num_span_pairs, span_dim * 2) span_pair_emb = torch.cat([span_pair_p_emb, span_pair_c_emb], -1) elif self.combine == 'coref': # use the indices gap as distance, which requires the indices to be consistent # with the order they appear in the sentences distance = span_pairs_p - span_pairs_c # SHAPE: (batch_size, num_span_pairs, dist_emb_dim) distance_embeddings = self.distance_embedding( util.bucket_values( distance, num_total_buckets=self.num_distance_buckets)) # SHAPE: (batch_size, num_span_pairs, span_dim * 3) span_pair_emb = torch.cat([ span_pair_p_emb, span_pair_c_emb, span_pair_p_emb * span_pair_c_emb, distance_embeddings ], -1) if self.repr_layer is not None: # SHAPE: (batch_size, num_span_pairs, out_dim) span_pair_emb = self.repr_layer(span_pair_emb) return span_pair_emb
def _compute_distance_embeddings(self, top_trig_spans, top_arg_spans): """ Compute integer distance and positional features of two tensors of span interval indices of different size. Embeds the bucketed distance values. :param top_trigger_spans: Size (batch_size, num_trig_spans, 2), 2 for (trigger_start_index, trigger_end_index) :param top_arg_spans: Size (batch_size, num_arg_spans, 2), 2 for (arg_start_index, arg_end_index) return res: Tensor of size (batch_size, num_args consists of concat of: - dist: size (batch_size, num_trig_spans, num_arg_spans) containing the integer distance values. - trigger_before_feature: size idem size, boolean feature indicating that trigger comes before argument. - trigger_overlap_feature: size idem, boolean feature indicating that trigger overlaps with argument. """ # tile the span matrices num_trigs = top_trig_spans.size(1) num_spans = top_arg_spans.size(1) trig_span_tiled = top_trig_spans.unsqueeze(2).repeat( 1, 1, num_spans, 1) arg_span_tiled = top_arg_spans.unsqueeze(1).repeat(1, num_trigs, 1, 1) # get start_idc and end_idc trig_span_starts = trig_span_tiled[:, :, :, 0] trig_span_ends = trig_span_tiled[:, :, :, 1] arg_span_starts = arg_span_tiled[:, :, :, 0] arg_span_ends = arg_span_tiled[:, :, :, 1] # compute all distances, abs().min is the correct dist value dist_start2end = trig_span_starts - arg_span_ends dist_end2start = trig_span_ends - arg_span_starts dist = torch.min(dist_start2end.abs(), dist_end2start.abs()) # When the trigger overlaps with the arg span, also set the distance to zero. # Overlap happens when the trigger is not outside before or after the span. trigger_before = (trig_span_starts <= trig_span_ends) & (trig_span_ends < arg_span_starts) trigger_after = (arg_span_starts <= arg_span_ends) & (arg_span_ends < trig_span_starts) trigger_overlap = ~(trigger_before | trigger_after) dist[trigger_overlap] = 0 # compute bucketed embeddings and add before, overlap boolean feature dist_buckets = util.bucket_values(dist, self._num_distance_buckets) dist_emb = self._distance_embedding(dist_buckets) trigger_before_feature = trigger_before.float().unsqueeze(-1) trigger_overlap_feature = trigger_overlap.float().unsqueeze(-1) res = torch.cat( [dist_emb, trigger_before_feature, trigger_overlap_feature], dim=-1) return res
def _compute_distance_embeddings(self, top_trig_indices, top_arg_spans): top_trig_ixs = top_trig_indices.unsqueeze(2) arg_span_starts = top_arg_spans[:, :, 0].unsqueeze(1) arg_span_ends = top_arg_spans[:, :, 1].unsqueeze(1) dist_from_start = top_trig_ixs - arg_span_starts dist_from_end = top_trig_ixs - arg_span_ends # Distance from trigger to arg. dist = torch.min(dist_from_start.abs(), dist_from_end.abs()) # When the trigger is inside the arg span, also set the distance to zero. trigger_inside = (top_trig_ixs >= arg_span_starts) & (top_trig_ixs <= arg_span_ends) dist[trigger_inside] = 0 dist_buckets = util.bucket_values(dist, self._num_distance_buckets) dist_emb = self._distance_embedding(dist_buckets) trigger_before_feature = (top_trig_ixs < arg_span_starts).float().unsqueeze(-1) trigger_inside_feature = trigger_inside.float().unsqueeze(-1) res = torch.cat([dist_emb, trigger_before_feature, trigger_inside_feature], dim=-1) return res
def forward(self, # pylint: disable=arguments-differ sequence_tensor: torch.FloatTensor, indicies: torch.LongTensor) -> None: # shape (batch_size, num_spans) span_starts, span_ends = [index.squeeze(-1) for index in indicies.split(1, dim=-1)] start_embeddings = batched_index_select(sequence_tensor, span_starts) end_embeddings = batched_index_select(sequence_tensor, span_ends) combined_tensors = combine_tensors(self._combination, [start_embeddings, end_embeddings]) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = bucket_values(span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([combined_tensors, span_width_embeddings], -1) return combined_tensors
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> None: # shape (batch_size, num_spans) span_starts, span_ends = [ index.squeeze(-1) for index in span_indices.split(1, dim=-1) ] if span_indices_mask is not None: # It's not strictly necessary to multiply the span indices by the mask here, # but it's possible that the span representation was padded with something other # than 0 (such as -1, which would be an invalid index), so we do so anyway to # be safe. span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask start_embeddings = batched_index_select(sequence_tensor, span_starts) end_embeddings = batched_index_select(sequence_tensor, span_ends) combined_tensors = combine_tensors(self._combination, [start_embeddings, end_embeddings]) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = bucket_values( span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([combined_tensors, span_width_embeddings], -1) if span_indices_mask is not None: return combined_tensors * span_indices_mask.unsqueeze(-1).float() return combined_tensors
def forward( self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.BoolTensor = None, span_indices_mask: torch.BoolTensor = None, ): """ Given a sequence tensor, extract spans, concatenate width embeddings when need and return representations of them. # Parameters sequence_tensor : `torch.FloatTensor`, required. A tensor of shape (batch_size, sequence_length, embedding_size) representing an embedded sequence of words. span_indices : `torch.LongTensor`, required. A tensor of shape `(batch_size, num_spans, 2)`, where the last dimension represents the inclusive start and end indices of the span to be extracted from the `sequence_tensor`. sequence_mask : `torch.BoolTensor`, optional (default = `None`). A tensor of shape (batch_size, sequence_length) representing padded elements of the sequence. span_indices_mask : `torch.BoolTensor`, optional (default = `None`). A tensor of shape (batch_size, num_spans) representing the valid spans in the `indices` tensor. This mask is optional because sometimes it's easier to worry about masking after calling this function, rather than passing a mask directly. # Returns A tensor of shape `(batch_size, num_spans, embedded_span_size)`, where `embedded_span_size` depends on the way spans are represented. """ # shape (batch_size, num_spans, embedding_dim) span_embeddings = self._embed_spans(sequence_tensor, span_indices, sequence_mask, span_indices_mask) if self._span_width_embedding is not None: # width = end_index - start_index + 1 since `SpanField` use inclusive indices. # But here we do not add 1 beacuse we often initiate the span width # embedding matrix with `num_width_embeddings = max_span_width` # shape (batch_size, num_spans) widths_minus_one = span_indices[..., 1] - span_indices[..., 0] if self._bucket_widths: widths_minus_one = util.bucket_values( widths_minus_one, num_total_buckets=self. _num_width_embeddings # type: ignore ) # Embed the span widths and concatenate to the rest of the representations. span_width_embeddings = self._span_width_embedding( widths_minus_one) span_embeddings = torch.cat( [span_embeddings, span_width_embeddings], -1) if span_indices_mask is not None: # Here we are masking the spans which were originally passed in as padding. return span_embeddings * span_indices_mask.unsqueeze(-1) return span_embeddings
def forward( self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor: # Both of shape (batch_size, sequence_length, embedding_size / 2) forward_sequence, backward_sequence = sequence_tensor.split(int( self._input_dim / 2), dim=-1) forward_sequence = forward_sequence.contiguous() backward_sequence = backward_sequence.contiguous() # shape (batch_size, num_spans) span_starts, span_ends = [ index.squeeze(-1) for index in span_indices.split(1, dim=-1) ] if span_indices_mask is not None: span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask # We want `exclusive` span starts, so we remove 1 from the forward span starts # as the AllenNLP ``SpanField`` is inclusive. # shape (batch_size, num_spans) exclusive_span_starts = span_starts - 1 # shape (batch_size, num_spans, 1) start_sentinel_mask = ( exclusive_span_starts == -1).long().unsqueeze(-1) # We want `exclusive` span ends for the backward direction # (so that the `start` of the span in that direction is exlusive), so # we add 1 to the span ends as the AllenNLP ``SpanField`` is inclusive. exclusive_span_ends = span_ends + 1 if sequence_mask is not None: # shape (batch_size) sequence_lengths = util.get_lengths_from_binary_sequence_mask( sequence_mask) else: # shape (batch_size), filled with the sequence length size of the sequence_tensor. sequence_lengths = util.ones_like( sequence_tensor[:, 0, 0]).long() * sequence_tensor.size(1) # shape (batch_size, num_spans, 1) end_sentinel_mask = (exclusive_span_ends == sequence_lengths.unsqueeze( -1)).long().unsqueeze(-1) # As we added 1 to the span_ends to make them exclusive, which might have caused indices # equal to the sequence_length to become out of bounds, we multiply by the inverse of the # end_sentinel mask to erase these indices (as we will replace them anyway in the block below). # The same argument follows for the exclusive span start indices. exclusive_span_ends = exclusive_span_ends * ( 1 - end_sentinel_mask.squeeze(-1)) exclusive_span_starts = exclusive_span_starts * ( 1 - start_sentinel_mask.squeeze(-1)) # We'll check the indices here at runtime, because it's difficult to debug # if this goes wrong and it's tricky to get right. if (exclusive_span_starts < 0).any() or ( exclusive_span_ends > sequence_lengths.unsqueeze(-1)).any(): raise ValueError( f"Adjusted span indices must lie inside the length of the sequence tensor, " f"but found: exclusive_span_starts: {exclusive_span_starts}, " f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths " f"{sequence_lengths}.") # Forward Direction: start indices are exclusive. Shape (batch_size, num_spans, input_size / 2) forward_start_embeddings = util.batched_index_select( forward_sequence, exclusive_span_starts) # Forward Direction: end indices are inclusive, so we can just use span_ends. # Shape (batch_size, num_spans, input_size / 2) forward_end_embeddings = util.batched_index_select( forward_sequence, span_ends) # Backward Direction: The backward start embeddings use the `forward` end # indices, because we are going backwards. # Shape (batch_size, num_spans, input_size / 2) backward_start_embeddings = util.batched_index_select( backward_sequence, exclusive_span_ends) # Backward Direction: The backward end embeddings use the `forward` start # indices, because we are going backwards. # Shape (batch_size, num_spans, input_size / 2) backward_end_embeddings = util.batched_index_select( backward_sequence, span_starts) if self._use_sentinels: # If we're using sentinels, we need to replace all the elements which were # outside the dimensions of the sequence_tensor with either the start sentinel, # or the end sentinel. float_end_sentinel_mask = end_sentinel_mask.float() float_start_sentinel_mask = start_sentinel_mask.float() forward_start_embeddings = forward_start_embeddings * (1 - float_start_sentinel_mask) \ + float_start_sentinel_mask * self._start_sentinel backward_start_embeddings = backward_start_embeddings * (1 - float_end_sentinel_mask) \ + float_end_sentinel_mask * self._end_sentinel # Now we combine the forward and backward spans in the manner specified by the # respective combinations and concatenate these representations. # Shape (batch_size, num_spans, forward_combination_dim) forward_spans = util.combine_tensors( self._forward_combination, [forward_start_embeddings, forward_end_embeddings]) # Shape (batch_size, num_spans, backward_combination_dim) backward_spans = util.combine_tensors( self._backward_combination, [backward_start_embeddings, backward_end_embeddings]) # Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim) span_embeddings = torch.cat([forward_spans, backward_spans], -1) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = util.bucket_values( span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([span_embeddings, span_width_embeddings], -1) if span_indices_mask is not None: return span_embeddings * span_indices_mask.float().unsqueeze(-1) return span_embeddings
def test_bucket_values(self): indices = torch.LongTensor([1, 2, 7, 1, 56, 900]) bucketed_distances = util.bucket_values(indices) numpy.testing.assert_array_equal(bucketed_distances.numpy(), numpy.array([1, 2, 5, 1, 8, 9]))
def forward( self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None, ) -> None: # shape (batch_size, num_spans) span_starts, span_ends = [ index.squeeze(-1) for index in span_indices.split(1, dim=-1) ] if span_indices_mask is not None: # It's not strictly necessary to multiply the span indices by the mask here, # but it's possible that the span representation was padded with something other # than 0 (such as -1, which would be an invalid index), so we do so anyway to # be safe. span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask if not self._use_exclusive_start_indices: if sequence_tensor.size(-1) != self._input_dim: raise ValueError( f"Dimension mismatch expected ({sequence_tensor.size(-1)}) " f"received ({self._input_dim}).") start_embeddings = util.batched_index_select( sequence_tensor, span_starts) end_embeddings = util.batched_index_select(sequence_tensor, span_ends) else: # We want `exclusive` span starts, so we remove 1 from the forward span starts # as the AllenNLP `SpanField` is inclusive. # shape (batch_size, num_spans) exclusive_span_starts = span_starts - 1 # shape (batch_size, num_spans, 1) start_sentinel_mask = ( exclusive_span_starts == -1).long().unsqueeze(-1) exclusive_span_starts = exclusive_span_starts * ( 1 - start_sentinel_mask.squeeze(-1)) # We'll check the indices here at runtime, because it's difficult to debug # if this goes wrong and it's tricky to get right. if (exclusive_span_starts < 0).any(): raise ValueError( f"Adjusted span indices must lie inside the the sequence tensor, " f"but found: exclusive_span_starts: {exclusive_span_starts}." ) start_embeddings = util.batched_index_select( sequence_tensor, exclusive_span_starts) end_embeddings = util.batched_index_select(sequence_tensor, span_ends) # We're using sentinels, so we need to replace all the elements which were # outside the dimensions of the sequence_tensor with the start sentinel. float_start_sentinel_mask = start_sentinel_mask.float() start_embeddings = ( start_embeddings * (1 - float_start_sentinel_mask) + float_start_sentinel_mask * self._start_sentinel) combined_tensors = util.combine_tensors( self._combination, [start_embeddings, end_embeddings]) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = util.bucket_values( span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) combined_tensors = torch.cat( [combined_tensors, span_width_embeddings], -1) if span_indices_mask is not None: return combined_tensors * span_indices_mask.unsqueeze(-1).float() return combined_tensors
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor: # Both of shape (batch_size, sequence_length, embedding_size / 2) forward_sequence, backward_sequence = sequence_tensor.split(int(self._input_dim / 2), dim=-1) forward_sequence = forward_sequence.contiguous() backward_sequence = backward_sequence.contiguous() # shape (batch_size, num_spans) span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)] if span_indices_mask is not None: span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask # We want `exclusive` span starts, so we remove 1 from the forward span starts # as the AllenNLP ``SpanField`` is inclusive. # shape (batch_size, num_spans) exclusive_span_starts = span_starts - 1 # shape (batch_size, num_spans, 1) start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1) # We want `exclusive` span ends for the backward direction # (so that the `start` of the span in that direction is exlusive), so # we add 1 to the span ends as the AllenNLP ``SpanField`` is inclusive. exclusive_span_ends = span_ends + 1 if sequence_mask is not None: # shape (batch_size) sequence_lengths = util.get_lengths_from_binary_sequence_mask(sequence_mask) else: # shape (batch_size), filled with the sequence length size of the sequence_tensor. sequence_lengths = util.ones_like(sequence_tensor[:, 0, 0]).long() * sequence_tensor.size(1) # shape (batch_size, num_spans, 1) end_sentinel_mask = (exclusive_span_ends == sequence_lengths.unsqueeze(-1)).long().unsqueeze(-1) # As we added 1 to the span_ends to make them exclusive, which might have caused indices # equal to the sequence_length to become out of bounds, we multiply by the inverse of the # end_sentinel mask to erase these indices (as we will replace them anyway in the block below). # The same argument follows for the exclusive span start indices. exclusive_span_ends = exclusive_span_ends * (1 - end_sentinel_mask.squeeze(-1)) exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1)) # We'll check the indices here at runtime, because it's difficult to debug # if this goes wrong and it's tricky to get right. if (exclusive_span_starts < 0).any() or (exclusive_span_ends > sequence_lengths.unsqueeze(-1)).any(): raise ValueError(f"Adjusted span indices must lie inside the length of the sequence tensor, " f"but found: exclusive_span_starts: {exclusive_span_starts}, " f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths " f"{sequence_lengths}.") # Forward Direction: start indices are exclusive. Shape (batch_size, num_spans, input_size / 2) forward_start_embeddings = util.batched_index_select(forward_sequence, exclusive_span_starts) # Forward Direction: end indices are inclusive, so we can just use span_ends. # Shape (batch_size, num_spans, input_size / 2) forward_end_embeddings = util.batched_index_select(forward_sequence, span_ends) # Backward Direction: The backward start embeddings use the `forward` end # indices, because we are going backwards. # Shape (batch_size, num_spans, input_size / 2) backward_start_embeddings = util.batched_index_select(backward_sequence, exclusive_span_ends) # Backward Direction: The backward end embeddings use the `forward` start # indices, because we are going backwards. # Shape (batch_size, num_spans, input_size / 2) backward_end_embeddings = util.batched_index_select(backward_sequence, span_starts) if self._use_sentinels: # If we're using sentinels, we need to replace all the elements which were # outside the dimensions of the sequence_tensor with either the start sentinel, # or the end sentinel. float_end_sentinel_mask = end_sentinel_mask.float() float_start_sentinel_mask = start_sentinel_mask.float() forward_start_embeddings = forward_start_embeddings * (1 - float_start_sentinel_mask) \ + float_start_sentinel_mask * self._start_sentinel backward_start_embeddings = backward_start_embeddings * (1 - float_end_sentinel_mask) \ + float_end_sentinel_mask * self._end_sentinel # Now we combine the forward and backward spans in the manner specified by the # respective combinations and concatenate these representations. # Shape (batch_size, num_spans, forward_combination_dim) forward_spans = util.combine_tensors(self._forward_combination, [forward_start_embeddings, forward_end_embeddings]) # Shape (batch_size, num_spans, backward_combination_dim) backward_spans = util.combine_tensors(self._backward_combination, [backward_start_embeddings, backward_end_embeddings]) # Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim) span_embeddings = torch.cat([forward_spans, backward_spans], -1) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = util.bucket_values(span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([span_embeddings, span_width_embeddings], -1) if span_indices_mask is not None: return span_embeddings * span_indices_mask.float().unsqueeze(-1) return span_embeddings
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor: # shape (batch_size, num_spans) # span_starts, span_ends = span_indices.split(1, dim=-1) # batch_size, max_seq_len, _ = sequence_tensor.shape # max_span_num = span_indices.shape[1] # range_vector = util.get_range_vector(max_seq_len, util.get_device_of(sequence_tensor)).repeat( # (batch_size, max_span_num, 1)) # att_mask = (span_ends >= range_vector) - (span_starts > range_vector) # att_mask = att_mask * span_mask.unsqueeze(-1) # res = self._attention(sequence_tensor.repeat((max_span_num,1,1)), att_mask) # combined_tensors = util.combine_tensors(self._combination, [start_embeddings, end_embeddings]) # both of shape (batch_size, num_spans, 1) span_starts, span_ends = span_indices.split(1, dim=-1) # shape (batch_size, num_spans, 1) # These span widths are off by 1, because the span ends are `inclusive`. span_widths = span_ends - span_starts # We need to know the maximum span width so we can # generate indices to extract the spans from the sequence tensor. # These indices will then get masked below, such that if the length # of a given span is smaller than the max, the rest of the values # are masked. max_batch_span_width = span_widths.max().item() + 1 # shape (batch_size, sequence_length, 1) # global_attention_logits = self._global_attention(sequence_tensor) # Shape: (1, 1, max_batch_span_width) max_span_range_indices = util.get_range_vector(max_batch_span_width, util.get_device_of(sequence_tensor)).view(1, 1, -1) # Shape: (batch_size, num_spans, max_batch_span_width) # This is a broadcasted comparison - for each span we are considering, # we are creating a range vector of size max_span_width, but masking values # which are greater than the actual length of the span. # # We're using <= here (and for the mask below) because the span ends are # inclusive, so we want to include indices which are equal to span_widths rather # than using it as a non-inclusive upper bound. span_mask = (max_span_range_indices <= span_widths).float() raw_span_indices = span_ends - max_span_range_indices # We also don't want to include span indices which are less than zero, # which happens because some spans near the beginning of the sequence # have an end index < max_batch_span_width, so we add this to the mask here. span_mask = span_mask * (raw_span_indices >= 0).float() span_indices = torch.nn.functional.relu(raw_span_indices.float()).long() # Shape: (batch_size * num_spans * max_batch_span_width) flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1)) # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim) span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices) span_embeddings = span_embeddings * span_mask.unsqueeze(-1) span_embeddings = span_embeddings.max(2)[0] if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = util.bucket_values(span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_widths = span_widths.squeeze(-1) span_width_embeddings = self._span_width_embedding(span_widths) combined_tensors = torch.cat([span_embeddings, span_width_embeddings], -1) else: combined_tensors = span_embeddings if span_indices_mask is not None: return combined_tensors * span_indices_mask.unsqueeze(-1).float() return combined_tensors
def forward( self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.BoolTensor = None, span_indices_mask: torch.BoolTensor = None, ) -> torch.FloatTensor: forward_sequence, backward_sequence = sequence_tensor.split( int(self._input_dim / 2), dim=-1 ) span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)] if span_indices_mask is not None: span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask exclusive_span_starts = span_starts - 1 start_sentinel_mask = (exclusive_span_starts == -1).unsqueeze(-1) exclusive_span_ends = span_ends + 1 if sequence_mask is not None: else: sequence_lengths = torch.ones_like( sequence_tensor[:, 0, 0], dtype=torch.long end_sentinel_mask = (exclusive_span_ends >= sequence_lengths.unsqueeze(-1)).unsqueeze(-1) exclusive_span_ends = exclusive_span_ends * ~end_sentinel_mask.squeeze(-1) exclusive_span_starts = exclusive_span_starts * ~start_sentinel_mask.squeeze(-1) if (exclusive_span_starts < 0).any() or ( exclusive_span_ends > sequence_lengths.unsqueeze(-1) ).any(): raise ValueError( f"Adjusted span indices must lie inside the length of the sequence tensor, " f"but found: exclusive_span_starts: {exclusive_span_starts}, " f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths " f"{sequence_lengths}." ) forward_start_embeddings = util.batched_index_select( forward_sequence, exclusive_span_starts ) backward_start_embeddings = util.batched_index_select( backward_sequence, exclusive_span_ends ) backward_end_embeddings = util.batched_index_select(backward_sequence, span_starts) if self._use_sentinels: forward_start_embeddings = ( forward_start_embeddings * ~start_sentinel_mask + start_sentinel_mask * self._start_sentinel ) backward_start_embeddings = ( backward_start_embeddings * ~end_sentinel_mask + end_sentinel_mask * self._end_sentinel ) forward_spans = util.combine_tensors( self._forward_combination, [forward_start_embeddings, forward_end_embeddings] ) backward_spans = util.combine_tensors( self._backward_combination, [backward_start_embeddings, backward_end_embeddings] ) if self._span_width_embedding is not None: if self._bucket_widths: span_widths = util.bucket_values( span_ends - span_starts, num_total_buckets=self._num_width_embeddings ) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([span_embeddings, span_width_embeddings], -1) if span_indices_mask is not None:) reveal_type(span_indices_mask)