Example #1
0
    def inference(self, tokens, token_types, valid_length, p_mask,
                  start_top_n: int = 5, end_top_n: int = 5):
        """Get the inference result with beam search

        Parameters
        ----------
        tokens
            The input tokens. Shape (batch_size, sequence_length)
        token_types
            The input token types. Shape (batch_size, sequence_length)
        valid_length
            The valid length of the tokens. Shape (batch_size,)
        p_mask
            The mask which indicates that some tokens won't be used in the calculation.
            Shape (batch_size, sequence_length)
        start_top_n
            The number of candidates to select for the start position.
        end_top_n
            The number of candidates to select for the end position.

        Returns
        -------
        start_top_logits
            The top start logits
            Shape (batch_size, start_top_n)
        start_top_index
            Index of the top start logits
            Shape (batch_size, start_top_n)
        end_top_logits
            The top end logits.
            Shape (batch_size, end_top_n)
        end_top_index
            Index of the top end logits
            Shape (batch_size, end_top_n)
        """
        # Shape (batch_size, sequence_length, C)
        if self.use_segmentation:
            contextual_embeddings = self.backbone(tokens, token_types, valid_length)
        else:
            contextual_embeddings = self.backbone(tokens, valid_length)
        scores = self.qa_outputs(contextual_embeddings)
        start_scores = scores[:, :, 0]
        end_scores = scores[:, :, 1]
        start_logits = masked_logsoftmax(start_scores, mask=p_mask, axis=-1)
        end_logits = masked_logsoftmax(end_scores, mask=p_mask, axis=-1)
        # The shape of start_top_index will be (..., start_top_n)
        start_top_logits, start_top_index = mx.npx.topk(start_logits, k=start_top_n, axis=-1,
                                                        ret_typ='both')
        # Note that end_top_index and end_top_log_probs have shape (bsz, start_n_top, end_n_top)
        # So that for each start position, there are end_n_top end positions on the third dim.
        end_top_logits, end_top_index = mx.npx.topk(end_logits, k=end_top_n, axis=-1,
                                                    ret_typ='both')
        return start_top_logits, start_top_index, end_top_logits, end_top_index
Example #2
0
    def get_end_logits(self, contextual_embedding, start_positions, p_mask):
        """

        Parameters
        ----------
        contextual_embedding
            Shape (batch_size, sequence_length, C)
        start_positions
            Shape (batch_size, N)
            We process multiple candidates simultaneously
        p_mask
            Shape (batch_size, sequence_length)

        Returns
        -------
        end_logits
            Shape (batch_size, N, sequence_length)
        """
        # Select the features at the start_positions
        # start_feature will have shape (batch_size, N, C)
        start_features = select_vectors_by_position(contextual_embedding, start_positions)
        # Concatenate the start_feature and the contextual_embedding
        contextual_embedding = np.expand_dims(contextual_embedding, axis=1)  # (B, 1, T, C)
        start_features = np.expand_dims(start_features, axis=2)  # (B, N, 1, C)
        concat_features = np.concatenate([npx.broadcast_like(start_features,
                                                                 contextual_embedding, 2, 2),
                                            npx.broadcast_like(contextual_embedding,
                                                                 start_features, 1, 1)],
                                           axis=-1)  # (B, N, T, 2C)
        end_scores = self.end_scores(concat_features)
        end_scores = np.squeeze(end_scores, -1)
        end_logits = masked_logsoftmax(end_scores, mask=np.expand_dims(p_mask, axis=1),
                                       axis=-1)
        return end_logits
Example #3
0
    def forward(self, tokens, token_types, valid_length, p_mask):
        """

        Parameters
        ----------
        tokens
            Shape (batch_size, seq_length)
            The merged input tokens
        token_types
            Shape (batch_size, seq_length)
            Token types for the sequences, used to indicate whether the word belongs to the
            first sentence or the second one.
        valid_length
            Shape (batch_size,)
            Valid length of the sequence. This is used to mask the padded tokens.
        p_mask
            The mask that is associated with the tokens.

        Returns
        -------
        start_logits
            Shape (batch_size, sequence_length)
            The log-softmax scores that the position is the start position.
        end_logits
            Shape (batch_size, sequence_length)
            The log-softmax scores that the position is the end position.
        """
        # Get contextual embedding with the shape (batch_size, sequence_length, C)
        if self.use_segmentation:
            contextual_embeddings = self.backbone(tokens, token_types,
                                                  valid_length)
        else:
            contextual_embeddings = self.backbone(tokens, valid_length)
        scores = self.qa_outputs(contextual_embeddings)
        start_scores = scores[:, :, 0]
        end_scores = scores[:, :, 1]
        start_logits = masked_logsoftmax(start_scores, mask=p_mask, axis=-1)
        end_logits = masked_logsoftmax(end_scores, mask=p_mask, axis=-1)
        return start_logits, end_logits
Example #4
0
    def get_start_logits(self, contextual_embedding, p_mask):
        """

        Parameters
        ----------
        contextual_embedding
            Shape (batch_size, sequence_length, C)

        Returns
        -------
        start_logits
            Shape (batch_size, sequence_length)
        """
        start_scores = np.squeeze(self.start_scores(contextual_embedding), -1)
        start_logits = masked_logsoftmax(start_scores, mask=p_mask, axis=-1)
        return start_logits