def inference(self, tokens, token_types, valid_length, p_mask, start_top_n: int = 5, end_top_n: int = 5): """Get the inference result with beam search Parameters ---------- tokens The input tokens. Shape (batch_size, sequence_length) token_types The input token types. Shape (batch_size, sequence_length) valid_length The valid length of the tokens. Shape (batch_size,) p_mask The mask which indicates that some tokens won't be used in the calculation. Shape (batch_size, sequence_length) start_top_n The number of candidates to select for the start position. end_top_n The number of candidates to select for the end position. Returns ------- start_top_logits The top start logits Shape (batch_size, start_top_n) start_top_index Index of the top start logits Shape (batch_size, start_top_n) end_top_logits The top end logits. Shape (batch_size, end_top_n) end_top_index Index of the top end logits Shape (batch_size, end_top_n) """ # Shape (batch_size, sequence_length, C) if self.use_segmentation: contextual_embeddings = self.backbone(tokens, token_types, valid_length) else: contextual_embeddings = self.backbone(tokens, valid_length) scores = self.qa_outputs(contextual_embeddings) start_scores = scores[:, :, 0] end_scores = scores[:, :, 1] start_logits = masked_logsoftmax(start_scores, mask=p_mask, axis=-1) end_logits = masked_logsoftmax(end_scores, mask=p_mask, axis=-1) # The shape of start_top_index will be (..., start_top_n) start_top_logits, start_top_index = mx.npx.topk(start_logits, k=start_top_n, axis=-1, ret_typ='both') # Note that end_top_index and end_top_log_probs have shape (bsz, start_n_top, end_n_top) # So that for each start position, there are end_n_top end positions on the third dim. end_top_logits, end_top_index = mx.npx.topk(end_logits, k=end_top_n, axis=-1, ret_typ='both') return start_top_logits, start_top_index, end_top_logits, end_top_index
def get_end_logits(self, contextual_embedding, start_positions, p_mask): """ Parameters ---------- contextual_embedding Shape (batch_size, sequence_length, C) start_positions Shape (batch_size, N) We process multiple candidates simultaneously p_mask Shape (batch_size, sequence_length) Returns ------- end_logits Shape (batch_size, N, sequence_length) """ # Select the features at the start_positions # start_feature will have shape (batch_size, N, C) start_features = select_vectors_by_position(contextual_embedding, start_positions) # Concatenate the start_feature and the contextual_embedding contextual_embedding = np.expand_dims(contextual_embedding, axis=1) # (B, 1, T, C) start_features = np.expand_dims(start_features, axis=2) # (B, N, 1, C) concat_features = np.concatenate([npx.broadcast_like(start_features, contextual_embedding, 2, 2), npx.broadcast_like(contextual_embedding, start_features, 1, 1)], axis=-1) # (B, N, T, 2C) end_scores = self.end_scores(concat_features) end_scores = np.squeeze(end_scores, -1) end_logits = masked_logsoftmax(end_scores, mask=np.expand_dims(p_mask, axis=1), axis=-1) return end_logits
def forward(self, tokens, token_types, valid_length, p_mask): """ Parameters ---------- tokens Shape (batch_size, seq_length) The merged input tokens token_types Shape (batch_size, seq_length) Token types for the sequences, used to indicate whether the word belongs to the first sentence or the second one. valid_length Shape (batch_size,) Valid length of the sequence. This is used to mask the padded tokens. p_mask The mask that is associated with the tokens. Returns ------- start_logits Shape (batch_size, sequence_length) The log-softmax scores that the position is the start position. end_logits Shape (batch_size, sequence_length) The log-softmax scores that the position is the end position. """ # Get contextual embedding with the shape (batch_size, sequence_length, C) if self.use_segmentation: contextual_embeddings = self.backbone(tokens, token_types, valid_length) else: contextual_embeddings = self.backbone(tokens, valid_length) scores = self.qa_outputs(contextual_embeddings) start_scores = scores[:, :, 0] end_scores = scores[:, :, 1] start_logits = masked_logsoftmax(start_scores, mask=p_mask, axis=-1) end_logits = masked_logsoftmax(end_scores, mask=p_mask, axis=-1) return start_logits, end_logits
def get_start_logits(self, contextual_embedding, p_mask): """ Parameters ---------- contextual_embedding Shape (batch_size, sequence_length, C) Returns ------- start_logits Shape (batch_size, sequence_length) """ start_scores = np.squeeze(self.start_scores(contextual_embedding), -1) start_logits = masked_logsoftmax(start_scores, mask=p_mask, axis=-1) return start_logits