示例#1
0
 def _my_loss_prob(self, score_expr, gold_idxes_expr, entropy_lambda: float,
                   loss_mask, neg_reweight: bool):
     probs = BK.softmax(score_expr, -1)  # [*, NLab]
     log_probs = BK.log(probs + 1e-8)
     # first plain NLL loss
     nll_loss = -BK.gather_one_lastdim(log_probs,
                                       gold_idxes_expr).squeeze(-1)
     # next the special loss
     if entropy_lambda > 0.:
         negative_entropy = probs * log_probs  # [*, NLab]
         last_dim = BK.get_shape(score_expr, -1)
         confusion_matrix = 1. - BK.eye(last_dim)  # [Nlab, Nlab]
         entropy_mask = confusion_matrix[gold_idxes_expr]  # [*, Nlab]
         entropy_loss = (negative_entropy * entropy_mask).sum(-1)
         final_loss = nll_loss + entropy_lambda * entropy_loss
     else:
         final_loss = nll_loss
     # reweight?
     if neg_reweight:
         golden_prob = BK.gather_one_lastdim(probs,
                                             gold_idxes_expr).squeeze(-1)
         is_full_nil = (gold_idxes_expr == 0.).float()
         not_full_nil = 1. - is_full_nil
         count_pos = (loss_mask * not_full_nil).sum()
         count_neg = (loss_mask * is_full_nil).sum()
         prob_pos = (loss_mask * not_full_nil * golden_prob).sum()
         prob_neg = (loss_mask * is_full_nil * golden_prob).sum()
         neg_weight = prob_pos / (count_pos + count_neg - prob_neg + 1e-8)
         final_weights = not_full_nil + is_full_nil * neg_weight
         # todo(note): final mask will be applied at outside
         final_loss = final_loss * final_weights
     return final_loss
示例#2
0
文件: dpar.py 项目: ValentinaPy/zmsp
 def _score(self, repr_t, attn_t, mask_t):
     conf = self.conf
     # -----
     repr_m = self.pre_aff_m(repr_t)  # [bs, slen, S]
     repr_h = self.pre_aff_h(repr_t)  # [bs, slen, S]
     scores0 = self.dps_node.paired_score(
         repr_m, repr_h, inputp=attn_t)  # [bs, len_q, len_k, 1+N]
     # mask at outside
     slen = BK.get_shape(mask_t, -1)
     score_mask = BK.constants(BK.get_shape(scores0)[:-1],
                               1.)  # [bs, len_q, len_k]
     score_mask *= (1. - BK.eye(slen))  # no diag
     score_mask *= mask_t.unsqueeze(-1)  # input mask at len_k
     score_mask *= mask_t.unsqueeze(-2)  # input mask at len_q
     NEG = Constants.REAL_PRAC_MIN
     scores1 = scores0 + NEG * (1. - score_mask.unsqueeze(-1)
                                )  # [bs, len_q, len_k, 1+N]
     # add fixed idx0 scores if set
     if conf.fix_s0:
         fix_s0_mask_t = BK.input_real(self.dps_s0_mask)  # [1+N]
         scores1 = (
             1. - fix_s0_mask_t
         ) * scores1 + fix_s0_mask_t * conf.fix_s0_val  # [bs, len_q, len_k, 1+N]
     # minus s0
     if conf.minus_s0:
         scores1 = scores1 - scores1.narrow(-1, 0, 1)  # minus idx=0 scores
     return scores1, score_mask
