示例#1
0
 def _log_v(self):
     c = self.stickbreaking.posterior.params.concentrations
     c = c.reshape(self.n_components, -1, 2)[:, self.ordering, :]
     s_dig = torch.digamma(c.sum(dim=-1))
     log_v = torch.digamma(c[:, :, 0]) - s_dig
     log_1_v = torch.digamma(c[:, :, 1]) - s_dig
     return log_v, log_1_v
示例#2
0
    def uncertainty_metrics(logits):
        """Calculates mutual info, entropy of expected, and expected entropy, EPKL and Differential Entropy uncertainty metrics for
        the data x."""
        alphas = torch.exp(logits)
        alpha0 = torch.sum(alphas, dim=1, keepdim=True)
        probs = alphas / alpha0

        epkl = (alphas.size()[1] - 1.0) / alphas

        dentropy = torch.sum(
            torch.lgamma(alphas) - (alphas - 1) * (torch.digamma(alphas) - torch.digamma(alpha0)),
            dim=1) - torch.lgamma(alpha0)

        conf = torch.max(probs, dim=1)

        expected_entropy = -torch.sum(
            (alphas / alpha0) * (torch.digamma(alphas + 1) - torch.digamma(alpha0 + 1)),
            dim=1)
        entropy_of_exp = categorical_entropy_torch(probs)
        mutual_info = entropy_of_exp - expected_entropy

        uncertainties = {'confidence': conf,
                         'entropy_of_expected': entropy_of_exp,
                         'expected_entropy': expected_entropy,
                         'mutual_information': mutual_info,
                         'EPKL': epkl,
                         'differential_entropy': torch.squeeze(dentropy),
                         }

        return uncertainties
示例#3
0
 def kl_loss(self, alpha, beta):
     return torch.mean(torch.sum(
             self.prior_logbeta-logbeta(alpha, beta)
             + (alpha-self.alpha_prior)*torch.digamma(alpha)
             + (beta-self.beta_prior)*torch.digamma(beta)
             + (self.alpha_prior-alpha+self.beta_prior-beta)*torch.digamma(alpha+beta)
             , dim=1))
示例#4
0
def dirichlet_prior_network_uncertainty(logits, epsilon=1e-10):
    # based on original code from prior networks, see https://github.com/KaosEngineer/PriorNetworks/
    alphas = torch.exp(logits)
    alpha0 = torch.sum(alphas, axis=1, keepdim=True)
    probs = alphas / alpha0

    conf = torch.max(probs, axis=1)[0]

    entropy_of_exp = -torch.sum(probs * torch.log(probs + epsilon), axis=1)
    expected_entropy = -torch.sum(
        (alphas / alpha0) *
        (torch.digamma(alphas + 1) - torch.digamma(alpha0 + 1.0)),
        axis=1)
    mutual_info = entropy_of_exp - expected_entropy

    epkl = torch.squeeze((alphas.shape[1] - 1.0) / alpha0)

    dentropy = torch.sum(torch.lgamma(alphas) - (alphas - 1.0) *
                         (torch.digamma(alphas) - torch.digamma(alpha0)),
                         axis=1,
                         keepdim=True) - torch.lgamma(alpha0)

    uncertainty = {
        'confidence': conf,
        'entropy_of_expected': entropy_of_exp,
        'expected_entropy': expected_entropy,
        'mutual_information': mutual_info,
        'EPKL': epkl,
        'differential_entropy': torch.squeeze(dentropy),
    }

    return uncertainty
示例#5
0
def dirichlet_kl_divergence(alphas,
                            target_alphas,
                            precision=None,
                            target_precision=None,
                            epsilon=1e-8):
    # based on original code from prior networks, see https://github.com/KaosEngineer/PriorNetworks/

    if not precision:
        precision = torch.sum(alphas, dim=1, keepdim=True)
    if not target_precision:
        target_precision = torch.sum(target_alphas, dim=1, keepdim=True)

    precision_term = torch.lgamma(target_precision) - torch.lgamma(precision)
    assert torch.all(torch.isfinite(precision_term)).item()
    alphas_term = torch.sum(torch.lgamma(alphas + epsilon) -
                            torch.lgamma(target_alphas + epsilon) +
                            (target_alphas - alphas) *
                            (torch.digamma(target_alphas + epsilon) -
                             torch.digamma(target_precision + epsilon)),
                            dim=1,
                            keepdim=True)
    assert torch.all(torch.isfinite(alphas_term)).item()

    cost = torch.squeeze(precision_term + alphas_term)
    return cost
