Example #1
0
 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     logits, value = broadcast_all(self.logits.clone(), value)
     log_factorial_n = torch.lgamma(value.sum(-1) + 1)
     log_factorial_xs = torch.lgamma(value + 1).sum(-1)
     logits[(value == 0) & (logits == -float('inf'))] = 0
     log_powers = (logits * value).sum(-1)
     return log_factorial_n - log_factorial_xs + log_powers
Example #2
0
 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     y = (value - self.loc) / self.scale
     Z = (self.scale.log() +
          0.5 * self.df.log() +
          0.5 * math.log(math.pi) +
          torch.lgamma(0.5 * self.df) -
          torch.lgamma(0.5 * (self.df + 1.)))
     return -0.5 * (self.df + 1.) * torch.log1p(y**2. / self.df) - Z
Example #3
0
 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     log_factorial_n = math.lgamma(self.total_count + 1)
     log_factorial_k = torch.lgamma(value + 1)
     log_factorial_nmk = torch.lgamma(self.total_count - value + 1)
     max_val = (-self.logits).clamp(min=0.0)
     # Note that: torch.log1p(-self.probs)) = max_val - torch.log1p((self.logits + 2 * max_val).exp()))
     return (log_factorial_n - log_factorial_k - log_factorial_nmk +
             value * self.logits + self.total_count * max_val -
             self.total_count * torch.log1p((self.logits + 2 * max_val).exp()))
Example #4
0
 def logprob(self, x):
   return torch.sum(
     self.shape * torch.log(self.rate)
     - torch.lgamma(self.shape)
     + (self.shape - 1) * torch.log(x)
     - self.rate * x
   )
Example #5
0
    def __init__(self, shape, rate, mean=None, variance=None):
        """
        Args:
            shape (torch.autograd.Variable): row vector of shape parameters
            rate (torch.autograd.Variable): row vector of rate parameters
        """
        super(InverseGamma, self).__init__()
        if mean is not None:
            assert variance is not None, "Must provide variance with mean"
            assert isinstance(mean, Variable), \
                "Provide mean as Pytorch Variable"
            assert isinstance(variance, Variable), \
                "Provide variance as Pytorch Variable"
            # Compute shape & rate from provided mean & var
            shape = th.pow(mean, 2) / variance + 2.0
            rate = th.pow(mean * (shape - 1.0), -1)
        assert isinstance(shape, Variable), "Provide shape as torch Variable"
        assert isinstance(rate, Variable), "Provide rate as torch Variable"
        if len(shape.size()) < 2:  # Given as 1d tensor
            shape = shape.view(1, -1)
        if len(rate.size()) < 2:  # Given as 1d tensor
            rate = rate.view(1, -1)
        self.shape = shape
        self.rate = rate
        self._coefficient = th.pow(rate, -shape) / \
                            Variable(th.exp(th.lgamma(shape.data)))
        self._dim = shape.size(1)

        # For RNG using gamma distribution:
        self._gamma_shape = self.shape.data.numpy().flatten()
        self._gamma_scale = self.rate.data.numpy().flatten()
Example #6
0
    def loss(self, partitions, lambdas, gamma):
        """
            Computes loss

            inputs:
                    partitions - torch.Tensor, size = (batch_size, seq_len, number of classes + 1)
                    lambdas - torch.Tensor, size = (batch_size, seq_len, number of classes), model output
                    gamma - torch.Tensor, size = (n_clusters, batch_size), probabilities p(k|x_n)

            outputs:
                    loss - torch.Tensor, size = (1), sum of output log likelihood weighted with convoluted gamma
                           and prior distribution log likelihood
        """
        # computing poisson parameters
        dts = partitions[:, 0, 0].to(self.device)
        dts = dts[None, :, None, None].to(self.device)
        tmp = lambdas * dts

        # preparing partitions
        p = partitions[None, :, :, 1:].to(self.device)

        # computing log likelihoods of every timestamp
        tmp1 = tmp - p * torch.log(tmp + self.epsilon) + torch.lgamma(p + 1)

        # computing log likelihoods of data points
        tmp2 = torch.sum(tmp1, dim=(2, 3))

        # computing loss
        tmp3 = gamma.to(self.device) * tmp2
        loss = torch.sum(tmp3)

        return loss
