예제 #1
0
def local_meanfield(global_natparam, node_potentials):
    dirichlet_natparam, niw_natparams = global_natparam
    node_potentials = gaussian.pack_dense(*node_potentials)

    # compute expected global parameters using current global factors
    label_global = dirichlet.expectedstats(dirichlet_natparam)
    gaussian_globals = niw.expectedstats(niw_natparams)

    # compute mean field fixed point using unboxed node_potentials
    label_stats = meanfield_fixed_point(label_global, gaussian_globals, getval(node_potentials))

    # compute values that depend directly on boxed node_potentials at optimum
    gaussian_natparam, gaussian_stats, gaussian_kl = \
        gaussian_meanfield(gaussian_globals, node_potentials, label_stats)
    label_natparam, label_stats, label_kl = \
        label_meanfield(label_global, gaussian_globals, gaussian_stats)

    # collect sufficient statistics for gmm prior (sum across conditional iid)
    dirichlet_stats = np.sum(label_stats, 0)
    niw_stats = np.tensordot(label_stats, gaussian_stats, [0, 0])

    local_stats = label_stats, gaussian_stats
    prior_stats = dirichlet_stats, niw_stats
    natparam = label_natparam, gaussian_natparam
    kl = label_kl + gaussian_kl

    return local_stats, prior_stats, natparam, kl
예제 #2
0
def local_meanfield(global_natparam, node_potentials):
    # global_natparam = \eta_{\theta}^0
    # node_potentials = r(\phi, y)
    
    dirichlet_natparam, niw_natparams = global_natparam
    node_potentials = gaussian.pack_dense(*node_potentials)

    #### compute expected global parameters using current global factors
    # label_global = E_{q(\pi)}[t(\pi)] here q(\pi) is posterior which is dirichlet with parameter dirichlet_natparam and t is [log\pi_1, log\pi_2....]
    # gaussian_globals = E_{q(\mu, \Sigma)}[t(\mu, \Sigma)] here q(\mu, \Sigma) is posterior which is NIW    
    label_global = dirichlet.expectedstats(dirichlet_natparam)
    gaussian_globals = niw.expectedstats(niw_natparams)

    #### compute mean field fixed point using unboxed node_potentials
    label_stats = meanfield_fixed_point(label_global, gaussian_globals, getval(node_potentials))

    #### compute values that depend directly on boxed node_potentials at optimum
    gaussian_natparam, gaussian_stats, gaussian_kl = \
        gaussian_meanfield(gaussian_globals, node_potentials, label_stats)
    label_natparam, label_stats, label_kl = \
        label_meanfield(label_global, gaussian_globals, gaussian_stats)

    #### collect sufficient statistics for gmm prior (sum across conditional iid)
    dirichlet_stats = np.sum(label_stats, 0)
    niw_stats = np.tensordot(label_stats, gaussian_stats, [0, 0])

    local_stats = label_stats, gaussian_stats
    prior_stats = dirichlet_stats, niw_stats
    natparam = label_natparam, gaussian_natparam
    kl = label_kl + gaussian_kl

    return local_stats, prior_stats, natparam, kl
예제 #3
0
def local_meanfield(global_natparam, gaussian_suff_stats):
    # global_natparam = \eta_{\theta}^0
    dirichlet_natparam, niw_natparams = global_natparam
    #node_potentials = gaussian.pack_dense(*node_potentials)

    #### compute expected global parameters using current global factors
    # label_global = E_{q(\pi)}[t(\pi)] here q(\pi) is posterior which is dirichlet with parameter dirichlet_natparam and t is [log\pi_1, log\pi_2....]
    # gaussian_globals = E_{q(\mu, \Sigma)}[t(\mu, \Sigma)] here q(\mu, \Sigma) is posterior which is NIW
    # label_stats = E_{q(z)}[t(z)] -> categorical expected statistics. Shape = (batch_size, K)
    # gaussian_suff_stats  Shape = (batch_size, 4, 4)
    label_global = dirichlet.expectedstats(dirichlet_natparam)
    gaussian_globals = niw.expectedstats(niw_natparams)

    #### compute values that depend directly on boxed node_potentials at optimum
    label_natparam, label_stats, label_kl = \
        label_meanfield(label_global, gaussian_globals, gaussian_suff_stats)

    #### collect sufficient statistics for gmm prior (sum across conditional iid)
    dirichlet_stats = np.sum(label_stats, 0)
    niw_stats = np.tensordot(label_stats, gaussian_suff_stats, [0, 0])

    local_stats = label_stats, gaussian_suff_stats
    prior_stats = dirichlet_stats, niw_stats
    natparam = label_natparam
    kl = label_kl

    return prior_stats, natparam, kl
예제 #4
0
 def plot_components(ax, params):
     pgm_params, loglike_params, recogn_params = params
     dirichlet_natparams, niw_natparams = pgm_params
     normalize = lambda arr: np.minimum(1., arr / np.sum(arr) * num_clusters)
     weights = normalize(np.exp(dirichlet.expectedstats(dirichlet_natparams)))
     components = map(get_component, niw.expectedstats(niw_natparams))
     lines = repeat(None) if isinstance(ax, plt.Axes) else ax
     for weight, (mu, Sigma), line in zip(weights, components, lines):
         plot_ellipse(ax, weight, mu, Sigma, line)
