def compute_elbo_smm(y, reconstructions, theta, phi_tilde, x_k_samps, log_z_given_y_phi, decoder_type): # ELBO for latent SMM with tf.name_scope('elbo'): # unpack phi_gmm and compute expected theta mu_theta, sigma_theta = unpack_smm(theta[1:3]) alpha_k = dirichlet.natural_to_standard(theta[0]) expected_log_pi_theta = dirichlet.expected_log_pi(alpha_k) dof = theta[3] # Student-t degrees of freedom # make sure that gradient is not propagated through stochastic Dirichlet parameter with tf.name_scope('block_backprop'): expected_log_pi_theta = tf.stop_gradient(expected_log_pi_theta) dof = tf.stop_gradient(dof) r_nk = tf.exp(log_z_given_y_phi) # compute negative reconstruction error; sum over minibatch (use VAE function) means, out_2 = reconstructions # out_2 is either gaussian variances or bernoulli logits. if decoder_type == 'standard': neg_reconstruction_error = vae.expected_diagonal_gaussian_loglike(y, means, out_2, weights=r_nk) elif decoder_type == 'bernoulli': neg_reconstruction_error = vae.expected_bernoulli_loglike(y, out_2, r_nk=r_nk) else: raise NotImplementedError eta1_phi_tilde, eta2_phi_tilde = phi_tilde N, K, L, _ = eta2_phi_tilde.get_shape().as_list() eta1_phi_tilde = tf.reshape(eta1_phi_tilde, (N, K, L)) with tf.name_scope('compute_regularizer'): # compute E[log q_phi(x,z=k|y)] with tf.name_scope('log_numerator'): log_N_x_given_phi = gaussian.log_probability_nat_per_samp(x_k_samps, eta1_phi_tilde, eta2_phi_tilde) log_numerator = log_N_x_given_phi + tf.expand_dims(log_z_given_y_phi, axis=2) with tf.name_scope('log_denominator'): # compute E[log p_theta(x,z=k)] log_N_x_given_theta = student_t.log_probability_per_samp(x_k_samps, mu_theta, sigma_theta, dof) log_denominator = log_N_x_given_theta + tf.expand_dims(tf.expand_dims(expected_log_pi_theta, axis=0), axis=2) regularizer_term = tf.reduce_mean( tf.reduce_sum( tf.reduce_sum( tf.multiply(tf.expand_dims(r_nk, axis=2), log_numerator - log_denominator), axis=1), # weighted average over components axis=0) # sum over data points ) # mean over samples elbo = tf.subtract(neg_reconstruction_error, regularizer_term, name='elbo') with tf.name_scope('elbo_summaries'): details = tf.tuple((neg_reconstruction_error, tf.reduce_sum(tf.multiply(r_nk, tf.reduce_mean(log_numerator, -1))), tf.reduce_sum(tf.multiply(r_nk, tf.reduce_mean(log_denominator, -1))), regularizer_term), name='debug') return elbo, details
def expected_values(self): _, m, C, v = niw.natural_to_standard(self.A, self.b, self.beta, self.v_hat) exp_log_pi = dirichlet.expected_log_pi( dirichlet.natural_to_standard(self.alpha)) with tf.name_scope('niw_expectation'): exp_m = tf.identity(m, 'expected_mean') C_inv = tf.matrix_inverse(C) C_inv_sym = tf.divide(tf.add(C_inv, tf.matrix_transpose(C_inv)), 2., name='C_inv_symmetrised') exp_C = tf.matrix_inverse(tf.multiply(C_inv_sym, tf.expand_dims( tf.expand_dims(v, 1), 2), name='expected_precision'), name='expected_covariance') return exp_log_pi, exp_m, exp_C
if lbl_tr is not None: entr_tr, prty_tr = purity(tf.exp(log_z_given_y_phi), lbl_tr) tf.summary.scalar('entropy_tr', entr_tr) tf.summary.scalar('purity_tr', prty_tr) # useful values for tensorboard and plotting with tf.name_scope('plotting_prep'): if 'smm' in config['method']: mu, sigma = svae.unpack_smm(theta[1:3]) else: beta_k, m_k, C_k, v_k = niw.natural_to_standard( theta[1], theta[2], theta[3], theta[4]) mu, sigma = niw.expected_values((beta_k, m_k, C_k, v_k)) alpha_k = dirichlet.natural_to_standard(theta[0]) expected_log_pi = dirichlet.expected_log_pi(alpha_k) pi_theta = tf.exp(expected_log_pi) theta_plot = mu, sigma, pi_theta q_z_given_y_phi = tf.exp(log_z_given_y_phi) neg_normed_elbo = -tf.divide(tf.reduce_sum(tower_elbo), size_minibatch) tf.summary.scalar('elbo/elbo_normed', neg_normed_elbo) tf.summary.scalar( 'elbo/neg_rec_err', tf.divide(tf.reduce_sum(tower_neg_rec_err), size_minibatch)) tf.summary.scalar( 'elbo/regularizer', tf.divide(tf.reduce_sum(tower_regularizer), size_minibatch))
def compute_elbo(y, reconstructions, theta, phi_tilde, x_k_samps, log_z_given_y_phi, decoder_type): # ELBO for latent GMM with tf.name_scope('elbo'): # unpack phi_gmm and compute expected theta with tf.name_scope('expct_theta_to_nat'): beta_k, m_k, C_k, v_k = niw.natural_to_standard(*theta[1:]) mu, sigma = niw.expected_values((beta_k, m_k, C_k, v_k)) eta1_theta, eta2_theta = gaussian.standard_to_natural(mu, sigma) alpha_k = dirichlet.natural_to_standard(theta[0]) expected_log_pi_theta = dirichlet.expected_log_pi(alpha_k) # do not backpropagate through GMM with tf.name_scope('block_backprop'): eta1_theta = tf.stop_gradient(eta1_theta) eta2_theta = tf.stop_gradient(eta2_theta) expected_log_pi_theta = tf.stop_gradient(expected_log_pi_theta) r_nk = tf.exp(log_z_given_y_phi) # compute negative reconstruction error; sum over minibatch (use VAE function) means, out_2 = reconstructions # out_2 is either gaussian variances or bernoulli logits. if decoder_type == 'standard': neg_reconstruction_error = vae.expected_diagonal_gaussian_loglike(y, means, out_2, weights=r_nk) elif decoder_type == 'bernoulli': neg_reconstruction_error = vae.expected_bernoulli_loglike(y, out_2, r_nk=r_nk) else: raise NotImplementedError # compute E[log q_phi(x,z=k|y)] eta1_phi_tilde, eta2_phi_tilde = phi_tilde N, K, L, _ = eta2_phi_tilde.get_shape().as_list() eta1_phi_tilde = tf.reshape(eta1_phi_tilde, (N, K, L)) N, K, S, L = x_k_samps.get_shape().as_list() with tf.name_scope('compute_regularizer'): with tf.name_scope('log_numerator'): log_N_x_given_phi = gaussian.log_probability_nat_per_samp(x_k_samps, eta1_phi_tilde, eta2_phi_tilde) log_numerator = log_N_x_given_phi + tf.expand_dims(log_z_given_y_phi, axis=2) with tf.name_scope('log_denominator'): log_N_x_given_theta = gaussian.log_probability_nat_per_samp(x_k_samps, tf.tile(tf.expand_dims(eta1_theta, axis=0), [N, 1, 1]), tf.tile(tf.expand_dims(eta2_theta, axis=0), [N, 1, 1, 1])) log_denominator = log_N_x_given_theta + tf.expand_dims(tf.expand_dims(expected_log_pi_theta, axis=0), axis=2) regularizer_term = tf.reduce_mean( tf.reduce_sum( tf.reduce_sum( tf.multiply(tf.expand_dims(r_nk, axis=2), log_numerator - log_denominator), axis=1), # weighted average over components axis=0) # sum over minibatch ) # mean over samples elbo = tf.subtract(neg_reconstruction_error, regularizer_term, name='elbo') with tf.name_scope('elbo_summaries'): details = tf.tuple((neg_reconstruction_error, tf.reduce_sum(tf.multiply(r_nk, tf.reduce_mean(log_numerator, -1))), tf.reduce_sum(tf.multiply(r_nk, tf.reduce_mean(log_denominator, -1))), regularizer_term), name='debug') return elbo, details
def visualize_svae(ax, config, log_path, ratio_tr=0.7, nb_samples=20, grid_density=100, window=((-20, 20), (-20, 20)), param_device='/cpu:0'): with tf.device(param_device): if config['dataset'] in ['mnist', 'fashion']: binarise = True size_minibatch = 1024 output_type = 'bernoulli' else: binarise = False size_minibatch = -1 output_type = 'standard' # First we build the model graph so that we can load the learned parameters from a checkpoint. # Initialisations don't matter, they'll be overwritten with saver.restore(). data, lbl, _, _ = make_minibatch(config['dataset'], path_datadir='../datasets', ratio_tr=ratio_tr, seed_split=0, size_minibatch=size_minibatch, size_testbatch=-1, binarise=binarise) # define nn-architecture encoder_layers = [(config['U'], tf.tanh), (config['U'], tf.tanh), (config['L'], 'natparam')] decoder_layers = [(config['U'], tf.tanh), (config['U'], tf.tanh), (int(data.get_shape()[1]), output_type)] sample_size = 100 if config['dataset'] in ['mnist', 'fashion']: data = tf.where(tf.equal(data, -1), tf.zeros_like(data, dtype=tf.float32), tf.ones_like(data, dtype=tf.float32)) with tf.name_scope('model'): gmm_prior, theta = svae.init_mm(config['K'], config['L'], seed=config['seed'], param_device='/gpu:0') theta_copied = niw.natural_to_standard(tf.identity(gmm_prior[1]), tf.identity(gmm_prior[2]), tf.identity(gmm_prior[3]), tf.identity(gmm_prior[4])) _, sigma_k = niw.expected_values(theta_copied) pi_k_init = tf.nn.softmax( tf.random_normal(shape=(config['K'], ), mean=0.0, stddev=1., seed=config['seed'])) L_k = tf.cholesky(sigma_k) mu_k = tf.random_normal(shape=(config['K'], config['L']), stddev=1, seed=config['seed']) with tf.variable_scope('phi_gmm'): mu_k = variable_on_device('mu_k', shape=None, initializer=mu_k, trainable=True, device=param_device) L_k = variable_on_device('L_k', shape=None, initializer=L_k, trainable=True, device=param_device) pi_k = variable_on_device('log_pi_k', shape=None, initializer=pi_k_init, trainable=True, device=param_device) phi_gmm = mu_k, L_k, pi_k _ = vae.make_encoder(data, layerspecs=encoder_layers, stddev_init=.1, seed=config['seed']) with tf.name_scope('random_sampling'): # compute expected theta_pgm beta_k, m_k, C_k, v_k = niw.natural_to_standard(*theta[1:]) alpha_k = dirichlet.natural_to_standard(theta[0]) mean, cov = niw.expected_values((beta_k, m_k, C_k, v_k)) expected_log_pi = dirichlet.expected_log_pi(alpha_k) pi = tf.exp(expected_log_pi) # sample from prior (first from x_k_samples = tf.contrib.distributions.MultivariateNormalFullCovariance( loc=mean, covariance_matrix=cov).sample(sample_size) z_samples = tf.multinomial(logits=tf.reshape(tf.log(pi), (1, -1)), num_samples=sample_size, name='k_samples') z_samples = tf.squeeze(z_samples) assert z_samples.get_shape() == (sample_size, ) assert x_k_samples.get_shape() == (sample_size, config['K'], config['L']) # compute reconstructions y_k_samples, _ = vae.make_decoder(x_k_samples, layerspecs=decoder_layers, stddev_init=.1, seed=config['seed']) assert y_k_samples.get_shape() == (sample_size, config['K'], data.get_shape()[1]) with tf.name_scope('cluster_sample_data'): tf.get_variable_scope().reuse_variables() _, clustering = svae.predict(data, phi_gmm, encoder_layers, decoder_layers, seed=config['seed']) # load trained model saver = tf.train.Saver() model_path = log_path + '/' + generate_log_id(config) print(model_path) latest_ckpnt = tf.train.latest_checkpoint(model_path) sess_config = tf.ConfigProto(allow_soft_placement=True) sess = tf.Session(config=sess_config) saver.restore(sess, latest_ckpnt) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) collected_y_samps = [] collected_z_samps = [] for s in range(nb_samples): y_samps, z_samps = sess.run((y_k_samples, z_samples)) collected_y_samps.append(y_samps) collected_z_samps.append(z_samps) collected_y_samps = np.concatenate(collected_y_samps, axis=0) collected_z_samps = np.concatenate(collected_z_samps, axis=0) assert collected_y_samps.shape == (nb_samples * sample_size, config['K'], data.shape[1]) assert collected_z_samps.shape == (nb_samples * sample_size, ) # use 300 sample points from the dataset data, lbl, clustering = sess.run( (data[:300], lbl[:300], clustering[:300])) # compute PCA if necessary samples_2d = [] if data.shape[1] > 2: pca = PCA(n_components=2).fit(data) data2d = pca.transform(data) for z_samples in range(config['K']): chosen = collected_z_samps == z_samples samps_k = collected_y_samps[chosen, z_samples, :] if samps_k.size > 0: samples_2d.append(pca.transform(samps_k)) else: data2d = data for z_samples in range(config['K']): chosen = (collected_z_samps == z_samples) samps_k = collected_y_samps[chosen, z_samples, :] if samps_k.size > 0: samples_2d.append(samps_k) # plot 2d-histogram (one histogram for each of the K components) from matplotlib.colors import LogNorm for z_samples, samples in enumerate(samples_2d): ax.hist2d(samples[:, 0], samples[:, 1], bins=grid_density, range=window, cmap=make_colormap(dark_colors[z_samples % len(dark_colors)]), normed=True, norm=LogNorm()) # overlay histogram with sample datapoints (coloured according to their most likely cluster allocation) labels = np.argmax(lbl, axis=1) for c in np.unique(labels): in_class_c = (labels == c) color = bright_colors[int(c % len(bright_colors))] marker = markers[int(c % len(markers))] ax.scatter(data2d[in_class_c, 0], data2d[in_class_c, 1], c=color, marker=marker, s=data_dot_size, linewidths=0)