Example #7
0
    def eval_p_x_z(self, X, F, BG, noise='gamma'):

        target = X[:, X.shape[1]//2][:,None]
        pred = F

        if BG is not None:

            bg_facs = BG.mean(-1).mean(-1)
            bg_map = BG.mean(0)
            bg_map /= bg_map.mean(-1).mean(-1)
            BG_res = bg_map[None,:,:] * bg_facs[:,None,None]

            pred = pred.reshape([self.batch_size, self.n_samples, target.shape[-2], target.shape[-1]]) + BG_res[:,None] * self.ll_pars['backg_max'] + self.ll_pars['baseline']

        else:

            pred = pred.reshape([self.batch_size, self.n_samples, target.shape[-2], target.shape[-1]]) + self.ll_pars['backg']

        if noise == 'poisson':

            return (target * torch.log(pred + 1e-6) - pred - torch.lgamma(target+1)).sum(-1).sum(-1)

        if noise == 'gamma':

            target = torch.clamp(target-self.ll_pars['baseline'],1,np.inf)
            pred = torch.clamp(pred-self.ll_pars['baseline'],1,np.inf)
            k = pred/self.ll_pars['theta']

            return ((k-1)*torch.log(target) - target/self.ll_pars['theta'] - k*torch.log(self.ll_pars['theta']) - torch.lgamma(k)).sum(-1).sum(-1)    
Example #8
0
def NB_log_prob(x, mu, theta, eps=1e-8):
    """
    Adapted from https://github.com/YosefLab/scVI/blob/master/scvi/models/log_likelihood.py
    """

    log_theta_mu_eps = torch.log(theta + mu + eps)

    res = (
        theta * (torch.log(theta + eps) - log_theta_mu_eps)
        + x * (torch.log(mu + eps) - log_theta_mu_eps)
        + torch.lgamma(x + theta)
        - torch.lgamma(theta)
        - torch.lgamma(x + 1)
    )

    return res
Example #9
0
 def log_prob(self, value):
     # See: https://mc-stan.org/docs/2_25/functions-reference/cholesky-lkj-correlation-distribution.html
     # The probability of a correlation matrix is proportional to
     #   determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1))
     # Additionally, the Jacobian of the transformation from Cholesky factor to
     # correlation matrix is:
     #   prod(L_ii ^ (D - i))
     # So the probability of a Cholesky factor is propotional to
     #   prod(L_ii ^ (2 * concentration - 2 + D - i)) = prod(L_ii ^ order_i)
     # with order_i = 2 * concentration - 2 + D - i
     if self._validate_args:
         self._validate_sample(value)
     diag_elems = value.diagonal(dim1=-1, dim2=-2)[..., 1:]
     order = torch.arange(2, self.dim + 1)
     order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order
     unnormalized_log_pdf = torch.sum(order * diag_elems.log(), dim=-1)
     # Compute normalization constant (page 1999 of [1])
     dm1 = self.dim - 1
     alpha = self.concentration + 0.5 * dm1
     denominator = torch.lgamma(alpha) * dm1
     numerator = torch.mvlgamma(alpha - 0.5, dm1)
     # pi_constant in [1] is D * (D - 1) / 4 * log(pi)
     # pi_constant in multigammaln is (D - 1) * (D - 2) / 4 * log(pi)
     # hence, we need to add a pi_constant = (D - 1) * log(pi) / 2
     pi_constant = 0.5 * dm1 * math.log(math.pi)
     normalize_term = pi_constant + numerator - denominator
     return unnormalized_log_pdf - normalize_term
Example #10
0
 def log_prob(self, value):
     value = torch.as_tensor(value, dtype=self.rate.dtype, device=self.rate.device)
     if self._validate_args:
         self._validate_sample(value)
     return (torch.xlogy(self.concentration, self.rate) +
             torch.xlogy(self.concentration - 1, value) -
             self.rate * value - torch.lgamma(self.concentration))
        def exp_log_inverse_gamma(shape, exp_rate, exp_log_rate, exp_log_x,
                                  exp_x_inverse):
            """
            Calculates the expectation of the log of an inverse gamma distribution p under
            the posterior distribution q
            E_q[log p(x | shape, rate)]


            Args:
            shape: float, the shape parameter of the gamma distribution
            exp_rate: torch tensor, the expectation of the rate parameter under q
            exp_log_rate: torch tensor, the expectation of the log of the rate parameter under q
            exp_log_x: torch tensor, the expectation of the log of the random variable under q
            exp_x_inverse: torch tensor, the expectation of the inverse of the random variable under q

            Returns:
            exp_log: torch tensor, E_q[log p(x | shape, rate)]
            """
            exp_log = - torch.lgamma(shape) + shape * exp_log_rate - (shape + 1) * exp_log_x\
                      -exp_rate * exp_x_inverse

            # We need to sum over all components since this is a vectorized implementation.
            # That is, we compute the sum over the individual expected values. For example,
            # in the horseshoe BLR model we have one local shrinkage parameter for each weight
            # and therefore one expected value for each of these shrinkage parameters.
            return torch.sum(exp_log)
Example #12
0
 def log_norm(self):
     scale, shape, rates = self.params.scale, self.params.shape, \
         self.params.rates
     dim = rates.shape[-1]
     return (dim * torch.lgamma(shape) \
         - shape * rates.log().sum(dim=-1, keepdim=True) \
         - .5 * dim * scale.log()).sum(dim=-1)
Example #13
0
def log_radius(p_tot: float, p_k: float, x: torch.Tensor) -> torch.Tensor:
    """Computes the radius of the holes to introduce in the training data

    Args:
        p_tot (float): proportion of total expected points in B_1 U B_2 U ... B_K
        p_k (float): proportion of points sampled as {B_k}_{k=1}^{K} centers
        x (torch.Tensor): sets of points

    Returns:
        torch.Tensor: radius of any B_k
    """
    log_d = log_density(x)
    D = x.shape[1]

    # log_r = torch.pow(
    #     (p_tot / p_k)
    #     * (torch.lgamma(torch.Tensor([D / 2 + 1])) - log_d).exp()
    #     / (pi ** (D / 2)),
    #     1 / D,
    # )
    log_r = (
        1
        / D
        * (
            t_log(p_tot / p_k)
            + torch.lgamma(torch.Tensor([D / 2 + 1]))
            - log_d
            - t_log(pi ** (D / 2))
        )
    )
    return log_r
Example #14
0
 def _log_norm(self, natural_parameters=None):
     if natural_parameters is None:
         natural_parameters = self.natural_parameters
     means, scales, shape, rate = self.to_std_parameters(natural_parameters)
     dim = means.shape[1]
     return torch.lgamma(shape) - shape * rate.log() \
         - .5 * dim * scales.log().sum()
Example #15
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     log_factorial_n = torch.lgamma(self.total_count + 1)
     log_factorial_k = torch.lgamma(value + 1)
     log_factorial_nmk = torch.lgamma(self.total_count - value + 1)
     # k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p)
     #     (case logit < 0)              = k * logit - n * log1p(e^logit)
     #     (case logit > 0)              = k * logit - n * (log(p) - log(1 - p)) + n * log(p)
     #                                   = k * logit - n * logit - n * log1p(e^-logit)
     #     (merge two cases)             = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|)
     normalize_term = (self.total_count * _clamp_by_zero(self.logits) +
                       self.total_count *
                       torch.log1p(torch.exp(-torch.abs(self.logits))) -
                       log_factorial_n)
     return value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term
    def log_prob(self, value):
        """

        """

        local_loc, local_scale, value = torch.broadcast_tensors(
            self.loc, self.scale, value
        )
        # local_cov = self.scale_to_cov(local_scale)
        # logdet = torch.logdet(local_cov)
        # diff = (value - local_loc).unsqueeze(-1)
        # isoval = torch.matmul(local_cov.inverse(), diff)
        # isoval = torch.matmul(diff.transpose(-1, -2), isoval)
        # isoval = isoval.squeeze(-1)
        # isoval = isoval.squeeze(-1)

        # central_term = (
        #     -0.5 * (self.d + self.df) * (1.0 + (1.0 / self.df * isoval)).log()
        # )
        # res = (
        #     torch.lgamma(0.5 * (self.df + self.d))
        #     - torch.lgamma(0.5 * self.df)
        #     - 0.5 * self.d * (self.df.log() + np.log(np.pi))
        #     - 0.5 * logdet
        #     + central_term
        # )

        local_cov = (local_scale ** 2.0).unsqueeze(-1)
        logdet = local_scale.log().sum(-1, keepdims=True)
        logdet = 2.0 * logdet
        diff = (value - local_loc).unsqueeze(-1)
        isoval = (1.0 / local_cov) * diff
        isoval = torch.matmul(diff.transpose(-1, -2), isoval)
        isoval = isoval.squeeze(-1)
        # isoval = isoval.squeeze(-1)

        central_term = (
            -0.5 * (self.d + self.df) * (1.0 + (1.0 / self.df * isoval)).log()
        )
        res = (
            torch.lgamma(0.5 * (self.df + self.d))
            - torch.lgamma(0.5 * self.df)
            - 0.5 * self.d * (self.df.log() + np.log(np.pi))
            - 0.5 * logdet
            + central_term
        )
        return res.squeeze(-1)
Example #17
0
 def pdfvectors_from_rvectors(self, rvecs):
     dim = rvecs.shape[-1] // 2
     shape = (EPS + rvecs[:, :dim]).exp()
     rate = (EPS + rvecs[:, dim:]).exp()
     lnorm = torch.lgamma(shape) - shape * torch.log(rate)
     lnorm = torch.sum(lnorm, dim=-1, keepdim=True)
     retval = torch.cat([-rate, shape - 1, -lnorm], dim=-1)
     return retval
Example #18
0
 def __log_surface_area(self):
     '''
     :return: 单位半径,(self._dim + 1)维超球体的表面积的自然对数值
     '''
     #torch.gamma():  the log of the gamma function
     return math.log(2) + (
         (self._dim + 1) / 2) * math.log(math.pi) - torch.lgamma(
             torch.Tensor([(self._dim + 1) / 2], device=self.device))
Example #19
0
    def existing_method(x, mu, theta, pi, eps=1e-8):
        case_zero = (F.softplus((- pi + theta * torch.log(theta + eps) - theta * torch.log(theta + mu + eps))) -
                     F.softplus(-pi))

        case_non_zero = - pi - F.softplus(-pi) + \
            theta * torch.log(theta + eps) - \
            theta * torch.log(theta + mu + eps) + \
            x * torch.log(mu + eps) - \
            x * torch.log(theta + mu + eps) + \
            torch.lgamma(x + theta) - \
            torch.lgamma(theta) - \
            torch.lgamma(x + 1)

        res = torch.mul((x < eps).type(torch.float32), case_zero) + \
            torch.mul((x > eps).type(torch.float32), case_non_zero)

        return torch.sum(res, dim=-1)
Example #20
0
def kl_divergence(alpha, num_classes, device=None):
    if not device:
        device = get_device()
    beta = torch.ones([1, num_classes], dtype=torch.float32, device=device)
    S_alpha = torch.sum(alpha, dim=1, keepdim=True)
    S_beta = torch.sum(beta, dim=1, keepdim=True)
    lnB = torch.lgamma(S_alpha) - \
        torch.sum(torch.lgamma(alpha), dim=1, keepdim=True)
    lnB_uni = torch.sum(torch.lgamma(beta), dim=1,
                        keepdim=True) - torch.lgamma(S_beta)

    dg0 = torch.digamma(S_alpha)
    dg1 = torch.digamma(alpha)

    kl = torch.sum(
        (alpha - beta) * (dg1 - dg0), dim=1, keepdim=True) + lnB + lnB_uni
    return kl
Example #21
0
File: prme.py Project: zan12/prme
 def global_bound(self):
   ln_p_ell = -self.D_ell*self.K/2*np.log(2*np.pi*self.b0) - torch.norm(
       self.ell.data).pow(2)/2/self.b0
   ln_p_v = (self.alpha0-1)*torch.sum(torch.log(1-self.v)) + (self.K-1)*(
       np.log(1+self.alpha0)-np.log(1)-np.log(self.alpha0))
   E_ln_eta = torch.digamma(self.gamma) - torch.digamma(
         torch.mm(torch.sum(self.gamma, dim=1, keepdim=True), torch.ones(
         1, self.D_vocab).to(self.device)))
   E_ln_p_eta = self.K*gammaln(
       self.D_vocab*self.gamma0) - self.K*self.D_vocab*gammaln(
       self.gamma0) + (self.gamma0-1)*torch.sum(E_ln_eta)
   H_eta =  - torch.sum(torch.lgamma(torch.sum(
       self.gamma, dim=1))) + torch.sum(
       torch.lgamma(self.gamma)) - torch.sum((self.gamma-1)*E_ln_eta)
   l_global = ln_p_ell.float().to(
       self.device)+ln_p_v.data+E_ln_p_eta.float().to(self.device)+H_eta
   return l_global
    def forward(self, y_pred, y_true):
        eps = 1e-6
        y_pred = y_pred.view(y_pred.shape[0], -1)
        y_true = y_true.view(y_true.shape[0], -1)

        p_l = -y_true + y_pred * torch.log(y_true +
                                           eps) - torch.lgamma(y_pred + 1)
        return -torch.mean(p_l)
Example #23
0
    def loss(self, ops, y, **kwargs):
        alpha, beta, gamma, v = ops
        twoBlambda = 2 * beta * (1 + v)
        error = torch.abs(y - gamma)

        nll = 0.5 * torch.log(np.pi / v) \
              - alpha * torch.log(twoBlambda) \
              + (alpha + 0.5) * torch.log(v * (y - gamma) ** 2 + twoBlambda) \
              + torch.lgamma(alpha) \
              - torch.lgamma(alpha + 0.5)

        if self.kl_reg:
            kl = self.get_reg_kl(**kwargs)  # TODO: add support for this
            reg = error * kl
        else:
            reg = error * (2 * v + alpha)
        return (nll + self.reg_coefficient * reg).mean()
def log_zinb_positive(x: torch.Tensor,
                      mu: torch.Tensor,
                      theta: torch.Tensor,
                      pi: torch.Tensor,
                      eps=1e-8):
    """
    Log likelihood (scalar) of a minibatch according to a zinb model.
    Parameters
    ----------
    x
        Data
    mu
        mean of the negative binomial (has to be positive support) (shape: minibatch x vars)
    theta
        inverse dispersion parameter (has to be positive support) (shape: minibatch x vars)
    pi
        logit of the dropout parameter (real support) (shape: minibatch x vars)
    eps
        numerical stability constant
    Notes
    -----
    We parametrize the bernoulli using the logits, hence the softplus functions appearing.
    """
    # theta is the dispersion rate. If .ndimension() == 1, it is shared for all cells (regardless of batch or labels)
    if theta.ndimension() == 1:
        theta = theta.view(
            1,
            theta.size(0))  # In this case, we reshape theta for broadcasting

    softplus_pi = F.softplus(-pi)  #  uses log(sigmoid(x)) = -softplus(-x)
    log_theta_eps = torch.log(theta + eps)
    log_theta_mu_eps = torch.log(theta + mu + eps)
    pi_theta_log = -pi + theta * (log_theta_eps - log_theta_mu_eps)

    case_zero = F.softplus(pi_theta_log) - softplus_pi
    mul_case_zero = torch.mul((x < eps).type(torch.float32), case_zero)

    case_non_zero = (-softplus_pi + pi_theta_log + x *
                     (torch.log(mu + eps) - log_theta_mu_eps) +
                     torch.lgamma(x + theta) - torch.lgamma(theta) -
                     torch.lgamma(x + 1))
    mul_case_non_zero = torch.mul((x > eps).type(torch.float32), case_non_zero)

    res = mul_case_zero + mul_case_non_zero

    return res
Example #25
0
def zinb(x: torch.Tensor,
         mu: torch.Tensor,
         theta: torch.Tensor,
         pi: torch.Tensor,
         eps=1e-8):
    """Computes zero inflated negative binomial loss.

       Parameters
       ----------
       x: torch.Tensor
            Torch Tensor of ground truth data.
       mu: torch.Tensor
            Torch Tensor of means of the negative binomial (has to be positive support).
       theta: torch.Tensor
            Torch Tensor of inverses dispersion parameter (has to be positive support).
       pi: torch.Tensor
            Torch Tensor of logits of the dropout parameter (real support)
       eps: Float
            numerical stability constant.

       Returns
       -------
       If 'mean' is 'True' ZINB loss value gets returned, otherwise Torch tensor of losses gets returned.
    """
    # theta is the dispersion rate. If .ndimension() == 1, it is shared for all cells (regardless of batch or labels)
    if theta.ndimension() == 1:
        theta = theta.view(
            1,
            theta.size(0))  # In this case, we reshape theta for broadcasting

    softplus_pi = F.softplus(-pi)  #  uses log(sigmoid(x)) = -softplus(-x)
    log_theta_eps = torch.log(theta + eps)
    log_theta_mu_eps = torch.log(theta + mu + eps)
    pi_theta_log = -pi + theta * (log_theta_eps - log_theta_mu_eps)

    case_zero = F.softplus(pi_theta_log) - softplus_pi
    mul_case_zero = torch.mul((x < eps).type(torch.float32), case_zero)

    case_non_zero = (-softplus_pi + pi_theta_log + x *
                     (torch.log(mu + eps) - log_theta_mu_eps) +
                     torch.lgamma(x + theta) - torch.lgamma(theta) -
                     torch.lgamma(x + 1))
    mul_case_non_zero = torch.mul((x > eps).type(torch.float32), case_non_zero)

    res = mul_case_zero + mul_case_non_zero
    return res
Example #26
0
def _ggd_parameters(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    gamma = torch.arange(0.2, 10 + 0.001, 0.001).to(x)
    r_table = (torch.lgamma(1. / gamma) + torch.lgamma(3. / gamma) - 2 * torch.lgamma(2. / gamma)).exp()
    r_table = r_table.repeat(x.size(0), 1)

    sigma_sq = x.pow(2).mean(dim=(-1, -2))
    sigma = sigma_sq.sqrt().squeeze(dim=-1)

    assert not torch.isclose(sigma, torch.zeros_like(sigma)).all(), \
        'Expected image with non zero variance of pixel values'

    E = x.abs().mean(dim=(-1, -2))
    rho = sigma_sq / E ** 2

    indexes = (rho - r_table).abs().argmin(dim=-1)
    solution = gamma[indexes]
    return solution, sigma
Example #27
0
 def log_prob(self, value):
     if self.use_pykeops:
         formula = "wj+Log(Step(xi-gj-IntCst(1)))+(ai-IntCst(1))*Log(IfElse(xi-gj-IntCst(1),xi-gj,xi))-bi*(xi-gj)"
         variables = [
             "wj = Vj(1)",
             "gj = Vj(1)",
             "ai = Vi(1)",
             "bi = Vi(1)",
             "xi = Vi(1)",
         ]
         dtype = self.concentration.dtype
         my_routine = Genred(
             formula,
             variables,
             reduction_op="LogSumExp",
             axis=1,
         )
         concentration, value, rate = torch.broadcast_tensors(
             self.concentration, value, self.rate)
         shape = value.shape
         result = my_routine(
             self.offset_logits.reshape(-1, 1),
             self.offset_samples.reshape(-1, 1).to(dtype),
             concentration.reshape(-1, 1),
             rate.reshape(-1, 1).contiguous(),
             value.reshape(-1, 1).to(dtype),
             backend=self.device_pykeops,
         )
         result = result.reshape(shape)
         result = (self.concentration * torch.log(self.rate) -
                   torch.lgamma(self.concentration) + result)
     else:
         value = torch.as_tensor(value).unsqueeze(-1)
         concentration = self.concentration.unsqueeze(-1)
         mask = value > self.offset_samples
         new_value = torch.where(mask, value - self.offset_samples,
                                 value.new_ones(()))
         obs_logits = (concentration * torch.log(self.rate) +
                       (concentration - 1) * torch.log(new_value) -
                       self.rate * (new_value) -
                       torch.lgamma(concentration))
         result = obs_logits + self.offset_logits + torch.log(mask)
         result = torch.logsumexp(result, -1)
     event_dims = tuple(-i for i in range(1, len(self.event_shape) + 1))
     return result.sum(event_dims)
Example #28
0
def log_mixture_nb(x, mu_1, mu_2, theta_1, theta_2, pi, eps=1e-8):
    """
    Note: All inputs should be torch Tensors
    log likelihood (scalar) of a minibatch according to a mixture nb model.
    pi is the probability to be in the first component.

    For totalVI, the first component should be background.

    Variables:
    mu1: mean of the first negative binomial component (has to be positive support) (shape: minibatch x genes)
    theta1: first inverse dispersion parameter (has to be positive support) (shape: minibatch x genes)
    mu2: mean of the second negative binomial (has to be positive support) (shape: minibatch x genes)
    theta2: second inverse dispersion parameter (has to be positive support) (shape: minibatch x genes)
        If None, assume one shared inverse dispersion parameter.
    eps: numerical stability constant
    """
    if theta_2 is not None:
        log_nb_1 = log_nb_positive(x, mu_1, theta_1)
        log_nb_2 = log_nb_positive(x, mu_2, theta_2)
    # this is intended to reduce repeated computations
    else:
        theta = theta_1
        if theta.ndimension() == 1:
            theta = theta.view(1, theta.size(
                0))  # In this case, we reshape theta for broadcasting

        log_theta_mu_1_eps = torch.log(theta + mu_1 + eps)
        log_theta_mu_2_eps = torch.log(theta + mu_2 + eps)
        lgamma_x_theta = torch.lgamma(x + theta)
        lgamma_theta = torch.lgamma(theta)
        lgamma_x_plus_1 = torch.lgamma(x + 1)

        log_nb_1 = (theta * (torch.log(theta + eps) - log_theta_mu_1_eps) + x *
                    (torch.log(mu_1 + eps) - log_theta_mu_1_eps) +
                    lgamma_x_theta - lgamma_theta - lgamma_x_plus_1)
        log_nb_2 = (theta * (torch.log(theta + eps) - log_theta_mu_2_eps) + x *
                    (torch.log(mu_2 + eps) - log_theta_mu_2_eps) +
                    lgamma_x_theta - lgamma_theta - lgamma_x_plus_1)

    logsumexp = torch.logsumexp(torch.stack((log_nb_1, log_nb_2 - pi)), dim=0)
    softplus_pi = F.softplus(-pi)

    log_mixture_nb = logsumexp - softplus_pi

    return log_mixture_nb
Example #29
0
def log_bnb_positive(x, mu, theta, alpha, eps=1e-8):  #
    """
    x
    mu: px_rate=l*scale, scale==Softmax(Linear())
    theta: px_r
    px_alpha
    
    
    Note: All inputs should be torch Tensors
    log likelihood (scalar) of a minibatch according to a nb model.

    Variables:
    mu: mean of the negative binomial (has to be positive support) (shape: minibatch x genes)
    theta: inverse dispersion parameter (has to be positive support) (shape: minibatch x genes)
    alpha: alpha+1 of beta distribution
    eps: numerical stability constant
    """
    if theta.ndimension() == 1:
        theta = theta.view(
            1,
            theta.size(0))  # In this case, we reshape theta for broadcasting

    beta = mu * alpha / (theta + eps)

    def log_beta_func(a, b):
        return torch.lgamma(a + eps) + torch.lgamma(b +
                                                    eps) - torch.lgamma(a + b +
                                                                        eps)

    '''res = (
        torch.lgamma(alpha+theta+eps)
        - torch.lgamma(theta+eps)
        - log_beta_func(alpha, beta)
        + (theta-1)*torch.log(x+eps)
        - (theta+alpha)*torch.log(beta+x)
    )'''
    res = (torch.lgamma(x + theta + eps) - torch.lgamma(theta + eps) -
           torch.lgamma(x + 1 + eps) - log_beta_func(alpha - 1, beta) +
           log_beta_func(alpha - 1 + theta, beta + x))

    #print('input', alpha.shape, beta.shape, alpha.max(), alpha.min(), beta.max(), beta.min())
    #print('log bnb', torch.lgamma(alpha+theta+eps), torch.lgamma(theta+1+eps), torch.lgamma(x+eps), log_beta_func(alpha, beta), log_beta_func(alpha+x, beta+x))
    #print('sum', torch.sum(res, dim=-1))

    return torch.sum(res, dim=-1)  #/alpha.shape[1]
Example #30
0
File: loss.py Project: dnbaker/mdn
def nb_loss(x, mu, theta, eps=EPS):
    """
    log likelihood (scalar) of a minibatch according to a nb model.
    Variables:
    mu: mean of the negative binomial (has to be positive support) (shape: minibatch x genes)
    theta: inverse dispersion parameter (has to be positive support) (shape: minibatch x genes)
    eps: numerical stability constant
    """
    if theta.ndimension() == 1:
        theta = theta.unsqueeze(0)

    log_theta_mu_eps = torch.log(theta + mu + eps)

    res = (theta * (torch.log(theta + eps) - log_theta_mu_eps) + x *
           (torch.log(mu + eps) - log_theta_mu_eps) + torch.lgamma(x + theta) -
           torch.lgamma(theta) - torch.lgamma(x + 1))

    return torch.sum(res, dim=-1)
    def log_mu_inner_sum(sigma, c, dim, k):
        dim = __to_tensor__(dim)

        a = torch.lgamma(dim) - torch.lgamma(dim - k) - torch.lgamma(k + 1)
        b = (dim - 1 - 2 * k).pow(2) * c * sigma.pow(2) / 2

        # Integral
        mu = (dim - 1 - 2 * k) * torch.sqrt(c) * sigma.pow(2)

        int_r_a = 2 * sigma**2 * \
            torch.exp(-mu**2 / (2 * sigma**2))

        int_r_b = np.sqrt(np.pi / 2) * mu * sigma * \
            (1 + torch.erf(mu / (np.sqrt(2) * sigma)))

        log_int_r = torch.log(int_r_a + int_r_b)

        return a + b + log_int_r
Example #32
0
File: xedl.py Project: hsljc/ae-dnn
def kl_dirichlet(alpha, beta):
    """Computes the KL-Divergence between two dirichlet distributions.

    Args:
        alpha: (N x K)-Tensor where the K-dimension describes the parameters of the dirichlet.
        beta: (N x K)-Tensor where the K-dimension describes the parameters of the dirichlet.
    """
    alpha_0 = torch.sum(alpha, dim=-1, keepdim=True)
    beta_0 = torch.sum(beta, dim=-1, keepdim=True)
    t1 = torch.lgamma(alpha_0) - torch.sum(
        torch.lgamma(alpha), dim=-1, keepdim=True)
    t2 = torch.lgamma(beta_0) - torch.sum(
        torch.lgamma(beta), dim=-1, keepdim=True)
    t3 = torch.sum(
        (alpha - beta) * (torch.digamma(alpha) - torch.digamma(alpha_0)),
        dim=-1,
        keepdim=True)
    return t1 - t2 + t3
Example #33
0
File: lkj.py Project: ucals/pyro
    def lkj_constant(self, eta, K):
        if self._lkj_constant is not None:
            return self._lkj_constant

        Km1 = K - 1

        constant = torch.lgamma(eta.add(0.5 * Km1)).mul(Km1)

        k = torch.linspace(start=1,
                           end=Km1,
                           steps=Km1,
                           dtype=eta.dtype,
                           device=eta.device)
        constant -= (k.mul(math.log(math.pi) * 0.5) +
                     torch.lgamma(eta.add(0.5 * (Km1 - k)))).sum()

        self._lkj_constant = constant
        return constant
Example #34
0
File: kl.py Project: lxlhh/pytorch
 def f(a, b, c, d):
     return -d * a / c + b * a.log() - torch.lgamma(b) + (b - 1) * torch.digamma(d) + (1 - b) * c.log()
Example #35
0
 def entropy(self):
     return (self.alpha - torch.log(self.beta) + torch.lgamma(self.alpha) +
             (1.0 - self.alpha) * torch.digamma(self.alpha))
Example #36
0
 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     return (self.alpha * torch.log(self.beta) +
             (self.alpha - 1) * torch.log(value) -
             self.beta * value - torch.lgamma(self.alpha))
Example #37
0
 def _log_normalizer(self, x, y):
     return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)
Example #38
0
 def _log_normalizer(self, x, y):
     return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())
