예제 #1
0
def local_meanfield(global_natparam, node_potentials):
    # global_natparam = \eta_{\theta}^0
    # node_potentials = r(\phi, y)
    
    dirichlet_natparam, niw_natparams = global_natparam
    node_potentials = gaussian.pack_dense(*node_potentials)

    #### compute expected global parameters using current global factors
    # label_global = E_{q(\pi)}[t(\pi)] here q(\pi) is posterior which is dirichlet with parameter dirichlet_natparam and t is [log\pi_1, log\pi_2....]
    # gaussian_globals = E_{q(\mu, \Sigma)}[t(\mu, \Sigma)] here q(\mu, \Sigma) is posterior which is NIW    
    label_global = dirichlet.expectedstats(dirichlet_natparam)
    gaussian_globals = niw.expectedstats(niw_natparams)

    #### compute mean field fixed point using unboxed node_potentials
    label_stats = meanfield_fixed_point(label_global, gaussian_globals, getval(node_potentials))

    #### compute values that depend directly on boxed node_potentials at optimum
    gaussian_natparam, gaussian_stats, gaussian_kl = \
        gaussian_meanfield(gaussian_globals, node_potentials, label_stats)
    label_natparam, label_stats, label_kl = \
        label_meanfield(label_global, gaussian_globals, gaussian_stats)

    #### collect sufficient statistics for gmm prior (sum across conditional iid)
    dirichlet_stats = np.sum(label_stats, 0)
    niw_stats = np.tensordot(label_stats, gaussian_stats, [0, 0])

    local_stats = label_stats, gaussian_stats
    prior_stats = dirichlet_stats, niw_stats
    natparam = label_natparam, gaussian_natparam
    kl = label_kl + gaussian_kl

    return local_stats, prior_stats, natparam, kl
예제 #2
0
def local_meanfield(global_stats, node_potentials):
    label_global, gaussian_globals = global_stats
    node_potentials = gaussian.pack_dense(*node_potentials)

    def make_fpfun((label_global, gaussian_globals, node_potentials)):
        return lambda (local_natparam, local_stats, kl): \
            meanfield_update(label_global, gaussian_globals, node_potentials, local_stats[0])

    x0 = initialize_meanfield(label_global, gaussian_globals, node_potentials)

    kl_diff = lambda a, b: abs(a[2]-b[2])

    (label_natparam, gaussian_natparam), (label_stats, gaussian_stats), _ = \
        fixed_point(make_fpfun, (label_global, gaussian_globals, node_potentials), x0, kl_diff, tol=1e-3)

    # collect sufficient statistics for gmm prior (sum across conditional iid)
    dirichlet_stats = np.sum(label_stats, 0)
    niw_stats = np.tensordot(label_stats, gaussian_stats, [0, 0])

    local_stats = label_stats, gaussian_stats
    prior_stats = dirichlet_stats, niw_stats
    natparam = label_natparam, gaussian_natparam
    kl = local_kl(getval(gaussian_globals), getval(label_global),
        label_natparam, gaussian_natparam, label_stats, gaussian_stats)

    return local_stats, prior_stats, natparam, kl
예제 #3
0
def local_meanfield(global_natparam, node_potentials):
    dirichlet_natparam, niw_natparams = global_natparam
    node_potentials = gaussian.pack_dense(*node_potentials)

    # compute expected global parameters using current global factors
    label_global = dirichlet.expectedstats(dirichlet_natparam)
    gaussian_globals = niw.expectedstats(niw_natparams)

    # compute mean field fixed point using unboxed node_potentials
    label_stats = meanfield_fixed_point(label_global, gaussian_globals, getval(node_potentials))

    # compute values that depend directly on boxed node_potentials at optimum
    gaussian_natparam, gaussian_stats, gaussian_kl = \
        gaussian_meanfield(gaussian_globals, node_potentials, label_stats)
    label_natparam, label_stats, label_kl = \
        label_meanfield(label_global, gaussian_globals, gaussian_stats)

    # collect sufficient statistics for gmm prior (sum across conditional iid)
    dirichlet_stats = np.sum(label_stats, 0)
    niw_stats = np.tensordot(label_stats, gaussian_stats, [0, 0])

    local_stats = label_stats, gaussian_stats
    prior_stats = dirichlet_stats, niw_stats
    natparam = label_natparam, gaussian_natparam
    kl = label_kl + gaussian_kl

    return local_stats, prior_stats, natparam, kl
예제 #4
0
 def mc_elbo(pgm_params, i):
     #Here nn_potentials are just the sufficient stats of the data
     x = get_batch(i)
     xxT = np.einsum('ij,ik->ijk', x, x)
     n = np.ones(x.shape[0]) if x.ndim == 2 else 1.
     nn_potentials = pack_dense(xxT, x, n, n)
     saved.stats, global_kl, local_kl = run_inference(
         pgm_prior, pgm_params, nn_potentials)
     return (-global_kl - num_batches * local_kl) / num_datapoints  #CHECK
예제 #5
0
 def check_params(natparam):
     natparam2 = pack_dense(*unpack_dense(natparam))
     assert np.allclose(natparam, natparam2)
예제 #6
0
def rand_gaussian(n):
    J = rand_psd(n) + n * np.eye(n)
    h = npr.randn(n)
    return pack_dense(-1./2*J, h)
예제 #7
0
 def check_params(natparam):
     natparam2 = pack_dense(*unpack_dense(natparam))
     assert np.allclose(natparam, natparam2)
예제 #8
0
def rand_gaussian(n):
    J = rand_psd(n) + n * np.eye(n)
    h = npr.randn(n)
    return pack_dense(-1. / 2 * J, h)