def nmarginal_proj(scores_expr, mask_expr, lengths_arr, labeled=True): assert labeled with BK.no_grad_env(): # first make it unlabeled by sum-exp scores_unlabeled = BK.logsumexp(scores_expr, dim=-1) # [BS, m, h] # marginal for unlabeled scores_unlabeled_arr = BK.get_value(scores_unlabeled) marginals_unlabeled_arr = marginal_proj(scores_unlabeled_arr, lengths_arr, False) # back to labeled values marginals_unlabeled_expr = BK.input_real(marginals_unlabeled_arr) marginals_labeled_expr = marginals_unlabeled_expr.unsqueeze(-1) * BK.exp(scores_expr - scores_unlabeled.unsqueeze(-1)) # [BS, m, h, L] return _ensure_margins_norm(marginals_labeled_expr)
def nmarginal_unproj(scores_expr, mask_expr, lengths_arr, labeled=True): assert labeled with BK.no_grad_env(): scores_shape = BK.get_shape(scores_expr) maxlen = scores_shape[1] # todo(warn): it seems that float32 is not enough for inverse when the model gets better (scores gets more diversed) diag1_m = BK.eye(maxlen).double() # [m, h] scores_expr_d = scores_expr.double() mask_expr_d = mask_expr.double() invalid_pad_expr_d = 1.-mask_expr_d # [*, m, h] full_invalid_d = (diag1_m + invalid_pad_expr_d.unsqueeze(-1) + invalid_pad_expr_d.unsqueeze(-2)).clamp(0., 1.) full_invalid_d[:, 0] = 1. # # first make it unlabeled by sum-exp scores_unlabeled = BK.logsumexp(scores_expr_d, dim=-1) # [BS, m, h] # force small values at diag entries and padded ones scores_unlabeled_diag_neg = scores_unlabeled + Constants.REAL_PRAC_MIN * full_invalid_d # # minus the MaxElement to make it more stable with larger values, to make it numerically stable. # [BS, m, h] # todo(+N): since one and only one Head is selected, thus minus by Max will make it the same? # I think it will be canceled out since this is like left-mul A by a diag Q scores_unlabeled_max = (scores_unlabeled_diag_neg.max(-1)[0] * mask_expr_d).unsqueeze(-1) # [BS, m, 1] scores_exp_unlabeled = BK.exp(scores_unlabeled_diag_neg - scores_unlabeled_max) # # todo(0): co-work with minus-max, force too small values to be 0 (serve as pruning, the gap is ?*ln(10)). # scores_exp_unlabeled *= (1 - (scores_exp_unlabeled<1e-10)).double() # force 0 at diag entries (again) scores_exp_unlabeled *= (1. - diag1_m) # assign non-zero values (does not matter) to (0, invalid) to make the matrix inversable scores_exp_unlabeled[:, :, 0] += (1. - mask_expr_d) # the value does not matter? # construct L(or K) Matrix: L=D-A A = scores_exp_unlabeled A_sum = A.sum(dim=-1, keepdim=True) # [BS, m, 1] # # ===== # todo(0): can this avoid singular matrix: feels like adding aug-values to h==0(COL0) to-root scores. # todo(+N): there are cases that the original matrix is not inversable (no solutions for trees)!! A_sum += 1e-6 # A_sum += A_sum * 1e-4 + 1e-6 # D = A_sum.expand(scores_shape[:-1])*diag1_m # [BS, m, h] L = D - A # [BS, m, h] # get the minor00 matrix LM00 = L[:, 1:, 1:] # [BS, m-1, h-1] # # Debug1 # try: # # for idx in range(scores_shape[0]): # # one_det = float(LM00[idx].det()) # # assert not math.isnan(one_det) # # LM00_CPU = LM00.cpu() # # LM00_CPU_inv = LM00_CPU.inverse() # scores_exp_unlabeled_CPU = scores_exp_unlabeled.cpu() # LM00_CPU = LM00.cpu() # assert BK.has_nan(LM00_CPU) == 0 # except: # assert False, "Problem here" # # det and inverse; using LU decomposition to hit two birds with one stone. diag1_m00 = BK.eye(maxlen-1).double() # deprecated operation # LM00_inv, LM00_lu = diag1_m00.gesv(LM00) # [BS, m-1, h-1] # # todo(warn): lacking P here, but the partition should always be non-negative! # LM00_det = BK.abs((LM00_lu*diag1_m00).sum(-1).prod(-1)) # [BS, ] # d(logZ)/d(LM00) = (LM00^-1)^T # # directly inverse (need pytorch >= 1.0) # LM00_inv = LM00.inverse() LM00_inv = BK.get_inverse(LM00, diag1_m00) LM00_grad = LM00_inv.transpose(-1, -2) # [BS, m-1, h-1] # marginal(m,h) = d(logZ)/d(score(m,h)) = d(logZ)/d(LM00) * d(LM00)/d(score(m,h)) = INV_mm - INV_mh # padding and minus LM00_grad_pad = BK.pad(LM00_grad, [1,0,1,0], 'constant', 0.) # [BS, m, h] LM00_grad_pad_sum = (LM00_grad_pad * diag1_m).sum(dim=-1, keepdim=True) # [BS, m, 1] marginals_unlabeled = A * (LM00_grad_pad_sum - LM00_grad_pad) # [BS, m, h] # make sure each row sum to 1. marginals_unlabeled[:, 0, 0] = 1. # finally, get labeled results marginals_labeled = marginals_unlabeled.unsqueeze(-1) * BK.exp(scores_expr_d - scores_unlabeled.unsqueeze(-1)) # # # Debug2 # try: # # for idx in range(scores_shape[0]): # # one_det = float(LM00[idx].det()) # # assert not math.isnan(one_det) # # LM00_CPU = LM00.cpu() # # LM00_CPU_inv = LM00_CPU.inverse() # scores_exp_unlabeled_CPU = scores_exp_unlabeled.cpu() # LM00_CPU = LM00.cpu() # marginals_unlabeled_CPU = marginals_unlabeled.cpu() # assert BK.has_nan(marginals_unlabeled_CPU) == 0 # # # global last_lm00, last_marginals # last_lm00 = LM00_CPU # last_marginals = marginals_unlabeled_CPU # except: # assert False, "Problem here" # # back to plain float32 masked_marginals_labeled = marginals_labeled * (1.-full_invalid_d).unsqueeze(-1) ret = masked_marginals_labeled.float() return _ensure_margins_norm(ret)