Exemplo n.º 1
0
 def _predict(self, final_score, attn, attn2, input_mask):
     # attn[:, 0] = 0.  # not for artificial root # todo(note): input_mask should already exclude idx=0
     attn *= input_mask.unsqueeze(-1)  # not for invalid tokens
     conf = self.conf
     # get sentence level types
     pred_sent_max_num = conf.pred_sent_max_num
     pred_sent_abs_thresh = conf.pred_sent_abs_thresh
     # todo(note): kthvalue is not available at gpu and only support smallest!!
     # s_thresh00 = final_score.kthvalue(pred_sent_max_num, -1, keepdim=True)[0]  # [*, 1]
     s_thresh0 = final_score.topk(pred_sent_max_num,
                                  -1)[0].min(-1, keepdim=True)[0]  # [*, 1]
     s_thresh = s_thresh0.clamp(min=pred_sent_abs_thresh)
     sent_pred_mask = (final_score >= s_thresh).float()  # [*, L]
     # get token level as triggers (each type how many triggers && each token how many types)
     pred_tok_max_num = conf.pred_tok_max_num
     pred_tok_rel_thresh = conf.pred_tok_rel_thresh
     pred_tok_abs_thresh = conf.pred_tok_abs_thresh
     all_tok_pred_masks = []
     # for a, r in zip([attn, attn2], [-2, -1]):
     for a, r in zip([attn], [-2]):
         # t_thresh00 = attn.kthvalue(pred_tok_max_num, r, keepdim=True)[0]  # [*, 1, L]
         t_thresh0 = a.topk(pred_tok_max_num,
                            r)[0].min(r, keepdim=True)[0]  # [*, 1, L]
         t_max_value = a.max(r, keepdim=True)[0]  # [*, 1, L]
         t_thresh1 = BK.max_elem(t_thresh0,
                                 t_max_value * pred_tok_rel_thresh)
         t_thresh = BK.max_elem(t_thresh1,
                                t_max_value - pred_tok_abs_thresh)
         one_tok_pred_mask = (a >= t_thresh).float()  # [*, slen, L]
         if r == -1:
             one_tok_pred_mask = one_tok_pred_mask.view(
                 BK.get_shape(attn))  # back to dim=3
         all_tok_pred_masks.append(one_tok_pred_mask)
     # put them together
     # final_pred_mask = all_tok_pred_masks[0] * all_tok_pred_masks[1] * (sent_pred_mask.unsqueeze(-2))  # [*, slen, L]
     final_pred_mask = all_tok_pred_masks[0] * (sent_pred_mask.unsqueeze(-2)
                                                )  # [*, slen, L]
     final_pred_mask[:, :, 0] = 0.  # not for nil-type
     final_pred_mask *= input_mask.unsqueeze(-1)
     return final_pred_mask
