Пример #1
0
def soft_clamp(x : torch.Tensor, _min=None, _max=None):
    # clamp tensor values while mataining the gradient
    if _max is not None:
        x = _max - F.softplus(_max - x)
    if _min is not None:
        x = _min + F.softplus(x - _min)
    return x
Пример #2
0
    def loss(self):
        # directly copy GPy implementation
        num_inducing = self.inducing.size(0)
        num_data = self.yn.size(0)

        variance = F.softplus(self.free_variance)
        Kmm = self.kernel(self.inducing)
        Knn = self.kernel(self.Xn, diag=True)
        Knm = self.kernel(self.Xn, self.inducing)
        U = Knm

        Lm = stable_cholesky(Kmm)
        LiUT = triangular_solve(U.t(), Lm)

        sigma_star = Knn.squeeze() + variance - torch.sum(LiUT ** 2, 0)
        beta = stable_divide(1., sigma_star)

        tmp = LiUT * torch.sqrt(beta)
        A = tmp @ tmp.t() + torch.eye(num_inducing, device=self.device)
        LA = stable_cholesky(A)

        URiy = (U.t() * beta) @ self.yn
        tmp = triangular_solve(URiy, Lm)
        b = triangular_solve(tmp, LA)

        loss = 0.5 * num_data * LOG_2PI + torch.sum(torch.log(torch.diag(LA))) - 0.5 * torch.sum(
            torch.log(beta)) + 0.5 * torch.sum((self.yn.t() * torch.sqrt(beta)) ** 2) - 0.5 * torch.sum(b ** 2)
        return loss
Пример #3
0
    def forward(self, Xt, diag=False):
        Kmm = self.kernel(self.inducing)
        Kmn = self.kernel(self.inducing, self.Xn)
        Ktm = self.kernel(Xt, self.inducing)
        Knn = self.kernel(self.Xn, diag=True)
        variance = F.softplus(self.free_variance)
        Lm = stable_cholesky(Kmm)
        LiUT = triangular_solve(Kmn, Lm)
        sigma_star = Knn.squeeze() + variance - torch.sum(LiUT ** 2, 0)
        sigma_star_sqrt_inv = stable_sqrt(stable_divide(1., sigma_star))

        Lmi_Kmn = LiUT
        sigma_Knm_Lmi = sigma_star_sqrt_inv.reshape(-1, 1) * Lmi_Kmn.t()
        woodbury_chol = stable_cholesky(
            torch.eye(self.inducing.size(0), device=Xt.device) + sigma_Knm_Lmi.t() @ sigma_Knm_Lmi)

        Lmi_Kmt = triangular_solve(Ktm.t(), Lm)
        left = triangular_solve(Lmi_Kmt, woodbury_chol)
        tmp = sigma_Knm_Lmi.t() @ (sigma_star_sqrt_inv.unsqueeze(-1) * self.yn)
        right = triangular_solve(tmp, woodbury_chol)
        mean = left.t() @ right

        if diag:
            # do sth
            Ktt_diag = self.kernel(Xt, diag=True)
            tmp = triangular_solve(Lmi_Kmt, woodbury_chol)
            var = Ktt_diag - torch.sum(Lmi_Kmt ** 2, dim=0).unsqueeze(-1) + torch.sum(tmp ** 2, dim=0).unsqueeze(-1)
            return mean, var
        else:
            Ktt = self.kernel(Xt, Xt)
            tmp = triangular_solve(Lmi_Kmt, woodbury_chol)
            cov = Ktt - Lmi_Kmt.t() @ Lmi_Kmt + tmp.t() @ tmp
            return mean, cov
Пример #4
0
    def forward(self, Xt, diag=False):
        Kmm = self.kernel(self.inducing)
        Kmn = self.kernel(self.inducing, self.Xn)
        variance = F.softplus(self.free_variance)
        sq_var_i = stable_sqrt(1. / variance)

        L = stable_cholesky(Kmm)
        A = sq_var_i * triangular_solve(Kmn, L)
        B = torch.eye(self.inducing.size(0),
                      device=self.inducing.device) + A @ A.t()
        Lb = stable_cholesky(B)
        c = sq_var_i * triangular_solve(A @ self.yn, Lb)

        Ktm = self.kernel(Xt, self.inducing)
        Li_Kmt = triangular_solve(Ktm.t(), L)
        Lbi_Li_Kmt = triangular_solve(Li_Kmt, Lb)
        mean = Lbi_Li_Kmt.t() @ c

        if diag:
            # do sth
            Ktt_diag = self.kernel(Xt, diag=True)
            var = Ktt_diag - torch.sum(Li_Kmt**2, 0).unsqueeze(-1) + torch.sum(
                Lbi_Li_Kmt**2, 0).unsqueeze(-1)
            return mean, var
        else:
            Ktt = self.kernel(Xt, Xt)
            cov = Ktt - Li_Kmt.t() @ Li_Kmt + Lbi_Li_Kmt.t() @ Lbi_Li_Kmt
            return mean, cov
Пример #5
0
    def loss(self):
        Kmm = self.kernel(self.inducing)
        Kmn = self.kernel(self.inducing, self.Xn)
        Knn = self.kernel(self.Xn, diag=True)
        variance = F.softplus(self.free_variance)
        sq_var_i = torch.sqrt(1. / variance)

        L = stable_cholesky(Kmm)
        A = sq_var_i * triangular_solve(Kmn, L)
        B = torch.eye(self.inducing.size(0),
                      device=self.inducing.device) + A @ A.t()
        Lb = stable_cholesky(B)
        c = sq_var_i * triangular_solve(A @ self.yn, Lb)
        double_loss = self.Xn.size(0) * LOG_2PI + logdet(
            Lb) + self.Xn.size(0) * stable_log(variance) + torch.sum(
                (sq_var_i * self.yn)**2) - torch.sum(
                    c**2) + 1. / variance * torch.sum(Knn) - torch.sum(A**2)
        return 0.5 * double_loss