def _denominator_score(self, emissions: torch.Tensor,
                           mask: torch.ByteTensor) -> torch.Tensor:
        """
        Parameters:
            emissions: (batch_size, sequence_length, num_tags)
            mask: Show padding tags. 0 don't calculate score. (batch_size, sequence_length)
        Returns:
            scores: (batch_size)
        """
        batch_size, sequence_length, num_tags = emissions.data.shape

        emissions = emissions.transpose(0, 1).contiguous()
        mask = mask.float().transpose(0, 1).contiguous()

        # Start transition score and first emissions score
        alpha = self.start_transitions.view(1, num_tags) + emissions[0]
        for i in range(1, sequence_length):

            emissions_score = emissions[i].view(
                batch_size, 1, num_tags)  # (batch_size, 1, num_tags)
            transition_scores = self.transitions.view(
                1, num_tags, num_tags)  # (1, num_tags, num_tags)
            broadcast_alpha = alpha.view(batch_size, num_tags,
                                         1)  # (batch_size, num_tags, 1)

            inner = broadcast_alpha + emissions_score + transition_scores  # (batch_size, num_tags, num_tags)

            alpha = (log_sum_exp(inner, 1) * mask[i].view(batch_size, 1) +
                     alpha * (1 - mask[i]).view(batch_size, 1))

        # Add end transition score
        stops = alpha + self.end_transitions.view(1, num_tags)
        return log_sum_exp(stops)  # (batch_size,)
    def _numerator_score(self, emissions: torch.Tensor,
                         marginal_tags: torch.LongTensor,
                         mask: torch.ByteTensor) -> torch.Tensor:
        """
        Parameters:
            emissions: (batch_size, sequence_length, num_tags)
            marginal_tags:  (batch_size, sequence_length, num_tags)
            mask:  Show padding tags. 0 don't calculate score. (batch_size, sequence_length)
        Returns:
            scores: (batch_size)
        """

        batch_size, sequence_length, num_tags = emissions.data.shape

        emissions = emissions.transpose(0, 1).contiguous()
        mask = mask.float().transpose(0, 1).contiguous()
        marginal_tags = marginal_tags.float().transpose(0, 1)
        log_marginal_tags = torch.log(marginal_tags)
        log_marginal_tags[log_marginal_tags ==
                          -float('inf')] = IMPOSSIBLE_SCORE

        # Start transition score and first emission
        alpha = self.start_transitions + emissions[0] + log_marginal_tags[0]

        for i in range(1, sequence_length):
            log_next_marginal_tags = log_marginal_tags[
                i]  # (batch_size, num_tags)

            # Emissions scores
            emissions_score = emissions[i].view(batch_size, 1, num_tags)
            # Transition scores
            transition_scores = self.transitions.view(1, num_tags,
                                                      num_tags).expand(
                                                          batch_size, num_tags,
                                                          num_tags).clone()
            # Broadcast alpha
            broadcast_alpha = alpha.view(batch_size, num_tags, 1)
            # Add all scores
            inner = broadcast_alpha + emissions_score + transition_scores  # (batch_size, num_tags, num_tags)
            alpha = (log_sum_exp(inner, 1) * mask[i].view(batch_size, 1) +
                     alpha * (1 - mask[i]).view(batch_size, 1))
            alpha += log_next_marginal_tags * mask[i].view(batch_size, 1)

        # Add end transition score
        last_tag_indexes = mask.sum(0).long() - 1
        end_transitions = self.end_transitions.expand(batch_size, num_tags)
        stops = alpha + end_transitions
        return log_sum_exp(stops)  # (batch_size,)
Beispiel #3
0
    def marginal_probabilities(
            self,
            emissions: torch.Tensor,
            mask: Optional[torch.ByteTensor] = None) -> torch.FloatTensor:
        """
        Parameters:
            emissions: (batch_size, sequence_length, num_tags)
            mask:  Show padding tags. 0 don't calculate score. (batch_size, sequence_length)
        Returns:
            marginal_probabilities: (sequence_length, sequence_length, num_tags)
        """
        if mask is None:
            batch_size, sequence_length, _ = emissions.data.shape
            mask = torch.ones([batch_size, sequence_length],
                              dtype=torch.uint8,
                              device=emissions.device)

        alpha = self._forward_algorithm(emissions,
                                        mask,
                                        reverse_direction=False)
        beta = self._forward_algorithm(emissions, mask, reverse_direction=True)
        z = log_sum_exp(alpha[alpha.size(0) - 1] + self.end_transitions, dim=1)

        proba = alpha + beta - z.view(1, -1, 1)
        return torch.exp(proba)
Beispiel #4
0
    def _forward_algorithm(
            self,
            emissions: torch.Tensor,
            mask: torch.ByteTensor,
            reverse_direction: bool = False) -> torch.FloatTensor:
        """
        Parameters:
            emissions: (batch_size, sequence_length, num_tags)
            mask:  Show padding tags. 0 don't calculate score. (batch_size, sequence_length)
            reverse_direction: This parameter decide algorithm direction.
        Returns:
            log_probabilities: (sequence_length, batch_size, num_tags)
        """
        batch_size, sequence_length, num_tags = emissions.data.shape

        broadcast_emissions = emissions.transpose(0, 1).unsqueeze(
            2).contiguous()  # (sequence_length, batch_size, 1, num_tags)
        mask = mask.float().transpose(
            0, 1).contiguous()  # (sequence_length, batch_size)
        broadcast_transitions = self.transitions.unsqueeze(
            0)  # (1, num_tags, num_tags)
        sequence_iter = range(1, sequence_length)

        # backward algorithm
        if reverse_direction:
            # Transpose transitions matrix and emissions
            broadcast_transitions = broadcast_transitions.transpose(
                1, 2)  # (1, num_tags, num_tags)
            broadcast_emissions = broadcast_emissions.transpose(
                2, 3)  # (sequence_length, batch_size, num_tags, 1)
            sequence_iter = reversed(sequence_iter)

            # It is beta
            log_proba = [self.end_transitions.expand(batch_size, num_tags)]
        # forward algorithm
        else:
            # It is alpha
            log_proba = [
                emissions.transpose(0, 1)[0] +
                self.start_transitions.view(1, -1)
            ]

        for i in sequence_iter:
            # Broadcast log probability
            broadcast_log_proba = log_proba[-1].unsqueeze(
                2)  # (batch_size, num_tags, 1)

            # Add all scores
            # inner: (batch_size, num_tags, num_tags)
            # broadcast_log_proba:   (batch_size, num_tags, 1)
            # broadcast_transitions: (1, num_tags, num_tags)
            # broadcast_emissions:   (batch_size, 1, num_tags)
            inner = broadcast_log_proba \
                    + broadcast_transitions \
                    + broadcast_emissions[i]

            # Append log proba
            log_proba.append(
                (log_sum_exp(inner, 1) * mask[i].view(batch_size, 1) +
                 log_proba[-1] * (1 - mask[i]).view(batch_size, 1)))

        if reverse_direction:
            log_proba.reverse()

        return torch.stack(log_proba)