Beispiel #1
0
 def _add_margin_inplaced(self, shape, hit_idxes0, hit_idxes1, hit_idxes2,
                          hit_labels, query_idxes0, query_idxes1,
                          query_idxes2, arc_scores, lab_scores,
                          arc_margin: float, lab_margin: float):
     # arc
     gold_arc_mat = BK.constants(shape, 0.)
     gold_arc_mat[hit_idxes0, hit_idxes1, hit_idxes2] = arc_margin
     gold_arc_margins = gold_arc_mat[query_idxes0, query_idxes1,
                                     query_idxes2]
     arc_scores -= gold_arc_margins
     if lab_scores is not None:
         # label
         gold_lab_mat = BK.constants_idx(shape,
                                         0)  # 0 means the padding idx
         gold_lab_mat[hit_idxes0, hit_idxes1, hit_idxes2] = hit_labels
         gold_lab_margin_idxes = gold_lab_mat[query_idxes0, query_idxes1,
                                              query_idxes2]
         lab_scores[BK.arange_idx(BK.get_shape(gold_lab_margin_idxes, 0)),
                    gold_lab_margin_idxes] -= lab_margin
     return
Beispiel #2
0
 def _decode(self, mb_insts: List[ParseInstance], mb_enc_expr,
             mb_valid_expr, mb_go1_pack, training: bool, margin: float):
     # =====
     use_sib, use_gp = self.use_sib, self.use_gp
     # =====
     mb_size = len(mb_insts)
     mat_shape = BK.get_shape(mb_valid_expr)
     max_slen = mat_shape[-1]
     # step 1: extract the candidate features
     batch_idxes, m_idxes, h_idxes, sib_idxes, gp_idxes = self.helper.get_cand_features(
         mb_valid_expr)
     # =====
     # step 2: high order scoring
     # step 2.1: basic scoring, [*], [*, Lab]
     arc_scores, lab_scores = self._get_basic_score(mb_enc_expr,
                                                    batch_idxes, m_idxes,
                                                    h_idxes, sib_idxes,
                                                    gp_idxes)
     cur_system_labeled = (lab_scores is not None)
     # step 2.2: margin
     # get gold labels, which can be useful for later calculating loss
     if training:
         gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes, gold_lab_idxes = \
             [(None if z is None else BK.input_idx(z)) for z in self._get_idxes_from_insts(mb_insts, use_sib, use_gp)]
         # add the margins to the scores: (m,h), (m,sib), (m,gp)
         cur_margin = margin / self.margin_div
         self._add_margin_inplaced(mat_shape, gold_b_idxes, gold_m_idxes,
                                   gold_h_idxes, gold_lab_idxes,
                                   batch_idxes, m_idxes, h_idxes,
                                   arc_scores, lab_scores, cur_margin,
                                   cur_margin)
         if use_sib:
             self._add_margin_inplaced(mat_shape, gold_b_idxes,
                                       gold_m_idxes, gold_sib_idxes,
                                       gold_lab_idxes, batch_idxes, m_idxes,
                                       sib_idxes, arc_scores, lab_scores,
                                       cur_margin, cur_margin)
         if use_gp:
             self._add_margin_inplaced(mat_shape, gold_b_idxes,
                                       gold_m_idxes, gold_gp_idxes,
                                       gold_lab_idxes, batch_idxes, m_idxes,
                                       gp_idxes, arc_scores, lab_scores,
                                       cur_margin, cur_margin)
         # may be useful for later training
         gold_pack = (mb_size, gold_b_idxes, gold_m_idxes, gold_h_idxes,
                      gold_sib_idxes, gold_gp_idxes, gold_lab_idxes)
     else:
         gold_pack = None
     # step 2.3: o1scores
     if mb_go1_pack is not None:
         go1_arc_scores, go1_lab_scores = mb_go1_pack
         # todo(note): go1_arc_scores is not added here, but as the input to the dec-algo
         if cur_system_labeled:
             lab_scores += go1_lab_scores[batch_idxes, m_idxes, h_idxes]
     else:
         go1_arc_scores = None
     # step 2.4: max out labels; todo(+N): or using logsumexp here?
     if cur_system_labeled:
         max_lab_scores, max_lab_idxes = lab_scores.max(-1)
         final_scores = arc_scores + max_lab_scores  # [*], final input arc scores
     else:
         max_lab_idxes = None
         final_scores = arc_scores
     # =====
     # step 3: actual decode
     res_heads = []
     for sid, inst in enumerate(mb_insts):
         slen = len(inst) + 1  # plus one for the art-root
         arr_o1_masks = BK.get_value(mb_valid_expr[sid, :slen, :slen].int())
         arr_o1_scores = BK.get_value(
             go1_arc_scores[sid, :slen, :slen].double()) if (
                 go1_arc_scores is not None) else None
         cur_bidx_mask = (batch_idxes == sid)
         input_pack = [m_idxes, h_idxes, sib_idxes, gp_idxes, final_scores]
         one_heads = self.helper.decode_one(slen, self.projective,
                                            arr_o1_masks, arr_o1_scores,
                                            input_pack, cur_bidx_mask)
         res_heads.append(one_heads)
     # =====
     # step 4: get labels back and pred_pack
     pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes, _ = \
         [(None if z is None else BK.input_idx(z)) for z in self._get_idxes_from_preds(res_heads, None, use_sib, use_gp)]
     if cur_system_labeled:
         # obtain hit components
         pred_hit_mask = self._get_hit_mask(mat_shape, pred_b_idxes,
                                            pred_m_idxes, pred_h_idxes,
                                            batch_idxes, m_idxes, h_idxes)
         if use_sib:
             pred_hit_mask &= self._get_hit_mask(mat_shape, pred_b_idxes,
                                                 pred_m_idxes,
                                                 pred_sib_idxes,
                                                 batch_idxes, m_idxes,
                                                 sib_idxes)
         if use_gp:
             pred_hit_mask &= self._get_hit_mask(mat_shape, pred_b_idxes,
                                                 pred_m_idxes,
                                                 pred_gp_idxes, batch_idxes,
                                                 m_idxes, gp_idxes)
         # get pred labels (there should be only one hit per mod!)
         pred_labels = BK.constants_idx([mb_size, max_slen], 0)
         pred_labels[batch_idxes[pred_hit_mask],
                     m_idxes[pred_hit_mask]] = max_lab_idxes[pred_hit_mask]
         res_labels = BK.get_value(pred_labels)
         pred_lab_idxes = pred_labels[pred_b_idxes, pred_m_idxes]
     else:
         res_labels = None
         pred_lab_idxes = None
     pred_pack = (mb_size, pred_b_idxes, pred_m_idxes, pred_h_idxes,
                  pred_sib_idxes, pred_gp_idxes, pred_lab_idxes)
     # return
     return res_heads, res_labels, gold_pack, pred_pack