def from_mask(select: torch.ByteTensor) -> 'SparseSequence': dbatch, dseq = select.size() return SparseSequence(dbatch, dseq, select)
def _forward_algorithm( self, emissions: torch.Tensor, mask: torch.ByteTensor, reverse_direction: bool = False) -> torch.FloatTensor: """ Parameters: emissions: (batch_size, sequence_length, num_tags) mask: Show padding tags. 0 don't calculate score. (batch_size, sequence_length) reverse_direction: This parameter decide algorithm direction. Returns: log_probabilities: (sequence_length, batch_size, num_tags) """ batch_size, sequence_length, num_tags = emissions.data.shape broadcast_emissions = emissions.transpose(0, 1).unsqueeze( 2).contiguous() # (sequence_length, batch_size, 1, num_tags) mask = mask.float().transpose( 0, 1).contiguous() # (sequence_length, batch_size) broadcast_transitions = self.transitions.unsqueeze( 0) # (1, num_tags, num_tags) sequence_iter = range(1, sequence_length) # backward algorithm if reverse_direction: # Transpose transitions matrix and emissions broadcast_transitions = broadcast_transitions.transpose( 1, 2) # (1, num_tags, num_tags) broadcast_emissions = broadcast_emissions.transpose( 2, 3) # (sequence_length, batch_size, num_tags, 1) sequence_iter = reversed(sequence_iter) # It is beta log_proba = [self.end_transitions.expand(batch_size, num_tags)] # forward algorithm else: # It is alpha log_proba = [ emissions.transpose(0, 1)[0] + self.start_transitions.view(1, -1) ] for i in sequence_iter: # Broadcast log probability broadcast_log_proba = log_proba[-1].unsqueeze( 2) # (batch_size, num_tags, 1) # Add all scores # inner: (batch_size, num_tags, num_tags) # broadcast_log_proba: (batch_size, num_tags, 1) # broadcast_transitions: (1, num_tags, num_tags) # broadcast_emissions: (batch_size, 1, num_tags) inner = broadcast_log_proba \ + broadcast_transitions \ + broadcast_emissions[i] # Append log proba log_proba.append( (log_sum_exp(inner, 1) * mask[i].view(batch_size, 1) + log_proba[-1] * (1 - mask[i]).view(batch_size, 1))) if reverse_direction: log_proba.reverse() return torch.stack(log_proba)
def restricted_viterbi_decode( self, emissions: torch.Tensor, possible_tags: torch.ByteTensor, mask: Optional[torch.ByteTensor] = None) -> torch.FloatTensor: """ Parameters: emissions: (batch_size, sequence_length, num_tags) possible_tags: (batch_size, sequence_length, num_tags) mask: Show padding tags. 0 don't calculate score. (batch_size, sequence_length) Returns: tags: (batch_size) """ batch_size, sequence_length, num_tags = emissions.data.shape if mask is None: mask = torch.ones([batch_size, sequence_length], dtype=torch.uint8, device=emissions.device) emissions = emissions.transpose(0, 1).contiguous() mask = mask.transpose(0, 1).contiguous() possible_tags = possible_tags.float().transpose(0, 1).contiguous() # Start transition score and first emission first_possible_tag = possible_tags[0] score = self.start_transitions + emissions[0] # (batch_size, num_tags) score[(first_possible_tag == 0)] = IMPOSSIBLE_SCORE history = [] for i in range(1, sequence_length): current_possible_tags = possible_tags[i - 1] next_possible_tags = possible_tags[i] # Feature score emissions_score = emissions[i] emissions_score[(next_possible_tags == 0)] = IMPOSSIBLE_SCORE emissions_score = emissions_score.view(batch_size, 1, num_tags) # Transition score transition_scores = self.transitions.view(1, num_tags, num_tags).expand( batch_size, num_tags, num_tags).clone() transition_scores[(current_possible_tags == 0)] = IMPOSSIBLE_SCORE transition_scores.transpose( 1, 2)[(next_possible_tags == 0)] = IMPOSSIBLE_SCORE broadcast_score = score.view(batch_size, num_tags, 1) next_score = broadcast_score + transition_scores + emissions_score next_score, indices = next_score.max(dim=1) score = torch.where(mask[i].unsqueeze(1), next_score, score) history.append(indices) # Add end transition score score += self.end_transitions # Compute the best path for each sample seq_ends = mask.long().sum(dim=0) - 1 max_len = int(seq_ends[0]) best_tags_list = [] for idx in range(batch_size): _, best_last_tag = score[idx].max(dim=0) best_tags = [best_last_tag.item()] for hist in reversed(history[:seq_ends[idx]]): best_last_tag = hist[idx][best_tags[-1]] best_tags.append(best_last_tag.item()) best_tags.reverse() best_tags_list.append(best_tags) return best_tags_list
def _viterbi_decode(self, emissions: torch.FloatTensor, mask: torch.ByteTensor) -> List[List[int]]: # emissions: (seq_length, batch_size, num_tags) # mask: (seq_length, batch_size) assert emissions.dim() == 3 and mask.dim() == 2 assert emissions.shape[:2] == mask.shape assert emissions.size(2) == self.num_tags assert mask[0].all() seq_length, batch_size = mask.shape # Start transition and first emission # shape: (batch_size, num_tags) score = self.start_transitions + emissions[0] history = [] # score is a tensor of size (batch_size, num_tags) where for every batch, # value at column j stores the score of the best tag sequence so far that ends # with tag j # history saves where the best tags candidate transitioned from; this is used # when we trace back the best tag sequence # Viterbi algorithm recursive case: we compute the score of the best tag sequence # for every possible next tag for i in range(1, seq_length): # Broadcast viterbi score for every possible next tag # shape: (batch_size, num_tags, 1) broadcast_score = score.unsqueeze(2) # Broadcast emission score for every possible current tag # shape: (batch_size, 1, num_tags) broadcast_emission = emissions[i].unsqueeze(1) # Compute the score tensor of size (batch_size, num_tags, num_tags) where # for each sample, entry at row i and column j stores the score of the best # tag sequence so far that ends with transitioning from tag i to tag j and emitting # shape: (batch_size, num_tags, num_tags) next_score = broadcast_score + self.transitions + broadcast_emission # Find the maximum score over all possible current tag # shape: (batch_size, num_tags) next_score, indices = next_score.max(dim=1) # Set score to the next score if this timestep is valid (mask == 1) # and save the index that produces the next score # shape: (batch_size, num_tags) score = torch.where(mask[i].unsqueeze(1), next_score, score) history.append(indices) # End transition score # shape: (batch_size, num_tags) score += self.end_transitions # Now, compute the best path for each sample # shape: (batch_size,) seq_ends = mask.long().sum(dim=0) - 1 best_tags_list = [] for idx in range(batch_size): # Find the tag which maximizes the score at the last timestep; this is our best tag # for the last timestep _, best_last_tag = score[idx].max(dim=0) best_tags = [best_last_tag.item()] # We trace back where the best last tag comes from, append that to our best tag # sequence, and trace it back again, and so on for hist in reversed(history[:seq_ends[idx]]): best_last_tag = hist[idx][best_tags[-1]] best_tags.append(best_last_tag.item()) # Reverse the order because we start from the last timestep best_tags.reverse() best_tags_list.append(best_tags) return best_tags_list
def _viterbi_decode_nbest( self, emissions: torch.FloatTensor, mask: torch.ByteTensor, nbest: int, pad_tag: Optional[int] = None) -> List[List[List[int]]]: # emissions: (seq_length, batch_size, num_tags) # mask: (seq_length, batch_size) # return: (nbest, batch_size, seq_length) if pad_tag is None: pad_tag = 0 device = emissions.device seq_length, batch_size = mask.shape # Start transition and first emission # shape: (batch_size, num_tags) score = self.start_transitions + emissions[0] history_idx = torch.zeros( (seq_length, batch_size, self.num_tags, nbest), dtype=torch.long, device=device) oor_idx = torch.zeros((batch_size, self.num_tags, nbest), dtype=torch.long, device=device) oor_tag = torch.full((seq_length, batch_size, nbest), pad_tag, dtype=torch.long, device=device) # + score is a tensor of size (batch_size, num_tags) where for every batch, # value at column j stores the score of the best tag sequence so far that ends # with tag j # + history_idx saves where the best tags candidate transitioned from; this is used # when we trace back the best tag sequence # - oor_idx saves the best tags candidate transitioned from at the positions # where mask is 0, i.e. out of range (oor) # Viterbi algorithm recursive case: we compute the score of the best tag sequence # for every possible next tag for i in range(1, seq_length): if i == 1: broadcast_score = score.unsqueeze(-1) broadcast_emission = emissions[i].unsqueeze(1) # shape: (batch_size, num_tags, num_tags) next_score = broadcast_score + self.transitions + broadcast_emission else: broadcast_score = score.unsqueeze(-1) broadcast_emission = emissions[i].unsqueeze(1).unsqueeze(2) # shape: (batch_size, num_tags, nbest, num_tags) next_score = broadcast_score + \ self.transitions.unsqueeze(1) + broadcast_emission # Find the top `nbest` maximum score over all possible current tag # shape: (batch_size, nbest, num_tags) next_score, indices = next_score.view(batch_size, -1, self.num_tags).topk(nbest, dim=1) if i == 1: score = score.unsqueeze(-1).expand(-1, -1, nbest) indices = indices * nbest # convert to shape: (batch_size, num_tags, nbest) next_score = next_score.transpose(2, 1) indices = indices.transpose(2, 1) # Set score to the next score if this timestep is valid (mask == 1) # and save the index that produces the next score # shape: (batch_size, num_tags, nbest) score = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), next_score, score) indices = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), indices, oor_idx) history_idx[i - 1] = indices # End transition score shape: (batch_size, num_tags, nbest) end_score = score + self.end_transitions.unsqueeze(-1) _, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1) # shape: (batch_size,) seq_ends = mask.long().sum(dim=0) - 1 # insert the best tag at each sequence end (last position with mask == 1) history_idx = history_idx.transpose(1, 0).contiguous() history_idx.scatter_( 1, seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest), end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest)) history_idx = history_idx.transpose(1, 0).contiguous() # The most probable path for each sequence best_tags_arr = torch.zeros((seq_length, batch_size, nbest), dtype=torch.long, device=device) best_tags = torch.arange(nbest, dtype=torch.long, device=device) \ .view(1, -1).expand(batch_size, -1) for idx in range(seq_length - 1, -1, -1): best_tags = torch.gather(history_idx[idx].view(batch_size, -1), 1, best_tags) best_tags_arr[idx] = best_tags.data.view(batch_size, -1) // nbest return torch.where(mask.unsqueeze(-1), best_tags_arr, oor_tag).permute(2, 1, 0)
def _viterbi_decode(self, emissions: torch.FloatTensor, mask: torch.ByteTensor) -> List[List[int]]: # emissions: (seq_length, batch_size, num_tags) # mask: (seq_length, batch_size) assert emissions.dim() == 3 and mask.dim() == 2 assert emissions.shape[:2] == mask.shape assert emissions.size(2) == self.num_tags assert mask[0].all() seq_length, batch_size = mask.shape # self.start_transitions start 到其他tag(不包含end)的得分 score = self.start_transitions + emissions[0] history = [] # for i in range(1,seq_length): # # # shape : (batch_size,num_tag,1) # broadcast_score = score.unsqueeze(dim=2) # # # shape: (batch_size,1,num_tags) # broadcast_emissions = emissions[i].unsqueeze(1) # # next_score = broadcast_score + self.transitions + broadcast_emissions # # next_score = torch.logsumexp(next_score,dim = 1) # # score = torch.where(mask[i].unsqueeze(1),next_score,score) for i in range(1, seq_length): broadcast_score = score.unsqueeze(2) broadcast_emission = emissions[i].unsqueeze(1) next_score = broadcast_score + self.transitions + broadcast_emission next_score, indices = next_score.max(dim=1) score = torch.where(mask[i].unsqueeze(1), next_score, score) history.append(indices) score += self.end_transitions seq_ends = mask.long().sum(dim=0) - 1 best_tags_list = [] for idx in range(batch_size): _, best_last_tag = score[idx].max(dim=0) best_tags = [best_last_tag.item()] # history[:seq_ends[idx]].shape (seq_ends[idx]) for hist in reversed(history[:seq_ends[idx]]): best_last_tag = hist[idx][best_tags[-1]] best_tags.append(best_last_tag.item()) best_tags.reverse() best_tags_list.append(best_tags) return best_tags_list