Exemple #1
0
def MultiChannelRateDistortion(W, device='cpu', eps=0.1):
    '''
    W in shape(N, C, H, W)
    '''
    m, c, _, _ = W.shape

    W = W.reshape(m, c, -1).permute(1, 2, 0)
    mean = W.mean(dim=2)
    W = W - mean.unsqueeze(dim=2)

    # W in shape(c, n, m)
    n = W.shape[1]
    I = torch.eye(n).to(device)
    a = n / (m * eps**2)
    # rate = torch.logdet(I.unsqueeze(0) + a * W.matmul(W.transpose(2,1))).sum() / (m*n*c) * (m+n)
    # rate = rate + torch.log(1 + mean.mul(mean).sum(dim=1)/eps**2).sum() / (m*c)
    rate = torch.logdet(I.unsqueeze(0) +
                        a * W.matmul(W.transpose(2, 1))).mean()

    rate = rate / 2.

    return rate
Exemple #2
0
    def forward(self, z, reverse=False):
        # shape
        batch_size, group_size, n_of_groups = z.size()

        W = self.conv.weight.squeeze()

        if reverse:
            if not hasattr(self, 'W_inverse'):
                # Reverse computation
                W_inverse = W.float().inverse()
                W_inverse = Variable(W_inverse[..., None])
                if (z.type() == 'torch.cuda.HalfTensor'
                        or z.type() == 'torch.HalfTensor'):
                    W_inverse = W_inverse.half()
                self.W_inverse = W_inverse
            z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
            return z
        else:
            # Forward computation
            log_det_W = batch_size * n_of_groups * torch.logdet(W.float())
            z = self.conv(z)
            return z, log_det_W
Exemple #3
0
    def test_inv_quad_logdet(self):
        # Forward
        lazy_tensor = self.create_lazy_tensor(with_solves=True, with_logdet=True)
        evaluated = self.evaluate_lazy_tensor(lazy_tensor)
        flattened_evaluated = evaluated.view(-1, *lazy_tensor.matrix_shape)

        vecs = lazy_tensor.eager_rhss[0].clone().detach().requires_grad_(True)
        vecs_copy = lazy_tensor.eager_rhss[0].clone().detach().requires_grad_(True)

        with gpytorch.settings.num_trace_samples(128), warnings.catch_warnings(record=True) as ws:
            res_inv_quad, res_logdet = lazy_tensor.inv_quad_logdet(inv_quad_rhs=vecs, logdet=True)
            self.assertFalse(any(issubclass(w.category, ExtraComputationWarning) for w in ws))
        res = res_inv_quad + res_logdet

        actual_inv_quad = evaluated.inverse().matmul(vecs_copy).mul(vecs_copy).sum(-2).sum(-1)
        actual_logdet = torch.cat(
            [torch.logdet(flattened_evaluated[i]).unsqueeze(0) for i in range(lazy_tensor.batch_shape.numel())]
        ).view(lazy_tensor.batch_shape)
        actual = actual_inv_quad + actual_logdet

        diff = (res - actual).abs() / actual.abs().clamp(1, math.inf)
        self.assertLess(diff.max().item(), 15e-2)
Exemple #4
0
    def forward(self, z, reverse=False):
        # shape
        batch_size, group_size, n_of_groups = z.size()

        W = self.conv.weight.squeeze()

        if reverse:
            if not hasattr(self, 'W_inverse'):
                # Reverse computation
                W_inverse = W.float().inverse()
                W_inverse = Variable(W_inverse[..., None])
                if "HalfTensor" in z.type():
                    W_inverse = W_inverse.half()
                self.W_inverse = W_inverse
                self.conv.weight.data = self.W_inverse
            z = self.conv(z)
            return z
        else:
            # Forward computation
            log_det_W = batch_size * n_of_groups * torch.logdet(W)
            z = self.conv(z)
            return z, log_det_W