示例#6
0
    def update_q_z(self):
        '''
            Psi : Noise matrix (N X 1)
            E_ln_v
            E_ln_1_minus_v
            ln_rho : (M X N) matrix
        '''
        Psi = self.new_Psi()

        E_ln_v = torch.digamma(
            self.v_beta_a) - torch.digamma(self.v_beta_a + self.v_beta_b)
        E_ln_1_minus_v = torch.digamma(
            self.v_beta_b) - torch.digamma(self.v_beta_a + self.v_beta_b)

        tmp_sum = torch.zeros(self.M)
        for m in range(0, self.M):
            tmp_sum[m] += E_ln_v[m]
            for i in np.arange(0, m):
                tmp_sum[m] += E_ln_1_minus_v[i]

        ln_rho = -0.5 * ((((self.Y.repeat(self.M,1,1)-self.q_f_mean)**2)/Psi).sum(2) \
                    + torch.stack([torch.diag(self.q_f_sig[m]/Psi) for m in range(self.M)]) \
                    + self.D*torch.log(np.pi*2*Psi).repeat(self.M,1,1).sum(2)) \
                    + (tmp_sum).repeat(self.N,1).T

        self.q_z_pi = torch.exp(ln_rho)
        self.q_z_pi /= self.q_z_pi.sum(0)[None, :]
        self.q_z_pi[torch.isnan(self.q_z_pi)] = 1.0 / self.M
示例#7
0
    def _E_log_stick(tau, K):
        """
        @param tau: (K, 2)
        @return: ((K,), (K, K))

        where the first return value is E_log_stick, and the second is q
        """
        # we use the same indexing as in eq. (10)
        q = torch.zeros(K, K)

        # working in log space until the last step
        first_term = digamma(tau[:, 1])
        second_term = digamma(tau[:, 0]).cumsum(0) - digamma(tau[:, 0])
        third_term = digamma(tau.sum(1)).cumsum(0)
        q += (first_term + second_term - third_term).view(1, -1)
        q = torch.tril(q.exp())
        q = torch.nn.functional.normalize(q, p=1, dim=1)
        # NOTE: we should definitely detach q, since it's a computational aid
        # (i.e. already optimized to make our lower bound better)

        assert (q.sum(1) - torch.ones(K)
                ).abs().max().item() < 1e-6, "WTF normalize didn't work"
        q = q.detach()

        torch_e_logstick = InfiniteIBP._E_log_stick_from_q(q, tau)
        return torch_e_logstick, q
示例#8
0
def dirichlet_kl_divergence(alphas, target_alphas, precision=None, target_precision=None,
                            epsilon=1e-8):
    """
    This function computes the Forward KL divergence between a model Dirichlet distribution
    and a target Dirichlet distribution based on the concentration (alpha) parameters of each.

    :param alphas: Tensor containing concentation parameters of model. Expected shape is batchsize X num_classes.
    :param target_alphas: Tensor containing target concentation parameters. Expected shape is batchsize X num_classes.
    :param precision: Optional argument. Can pass in precision of model. Expected shape is batchsize X 1.
        precision is the alpha_0 value representing the sum of all alpha_c's (for a normal DNN this is denominator of softmax)
    :param target_precision: Optional argument. Can pass in target precision. Expected shape is batchsize X 1
        target precision is the alpha_0 hyperparameter, to be chosen for target dirichlet distribution.
    :param epsilon: Smoothing factor for numercal stability. Default value is 1e-8
    :return: Tensor for Batchsize X 1 of forward KL divergences between target Dirichlet and model
    """
    if not precision:
        precision = torch.sum(alphas, dim=1, keepdim=True)
    if not target_precision:
        target_precision = torch.sum(target_alphas, dim=1, keepdim=True)

    print(target_precision, precision)
    precision_term = torch.lgamma(target_precision) - torch.lgamma(precision)
    print(torch.isfinite(precision_term))
    assert torch.all(torch.isfinite(precision_term)).item()
    alphas_term = torch.sum(torch.lgamma(alphas) - torch.lgamma(target_alphas)
                            + (target_alphas - alphas) * (torch.digamma(target_alphas)
                                                          - torch.digamma(
                target_precision)), dim=1, keepdim=True)
    print("alphas:", alphas_term)
    assert torch.all(torch.isfinite(alphas_term)).item()

    cost = torch.squeeze(precision_term + alphas_term)
    return cost