示例#3
0
 def __call__(self, query, key, accu_attn, mask_k, mask_qk, rel_dist):
     conf = self.conf
     # == calculate the dot-product scores
     # calculate the three: # [bs, len_?, head*D]; and also add sta ones if needed
     query_up, key_up = self.affine_q(query), self.affine_k(
         key)  # [*, len?, head?*Dqk]
     query_up, key_up = self._shape_project(
         query_up, True), self._shape_project(key_up,
                                              True)  # [*, head?, len_?, D]
     # original scores
     scores = BK.matmul(query_up, BK.transpose(
         key_up, -1, -2)) / self._att_scale_qk  # [*, head?, len_q, len_k]
     # == adding rel_dist ones
     if conf.use_rel_dist:
         scores = self.dist_helper(query_up,
                                   key_up,
                                   rel_dist=rel_dist,
                                   input_scores=scores)
     # tranpose
     scores = scores.transpose(-2,
                               -3).transpose(-1,
                                             -2)  # [*, len_q, len_k, head?]
     # == unhead score
     if conf.use_unhead_score:
         scores_t0, score_t1 = BK.split(scores, [1, self.head_count],
                                        -1)  # [*, len_q, len_k, 1|head]
         scores = scores_t0 + score_t1  # [*, len_q, len_k, head]
     # == combining with history accumulated attns
     if conf.use_lambq and accu_attn is not None:
         # todo(note): here we only consider "query" and "head", would it be necessary for "key"?
         lambq_vals = self.lambq_aff(
             query
         )  # [*, len_q, head], if for eg., using relu as fact, this>=0
         scores -= lambq_vals.unsqueeze(-2) * accu_attn
     # == score offset
     if conf.use_soff:
         # todo(note): here we only consider "query" and "head", key may be handled by "unhead_score"
         score_offset_t = self.soff_aff(query)  # [*, len_q, 1+head]
         score_offset_t0, score_offset_t1 = BK.split(
             score_offset_t, [1, self.head_count], -1)  # [*, len_q, 1|head]
         scores -= score_offset_t0.unsqueeze(-2)
         scores -= score_offset_t1.unsqueeze(
             -2)  # still [*, len_q, len_k, head]
     # == apply mask & no-self-loop
     # NEG_INF = Constants.REAL_PRAC_MIN
     NEG_INF = -1000.  # this should be enough
     NEG_INF2 = -2000.  # this should be enough
     if mask_k is not None:  # [*, 1, len_k, 1]
         scores += (1. - mask_k).unsqueeze(-2).unsqueeze(-1) * NEG_INF2
     if mask_qk is not None:  # [*, len_q, len_k, 1]
         scores += (1. - mask_qk).unsqueeze(-1) * NEG_INF2
     if self.no_self_loop:
         query_len = BK.get_shape(query, -2)
         assert query_len == BK.get_shape(
             key, -2), "Shape not matched for no_self_loop"
         scores += BK.eye(query_len).unsqueeze(
             -1) * NEG_INF  # [len_q, len_k, 1]
     return scores.contiguous()  # [*, len_q, len_k, head]
示例#4
0
 def _init_fixed_mask(self, enc_mask_arr):
     tmp_device = BK.CPU_DEVICE
     # by token mask
     mask_ct = BK.input_real(enc_mask_arr, device=tmp_device)  # [*, len]
     full_mask_ct = mask_ct.unsqueeze(-1) * mask_ct.unsqueeze(
         -2)  # [*, len-mod, len-head]
     # no self loop
     full_mask_ct *= (1. - BK.eye(self.max_slen, device=tmp_device))
     # no root as mod; todo(warn): assume it is 3D
     full_mask_ct[:, 0, :] = 0.
     return full_mask_ct
示例#5
0
文件: s2p.py 项目: ValentinaPy/zmsp
 def _make_final_valid(self, valid_expr, mask_expr):
     maxlen = BK.get_shape(mask_expr, -1)
     # first apply masks
     mask_expr_byte = mask_expr.byte()
     valid_expr &= mask_expr_byte.unsqueeze(-1)
     valid_expr &= mask_expr_byte.unsqueeze(-2)
     # then diag
     mask_diag = 1 - BK.eye(maxlen).byte()
     valid_expr &= mask_diag
     # root not as mod, todo(note): here no [0,0] since no need
     valid_expr[:, 0] = 0
     return valid_expr.float()
示例#6
0
文件: g2p.py 项目: ValentinaPy/zmsp
 def _make_final_valid(self, valid_expr, mask_expr):
     maxlen = BK.get_shape(mask_expr, -1)
     # first apply masks
     mask_expr_byte = mask_expr.byte()
     valid_expr &= mask_expr_byte.unsqueeze(-1)
     valid_expr &= mask_expr_byte.unsqueeze(-2)
     # then diag
     mask_diag = 1 - BK.eye(maxlen).byte()
     valid_expr &= mask_diag
     # root not as mod
     valid_expr[:, 0] = 0
     # only allow root->root (for grandparent feature)
     valid_expr[:, 0, 0] = 1
     return valid_expr