Exemple #5
0
    def forward(self, x, x_mask=None, reverse=False, **kwargs):  # pylint: disable=unused-argument
        """
        Shapes:
            - x: :math:`[B, C, T]`
            - x_mask: :math:`[B, 1, T]`
        """
        b, c, t = x.size()
        assert c % self.num_splits == 0
        if x_mask is None:
            x_mask = 1
            x_len = torch.ones((b, ), dtype=x.dtype, device=x.device) * t
        else:
            x_len = torch.sum(x_mask, [1, 2])

        x = x.view(b, 2, c // self.num_splits, self.num_splits // 2, t)
        x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.num_splits,
                                                       c // self.num_splits, t)

        if reverse:
            if self.weight_inv is not None:
                weight = self.weight_inv
            else:
                weight = torch.inverse(
                    self.weight.float()).to(dtype=self.weight.dtype)
            logdet = None
        else:
            weight = self.weight
            if self.no_jacobian:
                logdet = 0
            else:
                logdet = torch.logdet(
                    self.weight) * (c / self.num_splits) * x_len  # [b]

        weight = weight.view(self.num_splits, self.num_splits, 1, 1)
        z = F.conv2d(x, weight)

        z = z.view(b, 2, self.num_splits // 2, c // self.num_splits, t)
        z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
        return z, logdet
Exemple #6
0
def lpdf_real_dirichlet(r, lpdf_p):
    """
    Remember to perform: r -= r.max() when before zeroing out the gradient
    """

    # The rank is the dim(r) - 1, since the remaining parameter
    # must be one minus the sum of the other parameters.
    rank = r.size().numel() - 1
    J = torch.empty([rank, rank], dtype=torch.float64)
    p = torch.softmax(r, 0)

    for i in range(rank):
        for j in range(i + 1):
            if i == j:
                # J[i, j] = torch.exp(x[i]) * (sum_x - torch.exp(x[j])) / (sum_x ** 2)
                J[i, i] = p[i] * (1 - p[i])
            else:
                # tmp = torch.exp(x[i] + x[j]) / (sum_x ** 2)
                tmp = -p[i] * p[j]
                J[i, j] = tmp
                J[j, i] = tmp

    return lpdf_p(p) + torch.logdet(J)
    def test_inv_quad_logdet_no_reduce(self):
        # Forward
        lazy_tensor = self.create_lazy_tensor()
        evaluated = self.evaluate_lazy_tensor(lazy_tensor)
        flattened_evaluated = evaluated.view(-1, *lazy_tensor.matrix_shape)

        vecs = torch.randn(*lazy_tensor.batch_shape, lazy_tensor.size(1), 3, requires_grad=True)
        vecs_copy = vecs.clone().detach_().requires_grad_(True)

        with gpytorch.settings.num_trace_samples(128):
            res_inv_quad, res_logdet = lazy_tensor.inv_quad_logdet(
                inv_quad_rhs=vecs, logdet=True, reduce_inv_quad=False
            )
        res = res_inv_quad.sum(-1) + res_logdet

        actual_inv_quad = evaluated.inverse().matmul(vecs_copy).mul(vecs_copy).sum(-2).sum(-1)
        actual_logdet = torch.cat(
            [torch.logdet(flattened_evaluated[i]).unsqueeze(0) for i in range(lazy_tensor.batch_shape.numel())]
        ).view(lazy_tensor.batch_shape)
        actual = actual_inv_quad + actual_logdet

        diff = (res - actual).abs() / actual.abs().clamp(1, math.inf)
        self.assertLess(diff.max().item(), 15e-2)
Exemple #8
0
    def forward(self, z, reverse=False):
        # shape
        batch_size, group_size, n_of_groups = z.size()

        W = self.conv.weight.squeeze()

        if reverse:
            if not hasattr(self, 'set'):
                # Reverse computation
                W_inverse = W.float().inverse()
                W_inverse = W_inverse[..., None]
                self.W_inverse = W_inverse
            z = torch.nn.functional.conv1d(z,
                                           self.W_inverse,
                                           bias=None,
                                           stride=1,
                                           padding=0)
            return z
        else:
            # Forward computation
            log_det_W = batch_size * n_of_groups * torch.logdet(W)
            z = self.conv(z)
            return z, log_det_W
Exemple #9
0
    def forward(self, post: Posterior, comp: Tensor) -> Tensor:
        r"""Calculate approximated log evidence, i.e., log(P(D|theta))

        Args:
            post: training posterior distribution from self.model
            comp: Comparisons pairs, see PairwiseGP.__init__ for more details

        Returns:
            The approximated evidence, i.e., the marginal log likelihood
        """

        model = self.model
        if comp is not model.comparisons:
            raise RuntimeError("Must train on training data")

        f_max = post.mean
        log_posterior = model._posterior_f(f_max)
        part1 = -log_posterior

        part2 = model.covar @ model.likelihood_hess
        eye = torch.eye(part2.size(-1),
                        dtype=model.datapoints.dtype,
                        device=model.datapoints.device).expand(part2.shape)
        part2 = part2 + eye
        part2 = -0.5 * torch.logdet(part2)

        evidence = part1 + part2

        # Sum up mll first so that when adding prior probs it won't
        # propagate and double count
        evidence = evidence.sum()

        # Add log probs of priors on the (functions of) parameters
        for _, prior, closure, _ in self.named_priors():
            evidence = evidence.add(prior.log_prob(closure()).sum())

        return evidence
Exemple #10
0
def JointEntropy(X, Y, device='cpu', eps=0.1):
    '''
    This is the alpha version of matrix MR
    W should be in shape R(M, C1, H, W)
    V should be in shape R(M, C2, h, w)
    '''
    m = X.shape[0]

    X = X.reshape(m, -1)
    Y = Y.reshape(m, -1)

    X = torch.cat((X, Y), dim=1)
    mean = X.mean(dim=0)
    X = X - mean.unsqueeze(dim=0)

    # W in shape(c, m, n)
    I = torch.eye(m).to(device)
    a = X.shape[1] / (m * eps**2)

    rate = torch.logdet(I + a * X.matmul(X.T))

    torch.cuda.empty_cache()

    return rate / 2
Exemple #11
0
    def forward(self, z, reverse=False):
        # shape
        z = z.permute(0,2,1)
        batch_size, group_size, n_of_groups = z.size()

        W = self.conv.weight.squeeze()
        if reverse:
            # if not hasattr(self, 'W_inverse'):
                # Reverse computation
            W_inverse = W.float().inverse()
            W_inverse = Variable(W_inverse[..., None])
            if z.type() == 'torch.cuda.HalfTensor':
                W_inverse = W_inverse.half()
            self.W_inverse = W_inverse
            z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
            z = z.permute(0,2,1)
            return z
        else:
            # Forward computation
            log_det_W = torch.ones_like(z[:,0,:]) * torch.logdet(W)
            z = self.conv(z)
            z = z.permute(0,2,1)

            return z, log_det_W
Exemple #12
0
    def forward(self, z, reverse=False):
        batch_size, group_size, n_of_groups = z.size()
        # Here, group size refers to the channel dimension of the data, and each matrix of channel dimension (i.e. [:, i, :]) is
        # going to be multiplied by the W matrix. The n_of_groups is number of groups, and that's *how many* matrices there are
        # in the channel dimension, and each one will be multiplied by W
        # The larger the group_size value is, the more thorough the "mixing" of the variables is before going back to the AC layer
        # In the extreme case of group_size=1, the variables are never permuted before going back to AC, and the flow would work
        # terribly because nothing would change order, so the same values would go into the WN each step of flow. In the other
        # extreme, where n_of_groups=1, mixing is maximized, but we run the risk of our W matrix being too big and possibly getting
        # numerical instability when we try to invert it, since we're not using very high precision to represent it (float32).

        W = self.conv.weight.squeeze()

        if reverse:
            if not hasattr(self, 'W_inv'):
                W_inv = W.float().inverse()
                W_inv = Variable(W_inv[..., None])
                self.W_inv = W_inv
            z = F.conv1d(z, self.W_inv, bias=None, stride=1, padding=0)
            return z
        else:
            log_det_w = batch_size * n_of_groups * torch.logdet(W)
            z = self.conv(z)
            return z, log_det_w
Exemple #13
0
def MultiChannelRateDistortion_Label(W, Pi, device='cpu', eps=0.1):
    m, c, _, _ = W.shape

    W = W.reshape(m, c, -1).permute(1, 2, 0)

    k, _ = Pi.shape
    n = W.shape[1]
    I = torch.eye(n).to(device)
    # W in shape(c, n, m)
    rate = 0
    for i in range(k):
        trPi = Pi[i].sum() + 1e-8
        a = n / (trPi * eps**2)
        W_k = W.matmul(torch.diag(Pi[i]))
        mean = W_k.sum(dim=2) / trPi
        W_k = W_k - mean.unsqueeze(2)

        rate = rate + torch.logdet(
            I.unsqueeze(0) + a * W_k.matmul(torch.diag(Pi[i])).matmul(
                W_k.transpose(2, 1))).sum() / (m * n * c) * (trPi)

    rate = rate / 2

    return rate
Exemple #14
0
 def _get_logdet_loss(self, M, delta=1e-5):
     G = _get_Gram(M)
     return torch.logdet(G + delta * torch.eye(G.size(-1)).repeat(G.size(0), 1, 1)).mean() #.cuda()
Exemple #15
0
 def discriminative_loss(self, Z: torch.tensor, gamma: float = 1) -> torch.tensor:
     m, p = Z.shape
     identity = torch.eye(p).cuda()
     return torch.logdet(identity + gamma * p / (m * self.eps) * Z.T.matmul(Z)) / 2
Exemple #16
0
    def forward(self) -> Tuple[torch.Tensor, int]:
        unstable = 0

        # get sigma/mean or each level
        sigma_w, mu_w = self.implied_sigma_mu()
        sigma_l2, mu_l2 = self.implied_sigma_mu(suffix="_l2")

        # decompose into 11, 12, 21, 22
        sigma_b, sigma_xx, sigma_yx, mu_b, mu_x = self._split(sigma_l2, mu_l2)

        # cluster FIML -2 * logL (without constants)
        loss = torch.zeros(1, dtype=mu_w.dtype, device=mu_w.device)

        # go through each cluster separately
        data_ys_available = ~torch.isnan(self.data_ys)
        cache_S_ij = {}
        cache_S_j_R_j = {}
        sigma_b_logdet = None
        sigma_b_inv = None
        if not self.naive_implementation:
            sigma_b_logdet = torch.logdet(sigma_b)
            sigma_b_inv = torch.inverse(sigma_b)
        for cluster_slice, batches in self.missing_patterns:
            # get cluster data and define R_j for current cluster j
            cluster_x = self.data_xs[cluster_slice.start, :]
            R_j_index = ~torch.isnan(cluster_x)
            no_cluster = ~R_j_index.any()

            # cache
            key = (
                tuple(R_j_index.tolist()),
                tuple([
                    tuple(x)
                    for x in data_ys_available[cluster_slice, :].tolist()
                ]),
            )
            sigma_j_logdet, sigma_j_inv = cache_S_j_R_j.get(key, (None, None))

            # define S_ij and S_j
            S_ijs = []
            eye_w = torch.eye(mu_w.shape[0],
                              dtype=mu_w.dtype,
                              device=mu_w.device)
            lambda_ijs_logdet_sum = 0.0
            lambda_ijs_inv = []
            A_j = torch.zeros_like(sigma_w)
            for batch_slice in batches:
                size = batch_slice.stop - batch_slice.start
                available = data_ys_available[batch_slice.start]
                S_ij = eye_w[available, :]
                S_ijs.extend([S_ij] * size)

                if self.naive_implementation or sigma_j_logdet is not None:
                    continue

                key_S_ij = tuple(available.tolist())
                lambda_ij_inv, lambda_ij_logdet, a_j = cache_S_ij.get(
                    key_S_ij, (None, None, None))

                if lambda_ij_inv is None:
                    lambda_ij = sigma_w  # no missing data
                    if S_ij.shape[0] != eye_w.shape[0]:
                        # missing data
                        lambda_ij = S_ij.mm(sigma_w.mm(S_ij.t()))
                    lambda_ij_inv = torch.inverse(lambda_ij)
                    lambda_ij_logdet = torch.logdet(lambda_ij)

                    if S_ij.shape[0] != eye_w.shape[0]:
                        # missing data
                        a_j = S_ij.t().mm(lambda_ij_inv.mm(S_ij))
                    else:
                        a_j = lambda_ij_inv
                    cache_S_ij[key_S_ij] = lambda_ij_inv, lambda_ij_logdet, a_j

                lambda_ijs_inv.extend([lambda_ij_inv] * size)
                lambda_ijs_logdet_sum = lambda_ijs_logdet_sum + lambda_ij_logdet * size
                A_j = A_j + a_j * size

            S_j = torch.cat(S_ijs, dim=0)

            # means
            y_j = torch.cat([
                self.data_ys[cluster_slice, :][data_ys_available[
                    cluster_slice, :]][:, None],
                cluster_x[R_j_index, None],
            ])
            mu_y = mu_w + mu_b
            mu_j = torch.cat([S_j.mm(mu_y), mu_x[R_j_index]])
            mean_diff = y_j - mu_j
            G_yj = mean_diff.mm(mean_diff.t())

            if sigma_j_logdet is None and not self.naive_implementation:
                sigma_b_inv_A_j = sigma_b_inv + A_j
                B_j = torch.inverse(sigma_b_inv_A_j)
                C_j = eye_w - A_j.mm(B_j)
                D_j = C_j.mm(A_j)
                lambda_inv = block_diag(lambda_ijs_inv)
                V_j_inv = lambda_inv - lambda_inv.mm(
                    S_j.mm(B_j.mm(S_j.t().mm(lambda_inv))))

                if no_cluster:
                    # no cluster
                    sigma_11_j = V_j_inv
                    sigma_21_j = torch.empty(0,
                                             device=sigma_11_j.device,
                                             dtype=sigma_11_j.dtype)
                    sigma_22_1 = torch.empty([0, 0],
                                             device=sigma_11_j.device,
                                             dtype=sigma_11_j.dtype)
                    sigma_22_inv = sigma_21_j

                else:
                    # normal case
                    sigma_22_1 = (sigma_xx - sigma_yx.t().mm(
                        D_j.mm(sigma_yx)))[R_j_index, :][:, R_j_index]
                    sigma_22_inv = torch.inverse(sigma_22_1)
                    sigma_jyx = S_j.mm(sigma_yx[:, R_j_index])
                    sigma_11_j = (V_j_inv.mm(
                        sigma_jyx.mm(sigma_22_inv.mm(
                            sigma_jyx.t().mm(V_j_inv)))) + V_j_inv)
                    sigma_21_j = -sigma_22_inv.mm(sigma_jyx.t().mm(V_j_inv))

                sigma_j_inv = torch.cat(
                    [
                        torch.cat([sigma_11_j, sigma_21_j]),
                        torch.cat([sigma_21_j.t(), sigma_22_inv]),
                    ],
                    dim=1,
                )

                sigma_j_logdet = (lambda_ijs_logdet_sum + sigma_b_logdet +
                                  torch.logdet(sigma_b_inv_A_j) +
                                  torch.logdet(sigma_22_1))
                cache_S_j_R_j[key] = (sigma_j_logdet, sigma_j_inv)

            elif sigma_j_logdet is None:
                # naive
                sigma_j = S_j.mm(sigma_b.mm(S_j.t())) + block_diag(
                    [S_ij.mm(sigma_w.mm(S_ij.t())) for S_ij in S_ijs])
                if not no_cluster:
                    sigma_j_12 = S_j.mm(sigma_yx[:, R_j_index])
                    sigma_j_21 = sigma_j_12.t()
                    sigma_j_22 = sigma_xx[R_j_index, :][:, R_j_index]
                    sigma_j = torch.cat(
                        [
                            torch.cat([sigma_j, sigma_j_21]),
                            torch.cat([sigma_j_12, sigma_j_22]),
                        ],
                        dim=1,
                    )
                sigma_j_logdet = torch.logdet(sigma_j)
                sigma_j_inv = torch.inverse(sigma_j)
                cache_S_j_R_j[key] = (sigma_j_logdet, sigma_j_inv)

            loss_current = sigma_j_logdet + torch.trace(sigma_j_inv.mm(G_yj))
            unstable += loss_current.detach().item() < 0
            loss = loss + loss_current.clamp(min=0.0)

        return loss, unstable
Exemple #17
0
 def Gaussian_NLLH(pred, cov, label):
     loss = mean( matmul( sub(label,pred).unsqueeze(1), matmul( inverse(cov), sub(label,pred).unsqueeze(2) ) ) + logdet(cov) )
     return loss
Exemple #18
0
def log_det_other(x):
    return torch.logdet(x)
def general_kl_divergence(mu_1, cov_1, mu_2, cov_2):
    mu_diff = (mu_1 - mu_2).view(-1, 1)
    cov_2_inverse = torch.inverse(cov_2)
    return -0.5 * (torch.logdet(cov_1) - torch.logdet(cov_2) + mu_1.size(0) -
                   torch.trace(cov_1 @ cov_2_inverse) -
                   mu_diff.t() @ cov_2_inverse @ mu_diff)
Exemple #20
0
    def _logdetgrad(self, z, x):
        """Returns logdet|dz/dx|."""

        with torch.enable_grad():
            if (self.brute_force or not self.training) and (x.ndimension() == 2 and x.shape[1] <= 10):
                x = x.requires_grad_(True)
                z = z.requires_grad_(True)
                Fx = x + self.nnet_x(x)
                Jx = batch_jacobian(Fx, x)
                logdet_x = torch.logdet(Jx)

                Fz = z + self.nnet_z(z)
                Jz = batch_jacobian(Fz, z)
                logdet_z = torch.logdet(Jz)

                return (logdet_x - logdet_z).view(-1, 1)
            if self.n_dist == 'geometric':
                geom_p = torch.sigmoid(self.geom_p).item()
                sample_fn = lambda m: geometric_sample(geom_p, m)
                rcdf_fn = lambda k, offset: geometric_1mcdf(geom_p, k, offset)
            elif self.n_dist == 'poisson':
                lamb = self.lamb.item()
                sample_fn = lambda m: poisson_sample(lamb, m)
                rcdf_fn = lambda k, offset: poisson_1mcdf(lamb, k, offset)

            if self.training:
                if self.n_power_series is None:
                    # Unbiased estimation.
                    lamb = self.lamb.item()
                    n_samples = sample_fn(self.n_samples)
                    n_power_series = max(n_samples) + self.n_exact_terms
                    coeff_fn = lambda k: 1 / rcdf_fn(k, self.n_exact_terms) * \
                        sum(n_samples >= k - self.n_exact_terms) / len(n_samples)
                else:
                    # Truncated estimation.
                    n_power_series = self.n_power_series
                    coeff_fn = lambda k: 1.
            else:
                # Unbiased estimation with more exact terms.

                lamb = self.lamb.item()
                n_samples = sample_fn(self.n_samples)
                n_power_series = max(n_samples) + self.n_exact_terms_test
                coeff_fn = lambda k: 1 / rcdf_fn(k, self.n_exact_terms_test) * \
                    sum(n_samples >= k - self.n_exact_terms_test) / len(n_samples)

            if not self.exact_trace:
                ####################################
                # Power series with trace estimator.
                ####################################
                # vareps_x = torch.randn_like(x)
                # vareps_z = torch.randn_like(z)
                vareps_x = torch.distributions.bernoulli.Bernoulli(torch.Tensor([0.5])).sample(x.shape).reshape(x.shape).to(x) * 2 - 1
                vareps_z = torch.distributions.bernoulli.Bernoulli(torch.Tensor([0.5])).sample(z.shape).reshape(z.shape).to(z) * 2 - 1

                # Choose the type of estimator.
                if self.training and self.neumann_grad:
                    estimator_fn = neumann_logdet_estimator
                else:
                    estimator_fn = basic_logdet_estimator

                # Do backprop-in-forward to save memory.
                if self.training and self.grad_in_forward:
                    logdet_x = mem_eff_wrapper(
                        estimator_fn, self.nnet_x, x, n_power_series, vareps_x, coeff_fn, self.training
                    )
                    logdet_z = mem_eff_wrapper(
                        estimator_fn, self.nnet_z, z, n_power_series, vareps_z, coeff_fn, self.training
                    )
                    logdetgrad = logdet_x - logdet_z
                else:
                    x = x.requires_grad_(True)
                    z = z.requires_grad_(True)
                    Fx = self.nnet_x(x)
                    Fz = self.nnet_z(z)
                    logdet_x = estimator_fn(Fx, x, n_power_series, vareps_x, coeff_fn, self.training)
                    logdet_z = estimator_fn(Fz, z, n_power_series, vareps_z, coeff_fn, self.training)
                    logdetgrad = logdet_x - logdet_z
            else:
                ############################################
                # Power series with exact trace computation.
                ############################################
                x = x.requires_grad_(True)
                z = z.requires_grad_(True)
                Fx = self.nnet_x(x)
                Jx = batch_jacobian(Fx, x)
                logdetJx = batch_trace(Jx)
                Jx_k = Jx
                for k in range(2, n_power_series + 1):
                    Jx_k = torch.bmm(Jx, Jx_k)
                    logdetJx = logdetJx + (-1)**(k+1) / k * coeff_fn(k) * batch_trace(Jx_k)
                Fz = self.nnet_z(z)
                Jz = batch_jacobian(Fz, z)
                logdetJz = batch_trace(Jz)
                Jz_k = Jz
                for k in range(2, n_power_series + 1):
                    Jz_k = torch.bmm(Jz, Jz_k)
                    logdetJz = logdetJz + (-1)**(k+1) / k * coeff_fn(k) * batch_trace(Jz_k)
                logdetgrad = logdetJx - logdetJz

            if self.training and self.n_power_series is None:
                self.last_n_samples.copy_(torch.tensor(n_samples).to(self.last_n_samples))
                estimator = logdetgrad.detach()
                self.last_firmom.copy_(torch.mean(estimator).to(self.last_firmom))
                self.last_secmom.copy_(torch.mean(estimator**2).to(self.last_secmom))
            return logdetgrad.view(-1, 1) 
Exemple #21
0
    def _logits_proposal_posterior(
        means_pp: Tensor,
        precisions_pp: Tensor,
        covariances_pp: Tensor,
        logits_p: Tensor,
        means_p: Tensor,
        precisions_p: Tensor,
        logits_d: Tensor,
        means_d: Tensor,
        precisions_d: Tensor,
    ):
        """
        Return the component weights (i.e. logits) of the proposal posterior.

        Args:
            means_pp: Means of the proposal posterior.
            precisions_pp: Precision matrices of the proposal posterior.
            covariances_pp: Covariance matrices of the proposal posterior.
            logits_p: Component weights (i.e. logits) of the proposal distribution.
            means_p: Means of the proposal distribution.
            precisions_p: Precision matrices of the proposal distribution.
            logits_d: Component weights (i.e. logits) of the density estimator.
            means_d: Means of the density estimator.
            precisions_d: Precision matrices of the density estimator.

        Returns: Component weights of the proposal posterior. L*K terms.
        """

        num_comps_p = precisions_p.shape[1]
        num_comps_d = precisions_d.shape[1]

        # Compute log(alpha_i * beta_j)
        logits_p_rep = logits_p.repeat_interleave(num_comps_d, dim=1)
        logits_d_rep = logits_d.repeat(1, num_comps_p)
        logit_factors = logits_p_rep + logits_d_rep

        # Compute sqrt(det()/(det()*det()))
        logdet_covariances_pp = torch.logdet(covariances_pp)
        logdet_covariances_p = -torch.logdet(precisions_p)
        logdet_covariances_d = -torch.logdet(precisions_d)

        # Repeat the proposal and density estimator terms such that there are LK terms.
        # Same trick as has been used above.
        logdet_covariances_p_rep = logdet_covariances_p.repeat_interleave(
            num_comps_d, dim=1
        )
        logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_p)

        log_sqrt_det_ratio = 0.5 * (
            logdet_covariances_pp
            - (logdet_covariances_p_rep + logdet_covariances_d_rep)
        )

        # Compute for proposal, density estimator, and proposal posterior:
        # mu_i.T * P_i * mu_i
        exponent_p = batched_mixture_vmv(precisions_p, means_p)
        exponent_d = batched_mixture_vmv(precisions_d, means_d)
        exponent_pp = batched_mixture_vmv(precisions_pp, means_pp)

        # Extend proposal and density estimator exponents to get LK terms.
        exponent_p_rep = exponent_p.repeat_interleave(num_comps_d, dim=1)
        exponent_d_rep = exponent_d.repeat(1, num_comps_p)
        exponent = -0.5 * (exponent_p_rep + exponent_d_rep - exponent_pp)

        logits_pp = logit_factors + log_sqrt_det_ratio + exponent

        return logits_pp
Exemple #22
0
    print("running tests")
    N = 4
    batch = 1

    _dtype = t.float64

    tmp1 = t.randn(N, N, dtype=_dtype)
    tmp1 = tmp1 @ tmp1.t()
    tmp2 = t.randn(N, N, dtype=_dtype)
    tmp2 = tmp2 @ tmp2.t()

    #### Test KFac
    Kpt = KFac(PositiveDefiniteMatrix(tmp1), PositiveDefiniteMatrix(tmp2))
    Knp = t.tensor(np.kron(tmp1.numpy(), tmp2.numpy()))
    assert t.allclose(Kpt.full(), Knp)
    assert t.allclose(Kpt.logdet(), t.logdet(Kpt.full()))

    #### Test other
    m1 = Scale(N, 0.5, dtype=_dtype)
    m2 = PositiveDefiniteMatrix(tmp1)
    m3 = m2.chol()
    m4 = m3.t()
    m5 = Product(m1, m2, m3)

    def tests(W):
        #### Test matrix-vector operations
        x = t.randn(batch, N, 1, dtype=_dtype)

        # Test mm
        assert t.allclose(W.full() @ x, W(x))
 def log_normalizer(self) -> th.FloatTensor:
     return th.logdet(self.L + th.eye(self.N))
Exemple #24
0
    def _logits_posterior(
        means_post: Tensor,
        precisions_post: Tensor,
        covariances_post: Tensor,
        logits_pp: Tensor,
        means_pp: Tensor,
        precisions_pp: Tensor,
        logits_d: Tensor,
        means_d: Tensor,
        precisions_d: Tensor,
    ):
        r"""
        Return the component weights (i.e. logits) of the MoG posterior.

        $\alpha_k^\prime = \frac{ \alpha_k exp(-0.5 c_k) }{ \sum{j} \alpha_j exp(-0.5
        c_j) } $
        with
        $c_k = logdet(S_k) - logdet(S_0) - logdet(S_k^\prime) +
             + m_k^T P_k m_k - m_0^T P_0 m_0 - m_k^\prime^T P_k^\prime m_k^\prime$
        (see eqs. (25, 26) in Appendix C of [1])

        Args:
            means_post: Means of the posterior.
            precisions_post: Precision matrices of the posterior.
            covariances_post: Covariance matrices of the posterior.
            logits_pp: Component weights (i.e. logits) of the proposal prior.
            means_pp: Means of the proposal prior.
            precisions_pp: Precision matrices of the proposal prior.
            logits_d: Component weights (i.e. logits) of the density estimator.
            means_d: Means of the density estimator.
            precisions_d: Precision matrices of the density estimator.

        Returns: Component weights of the proposal posterior.
        """

        num_comps_pp = precisions_pp.shape[1]
        num_comps_d = precisions_d.shape[1]

        # Compute the ratio of the logits similar to eq (10) in Appendix A.1 of [2]
        logits_pp_rep = logits_pp.repeat_interleave(num_comps_d, dim=1)
        logits_d_rep = logits_d.repeat(1, num_comps_pp)
        logit_factors = logits_d_rep - logits_pp_rep

        # Compute the log-determinants
        logdet_covariances_post = torch.logdet(covariances_post)
        logdet_covariances_pp = -torch.logdet(precisions_pp)
        logdet_covariances_d = -torch.logdet(precisions_d)

        # Repeat the proposal and density estimator terms such that there are LK terms.
        # Same trick as has been used above.
        logdet_covariances_pp_rep = logdet_covariances_pp.repeat_interleave(
            num_comps_d, dim=1
        )
        logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_pp)

        log_sqrt_det_ratio = 0.5 * (  # similar to eq (14) in Appendix A.1 of [2]
            logdet_covariances_post
            + logdet_covariances_pp_rep
            - logdet_covariances_d_rep
        )

        # Compute for proposal, density estimator, and proposal posterior:
        exponent_pp = utils.batched_mixture_vmv(
            precisions_pp, means_pp  # m_0 in eq (26) in Appendix C of [1]
        )
        exponent_d = utils.batched_mixture_vmv(
            precisions_d, means_d  # m_k in eq (26) in Appendix C of [1]
        )
        exponent_post = utils.batched_mixture_vmv(
            precisions_post, means_post  # m_k^\prime in eq (26) in Appendix C of [1]
        )

        # Extend proposal and density estimator exponents to get LK terms.
        exponent_pp_rep = exponent_pp.repeat_interleave(num_comps_d, dim=1)
        exponent_d_rep = exponent_d.repeat(1, num_comps_pp)
        exponent = -0.5 * (
            exponent_d_rep - exponent_pp_rep - exponent_post  # eq (26) in [1]
        )

        logits_post = logit_factors + log_sqrt_det_ratio + exponent
        return logits_post
Exemple #25
0
 def entropy(self):
     simple_tst = StudentT_torch(self.df)
     H = self.coeff * torch.logdet(self.S) + self.d * simple_tst.entropy()
     return H
Exemple #26
0
    def get_laplace_logdet_term(self):
        inverse_posterior_covariance = self.get_inverse_posterior_covariance()
        logdet_term = -0.5 * torch.logdet(inverse_posterior_covariance)

        return logdet_term
Exemple #27
0
 def logdet(self):
     return torch.logdet(self)
Exemple #28
0
def logdet_torch(A):
    #return torch.log(torch.det(A))
    return torch.logdet(A)
Exemple #29
0
def blogdet(A):
    return torch.stack([torch.logdet(a) for a in A])
Exemple #30
0
 def log_prob(self, X):
     logdetX = torch.logdet(X)
     expon = torch.trace(torch.matmul(self.W_inv, X))
     return self.C + 0.5 * (self.df - self.p - 1) * logdetX - 0.5 * expon