コード例 #1
0
 def _my_loss_prob(self, score_expr, gold_idxes_expr, entropy_lambda: float,
                   loss_mask, neg_reweight: bool):
     probs = BK.softmax(score_expr, -1)  # [*, NLab]
     log_probs = BK.log(probs + 1e-8)
     # first plain NLL loss
     nll_loss = -BK.gather_one_lastdim(log_probs,
                                       gold_idxes_expr).squeeze(-1)
     # next the special loss
     if entropy_lambda > 0.:
         negative_entropy = probs * log_probs  # [*, NLab]
         last_dim = BK.get_shape(score_expr, -1)
         confusion_matrix = 1. - BK.eye(last_dim)  # [Nlab, Nlab]
         entropy_mask = confusion_matrix[gold_idxes_expr]  # [*, Nlab]
         entropy_loss = (negative_entropy * entropy_mask).sum(-1)
         final_loss = nll_loss + entropy_lambda * entropy_loss
     else:
         final_loss = nll_loss
     # reweight?
     if neg_reweight:
         golden_prob = BK.gather_one_lastdim(probs,
                                             gold_idxes_expr).squeeze(-1)
         is_full_nil = (gold_idxes_expr == 0.).float()
         not_full_nil = 1. - is_full_nil
         count_pos = (loss_mask * not_full_nil).sum()
         count_neg = (loss_mask * is_full_nil).sum()
         prob_pos = (loss_mask * not_full_nil * golden_prob).sum()
         prob_neg = (loss_mask * is_full_nil * golden_prob).sum()
         neg_weight = prob_pos / (count_pos + count_neg - prob_neg + 1e-8)
         final_weights = not_full_nil + is_full_nil * neg_weight
         # todo(note): final mask will be applied at outside
         final_loss = final_loss * final_weights
     return final_loss
コード例 #2
0
ファイル: parser.py プロジェクト: ValentinaPy/zmsp
 def inference_on_batch(self, insts: List[ParseInstance], **kwargs):
     with BK.no_grad_env():
         self.refresh_batch(False)
         # ===== calculate
         scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score(
             insts, False)
         full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr,
                                               False, 0.)
         full_label_score = self._score_label_full(scoring_expr_pack,
                                                   mask_expr, False, 0.)
         # normalizing scores
         full_score = None
         final_exp_score = False  # whether to provide PROB by exp
         if self.norm_local and self.loss_prob:
             full_score = BK.log_softmax(full_arc_score,
                                         -1).unsqueeze(-1) + BK.log_softmax(
                                             full_label_score, -1)
             final_exp_score = True
         elif self.norm_hlocal and self.loss_prob:
             # normalize at m dimension, ignore each nodes's self-finish step.
             full_score = BK.log_softmax(full_arc_score,
                                         -2).unsqueeze(-1) + BK.log_softmax(
                                             full_label_score, -1)
         elif self.norm_single and self.loss_prob:
             if self.conf.iconf.dec_single_neg:
                 # todo(+2): add all-neg for prob explanation
                 full_arc_probs = BK.sigmoid(full_arc_score)
                 full_label_probs = BK.sigmoid(full_label_score)
                 fake_arc_scores = BK.log(full_arc_probs) - BK.log(
                     1. - full_arc_probs)
                 fake_label_scores = BK.log(full_label_probs) - BK.log(
                     1. - full_label_probs)
                 full_score = fake_arc_scores.unsqueeze(
                     -1) + fake_label_scores
             else:
                 full_score = BK.logsigmoid(full_arc_score).unsqueeze(
                     -1) + BK.logsigmoid(full_label_score)
                 final_exp_score = True
         else:
             full_score = full_arc_score.unsqueeze(-1) + full_label_score
         # decode
         mst_lengths = [len(z) + 1 for z in insts
                        ]  # +=1 to include ROOT for mst decoding
         mst_heads_arr, mst_labels_arr, mst_scores_arr = self._decode(
             full_score, mask_expr, np.asarray(mst_lengths, dtype=np.int32))
         if final_exp_score:
             mst_scores_arr = np.exp(mst_scores_arr)
         # jpos prediction (directly index, no converting as in parsing)
         jpos_preds_expr = jpos_pack[2]
         has_jpos_pred = jpos_preds_expr is not None
         jpos_preds_arr = BK.get_value(
             jpos_preds_expr) if has_jpos_pred else None
         # ===== assign
         info = {"sent": len(insts), "tok": sum(mst_lengths) - len(insts)}
         mst_real_labels = self.pred2real_labels(mst_labels_arr)
         for one_idx, one_inst in enumerate(insts):
             cur_length = mst_lengths[one_idx]
             one_inst.pred_heads.set_vals(
                 mst_heads_arr[one_idx]
                 [:cur_length])  # directly int-val for heads
             one_inst.pred_labels.build_vals(
                 mst_real_labels[one_idx][:cur_length], self.label_vocab)
             one_inst.pred_par_scores.set_vals(
                 mst_scores_arr[one_idx][:cur_length])
             if has_jpos_pred:
                 one_inst.pred_poses.build_vals(
                     jpos_preds_arr[one_idx][:cur_length],
                     self.bter.pos_vocab)
         return info