Example #1
0
 def forward(self, input_t: BK.Expr, edges: BK.Expr, mask_t: BK.Expr):
     _isize = self.conf._isize
     _ntype = self.conf.type_num
     _slen = BK.get_shape(edges, -1)
     # --
     edges3 = edges.clamp(min=-1, max=1) + 1
     edgesF = edges + _ntype  # offset to positive!
     # get hid
     hid0 = BK.matmul(input_t, self.W_hid).view(
         BK.get_shape(input_t)[:-1] + [3, _isize])  # [*, L, 3, D]
     hid1 = hid0.unsqueeze(-4).expand(-1, _slen, -1, -1,
                                      -1)  # [*, L, L, 3, D]
     hid2 = BK.gather_first_dims(hid1.contiguous(), edges3.unsqueeze(-1),
                                 -2).squeeze(-2)  # [*, L, L, D]
     hidB = self.b_hid[edgesF]  # [*, L, L, D]
     _hid = hid2 + hidB
     # get gate
     gate0 = BK.matmul(input_t, self.W_gate)  # [*, L, 3]
     gate1 = gate0.unsqueeze(-3).expand(-1, _slen, -1, -1)  # [*, L, L, 3]
     gate2 = gate1.gather(-1, edges3.unsqueeze(-1))  # [*, L, L, 1]
     gateB = self.b_gate[edgesF].unsqueeze(-1)  # [*, L, L, 1]
     _gate0 = BK.sigmoid(gate2 + gateB)
     _gmask0 = (
         (edges != 0) |
         (BK.eye(_slen) > 0)).float() * mask_t.unsqueeze(-2)  # [*,L,L]
     _gate = _gate0 * _gmask0.unsqueeze(-1)  # [*,L,L,1]
     # combine
     h0 = BK.relu((_hid * _gate).sum(-2))  # [*, L, D]
     h1 = self.drop_node(h0)
     # add & norm?
     if self.ln is not None:
         h1 = self.ln(h1 + input_t)
     return h1
