def call(self, state): # Get mean and standard deviation from the policy network a1 = self.dense1_layer(state) a2 = self.dense2_layer(a1) mu = self.mean_layer(a2) # Standard deviation is bounded by a constraint of being non-negative # therefore we produce log stdev as output which can be [-inf, inf] log_sigma = self.stdev_layer(a2) sigma = tf.exp(log_sigma) # Use re-parameterization trick to deterministically sample action from # the policy network. First, sample from a Normal distribution of # sample size as the action and multiply it with stdev dist = Normal(mu, sigma) action_ = dist.sample() # Apply the tanh squashing to keep the gaussian bounded in (-1,1) action = tf.tanh(action_) # Calculate the log probability log_pi_ = dist.log_prob(action_) log_pi = log_pi_ - tf.reduce_sum( tf.math.log(1 - action**2 + eps), axis=1, keepdims=True) return action, log_pi
def total_correlation(z_samples: Tensor, qZ_X: Distribution) -> Tensor: r"""Estimate of total correlation using Gaussian distribution on a batch. We need to compute the expectation over a batch of: `E_j [log(q(z(x_j))) - log(prod_l q(z(x_j)_l))]` We ignore the constants as they do not matter for the minimization. The constant should be equal to `(num_latents - 1) * log(batch_size * dataset_size)` If `alpha = gamma = 1`, Eq(4) can be written as `ELBO + (1 - beta) * TC`. (i.e. `(1. - beta) * total_correlation(z_sampled, qZ_X)`) Parameters ---------- z_samples : Tensor shape `[batch_size, num_latents]` - tensor with sampled representation. qZ_X : Distribution the posterior distribution, shape `[batch_size, num_latents]` Note ---- This involve calculating pair-wise distance, memory complexity up to `O(n*n*d)`. Returns ------- Total correlation estimated on a batch. References ---------- Chen, R.T.Q., Li, X., Grosse, R., Duvenaud, D., 2019. Isolating Sources of Disentanglement in Variational Autoencoders. arXiv:1802.04942 [cs, stat]. Github code https://github.com/google-research/disentanglement_lib """ gaus = Normal(loc=tf.expand_dims(qZ_X.mean(), 0), scale=tf.expand_dims(qZ_X.stddev(), 0)) # Compute log(q(z(x_j)|x_i)) for every sample in the batch, which is a # tensor of size [batch_size, batch_size, num_latents]. In the following # comments, [batch_size, batch_size, num_latents] are indexed by [j, i, l]. log_qz_prob = gaus.log_prob(tf.expand_dims(z_samples, 1)) # Compute log prod_l p(z(x_j)_l) = sum_l(log(sum_i(q(z(z_j)_l|x_i))) # + constant) for each sample in the batch, which is a vector of size # [batch_size,]. log_qz_product = tf.reduce_sum(tf.reduce_logsumexp(log_qz_prob, axis=1, keepdims=False), axis=1, keepdims=False) # Compute log(q(z(x_j))) as log(sum_i(q(z(x_j)|x_i))) + constant = # log(sum_i(prod_l q(z(x_j)_l|x_i))) + constant. log_qz = tf.reduce_logsumexp(tf.reduce_sum(log_qz_prob, axis=2, keepdims=False), axis=1, keepdims=False) return tf.reduce_mean(log_qz - log_qz_product)
def compute_KL_univariate_prior(self,univariateprior, samples): """ :param prior: assuming univatier prior of Normal(m,s); i.e. Normal(s, :param posterior: (theta: mean,std) to create posterior q(w/theta) i.e. Normal(mean,std) :param sample: :return: """ samples = tf.reshape(samples, [-1]) # flatten vector (mean2, std2) = univariateprior prior = Normal(mean2, std2) posterior = Normal(self.mu, self.sigma) q_theta = tf.reduce_sum(posterior.log_prob(samples)) p_d = tf.reduce_sum(prior.log_prob(self.samples)) KL = tf.subtract(q_theta, p_d) return KL
def elbo_components(self, inputs, training=None, mask=None): X_u, y_u, X_l, y_l = _prepare_elbo(self, inputs, training=training, mask=mask) y_l = tf.clip_by_value(y_l, 1e-8, 1. - 1e-8) px_z_u, (qz_x_u, qzc_x_u, qy_zx_u) = self(X_u, training=training) px_z_l, (qz_x_l, qzc_x_l, qy_zx_l) = self(X_l, training=training) z_exc = tf.concat( [tf.convert_to_tensor(qz_x_u), tf.convert_to_tensor(qz_x_l)], axis=0) z_c = tf.concat( [tf.convert_to_tensor(qzc_x_u), tf.convert_to_tensor(qzc_x_l)], axis=0) # Convert y to one-hot vector and Sample y for those without labels y_sup = y_l y_uns = tf.convert_to_tensor(qy_zx_u) y = tf.concat((y_uns, y_sup), axis=0) # log q(y|z_c) h = tf.concat([qy_zx_u.logits, qy_zx_l.logits], axis=0) log_q_y_zc = tf.reduce_sum(h * y, axis=1) # log p(x|z) log_p_x_z = tf.concat([px_z_u.log_prob(X_u), px_z_l.log_prob(X_l)], axis=0) # log p(z_c|y) pzc_y = self.regressor(y) log_p_zc_y = pzc_y.log_prob(z_c) # log p(z_\c) dist = Normal(tf.cast(0., self.dtype), 1.) log_p_zexc = tf.reduce_sum(dist.log_prob(z_exc), axis=-1) # log p(z|y) log_p_z_y = log_p_zc_y + log_p_zexc # log q(y|x) (Draw 128 points from q(z_c|x). Supervised samples only) h = qzc_x_l.sample(self.n_resamples) h = tf.reshape(h, (-1, h.shape[-1])) qy_x = self.classify(h, training=training) qy_x_logits = tf.reshape(qy_x.logits, (self.n_resamples, -1, h.shape[-1])) h = tf.reduce_logsumexp(h, axis=0) - tf.math.log(128.) log_q_y_x = tf.reduce_sum(h * y_l, axis=1) # log q(z|x) log_qz_x = tf.concat([qz_x_u.log_prob(qz_x_u), qz_x_l.log_prob(qz_x_l)], axis=0) log_qzc_x = tf.concat( [qzc_x_u.log_prob(qzc_x_u), qzc_x_l.log_prob(qzc_x_l)], axis=0) log_q_z_x = log_qz_x + log_qzc_x # Calculate the lower bound n_uns = ps.shape(X_u)[0] h = log_p_x_z + log_p_z_y - log_q_y_zc - log_q_z_x coef_sup = tf.math.exp(log_q_y_zc[n_uns:] - log_q_y_x) coef_uns = tf.ones((n_uns,), dtype=self.dtype) coef = tf.concat((coef_uns, coef_sup), axis=0) zeros = tf.zeros((n_uns,), dtype=self.dtype) lb = coef * h + tf.concat((zeros, log_q_y_x), axis=0) return {'elbo': lb}, {}
def compute_KL_univariate_prior(univariateprior, theta, sample): """ :param prior: assuming univariate prior of Normal(m,s); :param posterior: (theta: mean,std) to create posterior q(w/theta) i.e. Normal(mean,std) :param sample: Number of sample """ sample = tf.reshape(sample, [-1]) #flatten vector (mean, std) = theta mean = tf.reshape(mean, [-1]) std = tf.reshape(std, [-1]) posterior = Normal(mean, std) (mean2, std2) = univariateprior prior = Normal(mean2, std2) q_theta = tf.reduce_sum(posterior.log_prob(sample)) p_d = tf.reduce_sum(prior.log_prob(sample)) KL = tf.subtract(q_theta, p_d) return KL
def get_KL_univariate_prior(univariateprior, theta, sample): """ :param prior: assuming univatier prior of Normal(m,s); i.e. Normal(s, :param posterior: (theta: mean,std) to create posterior q(w/theta) i.e. Normal(mean,std) :param sample: :return: """ sample = tf.reshape(sample, [-1]) #flatten vector (mean, std) = theta (mean2, std2) = univariateprior prior = Normal(mean2, std2) posterior = Normal(mean, std) q_theta = tf.reduce_sum(posterior.log_prob(sample)) p_d = tf.reduce_sum(prior.log_prob(sample)) KL = tf.subtract(q_theta, p_d) return KL
def get_KL_multivariate_prior(self, multivariateprior, theta, sample): """ :param prior: assuming univatier prior of Normal(m,s); i.e. Normal(m1,s1) and Normal(m2,s2) :param posterior: (theta: mean,std) to create posterior q(w/theta) i.e. Normal(mean,std) :param sample: :return: """ sample = tf.reshape(sample, [-1]) #flatten vector (mean, std) = theta posterior = Normal(mean, std) (std1, std2) = multivariateprior prior1 = Normal(0, std1) prior2 = Normal(0, std2) q_theta = tf.reduce_sum(posterior.log_prob(sample)) p1 = tf.reduce_sum(prior1.log_prob(sample)) p2 = tf.reduce_sum( prior2.log_prob(sample)) #this is wrong need to work this out KL = tf.subtract(q_theta, tf.reduce_logsumexp([p1, p2])) return KL
def compute_KL_univariate_prior(self, univariateprior, theta, sample): """ :param prior: assuming univariate prior of Normal(m,s); :param posterior: (theta: mean,std) to create posterior q(w/theta) i.e. Normal(mean,std) :param sample: :return: KL (analytical) """ sample=tf.reshape(sample, [-1]) #flatten vector (mean,std)=theta mean =tf.reshape(mean, [-1]) std=tf.reshape(std, [-1]) posterior = Normal(mean, std) (mean2,std2) = univariateprior prior=Normal(mean2, std2) q_theta=tf.reduce_sum(posterior.log_prob(sample)) p_d=tf.reduce_sum(prior.log_prob(sample)) KL=tf.subtract(q_theta,p_d) print("computed KL loss" + self.layer_name) return KL
class NormalPyramid(Distribution): # means is indexed by *, y, x, channel # base_sigma is a scalar, and is the std used for the 'raw' pixels; subsequent levels use smaller stds def __init__(self, means, base_sigma, levels=None, validate_args=False, allow_nan_stats=True, name='NormalPyramid'): with ops.name_scope(name, values=[means, base_sigma]) as ns: self._means = array_ops.identity(means, name='means') self._base_sigma = array_ops.identity(base_sigma, name='base_sigma') self._base_dist = Normal(loc=self._means, scale=self._base_sigma) self._standard_normal = Normal(loc=0., scale=1.) self._levels = levels super(NormalPyramid, self).__init__(dtype=tf.float32, parameters={ 'means': means, 'base_sigma': base_sigma }, reparameterization_type=FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=ns) def _log_prob(self, x): # The resulting density here will be indexed by *, i.e. we sum over x, y, channel, and pyramid-levels z = (x - self._means) / self._base_sigma z_shape = list(map(int, z.get_shape())) z_pyramid = gaussian_pyramid(z, self._levels) return sum( tf.reduce_mean(self._standard_normal.log_prob(z_level), axis=[-3, -2, -1]) # ** check the rescaling here! for level_index, z_level in enumerate(z_pyramid)) / len(z_pyramid) def _sample_n(self, n, seed=None): return self._base_dist._sample_n(n, seed) def _mean(self): return self._means def _mode(self): return self._means
def latent_units(model: VariationalModel, valid_ds: tf.data.Dataset): weights = model.weights Qz = [] Pz = [] for x, y in valid_ds.take(10): _call(model, x, y, decode=True) qz, pz = model.get_latents(return_prior=True) Qz.append(as_tuple(qz)) Pz.append(as_tuple(pz)) n_latents = len(Qz[0]) for i in range(n_latents): qz: Sequence[Distribution] = [q[i] for q in Qz] pz: Sequence[Distribution] = [p[i] for p in Pz] # tracking kl kld = [] for q, p in zip(qz, pz): q = Normal(loc=q.mean(), scale=q.stddev()) z = q.sample() if isinstance(p, Vamprior): C = p.C p = p.distribution # [n_components, zdim] p = Normal(loc=p.mean(), scale=p.stddev()) kld.append( q.log_prob(z) - (tf.reduce_logsumexp( p.log_prob(tf.expand_dims(z, 1)), 1) - C)) else: p = Normal(loc=p.mean(), scale=p.stddev()) kld.append(q.log_prob(z) - p.log_prob(z)) kld = tf.reshape(tf.reduce_mean(tf.concat(kld, 0), axis=0), -1).numpy() # mean and stddev mean = tf.reduce_mean(tf.concat([d.mean() for d in qz], axis=0), 0) stddev = tf.reduce_mean( tf.concat([d.stddev() for d in qz], axis=0), 0) mean = tf.reshape(mean, -1).numpy() stddev = tf.reshape(stddev, -1).numpy() zdim = mean.shape[0] # the figure plt.figure(figsize=(12, 5), dpi=50) lines = [] ids = np.argsort(stddev) styles = dict(marker='o', markersize=2, linewidth=0, alpha=0.8) lines += plt.plot(mean[ids], label='mean', color='r', **styles) lines += plt.plot(stddev[ids], label='stddev', color='b', **styles) plt.grid(False) # show weights if exists plt.twinx() lines += plt.plot(kld[ids], label='KL(q|p)', linestyle='--', linewidth=1.0, alpha=0.6) for w in weights: name = w.name if w.shape.rank > 0 and w.shape[ 0] == zdim and '/kernel' in name: w = tf.linalg.norm(tf.reshape(w, (w.shape[0], -1)), axis=1).numpy() lines += plt.plot(w[ids], label=name.split(':')[0], linestyle='--', alpha=0.6) plt.grid(False) plt.legend(lines, [ln.get_label() for ln in lines], fontsize=6) # save summary tf.summary.image(f'z{i}', vs.plot_to_image(plt.gcf()), step=model.step) tf.summary.histogram(f'z{i}/mean', mean, step=model.step) tf.summary.histogram(f'z{i}/stddev', stddev, step=model.step) tf.summary.histogram(f'z{i}/kld', kld, step=model.step)
from tensorflow.keras.losses import MeanSquaredError from tensorflow.keras.models import Model from tensorflow_probability.python.distributions import Normal env = gym.make('MountainCarContinuous-v0') state = env.reset() state_input = Input(state.shape, name='state_input') actor_dense1_layer = Dense(4, activation=relu, name='actor_dense1') actor_dense1 = actor_dense1_layer(state_input) actor_out_layer = Dense(2, activation=None, name='actor_nn_out') actor_out = actor_out_layer(actor_dense1) mean = actor_out[:, 0] std = softplus(actor_out[:, 1]) dist = Normal(mean, std, name='dist') action = dist.sample((), name='sample') action_log_prob = dist.log_prob(action, name='log_prob') action_log_prob = expand_dims(action_log_prob) action = expand_dims(action) action = clip(action, -1.0, 1.0) critic_concat = Concatenate()([state_input, action]) critic_dense1_layer = Dense(units=8, activation=relu, name='critic_dense1') critic_dense1 = critic_dense1_layer(critic_concat) critic_out_layer = Dense(1, activation=None) critic_out = critic_out_layer(critic_dense1) actor_critic_model = Model(inputs=[state_input], outputs=[action, action_log_prob, critic_out]) critic_loss_op = MeanSquaredError(name='critic_loss') actor_vars = actor_dense1_layer.trainable_variables + actor_out_layer.trainable_variables critic_vars = critic_dense1_layer.trainable_variables + critic_out_layer.trainable_variables GAMES = 100 LR = 1e-4
def get_log_prob(self, states, actions): mean, std = self._get_dist(states) dist = Normal(mean, std) log_prob = tf.reduce_sum(dist.log_prob(actions), -1) return log_prob
def call(self, states, **kwargs): mean, std = self._get_dist(states) dist = Normal(mean, std) action = dist.sample() log_prob = tf.reduce_sum(dist.log_prob(action), -1) return action, log_prob
def evaluate(vae: VariationalAutoencoder, ds: ImageDataset, expdir: str, title: str, batch_size: int = 64, take_count: int = -1, n_images: int = 36, seed: int = 1): n_rows = int(np.sqrt(n_images)) is_semi = vae.is_semi_supervised() is_hierarchical = vae.is_hierarchical() ds_kw = dict(batch_size=batch_size, label_percent=1.0, shuffle=False) ## prepare rand = np.random.RandomState(seed=seed) if not os.path.exists(expdir): os.makedirs(expdir) ## data for training semi-supervised train = ds.create_dataset('train', **ds_kw) (llkx_train, llky_train, x_org_train, x_rec_train, y_true_train, y_pred_train, z_train, pz_train) = _call(vae, ds=train, rand=rand, take_count=take_count, n_images=n_images, verbose=True) ## data for testing test = ds.create_dataset('test', **ds_kw) (llkx_test, llky_test, x_org_test, x_rec_test, y_true_test, y_pred_test, z_test, pz_test) = _call(vae, ds=test, rand=rand, take_count=take_count, n_images=n_images, verbose=True) # === 0. plotting latent-factor pairs for idx, z in enumerate(z_test): z = z.mean() f = y_true_test corr_mat = Correlation.Spearman(z, f) # [n_latents, n_factors] plot_latents_pairs(z, f, corr_mat, ds.labels) vs.plot_save(f'{expdir}/latent{idx}_factor.pdf', dpi=100, verbose=True) # === 0. latent traverse plot x_travs = x_org_test if x_travs.ndim == 3: # grayscale image x_travs = np.expand_dims(x_travs, -1) else: # color image x_travs = np.transpose(x_travs, (0, 2, 3, 1)) x_travs = x_travs[rand.permutation(x_travs.shape[0])] n_visual_samples = 5 n_traverse_points = 21 n_top_latents = 10 plt.figure(figsize=(8, 3 * n_visual_samples)) for i in range(n_visual_samples): images = vae.sample_traverse(x_travs[i:i + 1], min_val=-np.min(z_test[0].mean()), max_val=np.max(z_test[0].mean()), n_best_latents=n_top_latents, n_traverse_points=n_traverse_points, mode='linear') images = as_tuple(images)[0] images = _prepare_images(images.mean().numpy(), normalize=True) vs.plot_images(images, grids=(n_top_latents, n_traverse_points), ax=(n_visual_samples, 1, i + 1)) if i == 0: plt.title('Latents traverse') plt.tight_layout() vs.plot_save(f'{expdir}/latents_traverse.pdf', dpi=180, verbose=True) # === 0. prior sampling plot images = as_tuple(vae.sample_observation(n=n_images, seed=seed))[0] images = _prepare_images(images.mean().numpy(), normalize=True) plt.figure(figsize=(5, 5)) vs.plot_images(images, grids=(n_rows, n_rows), title='Sampled') # === 1. reconstruction plot plt.figure(figsize=(15, 15)) vs.plot_images(x_org_train, grids=(n_rows, n_rows), ax=(2, 2, 1), title='[Train]Original') vs.plot_images(x_rec_train, grids=(n_rows, n_rows), ax=(2, 2, 2), title='[Train]Reconstructed') vs.plot_images(x_org_test, grids=(n_rows, n_rows), ax=(2, 2, 3), title='[Test]Original') vs.plot_images(x_rec_test, grids=(n_rows, n_rows), ax=(2, 2, 4), title='[Test]Reconstructed') plt.tight_layout() ## prepare the labels label_type = ds.label_type if label_type == 'categorical': labels_name = ds.labels true = np.argmax(y_true_test, axis=-1) labels_true = np.array([labels_name[i] for i in true]) labels_pred = labels_true if is_semi: pred = np.argmax(y_pred_test.mean().numpy(), axis=-1) labels_pred = np.array([labels_name[i] for i in pred]) elif label_type == 'factor': # dsprites, shapes3d labels_name = ['cube', 'cylinder', 'sphere', 'round'] \ if 'shapes3d' in ds.name else ['square', 'ellipse', 'heart'] true = y_true_test[:, 2].astype('int32') labels_true = np.array([labels_name[i] for i in true]) labels_pred = labels_true if is_semi: pred = get_ymean(y_pred_test)[:, 2].astype('int32') labels_pred = np.array([labels_name[i] for i in pred]) else: # CelebA raise NotImplementedError ## confusion matrix if is_semi: plt.figure(figsize=(8, 8)) acc = accuracy_score(y_true=true, y_pred=pred) vs.plot_confusion_matrix(cm=confusion_matrix(y_true=true, y_pred=pred), labels=labels_name, cbar=True, fontsize=10, title=f'{title} Acc:{acc:.2f}') ## save arrays for later inspections with open(f'{expdir}/arrays', 'wb') as f: pickle.dump( dict(z_train=z_train, y_pred_train=y_pred_train, y_true_train=y_true_train, z_test=z_test, y_pred_test=y_pred_test, y_true_test=y_true_test, labels=labels_name, ds=ds.name, label_type=label_type), f) print(f'Exported arrays to "{expdir}/arrays"') ## semi-supervised z_mean_train = np.concatenate( [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_train], -1) z_mean_test = np.concatenate( [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_test], -1) # === 2. scatter points latents plot n_points = 5000 ids = rand.permutation(len(labels_true))[:n_points] Y_true = labels_true[ids] Y_pred = labels_pred[ids] # tsne plot n_latents = 0 if len(z_train) == 1 else len(z_train) for name, X in zip( ['all'] + [f'latents{i}' for i in range(n_latents)], [z_mean_test[ids]] + [z_test[i].mean().numpy()[ids] for i in range(n_latents)]): print(f'Plot scatter points for {name}') X = X.reshape(X.shape[0], -1) # flatten to 2D X = Pipeline([('zscore', StandardScaler()), ('pca', PCA(min(X.shape[1], 512), random_state=seed))]).fit_transform(X) tsne = DimReduce.TSNE(X, n_components=2, framework='sklearn') kw = dict(x=tsne[:, 0], y=tsne[:, 1], grid=False, size=12.0, alpha=0.8) plt.figure(figsize=(12, 6)) vs.plot_scatter(color=Y_true, title=f'[True]{title}-{name}', ax=(1, 2, 1), **kw) vs.plot_scatter(color=Y_pred, title=f'[Pred]{title}-{name}', ax=(1, 2, 2), **kw) ## save all plot vs.plot_save(f'{expdir}/analysis.pdf', dpi=180, verbose=True) # === 3. show the latents statistics n_latents = len(z_train) colors = sns.color_palette(n_colors=len(labels_true)) styles = dict(grid=False, ticks_off=False, alpha=0.6, xlabel='mean', ylabel='stddev') # scatter between latents and labels (assume categorical distribution) def _show_latents_labels(Z, Y, title): plt.figure(figsize=(5 * n_latents, 5), dpi=150) for idx, z in enumerate(Z): if len(z.batch_shape) == 0: mean = np.repeat(np.expand_dims(z.mean(), 0), Y.shape[0], 0) stddev = z.sample(Y.shape[0]) - mean else: mean = flatten(z.mean()) stddev = flatten(z.stddev()) y = np.argmax(Y, axis=-1) data = [[], [], []] for y_i, c in zip(np.unique(y), colors): mask = (y == y_i) data[0].append(np.mean(mean[mask], 0)) data[1].append(np.mean(stddev[mask], 0)) data[2].append([labels_true[y_i]] * mean.shape[1]) vs.plot_scatter( x=np.concatenate(data[0], 0), y=np.concatenate(data[1], 0), color=np.concatenate(data[2], 0), ax=(1, n_latents, idx + 1), size=15 if mean.shape[1] < 128 else 8, title=f'[Test-{title}]#{idx} - {mean.shape[1]} (units)', **styles) plt.tight_layout() # simple scatter mean-stddev each latents def _show_latents(Z, title): plt.figure(figsize=(3.5 * n_latents, 3.5), dpi=150) for idx, z in enumerate(Z): mean = flatten(z.mean()) stddev = flatten(z.stddev()) if mean.ndim == 2: mean = np.mean(mean, 0) stddev = np.mean(stddev, 0) vs.plot_scatter( x=mean, y=stddev, ax=(1, n_latents, idx + 1), size=15 if len(mean) < 128 else 8, title=f'[Test-{title}]#{idx} - {len(mean)} (units)', **styles) _show_latents_labels(z_test, y_true_test, 'post') _show_latents_labels(pz_test, y_true_test, 'prior') _show_latents(z_test, 'post') _show_latents(pz_test, 'prior') # KL statistics vs.plot_figure() for idx, (qz, pz) in enumerate(zip(z_test, pz_test)): kl = [] qz = Normal(loc=qz.mean(), scale=qz.stddev(), name=f'posterior{idx}') pz = Normal(loc=pz.mean(), scale=pz.stddev(), name=f'prior{idx}') for s, e in minibatch(batch_size=8, n=100): z = qz.sample(e - s) # don't do this in GPU, it explodes! kl.append((qz.log_prob(z) - pz.log_prob(z)).numpy()) kl = np.concatenate(kl, 0) # (mcmc, batch, event) # per sample kl_samples = np.sum(kl, as_tuple(list(range(2, kl.ndim)))) kl_samples = logsumexp(kl_samples, 0) plt.subplot(n_latents, 2, idx * 2 + 1) sns.histplot(kl_samples, bins=50) plt.title(f'Z#{idx} KL per sample (nats)') # per latent kl_latents = np.mean(flatten(logsumexp(kl, 0)), 0) plt.subplot(n_latents, 2, idx * 2 + 2) plt.plot(np.sort(kl_latents)) plt.title(f'Z#{idx} KL per dim (nats)') plt.tight_layout() vs.plot_save(f'{expdir}/latents.pdf', dpi=180, verbose=True)
class TanhNormal(Distribution): """ Represent distribution of X where X ~ tanh(Z) Z ~ N(mean, std) Note: this is not very numerically stable. """ def __init__(self, normal_mean, normal_std, epsilon=1e-6): """ :param normal_mean: Mean of the normal distribution :param normal_std: Std of the normal distribution :param epsilon: Numerical stability epsilon when computing log-prob. """ self.normal_mean = normal_mean self.normal_std = normal_std self.normal = Normal(normal_mean, normal_std) self.epsilon = epsilon def sample_n(self, n, return_pre_tanh_value=False): z = self.normal.sample(n) if return_pre_tanh_value: return tf.math.tanh(z), z else: return tf.math.tanh(z) def log_prob(self, value, pre_tanh_value=None): """ :param value: some value, x :param pre_tanh_value: arctanh(x) :return: """ if pre_tanh_value is None: pre_tanh_value = tf.math.log((1 + value) / (1 - value)) / 2 return self.normal.log_prob(pre_tanh_value) - tf.math.log( 1 - value * value + self.epsilon) def sample(self, return_pretanh_value=False): """ Gradients will and should *not* pass through this operation. See https://github.com/pytorch/pytorch/issues/4620 for discussion. """ z = self.normal.sample() if return_pretanh_value: return tf.math.tanh(z), z else: return tf.math.tanh(z) def rsample(self, return_pretanh_value=False): """ Sampling in the reparameterization case. """ z = (self.normal_mean + self.normal_std * Normal(tf.zeros_like(self.normal_mean), tf.ones_like(self.normal_std)).sample()) if return_pretanh_value: return tf.math.tanh(z), z else: return tf.math.tanh(z)