예제 #5
0
 def plot_components(ax, params):
     pgm_params = params
     dirichlet_natparams, niw_natparams = pgm_params
     normalize = lambda arr: np.minimum(1., arr / np.sum(arr) * num_clusters
                                        )
     weights = normalize(
         np.exp(dirichlet.expectedstats(dirichlet_natparams)))
     components = map(get_component, niw.expectedstats(niw_natparams))
     S, m, kappa, nu = niw.natural_to_standard(niw_natparams)
     #print m
     lines = repeat(None) if isinstance(ax, plt.Axes) else ax
     for weight, (mu, Sigma), line in zip(weights, components, lines):
         #print mu
         plot_ellipse(ax, weight, mu, Sigma, line)
예제 #6
0
 def plot_density(density_ax, params):
     pgm_params, loglike_params, recogn_params = params
     dirichlet_natparams, niw_natparams = pgm_params
     normalize = lambda arr: np.minimum(1., arr / np.sum(arr) * num_clusters)
     weights = normalize(np.exp(dirichlet.expectedstats(dirichlet_natparams)))
     components = map(get_component, niw.expectedstats(niw_natparams))
     num_samples = 1000
     #lines = repeat(None) if isinstance(ax, plt.Axes) else ax
     idx = 0
     for weight, (mu, Sigma) in zip(weights, components):
         samples = npr.RandomState(0).multivariate_normal(mu, Sigma, num_samples)
         density = decode_density(samples, loglike_params, decode, 75. * weight)
         density_axis.plot(data[:,0], data[:,1], color='k', marker='.', linestyle='')
         xlim, ylim = density_axis.get_xlim(), density_axis.get_ylim()
         plot_transparent_hexbin(density_axis, density, xlim, ylim, gridsize, colors[idx % len(colors)])
         idx+=1
예제 #7
0
def prior_expectedstats(gmm_natparam):
    dirichlet_natparam, niw_natparams = gmm_natparam
    dirichlet_expectedstats = dirichlet.expectedstats(dirichlet_natparam)
    niw_expectedstats = niw.expectedstats(niw_natparams)
    return dirichlet_expectedstats, niw_expectedstats
예제 #8
0
파일: gmm.py 프로젝트: mattjj/svae
def prior_expectedstats(gmm_natparam):
    dirichlet_natparam, niw_natparams = gmm_natparam
    dirichlet_expectedstats = dirichlet.expectedstats(dirichlet_natparam)
    niw_expectedstats = niw.expectedstats(niw_natparams)
    return dirichlet_expectedstats, niw_expectedstats
예제 #9
0
def pgm_expectedstats(global_natparam):
    dirichlet_natparam, niw_natparams = global_natparam
    return dirichlet.expectedstats(dirichlet_natparam), niw.expectedstats(niw_natparams)
예제 #10
0
 def check_expectedstats(natparam):
     E_stats1 = expectedstats(natparam)
     E_stats2 = grad(logZ)(natparam)
     assert np.allclose(E_stats1, E_stats2)
예제 #11
0
def plot(axs, data, params):
    natparam, phi, psi = params
    ax_data, ax_latent = axs
    K = len(natparam[1])

    def plot_or_update(idx, ax, x, y, alpha=1, **kwargs):
        if len(ax.lines) > idx:
            ax.lines[idx].set_data((x, y))
            ax.lines[idx].set_alpha(alpha)
        else:
            ax.plot(x, y, alpha=alpha, **kwargs)

    dir_hypers, all_niw_hypers = natparam
    weights = normalize(np.exp(dirichlet.expectedstats(dir_hypers)))
    components = map(niw.expected_standard_params, all_niw_hypers)

    latent_locations = encode_mean(data, natparam, psi)
    reconstruction = decode_mean(latent_locations, phi)

    ## make data-space plot

    # ax_data.scatter(data[:, 0], data[:, 1], s=1, color='k', marker='.', zorder=2)
    # set_border_around_data(ax_data, data)
    ax_data.collections[:] = []
    ax_data.plot(data[:, 0], data[:, 1], 'k.', markersize=markersize)

    xlim, ylim = ax_data.get_xlim(), ax_data.get_ylim()
    for idx, (weight, (mu, Sigma)) in enumerate(
            sorted(zip(weights, components), key=itemgetter(0))):
        samples = npr.RandomState(0).multivariate_normal(
            mu, Sigma, num_samples)
        density = decode_density(samples, phi, gmm_decode, 75. * weight)
        plot_transparent_hexbin(ax_data, density, xlim, ylim, gridsize,
                                colors[idx % len(colors)])
    ax_data.set_xlim(xlim)
    ax_data.set_ylim(ylim)

    ## make latent space plot

    plot_or_update(0,
                   ax_latent,
                   latent_locations[:, 0],
                   latent_locations[:, 1],
                   color='k',
                   marker='.',
                   linestyle='',
                   markersize=markersize)
    # set_border_around_data(ax_latent, latent_locations)

    for idx, (weight, (mu, Sigma)) in enumerate(
            sorted(zip(weights, components), key=itemgetter(0))):
        x, y = generate_ellipse(mu, Sigma)
        plot_or_update(idx + 1,
                       ax_latent,
                       x,
                       y,
                       alpha=min(1., K * weight),
                       linestyle='-',
                       linewidth=2,
                       color=colors[idx % len(colors)])

    ax_latent.relim()
    ax_latent.autoscale_view(True, True, True)