def forward(self, tensor_1: torch.Tensor,
             tensor_2: torch.Tensor) -> torch.Tensor:
     combined_tensors = util.combine_tensors(self._combination,
                                             [tensor_1, tensor_2])
     dot_product = torch.matmul(combined_tensors, self._weight_vector)
     return torch.matmul(self._activation(dot_product + self._bias),
                         self._V)
Ejemplo n.º 2
0
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> None:
        # shape (batch_size, num_spans)
        span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)]

        if span_indices_mask is not None:
            # It's not strictly necessary to multiply the span indices by the mask here,
            # but it's possible that the span representation was padded with something other
            # than 0 (such as -1, which would be an invalid index), so we do so anyway to
            # be safe.
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask

        if not self._use_exclusive_start_indices:
            start_embeddings = util.batched_index_select(sequence_tensor, span_starts)
            end_embeddings = util.batched_index_select(sequence_tensor, span_ends)

        else:
            # We want `exclusive` span starts, so we remove 1 from the forward span starts
            # as the AllenNLP ``SpanField`` is inclusive.
            # shape (batch_size, num_spans)
            exclusive_span_starts = span_starts - 1
            # shape (batch_size, num_spans, 1)
            start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1)
            exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1))

            # We'll check the indices here at runtime, because it's difficult to debug
            # if this goes wrong and it's tricky to get right.
            if (exclusive_span_starts < 0).any():
                raise ValueError(f"Adjusted span indices must lie inside the the sequence tensor, "
                                 f"but found: exclusive_span_starts: {exclusive_span_starts}.")

            start_embeddings = util.batched_index_select(sequence_tensor, exclusive_span_starts)
            end_embeddings = util.batched_index_select(sequence_tensor, span_ends)

            # We're using sentinels, so we need to replace all the elements which were
            # outside the dimensions of the sequence_tensor with the start sentinel.
            float_start_sentinel_mask = start_sentinel_mask.float()
            start_embeddings = start_embeddings * (1 - float_start_sentinel_mask) \
                                        + float_start_sentinel_mask * self._start_sentinel

        combined_tensors = util.combine_tensors(self._combination, [start_embeddings, end_embeddings])
        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = util.bucket_values(span_ends - span_starts,
                                                 num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([combined_tensors, span_width_embeddings], -1)

        if span_indices_mask is not None:
            return combined_tensors * span_indices_mask.unsqueeze(-1).float()
        return combined_tensors
Ejemplo n.º 3
0
    def _forward_internal(self, vector              , matrix              )                :
        # TODO(mattg): Remove the need for this tiling.
        # https://github.com/allenai/allennlp/pull/1235#issuecomment-391540133
        tiled_vector = vector.unsqueeze(1).expand(vector.size()[0],
                                                  matrix.size()[1],
                                                  vector.size()[1])

        combined_tensors = util.combine_tensors(self._combination, [tiled_vector, matrix])
        dot_product = torch.matmul(combined_tensors, self._weight_vector)
        return self._activation(dot_product + self._bias)
Ejemplo n.º 4
0
    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        # If more than two words, take the average of each constituent
        if x.size()[1] > 1:
            x = torch.mean(x, dim=1).unsqueeze(1)

        if y.size()[1] > 1:
            y = torch.mean(x, dim=1).unsqueeze(1)

        combined = util.combine_tensors(self._combination, [x, y])
        product = torch.matmul(combined, self._weight_vector)
        return self._activation(product)
Ejemplo n.º 5
0
    def forward(self, tokens: torch.Tensor, mask: torch.Tensor):  # pylint: disable=arguments-differ
        assert mask is not None, "TalkativeIntraSentenceAttentionEncoder requires a mask to be provided."
        batch_size, sequence_length, _ = tokens.size()
        # Shape: (batch_size, sequence_length, sequence_length)
        similarity_matrix = self._matrix_attention(tokens, tokens)

        if self._num_attention_heads > 1:
            # In this case, the similarity matrix actually has shape
            # (batch_size, sequence_length, sequence_length, num_heads).  To make the rest of the
            # logic below easier, we'll permute this to
            # (batch_size, sequence_length, num_heads, sequence_length).
            similarity_matrix = similarity_matrix.permute(0, 1, 3, 2)

        # Shape: (batch_size, sequence_length, [num_heads,] sequence_length)
        similarity_matrix = similarity_matrix.contiguous()
        temp_mask = mask.unsqueeze(1)
        # Shape: (batch_size, sequence_length, projection_dim)
        output_token_representation = self._projection(tokens)
        self.report_unnormalized_log_attn_weights(
            similarity_matrix * temp_mask, output_token_representation)

        intra_sentence_attention = util.masked_softmax(
            similarity_matrix.contiguous(), mask)

        if self._num_attention_heads > 1:
            # We need to split and permute the output representation to be
            # (batch_size, num_heads, sequence_length, projection_dim / num_heads), so that we can
            # do a proper weighted sum with `intra_sentence_attention`.
            shape = list(output_token_representation.size())
            new_shape = shape[:-1] + [self._num_attention_heads, -1]
            # Shape: (batch_size, sequence_length, num_heads, projection_dim / num_heads)
            output_token_representation = output_token_representation.view(
                *new_shape)
            # Shape: (batch_size, num_heads, sequence_length, projection_dim / num_heads)
            output_token_representation = output_token_representation.permute(
                0, 2, 1, 3)

        # Shape: (batch_size, sequence_length, [num_heads,] projection_dim [/ num_heads])
        attended_sentence = util.weighted_sum(output_token_representation,
                                              intra_sentence_attention)

        if self._num_attention_heads > 1:
            # Here we concatenate the weighted representation for each head.  We'll accomplish this
            # just with a resize.
            # Shape: (batch_size, sequence_length, projection_dim)
            attended_sentence = attended_sentence.view(batch_size,
                                                       sequence_length, -1)

        # Shape: (batch_size, sequence_length, combination_dim)
        combined_tensors = util.combine_tensors(self._combination,
                                                [tokens, attended_sentence])
        return self._output_projection(combined_tensors)
Ejemplo n.º 6
0
    def forward(self, tokens, mask):  # pylint: disable=arguments-differ
        batch_size, sequence_length, _ = tokens.size()
        # Shape: (batch_size, sequence_length, sequence_length)
        similarity_matrix = self._matrix_attention(tokens, tokens)

        if self._num_attention_heads > 1:
            # In this case, the similarity matrix actually has shape
            # (batch_size, sequence_length, sequence_length, num_heads).  To make the rest of the
            # logic below easier, we'll permute this to
            # (batch_size, sequence_length, num_heads, sequence_length).
            similarity_matrix = similarity_matrix.permute(0, 1, 3, 2)

        # Shape: (batch_size, sequence_length, [num_heads,] sequence_length)
        intra_sentence_attention = util.last_dim_softmax(
            similarity_matrix.contiguous(), mask)

        # Shape: (batch_size, sequence_length, projection_dim)
        output_token_representation = self._projection(tokens)

        if self._num_attention_heads > 1:
            # We need to split and permute the output representation to be
            # (batch_size, num_heads, sequence_length, projection_dim / num_heads), so that we can
            # do a proper weighted sum with `intra_sentence_attention`.
            shape = list(output_token_representation.size())
            new_shape = shape[:-1] + [self._num_attention_heads, -1]
            # Shape: (batch_size, sequence_length, num_heads, projection_dim / num_heads)
            output_token_representation = output_token_representation.view(
                *new_shape)
            # Shape: (batch_size, num_heads, sequence_length, projection_dim / num_heads)
            output_token_representation = output_token_representation.permute(
                0, 2, 1, 3)

        # Shape: (batch_size, sequence_length, [num_heads,] projection_dim [/ num_heads])
        attended_sentence = util.weighted_sum(output_token_representation,
                                              intra_sentence_attention)

        if self._num_attention_heads > 1:
            # Here we concatenate the weighted representation for each head.  We'll accomplish this
            # just with a resize.
            # Shape: (batch_size, sequence_length, projection_dim)
            attended_sentence = attended_sentence.view(batch_size,
                                                       sequence_length, -1)

        # Shape: (batch_size, sequence_length, combination_dim)
        combined_tensors = util.combine_tensors(self._combination,
                                                [tokens, attended_sentence])
        return self._output_projection(combined_tensors)
Ejemplo n.º 7
0
    def forward(self,  # pylint: disable=arguments-differ
                matrix_1: torch.Tensor,
                matrix_2: torch.Tensor) -> torch.Tensor:
        # TODO(mattg): Remove the need for this tiling.
        # https://github.com/allenai/allennlp/pull/1235#issuecomment-391540133
        tiled_matrix_1 = matrix_1.unsqueeze(2).expand(matrix_1.size()[0],
                                                      matrix_1.size()[1],
                                                      matrix_2.size()[1],
                                                      matrix_1.size()[2])
        tiled_matrix_2 = matrix_2.unsqueeze(1).expand(matrix_2.size()[0],
                                                      matrix_1.size()[1],
                                                      matrix_2.size()[1],
                                                      matrix_2.size()[2])

        combined_tensors = util.combine_tensors(self._combination, [tiled_matrix_1, tiled_matrix_2])
        dot_product = torch.matmul(combined_tensors, self._weight_vector)
        return self._activation(dot_product + self._bias)
Ejemplo n.º 8
0
    def forward(self, tokens: torch.Tensor, mask: torch.Tensor):  # pylint: disable=arguments-differ
        batch_size, sequence_length, _ = tokens.size()
        # Shape: (batch_size, sequence_length, sequence_length)
        similarity_matrix = self._matrix_attention(tokens, tokens)

        if self._num_attention_heads > 1:
            # In this case, the similarity matrix actually has shape
            # (batch_size, sequence_length, sequence_length, num_heads).  To make the rest of the
            # logic below easier, we'll permute this to
            # (batch_size, sequence_length, num_heads, sequence_length).
            similarity_matrix = similarity_matrix.permute(0, 1, 3, 2)

        # Shape: (batch_size, sequence_length, [num_heads,] sequence_length)
        intra_sentence_attention = util.masked_softmax(similarity_matrix.contiguous(), mask)

        # Shape: (batch_size, sequence_length, projection_dim)
        output_token_representation = self._projection(tokens)

        if self._num_attention_heads > 1:
            # We need to split and permute the output representation to be
            # (batch_size, num_heads, sequence_length, projection_dim / num_heads), so that we can
            # do a proper weighted sum with `intra_sentence_attention`.
            shape = list(output_token_representation.size())
            new_shape = shape[:-1] + [self._num_attention_heads, -1]
            # Shape: (batch_size, sequence_length, num_heads, projection_dim / num_heads)
            output_token_representation = output_token_representation.view(*new_shape)
            # Shape: (batch_size, num_heads, sequence_length, projection_dim / num_heads)
            output_token_representation = output_token_representation.permute(0, 2, 1, 3)

        # Shape: (batch_size, sequence_length, [num_heads,] projection_dim [/ num_heads])
        attended_sentence = util.weighted_sum(output_token_representation,
                                              intra_sentence_attention)

        if self._num_attention_heads > 1:
            # Here we concatenate the weighted representation for each head.  We'll accomplish this
            # just with a resize.
            # Shape: (batch_size, sequence_length, projection_dim)
            attended_sentence = attended_sentence.view(batch_size, sequence_length, -1)

        # Shape: (batch_size, sequence_length, combination_dim)
        combined_tensors = util.combine_tensors(self._combination, [tokens, attended_sentence])
        return self._output_projection(combined_tensors)
    def forward(self, # pylint: disable=arguments-differ
                sequence_tensor: torch.FloatTensor,
                indicies: torch.LongTensor) -> None:
        # shape (batch_size, num_spans)
        span_starts, span_ends = [index.squeeze(-1) for index in indicies.split(1, dim=-1)]
        start_embeddings = batched_index_select(sequence_tensor, span_starts)
        end_embeddings = batched_index_select(sequence_tensor, span_ends)

        combined_tensors = combine_tensors(self._combination, [start_embeddings, end_embeddings])
        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = bucket_values(span_ends - span_starts,
                                            num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([combined_tensors, span_width_embeddings], -1)

        return combined_tensors
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> None:
        # shape (batch_size, num_spans)
        span_starts, span_ends = [
            index.squeeze(-1) for index in span_indices.split(1, dim=-1)
        ]

        if span_indices_mask is not None:
            # It's not strictly necessary to multiply the span indices by the mask here,
            # but it's possible that the span representation was padded with something other
            # than 0 (such as -1, which would be an invalid index), so we do so anyway to
            # be safe.
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask

        start_embeddings = batched_index_select(sequence_tensor, span_starts)
        end_embeddings = batched_index_select(sequence_tensor, span_ends)

        combined_tensors = combine_tensors(self._combination,
                                           [start_embeddings, end_embeddings])
        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = bucket_values(
                    span_ends - span_starts,
                    num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([combined_tensors, span_width_embeddings], -1)

        if span_indices_mask is not None:
            return combined_tensors * span_indices_mask.unsqueeze(-1).float()
        return combined_tensors
Ejemplo n.º 11
0
    def forward(
            self,
            sequence_tensor: torch.FloatTensor,
            span_indices: torch.LongTensor,
            sequence_mask: torch.LongTensor = None,
            span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:

        # Both of shape (batch_size, sequence_length, embedding_size / 2)
        forward_sequence, backward_sequence = sequence_tensor.split(int(
            self._input_dim / 2),
                                                                    dim=-1)
        forward_sequence = forward_sequence.contiguous()
        backward_sequence = backward_sequence.contiguous()

        # shape (batch_size, num_spans)
        span_starts, span_ends = [
            index.squeeze(-1) for index in span_indices.split(1, dim=-1)
        ]

        if span_indices_mask is not None:
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask
        # We want `exclusive` span starts, so we remove 1 from the forward span starts
        # as the AllenNLP ``SpanField`` is inclusive.
        # shape (batch_size, num_spans)
        exclusive_span_starts = span_starts - 1
        # shape (batch_size, num_spans, 1)
        start_sentinel_mask = (
            exclusive_span_starts == -1).long().unsqueeze(-1)

        # We want `exclusive` span ends for the backward direction
        # (so that the `start` of the span in that direction is exlusive), so
        # we add 1 to the span ends as the AllenNLP ``SpanField`` is inclusive.
        exclusive_span_ends = span_ends + 1

        if sequence_mask is not None:
            # shape (batch_size)
            sequence_lengths = util.get_lengths_from_binary_sequence_mask(
                sequence_mask)
        else:
            # shape (batch_size), filled with the sequence length size of the sequence_tensor.
            sequence_lengths = util.ones_like(
                sequence_tensor[:, 0, 0]).long() * sequence_tensor.size(1)

        # shape (batch_size, num_spans, 1)
        end_sentinel_mask = (exclusive_span_ends == sequence_lengths.unsqueeze(
            -1)).long().unsqueeze(-1)

        # As we added 1 to the span_ends to make them exclusive, which might have caused indices
        # equal to the sequence_length to become out of bounds, we multiply by the inverse of the
        # end_sentinel mask to erase these indices (as we will replace them anyway in the block below).
        # The same argument follows for the exclusive span start indices.
        exclusive_span_ends = exclusive_span_ends * (
            1 - end_sentinel_mask.squeeze(-1))
        exclusive_span_starts = exclusive_span_starts * (
            1 - start_sentinel_mask.squeeze(-1))

        # We'll check the indices here at runtime, because it's difficult to debug
        # if this goes wrong and it's tricky to get right.
        if (exclusive_span_starts < 0).any() or (
                exclusive_span_ends > sequence_lengths.unsqueeze(-1)).any():
            raise ValueError(
                f"Adjusted span indices must lie inside the length of the sequence tensor, "
                f"but found: exclusive_span_starts: {exclusive_span_starts}, "
                f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths "
                f"{sequence_lengths}.")

        # Forward Direction: start indices are exclusive. Shape (batch_size, num_spans, input_size / 2)
        forward_start_embeddings = util.batched_index_select(
            forward_sequence, exclusive_span_starts)
        # Forward Direction: end indices are inclusive, so we can just use span_ends.
        # Shape (batch_size, num_spans, input_size / 2)
        forward_end_embeddings = util.batched_index_select(
            forward_sequence, span_ends)

        # Backward Direction: The backward start embeddings use the `forward` end
        # indices, because we are going backwards.
        # Shape (batch_size, num_spans, input_size / 2)
        backward_start_embeddings = util.batched_index_select(
            backward_sequence, exclusive_span_ends)
        # Backward Direction: The backward end embeddings use the `forward` start
        # indices, because we are going backwards.
        # Shape (batch_size, num_spans, input_size / 2)
        backward_end_embeddings = util.batched_index_select(
            backward_sequence, span_starts)

        if self._use_sentinels:
            # If we're using sentinels, we need to replace all the elements which were
            # outside the dimensions of the sequence_tensor with either the start sentinel,
            # or the end sentinel.
            float_end_sentinel_mask = end_sentinel_mask.float()
            float_start_sentinel_mask = start_sentinel_mask.float()
            forward_start_embeddings = forward_start_embeddings * (1 - float_start_sentinel_mask) \
                                        + float_start_sentinel_mask * self._start_sentinel
            backward_start_embeddings = backward_start_embeddings * (1 - float_end_sentinel_mask) \
                                        + float_end_sentinel_mask * self._end_sentinel

        # Now we combine the forward and backward spans in the manner specified by the
        # respective combinations and concatenate these representations.
        # Shape (batch_size, num_spans, forward_combination_dim)
        forward_spans = util.combine_tensors(
            self._forward_combination,
            [forward_start_embeddings, forward_end_embeddings])
        # Shape (batch_size, num_spans, backward_combination_dim)
        backward_spans = util.combine_tensors(
            self._backward_combination,
            [backward_start_embeddings, backward_end_embeddings])
        # Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim)
        span_embeddings = torch.cat([forward_spans, backward_spans], -1)

        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = util.bucket_values(
                    span_ends - span_starts,
                    num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([span_embeddings, span_width_embeddings], -1)

        if span_indices_mask is not None:
            return span_embeddings * span_indices_mask.float().unsqueeze(-1)
        return span_embeddings
Ejemplo n.º 12
0
 def forward(self, tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor:
     combined_tensors = util.combine_tensors(self._combination, [tensor_1, tensor_2])
     dot_product = torch.matmul(combined_tensors, self._weight_vector)
     return self._activation(dot_product + self._bias)
Ejemplo n.º 13
0
    def forward(
        self,
        sequence_tensor: torch.FloatTensor,
        span_indices: torch.LongTensor,
        sequence_mask: torch.LongTensor = None,
        span_indices_mask: torch.LongTensor = None,
    ) -> None:
        # shape (batch_size, num_spans)
        span_starts, span_ends = [
            index.squeeze(-1) for index in span_indices.split(1, dim=-1)
        ]

        if span_indices_mask is not None:
            # It's not strictly necessary to multiply the span indices by the mask here,
            # but it's possible that the span representation was padded with something other
            # than 0 (such as -1, which would be an invalid index), so we do so anyway to
            # be safe.
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask

        if not self._use_exclusive_start_indices:
            if sequence_tensor.size(-1) != self._input_dim:
                raise ValueError(
                    f"Dimension mismatch expected ({sequence_tensor.size(-1)}) "
                    f"received ({self._input_dim}).")
            start_embeddings = util.batched_index_select(
                sequence_tensor, span_starts)
            end_embeddings = util.batched_index_select(sequence_tensor,
                                                       span_ends)

        else:
            # We want `exclusive` span starts, so we remove 1 from the forward span starts
            # as the AllenNLP `SpanField` is inclusive.
            # shape (batch_size, num_spans)
            exclusive_span_starts = span_starts - 1
            # shape (batch_size, num_spans, 1)
            start_sentinel_mask = (
                exclusive_span_starts == -1).long().unsqueeze(-1)
            exclusive_span_starts = exclusive_span_starts * (
                1 - start_sentinel_mask.squeeze(-1))

            # We'll check the indices here at runtime, because it's difficult to debug
            # if this goes wrong and it's tricky to get right.
            if (exclusive_span_starts < 0).any():
                raise ValueError(
                    f"Adjusted span indices must lie inside the the sequence tensor, "
                    f"but found: exclusive_span_starts: {exclusive_span_starts}."
                )

            start_embeddings = util.batched_index_select(
                sequence_tensor, exclusive_span_starts)
            end_embeddings = util.batched_index_select(sequence_tensor,
                                                       span_ends)

            # We're using sentinels, so we need to replace all the elements which were
            # outside the dimensions of the sequence_tensor with the start sentinel.
            float_start_sentinel_mask = start_sentinel_mask.float()
            start_embeddings = (
                start_embeddings * (1 - float_start_sentinel_mask) +
                float_start_sentinel_mask * self._start_sentinel)

        combined_tensors = util.combine_tensors(
            self._combination, [start_embeddings, end_embeddings])
        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = util.bucket_values(
                    span_ends - span_starts,
                    num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            combined_tensors = torch.cat(
                [combined_tensors, span_width_embeddings], -1)

        if span_indices_mask is not None:
            return combined_tensors * span_indices_mask.unsqueeze(-1).float()

        return combined_tensors
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:

        # Both of shape (batch_size, sequence_length, embedding_size / 2)
        forward_sequence, backward_sequence = sequence_tensor.split(int(self._input_dim / 2), dim=-1)
        forward_sequence = forward_sequence.contiguous()
        backward_sequence = backward_sequence.contiguous()

        # shape (batch_size, num_spans)
        span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)]

        if span_indices_mask is not None:
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask
        # We want `exclusive` span starts, so we remove 1 from the forward span starts
        # as the AllenNLP ``SpanField`` is inclusive.
        # shape (batch_size, num_spans)
        exclusive_span_starts = span_starts - 1
        # shape (batch_size, num_spans, 1)
        start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1)

        # We want `exclusive` span ends for the backward direction
        # (so that the `start` of the span in that direction is exlusive), so
        # we add 1 to the span ends as the AllenNLP ``SpanField`` is inclusive.
        exclusive_span_ends = span_ends + 1

        if sequence_mask is not None:
            # shape (batch_size)
            sequence_lengths = util.get_lengths_from_binary_sequence_mask(sequence_mask)
        else:
            # shape (batch_size), filled with the sequence length size of the sequence_tensor.
            sequence_lengths = util.ones_like(sequence_tensor[:, 0, 0]).long() * sequence_tensor.size(1)

        # shape (batch_size, num_spans, 1)
        end_sentinel_mask = (exclusive_span_ends == sequence_lengths.unsqueeze(-1)).long().unsqueeze(-1)

        # As we added 1 to the span_ends to make them exclusive, which might have caused indices
        # equal to the sequence_length to become out of bounds, we multiply by the inverse of the
        # end_sentinel mask to erase these indices (as we will replace them anyway in the block below).
        # The same argument follows for the exclusive span start indices.
        exclusive_span_ends = exclusive_span_ends * (1 - end_sentinel_mask.squeeze(-1))
        exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1))

        # We'll check the indices here at runtime, because it's difficult to debug
        # if this goes wrong and it's tricky to get right.
        if (exclusive_span_starts < 0).any() or (exclusive_span_ends > sequence_lengths.unsqueeze(-1)).any():
            raise ValueError(f"Adjusted span indices must lie inside the length of the sequence tensor, "
                             f"but found: exclusive_span_starts: {exclusive_span_starts}, "
                             f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths "
                             f"{sequence_lengths}.")

        # Forward Direction: start indices are exclusive. Shape (batch_size, num_spans, input_size / 2)
        forward_start_embeddings = util.batched_index_select(forward_sequence, exclusive_span_starts)
        # Forward Direction: end indices are inclusive, so we can just use span_ends.
        # Shape (batch_size, num_spans, input_size / 2)
        forward_end_embeddings = util.batched_index_select(forward_sequence, span_ends)

        # Backward Direction: The backward start embeddings use the `forward` end
        # indices, because we are going backwards.
        # Shape (batch_size, num_spans, input_size / 2)
        backward_start_embeddings = util.batched_index_select(backward_sequence, exclusive_span_ends)
        # Backward Direction: The backward end embeddings use the `forward` start
        # indices, because we are going backwards.
        # Shape (batch_size, num_spans, input_size / 2)
        backward_end_embeddings = util.batched_index_select(backward_sequence, span_starts)

        if self._use_sentinels:
            # If we're using sentinels, we need to replace all the elements which were
            # outside the dimensions of the sequence_tensor with either the start sentinel,
            # or the end sentinel.
            float_end_sentinel_mask = end_sentinel_mask.float()
            float_start_sentinel_mask = start_sentinel_mask.float()
            forward_start_embeddings = forward_start_embeddings * (1 - float_start_sentinel_mask) \
                                        + float_start_sentinel_mask * self._start_sentinel
            backward_start_embeddings = backward_start_embeddings * (1 - float_end_sentinel_mask) \
                                        + float_end_sentinel_mask * self._end_sentinel

        # Now we combine the forward and backward spans in the manner specified by the
        # respective combinations and concatenate these representations.
        # Shape (batch_size, num_spans, forward_combination_dim)
        forward_spans = util.combine_tensors(self._forward_combination,
                                             [forward_start_embeddings, forward_end_embeddings])
        # Shape (batch_size, num_spans, backward_combination_dim)
        backward_spans = util.combine_tensors(self._backward_combination,
                                              [backward_start_embeddings, backward_end_embeddings])
        # Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim)
        span_embeddings = torch.cat([forward_spans, backward_spans], -1)

        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = util.bucket_values(span_ends - span_starts,
                                                 num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([span_embeddings, span_width_embeddings], -1)

        if span_indices_mask is not None:
            return span_embeddings * span_indices_mask.float().unsqueeze(-1)
        return span_embeddings
Ejemplo n.º 15
0
    def forward(
            self,  # type: ignore
            arc_indices: torch.LongTensor,
            token_representations: torch.FloatTensor = None,
            raw_tokens: List[List[str]] = None,
            labels: torch.LongTensor = None,
            **kwargs) -> Dict[str, torch.Tensor]:
        """
        If ``token_representations`` is provided, ``tokens`` is not required. If
        ``token_representations`` is ``None``, then ``tokens`` is required.

        Parameters
        ----------
        arc_indices : torch.LongTensor
            A LongTensor of shape (batch_size, max_num_arcs, 2) with the token pairs
            to predict a label for for each element in the batch.
        token_representations : torch.FloatTensor, optional (default = None)
            A tensor of shape (batch_size, sequence_length, representation_dim) with
            the represenatation of the first token. If None, we use a contextualizer
            within this model to produce the token representation.
        raw_tokens : List[List[str]], optional (default = None)
            A batch of lists with the raw token strings. Used to compute
            token_representations, if either are None.
        labels : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels
            of shape ``(batch_size, num_arc_indices)``.

        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, max_num_arcs, num_classes)`` representing
            unnormalized log probabilities of the classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, max_num_arcs, num_classes)`` representing
            a distribution of the tag classes.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimized.
        """
        # Convert to LongTensor
        # TODO: add PR to ArrayField to preserve array types.
        arc_indices = arc_indices.long()
        if token_representations is None:
            if self._contextualizer is None:
                raise ConfigurationError(
                    "token_representations not provided as input to the model, and no "
                    "contextualizer was specified. Either add a contextualizer to your "
                    "dataset reader (preferred if your contextualizer is frozen) or to "
                    "this model (if you wish to train your contextualizer).")
            if raw_tokens is None:
                raise ValueError(
                    "Input raw_tokens is ``None`` --- make sure to set "
                    "include_raw_tokens in the DatasetReader to True.")
            if arc_indices is None:
                raise ValueError(
                    "Did not recieve arc_indices as input, needed "
                    "if the contextualizer is within the model.")
            # Convert contextualizer output into a tensor
            # Shape: (batch_size, max_seq_len, representation_dim)
            token_representations, _ = pad_contextualizer_output(
                self._contextualizer(raw_tokens))

        # Move token representations to the same device as the
        # module (CPU or CUDA). TODO(nfliu): This only works if the module
        # is on one device.
        device = next(self._decoder._linear_layers[0].parameters()).device
        token_representations = token_representations.to(device)
        text_mask = get_text_mask_from_representations(token_representations)
        text_mask = text_mask.to(device)
        label_mask = self._get_label_mask_from_arc_indices(arc_indices)
        label_mask = label_mask.to(device)

        # Encode the token representations
        encoded_token_representations = self._encoder(token_representations,
                                                      text_mask)

        batch_size = arc_indices.size(0)

        # Index into the encoded_token_representations to get two tensors corresponding
        # to the children and parent of the arcs. Each of these tensors is of shape
        # (batch_size, num_arc_indices, representation_dim)
        first_arc_indices = arc_indices[:, :, 0]
        range_vector = get_range_vector(
            batch_size, get_device_of(first_arc_indices)).unsqueeze(1)
        first_token_representations = encoded_token_representations[
            range_vector, first_arc_indices]
        first_token_representations = first_token_representations.contiguous()

        second_arc_indices = arc_indices[:, :, 1]
        range_vector = get_range_vector(
            batch_size, get_device_of(second_arc_indices)).unsqueeze(1)
        second_token_representations = encoded_token_representations[
            range_vector, second_arc_indices]
        second_token_representations = second_token_representations.contiguous(
        )

        # Take the batch and produce two tensors fit for combining
        # Shape: (batch_size, num_arc_indices, combined_representation_dim)
        combined_tensor = combine_tensors(
            self._combination,
            [first_token_representations, second_token_representations])

        # Decode out a label from the combined tensor.
        # Shape: (batch_size, num_arc_indices, num_classes)
        logits = self._decoder(combined_tensor)
        class_probabilities = F.softmax(logits, dim=-1)
        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities
        }
        if labels is not None:
            loss = sequence_cross_entropy_with_logits(
                logits, labels, label_mask, average=self.loss_average)
            for name, metric in self.metrics.items():
                # When not running in error analysis mode, skip
                # metrics that start with "_"
                if not self.error_analysis and name.startswith("_"):
                    continue
                metric(logits, labels, label_mask.float())
            output_dict["loss"] = loss
        return output_dict
Ejemplo n.º 16
0
    def forward(
        self,
        sequence_tensor: torch.FloatTensor,
        span_indices: torch.LongTensor,
        sequence_mask: torch.BoolTensor = None,
        span_indices_mask: torch.BoolTensor = None,
    ) -> torch.FloatTensor:

        forward_sequence, backward_sequence = sequence_tensor.split(
            int(self._input_dim / 2), dim=-1
        )

        span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)]

        if span_indices_mask is not None:
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask
        exclusive_span_starts = span_starts - 1
        start_sentinel_mask = (exclusive_span_starts == -1).unsqueeze(-1)

        exclusive_span_ends = span_ends + 1

        if sequence_mask is not None:
        else:
            sequence_lengths = torch.ones_like(
                sequence_tensor[:, 0, 0], dtype=torch.long

        end_sentinel_mask = (exclusive_span_ends >= sequence_lengths.unsqueeze(-1)).unsqueeze(-1)

        exclusive_span_ends = exclusive_span_ends * ~end_sentinel_mask.squeeze(-1)
        exclusive_span_starts = exclusive_span_starts * ~start_sentinel_mask.squeeze(-1)

        if (exclusive_span_starts < 0).any() or (
            exclusive_span_ends > sequence_lengths.unsqueeze(-1)
        ).any():
            raise ValueError(
                f"Adjusted span indices must lie inside the length of the sequence tensor, "
                f"but found: exclusive_span_starts: {exclusive_span_starts}, "
                f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths "
                f"{sequence_lengths}."
            )

        forward_start_embeddings = util.batched_index_select(
            forward_sequence, exclusive_span_starts
        )

        backward_start_embeddings = util.batched_index_select(
            backward_sequence, exclusive_span_ends
        )
        backward_end_embeddings = util.batched_index_select(backward_sequence, span_starts)

        if self._use_sentinels:
            forward_start_embeddings = (
                forward_start_embeddings * ~start_sentinel_mask
                + start_sentinel_mask * self._start_sentinel
            )
            backward_start_embeddings = (
                backward_start_embeddings * ~end_sentinel_mask
                + end_sentinel_mask * self._end_sentinel
            )

        forward_spans = util.combine_tensors(
            self._forward_combination, [forward_start_embeddings, forward_end_embeddings]
        )
        backward_spans = util.combine_tensors(
            self._backward_combination, [backward_start_embeddings, backward_end_embeddings]
        )

        if self._span_width_embedding is not None:
            if self._bucket_widths:
                span_widths = util.bucket_values(
                    span_ends - span_starts, num_total_buckets=self._num_width_embeddings
                )
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([span_embeddings, span_width_embeddings], -1)

        if span_indices_mask is not None:)
            reveal_type(span_indices_mask)