示例#9
0
 def entropy(self):
     k = self.concentration.size(-1)
     a0 = self.concentration.sum(-1)
     return (torch.lgamma(self.concentration).sum(-1) - torch.lgamma(a0) -
             (k - a0) * torch.digamma(a0) -
             ((self.concentration - 1.0) *
              torch.digamma(self.concentration)).sum(-1))
示例#10
0
 def fit_psi(self, M, x_idx):
     '''
       This fits the variational parameter for the latent indicator z_jkm ~ Multinomial(psi_jkm) where psi_jkm = phi_km * theta_jk
       @Args:
           M = The number of mutations in sample N
           x_idx = The index of the relevant mutation in sample J
       @Returns:
           This returns the update the latent indicator variational parameter psi, shape = J x K x M_n 
           (sample j, number of factors, and proportions of the specific mutations for each factor)
     '''
     for n, M_n in enumerate(M):
         # First calculate log prob of phi
         E_ln_phi = torch.digamma(self.eta[:,x_idx[n]]) - 
         torch.digamma(torch.mm(torch.sum(self.eta, dim=1, keepdim=True), 
                                torch.ones(1, M_n).to(self.device))) # Expectation of Dirichlet, Shape Factors x Mutations
         
         # Then calculate log prob of theta - factor proportions in sample J - shape = K x M, 
         E_ln_theta_j = torch.mm((torch.digamma(self.theta_jk1[n,:]) + torch.log(
         self.theta_jk2[n,:])).unsqueeze(1), torch.ones(1, M_n).to(self.device)) # Expectation of Gamma
         
         # Update indicators [K x M] + [K x M]
         psi_n = torch.exp(E_ln_phi+E_ln_theta_j)/torch.mm(torch.ones(self.K, 1).to(self.device), 
                                                             torch.sum(torch.exp(E_ln_phi+E_ln_theta_j), dim=0, keepdim=True))
         # Add small noise to avoid NaN/Zeros
         self.psi[n] = psi_n.data+1e-6
示例#11
0
    def update_q_z(self):
        sigma = torch.exp(self.log_p_y_sigma)
        E_ln_v = torch.digamma(
            self.v_beta_a) - torch.digamma(self.v_beta_a + self.v_beta_b)
        E_ln_1_minus_v = torch.digamma(
            self.v_beta_b) - torch.digamma(self.v_beta_a + self.v_beta_b)

        tmp_sum = torch.zeros(self.M)
        for m in range(0, self.M):
            tmp_sum[m] += E_ln_v[m]
            for i in np.arange(0, m):
                tmp_sum[m] += E_ln_1_minus_v[i]

        log_pi = -0.5/sigma * ((self.Y.repeat(self.M,1,1)-self.q_f_mean)**2).sum(2) \
                    -0.5/sigma * torch.stack([torch.diag(self.q_f_sig[m]) for m in range(self.M)]) \
                    -0.5*torch.log(np.pi*2*sigma)*self.D \
                    + (tmp_sum).repeat(self.N,1).T

        self.q_z_pi = torch.exp(log_pi)

        self.q_z_pi /= self.q_z_pi.sum(0)[None, :]

        self.q_z_pi[torch.isnan(self.q_z_pi)] = 1.0 / self.M

        n_digits = 3

        self.q_z_pi = (self.q_z_pi * 10**n_digits).round() / (10**n_digits)
示例#12
0
文件: MFMM.py 项目: russellkune/MFMM
    def ELBO_prior(self, W, q_A, q_gamma, q_alpha, q_mu, q_beta, q_sigma, q_pi,
                   phi):
        '''
        taking the expectation of 
        
        log p(A) + sum_k [ log p(mu_k) + log p(Sigma_k) ] + sum_i [log p(pi_i) + log p (gamma_i )] 
        + sum_i sum_j sum_k z_ijk [ log pi_ik + log p(alpha_ij) + log p(beta_ij)]
        
        ***NOTE THAT THIS TERM IS OFF BY CONSTANTS*** not the true elbo
        
        '''
        first_term = -0.5 * ((q_A[0]**2 + q_A[1].exp()**2).sum() +
                             (q_mu[0]**2 + q_mu[1].exp()**2).sum() +
                             (q_gamma[0]**2 + q_gamma[1].exp()**2).sum())

        second_term = (
            -2.0 *
            (torch.log(q_sigma[1] + self.epsilon) - torch.digamma(q_sigma[0]))
            - q_sigma[0] / (q_sigma[1] + self.epsilon)).sum()

        third_term1 = torch.digamma(q_pi) - torch.digamma(
            q_pi.sum(dim=1, keepdim=True))

        third_term = contract('ijk, ik ->', phi, third_term1)

        fourth_term = -0.5 * contract('ijk, ij ->', phi,
                                      q_alpha[0]**2 + q_alpha[1].exp()**2)

        fifth_term1 = -2.0 * (torch.log(q_beta[1] + self.epsilon) -
                              torch.digamma(q_beta[0] + self.epsilon)
                              ) - q_beta[0] / (q_beta[1] + self.epsilon)
        fifth_term = contract('ijk, ij->', phi, fifth_term1)

        return first_term + second_term + third_term + fourth_term + fifth_term
