Exemple #1
0
    def forward(self, gt, preds, att, collate=True):
        if self.num_branches == 1:
            preds = [preds]

        loss_regular = super(MultiCrossEntropyMultiBranchWithDT,
                             self).forward(gt, preds, collate=collate)
        # print 'min_val',torch.min(torch.abs(att))

        # print att.size()
        # print att[:10]
        k = att[:, 1]
        att = att[:, 0]

        alpha_curr = self.alpha / k
        # print alpha_curr
        lbeta = k * torch.mvlgamma(alpha_curr, 1) - torch.mvlgamma(
            k * alpha_curr, 1)
        # print lbeta[:10]
        # raw_input()
        l1 = torch.mean((1 - alpha_curr) * att + lbeta)

        if collate:
            l1 = self.att_weight * l1
            loss_all = l1 + loss_regular
        else:
            loss_regular.append(l1)
            loss_all = loss_regular

        return loss_all
Exemple #2
0
def dirichlet_kl_div(alpha1, alpha2=to_var(torch.FloatTensor([0.0]))):
    alpha0 = to_var(torch.Tensor(alpha1.shape[1], alpha1.shape[0]))
    alpha0 = torch.transpose(alpha0.copy_(torch.sum(alpha1, 1)), 1, 0)
    try:
        return torch.mvlgamma(torch.sum(alpha1, 1), 1) - torch.mvlgamma(torch.sum(alpha2, 1), 1) - \
               torch.sum(torch.mvlgamma(alpha1, 1), 1) + torch.sum(torch.mvlgamma(alpha2, 1), 1) + \
               torch.sum((alpha1 - alpha2) * (torch.digamma(alpha1) - torch.digamma(alpha0)), 1)
    except RuntimeError:
        print(alpha0)
        print(alpha1)
        print(alpha2)
        sys.exit(-1)
Exemple #3
0
def kl_diag_wishart(p: DiagonalWishart, q: DiagonalWishart):
    if p.event_shape != q.event_shape:
        raise ValueError("KL-divergence between two Diagonal Wisharts with\
                          different event shapes cannot be computed")
    log_det_term = -(0.5 * q.df) * torch.div(p.scale_diag,
                                             q.scale_diag).log().sum(dim=-1)
    tr_term = (0.5 * p.df) * (
        torch.div(p.scale_diag, q.scale_diag).sum(dim=-1) - p.dimensionality)
    mvlgamma_term = torch.mvlgamma(0.5 * q.df,
                                   q.dimensionality) - torch.mvlgamma(
                                       0.5 * p.df, p.dimensionality)
    digamma_term = 0.5 * (p.df - q.df) * mvdigamma(0.5 * p.df,
                                                   p.dimensionality)
    return log_det_term + tr_term + mvlgamma_term + digamma_term
Exemple #4
0
 def log_prob(self, value):
     # See: https://mc-stan.org/docs/2_25/functions-reference/cholesky-lkj-correlation-distribution.html
     # The probability of a correlation matrix is proportional to
     #   determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1))
     # Additionally, the Jacobian of the transformation from Cholesky factor to
     # correlation matrix is:
     #   prod(L_ii ^ (D - i))
     # So the probability of a Cholesky factor is propotional to
     #   prod(L_ii ^ (2 * concentration - 2 + D - i)) = prod(L_ii ^ order_i)
     # with order_i = 2 * concentration - 2 + D - i
     if self._validate_args:
         self._validate_sample(value)
     diag_elems = value.diagonal(dim1=-1, dim2=-2)[..., 1:]
     order = torch.arange(2, self.dim + 1, device=self.concentration.device)
     order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order
     unnormalized_log_pdf = torch.sum(order * diag_elems.log(), dim=-1)
     # Compute normalization constant (page 1999 of [1])
     dm1 = self.dim - 1
     alpha = self.concentration + 0.5 * dm1
     denominator = torch.lgamma(alpha) * dm1
     numerator = torch.mvlgamma(alpha - 0.5, dm1)
     # pi_constant in [1] is D * (D - 1) / 4 * log(pi)
     # pi_constant in multigammaln is (D - 1) * (D - 2) / 4 * log(pi)
     # hence, we need to add a pi_constant = (D - 1) * log(pi) / 2
     pi_constant = 0.5 * dm1 * math.log(math.pi)
     normalize_term = pi_constant + numerator - denominator
     return unnormalized_log_pdf - normalize_term
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     log_scale = math.log(self.scale) if isinstance(self.scale, Number) else self.scale.log()
     log_beta = math.log(self.beta) if isinstance(self.beta, Number) else self.beta.log()
     log_gamma = math.log(1.0 / self.beta) if isinstance(self.beta, Number) else torch.mvlgamma(1.0 / self.beta, 1)
     return -((torch.abs(value - self.loc) / (self.scale)) ** self.beta) + log_beta - log_scale - math.log(2) - log_gamma
