def forward_algo(self, scores, mask): # Forward Algorithm seq_len = scores.size(0) bat_size = scores.size(1) seq_iter = enumerate(scores) # the first score should start with <start> _, inivalues = seq_iter.__next__() # bat_size * from_target_size * to_target_size # only need start from start_tag cur_partition = inivalues[:, self.start_tag, :] # bat_size * to_target_size partition = cur_partition # iter over last scores for idx, cur_values in seq_iter: # previous to_target is current from_target # cur_partition: previous->current results log(exp(from_target)), #(batch_size * from_target) # cur_values: bat_size * from_target * to_target cur_values = cur_values + cur_partition.contiguous().view(bat_size, self.tagset_size, 1).expand(bat_size, self.tagset_size, self.tagset_size) cur_partition = utils.log_sum_exp(cur_values, self.tagset_size) # (bat_size * from_target * to_target) -> (bat_size * to_target) partition = utils.switch(partition.contiguous(), cur_partition.contiguous(), mask[idx].contiguous().view(bat_size, 1).expand(bat_size, self.tagset_size)).contiguous().view(bat_size, -1) #only need end at end_tag # partition = partition[:, self.end_tag].sum() partition = partition[:, self.end_tag] return partition
def predict_batch(self, ner_model, crf_no, f_f, f_p, b_f, b_p, w_f, tg, mask_v, len_v, corpus_mask_v, pred_method): """ calculate score for pre-selected metrics args: ner_model: LM-LSTM-CRF model dataset_loader: loader class for test set """ if ner_model.training: ner_model.eval() scores = ner_model(f_f, f_p, b_f, b_p, w_f, crf_no, corpus_mask_v) assert pred_method in ["M", "U"] if pred_method == "M": # no matter take sigmoid or not, setting undesired scores to -inf neg_inf_scores = autograd.Variable(torch.FloatTensor(np.full(scores.shape, -1e9))).cuda() selected_scores = utils.switch(neg_inf_scores.contiguous(), scores.contiguous(), corpus_mask_v).view(scores.shape) decoded = self.decoder.decode(selected_scores.data, mask_v.data) return decoded, scores if pred_method == "U": decoded = self.decoder.decode(scores.data, mask_v.data) for i in range(decoded.shape[0]): for j in range(decoded.shape[1]): idx_annotated = np.where(corpus_mask_v[i,j,0].cpu().data)[0] if not decoded[i,j] in idx_annotated: decoded[i,j] = self.l_map['O'] return decoded, scores
def forward(self, scores, target, mask): """ args: scores (seq_len, bat_size, target_size_from, target_size_to) : crf scores target (seq_len, bat_size, 1) : golden state mask (size seq_len, bat_size) : mask for padding return: loss """ # calculate batch size and seq len seq_len = scores.size(0) bat_size = scores.size(1) # calculate sentence score tg_energy = torch.gather(scores.view(seq_len, bat_size, -1), 2, target).view(seq_len, bat_size) # seq_len * bat_size tg_energy = tg_energy.masked_select(mask).sum() # calculate forward partition score # build iter seq_iter = enumerate(scores) # the first score should start with <start> _, inivalues = seq_iter.__next__( ) # bat_size * from_target_size * to_target_size # only need start from start_tag partition = inivalues[:, self.start_tag, :].clone( ) # bat_size * to_target_size # iter over last scores for idx, cur_values in seq_iter: # previous to_target is current from_target # partition: previous results log(exp(from_target)), #(batch_size * from_target) # cur_values: bat_size * from_target * to_target cur_values = cur_values + partition.contiguous().view( bat_size, self.tagset_size, 1).expand( bat_size, self.tagset_size, self.tagset_size) cur_partition = utils.log_sum_exp(cur_values, self.tagset_size) # (bat_size * from_target * to_target) -> (bat_size * to_target) partition = utils.switch( partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1) #mask_idx = mask[idx, :].view(bat_size, 1).expand(bat_size, self.tagset_size) #partition.masked_scatter_(mask_idx, cur_partition.masked_select(mask_idx)) #0 for partition, 1 for cur_partition #only need end at end_tag partition = partition[:, self.end_tag].sum() # average = mask.sum() # average_batch if self.average_batch: loss = (partition - tg_energy) / bat_size else: loss = (partition - tg_energy) return loss