Ejemplo n.º 1
0
    def variance(self, n_s):
        """
        Stochastic approximator of predictive variance.
         Follows "Massively Scalable GPs"
        Args:
            n_s (int): Number of iterations to run stochastic approximation

        Returns: Approximate predictive variance at grid points

        """

        if self.root_eigdecomp is None:
            self.sqrt_eig()
        if self.obs_idx is not None:
            root_K = self.root_eigdecomp[self.obs_idx, :]
        else:
            root_K = self.root_eigdecomp

        diag = kron_list_diag(self.Ks)
        samples = []
        for i in range(n_s):
            g_m = np.random.normal(size=self.m)
            g_n = np.random.normal(size=self.n)
            right_side = np.sqrt(self.W).dot(np.dot(root_K, g_m)) +\
                         np.sqrt(self.noise) * g_n
            r = self.opt.cg(self.Ks, right_side)
            if self.obs_idx is not None:
                Wr = np.zeros(self.m)
                Wr[self.obs_idx] = np.multiply(np.sqrt(self.W), r)
            else:
                Wr = np.multiply(np.sqrt(self.W), r)
            samples.append(kron_mvp(self.Ks, Wr))
        var = np.var(samples, axis=0)
        return np.clip(diag - var, 0, 1e12).flatten(), var
Ejemplo n.º 2
0
    def __init__(self, X, y, kernels, likelihood, mu=None, obs_idx=None):
        """
        Args:
            kernel (): kernel function
            likelihood (): likelihood function. Requires log_like() function
            X (): data
            y (): responses
            mu (): prior mean
            noise (): noise variance
            obs_idx (): if dealing with partial grid, indices of grid that are observed
            verbose (): print or not
        """

        super(MFSVI, self).__init__(X, y, kernels, likelihood, mu, obs_idx)
        self.q_S = np.log(kron_list_diag(self.Ks))
        self.mu_params, self.s_params = (None, None)
Ejemplo n.º 3
0
    def marginal(self):
        """
        Calculates marginal likelihood
        Returns: marginal likelihood

        """
        if self.alpha is None:
            self.solve()
        if self.eigvals is None:
            self.eig_decomp()
        mu = self.mu
        if self.obs_idx is not None:
            mu = self.mu[self.obs_idx]
        det = 0.5 * np.sum(np.log(kron_list_diag(self.eigvals) + self.noise))
        fit = 0.5 * np.dot(self.y - mu, self.alpha)
        prior = 0
        for kern in self.kernels:
            prior += kern.log_prior(kern.params)
        return -det - fit - prior
Ejemplo n.º 4
0
    def variance_pmap(self, n_s=30):
        """
        Stochastic approximator of predictive variance.
         Follows "Massively Scalable GPs"
        Args:
            n_s (int): Number of iterations to run stochastic approximation

        Returns: Approximate predictive variance at grid points

        """
        if self.eigvals or self.eigvecs is None:
            self.eig_decomp()

        Q = self.eigvecs
        Q_t = [v.T for v in self.eigvecs]
        Vr = [np.nan_to_num(np.sqrt(e)) for e in self.eigvals]

        diag = kron_list_diag(self.Ks) + self.noise
        samples = []

        for i in range(n_s):
            g_m = np.random.normal(size=self.m)
            g_n = np.random.normal(size=self.n)

            Kroot_g = kron_mvp(Q, kron_mvp(Vr, kron_mvp(Q_t, g_m)))
            if self.obs_idx is not None:
                Kroot_g = Kroot_g[self.obs_idx]
            right_side = Kroot_g + np.sqrt(self.noise) * g_n

            r = self.cg_opt.cg(self.Ks, right_side)
            if self.obs_idx is not None:
                Wr = np.zeros(self.m)
                Wr[self.obs_idx] = r
            else:
                Wr = r
            samples.append(kron_mvp(self.Ks, Wr))

        est = np.var(samples, axis=0)
        return np.clip(diag - est, 0, a_max=None).flatten()
Ejemplo n.º 5
0
 def construct_Ks(self, kernels=None, noise=1e-2):
     super(SVIBase, self).construct_Ks()
     self.Ks = [K + np.diag(np.ones(K.shape[0])) * noise for K in self.Ks]
     self.K_invs = [np.linalg.inv(K) for K in self.Ks]
     self.k_inv_diag = kron_list_diag(self.K_invs)
     self.det_K = self.log_det_K()