Exemple #6
0
 def __init__(self, nu, K, validate_args=False):
     TModule.__init__(self)
     if K.dim() < 2:
         raise ValueError("K must be at least 2-dimensional")
     n = K.shape[-1]
     if K.shape[-2] != K.shape[-1]:
         raise ValueError("K must be square")
     if isinstance(nu, Number):
         nu = torch.tensor(float(nu))
     if torch.any(nu <= n):
         raise ValueError("Must have nu > n - 1")
     self.n = torch.tensor(n, dtype=torch.long, device=nu.device)
     batch_shape = nu.shape
     event_shape = torch.Size([n, n])
     # normalization constant
     logdetK = torch.logdet(K) if K.dim() == 2 else torch.stack(
         [torch.logdet(k) for k in K])
     C = -(nu / 2) * (logdetK + n * math.log(2)) - torch.mvlgamma(nu / 2, n)
     K_inv = torch.inverse(K) if K.dim() == 2 else torch.stack(
         [torch.inverse(k) for k in K])
     # need to assign values before registering as buffers to make argument validation work
     self.nu = nu
     self.K_inv = K_inv
     self.C = C
     super(WishartPrior, self).__init__(batch_shape,
                                        event_shape,
                                        validate_args=validate_args)
     # now need to delete to be able to register buffer
     del self.nu, self.K_inv, self.C
     self.register_buffer("nu", nu)
     self.register_buffer("K_inv", K_inv)
     self.register_buffer("C", C)
Exemple #7
0
 def __init__(self, nu, K, validate_args=False):
     TModule.__init__(self)
     if K.dim() < 2:
         raise ValueError("K must be at least 2-dimensional")
     n = K.shape[-1]
     if isinstance(nu, Number):
         nu = torch.tensor(float(nu))
     if torch.any(nu <= 0):
         raise ValueError("Must have nu > 0")
     self.n = torch.tensor(n, dtype=torch.long, device=nu.device)
     batch_shape = nu.shape
     event_shape = torch.Size([n, n])
     # normalization constant
     c = (nu + n - 1) / 2
     logdetK = torch.logdet(K) if K.dim() == 2 else torch.stack(
         [torch.logdet(k) for k in K])
     C = c * (logdetK - n * math.log(2)) - torch.mvlgamma(c, n)
     # need to assign values before registering as buffers to make argument validation work
     self.nu = nu
     self.K = K
     self.C = C
     super(InverseWishartPrior, self).__init__(batch_shape,
                                               event_shape,
                                               validate_args=validate_args)
     # now need to delete to be able to register buffer
     del self.nu, self.K, self.C
     self.register_buffer("nu", nu)
     self.register_buffer("K", K)
     self.register_buffer("C", C)
     self._log_transform = False
Exemple #8
0
 def entropy(self):
     nu = self.df  # has shape (batch_shape)
     p = self._event_shape[-1]  # has singleton shape
     V = self.covariance_matrix  # has shape (batch_shape x event_shape)
     return ((p + 1) * self._unbroadcasted_scale_tril.diagonal(
         dim1=-2, dim2=-1).log().sum(-1) + 0.5 * p * (p + 1) * _log_2 +
             torch.mvlgamma(0.5 * nu, p=p) - 0.5 *
             (nu - p - 1) * _mvdigamma(0.5 * nu, p=p) + 0.5 * nu * p)
Exemple #9
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     nu = self.df  # has shape (batch_shape)
     p = self._event_shape[-1]  # has singleton shape
     return (
         - nu * (p * _log_2 / 2 + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1))
         - torch.mvlgamma(nu / 2, p=p)
         + (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet
         - torch.cholesky_solve(value, self._unbroadcasted_scale_tril).diagonal(dim1=-2, dim2=-1).sum(dim=-1) / 2
     )
Exemple #10
0
    def log_prob(self, x):
        nu = self.nu
        p = self.p

        res = ((nu - p - 1) / 2) * t.logdet(x)
        res -= (1 / 2) * (t.inverse(self.K) * x).sum((-1, -2))
        res -= (nu * p / 2) * math.log(2)
        res -= (nu / 2) * t.logdet(self.K)
        res -= t.mvlgamma(nu / 2 * t.ones(()), p)

        return res
Exemple #11
0
def D_alpha_u_n(n, alpha_n, DU_DY, Y):
    '''
    Compute D^alpha(y)_y u(y) at y=y_n.
    (Note: the fractional derivative at y_0 is not defined when 0<alpha(y_0)<1.
    In such cases, the output is 0.)

    :param n      : integer n in [0, N].
    :param alpha_n: alpha(y_n). Must be in [0, 1].
    :param DU_DY  : du/dy(y_n) for n = 0, 1, ..., N.
    :param Y      : y_n for n = 0, 1, ..., N.
    :return       : scalar.
    '''
    torch_gamma = lambda x: torch.exp(torch.mvlgamma(x, p=1))

    def _D_0_u_n(n, DU_DY, Y):

        # integrate du/dy from y_0 to y_n using the composite trapezoidal rule
        res = torch.dot(Y[1:n + 1] - Y[:n], DU_DY[1:n + 1] + DU_DY[:n]) / 2.
        return res

    def _D_1_u_n(n, DU_DY, Y):
        return DU_DY[n]

    def _D_alpha_u_n(n, alpha_n, DU_DY, Y):
        a = alpha_n
        fac = 1. / torch_gamma(2 - a)

        b = lambda k: (Y[n] - Y[k])**(1 - a) - (Y[n] - Y[k + 1])**(1 - a)

        kk = torch.arange(n - 1)
        sum_ = torch.dot(b(kk),
                         DU_DY[kk]) + (Y[n] - Y[n - 1])**(1 - a) * DU_DY[n - 1]

        return fac * sum_

    if n == 0:
        if alpha_n == 0:
            res = _D_0_u_n(n, DU_DY, Y).view(alpha_n.shape)
        elif alpha_n == 1:
            res = _D_1_u_n(n, DU_DY, Y).view(alpha_n.shape)
        else:
            res = torch.full(alpha_n.shape, 0)

    else:
        if alpha_n == 0:
            res = _D_0_u_n(n, DU_DY, Y).view(alpha_n.shape)
        elif alpha_n == 1:
            res = _D_1_u_n(n, DU_DY, Y).view(alpha_n.shape)
        else:
            res = _D_alpha_u_n(n, alpha_n, DU_DY, Y)

    return res
    def log_prob(self, value):
        chol = self.cholesky_factor

        scale = torch.bmm(chol, chol.transpose(-2, -1))
        log_normalizer = (self.df * self._dim / 2.) * np.log(2) +\
                         (self.df / 2.) * torch.logdet(scale) +\
                         torch.mvlgamma(self.df / 2., self._dim)

        numerator_logdet = (self.df - self._dim - 1) / 2. * torch.logdet(value)
        choleskied_value = torch.bmm(torch.inverse(chol), value)
        numerator_logtrace = -1 / 2 * torch.diagonal(
            choleskied_value, dim1=-2, dim2=-1).sum(-1)
        log_numerator = numerator_logdet + numerator_logtrace
        return log_numerator - log_normalizer
