def forward(self, input_t: BK.Expr, edges: BK.Expr, mask_t: BK.Expr): _isize = self.conf._isize _ntype = self.conf.type_num _slen = BK.get_shape(edges, -1) # -- edges3 = edges.clamp(min=-1, max=1) + 1 edgesF = edges + _ntype # offset to positive! # get hid hid0 = BK.matmul(input_t, self.W_hid).view( BK.get_shape(input_t)[:-1] + [3, _isize]) # [*, L, 3, D] hid1 = hid0.unsqueeze(-4).expand(-1, _slen, -1, -1, -1) # [*, L, L, 3, D] hid2 = BK.gather_first_dims(hid1.contiguous(), edges3.unsqueeze(-1), -2).squeeze(-2) # [*, L, L, D] hidB = self.b_hid[edgesF] # [*, L, L, D] _hid = hid2 + hidB # get gate gate0 = BK.matmul(input_t, self.W_gate) # [*, L, 3] gate1 = gate0.unsqueeze(-3).expand(-1, _slen, -1, -1) # [*, L, L, 3] gate2 = gate1.gather(-1, edges3.unsqueeze(-1)) # [*, L, L, 1] gateB = self.b_gate[edgesF].unsqueeze(-1) # [*, L, L, 1] _gate0 = BK.sigmoid(gate2 + gateB) _gmask0 = ( (edges != 0) | (BK.eye(_slen) > 0)).float() * mask_t.unsqueeze(-2) # [*,L,L] _gate = _gate0 * _gmask0.unsqueeze(-1) # [*,L,L,1] # combine h0 = BK.relu((_hid * _gate).sum(-2)) # [*, L, D] h1 = self.drop_node(h0) # add & norm? if self.ln is not None: h1 = self.ln(h1 + input_t) return h1
def nmarginal_unproj(scores_expr, mask_expr, lengths_arr, labeled=True): assert labeled with BK.no_grad_env(): scores_shape = BK.get_shape(scores_expr) maxlen = scores_shape[1] # todo(warn): it seems that float32 is not enough for inverse when the model gets better (scores gets more diversed) diag1_m = BK.eye(maxlen).double() # [m, h] scores_expr_d = scores_expr.double() mask_expr_d = mask_expr.double() invalid_pad_expr_d = 1.-mask_expr_d # [*, m, h] full_invalid_d = (diag1_m + invalid_pad_expr_d.unsqueeze(-1) + invalid_pad_expr_d.unsqueeze(-2)).clamp(0., 1.) full_invalid_d[:, 0] = 1. # # first make it unlabeled by sum-exp scores_unlabeled = BK.logsumexp(scores_expr_d, dim=-1) # [BS, m, h] # force small values at diag entries and padded ones scores_unlabeled_diag_neg = scores_unlabeled + Constants.REAL_PRAC_MIN * full_invalid_d # # minus the MaxElement to make it more stable with larger values, to make it numerically stable. # [BS, m, h] # todo(+N): since one and only one Head is selected, thus minus by Max will make it the same? # I think it will be canceled out since this is like left-mul A by a diag Q scores_unlabeled_max = (scores_unlabeled_diag_neg.max(-1)[0] * mask_expr_d).unsqueeze(-1) # [BS, m, 1] scores_exp_unlabeled = BK.exp(scores_unlabeled_diag_neg - scores_unlabeled_max) # # todo(0): co-work with minus-max, force too small values to be 0 (serve as pruning, the gap is ?*ln(10)). # scores_exp_unlabeled *= (1 - (scores_exp_unlabeled<1e-10)).double() # force 0 at diag entries (again) scores_exp_unlabeled *= (1. - diag1_m) # assign non-zero values (does not matter) to (0, invalid) to make the matrix inversable scores_exp_unlabeled[:, :, 0] += (1. - mask_expr_d) # the value does not matter? # construct L(or K) Matrix: L=D-A A = scores_exp_unlabeled A_sum = A.sum(dim=-1, keepdim=True) # [BS, m, 1] # # ===== # todo(0): can this avoid singular matrix: feels like adding aug-values to h==0(COL0) to-root scores. # todo(+N): there are cases that the original matrix is not inversable (no solutions for trees)!! A_sum += 1e-6 # A_sum += A_sum * 1e-4 + 1e-6 # D = A_sum.expand(scores_shape[:-1])*diag1_m # [BS, m, h] L = D - A # [BS, m, h] # get the minor00 matrix LM00 = L[:, 1:, 1:] # [BS, m-1, h-1] # # Debug1 # try: # # for idx in range(scores_shape[0]): # # one_det = float(LM00[idx].det()) # # assert not math.isnan(one_det) # # LM00_CPU = LM00.cpu() # # LM00_CPU_inv = LM00_CPU.inverse() # scores_exp_unlabeled_CPU = scores_exp_unlabeled.cpu() # LM00_CPU = LM00.cpu() # assert BK.has_nan(LM00_CPU) == 0 # except: # assert False, "Problem here" # # det and inverse; using LU decomposition to hit two birds with one stone. diag1_m00 = BK.eye(maxlen-1).double() # deprecated operation # LM00_inv, LM00_lu = diag1_m00.gesv(LM00) # [BS, m-1, h-1] # # todo(warn): lacking P here, but the partition should always be non-negative! # LM00_det = BK.abs((LM00_lu*diag1_m00).sum(-1).prod(-1)) # [BS, ] # d(logZ)/d(LM00) = (LM00^-1)^T # # directly inverse (need pytorch >= 1.0) # LM00_inv = LM00.inverse() LM00_inv = BK.get_inverse(LM00, diag1_m00) LM00_grad = LM00_inv.transpose(-1, -2) # [BS, m-1, h-1] # marginal(m,h) = d(logZ)/d(score(m,h)) = d(logZ)/d(LM00) * d(LM00)/d(score(m,h)) = INV_mm - INV_mh # padding and minus LM00_grad_pad = BK.pad(LM00_grad, [1,0,1,0], 'constant', 0.) # [BS, m, h] LM00_grad_pad_sum = (LM00_grad_pad * diag1_m).sum(dim=-1, keepdim=True) # [BS, m, 1] marginals_unlabeled = A * (LM00_grad_pad_sum - LM00_grad_pad) # [BS, m, h] # make sure each row sum to 1. marginals_unlabeled[:, 0, 0] = 1. # finally, get labeled results marginals_labeled = marginals_unlabeled.unsqueeze(-1) * BK.exp(scores_expr_d - scores_unlabeled.unsqueeze(-1)) # # # Debug2 # try: # # for idx in range(scores_shape[0]): # # one_det = float(LM00[idx].det()) # # assert not math.isnan(one_det) # # LM00_CPU = LM00.cpu() # # LM00_CPU_inv = LM00_CPU.inverse() # scores_exp_unlabeled_CPU = scores_exp_unlabeled.cpu() # LM00_CPU = LM00.cpu() # marginals_unlabeled_CPU = marginals_unlabeled.cpu() # assert BK.has_nan(marginals_unlabeled_CPU) == 0 # # # global last_lm00, last_marginals # last_lm00 = LM00_CPU # last_marginals = marginals_unlabeled_CPU # except: # assert False, "Problem here" # # back to plain float32 masked_marginals_labeled = marginals_labeled * (1.-full_invalid_d).unsqueeze(-1) ret = masked_marginals_labeled.float() return _ensure_margins_norm(ret)