示例#13
0
文件: edl.py 项目: hsljc/ae-dnn
def edl_loss(evidence, y, epoch, n_classes, ):
    """Implementation of the EDL loss

    Args:
        evidence: Predicted evidence.
        y: Ground truth labels.
        epoch: Current epoch starting with 0.

    Returns:
        float: The loss defined by evidential deep learning.
    """
    device = y.device

    y_one_hot = torch.eye(n_classes, device=device)[y]
    alpha = evidence + 1
    S = alpha.sum(-1, keepdim=True)
    p_hat = alpha / S

    # comp bayes risk
    bayes_risk = torch.sum((y_one_hot - p_hat)**2 + p_hat * (1 - p_hat) / S, -1)

    # kl-div term
    alpha_tilde = y_one_hot + (1 - y_one_hot) * alpha  # hadmard first???
    S_alpha_tilde = alpha_tilde.sum(-1, keepdim=True)
    t1 = torch.lgamma(S_alpha_tilde) - math.lgamma(10) - torch.lgamma(alpha_tilde).sum(-1, keepdim=True)
    t2 = torch.sum((alpha_tilde - 1) * (torch.digamma(alpha_tilde) -
                                        torch.digamma(S_alpha_tilde)), dim=-1, keepdim=True)
    kl_div = t1 + t2

    lmbda = min((epoch + 1)/10, 1)
    loss = torch.mean(bayes_risk) + lmbda*torch.mean(kl_div)

    return loss
示例#14
0
文件: vae.py 项目: entn-at/padertorch
    def forward(self, inputs):
        mean, log_var = inputs
        qz = D.Normal(loc=mean, scale=torch.exp(0.5 * log_var))

        gaussians = self.gaussians

        # Patricks Arbeit Gl. 3.14
        # Gl. 2.21:
        term1 = torch.digamma(self.alpha_0 + self.counts.detach())  # + const.

        # Gl. 2.22:
        term2 = torch.digamma(
            (self.nu_0 + self.counts[:, None].detach()
             - torch.arange(self.feature_size).float().to(term1.device)) / 2
        ).sum(-1) - torch.log(self.nu_0 + self.counts.detach())
        # 0.5*ln|\nu*W| = 0.5*(ln\nu + ln|W|) is part of kl in term3

        # Gl. 3.15
        term3 = (
            kl_divergence(qz, gaussians)
            + 0.5 * self.feature_size / (self.kappa_0 + self.counts.detach())
        )

        log_rho = term1 + 0.5 * term2 - term3
        log_gamma = torch.log_softmax(log_rho, dim=-1).detach()
        gamma = torch.exp(log_gamma)
        log_rho_ = (gamma * log_rho).sum(-1)
        z = self.sample(mean, log_var)
        z, log_rho_, pool_indices, log_gamma = self.pool(
            z, log_rho_, log_gamma
        )
        return z, log_rho_, pool_indices, log_gamma
    def update_qlatent(self, a_i, V_ji):

        # Out ← [?, 1, C, 1, 1, F, F, 1, 1]
        self.Elnpi_j = torch.digamma(self.alpha_j) \
            - torch.digamma(self.alpha_j.sum(dim=2, keepdim=True))

        # Out ← [?, 1, C, 1, 1, F, F, 1, 1] broadcasting diga_arg
        self.Elnlambda_j = self.reduce_poses(
            torch.digamma(.5*(self.nu_j - self.diga_arg))) \
                + self.Dlog2 + self.lndet_Psi_j

        if self.cov == 'diag':
            # Out ← [?, B, C, 1, 1, F, F, K, K]
            ElnQ = (self.D/self.kappa_j) + self.nu_j \
                * self.reduce_poses((1./self.invPsi_j) * (V_ji - self.m_j).pow(2))

        elif self.cov == 'full':
            # Out ← [?, B, C, 1, 1, F, F, K, K]
            Vm_j = V_ji - self.m_j
            ElnQ = (self.D/self.kappa_j) + self.nu_j * self.reduce_poses(
                Vm_j.transpose(3,4) * torch.inverse(
                    self.invPsi_j).permute(0,1,2,7,8,3,4,5,6) * Vm_j)

        # Out ← [?, B, C, 1, 1, F, F, 1, 1]
        lnp_j = .5*self.Elnlambda_j -.5*self.Dlog2pi -.5*ElnQ

        # Out ← [?, B, C, 1, 1, F, F, 1, 1] # normalise over out_caps j
        lnR_ij = lnp_j - torch.logsumexp(self.Elnpi_j + lnp_j, dim=2, keepdim=True)
        return torch.exp(lnR_ij)
