示例#1
0
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        diff = value - self.loc
        M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
        n_dim = len(self._event_shape)
        p = diff.size()[-n_dim:].numel()
        if n_dim > 1:
            M = M.sum(tuple(range(-n_dim + 1, 0)))

        log_diag = self._unbroadcasted_scale_tril.diagonal(dim1=-2,
                                                           dim2=-1).log()
        if n_dim > log_diag.dim():
            half_log_det = log_diag.sum() * (p / log_diag.numel())
        else:
            half_log_det = log_diag.sum(tuple(range(
                -n_dim, 0))) * (p / log_diag.size()[-n_dim:].numel())

        lambda_ = self.df - 2.
        lp = torch.lgamma((p+self.df)/2.) \
                - ((p/2.) * torch.log(math.pi * lambda_)) \
                - torch.lgamma(self.df / 2.) \
                - half_log_det \
                - ((self.df+p)/2.) * torch.log(1 + M/lambda_)
        return lp
示例#2
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     diff = value - self.mu
     L = self.cholesky
     M = _batch_mahalanobis(L, diff)
     half_log_det = L.diagonal(dim1=-2, dim2=-1).log().sum(-1)
     p = -half_log_det - 0.5 * (
         self._event_shape[0] * math.log(2 * math.pi) + M)
     return p
示例#3
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     value_x = _standard_normal_quantile(value)
     half_log_det = (
         self.multivariate_normal._unbroadcasted_scale_tril.diagonal(
             dim1=-2, dim2=-1).log().sum(-1))
     M = _batch_mahalanobis(
         self.multivariate_normal._unbroadcasted_scale_tril, value_x)
     M -= value_x.pow(2).sum(-1)
     return -0.5 * M - half_log_det
def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
    r"""
    Uses "Woodbury matrix identity"::
        inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
    where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
    Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
    """
    Wt_Dinv = W.transpose(-1, -2) / D.unsqueeze(-2)
    Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
    mahalanobis_term1 = (x.pow(2) / D).sum(-1)
    mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
    return mahalanobis_term1 - mahalanobis_term2
示例#5
0
def log_prob(loc, scale_tril, value, dim):
    diff = value - loc
    M = _batch_mahalanobis(scale_tril, diff)
    half_log_det = scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)

    return -0.5 * (dim * math.log(2 * math.pi) + M) - half_log_det