Example #1
0
def nmarginal_proj(scores_expr, mask_expr, lengths_arr, labeled=True):
    assert labeled
    with BK.no_grad_env():
        # first make it unlabeled by sum-exp
        scores_unlabeled = BK.logsumexp(scores_expr, dim=-1)  # [BS, m, h]
        # marginal for unlabeled
        scores_unlabeled_arr = BK.get_value(scores_unlabeled)
        marginals_unlabeled_arr = marginal_proj(scores_unlabeled_arr,
                                                lengths_arr, False)
        # back to labeled values
        marginals_unlabeled_expr = BK.input_real(marginals_unlabeled_arr)
        marginals_labeled_expr = marginals_unlabeled_expr.unsqueeze(
            -1) * BK.exp(scores_expr - scores_unlabeled.unsqueeze(-1))
        # [BS, m, h, L]
        return _ensure_margins_norm(marginals_labeled_expr)
Example #2
0
def nmarginal_unproj(scores_expr, mask_expr, lengths_arr, labeled=True):
    assert labeled
    with BK.no_grad_env():
        scores_shape = BK.get_shape(scores_expr)
        maxlen = scores_shape[1]
        # todo(warn): it seems that float32 is not enough for inverse when the model gets better (scores gets more diversed)
        diag1_m = BK.eye(maxlen).double()  # [m, h]
        scores_expr_d = scores_expr.double()
        mask_expr_d = mask_expr.double()
        invalid_pad_expr_d = 1. - mask_expr_d
        # [*, m, h]
        full_invalid_d = (diag1_m + invalid_pad_expr_d.unsqueeze(-1) +
                          invalid_pad_expr_d.unsqueeze(-2)).clamp(0., 1.)
        full_invalid_d[:, 0] = 1.
        #
        # first make it unlabeled by sum-exp
        scores_unlabeled = BK.logsumexp(scores_expr_d, dim=-1)  # [BS, m, h]
        # force small values at diag entries and padded ones
        scores_unlabeled_diag_neg = scores_unlabeled + Constants.REAL_PRAC_MIN * full_invalid_d
        # # minus the MaxElement to make it more stable with larger values, to make it numerically stable.
        # [BS, m, h]
        # todo(+N): since one and only one Head is selected, thus minus by Max will make it the same?
        #  I think it will be canceled out since this is like left-mul A by a diag Q
        scores_unlabeled_max = (scores_unlabeled_diag_neg.max(-1)[0] *
                                mask_expr_d).unsqueeze(-1)  # [BS, m, 1]
        scores_exp_unlabeled = BK.exp(scores_unlabeled_diag_neg -
                                      scores_unlabeled_max)
        # # todo(0): co-work with minus-max, force too small values to be 0 (serve as pruning, the gap is ?*ln(10)).
        # scores_exp_unlabeled *= (1 - (scores_exp_unlabeled<1e-10)).double()
        # force 0 at diag entries (again)
        scores_exp_unlabeled *= (1. - diag1_m)
        # assign non-zero values (does not matter) to (0, invalid) to make the matrix inversable
        scores_exp_unlabeled[:, :, 0] += (1. - mask_expr_d
                                          )  # the value does not matter?
        # construct L(or K) Matrix: L=D-A
        A = scores_exp_unlabeled
        A_sum = A.sum(dim=-1, keepdim=True)  # [BS, m, 1]
        # # =====
        # todo(0): can this avoid singular matrix: feels like adding aug-values to h==0(COL0) to-root scores.
        # todo(+N): there are cases that the original matrix is not inversable (no solutions for trees)!!
        A_sum += 1e-6
        # A_sum += A_sum * 1e-4 + 1e-6
        #
        D = A_sum.expand(scores_shape[:-1]) * diag1_m  # [BS, m, h]
        L = D - A  # [BS, m, h]
        # get the minor00 matrix
        LM00 = L[:, 1:, 1:]  # [BS, m-1, h-1]
        # # Debug1
        # try:
        #     # for idx in range(scores_shape[0]):
        #     #         one_det = float(LM00[idx].det())
        #     #         assert not math.isnan(one_det)
        #     #     LM00_CPU = LM00.cpu()
        #     #     LM00_CPU_inv = LM00_CPU.inverse()
        #     scores_exp_unlabeled_CPU = scores_exp_unlabeled.cpu()
        #     LM00_CPU = LM00.cpu()
        #     assert BK.has_nan(LM00_CPU) == 0
        # except:
        #     assert False, "Problem here"
        #
        # det and inverse; using LU decomposition to hit two birds with one stone.
        diag1_m00 = BK.eye(maxlen - 1).double()
        # deprecated operation
        # LM00_inv, LM00_lu = diag1_m00.gesv(LM00)                # [BS, m-1, h-1]
        # # todo(warn): lacking P here, but the partition should always be non-negative!
        # LM00_det = BK.abs((LM00_lu*diag1_m00).sum(-1).prod(-1))         # [BS, ]
        # d(logZ)/d(LM00) = (LM00^-1)^T
        # # directly inverse (need pytorch >= 1.0)
        # LM00_inv = LM00.inverse()
        LM00_inv = BK.get_inverse(LM00, diag1_m00)
        LM00_grad = LM00_inv.transpose(-1, -2)  # [BS, m-1, h-1]
        # marginal(m,h) = d(logZ)/d(score(m,h)) = d(logZ)/d(LM00) * d(LM00)/d(score(m,h)) = INV_mm - INV_mh
        # padding and minus
        LM00_grad_pad = BK.pad(LM00_grad, [1, 0, 1, 0], 'constant',
                               0.)  # [BS, m, h]
        LM00_grad_pad_sum = (LM00_grad_pad * diag1_m).sum(
            dim=-1, keepdim=True)  # [BS, m, 1]
        marginals_unlabeled = A * (LM00_grad_pad_sum - LM00_grad_pad
                                   )  # [BS, m, h]
        # make sure each row sum to 1.
        marginals_unlabeled[:, 0, 0] = 1.
        # finally, get labeled results
        marginals_labeled = marginals_unlabeled.unsqueeze(-1) * BK.exp(
            scores_expr_d - scores_unlabeled.unsqueeze(-1))
        #
        # # Debug2
        # try:
        #     # for idx in range(scores_shape[0]):
        #     #         one_det = float(LM00[idx].det())
        #     #         assert not math.isnan(one_det)
        #     #     LM00_CPU = LM00.cpu()
        #     #     LM00_CPU_inv = LM00_CPU.inverse()
        #     scores_exp_unlabeled_CPU = scores_exp_unlabeled.cpu()
        #     LM00_CPU = LM00.cpu()
        #     marginals_unlabeled_CPU = marginals_unlabeled.cpu()
        #     assert BK.has_nan(marginals_unlabeled_CPU) == 0
        #     #
        #     global last_lm00, last_marginals
        #     last_lm00 = LM00_CPU
        #     last_marginals = marginals_unlabeled_CPU
        # except:
        #     assert False, "Problem here"
        #
        # back to plain float32
        masked_marginals_labeled = marginals_labeled * (
            1. - full_invalid_d).unsqueeze(-1)
        ret = masked_marginals_labeled.float()
        return _ensure_margins_norm(ret)