Exemple #1
0
def local_meanfield(global_natparam, node_potentials):
    # global_natparam = \eta_{\theta}^0
    # node_potentials = r(\phi, y)
    
    dirichlet_natparam, niw_natparams = global_natparam
    node_potentials = gaussian.pack_dense(*node_potentials)

    #### compute expected global parameters using current global factors
    # label_global = E_{q(\pi)}[t(\pi)] here q(\pi) is posterior which is dirichlet with parameter dirichlet_natparam and t is [log\pi_1, log\pi_2....]
    # gaussian_globals = E_{q(\mu, \Sigma)}[t(\mu, \Sigma)] here q(\mu, \Sigma) is posterior which is NIW    
    label_global = dirichlet.expectedstats(dirichlet_natparam)
    gaussian_globals = niw.expectedstats(niw_natparams)

    #### compute mean field fixed point using unboxed node_potentials
    label_stats = meanfield_fixed_point(label_global, gaussian_globals, getval(node_potentials))

    #### compute values that depend directly on boxed node_potentials at optimum
    gaussian_natparam, gaussian_stats, gaussian_kl = \
        gaussian_meanfield(gaussian_globals, node_potentials, label_stats)
    label_natparam, label_stats, label_kl = \
        label_meanfield(label_global, gaussian_globals, gaussian_stats)

    #### collect sufficient statistics for gmm prior (sum across conditional iid)
    dirichlet_stats = np.sum(label_stats, 0)
    niw_stats = np.tensordot(label_stats, gaussian_stats, [0, 0])

    local_stats = label_stats, gaussian_stats
    prior_stats = dirichlet_stats, niw_stats
    natparam = label_natparam, gaussian_natparam
    kl = label_kl + gaussian_kl

    return local_stats, prior_stats, natparam, kl
Exemple #2
0
def local_meanfield(global_natparam, gaussian_suff_stats):
    # global_natparam = \eta_{\theta}^0
    dirichlet_natparam, niw_natparams = global_natparam
    #node_potentials = gaussian.pack_dense(*node_potentials)

    #### compute expected global parameters using current global factors
    # label_global = E_{q(\pi)}[t(\pi)] here q(\pi) is posterior which is dirichlet with parameter dirichlet_natparam and t is [log\pi_1, log\pi_2....]
    # gaussian_globals = E_{q(\mu, \Sigma)}[t(\mu, \Sigma)] here q(\mu, \Sigma) is posterior which is NIW
    # label_stats = E_{q(z)}[t(z)] -> categorical expected statistics. Shape = (batch_size, K)
    # gaussian_suff_stats  Shape = (batch_size, 4, 4)
    label_global = dirichlet.expectedstats(dirichlet_natparam)
    gaussian_globals = niw.expectedstats(niw_natparams)

    #### compute values that depend directly on boxed node_potentials at optimum
    label_natparam, label_stats, label_kl = \
        label_meanfield(label_global, gaussian_globals, gaussian_suff_stats)

    #### collect sufficient statistics for gmm prior (sum across conditional iid)
    dirichlet_stats = np.sum(label_stats, 0)
    niw_stats = np.tensordot(label_stats, gaussian_suff_stats, [0, 0])

    local_stats = label_stats, gaussian_suff_stats
    prior_stats = dirichlet_stats, niw_stats
    natparam = label_natparam
    kl = label_kl

    return prior_stats, natparam, kl
Exemple #3
0
def local_meanfield(global_natparam, node_potentials):
    dirichlet_natparam, niw_natparams = global_natparam
    node_potentials = gaussian.pack_dense(*node_potentials)

    # compute expected global parameters using current global factors
    label_global = dirichlet.expectedstats(dirichlet_natparam)
    gaussian_globals = niw.expectedstats(niw_natparams)

    # compute mean field fixed point using unboxed node_potentials
    label_stats = meanfield_fixed_point(label_global, gaussian_globals, getval(node_potentials))

    # compute values that depend directly on boxed node_potentials at optimum
    gaussian_natparam, gaussian_stats, gaussian_kl = \
        gaussian_meanfield(gaussian_globals, node_potentials, label_stats)
    label_natparam, label_stats, label_kl = \
        label_meanfield(label_global, gaussian_globals, gaussian_stats)

    # collect sufficient statistics for gmm prior (sum across conditional iid)
    dirichlet_stats = np.sum(label_stats, 0)
    niw_stats = np.tensordot(label_stats, gaussian_stats, [0, 0])

    local_stats = label_stats, gaussian_stats
    prior_stats = dirichlet_stats, niw_stats
    natparam = label_natparam, gaussian_natparam
    kl = label_kl + gaussian_kl

    return local_stats, prior_stats, natparam, kl
