def local_meanfield(global_natparam, node_potentials): # global_natparam = \eta_{\theta}^0 # node_potentials = r(\phi, y) dirichlet_natparam, niw_natparams = global_natparam node_potentials = gaussian.pack_dense(*node_potentials) #### compute expected global parameters using current global factors # label_global = E_{q(\pi)}[t(\pi)] here q(\pi) is posterior which is dirichlet with parameter dirichlet_natparam and t is [log\pi_1, log\pi_2....] # gaussian_globals = E_{q(\mu, \Sigma)}[t(\mu, \Sigma)] here q(\mu, \Sigma) is posterior which is NIW label_global = dirichlet.expectedstats(dirichlet_natparam) gaussian_globals = niw.expectedstats(niw_natparams) #### compute mean field fixed point using unboxed node_potentials label_stats = meanfield_fixed_point(label_global, gaussian_globals, getval(node_potentials)) #### compute values that depend directly on boxed node_potentials at optimum gaussian_natparam, gaussian_stats, gaussian_kl = \ gaussian_meanfield(gaussian_globals, node_potentials, label_stats) label_natparam, label_stats, label_kl = \ label_meanfield(label_global, gaussian_globals, gaussian_stats) #### collect sufficient statistics for gmm prior (sum across conditional iid) dirichlet_stats = np.sum(label_stats, 0) niw_stats = np.tensordot(label_stats, gaussian_stats, [0, 0]) local_stats = label_stats, gaussian_stats prior_stats = dirichlet_stats, niw_stats natparam = label_natparam, gaussian_natparam kl = label_kl + gaussian_kl return local_stats, prior_stats, natparam, kl
def local_meanfield(global_stats, node_potentials): label_global, gaussian_globals = global_stats node_potentials = gaussian.pack_dense(*node_potentials) def make_fpfun((label_global, gaussian_globals, node_potentials)): return lambda (local_natparam, local_stats, kl): \ meanfield_update(label_global, gaussian_globals, node_potentials, local_stats[0]) x0 = initialize_meanfield(label_global, gaussian_globals, node_potentials) kl_diff = lambda a, b: abs(a[2]-b[2]) (label_natparam, gaussian_natparam), (label_stats, gaussian_stats), _ = \ fixed_point(make_fpfun, (label_global, gaussian_globals, node_potentials), x0, kl_diff, tol=1e-3) # collect sufficient statistics for gmm prior (sum across conditional iid) dirichlet_stats = np.sum(label_stats, 0) niw_stats = np.tensordot(label_stats, gaussian_stats, [0, 0]) local_stats = label_stats, gaussian_stats prior_stats = dirichlet_stats, niw_stats natparam = label_natparam, gaussian_natparam kl = local_kl(getval(gaussian_globals), getval(label_global), label_natparam, gaussian_natparam, label_stats, gaussian_stats) return local_stats, prior_stats, natparam, kl
def local_meanfield(global_natparam, node_potentials): dirichlet_natparam, niw_natparams = global_natparam node_potentials = gaussian.pack_dense(*node_potentials) # compute expected global parameters using current global factors label_global = dirichlet.expectedstats(dirichlet_natparam) gaussian_globals = niw.expectedstats(niw_natparams) # compute mean field fixed point using unboxed node_potentials label_stats = meanfield_fixed_point(label_global, gaussian_globals, getval(node_potentials)) # compute values that depend directly on boxed node_potentials at optimum gaussian_natparam, gaussian_stats, gaussian_kl = \ gaussian_meanfield(gaussian_globals, node_potentials, label_stats) label_natparam, label_stats, label_kl = \ label_meanfield(label_global, gaussian_globals, gaussian_stats) # collect sufficient statistics for gmm prior (sum across conditional iid) dirichlet_stats = np.sum(label_stats, 0) niw_stats = np.tensordot(label_stats, gaussian_stats, [0, 0]) local_stats = label_stats, gaussian_stats prior_stats = dirichlet_stats, niw_stats natparam = label_natparam, gaussian_natparam kl = label_kl + gaussian_kl return local_stats, prior_stats, natparam, kl
def mc_elbo(pgm_params, i): #Here nn_potentials are just the sufficient stats of the data x = get_batch(i) xxT = np.einsum('ij,ik->ijk', x, x) n = np.ones(x.shape[0]) if x.ndim == 2 else 1. nn_potentials = pack_dense(xxT, x, n, n) saved.stats, global_kl, local_kl = run_inference( pgm_prior, pgm_params, nn_potentials) return (-global_kl - num_batches * local_kl) / num_datapoints #CHECK
def check_params(natparam): natparam2 = pack_dense(*unpack_dense(natparam)) assert np.allclose(natparam, natparam2)
def rand_gaussian(n): J = rand_psd(n) + n * np.eye(n) h = npr.randn(n) return pack_dense(-1./2*J, h)
def rand_gaussian(n): J = rand_psd(n) + n * np.eye(n) h = npr.randn(n) return pack_dense(-1. / 2 * J, h)