def check(self, value): # check for diagonal equal to 1 unit_variance = torch.all( torch.abs(torch.diagonal(value, dim1=-2, dim2=-1) - 1) < 1e-6, dim=-1 ) # TODO: fix upstream - positive_definite has an extra dimension in front of output shape return positive_definite.check(value) & unit_variance
def _log_prob(self, parameter): if not positive_definite.check(parameter): raise ValueError( "parameter must be positive definite for Wishart prior") return self.C + 0.5 * ( (self.nu - self.shape[0] - 1) * torch.log(torch.det(parameter)) - torch.trace(self.K_inv.matmul(parameter)))
def _log_prob(self, parameter): if not positive_definite.check(parameter): raise ValueError( "parameter must be positive definite for Inverse Wishart prior" ) return self.C - 0.5 * ( (self.nu + 2 * self.shape[0]) * torch.log(torch.det(parameter)) + torch.trace(torch.gesv(self.K, parameter)[0]))
def is_valid_correlation_matrix(Sigma, tol=1e-6): """ This function returns true when all diagonal elements of Sigma are strictly 1 (in a float sense) and the matrix is positive definite, and false otherwise. """ pdef = positive_definite.check(Sigma) return bool(torch.all( torch.abs(Sigma.diag() - 1) < tol)) if pdef else False
def __init__(self, nu, K): if not positive_definite.check(K): raise ValueError("K must be positive definite") n = K.shape[0] if nu <= n: raise ValueError("Must have nu > n - 1") super(WishartPrior, self).__init__() self.register_buffer("K_inv", torch.inverse(K)) self.register_buffer("nu", torch.Tensor([nu])) # normalization constant C = -(nu / 2 * torch.log(torch.det(K)) + nu * n / 2 * math.log(2) + log_mv_gamma(n, nu / 2)) self.register_buffer("C", C) self._log_transform = False
def __init__(self, nu, K): n = K.shape[0] if not positive_definite.check(K): raise ValueError("K must be positive definite") if nu <= 0: raise ValueError("Must have nu > 0") super(InverseWishartPrior, self).__init__() self.register_buffer("K", K) self.register_buffer("nu", torch.Tensor([nu])) # normalization constant c = (nu + n - 1) / 2 C = c * torch.log(torch.det(K)) - c * n * math.log(2) - log_mv_gamma( n, c) self.register_buffer("C", C) self._log_transform = False
def is_in_support(self, parameter): return bool(positive_definite.check(parameter))