def KL(alpha, K):    
    beta = torch.ones([1, K], dtype=torch.float32)
    S_alpha = torch.sum(alpha, dim=1, keepdim=True)    
    KL_val = torch.sum((alpha - beta)*(torch.digamma(alpha)-torch.digamma(S_alpha)),dim=1,keepdim=True) + \
         torch.lgamma(S_alpha) - torch.sum(torch.lgamma(alpha),dim=1,keepdim=True) + \
         torch.sum(torch.lgamma(beta),dim=1,keepdim=True) - torch.lgamma(torch.sum(beta,dim=1,keepdim=True))
    return KL_val
示例#17
0
 def entropy(self):
     lbeta = torch.lgamma(0.5 * self.df) + math.lgamma(0.5) - torch.lgamma(
         0.5 * (self.df + 1))
     return (self.scale.log() + 0.5 * (self.df + 1) *
             (torch.digamma(0.5 *
                            (self.df + 1)) - torch.digamma(0.5 * self.df)) +
             0.5 * self.df.log() + lbeta)
示例#18
0
 def compute_kld_pres(self, result_obj):
     tau1 = result_obj['tau1']
     tau2 = result_obj['tau2']
     logits_zeta = result_obj['logits_zeta']
     psi1 = torch.digamma(tau1)
     psi2 = torch.digamma(tau2)
     psi12 = torch.digamma(tau1 + tau2)
     kld_1 = torch.lgamma(tau1 + tau2) - torch.lgamma(tau1) - torch.lgamma(tau2) - self.prior_pres_log_alpha
     kld_2 = (tau1 - self.prior_pres_alpha) * psi1
     kld_3 = (tau2 - 1) * psi2
     kld_4 = -(tau1 + tau2 - self.prior_pres_alpha - 1) * psi12
     zeta = torch.sigmoid(logits_zeta)
     log_zeta = nn_func.logsigmoid(logits_zeta)
     log1m_zeta = log_zeta - logits_zeta
     psi1_le_sum = psi1.cumsum(0)
     psi12_le_sum = psi12.cumsum(0)
     kappa1 = psi1_le_sum - psi12_le_sum
     psi1_lt_sum = torch.cat([torch.zeros([1, *psi1_le_sum.shape[1:]], device=zeta.device), psi1_le_sum[:-1]])
     logits_coef = psi2 + psi1_lt_sum - psi12_le_sum
     kappa2_list = []
     for idx in range(logits_coef.shape[0]):
         coef = torch.softmax(logits_coef[:idx + 1], dim=0)
         log_coef = nn_func.log_softmax(logits_coef[:idx + 1], dim=0)
         coef_le_sum = coef.cumsum(0)
         coef_lt_sum = torch.cat([torch.zeros([1, *coef_le_sum.shape[1:]], device=zeta.device), coef_le_sum[:-1]])
         part1 = (coef * psi2[:idx + 1]).sum(0)
         part2 = ((1 - coef_le_sum[:-1]) * psi1[:idx]).sum(0)
         part3 = -((1 - coef_lt_sum) * psi12[:idx + 1]).sum(0)
         part4 = -(coef * log_coef).sum(0)
         kappa2_list.append(part1 + part2 + part3 + part4)
     kappa2 = torch.stack(kappa2_list)
     kld_5 = zeta * (log_zeta - kappa1) + (1 - zeta) * (log1m_zeta - kappa2)
     kld = kld_1 + kld_2 + kld_3 + kld_4 + kld_5
     return kld.sum([0, *range(2, kld.ndim)])
