示例#1
0
 def _score(self,
            input_expr,
            input_mask,
            scores_aug_tok=None,
            scores_aug_sent=None):
     # token level attention and score
     # calculate the attention
     query_tok = self.query_tok  # [L, D]
     query_tok_t = query_tok.transpose(0, 1)  # [D, L]
     att_scores = BK.matmul(input_expr, query_tok_t)  # [*, slen, L]
     att_scores += (1. - input_mask).unsqueeze(-1) * Constants.REAL_PRAC_MIN
     if scores_aug_tok is not None:  # margin
         att_scores += scores_aug_tok
     attn = BK.softmax(att_scores, -2)  # [*, slen, L]
     score_tok = (att_scores * attn).sum(-2)  # [*, L]
     # token level labeling softmax
     attn2 = BK.softmax(
         att_scores.view(BK.get_shape(att_scores)[:-2] + [-1]),
         -1)  # [*, slen*L]
     # sent level score
     query_sent = self.query_sent  # [L, D]
     context_sent = input_expr[:,
                               0] + input_expr[:,
                                               -1]  # [*, D], simply adding the two ends
     score_sent = BK.matmul(context_sent,
                            self.query_sent.transpose(0, 1))  # [*, L]
     # combine
     if self.lambda_score_tok < 0.:
         context_tok = BK.matmul(input_expr.transpose(
             -1, -2), attn).transpose(-1, -2).contiguous()  # [*, L, D]
         # 4*[*,L,D] -> [*, L]
         cur_lambda_score_tok = self.score_gate([
             context_tok,
             query_tok.unsqueeze(0),
             context_sent.unsqueeze(-2),
             query_sent.unsqueeze(0)
         ]).squeeze(-1)
     else:
         cur_lambda_score_tok = self.lambda_score_tok
     final_score = score_tok * cur_lambda_score_tok + score_sent * (
         1. - cur_lambda_score_tok)
     if scores_aug_sent is not None:
         final_score += scores_aug_sent
     if self.conf.score_sigmoid:  # margin
         final_score = BK.sigmoid(final_score)
     return final_score, attn, attn2  # [*, L], [*, slen, L], [*, slen*L]
示例#2
0
 def inference_on_batch(self, insts: List[ParseInstance], **kwargs):
     with BK.no_grad_env():
         self.refresh_batch(False)
         # ===== calculate
         scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score(
             insts, False)
         full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr,
                                               False, 0.)
         full_label_score = self._score_label_full(scoring_expr_pack,
                                                   mask_expr, False, 0.)
         # normalizing scores
         full_score = None
         final_exp_score = False  # whether to provide PROB by exp
         if self.norm_local and self.loss_prob:
             full_score = BK.log_softmax(full_arc_score,
                                         -1).unsqueeze(-1) + BK.log_softmax(
                                             full_label_score, -1)
             final_exp_score = True
         elif self.norm_hlocal and self.loss_prob:
             # normalize at m dimension, ignore each nodes's self-finish step.
             full_score = BK.log_softmax(full_arc_score,
                                         -2).unsqueeze(-1) + BK.log_softmax(
                                             full_label_score, -1)
         elif self.norm_single and self.loss_prob:
             if self.conf.iconf.dec_single_neg:
                 # todo(+2): add all-neg for prob explanation
                 full_arc_probs = BK.sigmoid(full_arc_score)
                 full_label_probs = BK.sigmoid(full_label_score)
                 fake_arc_scores = BK.log(full_arc_probs) - BK.log(
                     1. - full_arc_probs)
                 fake_label_scores = BK.log(full_label_probs) - BK.log(
                     1. - full_label_probs)
                 full_score = fake_arc_scores.unsqueeze(
                     -1) + fake_label_scores
             else:
                 full_score = BK.logsigmoid(full_arc_score).unsqueeze(
                     -1) + BK.logsigmoid(full_label_score)
                 final_exp_score = True
         else:
             full_score = full_arc_score.unsqueeze(-1) + full_label_score
         # decode
         mst_lengths = [len(z) + 1 for z in insts
                        ]  # +=1 to include ROOT for mst decoding
         mst_heads_arr, mst_labels_arr, mst_scores_arr = self._decode(
             full_score, mask_expr, np.asarray(mst_lengths, dtype=np.int32))
         if final_exp_score:
             mst_scores_arr = np.exp(mst_scores_arr)
         # jpos prediction (directly index, no converting as in parsing)
         jpos_preds_expr = jpos_pack[2]
         has_jpos_pred = jpos_preds_expr is not None
         jpos_preds_arr = BK.get_value(
             jpos_preds_expr) if has_jpos_pred else None
         # ===== assign
         info = {"sent": len(insts), "tok": sum(mst_lengths) - len(insts)}
         mst_real_labels = self.pred2real_labels(mst_labels_arr)
         for one_idx, one_inst in enumerate(insts):
             cur_length = mst_lengths[one_idx]
             one_inst.pred_heads.set_vals(
                 mst_heads_arr[one_idx]
                 [:cur_length])  # directly int-val for heads
             one_inst.pred_labels.build_vals(
                 mst_real_labels[one_idx][:cur_length], self.label_vocab)
             one_inst.pred_par_scores.set_vals(
                 mst_scores_arr[one_idx][:cur_length])
             if has_jpos_pred:
                 one_inst.pred_poses.build_vals(
                     jpos_preds_arr[one_idx][:cur_length],
                     self.bter.pos_vocab)
         return info