def _my_loss_prob(self, score_expr, gold_idxes_expr, entropy_lambda: float, loss_mask, neg_reweight: bool): probs = BK.softmax(score_expr, -1) # [*, NLab] log_probs = BK.log(probs + 1e-8) # first plain NLL loss nll_loss = -BK.gather_one_lastdim(log_probs, gold_idxes_expr).squeeze(-1) # next the special loss if entropy_lambda > 0.: negative_entropy = probs * log_probs # [*, NLab] last_dim = BK.get_shape(score_expr, -1) confusion_matrix = 1. - BK.eye(last_dim) # [Nlab, Nlab] entropy_mask = confusion_matrix[gold_idxes_expr] # [*, Nlab] entropy_loss = (negative_entropy * entropy_mask).sum(-1) final_loss = nll_loss + entropy_lambda * entropy_loss else: final_loss = nll_loss # reweight? if neg_reweight: golden_prob = BK.gather_one_lastdim(probs, gold_idxes_expr).squeeze(-1) is_full_nil = (gold_idxes_expr == 0.).float() not_full_nil = 1. - is_full_nil count_pos = (loss_mask * not_full_nil).sum() count_neg = (loss_mask * is_full_nil).sum() prob_pos = (loss_mask * not_full_nil * golden_prob).sum() prob_neg = (loss_mask * is_full_nil * golden_prob).sum() neg_weight = prob_pos / (count_pos + count_neg - prob_neg + 1e-8) final_weights = not_full_nil + is_full_nil * neg_weight # todo(note): final mask will be applied at outside final_loss = final_loss * final_weights return final_loss
def inference_on_batch(self, insts: List[ParseInstance], **kwargs): with BK.no_grad_env(): self.refresh_batch(False) # ===== calculate scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score( insts, False) full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr, False, 0.) full_label_score = self._score_label_full(scoring_expr_pack, mask_expr, False, 0.) # normalizing scores full_score = None final_exp_score = False # whether to provide PROB by exp if self.norm_local and self.loss_prob: full_score = BK.log_softmax(full_arc_score, -1).unsqueeze(-1) + BK.log_softmax( full_label_score, -1) final_exp_score = True elif self.norm_hlocal and self.loss_prob: # normalize at m dimension, ignore each nodes's self-finish step. full_score = BK.log_softmax(full_arc_score, -2).unsqueeze(-1) + BK.log_softmax( full_label_score, -1) elif self.norm_single and self.loss_prob: if self.conf.iconf.dec_single_neg: # todo(+2): add all-neg for prob explanation full_arc_probs = BK.sigmoid(full_arc_score) full_label_probs = BK.sigmoid(full_label_score) fake_arc_scores = BK.log(full_arc_probs) - BK.log( 1. - full_arc_probs) fake_label_scores = BK.log(full_label_probs) - BK.log( 1. - full_label_probs) full_score = fake_arc_scores.unsqueeze( -1) + fake_label_scores else: full_score = BK.logsigmoid(full_arc_score).unsqueeze( -1) + BK.logsigmoid(full_label_score) final_exp_score = True else: full_score = full_arc_score.unsqueeze(-1) + full_label_score # decode mst_lengths = [len(z) + 1 for z in insts ] # +=1 to include ROOT for mst decoding mst_heads_arr, mst_labels_arr, mst_scores_arr = self._decode( full_score, mask_expr, np.asarray(mst_lengths, dtype=np.int32)) if final_exp_score: mst_scores_arr = np.exp(mst_scores_arr) # jpos prediction (directly index, no converting as in parsing) jpos_preds_expr = jpos_pack[2] has_jpos_pred = jpos_preds_expr is not None jpos_preds_arr = BK.get_value( jpos_preds_expr) if has_jpos_pred else None # ===== assign info = {"sent": len(insts), "tok": sum(mst_lengths) - len(insts)} mst_real_labels = self.pred2real_labels(mst_labels_arr) for one_idx, one_inst in enumerate(insts): cur_length = mst_lengths[one_idx] one_inst.pred_heads.set_vals( mst_heads_arr[one_idx] [:cur_length]) # directly int-val for heads one_inst.pred_labels.build_vals( mst_real_labels[one_idx][:cur_length], self.label_vocab) one_inst.pred_par_scores.set_vals( mst_scores_arr[one_idx][:cur_length]) if has_jpos_pred: one_inst.pred_poses.build_vals( jpos_preds_arr[one_idx][:cur_length], self.bter.pos_vocab) return info