def _score(self, input_expr, input_mask, scores_aug_tok=None, scores_aug_sent=None): # token level attention and score # calculate the attention query_tok = self.query_tok # [L, D] query_tok_t = query_tok.transpose(0, 1) # [D, L] att_scores = BK.matmul(input_expr, query_tok_t) # [*, slen, L] att_scores += (1. - input_mask).unsqueeze(-1) * Constants.REAL_PRAC_MIN if scores_aug_tok is not None: # margin att_scores += scores_aug_tok attn = BK.softmax(att_scores, -2) # [*, slen, L] score_tok = (att_scores * attn).sum(-2) # [*, L] # token level labeling softmax attn2 = BK.softmax( att_scores.view(BK.get_shape(att_scores)[:-2] + [-1]), -1) # [*, slen*L] # sent level score query_sent = self.query_sent # [L, D] context_sent = input_expr[:, 0] + input_expr[:, -1] # [*, D], simply adding the two ends score_sent = BK.matmul(context_sent, self.query_sent.transpose(0, 1)) # [*, L] # combine if self.lambda_score_tok < 0.: context_tok = BK.matmul(input_expr.transpose( -1, -2), attn).transpose(-1, -2).contiguous() # [*, L, D] # 4*[*,L,D] -> [*, L] cur_lambda_score_tok = self.score_gate([ context_tok, query_tok.unsqueeze(0), context_sent.unsqueeze(-2), query_sent.unsqueeze(0) ]).squeeze(-1) else: cur_lambda_score_tok = self.lambda_score_tok final_score = score_tok * cur_lambda_score_tok + score_sent * ( 1. - cur_lambda_score_tok) if scores_aug_sent is not None: final_score += scores_aug_sent if self.conf.score_sigmoid: # margin final_score = BK.sigmoid(final_score) return final_score, attn, attn2 # [*, L], [*, slen, L], [*, slen*L]
def inference_on_batch(self, insts: List[ParseInstance], **kwargs): with BK.no_grad_env(): self.refresh_batch(False) # ===== calculate scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score( insts, False) full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr, False, 0.) full_label_score = self._score_label_full(scoring_expr_pack, mask_expr, False, 0.) # normalizing scores full_score = None final_exp_score = False # whether to provide PROB by exp if self.norm_local and self.loss_prob: full_score = BK.log_softmax(full_arc_score, -1).unsqueeze(-1) + BK.log_softmax( full_label_score, -1) final_exp_score = True elif self.norm_hlocal and self.loss_prob: # normalize at m dimension, ignore each nodes's self-finish step. full_score = BK.log_softmax(full_arc_score, -2).unsqueeze(-1) + BK.log_softmax( full_label_score, -1) elif self.norm_single and self.loss_prob: if self.conf.iconf.dec_single_neg: # todo(+2): add all-neg for prob explanation full_arc_probs = BK.sigmoid(full_arc_score) full_label_probs = BK.sigmoid(full_label_score) fake_arc_scores = BK.log(full_arc_probs) - BK.log( 1. - full_arc_probs) fake_label_scores = BK.log(full_label_probs) - BK.log( 1. - full_label_probs) full_score = fake_arc_scores.unsqueeze( -1) + fake_label_scores else: full_score = BK.logsigmoid(full_arc_score).unsqueeze( -1) + BK.logsigmoid(full_label_score) final_exp_score = True else: full_score = full_arc_score.unsqueeze(-1) + full_label_score # decode mst_lengths = [len(z) + 1 for z in insts ] # +=1 to include ROOT for mst decoding mst_heads_arr, mst_labels_arr, mst_scores_arr = self._decode( full_score, mask_expr, np.asarray(mst_lengths, dtype=np.int32)) if final_exp_score: mst_scores_arr = np.exp(mst_scores_arr) # jpos prediction (directly index, no converting as in parsing) jpos_preds_expr = jpos_pack[2] has_jpos_pred = jpos_preds_expr is not None jpos_preds_arr = BK.get_value( jpos_preds_expr) if has_jpos_pred else None # ===== assign info = {"sent": len(insts), "tok": sum(mst_lengths) - len(insts)} mst_real_labels = self.pred2real_labels(mst_labels_arr) for one_idx, one_inst in enumerate(insts): cur_length = mst_lengths[one_idx] one_inst.pred_heads.set_vals( mst_heads_arr[one_idx] [:cur_length]) # directly int-val for heads one_inst.pred_labels.build_vals( mst_real_labels[one_idx][:cur_length], self.label_vocab) one_inst.pred_par_scores.set_vals( mst_scores_arr[one_idx][:cur_length]) if has_jpos_pred: one_inst.pred_poses.build_vals( jpos_preds_arr[one_idx][:cur_length], self.bter.pos_vocab) return info