示例#7
0
文件: dec.py 项目: ValentinaPy/zmsp
    def _score(self, enc_expr, mask_expr):
        # -----
        def _special_score(
                one_score):  # specially change ablpair scores into [bs,m,h,*]
            root_score = one_score[:, :, 0].unsqueeze(2)  # [bs, rlen, 1, *]
            tmp_shape = BK.get_shape(root_score)
            tmp_shape[1] = 1  # [bs, 1, 1, *]
            padded_root_score = BK.concat([BK.zeros(tmp_shape), root_score],
                                          dim=1)  # [bs, rlen+1, 1, *]
            final_score = BK.concat(
                [padded_root_score,
                 one_score.transpose(1, 2)],
                dim=2)  # [bs, rlen+1[m], rlen+1[h], *]
            return final_score

        # -----
        if self.use_ablpair:
            input_mask_expr = (
                mask_expr.unsqueeze(-1) *
                mask_expr.unsqueeze(-2))[:, 1:]  # [bs, rlen, rlen+1]
            arc_score = self.scorer.transform_and_arc_score(
                enc_expr, input_mask_expr)  # [bs, rlen, rlen+1, 1]
            lab_score = self.scorer.transform_and_lab_score(
                enc_expr, input_mask_expr)  # [bs, rlen, rlen+1, Lab]
            # put root-scores for both directions
            arc_score = _special_score(arc_score)
            lab_score = _special_score(lab_score)
        else:
            # todo(+2): for training, we can simply select and lab-score
            arc_score = self.scorer.transform_and_arc_score(
                enc_expr, mask_expr)  # [bs, m, h, 1]
            lab_score = self.scorer.transform_and_lab_score(
                enc_expr, mask_expr)  # [bs, m, h, Lab]
        # mask out diag scores
        diag_mask = BK.eye(BK.get_shape(arc_score, 1))
        diag_mask[0, 0] = 0.
        diag_add = Constants.REAL_PRAC_MIN * (
            diag_mask.unsqueeze(-1).unsqueeze(0))  # [1, m, h, 1]
        arc_score += diag_add
        lab_score += diag_add
        return arc_score, lab_score