Example #2
0
def nmarginal_unproj(scores_expr, mask_expr, lengths_arr, labeled=True):
    assert labeled
    with BK.no_grad_env():
        scores_shape = BK.get_shape(scores_expr)
        maxlen = scores_shape[1]
        # todo(warn): it seems that float32 is not enough for inverse when the model gets better (scores gets more diversed)
        diag1_m = BK.eye(maxlen).double()        # [m, h]
        scores_expr_d = scores_expr.double()
        mask_expr_d = mask_expr.double()
        invalid_pad_expr_d = 1.-mask_expr_d
        # [*, m, h]
        full_invalid_d = (diag1_m + invalid_pad_expr_d.unsqueeze(-1) + invalid_pad_expr_d.unsqueeze(-2)).clamp(0., 1.)
        full_invalid_d[:, 0] = 1.
        #
        # first make it unlabeled by sum-exp
        scores_unlabeled = BK.logsumexp(scores_expr_d, dim=-1)    # [BS, m, h]
        # force small values at diag entries and padded ones
        scores_unlabeled_diag_neg = scores_unlabeled + Constants.REAL_PRAC_MIN * full_invalid_d
        # # minus the MaxElement to make it more stable with larger values, to make it numerically stable.
        # [BS, m, h]
        # todo(+N): since one and only one Head is selected, thus minus by Max will make it the same?
        #  I think it will be canceled out since this is like left-mul A by a diag Q
        scores_unlabeled_max = (scores_unlabeled_diag_neg.max(-1)[0] * mask_expr_d).unsqueeze(-1)   # [BS, m, 1]
        scores_exp_unlabeled = BK.exp(scores_unlabeled_diag_neg - scores_unlabeled_max)
        # # todo(0): co-work with minus-max, force too small values to be 0 (serve as pruning, the gap is ?*ln(10)).
        # scores_exp_unlabeled *= (1 - (scores_exp_unlabeled<1e-10)).double()
        # force 0 at diag entries (again)
        scores_exp_unlabeled *= (1. - diag1_m)
        # assign non-zero values (does not matter) to (0, invalid) to make the matrix inversable
        scores_exp_unlabeled[:, :, 0] += (1. - mask_expr_d)      # the value does not matter?
        # construct L(or K) Matrix: L=D-A
        A = scores_exp_unlabeled
        A_sum = A.sum(dim=-1, keepdim=True)                 # [BS, m, 1]
        # # =====
        # todo(0): can this avoid singular matrix: feels like adding aug-values to h==0(COL0) to-root scores.
        # todo(+N): there are cases that the original matrix is not inversable (no solutions for trees)!!
        A_sum += 1e-6
        # A_sum += A_sum * 1e-4 + 1e-6
        #
        D = A_sum.expand(scores_shape[:-1])*diag1_m         # [BS, m, h]
        L = D - A                                           # [BS, m, h]
        # get the minor00 matrix
        LM00 = L[:, 1:, 1:]          # [BS, m-1, h-1]
        # # Debug1
        # try:
        #     # for idx in range(scores_shape[0]):
        #     #         one_det = float(LM00[idx].det())
        #     #         assert not math.isnan(one_det)
        #     #     LM00_CPU = LM00.cpu()
        #     #     LM00_CPU_inv = LM00_CPU.inverse()
        #     scores_exp_unlabeled_CPU = scores_exp_unlabeled.cpu()
        #     LM00_CPU = LM00.cpu()
        #     assert BK.has_nan(LM00_CPU) == 0
        # except:
        #     assert False, "Problem here"
        #
        # det and inverse; using LU decomposition to hit two birds with one stone.
        diag1_m00 = BK.eye(maxlen-1).double()
        # deprecated operation
        # LM00_inv, LM00_lu = diag1_m00.gesv(LM00)                # [BS, m-1, h-1]
        # # todo(warn): lacking P here, but the partition should always be non-negative!
        # LM00_det = BK.abs((LM00_lu*diag1_m00).sum(-1).prod(-1))         # [BS, ]
        # d(logZ)/d(LM00) = (LM00^-1)^T
        # # directly inverse (need pytorch >= 1.0)
        # LM00_inv = LM00.inverse()
        LM00_inv = BK.get_inverse(LM00, diag1_m00)
        LM00_grad = LM00_inv.transpose(-1, -2)              # [BS, m-1, h-1]
        # marginal(m,h) = d(logZ)/d(score(m,h)) = d(logZ)/d(LM00) * d(LM00)/d(score(m,h)) = INV_mm - INV_mh
        # padding and minus
        LM00_grad_pad = BK.pad(LM00_grad, [1,0,1,0], 'constant', 0.)                    # [BS, m, h]
        LM00_grad_pad_sum = (LM00_grad_pad * diag1_m).sum(dim=-1, keepdim=True)     # [BS, m, 1]
        marginals_unlabeled = A * (LM00_grad_pad_sum - LM00_grad_pad)                         # [BS, m, h]
        # make sure each row sum to 1.
        marginals_unlabeled[:, 0, 0] = 1.
        # finally, get labeled results
        marginals_labeled = marginals_unlabeled.unsqueeze(-1) * BK.exp(scores_expr_d - scores_unlabeled.unsqueeze(-1))
        #
        # # Debug2
        # try:
        #     # for idx in range(scores_shape[0]):
        #     #         one_det = float(LM00[idx].det())
        #     #         assert not math.isnan(one_det)
        #     #     LM00_CPU = LM00.cpu()
        #     #     LM00_CPU_inv = LM00_CPU.inverse()
        #     scores_exp_unlabeled_CPU = scores_exp_unlabeled.cpu()
        #     LM00_CPU = LM00.cpu()
        #     marginals_unlabeled_CPU = marginals_unlabeled.cpu()
        #     assert BK.has_nan(marginals_unlabeled_CPU) == 0
        #     #
        #     global last_lm00, last_marginals
        #     last_lm00 = LM00_CPU
        #     last_marginals = marginals_unlabeled_CPU
        # except:
        #     assert False, "Problem here"
        #
        # back to plain float32
        masked_marginals_labeled = marginals_labeled * (1.-full_invalid_d).unsqueeze(-1)
        ret = masked_marginals_labeled.float()
        return _ensure_margins_norm(ret)