Example #1
0
    def forward(self, scores, target, mask, aid):
        """
        calculate the negative log likehood for the conditional random field.

        Parameters
        ----------
        scores: ``torch.FloatTensor``, required.
            the potential score for the conditional random field, of shape (seq_len, batch_size, from_tag_size, to_tag_size).
        target: ``torch.LongTensor``, required.
            the positive path for the conditional random field, of shape (seq_len, batch_size).
        mask: ``torch.ByteTensor``, required.
            the mask for the unpadded sentence parts, of shape (seq_len, batch_size).

        Returns
        -------
        loss: ``torch.FloatTensor``.
            The NLL loss.
        """
        seq_len = scores.size(0)
        bat_size = scores.size(1)

        tg_energy = torch.gather(scores.view(seq_len, bat_size, -1), 2,
                                 target.unsqueeze(2)).view(seq_len, bat_size)
        tg_energy = tg_energy.masked_select(mask).sum()

        seq_iter = enumerate(scores)

        _, inivalues = seq_iter.__next__()
        partition = inivalues[:, self.start_tag, :].squeeze(1).clone()

        for idx, cur_values in seq_iter:
            cur_values = cur_values + partition.unsqueeze(2).expand(
                bat_size, self.tagset_size, self.tagset_size)

            cur_partition = utils.log_sum_exp(cur_values)

            mask_idx = mask[idx, :].view(bat_size,
                                         1).expand(bat_size, self.tagset_size)
            partition.masked_scatter_(mask_idx,
                                      cur_partition.masked_select(mask_idx))

        partition = partition[:, self.end_tag].sum()

        loss = partition - tg_energy

        if self.task == 'maWeightAnnotator':
            antor_score = F.softmax(self.antor_score, dim=0)
            loss = loss * antor_score[aid]

        if self.average_batch:
            return loss / bat_size
        else:
            return loss
Example #2
0
    def forward(self, scores, targets, mask, a_mask):
        """
        calculate the negative log likehood for the conditional random field.

        Parameters
        ----------
        scores: ``torch.FloatTensor``, required.
            the potential score for the conditional random field, of shape (a_num, seq_len, batch_size, from_tag_size, to_tag_size).
        targets: ``torch.LongTensor``, required.
            the positive path for the conditional random field, of shape (a_num, seq_len, batch_size).
        mask: ``torch.ByteTensor``, required.
            the mask for the unpadded sentence parts, of shape (seq_len, batch_size).
        a_mask: ``torch.ByteTensor``, required.
            the mask for the valid annotator, of shape (a_num, batch_size)

        Returns
        -------
        loss: ``torch.FloatTensor``.
            The NLL loss.
        """
        a_num = scores.size(0)
        seq_len = scores.size(1)
        bat_size = scores.size(2)

        mask_not = mask == 0
        a_mask_not = a_mask == 0

        losses = a_mask.clone().float()

        for aid in range(a_num):
            target = targets[aid]
            score = scores[aid]

            #print('score',score.size())
            #print('target',target.size())
            tg_energy = torch.gather(score.view(seq_len, bat_size, -1), 2,
                                     target.unsqueeze(2)).view(
                                         seq_len, bat_size)
            #tg_energy = tg_energy.masked_select(mask).sum()
            tg_energy = tg_energy.masked_fill_(mask_not, 0).sum(dim=0)

            seq_iter = enumerate(score)

            _, inivalues = seq_iter.__next__()
            partition = inivalues[:, self.start_tag, :].squeeze(1).clone()

            for idx, cur_values in seq_iter:
                cur_values = cur_values + partition.unsqueeze(2).expand(
                    bat_size, self.tagset_size, self.tagset_size)

                cur_partition = utils.log_sum_exp(cur_values)

                mask_idx = mask[idx, :].view(bat_size,
                                             1).expand(bat_size,
                                                       self.tagset_size)
                partition.masked_scatter_(
                    mask_idx, cur_partition.masked_select(mask_idx))

            partition = partition[:, self.end_tag]  # [bat_size]
            losses[aid] = partition - tg_energy
        #print('losses',losses)

        loss = losses.masked_select(a_mask).sum()
        #print('a_mask',a_mask)
        #print('loss',loss)

        if self.average_batch:
            return loss / bat_size
        else:
            return loss