Exemplo n.º 2
0
 def fb_on_batch(self,
                 annotated_insts,
                 training=True,
                 loss_factor=1,
                 **kwargs):
     self.refresh_batch(training)
     margin = self.margin.value
     # gold heads and labels
     gold_heads_arr, _ = self.predict_padder.pad(
         [z.heads.vals for z in annotated_insts])
     gold_labels_arr, _ = self.predict_padder.pad(
         [self.real2pred_labels(z.labels.idxes) for z in annotated_insts])
     gold_heads_expr = BK.input_idx(gold_heads_arr)  # [BS, Len]
     gold_labels_expr = BK.input_idx(gold_labels_arr)  # [BS, Len]
     # ===== calculate
     scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score(
         annotated_insts, training)
     full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr,
                                           training, margin,
                                           gold_heads_expr)
     #
     final_losses = None
     if self.norm_local or self.norm_single:
         select_label_score = self._score_label_selected(
             scoring_expr_pack, mask_expr, training, margin,
             gold_heads_expr, gold_labels_expr)
         # already added margin previously
         losses_heads = losses_labels = None
         if self.loss_prob:
             if self.norm_local:
                 losses_heads = BK.loss_nll(full_arc_score, gold_heads_expr)
                 losses_labels = BK.loss_nll(select_label_score,
                                             gold_labels_expr)
             elif self.norm_single:
                 single_sample = self.conf.tconf.loss_single_sample
                 losses_heads = self._losses_single(full_arc_score,
                                                    gold_heads_expr,
                                                    single_sample,
                                                    is_hinge=False)
                 losses_labels = self._losses_single(select_label_score,
                                                     gold_labels_expr,
                                                     single_sample,
                                                     is_hinge=False)
             # simply adding
             final_losses = losses_heads + losses_labels
         elif self.loss_hinge:
             if self.norm_local:
                 losses_heads = BK.loss_hinge(full_arc_score,
                                              gold_heads_expr)
                 losses_labels = BK.loss_hinge(select_label_score,
                                               gold_labels_expr)
             elif self.norm_single:
                 single_sample = self.conf.tconf.loss_single_sample
                 losses_heads = self._losses_single(full_arc_score,
                                                    gold_heads_expr,
                                                    single_sample,
                                                    is_hinge=True,
                                                    margin=margin)
                 losses_labels = self._losses_single(select_label_score,
                                                     gold_labels_expr,
                                                     single_sample,
                                                     is_hinge=True,
                                                     margin=margin)
             # simply adding
             final_losses = losses_heads + losses_labels
         elif self.loss_mr:
             # special treatment!
             probs_heads = BK.softmax(full_arc_score, dim=-1)  # [bs, m, h]
             probs_labels = BK.softmax(select_label_score,
                                       dim=-1)  # [bs, m, h]
             # select
             probs_head_gold = BK.gather_one_lastdim(
                 probs_heads, gold_heads_expr).squeeze(-1)  # [bs, m]
             probs_label_gold = BK.gather_one_lastdim(
                 probs_labels, gold_labels_expr).squeeze(-1)  # [bs, m]
             # root and pad will be excluded later
             # Reward = \sum_i 1.*marginal(GEdge_i); while for global models, need to gradient on marginal-functions
             # todo(warn): have problem since steps will be quite small, not used!
             final_losses = (mask_expr - probs_head_gold * probs_label_gold
                             )  # let loss>=0
     elif self.norm_global:
         full_label_score = self._score_label_full(scoring_expr_pack,
                                                   mask_expr, training,
                                                   margin, gold_heads_expr,
                                                   gold_labels_expr)
         # for this one, use the merged full score
         full_score = full_arc_score.unsqueeze(
             -1) + full_label_score  # [BS, m, h, L]
         # +=1 to include ROOT for mst decoding
         mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts],
                                      dtype=np.int32)
         # do inference
         if self.loss_prob:
             marginals_expr = self._marginal(
                 full_score, mask_expr, mst_lengths_arr)  # [BS, m, h, L]
             final_losses = self._losses_global_prob(
                 full_score, gold_heads_expr, gold_labels_expr,
                 marginals_expr, mask_expr)
             if self.alg_proj:
                 # todo(+N): deal with search-error-like problem, discard unproj neg losses (score>weighted-avg),
                 #  but this might be too loose, although the unproj edges are few?
                 gold_unproj_arr, _ = self.predict_padder.pad(
                     [z.unprojs for z in annotated_insts])
                 gold_unproj_expr = BK.input_real(
                     gold_unproj_arr)  # [BS, Len]
                 comparing_expr = Constants.REAL_PRAC_MIN * (
                     1. - gold_unproj_expr)
                 final_losses = BK.max_elem(final_losses, comparing_expr)
         elif self.loss_hinge:
             pred_heads_arr, pred_labels_arr, _ = self._decode(
                 full_score, mask_expr, mst_lengths_arr)
             pred_heads_expr = BK.input_idx(pred_heads_arr)  # [BS, Len]
             pred_labels_expr = BK.input_idx(pred_labels_arr)  # [BS, Len]
             #
             final_losses = self._losses_global_hinge(
                 full_score, gold_heads_expr, gold_labels_expr,
                 pred_heads_expr, pred_labels_expr, mask_expr)
         elif self.loss_mr:
             # todo(+N): Loss = -Reward = \sum marginals, which requires gradients on marginal-one-edge, or marginal-two-edges
             raise NotImplementedError(
                 "Not implemented for global-loss + mr.")
     elif self.norm_hlocal:
         # firstly label losses are the same
         select_label_score = self._score_label_selected(
             scoring_expr_pack, mask_expr, training, margin,
             gold_heads_expr, gold_labels_expr)
         losses_labels = BK.loss_nll(select_label_score, gold_labels_expr)
         # then specially for arc loss
         children_masks_arr, _ = self.hlocal_padder.pad(
             [z.get_children_mask_arr() for z in annotated_insts])
         children_masks_expr = BK.input_real(
             children_masks_arr)  # [bs, h, m]
         # [bs, h]
         # todo(warn): use prod rather than sum, but still only an approximation for the top-down
         # losses_arc = -BK.log(BK.sum(BK.softmax(full_arc_score, -2).transpose(-1, -2) * children_masks_expr, dim=-1) + (1-mask_expr))
         losses_arc = -BK.sum(BK.log_softmax(full_arc_score, -2).transpose(
             -1, -2) * children_masks_expr,
                              dim=-1)
         # including the root-head is important
         losses_arc[:, 1] += losses_arc[:, 0]
         final_losses = losses_arc + losses_labels
     #
     # jpos loss? (the same mask as parsing)
     jpos_losses_expr = jpos_pack[1]
     if jpos_losses_expr is not None:
         final_losses += jpos_losses_expr
     # collect loss with mask, also excluding the first symbol of ROOT
     final_losses_masked = (final_losses * mask_expr)[:, 1:]
     final_loss_sum = BK.sum(final_losses_masked)
     # divide loss by what?
     num_sent = len(annotated_insts)
     num_valid_tok = sum(len(z) for z in annotated_insts)
     if self.conf.tconf.loss_div_tok:
         final_loss = final_loss_sum / num_valid_tok
     else:
         final_loss = final_loss_sum / num_sent
     #
     final_loss_sum_val = float(BK.get_value(final_loss_sum))
     info = {
         "sent": num_sent,
         "tok": num_valid_tok,
         "loss_sum": final_loss_sum_val
     }
     if training:
         BK.backward(final_loss, loss_factor)
     return info
