Esempio n. 1
0
 def forward_algo(self, scores, mask):
     # Forward Algorithm
     seq_len = scores.size(0)
     bat_size = scores.size(1)
     
     seq_iter = enumerate(scores)
     # the first score should start with <start>
     _, inivalues = seq_iter.__next__()  # bat_size * from_target_size * to_target_size
     # only need start from start_tag
     cur_partition = inivalues[:, self.start_tag, :]  # bat_size * to_target_size
     partition = cur_partition
     # iter over last scores
     for idx, cur_values in seq_iter:
         # previous to_target is current from_target
         # cur_partition: previous->current results log(exp(from_target)), #(batch_size * from_target)
         # cur_values: bat_size * from_target * to_target            
         cur_values = cur_values + cur_partition.contiguous().view(bat_size, self.tagset_size, 1).expand(bat_size, self.tagset_size, self.tagset_size)
         cur_partition = utils.log_sum_exp(cur_values, self.tagset_size)
               # (bat_size * from_target * to_target) -> (bat_size * to_target)
         partition = utils.switch(partition.contiguous(), cur_partition.contiguous(),
                                  mask[idx].contiguous().view(bat_size, 1).expand(bat_size, self.tagset_size)).contiguous().view(bat_size, -1)
     
     #only need end at end_tag
     # partition = partition[:, self.end_tag].sum()
     partition = partition[:, self.end_tag]
     
     return partition
Esempio n. 2
0
    def get_loss(self, scores, target, mask):
        """
        calculate viterbi loss

        args:
            scores (seq_len, bat_size, target_size_from, target_size_to) : class score for CRF
            target (seq_len, bat_size, 1) : crf label
            mask   (seq_len, bat_size) : mask for crf label

        """

        seq_len = scores.size(0)
        bat_size = scores.size(1)

        tg_energy = torch.gather(scores.view(seq_len, bat_size, -1), 2,
                                 target).view(seq_len,
                                              bat_size)  # seq_len * bat_size
        tg_energy = tg_energy.masked_select(mask).sum()

        seq_iter = enumerate(scores)
        _, inivalues = next(seq_iter)
        partition = inivalues[:, self.start_tag, :].clone()
        for idx, cur_values in seq_iter:
            cur_values = cur_values + partition.contiguous().view(bat_size, self.tagset_size, 1).\
                expand(bat_size, self.tagset_size, self.tagset_size)
            cur_partition = utils.log_sum_exp(cur_values, self.tagset_size)
            mask_idx = mask[idx, :].view(bat_size,
                                         1).expand(bat_size, self.tagset_size)
            partition.masked_scatter_(mask_idx,
                                      cur_partition.masked_select(mask_idx))

        partition = partition[:, self.end_tag].sum()
        loss = (partition - tg_energy) / bat_size

        return loss
Esempio n. 3
0
    def forward(self, scores, target, mask):
        """
        args:
            scores (seq_len, bat_size, target_size_from, target_size_to) : crf scores
            target (seq_len, bat_size, 1) : golden state
            mask (size seq_len, bat_size) : mask for padding
        return:
            loss
        """

        # calculate batch size and seq len
        seq_len = scores.size(0)
        bat_size = scores.size(1)

        # calculate sentence score
        tg_energy = torch.gather(scores.view(seq_len, bat_size, -1), 2,
                                 target).view(seq_len,
                                              bat_size)  # seq_len * bat_size
        tg_energy = tg_energy.masked_select(mask).sum()

        # calculate forward partition score

        # build iter
        seq_iter = enumerate(scores)
        # the first score should start with <start>
        _, inivalues = seq_iter.__next__(
        )  # bat_size * from_target_size * to_target_size
        # only need start from start_tag
        partition = inivalues[:, self.start_tag, :].clone(
        )  # bat_size * to_target_size
        # iter over last scores
        for idx, cur_values in seq_iter:
            # previous to_target is current from_target
            # partition: previous results log(exp(from_target)), #(batch_size * from_target)
            # cur_values: bat_size * from_target * to_target
            cur_values = cur_values + partition.contiguous().view(
                bat_size, self.tagset_size, 1).expand(
                    bat_size, self.tagset_size, self.tagset_size)
            cur_partition = utils.log_sum_exp(cur_values, self.tagset_size)
            # (bat_size * from_target * to_target) -> (bat_size * to_target)
            partition = utils.switch(
                partition, cur_partition,
                mask[idx].view(bat_size,
                               1).expand(bat_size,
                                         self.tagset_size)).view(bat_size, -1)
            #mask_idx = mask[idx, :].view(bat_size, 1).expand(bat_size, self.tagset_size)
            #partition.masked_scatter_(mask_idx, cur_partition.masked_select(mask_idx))  #0 for partition, 1 for cur_partition

        #only need end at end_tag
        partition = partition[:, self.end_tag].sum()
        # average = mask.sum()

        # average_batch
        if self.average_batch:
            loss = (partition - tg_energy) / bat_size
        else:
            loss = (partition - tg_energy)

        return loss
Esempio n. 4
0
    def forward(self, scores, target, mask):
        """
        args:
            scores (seq_len, bat_size, target_size_from, target_size_to) : crf scores
            target (seq_len, bat_size, 1) : golden state
            mask (size seq_len, bat_size) : mask for padding
        return:
            loss
        """

        # calculate batch size and seq len
        seq_len = scores.size(0)
        bat_size = scores.size(1)

        # calculate sentence score
        tg_energy = torch.gather(scores.view(seq_len, bat_size, -1), 2, target).view(seq_len, bat_size)  # seq_len * bat_size
        tg_energy = tg_energy.masked_select(mask).sum()

        # calculate forward partition score

        # build iter
        seq_iter = enumerate(scores)
        # the first score should start with <start>
        _, inivalues = seq_iter.__next__()  # bat_size * from_target_size * to_target_size
        # only need start from start_tag
        partition = inivalues[:, self.start_tag, :].clone()  # bat_size * to_target_size
        # iter over last scores
        for idx, cur_values in seq_iter:
            # previous to_target is current from_target
            # partition: previous results log(exp(from_target)), #(batch_size * from_target)
            # cur_values: bat_size * from_target * to_target
            cur_values = cur_values + partition.contiguous().view(bat_size, self.tagset_size, 1).expand(bat_size, self.tagset_size, self.tagset_size)
            cur_partition = utils.log_sum_exp(cur_values, self.tagset_size)
            # (bat_size * from_target * to_target) -> (bat_size * to_target)
            # partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1)
            mask_idx = mask[idx, :].view(bat_size, 1).expand(bat_size, self.tagset_size)
            partition.masked_scatter_(mask_idx, cur_partition.masked_select(mask_idx))  #0 for partition, 1 for cur_partition

        #only need end at end_tag
        partition = partition[:, self.end_tag].sum()
        # average = mask.sum()

        # average_batch
        if self.average_batch:
            loss = (partition - tg_energy) / bat_size
        else:
            loss = (partition - tg_energy)

        return loss