Ejemplo n.º 17
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            choices_list: Dict[str, torch.LongTensor],
            choice_kb: Dict[str, torch.LongTensor],
            answer_text: Dict[str, torch.LongTensor],
            fact: Dict[str, torch.LongTensor],
            answer_spans: torch.IntTensor,
            relations: torch.IntTensor = None,
            relation_label: torch.IntTensor = None,
            answer_id: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # B X C X Ct X D
        embedded_choice, choice_mask = get_embedding(choices_list, 1,
                                                     self._text_field_embedder,
                                                     self._encoder,
                                                     self._var_dropout)
        # B X C X D
        # agg_choice, agg_choice_mask = get_agg_rep(embedded_choice, choice_mask, 1, self._encoder, self._aggregate)
        num_choices = embedded_choice.size()[1]
        batch_size = embedded_choice.size()[0]
        # B X Qt X D
        embedded_question, question_mask = get_embedding(
            question, 0, self._text_field_embedder, self._encoder,
            self._var_dropout)
        # B X D
        agg_question, agg_question_mask = get_agg_rep(embedded_question,
                                                      question_mask, 0,
                                                      self._encoder,
                                                      self._aggregate)

        # B X Ft X D
        embedded_fact, fact_mask = get_embedding(fact, 0,
                                                 self._text_field_embedder,
                                                 self._encoder,
                                                 self._var_dropout)
        # B X D
        agg_fact, agg_fact_mask = get_agg_rep(embedded_fact, fact_mask, 0,
                                              self._encoder, self._aggregate)

        # ==============================================
        # Interaction between fact and question
        # ==============================================
        # B x Ft x Qt
        fact_question_att = self._attention(embedded_fact, embedded_question)
        fact_question_mask = self.add_dimension(question_mask, 1,
                                                fact_question_att.shape[1])
        masked_fact_question_att = replace_masked_values(
            fact_question_att, fact_question_mask, -1e7)
        # B X Ft
        fact_question_att_max = masked_fact_question_att.max(
            dim=-1)[0].squeeze(-1)
        fact_question_att_softmax = masked_softmax(fact_question_att_max,
                                                   fact_mask)
        # B X D
        fact_question_att_rep = weighted_sum(embedded_fact,
                                             fact_question_att_softmax)
        # B*C X D
        cmerged_fact_question_att_rep = self.merge_dimensions(
            self.add_dimension(fact_question_att_rep, 1, num_choices))

        # ==============================================
        # Interaction between fact and answer choices
        # ==============================================

        # B*C X Ft X D
        cmerged_embedded_fact = self.merge_dimensions(
            self.add_dimension(embedded_fact, 1, num_choices))
        cmerged_fact_mask = self.merge_dimensions(
            self.add_dimension(fact_mask, 1, num_choices))

        # B*C X Ct X D
        cmerged_embedded_choice = self.merge_dimensions(embedded_choice)
        cmerged_choice_mask = self.merge_dimensions(choice_mask)

        # B*C X Ft X Ct
        cmerged_fact_choice_att = self._attention(cmerged_embedded_fact,
                                                  cmerged_embedded_choice)
        cmerged_fact_choice_mask = self.add_dimension(
            cmerged_choice_mask, 1, cmerged_fact_choice_att.shape[1])
        masked_cmerged_fact_choice_att = replace_masked_values(
            cmerged_fact_choice_att, cmerged_fact_choice_mask, -1e7)

        # B*C X Ft
        cmerged_fact_choice_att_max = masked_cmerged_fact_choice_att.max(
            dim=-1)[0].squeeze(-1)
        cmerged_fact_choice_att_softmax = masked_softmax(
            cmerged_fact_choice_att_max, cmerged_fact_mask)

        # B*C X D
        cmerged_fact_choice_att_rep = weighted_sum(
            cmerged_embedded_fact, cmerged_fact_choice_att_softmax)

        # ==============================================
        # Combined fact + choice + question + span rep
        # ==============================================
        if not self._ignore_spans and not self._ignore_ann:
            # B X A
            per_span_mask = (answer_spans >= 0).long()[:, :, 0]
            # B X A X D
            per_span_rep = self._span_extractor(embedded_fact, answer_spans,
                                                fact_mask, per_span_mask)
            # expanded_span_mask = per_span_mask.unsqueeze(-1).expand_as(per_span_rep)

            # B X D
            answer_span_rep = per_span_rep[:, 0, :]

            # B*C X D
            cmerged_span_rep = self.merge_dimensions(
                self.add_dimension(answer_span_rep, 1, num_choices))
            fact_choice_question_rep = (cmerged_fact_choice_att_rep +
                                        cmerged_fact_question_att_rep +
                                        cmerged_span_rep) / 3

        else:
            fact_choice_question_rep = (cmerged_fact_choice_att_rep +
                                        cmerged_fact_question_att_rep) / 2
        # B*C X D
        cmerged_fact_rep = masked_mean(
            cmerged_embedded_fact,
            cmerged_fact_mask.unsqueeze(-1).expand_as(cmerged_embedded_fact),
            1)
        # B*C X D
        fact_question_combined_rep = combine_tensors(
            self._coverage_combination,
            [fact_choice_question_rep, cmerged_fact_rep])

        # B X C X  D
        new_size = [batch_size, num_choices, -1]
        fact_question_combined_rep = fact_question_combined_rep.contiguous(
        ).view(*new_size)
        # B X C
        coverage_score = self._coverage_ff(fact_question_combined_rep).squeeze(
            -1)
        logger.info("coverage_score" + str(coverage_score.shape))

        # ==============================================
        # Interaction between spans+choices and KB
        # ==============================================

        # B X C X K X Kt x D
        embedded_choice_kb, choice_kb_mask = get_embedding(
            choice_kb, 2, self._text_field_embedder, self._encoder,
            self._var_dropout)
        num_kb = embedded_choice_kb.size()[2]

        # B X A X At X D
        embedded_answer, answer_mask = get_embedding(answer_text, 1,
                                                     self._text_field_embedder,
                                                     self._encoder,
                                                     self._var_dropout)
        # B X At X D
        embedded_answer = embedded_answer[:, 0, :, :]
        answer_mask = answer_mask[:, 0, :]

        # B*C*K X Kt X D
        ckmerged_embedded_choice_kb = self.merge_dimensions(
            self.merge_dimensions(embedded_choice_kb))
        ckmerged_choice_kb_mask = self.merge_dimensions(
            self.merge_dimensions(choice_kb_mask))

        # B*C X At X D
        cmerged_embedded_answer = self.merge_dimensions(
            self.add_dimension(embedded_answer, 1, num_choices))
        cmerged_answer_mask = self.merge_dimensions(
            self.add_dimension(answer_mask, 1, num_choices))
        # B*C*K X At X D
        ckmerged_embedded_answer = self.merge_dimensions(
            self.add_dimension(cmerged_embedded_answer, 1, num_kb))
        ckmerged_answer_mask = self.merge_dimensions(
            self.add_dimension(cmerged_answer_mask, 1, num_kb))
        # B*C*K X Ct X D
        ckmerged_embedded_choice = self.merge_dimensions(
            self.add_dimension(cmerged_embedded_choice, 1, num_kb))
        ckmerged_choice_mask = self.merge_dimensions(
            self.add_dimension(cmerged_choice_mask, 1, num_kb))
        logger.info("ckmerged_choice_mask" + str(ckmerged_choice_mask.shape))

        # == KB rep based on answer span ==
        if self._ignore_ann:
            # B*C*K X Ft X D
            ckmerged_embedded_fact = self.merge_dimensions(
                self.add_dimension(cmerged_embedded_fact, 1, num_kb))
            ckmerged_fact_mask = self.merge_dimensions(
                self.add_dimension(cmerged_fact_mask, 1, num_kb))
            # B*C*K X Kt x Ft
            ckmerged_kb_fact_att = self._attention(ckmerged_embedded_choice_kb,
                                                   ckmerged_embedded_fact)
            ckmerged_kb_fact_mask = self.add_dimension(
                ckmerged_fact_mask, 1, ckmerged_kb_fact_att.shape[1])
            masked_ckmerged_kb_fact_att = replace_masked_values(
                ckmerged_kb_fact_att, ckmerged_kb_fact_mask, -1e7)

            # B*C*K X Kt
            ckmerged_kb_answer_att_max = masked_ckmerged_kb_fact_att.max(
                dim=-1)[0].squeeze(-1)
        else:
            # B*C*K X Kt x At
            ckmerged_kb_answer_att = self._attention(
                ckmerged_embedded_choice_kb, ckmerged_embedded_answer)
            ckmerged_kb_answer_mask = self.add_dimension(
                ckmerged_answer_mask, 1, ckmerged_kb_answer_att.shape[1])
            masked_ckmerged_kb_answer_att = replace_masked_values(
                ckmerged_kb_answer_att, ckmerged_kb_answer_mask, -1e7)

            # B*C*K X Kt
            ckmerged_kb_answer_att_max = masked_ckmerged_kb_answer_att.max(
                dim=-1)[0].squeeze(-1)

        ckmerged_kb_answer_att_softmax = masked_softmax(
            ckmerged_kb_answer_att_max, ckmerged_choice_kb_mask)

        # B*C*K X D
        kb_answer_att_rep = weighted_sum(ckmerged_embedded_choice_kb,
                                         ckmerged_kb_answer_att_softmax)

        # == KB rep based on answer choice ==
        # B*C*K X Kt x Ct
        ckmerged_kb_choice_att = self._attention(ckmerged_embedded_choice_kb,
                                                 ckmerged_embedded_choice)
        ckmerged_kb_choice_mask = self.add_dimension(
            ckmerged_choice_mask, 1, ckmerged_kb_choice_att.shape[1])
        masked_ckmerged_kb_choice_att = replace_masked_values(
            ckmerged_kb_choice_att, ckmerged_kb_choice_mask, -1e7)

        # B*C*K X Kt
        ckmerged_kb_choice_att_max = masked_ckmerged_kb_choice_att.max(
            dim=-1)[0].squeeze(-1)
        ckmerged_kb_choice_att_softmax = masked_softmax(
            ckmerged_kb_choice_att_max, ckmerged_choice_kb_mask)

        # B*C*K X D
        kb_choice_att_rep = weighted_sum(ckmerged_embedded_choice_kb,
                                         ckmerged_kb_choice_att_softmax)

        # B*C*K X D
        answer_choice_kb_combined_rep = combine_tensors(
            self._answer_choice_combination,
            [kb_answer_att_rep, kb_choice_att_rep])
        logger.info("answer_choice_kb_combined_rep" +
                    str(answer_choice_kb_combined_rep.shape))

        # ==============================================
        # Relation Predictions
        # ==============================================

        # B*C*K x R
        choice_kb_relation_rep = self._relation_predictor(
            answer_choice_kb_combined_rep)
        new_choice_kb_size = [batch_size * num_choices, num_kb, -1]
        # B*C*K
        merged_choice_kb_mask = (torch.sum(ckmerged_choice_kb_mask, dim=-1) >
                                 0).float()
        if self._num_relations and not self._ignore_ann:
            if self._relation_projector:
                choice_kb_relation_pred = self._relation_projector(
                    choice_kb_relation_rep)
            else:
                choice_kb_relation_pred = choice_kb_relation_rep

            # Aggregate the predictions
            # B*C*K
            choice_kb_relation_mask = self.add_dimension(
                merged_choice_kb_mask, -1, choice_kb_relation_pred.shape[-1])
            choice_kb_relation_pred_masked = replace_masked_values(
                choice_kb_relation_pred, choice_kb_relation_mask, -1e7)
            # B*C X K X R
            relation_pred_perkb = choice_kb_relation_pred_masked.contiguous(
            ).view(*new_choice_kb_size)
            # B*C X R
            relation_pred_max = relation_pred_perkb.max(dim=1)[0].squeeze(1)

            # B X C X R
            choice_relation_size = [batch_size, num_choices, -1]
            relation_label_logits = relation_pred_max.contiguous().view(
                *choice_relation_size)
            relation_label_probs = softmax(relation_label_logits, dim=-1)
            # B X C
            add_relation_predictions(self.vocab, relation_label_probs,
                                     metadata)
            # B X C X K X R
            choice_kb_relation_size = [batch_size, num_choices, num_kb, -1]
            relation_predictions = choice_kb_relation_rep.contiguous().view(
                *choice_kb_relation_size)
            add_tuple_predictions(relation_predictions, metadata)
            logger.info("relation_predictions" +
                        str(relation_predictions.shape))
        else:
            relation_label_logits = None
            relation_label_probs = None

        if not self._ignore_relns:
            # B X C X D
            expanded_size = [batch_size, num_choices, -1]
            # Aggregate the relation representation
            if self._relation_projector or self._num_relations == 0 or self._ignore_ann:
                # B*C X K X D
                relation_rep_perkb = choice_kb_relation_rep.contiguous().view(
                    *new_choice_kb_size)
                # B*C*K X D
                merged_relation_rep_mask = self.add_dimension(
                    merged_choice_kb_mask, -1, relation_rep_perkb.shape[-1])
                # B*C X K X D
                relation_rep_perkb_mask = merged_relation_rep_mask.contiguous(
                ).view(*relation_rep_perkb.size())
                # B*C X D
                agg_relation_rep = masked_mean(relation_rep_perkb,
                                               relation_rep_perkb_mask,
                                               dim=1)
                # B X C X D
                expanded_relation_rep = agg_relation_rep.contiguous().view(
                    *expanded_size)
            else:
                expanded_relation_rep = relation_label_logits

            expanded_question_rep = agg_question.unsqueeze(1).expand(
                expanded_size)
            expanded_fact_rep = agg_fact.unsqueeze(1).expand(expanded_size)
            question_fact_rep = combine_tensors(
                self._combination, [expanded_question_rep, expanded_fact_rep])

            relation_score_rep = torch.cat(
                [question_fact_rep, expanded_relation_rep], dim=-1)
            relation_score = self._reln_ff(relation_score_rep).squeeze(-1)
            choice_label_logits = (coverage_score + relation_score) / 2
        else:
            choice_label_logits = coverage_score
        logger.info("choice_label_logits" + str(choice_label_logits.shape))

        choice_label_probs = softmax(choice_label_logits, dim=-1)
        output_dict = {
            "label_logits": choice_label_logits,
            "label_probs": choice_label_probs,
            "metadata": metadata
        }
        if relation_label_logits is not None:
            output_dict["relation_label_logits"] = relation_label_logits
            output_dict["relation_label_probs"] = relation_label_probs

        if answer_id is not None or relation_label is not None:
            self.compute_loss_and_accuracy(answer_id, relation_label,
                                           relation_label_logits,
                                           choice_label_logits, output_dict)
        return output_dict
Ejemplo n.º 18
0
def path_encoding(x, y, combine_str, fforward, gate):
    z = combine_tensors(combine_str, [x, y])
    z = fforward(z)
    gatef = gate(z)
    return gatef * z