Exemplo n.º 3
0
 def prune_with_scores(arc_score,
                       label_score,
                       mask_expr,
                       pconf: PruneG1Conf,
                       arc_marginals=None):
     prune_use_topk, prune_use_marginal, prune_labeled, prune_perc, prune_topk, prune_gap, prune_mthresh, prune_mthresh_rel = \
         pconf.pruning_use_topk, pconf.pruning_use_marginal, pconf.pruning_labeled, pconf.pruning_perc, pconf.pruning_topk, \
         pconf.pruning_gap, pconf.pruning_mthresh, pconf.pruning_mthresh_rel
     full_score = arc_score + label_score
     final_valid_mask = BK.constants(BK.get_shape(arc_score),
                                     0,
                                     dtype=BK.uint8).squeeze(-1)
     # (put as argument) arc_marginals = None  # [*, mlen, hlen]
     if prune_use_marginal:
         if arc_marginals is None:  # does not provided, calculate from scores
             if prune_labeled:
                 # arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).max(-1)[0]
                 # use sum of label marginals instead of max
                 arc_marginals = nmarginal_unproj(full_score,
                                                  mask_expr,
                                                  None,
                                                  labeled=True).sum(-1)
             else:
                 arc_marginals = nmarginal_unproj(arc_score,
                                                  mask_expr,
                                                  None,
                                                  labeled=True).squeeze(-1)
         if prune_mthresh_rel:
             # relative value
             max_arc_marginals = arc_marginals.max(-1)[0].log().unsqueeze(
                 -1)
             m_valid_mask = (arc_marginals.log() -
                             max_arc_marginals) > float(
                                 np.log(prune_mthresh))
         else:
             # absolute value
             m_valid_mask = (arc_marginals > prune_mthresh
                             )  # [*, len-m, len-h]
         final_valid_mask |= m_valid_mask
     if prune_use_topk:
         # prune by "in topk" and "gap-to-top less than gap" for each mod
         if prune_labeled:  # take argmax among label dim
             tmp_arc_score, _ = full_score.max(-1)
         else:
             # todo(note): may be modified inplaced, but does not matter since will finally be masked later
             tmp_arc_score = arc_score.squeeze(-1)
         # first apply mask
         mask_value = Constants.REAL_PRAC_MIN
         mask_mul = (mask_value * (1. - mask_expr))  # [*, len]
         tmp_arc_score += mask_mul.unsqueeze(-1)
         tmp_arc_score += mask_mul.unsqueeze(-2)
         maxlen = BK.get_shape(tmp_arc_score, -1)
         tmp_arc_score += mask_value * BK.eye(maxlen)
         prune_topk = min(prune_topk, int(maxlen * prune_perc + 1), maxlen)
         if prune_topk >= maxlen:
             topk_arc_score = tmp_arc_score
         else:
             topk_arc_score, _ = BK.topk(tmp_arc_score,
                                         prune_topk,
                                         dim=-1,
                                         sorted=False)  # [*, len, k]
         min_topk_arc_score = topk_arc_score.min(-1)[0].unsqueeze(
             -1)  # [*, len, 1]
         max_topk_arc_score = topk_arc_score.max(-1)[0].unsqueeze(
             -1)  # [*, len, 1]
         arc_score_thresh = BK.max_elem(min_topk_arc_score,
                                        max_topk_arc_score -
                                        prune_gap)  # [*, len, 1]
         t_valid_mask = (tmp_arc_score > arc_score_thresh
                         )  # [*, len-m, len-h]
         final_valid_mask |= t_valid_mask
     return final_valid_mask, arc_marginals