Example #39
0
 def entropy(self):
     k = self.alpha.size(-1)
     a0 = self.alpha.sum(-1)
     return (torch.lgamma(self.alpha).sum(-1) - torch.lgamma(a0) -
             (k - a0) * torch.digamma(a0) -
             ((self.alpha - 1.0) * torch.digamma(self.alpha)).sum(-1))
Example #40
0
 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     return (self.concentration * torch.log(self.rate) +
             (self.concentration - 1) * torch.log(value) -
             self.rate * value - torch.lgamma(self.concentration))
 def digamma(x):
     """Finite difference approximation of digamma."""
     eps = x * 0.01
     return (torch.lgamma(x + eps) - torch.lgamma(x - eps)) / (2 * eps)
Example #42
0
 def _log_normalizer(self, x):
     return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))
Example #43
0
 def entropy(self):
     k = self.concentration.size(-1)
     a0 = self.concentration.sum(-1)
     return (torch.lgamma(self.concentration).sum(-1) - torch.lgamma(a0) -
             (k - a0) * torch.digamma(a0) -
             ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1))
Example #44
0
def _kl_gamma_gamma(p, q):
    t1 = q.concentration * (p.rate / q.rate).log()
    t2 = torch.lgamma(q.concentration) - torch.lgamma(p.concentration)
    t3 = (p.concentration - q.concentration) * torch.digamma(p.concentration)
    t4 = (q.rate - p.rate) * (p.concentration / p.rate)
    return t1 + t2 + t3 + t4
Example #45
0
 def entropy(self):
     lbeta = torch.lgamma(0.5 * self.df) + math.lgamma(0.5) - torch.lgamma(0.5 * (self.df + 1))
     return (self.scale.log() +
             0.5 * (self.df + 1) *
             (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df)) +
             0.5 * self.df.log() + lbeta)
Example #46
0
 def entropy(self):
     return (self.concentration - torch.log(self.rate) + torch.lgamma(self.concentration) +
             (1.0 - self.concentration) * torch.digamma(self.concentration))
Example #47
0
 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     return ((torch.log(value) * (self.concentration - 1.0)).sum(-1) +
             torch.lgamma(self.concentration.sum(-1)) -
             torch.lgamma(self.concentration).sum(-1))
Example #48
0
 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     return ((torch.log(value) * (self.alpha - 1.0)).sum(-1) +
             torch.lgamma(self.alpha.sum(-1)) -
             torch.lgamma(self.alpha).sum(-1))