def forward(self, scores, target, mask, aid): """ calculate the negative log likehood for the conditional random field. Parameters ---------- scores: ``torch.FloatTensor``, required. the potential score for the conditional random field, of shape (seq_len, batch_size, from_tag_size, to_tag_size). target: ``torch.LongTensor``, required. the positive path for the conditional random field, of shape (seq_len, batch_size). mask: ``torch.ByteTensor``, required. the mask for the unpadded sentence parts, of shape (seq_len, batch_size). Returns ------- loss: ``torch.FloatTensor``. The NLL loss. """ seq_len = scores.size(0) bat_size = scores.size(1) tg_energy = torch.gather(scores.view(seq_len, bat_size, -1), 2, target.unsqueeze(2)).view(seq_len, bat_size) tg_energy = tg_energy.masked_select(mask).sum() seq_iter = enumerate(scores) _, inivalues = seq_iter.__next__() partition = inivalues[:, self.start_tag, :].squeeze(1).clone() for idx, cur_values in seq_iter: cur_values = cur_values + partition.unsqueeze(2).expand( bat_size, self.tagset_size, self.tagset_size) cur_partition = utils.log_sum_exp(cur_values) mask_idx = mask[idx, :].view(bat_size, 1).expand(bat_size, self.tagset_size) partition.masked_scatter_(mask_idx, cur_partition.masked_select(mask_idx)) partition = partition[:, self.end_tag].sum() loss = partition - tg_energy if self.task == 'maWeightAnnotator': antor_score = F.softmax(self.antor_score, dim=0) loss = loss * antor_score[aid] if self.average_batch: return loss / bat_size else: return loss
def forward(self, scores, targets, mask, a_mask): """ calculate the negative log likehood for the conditional random field. Parameters ---------- scores: ``torch.FloatTensor``, required. the potential score for the conditional random field, of shape (a_num, seq_len, batch_size, from_tag_size, to_tag_size). targets: ``torch.LongTensor``, required. the positive path for the conditional random field, of shape (a_num, seq_len, batch_size). mask: ``torch.ByteTensor``, required. the mask for the unpadded sentence parts, of shape (seq_len, batch_size). a_mask: ``torch.ByteTensor``, required. the mask for the valid annotator, of shape (a_num, batch_size) Returns ------- loss: ``torch.FloatTensor``. The NLL loss. """ a_num = scores.size(0) seq_len = scores.size(1) bat_size = scores.size(2) mask_not = mask == 0 a_mask_not = a_mask == 0 losses = a_mask.clone().float() for aid in range(a_num): target = targets[aid] score = scores[aid] #print('score',score.size()) #print('target',target.size()) tg_energy = torch.gather(score.view(seq_len, bat_size, -1), 2, target.unsqueeze(2)).view( seq_len, bat_size) #tg_energy = tg_energy.masked_select(mask).sum() tg_energy = tg_energy.masked_fill_(mask_not, 0).sum(dim=0) seq_iter = enumerate(score) _, inivalues = seq_iter.__next__() partition = inivalues[:, self.start_tag, :].squeeze(1).clone() for idx, cur_values in seq_iter: cur_values = cur_values + partition.unsqueeze(2).expand( bat_size, self.tagset_size, self.tagset_size) cur_partition = utils.log_sum_exp(cur_values) mask_idx = mask[idx, :].view(bat_size, 1).expand(bat_size, self.tagset_size) partition.masked_scatter_( mask_idx, cur_partition.masked_select(mask_idx)) partition = partition[:, self.end_tag] # [bat_size] losses[aid] = partition - tg_energy #print('losses',losses) loss = losses.masked_select(a_mask).sum() #print('a_mask',a_mask) #print('loss',loss) if self.average_batch: return loss / bat_size else: return loss