示例#8
0
 def prune_with_scores(arc_score,
                       label_score,
                       mask_expr,
                       pconf: PruneG1Conf,
                       arc_marginals=None):
     prune_use_topk, prune_use_marginal, prune_labeled, prune_perc, prune_topk, prune_gap, prune_mthresh, prune_mthresh_rel = \
         pconf.pruning_use_topk, pconf.pruning_use_marginal, pconf.pruning_labeled, pconf.pruning_perc, pconf.pruning_topk, \
         pconf.pruning_gap, pconf.pruning_mthresh, pconf.pruning_mthresh_rel
     full_score = arc_score + label_score
     final_valid_mask = BK.constants(BK.get_shape(arc_score),
                                     0,
                                     dtype=BK.uint8).squeeze(-1)
     # (put as argument) arc_marginals = None  # [*, mlen, hlen]
     if prune_use_marginal:
         if arc_marginals is None:  # does not provided, calculate from scores
             if prune_labeled:
                 # arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).max(-1)[0]
                 # use sum of label marginals instead of max
                 arc_marginals = nmarginal_unproj(full_score,
                                                  mask_expr,
                                                  None,
                                                  labeled=True).sum(-1)
             else:
                 arc_marginals = nmarginal_unproj(arc_score,
                                                  mask_expr,
                                                  None,
                                                  labeled=True).squeeze(-1)
         if prune_mthresh_rel:
             # relative value
             max_arc_marginals = arc_marginals.max(-1)[0].log().unsqueeze(
                 -1)
             m_valid_mask = (arc_marginals.log() -
                             max_arc_marginals) > float(
                                 np.log(prune_mthresh))
         else:
             # absolute value
             m_valid_mask = (arc_marginals > prune_mthresh
                             )  # [*, len-m, len-h]
         final_valid_mask |= m_valid_mask
     if prune_use_topk:
         # prune by "in topk" and "gap-to-top less than gap" for each mod
         if prune_labeled:  # take argmax among label dim
             tmp_arc_score, _ = full_score.max(-1)
         else:
             # todo(note): may be modified inplaced, but does not matter since will finally be masked later
             tmp_arc_score = arc_score.squeeze(-1)
         # first apply mask
         mask_value = Constants.REAL_PRAC_MIN
         mask_mul = (mask_value * (1. - mask_expr))  # [*, len]
         tmp_arc_score += mask_mul.unsqueeze(-1)
         tmp_arc_score += mask_mul.unsqueeze(-2)
         maxlen = BK.get_shape(tmp_arc_score, -1)
         tmp_arc_score += mask_value * BK.eye(maxlen)
         prune_topk = min(prune_topk, int(maxlen * prune_perc + 1), maxlen)
         if prune_topk >= maxlen:
             topk_arc_score = tmp_arc_score
         else:
             topk_arc_score, _ = BK.topk(tmp_arc_score,
                                         prune_topk,
                                         dim=-1,
                                         sorted=False)  # [*, len, k]
         min_topk_arc_score = topk_arc_score.min(-1)[0].unsqueeze(
             -1)  # [*, len, 1]
         max_topk_arc_score = topk_arc_score.max(-1)[0].unsqueeze(
             -1)  # [*, len, 1]
         arc_score_thresh = BK.max_elem(min_topk_arc_score,
                                        max_topk_arc_score -
                                        prune_gap)  # [*, len, 1]
         t_valid_mask = (tmp_arc_score > arc_score_thresh
                         )  # [*, len-m, len-h]
         final_valid_mask |= t_valid_mask
     return final_valid_mask, arc_marginals