Exemple #13
0
    def log_prob(self, x, AAT=None):
        if AAT is None:
            AAT = x.inv(self.K.full())
        else:
            AAT = AAT.full()

        nu = self.nu
        p = self.K.N

        res = -((nu + p + 1) / 2) * x.logdet()
        #modified by multiplying by nu
        res += (nu / 2) * self.K.logdet() + (p * nu / 2) * t.log(
            t.ones((), device=x.device) * nu)
        #modified by multiplying by nu
        res -= (nu / 2) * (AAT).diagonal(dim1=-1, dim2=-2).sum(-1)
        res -= (nu * p / 2) * math.log(2)
        res -= t.mvlgamma(nu / 2 * t.ones(()), p)

        return res
 def _log_normalizer(self, eta1, eta2):
     D = self.event_shape[-1]
     a = -(eta1 + .5 * (D + 1))
     return torch.mvlgamma(a, D) - a * _posdef_logdet(-eta2)
Exemple #15
0
 def log_normalizer(self):
     D = self.event_shape[-1]
     return torch.mvlgamma(self.concentration, D) + self.concentration \
         * (D * math.log(2.) + 2. * LA.triangular_logdet(self.scale_tril))
Exemple #16
0
def logbeta(a, b):
    return torch.mvlgamma(a, 1) + torch.mvlgamma(b, 1) - torch.mvlgamma(
        a + b, 1)
Exemple #17
0
def dirichlet_logpdf(x, alpha):
    one = to_var(torch.FloatTensor([1.0]))
    return torch.mvlgamma(torch.sum(alpha, 1), 1) - torch.sum(torch.mvlgamma(alpha, 1), 1) + \
           torch.sum((alpha - one) * torch.log(x))
