def latent_units(model: VariationalModel, valid_ds: tf.data.Dataset): weights = model.weights Qz = [] Pz = [] for x, y in valid_ds.take(10): _call(model, x, y, decode=True) qz, pz = model.get_latents(return_prior=True) Qz.append(as_tuple(qz)) Pz.append(as_tuple(pz)) n_latents = len(Qz[0]) for i in range(n_latents): qz: Sequence[Distribution] = [q[i] for q in Qz] pz: Sequence[Distribution] = [p[i] for p in Pz] # tracking kl kld = [] for q, p in zip(qz, pz): q = Normal(loc=q.mean(), scale=q.stddev()) z = q.sample() if isinstance(p, Vamprior): C = p.C p = p.distribution # [n_components, zdim] p = Normal(loc=p.mean(), scale=p.stddev()) kld.append( q.log_prob(z) - (tf.reduce_logsumexp( p.log_prob(tf.expand_dims(z, 1)), 1) - C)) else: p = Normal(loc=p.mean(), scale=p.stddev()) kld.append(q.log_prob(z) - p.log_prob(z)) kld = tf.reshape(tf.reduce_mean(tf.concat(kld, 0), axis=0), -1).numpy() # mean and stddev mean = tf.reduce_mean(tf.concat([d.mean() for d in qz], axis=0), 0) stddev = tf.reduce_mean( tf.concat([d.stddev() for d in qz], axis=0), 0) mean = tf.reshape(mean, -1).numpy() stddev = tf.reshape(stddev, -1).numpy() zdim = mean.shape[0] # the figure plt.figure(figsize=(12, 5), dpi=50) lines = [] ids = np.argsort(stddev) styles = dict(marker='o', markersize=2, linewidth=0, alpha=0.8) lines += plt.plot(mean[ids], label='mean', color='r', **styles) lines += plt.plot(stddev[ids], label='stddev', color='b', **styles) plt.grid(False) # show weights if exists plt.twinx() lines += plt.plot(kld[ids], label='KL(q|p)', linestyle='--', linewidth=1.0, alpha=0.6) for w in weights: name = w.name if w.shape.rank > 0 and w.shape[ 0] == zdim and '/kernel' in name: w = tf.linalg.norm(tf.reshape(w, (w.shape[0], -1)), axis=1).numpy() lines += plt.plot(w[ids], label=name.split(':')[0], linestyle='--', alpha=0.6) plt.grid(False) plt.legend(lines, [ln.get_label() for ln in lines], fontsize=6) # save summary tf.summary.image(f'z{i}', vs.plot_to_image(plt.gcf()), step=model.step) tf.summary.histogram(f'z{i}/mean', mean, step=model.step) tf.summary.histogram(f'z{i}/stddev', stddev, step=model.step) tf.summary.histogram(f'z{i}/kld', kld, step=model.step)
def evaluate(vae: VariationalAutoencoder, ds: ImageDataset, expdir: str, title: str, batch_size: int = 64, take_count: int = -1, n_images: int = 36, seed: int = 1): n_rows = int(np.sqrt(n_images)) is_semi = vae.is_semi_supervised() is_hierarchical = vae.is_hierarchical() ds_kw = dict(batch_size=batch_size, label_percent=1.0, shuffle=False) ## prepare rand = np.random.RandomState(seed=seed) if not os.path.exists(expdir): os.makedirs(expdir) ## data for training semi-supervised train = ds.create_dataset('train', **ds_kw) (llkx_train, llky_train, x_org_train, x_rec_train, y_true_train, y_pred_train, z_train, pz_train) = _call(vae, ds=train, rand=rand, take_count=take_count, n_images=n_images, verbose=True) ## data for testing test = ds.create_dataset('test', **ds_kw) (llkx_test, llky_test, x_org_test, x_rec_test, y_true_test, y_pred_test, z_test, pz_test) = _call(vae, ds=test, rand=rand, take_count=take_count, n_images=n_images, verbose=True) # === 0. plotting latent-factor pairs for idx, z in enumerate(z_test): z = z.mean() f = y_true_test corr_mat = Correlation.Spearman(z, f) # [n_latents, n_factors] plot_latents_pairs(z, f, corr_mat, ds.labels) vs.plot_save(f'{expdir}/latent{idx}_factor.pdf', dpi=100, verbose=True) # === 0. latent traverse plot x_travs = x_org_test if x_travs.ndim == 3: # grayscale image x_travs = np.expand_dims(x_travs, -1) else: # color image x_travs = np.transpose(x_travs, (0, 2, 3, 1)) x_travs = x_travs[rand.permutation(x_travs.shape[0])] n_visual_samples = 5 n_traverse_points = 21 n_top_latents = 10 plt.figure(figsize=(8, 3 * n_visual_samples)) for i in range(n_visual_samples): images = vae.sample_traverse(x_travs[i:i + 1], min_val=-np.min(z_test[0].mean()), max_val=np.max(z_test[0].mean()), n_best_latents=n_top_latents, n_traverse_points=n_traverse_points, mode='linear') images = as_tuple(images)[0] images = _prepare_images(images.mean().numpy(), normalize=True) vs.plot_images(images, grids=(n_top_latents, n_traverse_points), ax=(n_visual_samples, 1, i + 1)) if i == 0: plt.title('Latents traverse') plt.tight_layout() vs.plot_save(f'{expdir}/latents_traverse.pdf', dpi=180, verbose=True) # === 0. prior sampling plot images = as_tuple(vae.sample_observation(n=n_images, seed=seed))[0] images = _prepare_images(images.mean().numpy(), normalize=True) plt.figure(figsize=(5, 5)) vs.plot_images(images, grids=(n_rows, n_rows), title='Sampled') # === 1. reconstruction plot plt.figure(figsize=(15, 15)) vs.plot_images(x_org_train, grids=(n_rows, n_rows), ax=(2, 2, 1), title='[Train]Original') vs.plot_images(x_rec_train, grids=(n_rows, n_rows), ax=(2, 2, 2), title='[Train]Reconstructed') vs.plot_images(x_org_test, grids=(n_rows, n_rows), ax=(2, 2, 3), title='[Test]Original') vs.plot_images(x_rec_test, grids=(n_rows, n_rows), ax=(2, 2, 4), title='[Test]Reconstructed') plt.tight_layout() ## prepare the labels label_type = ds.label_type if label_type == 'categorical': labels_name = ds.labels true = np.argmax(y_true_test, axis=-1) labels_true = np.array([labels_name[i] for i in true]) labels_pred = labels_true if is_semi: pred = np.argmax(y_pred_test.mean().numpy(), axis=-1) labels_pred = np.array([labels_name[i] for i in pred]) elif label_type == 'factor': # dsprites, shapes3d labels_name = ['cube', 'cylinder', 'sphere', 'round'] \ if 'shapes3d' in ds.name else ['square', 'ellipse', 'heart'] true = y_true_test[:, 2].astype('int32') labels_true = np.array([labels_name[i] for i in true]) labels_pred = labels_true if is_semi: pred = get_ymean(y_pred_test)[:, 2].astype('int32') labels_pred = np.array([labels_name[i] for i in pred]) else: # CelebA raise NotImplementedError ## confusion matrix if is_semi: plt.figure(figsize=(8, 8)) acc = accuracy_score(y_true=true, y_pred=pred) vs.plot_confusion_matrix(cm=confusion_matrix(y_true=true, y_pred=pred), labels=labels_name, cbar=True, fontsize=10, title=f'{title} Acc:{acc:.2f}') ## save arrays for later inspections with open(f'{expdir}/arrays', 'wb') as f: pickle.dump( dict(z_train=z_train, y_pred_train=y_pred_train, y_true_train=y_true_train, z_test=z_test, y_pred_test=y_pred_test, y_true_test=y_true_test, labels=labels_name, ds=ds.name, label_type=label_type), f) print(f'Exported arrays to "{expdir}/arrays"') ## semi-supervised z_mean_train = np.concatenate( [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_train], -1) z_mean_test = np.concatenate( [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_test], -1) # === 2. scatter points latents plot n_points = 5000 ids = rand.permutation(len(labels_true))[:n_points] Y_true = labels_true[ids] Y_pred = labels_pred[ids] # tsne plot n_latents = 0 if len(z_train) == 1 else len(z_train) for name, X in zip( ['all'] + [f'latents{i}' for i in range(n_latents)], [z_mean_test[ids]] + [z_test[i].mean().numpy()[ids] for i in range(n_latents)]): print(f'Plot scatter points for {name}') X = X.reshape(X.shape[0], -1) # flatten to 2D X = Pipeline([('zscore', StandardScaler()), ('pca', PCA(min(X.shape[1], 512), random_state=seed))]).fit_transform(X) tsne = DimReduce.TSNE(X, n_components=2, framework='sklearn') kw = dict(x=tsne[:, 0], y=tsne[:, 1], grid=False, size=12.0, alpha=0.8) plt.figure(figsize=(12, 6)) vs.plot_scatter(color=Y_true, title=f'[True]{title}-{name}', ax=(1, 2, 1), **kw) vs.plot_scatter(color=Y_pred, title=f'[Pred]{title}-{name}', ax=(1, 2, 2), **kw) ## save all plot vs.plot_save(f'{expdir}/analysis.pdf', dpi=180, verbose=True) # === 3. show the latents statistics n_latents = len(z_train) colors = sns.color_palette(n_colors=len(labels_true)) styles = dict(grid=False, ticks_off=False, alpha=0.6, xlabel='mean', ylabel='stddev') # scatter between latents and labels (assume categorical distribution) def _show_latents_labels(Z, Y, title): plt.figure(figsize=(5 * n_latents, 5), dpi=150) for idx, z in enumerate(Z): if len(z.batch_shape) == 0: mean = np.repeat(np.expand_dims(z.mean(), 0), Y.shape[0], 0) stddev = z.sample(Y.shape[0]) - mean else: mean = flatten(z.mean()) stddev = flatten(z.stddev()) y = np.argmax(Y, axis=-1) data = [[], [], []] for y_i, c in zip(np.unique(y), colors): mask = (y == y_i) data[0].append(np.mean(mean[mask], 0)) data[1].append(np.mean(stddev[mask], 0)) data[2].append([labels_true[y_i]] * mean.shape[1]) vs.plot_scatter( x=np.concatenate(data[0], 0), y=np.concatenate(data[1], 0), color=np.concatenate(data[2], 0), ax=(1, n_latents, idx + 1), size=15 if mean.shape[1] < 128 else 8, title=f'[Test-{title}]#{idx} - {mean.shape[1]} (units)', **styles) plt.tight_layout() # simple scatter mean-stddev each latents def _show_latents(Z, title): plt.figure(figsize=(3.5 * n_latents, 3.5), dpi=150) for idx, z in enumerate(Z): mean = flatten(z.mean()) stddev = flatten(z.stddev()) if mean.ndim == 2: mean = np.mean(mean, 0) stddev = np.mean(stddev, 0) vs.plot_scatter( x=mean, y=stddev, ax=(1, n_latents, idx + 1), size=15 if len(mean) < 128 else 8, title=f'[Test-{title}]#{idx} - {len(mean)} (units)', **styles) _show_latents_labels(z_test, y_true_test, 'post') _show_latents_labels(pz_test, y_true_test, 'prior') _show_latents(z_test, 'post') _show_latents(pz_test, 'prior') # KL statistics vs.plot_figure() for idx, (qz, pz) in enumerate(zip(z_test, pz_test)): kl = [] qz = Normal(loc=qz.mean(), scale=qz.stddev(), name=f'posterior{idx}') pz = Normal(loc=pz.mean(), scale=pz.stddev(), name=f'prior{idx}') for s, e in minibatch(batch_size=8, n=100): z = qz.sample(e - s) # don't do this in GPU, it explodes! kl.append((qz.log_prob(z) - pz.log_prob(z)).numpy()) kl = np.concatenate(kl, 0) # (mcmc, batch, event) # per sample kl_samples = np.sum(kl, as_tuple(list(range(2, kl.ndim)))) kl_samples = logsumexp(kl_samples, 0) plt.subplot(n_latents, 2, idx * 2 + 1) sns.histplot(kl_samples, bins=50) plt.title(f'Z#{idx} KL per sample (nats)') # per latent kl_latents = np.mean(flatten(logsumexp(kl, 0)), 0) plt.subplot(n_latents, 2, idx * 2 + 2) plt.plot(np.sort(kl_latents)) plt.title(f'Z#{idx} KL per dim (nats)') plt.tight_layout() vs.plot_save(f'{expdir}/latents.pdf', dpi=180, verbose=True)