def _predict(self, final_score, attn, attn2, input_mask): # attn[:, 0] = 0. # not for artificial root # todo(note): input_mask should already exclude idx=0 attn *= input_mask.unsqueeze(-1) # not for invalid tokens conf = self.conf # get sentence level types pred_sent_max_num = conf.pred_sent_max_num pred_sent_abs_thresh = conf.pred_sent_abs_thresh # todo(note): kthvalue is not available at gpu and only support smallest!! # s_thresh00 = final_score.kthvalue(pred_sent_max_num, -1, keepdim=True)[0] # [*, 1] s_thresh0 = final_score.topk(pred_sent_max_num, -1)[0].min(-1, keepdim=True)[0] # [*, 1] s_thresh = s_thresh0.clamp(min=pred_sent_abs_thresh) sent_pred_mask = (final_score >= s_thresh).float() # [*, L] # get token level as triggers (each type how many triggers && each token how many types) pred_tok_max_num = conf.pred_tok_max_num pred_tok_rel_thresh = conf.pred_tok_rel_thresh pred_tok_abs_thresh = conf.pred_tok_abs_thresh all_tok_pred_masks = [] # for a, r in zip([attn, attn2], [-2, -1]): for a, r in zip([attn], [-2]): # t_thresh00 = attn.kthvalue(pred_tok_max_num, r, keepdim=True)[0] # [*, 1, L] t_thresh0 = a.topk(pred_tok_max_num, r)[0].min(r, keepdim=True)[0] # [*, 1, L] t_max_value = a.max(r, keepdim=True)[0] # [*, 1, L] t_thresh1 = BK.max_elem(t_thresh0, t_max_value * pred_tok_rel_thresh) t_thresh = BK.max_elem(t_thresh1, t_max_value - pred_tok_abs_thresh) one_tok_pred_mask = (a >= t_thresh).float() # [*, slen, L] if r == -1: one_tok_pred_mask = one_tok_pred_mask.view( BK.get_shape(attn)) # back to dim=3 all_tok_pred_masks.append(one_tok_pred_mask) # put them together # final_pred_mask = all_tok_pred_masks[0] * all_tok_pred_masks[1] * (sent_pred_mask.unsqueeze(-2)) # [*, slen, L] final_pred_mask = all_tok_pred_masks[0] * (sent_pred_mask.unsqueeze(-2) ) # [*, slen, L] final_pred_mask[:, :, 0] = 0. # not for nil-type final_pred_mask *= input_mask.unsqueeze(-1) return final_pred_mask
def fb_on_batch(self, annotated_insts, training=True, loss_factor=1, **kwargs): self.refresh_batch(training) margin = self.margin.value # gold heads and labels gold_heads_arr, _ = self.predict_padder.pad( [z.heads.vals for z in annotated_insts]) gold_labels_arr, _ = self.predict_padder.pad( [self.real2pred_labels(z.labels.idxes) for z in annotated_insts]) gold_heads_expr = BK.input_idx(gold_heads_arr) # [BS, Len] gold_labels_expr = BK.input_idx(gold_labels_arr) # [BS, Len] # ===== calculate scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score( annotated_insts, training) full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr, training, margin, gold_heads_expr) # final_losses = None if self.norm_local or self.norm_single: select_label_score = self._score_label_selected( scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) # already added margin previously losses_heads = losses_labels = None if self.loss_prob: if self.norm_local: losses_heads = BK.loss_nll(full_arc_score, gold_heads_expr) losses_labels = BK.loss_nll(select_label_score, gold_labels_expr) elif self.norm_single: single_sample = self.conf.tconf.loss_single_sample losses_heads = self._losses_single(full_arc_score, gold_heads_expr, single_sample, is_hinge=False) losses_labels = self._losses_single(select_label_score, gold_labels_expr, single_sample, is_hinge=False) # simply adding final_losses = losses_heads + losses_labels elif self.loss_hinge: if self.norm_local: losses_heads = BK.loss_hinge(full_arc_score, gold_heads_expr) losses_labels = BK.loss_hinge(select_label_score, gold_labels_expr) elif self.norm_single: single_sample = self.conf.tconf.loss_single_sample losses_heads = self._losses_single(full_arc_score, gold_heads_expr, single_sample, is_hinge=True, margin=margin) losses_labels = self._losses_single(select_label_score, gold_labels_expr, single_sample, is_hinge=True, margin=margin) # simply adding final_losses = losses_heads + losses_labels elif self.loss_mr: # special treatment! probs_heads = BK.softmax(full_arc_score, dim=-1) # [bs, m, h] probs_labels = BK.softmax(select_label_score, dim=-1) # [bs, m, h] # select probs_head_gold = BK.gather_one_lastdim( probs_heads, gold_heads_expr).squeeze(-1) # [bs, m] probs_label_gold = BK.gather_one_lastdim( probs_labels, gold_labels_expr).squeeze(-1) # [bs, m] # root and pad will be excluded later # Reward = \sum_i 1.*marginal(GEdge_i); while for global models, need to gradient on marginal-functions # todo(warn): have problem since steps will be quite small, not used! final_losses = (mask_expr - probs_head_gold * probs_label_gold ) # let loss>=0 elif self.norm_global: full_label_score = self._score_label_full(scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) # for this one, use the merged full score full_score = full_arc_score.unsqueeze( -1) + full_label_score # [BS, m, h, L] # +=1 to include ROOT for mst decoding mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts], dtype=np.int32) # do inference if self.loss_prob: marginals_expr = self._marginal( full_score, mask_expr, mst_lengths_arr) # [BS, m, h, L] final_losses = self._losses_global_prob( full_score, gold_heads_expr, gold_labels_expr, marginals_expr, mask_expr) if self.alg_proj: # todo(+N): deal with search-error-like problem, discard unproj neg losses (score>weighted-avg), # but this might be too loose, although the unproj edges are few? gold_unproj_arr, _ = self.predict_padder.pad( [z.unprojs for z in annotated_insts]) gold_unproj_expr = BK.input_real( gold_unproj_arr) # [BS, Len] comparing_expr = Constants.REAL_PRAC_MIN * ( 1. - gold_unproj_expr) final_losses = BK.max_elem(final_losses, comparing_expr) elif self.loss_hinge: pred_heads_arr, pred_labels_arr, _ = self._decode( full_score, mask_expr, mst_lengths_arr) pred_heads_expr = BK.input_idx(pred_heads_arr) # [BS, Len] pred_labels_expr = BK.input_idx(pred_labels_arr) # [BS, Len] # final_losses = self._losses_global_hinge( full_score, gold_heads_expr, gold_labels_expr, pred_heads_expr, pred_labels_expr, mask_expr) elif self.loss_mr: # todo(+N): Loss = -Reward = \sum marginals, which requires gradients on marginal-one-edge, or marginal-two-edges raise NotImplementedError( "Not implemented for global-loss + mr.") elif self.norm_hlocal: # firstly label losses are the same select_label_score = self._score_label_selected( scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) losses_labels = BK.loss_nll(select_label_score, gold_labels_expr) # then specially for arc loss children_masks_arr, _ = self.hlocal_padder.pad( [z.get_children_mask_arr() for z in annotated_insts]) children_masks_expr = BK.input_real( children_masks_arr) # [bs, h, m] # [bs, h] # todo(warn): use prod rather than sum, but still only an approximation for the top-down # losses_arc = -BK.log(BK.sum(BK.softmax(full_arc_score, -2).transpose(-1, -2) * children_masks_expr, dim=-1) + (1-mask_expr)) losses_arc = -BK.sum(BK.log_softmax(full_arc_score, -2).transpose( -1, -2) * children_masks_expr, dim=-1) # including the root-head is important losses_arc[:, 1] += losses_arc[:, 0] final_losses = losses_arc + losses_labels # # jpos loss? (the same mask as parsing) jpos_losses_expr = jpos_pack[1] if jpos_losses_expr is not None: final_losses += jpos_losses_expr # collect loss with mask, also excluding the first symbol of ROOT final_losses_masked = (final_losses * mask_expr)[:, 1:] final_loss_sum = BK.sum(final_losses_masked) # divide loss by what? num_sent = len(annotated_insts) num_valid_tok = sum(len(z) for z in annotated_insts) if self.conf.tconf.loss_div_tok: final_loss = final_loss_sum / num_valid_tok else: final_loss = final_loss_sum / num_sent # final_loss_sum_val = float(BK.get_value(final_loss_sum)) info = { "sent": num_sent, "tok": num_valid_tok, "loss_sum": final_loss_sum_val } if training: BK.backward(final_loss, loss_factor) return info
def prune_with_scores(arc_score, label_score, mask_expr, pconf: PruneG1Conf, arc_marginals=None): prune_use_topk, prune_use_marginal, prune_labeled, prune_perc, prune_topk, prune_gap, prune_mthresh, prune_mthresh_rel = \ pconf.pruning_use_topk, pconf.pruning_use_marginal, pconf.pruning_labeled, pconf.pruning_perc, pconf.pruning_topk, \ pconf.pruning_gap, pconf.pruning_mthresh, pconf.pruning_mthresh_rel full_score = arc_score + label_score final_valid_mask = BK.constants(BK.get_shape(arc_score), 0, dtype=BK.uint8).squeeze(-1) # (put as argument) arc_marginals = None # [*, mlen, hlen] if prune_use_marginal: if arc_marginals is None: # does not provided, calculate from scores if prune_labeled: # arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).max(-1)[0] # use sum of label marginals instead of max arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).sum(-1) else: arc_marginals = nmarginal_unproj(arc_score, mask_expr, None, labeled=True).squeeze(-1) if prune_mthresh_rel: # relative value max_arc_marginals = arc_marginals.max(-1)[0].log().unsqueeze( -1) m_valid_mask = (arc_marginals.log() - max_arc_marginals) > float( np.log(prune_mthresh)) else: # absolute value m_valid_mask = (arc_marginals > prune_mthresh ) # [*, len-m, len-h] final_valid_mask |= m_valid_mask if prune_use_topk: # prune by "in topk" and "gap-to-top less than gap" for each mod if prune_labeled: # take argmax among label dim tmp_arc_score, _ = full_score.max(-1) else: # todo(note): may be modified inplaced, but does not matter since will finally be masked later tmp_arc_score = arc_score.squeeze(-1) # first apply mask mask_value = Constants.REAL_PRAC_MIN mask_mul = (mask_value * (1. - mask_expr)) # [*, len] tmp_arc_score += mask_mul.unsqueeze(-1) tmp_arc_score += mask_mul.unsqueeze(-2) maxlen = BK.get_shape(tmp_arc_score, -1) tmp_arc_score += mask_value * BK.eye(maxlen) prune_topk = min(prune_topk, int(maxlen * prune_perc + 1), maxlen) if prune_topk >= maxlen: topk_arc_score = tmp_arc_score else: topk_arc_score, _ = BK.topk(tmp_arc_score, prune_topk, dim=-1, sorted=False) # [*, len, k] min_topk_arc_score = topk_arc_score.min(-1)[0].unsqueeze( -1) # [*, len, 1] max_topk_arc_score = topk_arc_score.max(-1)[0].unsqueeze( -1) # [*, len, 1] arc_score_thresh = BK.max_elem(min_topk_arc_score, max_topk_arc_score - prune_gap) # [*, len, 1] t_valid_mask = (tmp_arc_score > arc_score_thresh ) # [*, len-m, len-h] final_valid_mask |= t_valid_mask return final_valid_mask, arc_marginals