def __call__(self, logits: torch.Tensor, mask: torch.ByteTensor) -> torch.LongTensor: """ 对于 sequence logits, shape: (batch_size, seq_len, num_label), 使用 max 进行 在每一个 timestep 上进行 decode, 得到 label index. :param logits: shape: (batch_size, seq_len, num_label) :param mask: shape: (bath_size, seq_len), 存储的是 0 或 1 :return: 解码后的 label index, shape: (batch_size, seq_len), 注意这是有padding_index 的结果, 需要使用 mask 来提取实际的 label index. """ if logits.dim() != 3: raise RuntimeError( f"logits shape 错误, 应该是 (B, seq_len, num_label), " f"而现在是 {logits.shape}") if (mask is not None) and (mask.dim() != 2): raise RuntimeError(f"mask shape 错误, 应该是 (B, seq_len), " f"而现在是 {mask.shape}") if mask is None: mask = torch.ones(size=(logits.size(0), logits.size(1)), dtype=torch.long) mask_bool = mask.bool() batch_indices = list() for sequence_logits, sequence_mask1d in zip(logits, mask_bool): # 扩充维度到 2 维 sequence_mask2d = torch.unsqueeze(sequence_mask1d, dim=-1) assert sequence_logits.dim() == sequence_mask2d.dim(), \ f"sequence_logits dim: {sequence_logits.dim()} 与 sequence_mask.dim: {sequence_mask2d.dim()} 不匹配" sequence_logits = torch.masked_select(sequence_logits, mask=sequence_mask2d) sequence_logits = sequence_logits.contiguous().view( -1, logits.size(-1)) sequence_labels, sequence_label_indices = BIO.decode_one_sequence_logits_to_label( sequence_logits=sequence_logits, vocabulary=self._label_vocabulary) sequence_label_indices = torch.tensor(sequence_label_indices, dtype=torch.long, device=logits.device) padding = torch.full_like( sequence_mask1d, fill_value=self._label_vocabulary.padding_index, dtype=torch.long) sequence_label_indices = padding.masked_scatter( sequence_mask1d, sequence_label_indices) batch_indices.append(sequence_label_indices) batch_indices = torch.stack(batch_indices, dim=0) return batch_indices
def _compute_normalizer(self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor: # emissions: (seq_length, batch_size, num_tags) # mask: (seq_length, batch_size) assert emissions.dim() == 3 and mask.dim() == 2 assert emissions.shape[:2] == mask.shape assert emissions.size(2) == self.num_tags assert mask[0].bool().all() mask = mask.bool() seq_length = emissions.size(0) # Start transition score and first emission; score has size of # (batch_size, num_tags) where for each batch, the j-th column stores # the score that the first timestep has tag j # shape: (batch_size, num_tags) score = self.start_transitions + emissions[0] for i in range(1, seq_length): # Broadcast score for every possible next tag # shape: (batch_size, num_tags, 1) broadcast_score = score.unsqueeze(2) # Broadcast emission score for every possible current tag # shape: (batch_size, 1, num_tags) broadcast_emissions = emissions[i].unsqueeze(1) # Compute the score tensor of size (batch_size, num_tags, num_tags) where # for each sample, entry at row i and column j stores the sum of scores of all # possible tag sequences so far that end with transitioning from tag i to tag j # and emitting # shape: (batch_size, num_tags, num_tags) next_score = broadcast_score + self.transitions + broadcast_emissions # Sum over all possible current tags, but we're in score space, so a sum # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of # all possible tag sequences so far, that end in tag i # shape: (batch_size, num_tags) next_score = torch.logsumexp(next_score, dim=1) # Set score to the next score if this timestep is valid (mask == 1) # shape: (batch_size, num_tags) score = torch.where(mask[i].unsqueeze(1), next_score, score) # End transition score # shape: (batch_size, num_tags) score += self.end_transitions # Sum (log-sum-exp) over all possible tags # shape: (batch_size,) return torch.logsumexp(score, dim=1)
def decode_label_index_to_span( batch_sequence_label_index: torch.Tensor, mask: torch.ByteTensor, vocabulary: LabelVocabulary) -> List[List[Dict]]: """ 将 label index 解码 成span batch_sequence_label shape:(B, seq_len) (B-T: 0, I-T: 1, O: 2) [[0, 1, 2], [2, 0, 1]] 对应label序列是: [[B, I, O], [O, B, I]] 解码成: [[{"label": T, "begin": 0, "end": 2}], [{"label": T, "begin": 1, "end": 3}]] :param batch_sequence_label_index: shape: (B, seq_len), label index 序列 :param mask: 对 batch_sequence_label 的 mask :param vocabulary: label 词汇表 :return: 解析好的span列表 """ spans = list() if mask is None: mask = torch.ones(size=(batch_sequence_label_index.shape[0], batch_sequence_label_index.shape[1]), dtype=torch.long) mask = mask.bool() for sequence_label_index, mask1d in zip(batch_sequence_label_index, mask): label_indices = torch.masked_select(sequence_label_index, mask=mask1d).tolist() sequence_label = [vocabulary.token(index) for index in label_indices] span = decode_one_sequence_label_to_span(sequence_label=sequence_label) spans.append(span) return spans
def _greedy_search( self, query: torch.FloatTensor, key: torch.FloatTensor, edge_head_score: torch.FloatTensor, edge_head_mask: torch.ByteTensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Predict edge heads and labels. :param query: [batch_size, query_length, query_vector_dim] :param key: [batch_size, key_length, key_vector_dim] :param edge_head_score: [batch_size, query_length, key_length] :param edge_head_mask: None or [batch_size, query_length, key_length] :return: edge_head: [batch_size, query_length] edge_type: [batch_size, query_length] """ edge_head_score = edge_head_score.masked_fill_(~edge_head_mask.bool(), self._minus_inf) _, edge_head = edge_head_score.max(dim=2) edge_type_score = self._get_edge_type_score(query, key, edge_head) _, edge_type = edge_type_score.max(dim=2) return edge_head, edge_type
def _viterbi_decode(self, emissions: torch.FloatTensor, mask: torch.ByteTensor) -> List[List[int]]: # emissions: (seq_length, batch_size, num_tags) # mask: (seq_length, batch_size) assert emissions.dim() == 3 and mask.dim() == 2 assert emissions.shape[:2] == mask.shape assert emissions.size(2) == self.num_tags assert mask[0].bool().all() mask = mask.bool() seq_length, batch_size = mask.shape # Start transition and first emission # shape: (batch_size, num_tags) score = self.start_transitions + emissions[0] history = [] # score is a tensor of size (batch_size, num_tags) where for every batch, # value at column j stores the score of the best tag sequence so far that ends # with tag j # history saves where the best tags candidate transitioned from; this is used # when we trace back the best tag sequence # Viterbi algorithm recursive case: we compute the score of the best tag sequence # for every possible next tag for i in range(1, seq_length): # Broadcast viterbi score for every possible next tag # shape: (batch_size, num_tags, 1) broadcast_score = score.unsqueeze(2) # Broadcast emission score for every possible current tag # shape: (batch_size, 1, num_tags) broadcast_emission = emissions[i].unsqueeze(1) # Compute the score tensor of size (batch_size, num_tags, num_tags) where # for each sample, entry at row i and column j stores the score of the best # tag sequence so far that ends with transitioning from tag i to tag j and emitting # shape: (batch_size, num_tags, num_tags) next_score = broadcast_score + self.transitions + broadcast_emission # Find the maximum score over all possible current tag # shape: (batch_size, num_tags) next_score, indices = next_score.max(dim=1) # Set score to the next score if this timestep is valid (mask == 1) # and save the index that produces the next score # shape: (batch_size, num_tags) score = torch.where(mask[i].unsqueeze(1), next_score, score) history.append(indices) # End transition score # shape: (batch_size, num_tags) score += self.end_transitions # Now, compute the best path for each sample # shape: (batch_size,) seq_ends = mask.long().sum(dim=0) - 1 best_tags_list = [] for idx in range(batch_size): # Find the tag which maximizes the score at the last timestep; this is our best tag # for the last timestep _, best_last_tag = score[idx].max(dim=0) best_tags = [best_last_tag.item()] # We trace back where the best last tag comes from, append that to our best tag # sequence, and trace it back again, and so on for hist in reversed(history[:seq_ends[idx]]): best_last_tag = hist[idx][best_tags[-1]] best_tags.append(best_last_tag.item()) # Reverse the order because we start from the last timestep best_tags.reverse() best_tags_list.append(best_tags) best_tags_list = [ item + [-1] * (seq_length - len(item)) for item in best_tags_list ] best_tags_list = torch.from_numpy(np.array(best_tags_list)).long() return torch.LongTensor(best_tags_list).cuda()