示例#19
0
 def expected_entropy_from_alphas(alphas, alpha0=None):
     if alpha0 is None:
         alpha0 = torch.sum(alphas, dim=1, keepdim=True)
     expected_entropy = -torch.sum(
         (alphas / alpha0) * (torch.digamma(alphas + 1) - torch.digamma(alpha0 + 1)),
         dim=1)
     return expected_entropy
示例#20
0
    def diffenrential_entropy(self, x):
        alphas = self.alphas(x)
        alpha0 = torch.sum(alphas, dim=1, keepdim=True)

        return torch.sum(
            torch.lgamma(alphas) - (alphas - 1) * (torch.digamma(alphas) - torch.digamma(alpha0)),
            dim=1) - torch.lgamma(alpha0)
def KL_Beta(alpha1, beta1, alpha2, beta2):
    kl = torch.lgamma(alpha1 + beta1) + torch.lgamma(alpha2) + torch.lgamma(
        beta2) - torch.lgamma(alpha2 + beta2) - torch.lgamma(
            alpha1) - torch.lgamma(beta1) + (alpha1 - alpha2) * (torch.digamma(
                alpha1) - torch.digamma(alpha1 + beta1)) + (beta1 - beta2) * (
                    torch.digamma(beta1) - torch.digamma(alpha1 + beta1))
    return kl
def dirichlet_kl_divergence(alphas,
                            target_alphas,
                            precision=None,
                            target_precision=None,
                            epsilon=1e-8):  # see supplementary C5
    """
    This function computes the Forward KL divergence between a model Dirichlet distribution
    and a target Dirichlet distribution based on the concentration (alpha) parameters of each.

    :param alphas: Tensor containing concentration parameters of model. Expected shape is batchsize X num_classes.
    :param target_alphas: Tensor containing target concentration parameters. Expected shape is batchsize X num_classes.
    :param precision: Optional argument. Can pass in precision of model. Expected shape is batchsize X 1
    :param target_precision: Optional argument. Can pass in target precision. Expected shape is batchsize X 1
    :param epsilon: Smoothing factor for numerical stability. Default value is 1e-8
    :return: Tensor for batchsize X 1 of forward KL divergences between target Dirichlet and model
    """
    if not precision:
        precision = torch.sum(alphas, dim=1, keepdim=True)
    if not target_precision:
        target_precision = torch.sum(target_alphas, dim=1, keepdim=True)
    precision_term = torch.lgamma(target_precision) - torch.lgamma(precision)
    alphas_term = torch.sum(torch.lgamma(alphas + epsilon) -
                            torch.lgamma(target_alphas + epsilon) +
                            (target_alphas - alphas) *
                            (torch.digamma(target_alphas + epsilon) -
                             torch.digamma(target_precision + epsilon)),
                            dim=1,
                            keepdim=True)

    cost = torch.squeeze(precision_term + alphas_term)
    return cost
