def recompute_posterior_fr(alpha: np.ndarray, beta: np.ndarray, K: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Recompute the posterior approximation (for the full rank approximation) mean: K alpha, covariance inv(K + beta) :param alpha: Alpha vector used to parametrize the posterior approximation :param beta: Beta vector/matrix used to parametrize the posterior approximation :param K: prior covariance :return: Tuple containing the mean and cholesky of the covariance, its inverse and derivatives of the KL divergence with respect to beta and alpha """ N = K.shape[0] L = choleskies._flat_to_triang_pure(beta) assert(L.shape[0]==1) L = L[0,:,:] lam_sqrt= np.diag(L) lam = lam_sqrt**2 # Compute Mean m = K @ alpha jitter = 1e-5 dKL_da = m.copy() Kinv = np.linalg.inv(K+ np.eye(N)*jitter) L_inv = np.linalg.inv(L) Sigma = np.empty((alpha.size, alpha.shape[0])) Lamda_full_rank = np.dot(L, L.T) dKL_db_triang = -dL_fr(L, 0.5*(np.linalg.inv(Lamda_full_rank) - Kinv), None, None, None) mat1 = np.linalg.inv(K + Lamda_full_rank) #Sigma = np.linalg.inv(Kinv + np.linalg.inv(Lamda_full_rank)) Sigma = Lamda_full_rank # Compute KL KL = 0.5*(-N + (m.T@Kinv@m) + np.trace(Kinv @ Sigma) - np.log(np.linalg.det(Sigma @ Kinv))) dKL_db = choleskies._triang_to_flat_pure(dKL_db_triang) return m, L, L_inv, KL, dKL_db, dKL_da
def test_flat_to_triang(self): L1 = choleskies._flat_to_triang_pure(self.flat) L2 = choleskies._flat_to_triang_cython(self.flat) np.testing.assert_allclose(L1, L2)
# settings #step_size= 1e-5/N #itt_max = 30000 num_samples = 1 num_samples_swa = 1 num_params = K means1 = np.ones((num_params,)) means2 = np.zeros((num_params,)) means = means_all[n*num_params:(n+1)*num_params] betas = betas_all[num_params*(num_params+1)*n//2:num_params*(num_params+1)*(n+1)//2] tmpL = np.tril(np.ones((K, K))) L = choleskies._flat_to_triang_pure(betas)[0,:] Sigma = L @ L.T #means = means_list[n] #sigmas = sigmas_list[n] params = [means.copy(), betas.copy()] means_vb_clr, betas_vb_clr = means.copy(), betas.copy() means_vb_swa, betas_vb_swa = means.copy(), betas.copy() means_vb_rms, betas_vb_rms = means.copy(), betas.copy() L_vb_clr = L.copy() L_vb_swa = L.copy() L_vb_rms = L.copy() params_constant_lr = [means_vb_clr, betas_vb_clr] params_rms = [means_vb_rms, betas_vb_rms]