Exemple #4
0
 def plot_components(ax, params):
     pgm_params, loglike_params, recogn_params = params
     dirichlet_natparams, niw_natparams = pgm_params
     normalize = lambda arr: np.minimum(1., arr / np.sum(arr) * num_clusters)
     weights = normalize(np.exp(dirichlet.expectedstats(dirichlet_natparams)))
     components = map(get_component, niw.expectedstats(niw_natparams))
     lines = repeat(None) if isinstance(ax, plt.Axes) else ax
     for weight, (mu, Sigma), line in zip(weights, components, lines):
         plot_ellipse(ax, weight, mu, Sigma, line)
Exemple #5
0
 def plot_components(ax, params):
     pgm_params = params
     dirichlet_natparams, niw_natparams = pgm_params
     normalize = lambda arr: np.minimum(1., arr / np.sum(arr) * num_clusters
                                        )
     weights = normalize(
         np.exp(dirichlet.expectedstats(dirichlet_natparams)))
     components = map(get_component, niw.expectedstats(niw_natparams))
     S, m, kappa, nu = niw.natural_to_standard(niw_natparams)
     #print m
     lines = repeat(None) if isinstance(ax, plt.Axes) else ax
     for weight, (mu, Sigma), line in zip(weights, components, lines):
         #print mu
         plot_ellipse(ax, weight, mu, Sigma, line)
Exemple #6
0
 def plot_density(density_ax, params):
     pgm_params, loglike_params, recogn_params = params
     dirichlet_natparams, niw_natparams = pgm_params
     normalize = lambda arr: np.minimum(1., arr / np.sum(arr) * num_clusters)
     weights = normalize(np.exp(dirichlet.expectedstats(dirichlet_natparams)))
     components = map(get_component, niw.expectedstats(niw_natparams))
     num_samples = 1000
     #lines = repeat(None) if isinstance(ax, plt.Axes) else ax
     idx = 0
     for weight, (mu, Sigma) in zip(weights, components):
         samples = npr.RandomState(0).multivariate_normal(mu, Sigma, num_samples)
         density = decode_density(samples, loglike_params, decode, 75. * weight)
         density_axis.plot(data[:,0], data[:,1], color='k', marker='.', linestyle='')
         xlim, ylim = density_axis.get_xlim(), density_axis.get_ylim()
         plot_transparent_hexbin(density_axis, density, xlim, ylim, gridsize, colors[idx % len(colors)])
         idx+=1
Exemple #7
0
 def check_expectedstats(natparam):
     E_stats1 = expectedstats(natparam)
     E_stats2 = grad(logZ)(natparam)
     assert np.allclose(E_stats1, E_stats2)
Exemple #8
0
def prior_expectedstats(gmm_natparam):
    dirichlet_natparam, niw_natparams = gmm_natparam
    dirichlet_expectedstats = dirichlet.expectedstats(dirichlet_natparam)
    niw_expectedstats = niw.expectedstats(niw_natparams)
    return dirichlet_expectedstats, niw_expectedstats
Exemple #9
0
def prior_expectedstats(gmm_natparam):
    dirichlet_natparam, niw_natparams = gmm_natparam
    dirichlet_expectedstats = dirichlet.expectedstats(dirichlet_natparam)
    niw_expectedstats = niw.expectedstats(niw_natparams)
    return dirichlet_expectedstats, niw_expectedstats
Exemple #10
0
def pgm_expectedstats(global_natparam):
    dirichlet_natparam, niw_natparams = global_natparam
    return dirichlet.expectedstats(dirichlet_natparam), niw.expectedstats(niw_natparams)
Exemple #11
0
 def check_expectedstats(natparam):
     E_stats1 = expectedstats(natparam)
     E_stats2 = grad(logZ)(natparam)
     assert np.allclose(E_stats1, E_stats2)