Exemple #18
0
 def pointwise_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     r = torch.tensor([0, 1, 10, 0], dtype=torch.int8)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     s = torch.tensor([4, 0, 1, 0], dtype=torch.int8)
     f = torch.zeros(3)
     g = torch.tensor([-1, 0, 1])
     w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637])
     return (
         torch.abs(torch.tensor([-1, -2, 3])),
         torch.absolute(torch.tensor([-1, -2, 3])),
         torch.acos(a),
         torch.arccos(a),
         torch.acosh(a.uniform_(1.0, 2.0)),
         torch.add(a, 20),
         torch.add(a, torch.randn(4, 1), alpha=10),
         torch.addcdiv(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.addcmul(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.angle(a),
         torch.asin(a),
         torch.arcsin(a),
         torch.asinh(a),
         torch.arcsinh(a),
         torch.atan(a),
         torch.arctan(a),
         torch.atanh(a.uniform_(-1.0, 1.0)),
         torch.arctanh(a.uniform_(-1.0, 1.0)),
         torch.atan2(a, a),
         torch.bitwise_not(t),
         torch.bitwise_and(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_or(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_xor(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.ceil(a),
         torch.clamp(a, min=-0.5, max=0.5),
         torch.clamp(a, min=0.5),
         torch.clamp(a, max=0.5),
         torch.clip(a, min=-0.5, max=0.5),
         torch.conj(a),
         torch.copysign(a, 1),
         torch.copysign(a, b),
         torch.cos(a),
         torch.cosh(a),
         torch.deg2rad(
             torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0,
                                                              -90.0]])),
         torch.div(a, b),
         torch.divide(a, b, rounding_mode="trunc"),
         torch.divide(a, b, rounding_mode="floor"),
         torch.digamma(torch.tensor([1.0, 0.5])),
         torch.erf(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfc(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfinv(torch.tensor([0.0, 0.5, -1.0])),
         torch.exp(torch.tensor([0.0, math.log(2.0)])),
         torch.exp2(torch.tensor([0.0, math.log(2.0), 3.0, 4.0])),
         torch.expm1(torch.tensor([0.0, math.log(2.0)])),
         torch.fake_quantize_per_channel_affine(
             torch.randn(2, 2, 2),
             (torch.randn(2) + 1) * 0.05,
             torch.zeros(2),
             1,
             0,
             255,
         ),
         torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255),
         torch.float_power(torch.randint(10, (4, )), 2),
         torch.float_power(torch.arange(1, 5), torch.tensor([2, -3, 4,
                                                             -5])),
         torch.floor(a),
         # torch.floor_divide(torch.tensor([4.0, 3.0]), torch.tensor([2.0, 2.0])),
         # torch.floor_divide(torch.tensor([4.0, 3.0]), 1.4),
         torch.fmod(torch.tensor([-3, -2, -1, 1, 2, 3]), 2),
         torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.frac(torch.tensor([1.0, 2.5, -3.2])),
         torch.randn(4, dtype=torch.cfloat).imag,
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1])),
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])),
         torch.lerp(torch.arange(1.0, 5.0),
                    torch.empty(4).fill_(10), 0.5),
         torch.lerp(
             torch.arange(1.0, 5.0),
             torch.empty(4).fill_(10),
             torch.full_like(torch.arange(1.0, 5.0), 0.5),
         ),
         torch.lgamma(torch.arange(0.5, 2, 0.5)),
         torch.log(torch.arange(5) + 10),
         torch.log10(torch.rand(5)),
         torch.log1p(torch.randn(5)),
         torch.log2(torch.rand(5)),
         torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([-100.0, -200.0, -300.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([1.0, 2000.0, 30000.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-100.0, -200.0, -300.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([1.0, 2000.0, 30000.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logical_and(r, s),
         torch.logical_and(r.double(), s.double()),
         torch.logical_and(r.double(), s),
         torch.logical_and(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)),
         torch.logical_not(
             torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)),
         torch.logical_not(
             torch.tensor([0.0, 1.0, -10.0], dtype=torch.double),
             out=torch.empty(3, dtype=torch.int16),
         ),
         torch.logical_or(r, s),
         torch.logical_or(r.double(), s.double()),
         torch.logical_or(r.double(), s),
         torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_xor(r, s),
         torch.logical_xor(r.double(), s.double()),
         torch.logical_xor(r.double(), s),
         torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logit(torch.rand(5), eps=1e-6),
         torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])),
         torch.i0(torch.arange(5, dtype=torch.float32)),
         torch.igamma(a, b),
         torch.igammac(a, b),
         torch.mul(torch.randn(3), 100),
         torch.multiply(torch.randn(4, 1), torch.randn(1, 4)),
         torch.mvlgamma(torch.empty(2, 3).uniform_(1.0, 2.0), 2),
         torch.tensor([float("nan"),
                       float("inf"), -float("inf"), 3.14]),
         torch.nan_to_num(w),
         torch.nan_to_num(w, nan=2.0),
         torch.nan_to_num(w, nan=2.0, posinf=1.0),
         torch.neg(torch.randn(5)),
         # torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]),
         torch.polygamma(1, torch.tensor([1.0, 0.5])),
         torch.polygamma(2, torch.tensor([1.0, 0.5])),
         torch.polygamma(3, torch.tensor([1.0, 0.5])),
         torch.polygamma(4, torch.tensor([1.0, 0.5])),
         torch.pow(a, 2),
         torch.pow(torch.arange(1.0, 5.0), torch.arange(1.0, 5.0)),
         torch.rad2deg(
             torch.tensor([[3.142, -3.142], [6.283, -6.283],
                           [1.570, -1.570]])),
         torch.randn(4, dtype=torch.cfloat).real,
         torch.reciprocal(a),
         torch.remainder(torch.tensor([-3.0, -2.0]), 2),
         torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.round(a),
         torch.rsqrt(a),
         torch.sigmoid(a),
         torch.sign(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sgn(a),
         torch.signbit(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sin(a),
         torch.sinc(a),
         torch.sinh(a),
         torch.sqrt(a),
         torch.square(a),
         torch.sub(torch.tensor((1, 2)), torch.tensor((0, 1)), alpha=2),
         torch.tan(a),
         torch.tanh(a),
         torch.trunc(a),
         torch.xlogy(f, g),
         torch.xlogy(f, g),
         torch.xlogy(f, 4),
         torch.xlogy(2, g),
     )
Exemple #19
0
 def _lprob(cls, x, dim, df, sm):
     a = th.logdet(x) * (df - dim - 1) / 2 - th.trace(sm.inverse() @ x) / 2
     b = np.log(2) * df * dim / 2 + th.logdet(sm) * df / 2 + th.mvlgamma(
         df / 2, dim)
     return a - b
 def log_normalizer(self):
     D = self.event_shape[-1]
     return torch.mvlgamma(.5 * self.df, D) + .5 * self.df \
         * (D * _LOG_2 - 2. * _triangular_logdet(self.scale_tril))
Exemple #21
0
 def _log_normalizer(self, x, y):
     p = self._event_shape[-1]
     return ((y + (p + 1) / 2) *
             (-torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p) +
             torch.mvlgamma(y + (p + 1) / 2, p=p))
Exemple #22
0
def Split(X,XXT,argmax,Nk,sons_LL_b,X_sons_b,X_father_b,father_LL_b,C1,c1_temp,clusters_LR,it_split,m_v_sons_b,m_v_father_b,b_sons_b,b_father_b,SigmaXY_b,SigmaXY_i_b,SIGMAxylab_b,Nk_b,X1_b,X2_00_b,X2_01_b,X2_11_b):
    it_split=it_split+1
    K_C_Split=torch.max(argmax[:,1])+1
    if(Nk.shape[0]>K_C_Split):
        K_C_Split=Nk.shape[0]
    Nk_s = torch.zeros(K_C_Split).float().to(Global.device)
    Nk.zero_()
    a_prior_sons = Nk_s
    Global.psi_prior_sons = torch.mul(torch.pow(a_prior_sons, 2).unsqueeze(1), torch.eye(2).reshape(-1, 4).to(Global.device))
    Global.ni_prior_sons = (Global.C_prior * a_prior_sons) - 3
    Nk.index_add_(0, argmax[:, 0], Global.ones)
    Nk = Nk + 0.0000000001
    Nk_s.index_add_(0, argmax[:, 1], Global.ones)
    Nk_s = Nk_s + 0.0000000001



    sons_LL=sons_LL_b[0:Nk_s.shape[0]].zero_()
    X_sons=X_sons_b[0:Nk_s.shape[0]].zero_()
    X_father=X_father_b[0:Global.K_C+1].zero_()
    father_LL=father_LL_b[0:Global.K_C+1].zero_()


    X_sons.index_add_(0,argmax[:,1],X[:,0:2])
    X_father.index_add_(0,argmax[:,0],X[:,0:2])
    sons_LL[:,0]= -torch.pow(X_sons[:,0],2)
    sons_LL[:,1]= -torch.mul(X_sons[:,0],X_sons[:,1])
    sons_LL[:,2]= -sons_LL[:,1]
    sons_LL[:,3]= -torch.pow(X_sons[:,1],2)

    father_LL[:, 0] = -torch.pow(X_father[:,0], 2)
    father_LL[:, 1] = -torch.mul(X_father[:,0], X_father[:,1])
    father_LL[:, 2] = -father_LL[:, 1]
    father_LL[:, 3] = -torch.pow(X_father[:,1], 2)


    sons_LL.index_add_(0, argmax[:,1], XXT)
    father_LL.index_add_(0,argmax[:,0],XXT)

    ni_sons=torch.add(Global.ni_prior_sons,Nk_s)[0:sons_LL.shape[0]]
    ni_father=torch.add(Global.ni_prior,Nk)[0:father_LL.shape[0]]
    psi_sons=torch.add(sons_LL.reshape(-1,4),Global.psi_prior_sons)[0:ni_sons.shape[0]]
    psi_father=torch.add(father_LL.reshape(-1,4),Global.psi_prior)[0:ni_father.shape[0]]
    ni_sons[(ni_sons <= 1).nonzero()] = 2.00000001
    ni_father[(ni_father <= 1).nonzero()] = 2.00000001

    gamma_sons=torch.mvlgamma((ni_sons/2),2)
    gamma_father=torch.mvlgamma((ni_father/2),2)
    det_psi_sons=0.00000001+torch.add(torch.mul(psi_sons[:, 0], psi_sons[:, 3]),-torch.mul(psi_sons[:, 1], psi_sons[:, 2]))
    det_psi_father=0.00000001+torch.add(torch.mul(psi_father[:, 0], psi_father[:, 3]),-torch.mul(psi_father[:, 1], psi_father[:, 2]))
    det_psi_sons[(det_psi_sons <= 0).nonzero()] = 0.00000001
    det_psi_father[(det_psi_father <= 0).nonzero()] = 0.00000001

    det_psi_prior_sons=0.00000001+torch.add(torch.mul(Global.psi_prior_sons[:, 0], Global.psi_prior_sons[:, 3]),-torch.mul(Global.psi_prior_sons[:, 1], Global.psi_prior_sons[:, 2]))
    det_psi_prior_father=0.00000001+torch.add(torch.mul(Global.psi_prior[:, 0], Global.psi_prior[:, 3]),-torch.mul(Global.psi_prior[:, 1], Global.psi_prior[:, 2]))
    det_psi_prior_sons[(det_psi_prior_sons <= 0).nonzero()] = 0.00000001
    det_psi_prior_father[(det_psi_prior_father <= 0).nonzero()] = 0.00000001

    Global.ni_prior_sons[(Global.ni_prior_sons <= 1).nonzero()] = 2.00000001
    Global.ni_prior[(Global.ni_prior <= 1).nonzero()] = 2.00000001
    gamma_prior_sons=torch.mvlgamma((Global.ni_prior_sons / 2),2)
    gamma_prior_father=torch.mvlgamma((Global.ni_prior / 2),2)

    ll_sons= -(torch.mul(torch.log((Global.PI)),(Nk_s)))+ \
             torch.add(gamma_sons,-gamma_prior_sons) + \
             torch.mul(torch.log(det_psi_prior_sons), (Global.ni_prior_sons * 0.5)) - \
             torch.mul(torch.log(det_psi_sons),(ni_sons * 0.5))+\
             torch.log(Nk_s[0:sons_LL.shape[0]])

    ll_father= -(torch.mul(torch.log((Global.PI)),(Nk)))+ \
               torch.add(gamma_father,-gamma_prior_father) + \
               torch.mul(torch.log((det_psi_father)), Global.ni_prior * 0.5) - \
               torch.mul(torch.log(det_psi_father),ni_father * 0.5) +\
               torch.log(Nk[0:father_LL.shape[0]])

    ll_sons_min=torch.min(ll_sons[1:ll_sons.shape[0]].masked_select(~(torch.isinf(ll_sons[1:ll_sons.shape[0]])^torch.isnan(ll_sons[1:ll_sons.shape[0]]))))
    ll_sons_max=torch.max(ll_sons[1:ll_sons.shape[0]].masked_select(~(torch.isinf(ll_sons[1:ll_sons.shape[0]])^torch.isnan(ll_sons[1:ll_sons.shape[0]]))))
    ll_father_min=torch.min(ll_father[1:ll_father.shape[0]].masked_select(~(torch.isinf(ll_father[1:ll_father.shape[0]])^torch.isnan(ll_father[1:ll_father.shape[0]]))))
    ll_father_max=torch.max(ll_father[1:ll_father.shape[0]].masked_select(~(torch.isinf(ll_father[1:ll_father.shape[0]])^torch.isnan(ll_father[1:ll_father.shape[0]]))))

    ll_sons_min=torch.min(ll_sons_min,ll_father_min)
    ll_sons_max=torch.max(ll_sons_max,ll_father_max)


    ll_sons=torch.div(torch.add(ll_sons,-ll_sons_min),(ll_sons_max-ll_sons_min))*(-1000)+0.1
    ll_father=torch.div(torch.add(ll_father,-ll_sons_min),(ll_sons_max-ll_sons_min))*(-1000)+0.1




    alpha=torch.Tensor([float(1000000)]).to(Global.device)
    beta=torch.Tensor([Global.int_scale*alpha+Global.int_scale]).to(Global.device)
    Nk.zero_()
    Nk_s.zero_()
    Nk.index_add_(0, argmax[:, 0], Global.ones)
    Nk = Nk + 0.0000000001
    Nk_s.index_add_(0, argmax[:, 1], Global.ones)
    Nk_s = Nk_s + 0.0000000001
    v_father=Nk
    v_sons=Nk_s




    m_v_sons=m_v_sons_b[0:Nk_s.shape[0]].zero_()
    m_v_father=m_v_father_b[0:Nk.shape[0]].zero_()
    b_sons = b_sons_b[0:Nk_s.shape[0]].zero_()
    b_father = b_father_b[0:Nk.shape[0]].zero_()
    m_v_sons.index_add_(0, argmax[:, 1], X[:,2:])
    m_v_father.index_add_(0, argmax[:, 0], X[:,2:])
    m_sons=torch.div(m_v_sons,v_sons.unsqueeze(1))
    m_father=torch.div(m_v_father,v_father.unsqueeze(1))
    a_sons=torch.add(Nk_s/2,alpha).unsqueeze(1)
    a_father=torch.add(Nk/2,alpha).unsqueeze(1)
    b_sons.index_add_(0, argmax[:, 1], torch.pow(X[:, 2:],2))
    b_father.index_add_(0, argmax[:, 0],torch.pow(X[:, 2:],2))
    b_sons=b_sons/2
    b_father=b_father/2
    b_sons.add_(torch.add(beta,-torch.mul(torch.pow(m_sons,2),v_sons.unsqueeze(1))/2))
    b_father.add_(torch.add(beta,-torch.mul(torch.pow(m_father,2),v_father.unsqueeze(1))/2))

    gamma_2_sons=torch.mvlgamma(a_sons,1)
    gamma_2_father=torch.mvlgamma(a_father,1)

    ll_2_sons=(0.5*torch.log(v_sons).unsqueeze(1))+\
              (torch.log(beta)*alpha)-\
              (a_sons*torch.log(b_sons))+\
              gamma_2_sons-\
              ((torch.mul(torch.log(Global.PI),Nk_s/2))+(0.301*Nk_s)).unsqueeze(1)
    ll_2_father=0.5*torch.log(v_father).unsqueeze(1)+ \
                (a_father*torch.log(b_father))+\
                gamma_2_father- \
                ((torch.mul(torch.log(Global.PI),Nk/2))+(0.301*Nk)).unsqueeze(1)

    ll_2_sons=torch.sum(ll_2_sons,1)[0:ll_sons.shape[0]]
    ll_2_father = torch.sum(ll_2_father, 1)[0:ll_father.shape[0]]

    ll_sons_min = torch.min(ll_2_sons[1:ll_2_sons.shape[0]].masked_select(~(torch.isnan(ll_2_sons[1:ll_2_sons.shape[0]])^torch.isinf(ll_2_sons[1:ll_2_sons.shape[0]]))))
    ll_sons_max = torch.max(ll_2_sons[1:ll_2_sons.shape[0]].masked_select(~(torch.isnan(ll_2_sons[1:ll_2_sons.shape[0]])^torch.isinf(ll_2_sons[1:ll_2_sons.shape[0]]))))
    ll_father_min = torch.min(ll_2_father.masked_select(~(torch.isnan(ll_2_father)^torch.isinf(ll_2_father))))
    ll_father_max = torch.max(ll_2_father.masked_select(~(torch.isnan(ll_2_father)^torch.isinf(ll_2_father))))



    ll_2_sons = torch.div(torch.add(ll_2_sons, -ll_sons_min), (ll_sons_max - ll_sons_min))*(-1000) + 0.1
    ll_2_father = torch.div(torch.add(ll_2_father, -ll_father_min), (ll_father_max - ll_father_min))*(-1000) + 0.1

    ll_sons.add_(ll_2_sons)
    ll_father.add_(ll_2_father)




    gamma_1_sons=torch.mvlgamma(Nk_s,1)
    gamma_1_father=torch.mvlgamma(Nk,1)
    ll_sons=torch.where(Nk_s[0:ll_sons.shape[0]]<35,Global.zeros[0:ll_sons.shape[0]]- torch.Tensor([float("inf")]).to(Global.device),ll_sons)
    ind_sons=clusters_LR[0:gamma_1_sons.shape[0]].long()
    ind_sons[ind_sons>ll_sons.shape[0]-1]=0 #TODO:Check if relevant
    prob=(Global.ALPHA_MS)+\
         ((ll_sons[ind_sons[:,0]]+\
           gamma_1_sons[ind_sons[:,0]]+\
           ll_sons[ind_sons[:,1]]+\
           gamma_1_sons[ind_sons[:,1]])[0:gamma_1_father.shape[0]-1]-
          ((gamma_1_father)+ll_father)[0:gamma_1_father.shape[0]-1])


    idx_rand=torch.where(torch.exp(prob) > 1.0, Global.N_index[0:prob.shape[0]].long(),Global.zeros[0:prob.shape[0]].long()).nonzero()[:, 0]
    if(Global.Print):
        print("Idx Split Size: ",idx_rand.shape[0])
    left = torch.zeros(Global.K_C + 1, 2).int().to(Global.device)
    left[:, 0] = Global.N_index[0:Global.K_C + 1]
    left[idx_rand,1]=1
    pixels_to_change = left[argmax[:,0],1]

    original=torch.where(pixels_to_change==1,argmax[:, 1],argmax[:,0])


    return argmax,Nk,original
 def log_normalizer(self):
     log_normalizer_1 = 0.5 * self.df * self.dimensionality * math.log(2)
     log_normalizer_2 = 0.5 * self.df * self.scale_diag.log().sum(dim=-1)
     log_normalizer_3 = torch.mvlgamma(0.5 * self.df, self.dimensionality)
     return log_normalizer_1 + log_normalizer_2 + log_normalizer_3
Exemple #24
0
def Merge(X,argmax,Nk,it_merge,temp_b,m_v_father_b,m_v_sons_b,b_father_b,b_sons_b):
        """Merge step

        **Parameters**:
         - :math:`X[N,D]` - Data matrix  [Number of pixels, Dimenstion of the data].

         - :math:`argmax[N,2]' Pixel to SP matrix

        **Returns**:
         - Update argmax, and split_lvl
        """
        padded_matrix = Global.Padding0(argmax[:,0].reshape(Global.HEIGHT,-1)).reshape(-1).to(Global.device)
        pair = torch.zeros(Global.K_C + 1).int().to(Global.device)
        left = torch.zeros(Global.K_C + 1,2).int().to(Global.device)
        left[:,0]=torch.arange(0,Global.K_C+1)
        left[:,0]=Global.N_index[0:Global.K_C+1]

        it_merge = 0
        if(it_merge%4==0):
            ind_left = torch.masked_select(Global.inside_padded, (
                        (padded_matrix[Global.inside_padded] != padded_matrix[Global.inside_padded - 1]) & (
                            padded_matrix[Global.inside_padded - 1] > 0)))
            left[padded_matrix[ind_left],1] = padded_matrix[ind_left - 1].int()
        if (it_merge%4 == 1):
            ind_left = torch.masked_select(Global.inside_padded, (
                    (padded_matrix[Global.inside_padded] != padded_matrix[Global.inside_padded + 1]) &(
                    padded_matrix[Global.inside_padded + 1] > 0)))
            left[padded_matrix[ind_left], 1] = padded_matrix[ind_left + 1].int()
        if (it_merge%4 == 2):
            ind_left = torch.masked_select(Global.inside_padded, (
                    (padded_matrix[Global.inside_padded] != padded_matrix[Global.inside_padded - (Global.WIDTH+2)]) & (
                    padded_matrix[Global.inside_padded - (Global.WIDTH+2)] > 0)))
            left[padded_matrix[ind_left], 1] = padded_matrix[ind_left - (Global.WIDTH+2)].int()
        if (it_merge%4 == 3):
            ind_left = torch.masked_select(Global.inside_padded, (
                    (padded_matrix[Global.inside_padded] != padded_matrix[Global.inside_padded + (Global.WIDTH+2)]) & (
                    padded_matrix[Global.inside_padded + (Global.WIDTH+2)] > 0)))
            left[padded_matrix[ind_left], 1] = padded_matrix[ind_left + (Global.WIDTH+2)].int()

        it_merge=it_merge+1


        for i in range(0, Global.K_C + 1):
            val = left[i, 1]
            if ((val > 0 )and (val!=i)):
                if ((pair[i] == 0) and (pair[val] == 0)):
                    if (val < i):
                        pair[i] = val
                        pair[val] = val
                    else:
                        pair[val] = i
                        pair[i] = i

        left[:,1]=pair

        Nk.zero_()
        Nk.index_add_(0, argmax[:, 0], Global.ones)
        Nk = Nk + 0.0000000001

        Nk_merged=torch.add(Nk,Nk[left[:,1].long()])
        alpha=torch.Tensor([float(1000000)]).to(Global.device)
        beta=torch.Tensor([Global.int_scale*alpha+Global.int_scale]).to(Global.device)
        v_father = Nk
        v_merged = Nk_merged


        m_v_father = m_v_father_b[0:Nk.shape[0]].zero_()
        b_father = b_father_b[0:Nk.shape[0]].zero_()

        m_v_father.index_add_(0, argmax[:, 0], X[:, 2:])
        m_v_merged = torch.add(m_v_father, m_v_father[left[:, 1].long()])


        m_merged = torch.div(m_v_merged, v_merged.unsqueeze(1))
        m_father = torch.div(m_v_father, v_father.unsqueeze(1))
        a_father = torch.add(Nk / 2, alpha).unsqueeze(1)
        a_merged = torch.add(Nk_merged / 2, alpha).unsqueeze(1)
        b_father.index_add_(0, argmax[:, 0], torch.pow(X[:, 2:], 2))
        b_merged=torch.add(b_father,b_father[left[:,1].long()])
        b_father=b_father/2
        b_merged=b_merged/2
        b_father.add_(torch.add(beta, -torch.mul(torch.pow(m_father, 2), v_father.unsqueeze(1)) / 2))
        b_merged.add_(torch.add(beta, -torch.mul(torch.pow(m_merged, 2), v_merged.unsqueeze(1)) / 2))


        gamma_2_merged = torch.mvlgamma(a_merged,1)
        gamma_2_father = torch.mvlgamma(a_father, 1)


        ll_2_merged=0.5*torch.log(v_merged).unsqueeze(1)+ \
                    (a_merged*torch.log(b_merged))+\
                    gamma_2_merged- \
                    ((torch.mul(torch.log(Global.PI),Nk_merged/2))+(0.301*Nk_merged)).unsqueeze(1)


        ll_2_father=0.5*torch.log(v_father).unsqueeze(1)+ \
                    (a_father*torch.log(b_father))+\
                    gamma_2_father- \
                    ((torch.mul(torch.log(Global.PI),Nk/2))+(0.301*Nk)).unsqueeze(1)


        ll_2_father = torch.sum(ll_2_father, 1)[0:ll_2_father.shape[0]]
        ll_2_merged = torch.sum(ll_2_merged, 1)[0:ll_2_merged.shape[0]]



        ll_merged_min = torch.min(ll_2_merged[1:ll_2_merged.shape[0]].masked_select(
            ~(torch.isnan(ll_2_merged[1:ll_2_merged.shape[0]]) ^ torch.isinf(ll_2_merged[1:ll_2_merged.shape[0]]))))
        ll_merged_max = torch.max(ll_2_merged[1:ll_2_merged.shape[0]].masked_select(
            ~(torch.isnan(ll_2_merged[1:ll_2_merged.shape[0]]) ^ torch.isinf(ll_2_merged[1:ll_2_merged.shape[0]]))))
        ll_father_min = torch.min(
            ll_2_father.masked_select(~(torch.isnan(ll_2_father) ^ torch.isinf(ll_2_father))))
        ll_father_max = torch.max(
            ll_2_father.masked_select(~(torch.isnan(ll_2_father) ^ torch.isinf(ll_2_father))))

        ll_merged_min=torch.min(ll_merged_min,ll_father_min)
        ll_merged_max=torch.max(ll_merged_max,ll_father_max)


        ll_2_merged = torch.div(torch.add(ll_2_merged, -ll_merged_min), (ll_merged_max - ll_merged_min)) * (-10000) + 0.1
        ll_2_father = torch.div(torch.add(ll_2_father, -ll_merged_min), (ll_merged_max - ll_merged_min))*(-10000) + 0.1

        gamma_alpha_2=torch.mvlgamma(torch.Tensor([Global.ALPHA_MS2/2]).to(Global.device),1)
        gamma_alpha=torch.mvlgamma(torch.Tensor([Global.ALPHA_MS2]).to(Global.device),1)

        gamma_father=torch.mvlgamma(Nk,1)
        gamma_add_father=torch.mvlgamma(Nk_merged,1)
        gamma_alpha_father=torch.mvlgamma(Nk+Global.ALPHA_MS2/2,1)
        gamma_add_alpha_merged = torch.mvlgamma(Nk_merged + Global.ALPHA_MS2, 1)

        prob = -Global.LOG_ALPHA_MS2+gamma_alpha-2*gamma_alpha_2 +\
               gamma_add_father-gamma_add_alpha_merged+ \
               gamma_alpha_father[left[:, 0].long()]-gamma_father[left[:, 0].long()]+ \
               gamma_alpha_father[left[:, 1].long()] - gamma_father[left[:, 1].long()] - 2 + \
               ll_2_merged[left[:, 0].long()] - ll_2_father[left[:, 0].long()] - ll_2_father[[left[:, 1].long()]]

        prob=torch.where(((left[:,0]==left[:,1])+(left[:,1]==0))>0,-torch.Tensor([float("inf")]).to(Global.device),prob)

        idx_rand=torch.where(torch.exp(prob) > 1.0, Global.N_index[0:prob.shape[0]].long(),Global.zeros[0:prob.shape[0]].long()).nonzero()[:, 0]

        pair[left[:,1].long()]=left[left[:,0].long()][:,0]

        left[:,1]=Global.N_index[0:Global.K_C+1]
        left[idx_rand.long(),1]=pair[idx_rand.long()]

        argmax[:,0] = left[argmax[:,0],1]


        if(Global.Print):
            print("Idx Merge Size: ",idx_rand.shape[0])

        Global.split_lvl[idx_rand]= Global.split_lvl[idx_rand]*2
        Global.split_lvl[left[idx_rand,1].long()]=Global.split_lvl[idx_rand]
Exemple #25
0
 def _log_normalizer(self, eta1, eta2):
     D = self.event_shape[-1]
     a = -eta1 - .5 * (D + 1)
     return torch.mvlgamma(a, D) - a * util.posdef_logdet(-eta2)[0]
Exemple #26
0
# i0
torch.i0(torch.arange(5, dtype=torch.float32))

# igamma/igammac
a1 = torch.tensor([4.0])
a2 = torch.tensor([3.0, 4.0, 5.0])
torch.igamma(a1, a2)
torch.igammac(a1, a2)

# mul/multiply
torch.mul(torch.randn(3), 100)
torch.multiply(torch.randn(4, 1), torch.randn(1, 4))

# mvlgamma
torch.mvlgamma(torch.empty(2, 3).uniform_(1, 2), 2)

# nan_to_num
w = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14])
torch.nan_to_num(x)
torch.nan_to_num(x, nan=2.0)
torch.nan_to_num(x, nan=2.0, posinf=1.0)

# neg/negative
torch.neg(torch.randn(5))

# nextafter
eps = torch.finfo(torch.float32).eps
torch.nextafter(torch.tensor([1, 2]),
                torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps])