def _my_loss_prob(self, score_expr, gold_idxes_expr, entropy_lambda: float, loss_mask, neg_reweight: bool): probs = BK.softmax(score_expr, -1) # [*, NLab] log_probs = BK.log(probs + 1e-8) # first plain NLL loss nll_loss = -BK.gather_one_lastdim(log_probs, gold_idxes_expr).squeeze(-1) # next the special loss if entropy_lambda > 0.: negative_entropy = probs * log_probs # [*, NLab] last_dim = BK.get_shape(score_expr, -1) confusion_matrix = 1. - BK.eye(last_dim) # [Nlab, Nlab] entropy_mask = confusion_matrix[gold_idxes_expr] # [*, Nlab] entropy_loss = (negative_entropy * entropy_mask).sum(-1) final_loss = nll_loss + entropy_lambda * entropy_loss else: final_loss = nll_loss # reweight? if neg_reweight: golden_prob = BK.gather_one_lastdim(probs, gold_idxes_expr).squeeze(-1) is_full_nil = (gold_idxes_expr == 0.).float() not_full_nil = 1. - is_full_nil count_pos = (loss_mask * not_full_nil).sum() count_neg = (loss_mask * is_full_nil).sum() prob_pos = (loss_mask * not_full_nil * golden_prob).sum() prob_neg = (loss_mask * is_full_nil * golden_prob).sum() neg_weight = prob_pos / (count_pos + count_neg - prob_neg + 1e-8) final_weights = not_full_nil + is_full_nil * neg_weight # todo(note): final mask will be applied at outside final_loss = final_loss * final_weights return final_loss
def _score(self, repr_t, attn_t, mask_t): conf = self.conf # ----- repr_m = self.pre_aff_m(repr_t) # [bs, slen, S] repr_h = self.pre_aff_h(repr_t) # [bs, slen, S] scores0 = self.dps_node.paired_score( repr_m, repr_h, inputp=attn_t) # [bs, len_q, len_k, 1+N] # mask at outside slen = BK.get_shape(mask_t, -1) score_mask = BK.constants(BK.get_shape(scores0)[:-1], 1.) # [bs, len_q, len_k] score_mask *= (1. - BK.eye(slen)) # no diag score_mask *= mask_t.unsqueeze(-1) # input mask at len_k score_mask *= mask_t.unsqueeze(-2) # input mask at len_q NEG = Constants.REAL_PRAC_MIN scores1 = scores0 + NEG * (1. - score_mask.unsqueeze(-1) ) # [bs, len_q, len_k, 1+N] # add fixed idx0 scores if set if conf.fix_s0: fix_s0_mask_t = BK.input_real(self.dps_s0_mask) # [1+N] scores1 = ( 1. - fix_s0_mask_t ) * scores1 + fix_s0_mask_t * conf.fix_s0_val # [bs, len_q, len_k, 1+N] # minus s0 if conf.minus_s0: scores1 = scores1 - scores1.narrow(-1, 0, 1) # minus idx=0 scores return scores1, score_mask
def __call__(self, query, key, accu_attn, mask_k, mask_qk, rel_dist): conf = self.conf # == calculate the dot-product scores # calculate the three: # [bs, len_?, head*D]; and also add sta ones if needed query_up, key_up = self.affine_q(query), self.affine_k( key) # [*, len?, head?*Dqk] query_up, key_up = self._shape_project( query_up, True), self._shape_project(key_up, True) # [*, head?, len_?, D] # original scores scores = BK.matmul(query_up, BK.transpose( key_up, -1, -2)) / self._att_scale_qk # [*, head?, len_q, len_k] # == adding rel_dist ones if conf.use_rel_dist: scores = self.dist_helper(query_up, key_up, rel_dist=rel_dist, input_scores=scores) # tranpose scores = scores.transpose(-2, -3).transpose(-1, -2) # [*, len_q, len_k, head?] # == unhead score if conf.use_unhead_score: scores_t0, score_t1 = BK.split(scores, [1, self.head_count], -1) # [*, len_q, len_k, 1|head] scores = scores_t0 + score_t1 # [*, len_q, len_k, head] # == combining with history accumulated attns if conf.use_lambq and accu_attn is not None: # todo(note): here we only consider "query" and "head", would it be necessary for "key"? lambq_vals = self.lambq_aff( query ) # [*, len_q, head], if for eg., using relu as fact, this>=0 scores -= lambq_vals.unsqueeze(-2) * accu_attn # == score offset if conf.use_soff: # todo(note): here we only consider "query" and "head", key may be handled by "unhead_score" score_offset_t = self.soff_aff(query) # [*, len_q, 1+head] score_offset_t0, score_offset_t1 = BK.split( score_offset_t, [1, self.head_count], -1) # [*, len_q, 1|head] scores -= score_offset_t0.unsqueeze(-2) scores -= score_offset_t1.unsqueeze( -2) # still [*, len_q, len_k, head] # == apply mask & no-self-loop # NEG_INF = Constants.REAL_PRAC_MIN NEG_INF = -1000. # this should be enough NEG_INF2 = -2000. # this should be enough if mask_k is not None: # [*, 1, len_k, 1] scores += (1. - mask_k).unsqueeze(-2).unsqueeze(-1) * NEG_INF2 if mask_qk is not None: # [*, len_q, len_k, 1] scores += (1. - mask_qk).unsqueeze(-1) * NEG_INF2 if self.no_self_loop: query_len = BK.get_shape(query, -2) assert query_len == BK.get_shape( key, -2), "Shape not matched for no_self_loop" scores += BK.eye(query_len).unsqueeze( -1) * NEG_INF # [len_q, len_k, 1] return scores.contiguous() # [*, len_q, len_k, head]
def _init_fixed_mask(self, enc_mask_arr): tmp_device = BK.CPU_DEVICE # by token mask mask_ct = BK.input_real(enc_mask_arr, device=tmp_device) # [*, len] full_mask_ct = mask_ct.unsqueeze(-1) * mask_ct.unsqueeze( -2) # [*, len-mod, len-head] # no self loop full_mask_ct *= (1. - BK.eye(self.max_slen, device=tmp_device)) # no root as mod; todo(warn): assume it is 3D full_mask_ct[:, 0, :] = 0. return full_mask_ct
def _make_final_valid(self, valid_expr, mask_expr): maxlen = BK.get_shape(mask_expr, -1) # first apply masks mask_expr_byte = mask_expr.byte() valid_expr &= mask_expr_byte.unsqueeze(-1) valid_expr &= mask_expr_byte.unsqueeze(-2) # then diag mask_diag = 1 - BK.eye(maxlen).byte() valid_expr &= mask_diag # root not as mod, todo(note): here no [0,0] since no need valid_expr[:, 0] = 0 return valid_expr.float()
def _make_final_valid(self, valid_expr, mask_expr): maxlen = BK.get_shape(mask_expr, -1) # first apply masks mask_expr_byte = mask_expr.byte() valid_expr &= mask_expr_byte.unsqueeze(-1) valid_expr &= mask_expr_byte.unsqueeze(-2) # then diag mask_diag = 1 - BK.eye(maxlen).byte() valid_expr &= mask_diag # root not as mod valid_expr[:, 0] = 0 # only allow root->root (for grandparent feature) valid_expr[:, 0, 0] = 1 return valid_expr
def _score(self, enc_expr, mask_expr): # ----- def _special_score( one_score): # specially change ablpair scores into [bs,m,h,*] root_score = one_score[:, :, 0].unsqueeze(2) # [bs, rlen, 1, *] tmp_shape = BK.get_shape(root_score) tmp_shape[1] = 1 # [bs, 1, 1, *] padded_root_score = BK.concat([BK.zeros(tmp_shape), root_score], dim=1) # [bs, rlen+1, 1, *] final_score = BK.concat( [padded_root_score, one_score.transpose(1, 2)], dim=2) # [bs, rlen+1[m], rlen+1[h], *] return final_score # ----- if self.use_ablpair: input_mask_expr = ( mask_expr.unsqueeze(-1) * mask_expr.unsqueeze(-2))[:, 1:] # [bs, rlen, rlen+1] arc_score = self.scorer.transform_and_arc_score( enc_expr, input_mask_expr) # [bs, rlen, rlen+1, 1] lab_score = self.scorer.transform_and_lab_score( enc_expr, input_mask_expr) # [bs, rlen, rlen+1, Lab] # put root-scores for both directions arc_score = _special_score(arc_score) lab_score = _special_score(lab_score) else: # todo(+2): for training, we can simply select and lab-score arc_score = self.scorer.transform_and_arc_score( enc_expr, mask_expr) # [bs, m, h, 1] lab_score = self.scorer.transform_and_lab_score( enc_expr, mask_expr) # [bs, m, h, Lab] # mask out diag scores diag_mask = BK.eye(BK.get_shape(arc_score, 1)) diag_mask[0, 0] = 0. diag_add = Constants.REAL_PRAC_MIN * ( diag_mask.unsqueeze(-1).unsqueeze(0)) # [1, m, h, 1] arc_score += diag_add lab_score += diag_add return arc_score, lab_score
def prune_with_scores(arc_score, label_score, mask_expr, pconf: PruneG1Conf, arc_marginals=None): prune_use_topk, prune_use_marginal, prune_labeled, prune_perc, prune_topk, prune_gap, prune_mthresh, prune_mthresh_rel = \ pconf.pruning_use_topk, pconf.pruning_use_marginal, pconf.pruning_labeled, pconf.pruning_perc, pconf.pruning_topk, \ pconf.pruning_gap, pconf.pruning_mthresh, pconf.pruning_mthresh_rel full_score = arc_score + label_score final_valid_mask = BK.constants(BK.get_shape(arc_score), 0, dtype=BK.uint8).squeeze(-1) # (put as argument) arc_marginals = None # [*, mlen, hlen] if prune_use_marginal: if arc_marginals is None: # does not provided, calculate from scores if prune_labeled: # arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).max(-1)[0] # use sum of label marginals instead of max arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).sum(-1) else: arc_marginals = nmarginal_unproj(arc_score, mask_expr, None, labeled=True).squeeze(-1) if prune_mthresh_rel: # relative value max_arc_marginals = arc_marginals.max(-1)[0].log().unsqueeze( -1) m_valid_mask = (arc_marginals.log() - max_arc_marginals) > float( np.log(prune_mthresh)) else: # absolute value m_valid_mask = (arc_marginals > prune_mthresh ) # [*, len-m, len-h] final_valid_mask |= m_valid_mask if prune_use_topk: # prune by "in topk" and "gap-to-top less than gap" for each mod if prune_labeled: # take argmax among label dim tmp_arc_score, _ = full_score.max(-1) else: # todo(note): may be modified inplaced, but does not matter since will finally be masked later tmp_arc_score = arc_score.squeeze(-1) # first apply mask mask_value = Constants.REAL_PRAC_MIN mask_mul = (mask_value * (1. - mask_expr)) # [*, len] tmp_arc_score += mask_mul.unsqueeze(-1) tmp_arc_score += mask_mul.unsqueeze(-2) maxlen = BK.get_shape(tmp_arc_score, -1) tmp_arc_score += mask_value * BK.eye(maxlen) prune_topk = min(prune_topk, int(maxlen * prune_perc + 1), maxlen) if prune_topk >= maxlen: topk_arc_score = tmp_arc_score else: topk_arc_score, _ = BK.topk(tmp_arc_score, prune_topk, dim=-1, sorted=False) # [*, len, k] min_topk_arc_score = topk_arc_score.min(-1)[0].unsqueeze( -1) # [*, len, 1] max_topk_arc_score = topk_arc_score.max(-1)[0].unsqueeze( -1) # [*, len, 1] arc_score_thresh = BK.max_elem(min_topk_arc_score, max_topk_arc_score - prune_gap) # [*, len, 1] t_valid_mask = (tmp_arc_score > arc_score_thresh ) # [*, len-m, len-h] final_valid_mask |= t_valid_mask return final_valid_mask, arc_marginals
def nmarginal_unproj(scores_expr, mask_expr, lengths_arr, labeled=True): assert labeled with BK.no_grad_env(): scores_shape = BK.get_shape(scores_expr) maxlen = scores_shape[1] # todo(warn): it seems that float32 is not enough for inverse when the model gets better (scores gets more diversed) diag1_m = BK.eye(maxlen).double() # [m, h] scores_expr_d = scores_expr.double() mask_expr_d = mask_expr.double() invalid_pad_expr_d = 1. - mask_expr_d # [*, m, h] full_invalid_d = (diag1_m + invalid_pad_expr_d.unsqueeze(-1) + invalid_pad_expr_d.unsqueeze(-2)).clamp(0., 1.) full_invalid_d[:, 0] = 1. # # first make it unlabeled by sum-exp scores_unlabeled = BK.logsumexp(scores_expr_d, dim=-1) # [BS, m, h] # force small values at diag entries and padded ones scores_unlabeled_diag_neg = scores_unlabeled + Constants.REAL_PRAC_MIN * full_invalid_d # # minus the MaxElement to make it more stable with larger values, to make it numerically stable. # [BS, m, h] # todo(+N): since one and only one Head is selected, thus minus by Max will make it the same? # I think it will be canceled out since this is like left-mul A by a diag Q scores_unlabeled_max = (scores_unlabeled_diag_neg.max(-1)[0] * mask_expr_d).unsqueeze(-1) # [BS, m, 1] scores_exp_unlabeled = BK.exp(scores_unlabeled_diag_neg - scores_unlabeled_max) # # todo(0): co-work with minus-max, force too small values to be 0 (serve as pruning, the gap is ?*ln(10)). # scores_exp_unlabeled *= (1 - (scores_exp_unlabeled<1e-10)).double() # force 0 at diag entries (again) scores_exp_unlabeled *= (1. - diag1_m) # assign non-zero values (does not matter) to (0, invalid) to make the matrix inversable scores_exp_unlabeled[:, :, 0] += (1. - mask_expr_d ) # the value does not matter? # construct L(or K) Matrix: L=D-A A = scores_exp_unlabeled A_sum = A.sum(dim=-1, keepdim=True) # [BS, m, 1] # # ===== # todo(0): can this avoid singular matrix: feels like adding aug-values to h==0(COL0) to-root scores. # todo(+N): there are cases that the original matrix is not inversable (no solutions for trees)!! A_sum += 1e-6 # A_sum += A_sum * 1e-4 + 1e-6 # D = A_sum.expand(scores_shape[:-1]) * diag1_m # [BS, m, h] L = D - A # [BS, m, h] # get the minor00 matrix LM00 = L[:, 1:, 1:] # [BS, m-1, h-1] # # Debug1 # try: # # for idx in range(scores_shape[0]): # # one_det = float(LM00[idx].det()) # # assert not math.isnan(one_det) # # LM00_CPU = LM00.cpu() # # LM00_CPU_inv = LM00_CPU.inverse() # scores_exp_unlabeled_CPU = scores_exp_unlabeled.cpu() # LM00_CPU = LM00.cpu() # assert BK.has_nan(LM00_CPU) == 0 # except: # assert False, "Problem here" # # det and inverse; using LU decomposition to hit two birds with one stone. diag1_m00 = BK.eye(maxlen - 1).double() # deprecated operation # LM00_inv, LM00_lu = diag1_m00.gesv(LM00) # [BS, m-1, h-1] # # todo(warn): lacking P here, but the partition should always be non-negative! # LM00_det = BK.abs((LM00_lu*diag1_m00).sum(-1).prod(-1)) # [BS, ] # d(logZ)/d(LM00) = (LM00^-1)^T # # directly inverse (need pytorch >= 1.0) # LM00_inv = LM00.inverse() LM00_inv = BK.get_inverse(LM00, diag1_m00) LM00_grad = LM00_inv.transpose(-1, -2) # [BS, m-1, h-1] # marginal(m,h) = d(logZ)/d(score(m,h)) = d(logZ)/d(LM00) * d(LM00)/d(score(m,h)) = INV_mm - INV_mh # padding and minus LM00_grad_pad = BK.pad(LM00_grad, [1, 0, 1, 0], 'constant', 0.) # [BS, m, h] LM00_grad_pad_sum = (LM00_grad_pad * diag1_m).sum( dim=-1, keepdim=True) # [BS, m, 1] marginals_unlabeled = A * (LM00_grad_pad_sum - LM00_grad_pad ) # [BS, m, h] # make sure each row sum to 1. marginals_unlabeled[:, 0, 0] = 1. # finally, get labeled results marginals_labeled = marginals_unlabeled.unsqueeze(-1) * BK.exp( scores_expr_d - scores_unlabeled.unsqueeze(-1)) # # # Debug2 # try: # # for idx in range(scores_shape[0]): # # one_det = float(LM00[idx].det()) # # assert not math.isnan(one_det) # # LM00_CPU = LM00.cpu() # # LM00_CPU_inv = LM00_CPU.inverse() # scores_exp_unlabeled_CPU = scores_exp_unlabeled.cpu() # LM00_CPU = LM00.cpu() # marginals_unlabeled_CPU = marginals_unlabeled.cpu() # assert BK.has_nan(marginals_unlabeled_CPU) == 0 # # # global last_lm00, last_marginals # last_lm00 = LM00_CPU # last_marginals = marginals_unlabeled_CPU # except: # assert False, "Problem here" # # back to plain float32 masked_marginals_labeled = marginals_labeled * ( 1. - full_invalid_d).unsqueeze(-1) ret = masked_marginals_labeled.float() return _ensure_margins_norm(ret)