def forward(self, tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor: combined_tensors = util.combine_tensors(self._combination, [tensor_1, tensor_2]) dot_product = torch.matmul(combined_tensors, self._weight_vector) return torch.matmul(self._activation(dot_product + self._bias), self._V)
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> None: # shape (batch_size, num_spans) span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)] if span_indices_mask is not None: # It's not strictly necessary to multiply the span indices by the mask here, # but it's possible that the span representation was padded with something other # than 0 (such as -1, which would be an invalid index), so we do so anyway to # be safe. span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask if not self._use_exclusive_start_indices: start_embeddings = util.batched_index_select(sequence_tensor, span_starts) end_embeddings = util.batched_index_select(sequence_tensor, span_ends) else: # We want `exclusive` span starts, so we remove 1 from the forward span starts # as the AllenNLP ``SpanField`` is inclusive. # shape (batch_size, num_spans) exclusive_span_starts = span_starts - 1 # shape (batch_size, num_spans, 1) start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1) exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1)) # We'll check the indices here at runtime, because it's difficult to debug # if this goes wrong and it's tricky to get right. if (exclusive_span_starts < 0).any(): raise ValueError(f"Adjusted span indices must lie inside the the sequence tensor, " f"but found: exclusive_span_starts: {exclusive_span_starts}.") start_embeddings = util.batched_index_select(sequence_tensor, exclusive_span_starts) end_embeddings = util.batched_index_select(sequence_tensor, span_ends) # We're using sentinels, so we need to replace all the elements which were # outside the dimensions of the sequence_tensor with the start sentinel. float_start_sentinel_mask = start_sentinel_mask.float() start_embeddings = start_embeddings * (1 - float_start_sentinel_mask) \ + float_start_sentinel_mask * self._start_sentinel combined_tensors = util.combine_tensors(self._combination, [start_embeddings, end_embeddings]) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = util.bucket_values(span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([combined_tensors, span_width_embeddings], -1) if span_indices_mask is not None: return combined_tensors * span_indices_mask.unsqueeze(-1).float() return combined_tensors
def _forward_internal(self, vector , matrix ) : # TODO(mattg): Remove the need for this tiling. # https://github.com/allenai/allennlp/pull/1235#issuecomment-391540133 tiled_vector = vector.unsqueeze(1).expand(vector.size()[0], matrix.size()[1], vector.size()[1]) combined_tensors = util.combine_tensors(self._combination, [tiled_vector, matrix]) dot_product = torch.matmul(combined_tensors, self._weight_vector) return self._activation(dot_product + self._bias)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # If more than two words, take the average of each constituent if x.size()[1] > 1: x = torch.mean(x, dim=1).unsqueeze(1) if y.size()[1] > 1: y = torch.mean(x, dim=1).unsqueeze(1) combined = util.combine_tensors(self._combination, [x, y]) product = torch.matmul(combined, self._weight_vector) return self._activation(product)
def forward(self, tokens: torch.Tensor, mask: torch.Tensor): # pylint: disable=arguments-differ assert mask is not None, "TalkativeIntraSentenceAttentionEncoder requires a mask to be provided." batch_size, sequence_length, _ = tokens.size() # Shape: (batch_size, sequence_length, sequence_length) similarity_matrix = self._matrix_attention(tokens, tokens) if self._num_attention_heads > 1: # In this case, the similarity matrix actually has shape # (batch_size, sequence_length, sequence_length, num_heads). To make the rest of the # logic below easier, we'll permute this to # (batch_size, sequence_length, num_heads, sequence_length). similarity_matrix = similarity_matrix.permute(0, 1, 3, 2) # Shape: (batch_size, sequence_length, [num_heads,] sequence_length) similarity_matrix = similarity_matrix.contiguous() temp_mask = mask.unsqueeze(1) # Shape: (batch_size, sequence_length, projection_dim) output_token_representation = self._projection(tokens) self.report_unnormalized_log_attn_weights( similarity_matrix * temp_mask, output_token_representation) intra_sentence_attention = util.masked_softmax( similarity_matrix.contiguous(), mask) if self._num_attention_heads > 1: # We need to split and permute the output representation to be # (batch_size, num_heads, sequence_length, projection_dim / num_heads), so that we can # do a proper weighted sum with `intra_sentence_attention`. shape = list(output_token_representation.size()) new_shape = shape[:-1] + [self._num_attention_heads, -1] # Shape: (batch_size, sequence_length, num_heads, projection_dim / num_heads) output_token_representation = output_token_representation.view( *new_shape) # Shape: (batch_size, num_heads, sequence_length, projection_dim / num_heads) output_token_representation = output_token_representation.permute( 0, 2, 1, 3) # Shape: (batch_size, sequence_length, [num_heads,] projection_dim [/ num_heads]) attended_sentence = util.weighted_sum(output_token_representation, intra_sentence_attention) if self._num_attention_heads > 1: # Here we concatenate the weighted representation for each head. We'll accomplish this # just with a resize. # Shape: (batch_size, sequence_length, projection_dim) attended_sentence = attended_sentence.view(batch_size, sequence_length, -1) # Shape: (batch_size, sequence_length, combination_dim) combined_tensors = util.combine_tensors(self._combination, [tokens, attended_sentence]) return self._output_projection(combined_tensors)
def forward(self, tokens, mask): # pylint: disable=arguments-differ batch_size, sequence_length, _ = tokens.size() # Shape: (batch_size, sequence_length, sequence_length) similarity_matrix = self._matrix_attention(tokens, tokens) if self._num_attention_heads > 1: # In this case, the similarity matrix actually has shape # (batch_size, sequence_length, sequence_length, num_heads). To make the rest of the # logic below easier, we'll permute this to # (batch_size, sequence_length, num_heads, sequence_length). similarity_matrix = similarity_matrix.permute(0, 1, 3, 2) # Shape: (batch_size, sequence_length, [num_heads,] sequence_length) intra_sentence_attention = util.last_dim_softmax( similarity_matrix.contiguous(), mask) # Shape: (batch_size, sequence_length, projection_dim) output_token_representation = self._projection(tokens) if self._num_attention_heads > 1: # We need to split and permute the output representation to be # (batch_size, num_heads, sequence_length, projection_dim / num_heads), so that we can # do a proper weighted sum with `intra_sentence_attention`. shape = list(output_token_representation.size()) new_shape = shape[:-1] + [self._num_attention_heads, -1] # Shape: (batch_size, sequence_length, num_heads, projection_dim / num_heads) output_token_representation = output_token_representation.view( *new_shape) # Shape: (batch_size, num_heads, sequence_length, projection_dim / num_heads) output_token_representation = output_token_representation.permute( 0, 2, 1, 3) # Shape: (batch_size, sequence_length, [num_heads,] projection_dim [/ num_heads]) attended_sentence = util.weighted_sum(output_token_representation, intra_sentence_attention) if self._num_attention_heads > 1: # Here we concatenate the weighted representation for each head. We'll accomplish this # just with a resize. # Shape: (batch_size, sequence_length, projection_dim) attended_sentence = attended_sentence.view(batch_size, sequence_length, -1) # Shape: (batch_size, sequence_length, combination_dim) combined_tensors = util.combine_tensors(self._combination, [tokens, attended_sentence]) return self._output_projection(combined_tensors)
def forward(self, # pylint: disable=arguments-differ matrix_1: torch.Tensor, matrix_2: torch.Tensor) -> torch.Tensor: # TODO(mattg): Remove the need for this tiling. # https://github.com/allenai/allennlp/pull/1235#issuecomment-391540133 tiled_matrix_1 = matrix_1.unsqueeze(2).expand(matrix_1.size()[0], matrix_1.size()[1], matrix_2.size()[1], matrix_1.size()[2]) tiled_matrix_2 = matrix_2.unsqueeze(1).expand(matrix_2.size()[0], matrix_1.size()[1], matrix_2.size()[1], matrix_2.size()[2]) combined_tensors = util.combine_tensors(self._combination, [tiled_matrix_1, tiled_matrix_2]) dot_product = torch.matmul(combined_tensors, self._weight_vector) return self._activation(dot_product + self._bias)
def forward(self, tokens: torch.Tensor, mask: torch.Tensor): # pylint: disable=arguments-differ batch_size, sequence_length, _ = tokens.size() # Shape: (batch_size, sequence_length, sequence_length) similarity_matrix = self._matrix_attention(tokens, tokens) if self._num_attention_heads > 1: # In this case, the similarity matrix actually has shape # (batch_size, sequence_length, sequence_length, num_heads). To make the rest of the # logic below easier, we'll permute this to # (batch_size, sequence_length, num_heads, sequence_length). similarity_matrix = similarity_matrix.permute(0, 1, 3, 2) # Shape: (batch_size, sequence_length, [num_heads,] sequence_length) intra_sentence_attention = util.masked_softmax(similarity_matrix.contiguous(), mask) # Shape: (batch_size, sequence_length, projection_dim) output_token_representation = self._projection(tokens) if self._num_attention_heads > 1: # We need to split and permute the output representation to be # (batch_size, num_heads, sequence_length, projection_dim / num_heads), so that we can # do a proper weighted sum with `intra_sentence_attention`. shape = list(output_token_representation.size()) new_shape = shape[:-1] + [self._num_attention_heads, -1] # Shape: (batch_size, sequence_length, num_heads, projection_dim / num_heads) output_token_representation = output_token_representation.view(*new_shape) # Shape: (batch_size, num_heads, sequence_length, projection_dim / num_heads) output_token_representation = output_token_representation.permute(0, 2, 1, 3) # Shape: (batch_size, sequence_length, [num_heads,] projection_dim [/ num_heads]) attended_sentence = util.weighted_sum(output_token_representation, intra_sentence_attention) if self._num_attention_heads > 1: # Here we concatenate the weighted representation for each head. We'll accomplish this # just with a resize. # Shape: (batch_size, sequence_length, projection_dim) attended_sentence = attended_sentence.view(batch_size, sequence_length, -1) # Shape: (batch_size, sequence_length, combination_dim) combined_tensors = util.combine_tensors(self._combination, [tokens, attended_sentence]) return self._output_projection(combined_tensors)
def forward(self, # pylint: disable=arguments-differ sequence_tensor: torch.FloatTensor, indicies: torch.LongTensor) -> None: # shape (batch_size, num_spans) span_starts, span_ends = [index.squeeze(-1) for index in indicies.split(1, dim=-1)] start_embeddings = batched_index_select(sequence_tensor, span_starts) end_embeddings = batched_index_select(sequence_tensor, span_ends) combined_tensors = combine_tensors(self._combination, [start_embeddings, end_embeddings]) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = bucket_values(span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([combined_tensors, span_width_embeddings], -1) return combined_tensors
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> None: # shape (batch_size, num_spans) span_starts, span_ends = [ index.squeeze(-1) for index in span_indices.split(1, dim=-1) ] if span_indices_mask is not None: # It's not strictly necessary to multiply the span indices by the mask here, # but it's possible that the span representation was padded with something other # than 0 (such as -1, which would be an invalid index), so we do so anyway to # be safe. span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask start_embeddings = batched_index_select(sequence_tensor, span_starts) end_embeddings = batched_index_select(sequence_tensor, span_ends) combined_tensors = combine_tensors(self._combination, [start_embeddings, end_embeddings]) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = bucket_values( span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([combined_tensors, span_width_embeddings], -1) if span_indices_mask is not None: return combined_tensors * span_indices_mask.unsqueeze(-1).float() return combined_tensors
def forward( self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor: # Both of shape (batch_size, sequence_length, embedding_size / 2) forward_sequence, backward_sequence = sequence_tensor.split(int( self._input_dim / 2), dim=-1) forward_sequence = forward_sequence.contiguous() backward_sequence = backward_sequence.contiguous() # shape (batch_size, num_spans) span_starts, span_ends = [ index.squeeze(-1) for index in span_indices.split(1, dim=-1) ] if span_indices_mask is not None: span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask # We want `exclusive` span starts, so we remove 1 from the forward span starts # as the AllenNLP ``SpanField`` is inclusive. # shape (batch_size, num_spans) exclusive_span_starts = span_starts - 1 # shape (batch_size, num_spans, 1) start_sentinel_mask = ( exclusive_span_starts == -1).long().unsqueeze(-1) # We want `exclusive` span ends for the backward direction # (so that the `start` of the span in that direction is exlusive), so # we add 1 to the span ends as the AllenNLP ``SpanField`` is inclusive. exclusive_span_ends = span_ends + 1 if sequence_mask is not None: # shape (batch_size) sequence_lengths = util.get_lengths_from_binary_sequence_mask( sequence_mask) else: # shape (batch_size), filled with the sequence length size of the sequence_tensor. sequence_lengths = util.ones_like( sequence_tensor[:, 0, 0]).long() * sequence_tensor.size(1) # shape (batch_size, num_spans, 1) end_sentinel_mask = (exclusive_span_ends == sequence_lengths.unsqueeze( -1)).long().unsqueeze(-1) # As we added 1 to the span_ends to make them exclusive, which might have caused indices # equal to the sequence_length to become out of bounds, we multiply by the inverse of the # end_sentinel mask to erase these indices (as we will replace them anyway in the block below). # The same argument follows for the exclusive span start indices. exclusive_span_ends = exclusive_span_ends * ( 1 - end_sentinel_mask.squeeze(-1)) exclusive_span_starts = exclusive_span_starts * ( 1 - start_sentinel_mask.squeeze(-1)) # We'll check the indices here at runtime, because it's difficult to debug # if this goes wrong and it's tricky to get right. if (exclusive_span_starts < 0).any() or ( exclusive_span_ends > sequence_lengths.unsqueeze(-1)).any(): raise ValueError( f"Adjusted span indices must lie inside the length of the sequence tensor, " f"but found: exclusive_span_starts: {exclusive_span_starts}, " f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths " f"{sequence_lengths}.") # Forward Direction: start indices are exclusive. Shape (batch_size, num_spans, input_size / 2) forward_start_embeddings = util.batched_index_select( forward_sequence, exclusive_span_starts) # Forward Direction: end indices are inclusive, so we can just use span_ends. # Shape (batch_size, num_spans, input_size / 2) forward_end_embeddings = util.batched_index_select( forward_sequence, span_ends) # Backward Direction: The backward start embeddings use the `forward` end # indices, because we are going backwards. # Shape (batch_size, num_spans, input_size / 2) backward_start_embeddings = util.batched_index_select( backward_sequence, exclusive_span_ends) # Backward Direction: The backward end embeddings use the `forward` start # indices, because we are going backwards. # Shape (batch_size, num_spans, input_size / 2) backward_end_embeddings = util.batched_index_select( backward_sequence, span_starts) if self._use_sentinels: # If we're using sentinels, we need to replace all the elements which were # outside the dimensions of the sequence_tensor with either the start sentinel, # or the end sentinel. float_end_sentinel_mask = end_sentinel_mask.float() float_start_sentinel_mask = start_sentinel_mask.float() forward_start_embeddings = forward_start_embeddings * (1 - float_start_sentinel_mask) \ + float_start_sentinel_mask * self._start_sentinel backward_start_embeddings = backward_start_embeddings * (1 - float_end_sentinel_mask) \ + float_end_sentinel_mask * self._end_sentinel # Now we combine the forward and backward spans in the manner specified by the # respective combinations and concatenate these representations. # Shape (batch_size, num_spans, forward_combination_dim) forward_spans = util.combine_tensors( self._forward_combination, [forward_start_embeddings, forward_end_embeddings]) # Shape (batch_size, num_spans, backward_combination_dim) backward_spans = util.combine_tensors( self._backward_combination, [backward_start_embeddings, backward_end_embeddings]) # Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim) span_embeddings = torch.cat([forward_spans, backward_spans], -1) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = util.bucket_values( span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([span_embeddings, span_width_embeddings], -1) if span_indices_mask is not None: return span_embeddings * span_indices_mask.float().unsqueeze(-1) return span_embeddings
def forward(self, tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor: combined_tensors = util.combine_tensors(self._combination, [tensor_1, tensor_2]) dot_product = torch.matmul(combined_tensors, self._weight_vector) return self._activation(dot_product + self._bias)
def forward( self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None, ) -> None: # shape (batch_size, num_spans) span_starts, span_ends = [ index.squeeze(-1) for index in span_indices.split(1, dim=-1) ] if span_indices_mask is not None: # It's not strictly necessary to multiply the span indices by the mask here, # but it's possible that the span representation was padded with something other # than 0 (such as -1, which would be an invalid index), so we do so anyway to # be safe. span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask if not self._use_exclusive_start_indices: if sequence_tensor.size(-1) != self._input_dim: raise ValueError( f"Dimension mismatch expected ({sequence_tensor.size(-1)}) " f"received ({self._input_dim}).") start_embeddings = util.batched_index_select( sequence_tensor, span_starts) end_embeddings = util.batched_index_select(sequence_tensor, span_ends) else: # We want `exclusive` span starts, so we remove 1 from the forward span starts # as the AllenNLP `SpanField` is inclusive. # shape (batch_size, num_spans) exclusive_span_starts = span_starts - 1 # shape (batch_size, num_spans, 1) start_sentinel_mask = ( exclusive_span_starts == -1).long().unsqueeze(-1) exclusive_span_starts = exclusive_span_starts * ( 1 - start_sentinel_mask.squeeze(-1)) # We'll check the indices here at runtime, because it's difficult to debug # if this goes wrong and it's tricky to get right. if (exclusive_span_starts < 0).any(): raise ValueError( f"Adjusted span indices must lie inside the the sequence tensor, " f"but found: exclusive_span_starts: {exclusive_span_starts}." ) start_embeddings = util.batched_index_select( sequence_tensor, exclusive_span_starts) end_embeddings = util.batched_index_select(sequence_tensor, span_ends) # We're using sentinels, so we need to replace all the elements which were # outside the dimensions of the sequence_tensor with the start sentinel. float_start_sentinel_mask = start_sentinel_mask.float() start_embeddings = ( start_embeddings * (1 - float_start_sentinel_mask) + float_start_sentinel_mask * self._start_sentinel) combined_tensors = util.combine_tensors( self._combination, [start_embeddings, end_embeddings]) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = util.bucket_values( span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) combined_tensors = torch.cat( [combined_tensors, span_width_embeddings], -1) if span_indices_mask is not None: return combined_tensors * span_indices_mask.unsqueeze(-1).float() return combined_tensors
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor: # Both of shape (batch_size, sequence_length, embedding_size / 2) forward_sequence, backward_sequence = sequence_tensor.split(int(self._input_dim / 2), dim=-1) forward_sequence = forward_sequence.contiguous() backward_sequence = backward_sequence.contiguous() # shape (batch_size, num_spans) span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)] if span_indices_mask is not None: span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask # We want `exclusive` span starts, so we remove 1 from the forward span starts # as the AllenNLP ``SpanField`` is inclusive. # shape (batch_size, num_spans) exclusive_span_starts = span_starts - 1 # shape (batch_size, num_spans, 1) start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1) # We want `exclusive` span ends for the backward direction # (so that the `start` of the span in that direction is exlusive), so # we add 1 to the span ends as the AllenNLP ``SpanField`` is inclusive. exclusive_span_ends = span_ends + 1 if sequence_mask is not None: # shape (batch_size) sequence_lengths = util.get_lengths_from_binary_sequence_mask(sequence_mask) else: # shape (batch_size), filled with the sequence length size of the sequence_tensor. sequence_lengths = util.ones_like(sequence_tensor[:, 0, 0]).long() * sequence_tensor.size(1) # shape (batch_size, num_spans, 1) end_sentinel_mask = (exclusive_span_ends == sequence_lengths.unsqueeze(-1)).long().unsqueeze(-1) # As we added 1 to the span_ends to make them exclusive, which might have caused indices # equal to the sequence_length to become out of bounds, we multiply by the inverse of the # end_sentinel mask to erase these indices (as we will replace them anyway in the block below). # The same argument follows for the exclusive span start indices. exclusive_span_ends = exclusive_span_ends * (1 - end_sentinel_mask.squeeze(-1)) exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1)) # We'll check the indices here at runtime, because it's difficult to debug # if this goes wrong and it's tricky to get right. if (exclusive_span_starts < 0).any() or (exclusive_span_ends > sequence_lengths.unsqueeze(-1)).any(): raise ValueError(f"Adjusted span indices must lie inside the length of the sequence tensor, " f"but found: exclusive_span_starts: {exclusive_span_starts}, " f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths " f"{sequence_lengths}.") # Forward Direction: start indices are exclusive. Shape (batch_size, num_spans, input_size / 2) forward_start_embeddings = util.batched_index_select(forward_sequence, exclusive_span_starts) # Forward Direction: end indices are inclusive, so we can just use span_ends. # Shape (batch_size, num_spans, input_size / 2) forward_end_embeddings = util.batched_index_select(forward_sequence, span_ends) # Backward Direction: The backward start embeddings use the `forward` end # indices, because we are going backwards. # Shape (batch_size, num_spans, input_size / 2) backward_start_embeddings = util.batched_index_select(backward_sequence, exclusive_span_ends) # Backward Direction: The backward end embeddings use the `forward` start # indices, because we are going backwards. # Shape (batch_size, num_spans, input_size / 2) backward_end_embeddings = util.batched_index_select(backward_sequence, span_starts) if self._use_sentinels: # If we're using sentinels, we need to replace all the elements which were # outside the dimensions of the sequence_tensor with either the start sentinel, # or the end sentinel. float_end_sentinel_mask = end_sentinel_mask.float() float_start_sentinel_mask = start_sentinel_mask.float() forward_start_embeddings = forward_start_embeddings * (1 - float_start_sentinel_mask) \ + float_start_sentinel_mask * self._start_sentinel backward_start_embeddings = backward_start_embeddings * (1 - float_end_sentinel_mask) \ + float_end_sentinel_mask * self._end_sentinel # Now we combine the forward and backward spans in the manner specified by the # respective combinations and concatenate these representations. # Shape (batch_size, num_spans, forward_combination_dim) forward_spans = util.combine_tensors(self._forward_combination, [forward_start_embeddings, forward_end_embeddings]) # Shape (batch_size, num_spans, backward_combination_dim) backward_spans = util.combine_tensors(self._backward_combination, [backward_start_embeddings, backward_end_embeddings]) # Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim) span_embeddings = torch.cat([forward_spans, backward_spans], -1) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = util.bucket_values(span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([span_embeddings, span_width_embeddings], -1) if span_indices_mask is not None: return span_embeddings * span_indices_mask.float().unsqueeze(-1) return span_embeddings
def forward( self, # type: ignore arc_indices: torch.LongTensor, token_representations: torch.FloatTensor = None, raw_tokens: List[List[str]] = None, labels: torch.LongTensor = None, **kwargs) -> Dict[str, torch.Tensor]: """ If ``token_representations`` is provided, ``tokens`` is not required. If ``token_representations`` is ``None``, then ``tokens`` is required. Parameters ---------- arc_indices : torch.LongTensor A LongTensor of shape (batch_size, max_num_arcs, 2) with the token pairs to predict a label for for each element in the batch. token_representations : torch.FloatTensor, optional (default = None) A tensor of shape (batch_size, sequence_length, representation_dim) with the represenatation of the first token. If None, we use a contextualizer within this model to produce the token representation. raw_tokens : List[List[str]], optional (default = None) A batch of lists with the raw token strings. Used to compute token_representations, if either are None. labels : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels of shape ``(batch_size, num_arc_indices)``. Returns ------- An output dictionary consisting of: logits : torch.FloatTensor A tensor of shape ``(batch_size, max_num_arcs, num_classes)`` representing unnormalized log probabilities of the classes. class_probabilities : torch.FloatTensor A tensor of shape ``(batch_size, max_num_arcs, num_classes)`` representing a distribution of the tag classes. loss : torch.FloatTensor, optional A scalar loss to be optimized. """ # Convert to LongTensor # TODO: add PR to ArrayField to preserve array types. arc_indices = arc_indices.long() if token_representations is None: if self._contextualizer is None: raise ConfigurationError( "token_representations not provided as input to the model, and no " "contextualizer was specified. Either add a contextualizer to your " "dataset reader (preferred if your contextualizer is frozen) or to " "this model (if you wish to train your contextualizer).") if raw_tokens is None: raise ValueError( "Input raw_tokens is ``None`` --- make sure to set " "include_raw_tokens in the DatasetReader to True.") if arc_indices is None: raise ValueError( "Did not recieve arc_indices as input, needed " "if the contextualizer is within the model.") # Convert contextualizer output into a tensor # Shape: (batch_size, max_seq_len, representation_dim) token_representations, _ = pad_contextualizer_output( self._contextualizer(raw_tokens)) # Move token representations to the same device as the # module (CPU or CUDA). TODO(nfliu): This only works if the module # is on one device. device = next(self._decoder._linear_layers[0].parameters()).device token_representations = token_representations.to(device) text_mask = get_text_mask_from_representations(token_representations) text_mask = text_mask.to(device) label_mask = self._get_label_mask_from_arc_indices(arc_indices) label_mask = label_mask.to(device) # Encode the token representations encoded_token_representations = self._encoder(token_representations, text_mask) batch_size = arc_indices.size(0) # Index into the encoded_token_representations to get two tensors corresponding # to the children and parent of the arcs. Each of these tensors is of shape # (batch_size, num_arc_indices, representation_dim) first_arc_indices = arc_indices[:, :, 0] range_vector = get_range_vector( batch_size, get_device_of(first_arc_indices)).unsqueeze(1) first_token_representations = encoded_token_representations[ range_vector, first_arc_indices] first_token_representations = first_token_representations.contiguous() second_arc_indices = arc_indices[:, :, 1] range_vector = get_range_vector( batch_size, get_device_of(second_arc_indices)).unsqueeze(1) second_token_representations = encoded_token_representations[ range_vector, second_arc_indices] second_token_representations = second_token_representations.contiguous( ) # Take the batch and produce two tensors fit for combining # Shape: (batch_size, num_arc_indices, combined_representation_dim) combined_tensor = combine_tensors( self._combination, [first_token_representations, second_token_representations]) # Decode out a label from the combined tensor. # Shape: (batch_size, num_arc_indices, num_classes) logits = self._decoder(combined_tensor) class_probabilities = F.softmax(logits, dim=-1) output_dict = { "logits": logits, "class_probabilities": class_probabilities } if labels is not None: loss = sequence_cross_entropy_with_logits( logits, labels, label_mask, average=self.loss_average) for name, metric in self.metrics.items(): # When not running in error analysis mode, skip # metrics that start with "_" if not self.error_analysis and name.startswith("_"): continue metric(logits, labels, label_mask.float()) output_dict["loss"] = loss return output_dict
def forward( self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.BoolTensor = None, span_indices_mask: torch.BoolTensor = None, ) -> torch.FloatTensor: forward_sequence, backward_sequence = sequence_tensor.split( int(self._input_dim / 2), dim=-1 ) span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)] if span_indices_mask is not None: span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask exclusive_span_starts = span_starts - 1 start_sentinel_mask = (exclusive_span_starts == -1).unsqueeze(-1) exclusive_span_ends = span_ends + 1 if sequence_mask is not None: else: sequence_lengths = torch.ones_like( sequence_tensor[:, 0, 0], dtype=torch.long end_sentinel_mask = (exclusive_span_ends >= sequence_lengths.unsqueeze(-1)).unsqueeze(-1) exclusive_span_ends = exclusive_span_ends * ~end_sentinel_mask.squeeze(-1) exclusive_span_starts = exclusive_span_starts * ~start_sentinel_mask.squeeze(-1) if (exclusive_span_starts < 0).any() or ( exclusive_span_ends > sequence_lengths.unsqueeze(-1) ).any(): raise ValueError( f"Adjusted span indices must lie inside the length of the sequence tensor, " f"but found: exclusive_span_starts: {exclusive_span_starts}, " f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths " f"{sequence_lengths}." ) forward_start_embeddings = util.batched_index_select( forward_sequence, exclusive_span_starts ) backward_start_embeddings = util.batched_index_select( backward_sequence, exclusive_span_ends ) backward_end_embeddings = util.batched_index_select(backward_sequence, span_starts) if self._use_sentinels: forward_start_embeddings = ( forward_start_embeddings * ~start_sentinel_mask + start_sentinel_mask * self._start_sentinel ) backward_start_embeddings = ( backward_start_embeddings * ~end_sentinel_mask + end_sentinel_mask * self._end_sentinel ) forward_spans = util.combine_tensors( self._forward_combination, [forward_start_embeddings, forward_end_embeddings] ) backward_spans = util.combine_tensors( self._backward_combination, [backward_start_embeddings, backward_end_embeddings] ) if self._span_width_embedding is not None: if self._bucket_widths: span_widths = util.bucket_values( span_ends - span_starts, num_total_buckets=self._num_width_embeddings ) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([span_embeddings, span_width_embeddings], -1) if span_indices_mask is not None:) reveal_type(span_indices_mask)
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], choices_list: Dict[str, torch.LongTensor], choice_kb: Dict[str, torch.LongTensor], answer_text: Dict[str, torch.LongTensor], fact: Dict[str, torch.LongTensor], answer_spans: torch.IntTensor, relations: torch.IntTensor = None, relation_label: torch.IntTensor = None, answer_id: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # B X C X Ct X D embedded_choice, choice_mask = get_embedding(choices_list, 1, self._text_field_embedder, self._encoder, self._var_dropout) # B X C X D # agg_choice, agg_choice_mask = get_agg_rep(embedded_choice, choice_mask, 1, self._encoder, self._aggregate) num_choices = embedded_choice.size()[1] batch_size = embedded_choice.size()[0] # B X Qt X D embedded_question, question_mask = get_embedding( question, 0, self._text_field_embedder, self._encoder, self._var_dropout) # B X D agg_question, agg_question_mask = get_agg_rep(embedded_question, question_mask, 0, self._encoder, self._aggregate) # B X Ft X D embedded_fact, fact_mask = get_embedding(fact, 0, self._text_field_embedder, self._encoder, self._var_dropout) # B X D agg_fact, agg_fact_mask = get_agg_rep(embedded_fact, fact_mask, 0, self._encoder, self._aggregate) # ============================================== # Interaction between fact and question # ============================================== # B x Ft x Qt fact_question_att = self._attention(embedded_fact, embedded_question) fact_question_mask = self.add_dimension(question_mask, 1, fact_question_att.shape[1]) masked_fact_question_att = replace_masked_values( fact_question_att, fact_question_mask, -1e7) # B X Ft fact_question_att_max = masked_fact_question_att.max( dim=-1)[0].squeeze(-1) fact_question_att_softmax = masked_softmax(fact_question_att_max, fact_mask) # B X D fact_question_att_rep = weighted_sum(embedded_fact, fact_question_att_softmax) # B*C X D cmerged_fact_question_att_rep = self.merge_dimensions( self.add_dimension(fact_question_att_rep, 1, num_choices)) # ============================================== # Interaction between fact and answer choices # ============================================== # B*C X Ft X D cmerged_embedded_fact = self.merge_dimensions( self.add_dimension(embedded_fact, 1, num_choices)) cmerged_fact_mask = self.merge_dimensions( self.add_dimension(fact_mask, 1, num_choices)) # B*C X Ct X D cmerged_embedded_choice = self.merge_dimensions(embedded_choice) cmerged_choice_mask = self.merge_dimensions(choice_mask) # B*C X Ft X Ct cmerged_fact_choice_att = self._attention(cmerged_embedded_fact, cmerged_embedded_choice) cmerged_fact_choice_mask = self.add_dimension( cmerged_choice_mask, 1, cmerged_fact_choice_att.shape[1]) masked_cmerged_fact_choice_att = replace_masked_values( cmerged_fact_choice_att, cmerged_fact_choice_mask, -1e7) # B*C X Ft cmerged_fact_choice_att_max = masked_cmerged_fact_choice_att.max( dim=-1)[0].squeeze(-1) cmerged_fact_choice_att_softmax = masked_softmax( cmerged_fact_choice_att_max, cmerged_fact_mask) # B*C X D cmerged_fact_choice_att_rep = weighted_sum( cmerged_embedded_fact, cmerged_fact_choice_att_softmax) # ============================================== # Combined fact + choice + question + span rep # ============================================== if not self._ignore_spans and not self._ignore_ann: # B X A per_span_mask = (answer_spans >= 0).long()[:, :, 0] # B X A X D per_span_rep = self._span_extractor(embedded_fact, answer_spans, fact_mask, per_span_mask) # expanded_span_mask = per_span_mask.unsqueeze(-1).expand_as(per_span_rep) # B X D answer_span_rep = per_span_rep[:, 0, :] # B*C X D cmerged_span_rep = self.merge_dimensions( self.add_dimension(answer_span_rep, 1, num_choices)) fact_choice_question_rep = (cmerged_fact_choice_att_rep + cmerged_fact_question_att_rep + cmerged_span_rep) / 3 else: fact_choice_question_rep = (cmerged_fact_choice_att_rep + cmerged_fact_question_att_rep) / 2 # B*C X D cmerged_fact_rep = masked_mean( cmerged_embedded_fact, cmerged_fact_mask.unsqueeze(-1).expand_as(cmerged_embedded_fact), 1) # B*C X D fact_question_combined_rep = combine_tensors( self._coverage_combination, [fact_choice_question_rep, cmerged_fact_rep]) # B X C X D new_size = [batch_size, num_choices, -1] fact_question_combined_rep = fact_question_combined_rep.contiguous( ).view(*new_size) # B X C coverage_score = self._coverage_ff(fact_question_combined_rep).squeeze( -1) logger.info("coverage_score" + str(coverage_score.shape)) # ============================================== # Interaction between spans+choices and KB # ============================================== # B X C X K X Kt x D embedded_choice_kb, choice_kb_mask = get_embedding( choice_kb, 2, self._text_field_embedder, self._encoder, self._var_dropout) num_kb = embedded_choice_kb.size()[2] # B X A X At X D embedded_answer, answer_mask = get_embedding(answer_text, 1, self._text_field_embedder, self._encoder, self._var_dropout) # B X At X D embedded_answer = embedded_answer[:, 0, :, :] answer_mask = answer_mask[:, 0, :] # B*C*K X Kt X D ckmerged_embedded_choice_kb = self.merge_dimensions( self.merge_dimensions(embedded_choice_kb)) ckmerged_choice_kb_mask = self.merge_dimensions( self.merge_dimensions(choice_kb_mask)) # B*C X At X D cmerged_embedded_answer = self.merge_dimensions( self.add_dimension(embedded_answer, 1, num_choices)) cmerged_answer_mask = self.merge_dimensions( self.add_dimension(answer_mask, 1, num_choices)) # B*C*K X At X D ckmerged_embedded_answer = self.merge_dimensions( self.add_dimension(cmerged_embedded_answer, 1, num_kb)) ckmerged_answer_mask = self.merge_dimensions( self.add_dimension(cmerged_answer_mask, 1, num_kb)) # B*C*K X Ct X D ckmerged_embedded_choice = self.merge_dimensions( self.add_dimension(cmerged_embedded_choice, 1, num_kb)) ckmerged_choice_mask = self.merge_dimensions( self.add_dimension(cmerged_choice_mask, 1, num_kb)) logger.info("ckmerged_choice_mask" + str(ckmerged_choice_mask.shape)) # == KB rep based on answer span == if self._ignore_ann: # B*C*K X Ft X D ckmerged_embedded_fact = self.merge_dimensions( self.add_dimension(cmerged_embedded_fact, 1, num_kb)) ckmerged_fact_mask = self.merge_dimensions( self.add_dimension(cmerged_fact_mask, 1, num_kb)) # B*C*K X Kt x Ft ckmerged_kb_fact_att = self._attention(ckmerged_embedded_choice_kb, ckmerged_embedded_fact) ckmerged_kb_fact_mask = self.add_dimension( ckmerged_fact_mask, 1, ckmerged_kb_fact_att.shape[1]) masked_ckmerged_kb_fact_att = replace_masked_values( ckmerged_kb_fact_att, ckmerged_kb_fact_mask, -1e7) # B*C*K X Kt ckmerged_kb_answer_att_max = masked_ckmerged_kb_fact_att.max( dim=-1)[0].squeeze(-1) else: # B*C*K X Kt x At ckmerged_kb_answer_att = self._attention( ckmerged_embedded_choice_kb, ckmerged_embedded_answer) ckmerged_kb_answer_mask = self.add_dimension( ckmerged_answer_mask, 1, ckmerged_kb_answer_att.shape[1]) masked_ckmerged_kb_answer_att = replace_masked_values( ckmerged_kb_answer_att, ckmerged_kb_answer_mask, -1e7) # B*C*K X Kt ckmerged_kb_answer_att_max = masked_ckmerged_kb_answer_att.max( dim=-1)[0].squeeze(-1) ckmerged_kb_answer_att_softmax = masked_softmax( ckmerged_kb_answer_att_max, ckmerged_choice_kb_mask) # B*C*K X D kb_answer_att_rep = weighted_sum(ckmerged_embedded_choice_kb, ckmerged_kb_answer_att_softmax) # == KB rep based on answer choice == # B*C*K X Kt x Ct ckmerged_kb_choice_att = self._attention(ckmerged_embedded_choice_kb, ckmerged_embedded_choice) ckmerged_kb_choice_mask = self.add_dimension( ckmerged_choice_mask, 1, ckmerged_kb_choice_att.shape[1]) masked_ckmerged_kb_choice_att = replace_masked_values( ckmerged_kb_choice_att, ckmerged_kb_choice_mask, -1e7) # B*C*K X Kt ckmerged_kb_choice_att_max = masked_ckmerged_kb_choice_att.max( dim=-1)[0].squeeze(-1) ckmerged_kb_choice_att_softmax = masked_softmax( ckmerged_kb_choice_att_max, ckmerged_choice_kb_mask) # B*C*K X D kb_choice_att_rep = weighted_sum(ckmerged_embedded_choice_kb, ckmerged_kb_choice_att_softmax) # B*C*K X D answer_choice_kb_combined_rep = combine_tensors( self._answer_choice_combination, [kb_answer_att_rep, kb_choice_att_rep]) logger.info("answer_choice_kb_combined_rep" + str(answer_choice_kb_combined_rep.shape)) # ============================================== # Relation Predictions # ============================================== # B*C*K x R choice_kb_relation_rep = self._relation_predictor( answer_choice_kb_combined_rep) new_choice_kb_size = [batch_size * num_choices, num_kb, -1] # B*C*K merged_choice_kb_mask = (torch.sum(ckmerged_choice_kb_mask, dim=-1) > 0).float() if self._num_relations and not self._ignore_ann: if self._relation_projector: choice_kb_relation_pred = self._relation_projector( choice_kb_relation_rep) else: choice_kb_relation_pred = choice_kb_relation_rep # Aggregate the predictions # B*C*K choice_kb_relation_mask = self.add_dimension( merged_choice_kb_mask, -1, choice_kb_relation_pred.shape[-1]) choice_kb_relation_pred_masked = replace_masked_values( choice_kb_relation_pred, choice_kb_relation_mask, -1e7) # B*C X K X R relation_pred_perkb = choice_kb_relation_pred_masked.contiguous( ).view(*new_choice_kb_size) # B*C X R relation_pred_max = relation_pred_perkb.max(dim=1)[0].squeeze(1) # B X C X R choice_relation_size = [batch_size, num_choices, -1] relation_label_logits = relation_pred_max.contiguous().view( *choice_relation_size) relation_label_probs = softmax(relation_label_logits, dim=-1) # B X C add_relation_predictions(self.vocab, relation_label_probs, metadata) # B X C X K X R choice_kb_relation_size = [batch_size, num_choices, num_kb, -1] relation_predictions = choice_kb_relation_rep.contiguous().view( *choice_kb_relation_size) add_tuple_predictions(relation_predictions, metadata) logger.info("relation_predictions" + str(relation_predictions.shape)) else: relation_label_logits = None relation_label_probs = None if not self._ignore_relns: # B X C X D expanded_size = [batch_size, num_choices, -1] # Aggregate the relation representation if self._relation_projector or self._num_relations == 0 or self._ignore_ann: # B*C X K X D relation_rep_perkb = choice_kb_relation_rep.contiguous().view( *new_choice_kb_size) # B*C*K X D merged_relation_rep_mask = self.add_dimension( merged_choice_kb_mask, -1, relation_rep_perkb.shape[-1]) # B*C X K X D relation_rep_perkb_mask = merged_relation_rep_mask.contiguous( ).view(*relation_rep_perkb.size()) # B*C X D agg_relation_rep = masked_mean(relation_rep_perkb, relation_rep_perkb_mask, dim=1) # B X C X D expanded_relation_rep = agg_relation_rep.contiguous().view( *expanded_size) else: expanded_relation_rep = relation_label_logits expanded_question_rep = agg_question.unsqueeze(1).expand( expanded_size) expanded_fact_rep = agg_fact.unsqueeze(1).expand(expanded_size) question_fact_rep = combine_tensors( self._combination, [expanded_question_rep, expanded_fact_rep]) relation_score_rep = torch.cat( [question_fact_rep, expanded_relation_rep], dim=-1) relation_score = self._reln_ff(relation_score_rep).squeeze(-1) choice_label_logits = (coverage_score + relation_score) / 2 else: choice_label_logits = coverage_score logger.info("choice_label_logits" + str(choice_label_logits.shape)) choice_label_probs = softmax(choice_label_logits, dim=-1) output_dict = { "label_logits": choice_label_logits, "label_probs": choice_label_probs, "metadata": metadata } if relation_label_logits is not None: output_dict["relation_label_logits"] = relation_label_logits output_dict["relation_label_probs"] = relation_label_probs if answer_id is not None or relation_label is not None: self.compute_loss_and_accuracy(answer_id, relation_label, relation_label_logits, choice_label_logits, output_dict) return output_dict
def path_encoding(x, y, combine_str, fforward, gate): z = combine_tensors(combine_str, [x, y]) z = fforward(z) gatef = gate(z) return gatef * z