def _add_margin_inplaced(self, shape, hit_idxes0, hit_idxes1, hit_idxes2, hit_labels, query_idxes0, query_idxes1, query_idxes2, arc_scores, lab_scores, arc_margin: float, lab_margin: float): # arc gold_arc_mat = BK.constants(shape, 0.) gold_arc_mat[hit_idxes0, hit_idxes1, hit_idxes2] = arc_margin gold_arc_margins = gold_arc_mat[query_idxes0, query_idxes1, query_idxes2] arc_scores -= gold_arc_margins if lab_scores is not None: # label gold_lab_mat = BK.constants_idx(shape, 0) # 0 means the padding idx gold_lab_mat[hit_idxes0, hit_idxes1, hit_idxes2] = hit_labels gold_lab_margin_idxes = gold_lab_mat[query_idxes0, query_idxes1, query_idxes2] lab_scores[BK.arange_idx(BK.get_shape(gold_lab_margin_idxes, 0)), gold_lab_margin_idxes] -= lab_margin return
def _decode(self, mb_insts: List[ParseInstance], mb_enc_expr, mb_valid_expr, mb_go1_pack, training: bool, margin: float): # ===== use_sib, use_gp = self.use_sib, self.use_gp # ===== mb_size = len(mb_insts) mat_shape = BK.get_shape(mb_valid_expr) max_slen = mat_shape[-1] # step 1: extract the candidate features batch_idxes, m_idxes, h_idxes, sib_idxes, gp_idxes = self.helper.get_cand_features( mb_valid_expr) # ===== # step 2: high order scoring # step 2.1: basic scoring, [*], [*, Lab] arc_scores, lab_scores = self._get_basic_score(mb_enc_expr, batch_idxes, m_idxes, h_idxes, sib_idxes, gp_idxes) cur_system_labeled = (lab_scores is not None) # step 2.2: margin # get gold labels, which can be useful for later calculating loss if training: gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes, gold_lab_idxes = \ [(None if z is None else BK.input_idx(z)) for z in self._get_idxes_from_insts(mb_insts, use_sib, use_gp)] # add the margins to the scores: (m,h), (m,sib), (m,gp) cur_margin = margin / self.margin_div self._add_margin_inplaced(mat_shape, gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_lab_idxes, batch_idxes, m_idxes, h_idxes, arc_scores, lab_scores, cur_margin, cur_margin) if use_sib: self._add_margin_inplaced(mat_shape, gold_b_idxes, gold_m_idxes, gold_sib_idxes, gold_lab_idxes, batch_idxes, m_idxes, sib_idxes, arc_scores, lab_scores, cur_margin, cur_margin) if use_gp: self._add_margin_inplaced(mat_shape, gold_b_idxes, gold_m_idxes, gold_gp_idxes, gold_lab_idxes, batch_idxes, m_idxes, gp_idxes, arc_scores, lab_scores, cur_margin, cur_margin) # may be useful for later training gold_pack = (mb_size, gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes, gold_lab_idxes) else: gold_pack = None # step 2.3: o1scores if mb_go1_pack is not None: go1_arc_scores, go1_lab_scores = mb_go1_pack # todo(note): go1_arc_scores is not added here, but as the input to the dec-algo if cur_system_labeled: lab_scores += go1_lab_scores[batch_idxes, m_idxes, h_idxes] else: go1_arc_scores = None # step 2.4: max out labels; todo(+N): or using logsumexp here? if cur_system_labeled: max_lab_scores, max_lab_idxes = lab_scores.max(-1) final_scores = arc_scores + max_lab_scores # [*], final input arc scores else: max_lab_idxes = None final_scores = arc_scores # ===== # step 3: actual decode res_heads = [] for sid, inst in enumerate(mb_insts): slen = len(inst) + 1 # plus one for the art-root arr_o1_masks = BK.get_value(mb_valid_expr[sid, :slen, :slen].int()) arr_o1_scores = BK.get_value( go1_arc_scores[sid, :slen, :slen].double()) if ( go1_arc_scores is not None) else None cur_bidx_mask = (batch_idxes == sid) input_pack = [m_idxes, h_idxes, sib_idxes, gp_idxes, final_scores] one_heads = self.helper.decode_one(slen, self.projective, arr_o1_masks, arr_o1_scores, input_pack, cur_bidx_mask) res_heads.append(one_heads) # ===== # step 4: get labels back and pred_pack pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes, _ = \ [(None if z is None else BK.input_idx(z)) for z in self._get_idxes_from_preds(res_heads, None, use_sib, use_gp)] if cur_system_labeled: # obtain hit components pred_hit_mask = self._get_hit_mask(mat_shape, pred_b_idxes, pred_m_idxes, pred_h_idxes, batch_idxes, m_idxes, h_idxes) if use_sib: pred_hit_mask &= self._get_hit_mask(mat_shape, pred_b_idxes, pred_m_idxes, pred_sib_idxes, batch_idxes, m_idxes, sib_idxes) if use_gp: pred_hit_mask &= self._get_hit_mask(mat_shape, pred_b_idxes, pred_m_idxes, pred_gp_idxes, batch_idxes, m_idxes, gp_idxes) # get pred labels (there should be only one hit per mod!) pred_labels = BK.constants_idx([mb_size, max_slen], 0) pred_labels[batch_idxes[pred_hit_mask], m_idxes[pred_hit_mask]] = max_lab_idxes[pred_hit_mask] res_labels = BK.get_value(pred_labels) pred_lab_idxes = pred_labels[pred_b_idxes, pred_m_idxes] else: res_labels = None pred_lab_idxes = None pred_pack = (mb_size, pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes, pred_lab_idxes) # return return res_heads, res_labels, gold_pack, pred_pack