示例#23
0
    def forward(self, logits, targets):
        '''
        Compute loss: kl - evi
        
        '''
        alphas = torch.exp(logits)
        betas = torch.ones_like(logits) * self.prior

        # compute log-likelihood loss: psi(alpha_target) - psi(alpha_zero)
        a_ans = torch.gather(alphas, -1, targets.unsqueeze(-1)).squeeze(-1)
        a_zero = torch.sum(alphas, -1)
        ll_loss = torch.digamma(a_ans) - torch.digamma(a_zero)

        # compute kl loss: loss1 + loss2
        #       loss1 = log_gamma(alpha_zero) - \sum_k log_gamma(alpha_zero)
        #       loss2 = sum_k (alpha_k - beta_k) (digamma(alpha_k) - digamma(alpha_zero) )
        loss1 = torch.lgamma(a_zero) - torch.sum(torch.lgamma(alphas), -1)

        loss2 = torch.sum(
            (alphas - betas) *
            (torch.digamma(alphas) - torch.digamma(a_zero.unsqueeze(-1))), -1)
        kl_loss = loss1 + loss2

        loss = ((self.coeff * kl_loss - ll_loss)).mean()

        return loss
    def forward(self, a_i, V_ji):

        self.F_i = a_i.shape[-2:] # input capsule (B) votes feature map size (K)
        self.F_o = a_i.shape[-4:-2] # output capsule (C) feature map size (F)
        self.N = self.B*self.F_i[0]*self.F_i[1] # total num of lower level capsules

        # Out ← [1, B, C, 1, 1, 1, 1, 1, 1]
        R_ij = (1./self.C) * torch.ones(1,self.B,self.C,1,1,1,1,1,1, requires_grad=False).cuda()

        for i in range(self.iter): # routing iters

            # update capsule parameter distributions
            self.update_qparam(a_i, V_ji, R_ij)

            if i != self.iter-1: # skip last iter
                # update latent variable distributions (child to parent capsule assignments)
                R_ij = self.update_qlatent(a_i, V_ji)

        # Out ← [?, 1, C, 1, 1, F, F, 1, 1]
        self.Elnlambda_j = self.reduce_poses(
            torch.digamma(.5*(self.nu_j - self.diga_arg))) \
                + self.Dlog2 + self.lndet_Psi_j

        # Out ← [?, 1, C, 1, 1, F, F, 1, 1]
        self.Elnpi_j = torch.digamma(self.alpha_j) \
            - torch.digamma(self.alpha_j.sum(dim=2, keepdim=True))

        # subtract "- .5*ln|lmbda|" due to precision matrix, instead of adding "+ .5*ln|sigma|" for covariance matrix
        H_q_j = .5*self.D * torch.log(torch.tensor(2*np.pi*np.e)) - .5*self.Elnlambda_j # posterior entropy H[q*(mu_j, sigma_j)]

        # Out ← [?, 1, C, 1, 1, F, F, 1, 1] weighted negative entropy with optional beta params and R_j weight
        a_j = self.beta_a - (torch.exp(self.Elnpi_j) * H_q_j + self.beta_u) #* self.R_j

        # Out ← [?, C, F, F]
        a_j = a_j.squeeze()

        # Out ← [?, C, P*P, F, F] ← [?, 1, C, P*P, 1, F, F, 1, 1]
        self.m_j = self.m_j.squeeze()

        # so BN works in the classcaps layer
        if self.class_caps:
            # Out ← [?, C, 1, 1] ← [?, C]
            a_j = a_j[...,None,None]

            # Out ← [?, C, P*P, 1, 1] ← [?, C, P*P]
            self.m_j = self.m_j[...,None,None]
        # else:
        #     self.m_j = self.BN_v(self.m_j)

        # Out ← [?, C, P*P, F, F]
        self.m_j = self.BN_v(self.m_j) # use 'else' above to deactivate BN_v for class_caps

        # Out ← [?, C, P, P, F, F] ← [?, C, P*P, F, F]
        self.m_j = self.m_j.reshape(-1, self.C, self.P, self.P, *self.F_o)

        # Out ← [?, C, F, F]
        a_j = torch.sigmoid(self.BN_a(a_j))

        return a_j.squeeze(), self.m_j.squeeze() # propagate posterior means to next layer
def eq4(evidence, target, epoch_num, K, annealing_step):
    alpha = evidence + 1
    S = torch.sum(alpha, dim=1, keepdim=True)
    loglikelihood = torch.sum(target * (torch.digamma(S) - torch.digamma(alpha)), dim=1, keepdim=True)
    annealing = torch.min(torch.tensor(
        1.0, dtype=torch.float32), torch.tensor(epoch_num / annealing_step, dtype=torch.float32))
    kl = annealing * KL((alpha - 1) * (1 - target) + 1, K)
    return torch.mean(loglikelihood + kl)
示例#26
0
 def entropy(self):
     alpha = self.base_dist.marginal_t.base_dist.concentration1
     beta = self.base_dist.marginal_t.base_dist.concentration0
     return -(
         self.log_normalizer()
         + self.scale
         * (math.log(2) + torch.digamma(alpha) - torch.digamma(alpha + beta))
     )
示例#27
0
文件: categorical.py 项目: xrick/beer
 def mean(self):
     c = self.stickbreaking.posterior.params.concentrations
     s_dig = torch.digamma(c.sum(dim=-1))
     log_v = torch.digamma(c[:, 0]) - s_dig
     log_1_v = torch.digamma(c[:, 1]) - s_dig
     log_prob = log_v
     log_prob[1:] += log_1_v[:-1].cumsum(dim=0)
     return log_prob.exp()[self.reverse_ordering]