示例#9
0
文件: nmst.py 项目: ValentinaPy/zmsp
def nmarginal_unproj(scores_expr, mask_expr, lengths_arr, labeled=True):
    assert labeled
    with BK.no_grad_env():
        scores_shape = BK.get_shape(scores_expr)
        maxlen = scores_shape[1]
        # todo(warn): it seems that float32 is not enough for inverse when the model gets better (scores gets more diversed)
        diag1_m = BK.eye(maxlen).double()  # [m, h]
        scores_expr_d = scores_expr.double()
        mask_expr_d = mask_expr.double()
        invalid_pad_expr_d = 1. - mask_expr_d
        # [*, m, h]
        full_invalid_d = (diag1_m + invalid_pad_expr_d.unsqueeze(-1) +
                          invalid_pad_expr_d.unsqueeze(-2)).clamp(0., 1.)
        full_invalid_d[:, 0] = 1.
        #
        # first make it unlabeled by sum-exp
        scores_unlabeled = BK.logsumexp(scores_expr_d, dim=-1)  # [BS, m, h]
        # force small values at diag entries and padded ones
        scores_unlabeled_diag_neg = scores_unlabeled + Constants.REAL_PRAC_MIN * full_invalid_d
        # # minus the MaxElement to make it more stable with larger values, to make it numerically stable.
        # [BS, m, h]
        # todo(+N): since one and only one Head is selected, thus minus by Max will make it the same?
        #  I think it will be canceled out since this is like left-mul A by a diag Q
        scores_unlabeled_max = (scores_unlabeled_diag_neg.max(-1)[0] *
                                mask_expr_d).unsqueeze(-1)  # [BS, m, 1]
        scores_exp_unlabeled = BK.exp(scores_unlabeled_diag_neg -
                                      scores_unlabeled_max)
        # # todo(0): co-work with minus-max, force too small values to be 0 (serve as pruning, the gap is ?*ln(10)).
        # scores_exp_unlabeled *= (1 - (scores_exp_unlabeled<1e-10)).double()
        # force 0 at diag entries (again)
        scores_exp_unlabeled *= (1. - diag1_m)
        # assign non-zero values (does not matter) to (0, invalid) to make the matrix inversable
        scores_exp_unlabeled[:, :, 0] += (1. - mask_expr_d
                                          )  # the value does not matter?
        # construct L(or K) Matrix: L=D-A
        A = scores_exp_unlabeled
        A_sum = A.sum(dim=-1, keepdim=True)  # [BS, m, 1]
        # # =====
        # todo(0): can this avoid singular matrix: feels like adding aug-values to h==0(COL0) to-root scores.
        # todo(+N): there are cases that the original matrix is not inversable (no solutions for trees)!!
        A_sum += 1e-6
        # A_sum += A_sum * 1e-4 + 1e-6
        #
        D = A_sum.expand(scores_shape[:-1]) * diag1_m  # [BS, m, h]
        L = D - A  # [BS, m, h]
        # get the minor00 matrix
        LM00 = L[:, 1:, 1:]  # [BS, m-1, h-1]
        # # Debug1
        # try:
        #     # for idx in range(scores_shape[0]):
        #     #         one_det = float(LM00[idx].det())
        #     #         assert not math.isnan(one_det)
        #     #     LM00_CPU = LM00.cpu()
        #     #     LM00_CPU_inv = LM00_CPU.inverse()
        #     scores_exp_unlabeled_CPU = scores_exp_unlabeled.cpu()
        #     LM00_CPU = LM00.cpu()
        #     assert BK.has_nan(LM00_CPU) == 0
        # except:
        #     assert False, "Problem here"
        #
        # det and inverse; using LU decomposition to hit two birds with one stone.
        diag1_m00 = BK.eye(maxlen - 1).double()
        # deprecated operation
        # LM00_inv, LM00_lu = diag1_m00.gesv(LM00)                # [BS, m-1, h-1]
        # # todo(warn): lacking P here, but the partition should always be non-negative!
        # LM00_det = BK.abs((LM00_lu*diag1_m00).sum(-1).prod(-1))         # [BS, ]
        # d(logZ)/d(LM00) = (LM00^-1)^T
        # # directly inverse (need pytorch >= 1.0)
        # LM00_inv = LM00.inverse()
        LM00_inv = BK.get_inverse(LM00, diag1_m00)
        LM00_grad = LM00_inv.transpose(-1, -2)  # [BS, m-1, h-1]
        # marginal(m,h) = d(logZ)/d(score(m,h)) = d(logZ)/d(LM00) * d(LM00)/d(score(m,h)) = INV_mm - INV_mh
        # padding and minus
        LM00_grad_pad = BK.pad(LM00_grad, [1, 0, 1, 0], 'constant',
                               0.)  # [BS, m, h]
        LM00_grad_pad_sum = (LM00_grad_pad * diag1_m).sum(
            dim=-1, keepdim=True)  # [BS, m, 1]
        marginals_unlabeled = A * (LM00_grad_pad_sum - LM00_grad_pad
                                   )  # [BS, m, h]
        # make sure each row sum to 1.
        marginals_unlabeled[:, 0, 0] = 1.
        # finally, get labeled results
        marginals_labeled = marginals_unlabeled.unsqueeze(-1) * BK.exp(
            scores_expr_d - scores_unlabeled.unsqueeze(-1))
        #
        # # Debug2
        # try:
        #     # for idx in range(scores_shape[0]):
        #     #         one_det = float(LM00[idx].det())
        #     #         assert not math.isnan(one_det)
        #     #     LM00_CPU = LM00.cpu()
        #     #     LM00_CPU_inv = LM00_CPU.inverse()
        #     scores_exp_unlabeled_CPU = scores_exp_unlabeled.cpu()
        #     LM00_CPU = LM00.cpu()
        #     marginals_unlabeled_CPU = marginals_unlabeled.cpu()
        #     assert BK.has_nan(marginals_unlabeled_CPU) == 0
        #     #
        #     global last_lm00, last_marginals
        #     last_lm00 = LM00_CPU
        #     last_marginals = marginals_unlabeled_CPU
        # except:
        #     assert False, "Problem here"
        #
        # back to plain float32
        masked_marginals_labeled = marginals_labeled * (
            1. - full_invalid_d).unsqueeze(-1)
        ret = masked_marginals_labeled.float()
        return _ensure_margins_norm(ret)