def local_meanfield(global_natparam, node_potentials): dirichlet_natparam, niw_natparams = global_natparam node_potentials = gaussian.pack_dense(*node_potentials) # compute expected global parameters using current global factors label_global = dirichlet.expectedstats(dirichlet_natparam) gaussian_globals = niw.expectedstats(niw_natparams) # compute mean field fixed point using unboxed node_potentials label_stats = meanfield_fixed_point(label_global, gaussian_globals, getval(node_potentials)) # compute values that depend directly on boxed node_potentials at optimum gaussian_natparam, gaussian_stats, gaussian_kl = \ gaussian_meanfield(gaussian_globals, node_potentials, label_stats) label_natparam, label_stats, label_kl = \ label_meanfield(label_global, gaussian_globals, gaussian_stats) # collect sufficient statistics for gmm prior (sum across conditional iid) dirichlet_stats = np.sum(label_stats, 0) niw_stats = np.tensordot(label_stats, gaussian_stats, [0, 0]) local_stats = label_stats, gaussian_stats prior_stats = dirichlet_stats, niw_stats natparam = label_natparam, gaussian_natparam kl = label_kl + gaussian_kl return local_stats, prior_stats, natparam, kl
def local_meanfield(global_natparam, node_potentials): # global_natparam = \eta_{\theta}^0 # node_potentials = r(\phi, y) dirichlet_natparam, niw_natparams = global_natparam node_potentials = gaussian.pack_dense(*node_potentials) #### compute expected global parameters using current global factors # label_global = E_{q(\pi)}[t(\pi)] here q(\pi) is posterior which is dirichlet with parameter dirichlet_natparam and t is [log\pi_1, log\pi_2....] # gaussian_globals = E_{q(\mu, \Sigma)}[t(\mu, \Sigma)] here q(\mu, \Sigma) is posterior which is NIW label_global = dirichlet.expectedstats(dirichlet_natparam) gaussian_globals = niw.expectedstats(niw_natparams) #### compute mean field fixed point using unboxed node_potentials label_stats = meanfield_fixed_point(label_global, gaussian_globals, getval(node_potentials)) #### compute values that depend directly on boxed node_potentials at optimum gaussian_natparam, gaussian_stats, gaussian_kl = \ gaussian_meanfield(gaussian_globals, node_potentials, label_stats) label_natparam, label_stats, label_kl = \ label_meanfield(label_global, gaussian_globals, gaussian_stats) #### collect sufficient statistics for gmm prior (sum across conditional iid) dirichlet_stats = np.sum(label_stats, 0) niw_stats = np.tensordot(label_stats, gaussian_stats, [0, 0]) local_stats = label_stats, gaussian_stats prior_stats = dirichlet_stats, niw_stats natparam = label_natparam, gaussian_natparam kl = label_kl + gaussian_kl return local_stats, prior_stats, natparam, kl
def local_meanfield(global_natparam, gaussian_suff_stats): # global_natparam = \eta_{\theta}^0 dirichlet_natparam, niw_natparams = global_natparam #node_potentials = gaussian.pack_dense(*node_potentials) #### compute expected global parameters using current global factors # label_global = E_{q(\pi)}[t(\pi)] here q(\pi) is posterior which is dirichlet with parameter dirichlet_natparam and t is [log\pi_1, log\pi_2....] # gaussian_globals = E_{q(\mu, \Sigma)}[t(\mu, \Sigma)] here q(\mu, \Sigma) is posterior which is NIW # label_stats = E_{q(z)}[t(z)] -> categorical expected statistics. Shape = (batch_size, K) # gaussian_suff_stats Shape = (batch_size, 4, 4) label_global = dirichlet.expectedstats(dirichlet_natparam) gaussian_globals = niw.expectedstats(niw_natparams) #### compute values that depend directly on boxed node_potentials at optimum label_natparam, label_stats, label_kl = \ label_meanfield(label_global, gaussian_globals, gaussian_suff_stats) #### collect sufficient statistics for gmm prior (sum across conditional iid) dirichlet_stats = np.sum(label_stats, 0) niw_stats = np.tensordot(label_stats, gaussian_suff_stats, [0, 0]) local_stats = label_stats, gaussian_suff_stats prior_stats = dirichlet_stats, niw_stats natparam = label_natparam kl = label_kl return prior_stats, natparam, kl
def plot_components(ax, params): pgm_params, loglike_params, recogn_params = params dirichlet_natparams, niw_natparams = pgm_params normalize = lambda arr: np.minimum(1., arr / np.sum(arr) * num_clusters) weights = normalize(np.exp(dirichlet.expectedstats(dirichlet_natparams))) components = map(get_component, niw.expectedstats(niw_natparams)) lines = repeat(None) if isinstance(ax, plt.Axes) else ax for weight, (mu, Sigma), line in zip(weights, components, lines): plot_ellipse(ax, weight, mu, Sigma, line)
def plot_components(ax, params): pgm_params = params dirichlet_natparams, niw_natparams = pgm_params normalize = lambda arr: np.minimum(1., arr / np.sum(arr) * num_clusters ) weights = normalize( np.exp(dirichlet.expectedstats(dirichlet_natparams))) components = map(get_component, niw.expectedstats(niw_natparams)) S, m, kappa, nu = niw.natural_to_standard(niw_natparams) #print m lines = repeat(None) if isinstance(ax, plt.Axes) else ax for weight, (mu, Sigma), line in zip(weights, components, lines): #print mu plot_ellipse(ax, weight, mu, Sigma, line)
def plot_density(density_ax, params): pgm_params, loglike_params, recogn_params = params dirichlet_natparams, niw_natparams = pgm_params normalize = lambda arr: np.minimum(1., arr / np.sum(arr) * num_clusters) weights = normalize(np.exp(dirichlet.expectedstats(dirichlet_natparams))) components = map(get_component, niw.expectedstats(niw_natparams)) num_samples = 1000 #lines = repeat(None) if isinstance(ax, plt.Axes) else ax idx = 0 for weight, (mu, Sigma) in zip(weights, components): samples = npr.RandomState(0).multivariate_normal(mu, Sigma, num_samples) density = decode_density(samples, loglike_params, decode, 75. * weight) density_axis.plot(data[:,0], data[:,1], color='k', marker='.', linestyle='') xlim, ylim = density_axis.get_xlim(), density_axis.get_ylim() plot_transparent_hexbin(density_axis, density, xlim, ylim, gridsize, colors[idx % len(colors)]) idx+=1
def prior_expectedstats(gmm_natparam): dirichlet_natparam, niw_natparams = gmm_natparam dirichlet_expectedstats = dirichlet.expectedstats(dirichlet_natparam) niw_expectedstats = niw.expectedstats(niw_natparams) return dirichlet_expectedstats, niw_expectedstats
def pgm_expectedstats(global_natparam): dirichlet_natparam, niw_natparams = global_natparam return dirichlet.expectedstats(dirichlet_natparam), niw.expectedstats(niw_natparams)
def check_expectedstats(natparam): E_stats1 = expectedstats(natparam) E_stats2 = grad(logZ)(natparam) assert np.allclose(E_stats1, E_stats2)
def plot(axs, data, params): natparam, phi, psi = params ax_data, ax_latent = axs K = len(natparam[1]) def plot_or_update(idx, ax, x, y, alpha=1, **kwargs): if len(ax.lines) > idx: ax.lines[idx].set_data((x, y)) ax.lines[idx].set_alpha(alpha) else: ax.plot(x, y, alpha=alpha, **kwargs) dir_hypers, all_niw_hypers = natparam weights = normalize(np.exp(dirichlet.expectedstats(dir_hypers))) components = map(niw.expected_standard_params, all_niw_hypers) latent_locations = encode_mean(data, natparam, psi) reconstruction = decode_mean(latent_locations, phi) ## make data-space plot # ax_data.scatter(data[:, 0], data[:, 1], s=1, color='k', marker='.', zorder=2) # set_border_around_data(ax_data, data) ax_data.collections[:] = [] ax_data.plot(data[:, 0], data[:, 1], 'k.', markersize=markersize) xlim, ylim = ax_data.get_xlim(), ax_data.get_ylim() for idx, (weight, (mu, Sigma)) in enumerate( sorted(zip(weights, components), key=itemgetter(0))): samples = npr.RandomState(0).multivariate_normal( mu, Sigma, num_samples) density = decode_density(samples, phi, gmm_decode, 75. * weight) plot_transparent_hexbin(ax_data, density, xlim, ylim, gridsize, colors[idx % len(colors)]) ax_data.set_xlim(xlim) ax_data.set_ylim(ylim) ## make latent space plot plot_or_update(0, ax_latent, latent_locations[:, 0], latent_locations[:, 1], color='k', marker='.', linestyle='', markersize=markersize) # set_border_around_data(ax_latent, latent_locations) for idx, (weight, (mu, Sigma)) in enumerate( sorted(zip(weights, components), key=itemgetter(0))): x, y = generate_ellipse(mu, Sigma) plot_or_update(idx + 1, ax_latent, x, y, alpha=min(1., K * weight), linestyle='-', linewidth=2, color=colors[idx % len(colors)]) ax_latent.relim() ax_latent.autoscale_view(True, True, True)