def _denominator_score(self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor: """ Parameters: emissions: (batch_size, sequence_length, num_tags) mask: Show padding tags. 0 don't calculate score. (batch_size, sequence_length) Returns: scores: (batch_size) """ batch_size, sequence_length, num_tags = emissions.data.shape emissions = emissions.transpose(0, 1).contiguous() mask = mask.float().transpose(0, 1).contiguous() # Start transition score and first emissions score alpha = self.start_transitions.view(1, num_tags) + emissions[0] for i in range(1, sequence_length): emissions_score = emissions[i].view( batch_size, 1, num_tags) # (batch_size, 1, num_tags) transition_scores = self.transitions.view( 1, num_tags, num_tags) # (1, num_tags, num_tags) broadcast_alpha = alpha.view(batch_size, num_tags, 1) # (batch_size, num_tags, 1) inner = broadcast_alpha + emissions_score + transition_scores # (batch_size, num_tags, num_tags) alpha = (log_sum_exp(inner, 1) * mask[i].view(batch_size, 1) + alpha * (1 - mask[i]).view(batch_size, 1)) # Add end transition score stops = alpha + self.end_transitions.view(1, num_tags) return log_sum_exp(stops) # (batch_size,)
def _numerator_score(self, emissions: torch.Tensor, marginal_tags: torch.LongTensor, mask: torch.ByteTensor) -> torch.Tensor: """ Parameters: emissions: (batch_size, sequence_length, num_tags) marginal_tags: (batch_size, sequence_length, num_tags) mask: Show padding tags. 0 don't calculate score. (batch_size, sequence_length) Returns: scores: (batch_size) """ batch_size, sequence_length, num_tags = emissions.data.shape emissions = emissions.transpose(0, 1).contiguous() mask = mask.float().transpose(0, 1).contiguous() marginal_tags = marginal_tags.float().transpose(0, 1) log_marginal_tags = torch.log(marginal_tags) log_marginal_tags[log_marginal_tags == -float('inf')] = IMPOSSIBLE_SCORE # Start transition score and first emission alpha = self.start_transitions + emissions[0] + log_marginal_tags[0] for i in range(1, sequence_length): log_next_marginal_tags = log_marginal_tags[ i] # (batch_size, num_tags) # Emissions scores emissions_score = emissions[i].view(batch_size, 1, num_tags) # Transition scores transition_scores = self.transitions.view(1, num_tags, num_tags).expand( batch_size, num_tags, num_tags).clone() # Broadcast alpha broadcast_alpha = alpha.view(batch_size, num_tags, 1) # Add all scores inner = broadcast_alpha + emissions_score + transition_scores # (batch_size, num_tags, num_tags) alpha = (log_sum_exp(inner, 1) * mask[i].view(batch_size, 1) + alpha * (1 - mask[i]).view(batch_size, 1)) alpha += log_next_marginal_tags * mask[i].view(batch_size, 1) # Add end transition score last_tag_indexes = mask.sum(0).long() - 1 end_transitions = self.end_transitions.expand(batch_size, num_tags) stops = alpha + end_transitions return log_sum_exp(stops) # (batch_size,)
def marginal_probabilities( self, emissions: torch.Tensor, mask: Optional[torch.ByteTensor] = None) -> torch.FloatTensor: """ Parameters: emissions: (batch_size, sequence_length, num_tags) mask: Show padding tags. 0 don't calculate score. (batch_size, sequence_length) Returns: marginal_probabilities: (sequence_length, sequence_length, num_tags) """ if mask is None: batch_size, sequence_length, _ = emissions.data.shape mask = torch.ones([batch_size, sequence_length], dtype=torch.uint8, device=emissions.device) alpha = self._forward_algorithm(emissions, mask, reverse_direction=False) beta = self._forward_algorithm(emissions, mask, reverse_direction=True) z = log_sum_exp(alpha[alpha.size(0) - 1] + self.end_transitions, dim=1) proba = alpha + beta - z.view(1, -1, 1) return torch.exp(proba)
def _forward_algorithm( self, emissions: torch.Tensor, mask: torch.ByteTensor, reverse_direction: bool = False) -> torch.FloatTensor: """ Parameters: emissions: (batch_size, sequence_length, num_tags) mask: Show padding tags. 0 don't calculate score. (batch_size, sequence_length) reverse_direction: This parameter decide algorithm direction. Returns: log_probabilities: (sequence_length, batch_size, num_tags) """ batch_size, sequence_length, num_tags = emissions.data.shape broadcast_emissions = emissions.transpose(0, 1).unsqueeze( 2).contiguous() # (sequence_length, batch_size, 1, num_tags) mask = mask.float().transpose( 0, 1).contiguous() # (sequence_length, batch_size) broadcast_transitions = self.transitions.unsqueeze( 0) # (1, num_tags, num_tags) sequence_iter = range(1, sequence_length) # backward algorithm if reverse_direction: # Transpose transitions matrix and emissions broadcast_transitions = broadcast_transitions.transpose( 1, 2) # (1, num_tags, num_tags) broadcast_emissions = broadcast_emissions.transpose( 2, 3) # (sequence_length, batch_size, num_tags, 1) sequence_iter = reversed(sequence_iter) # It is beta log_proba = [self.end_transitions.expand(batch_size, num_tags)] # forward algorithm else: # It is alpha log_proba = [ emissions.transpose(0, 1)[0] + self.start_transitions.view(1, -1) ] for i in sequence_iter: # Broadcast log probability broadcast_log_proba = log_proba[-1].unsqueeze( 2) # (batch_size, num_tags, 1) # Add all scores # inner: (batch_size, num_tags, num_tags) # broadcast_log_proba: (batch_size, num_tags, 1) # broadcast_transitions: (1, num_tags, num_tags) # broadcast_emissions: (batch_size, 1, num_tags) inner = broadcast_log_proba \ + broadcast_transitions \ + broadcast_emissions[i] # Append log proba log_proba.append( (log_sum_exp(inner, 1) * mask[i].view(batch_size, 1) + log_proba[-1] * (1 - mask[i]).view(batch_size, 1))) if reverse_direction: log_proba.reverse() return torch.stack(log_proba)