示例#28
0
def lglh(alpha, gamma):
    len_doc = len(gamma)
    alpha_g = len_doc * (torch.lgamma(alpha.sum(0)) -
                         torch.lgamma(alpha).sum(0))
    gamma_g = torch.sum(
        (alpha - 1) * (torch.digamma(gamma) -
                       torch.digamma(gamma.sum(-1)).view(-1, 1)).sum(0))
    return alpha_g + gamma_g
示例#29
0
 def kl_divergence(alpha_t):
     num_cls = alpha_t.size(1)
     beta = torch.ones([1, num_cls], device=Evidential_DL.device)
     S_alpha_t = torch.sum(alpha_t, dim=1, keepdim=True)
     KL = torch.sum((alpha_t - beta) * (torch.digamma(alpha_t) - torch.digamma(S_alpha_t)), dim=1, keepdim=True) +\
          torch.lgamma(S_alpha_t) - torch.sum(torch.lgamma(alpha_t), dim=1, keepdim=True) +\
          torch.sum(torch.lgamma(beta), dim=1, keepdim=True) - torch.lgamma(torch.sum(beta, dim=1, keepdim=True))
     return KL
示例#30
0
文件: kl.py 项目: Jsmilemsj/pytorch
def _kl_beta_beta(p, q):
    sum_params_p = p.concentration1 + p.concentration0
    sum_params_q = q.concentration1 + q.concentration0
    t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (sum_params_p).lgamma()
    t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (sum_params_q).lgamma()
    t3 = (p.concentration1 - q.concentration1) * torch.digamma(p.concentration1)
    t4 = (p.concentration0 - q.concentration0) * torch.digamma(p.concentration0)
    t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p)
    return t1 - t2 + t3 + t4 + t5
示例#31
0
 def cavi_nu(self, n, k, X, log_stick):
     N, K, D = X.shape[0], self.K, self.D
     first_term = (digamma(self.tau[:k+1, 0]) - digamma(self.tau.sum(1)[:k+1])).sum() - \
         log_stick[k]
     # this line is really slow
     other_prod = (self.nu[n] @ self.phi - self.nu[n, k] * self.phi[k])
     second_term = (-1. / (2 * self.sigma_n ** 2) * (self.phi_var[k].sum() + self.phi[k].pow(2).sum())) + \
         (self.phi[k] @ (X[n] - other_prod)) / (self.sigma_n ** 2)
     self._nu[n][k] = first_term + second_term
示例#32
0
文件: kl.py 项目: lxlhh/pytorch
 def f(a, b, c, d):
     return -d * a / c + b * a.log() - torch.lgamma(b) + (b - 1) * torch.digamma(d) + (1 - b) * c.log()
示例#33
0
 def entropy(self):
     k = self.concentration.size(-1)
     a0 = self.concentration.sum(-1)
     return (torch.lgamma(self.concentration).sum(-1) - torch.lgamma(a0) -
             (k - a0) * torch.digamma(a0) -
             ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1))
示例#34
0
文件: kl.py 项目: Jsmilemsj/pytorch
def _kl_gamma_gamma(p, q):
    t1 = q.concentration * (p.rate / q.rate).log()
    t2 = torch.lgamma(q.concentration) - torch.lgamma(p.concentration)
    t3 = (p.concentration - q.concentration) * torch.digamma(p.concentration)
    t4 = (q.rate - p.rate) * (p.concentration / p.rate)
    return t1 + t2 + t3 + t4
示例#35
0
文件: studentT.py 项目: lxlhh/pytorch
 def entropy(self):
     lbeta = torch.lgamma(0.5 * self.df) + math.lgamma(0.5) - torch.lgamma(0.5 * (self.df + 1))
     return (self.scale.log() +
             0.5 * (self.df + 1) *
             (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df)) +
             0.5 * self.df.log() + lbeta)
示例#36
0
 def entropy(self):
     k = self.alpha.size(-1)
     a0 = self.alpha.sum(-1)
     return (torch.lgamma(self.alpha).sum(-1) - torch.lgamma(a0) -
             (k - a0) * torch.digamma(a0) -
             ((self.alpha - 1.0) * torch.digamma(self.alpha)).sum(-1))
示例#37
0
文件: gamma.py 项目: lxlhh/pytorch
 def entropy(self):
     return (self.alpha - torch.log(self.beta) + torch.lgamma(self.alpha) +
             (1.0 - self.alpha) * torch.digamma(self.alpha))
示例#38
0
 def entropy(self):
     return (self.concentration - torch.log(self.rate) + torch.lgamma(self.concentration) +
             (1.0 - self.concentration) * torch.digamma(self.concentration))