Ejemplo n.º 1
0
def to_image(X, grids):
    if X.shape[-1] == 1:  # grayscale image
        X = np.squeeze(X, axis=-1)
    else:  # color image
        X = np.transpose(X, (0, 3, 1, 2))
    nrows, ncols = grids
    fig = vs.plot_figure(nrows=nrows, ncols=ncols, dpi=100)
    vs.plot_images(X, grids=grids)
    image = vs.plot_to_image(fig)
    return image
Ejemplo n.º 2
0
def save_images(pX_Z, name, step, path):
    X = pX_Z.mean().numpy()
    if X.shape[-1] == 1:
        X = np.squeeze(X, axis=-1)
    else:
        X = np.transpose(X, (0, 3, 1, 2))
    fig = vs.plot_figure(nrow=16, ncol=16, dpi=60)
    vs.plot_images(X, fig=fig, title="[%s]#Iter: %d" % (name, step))
    fig.savefig(path, dpi=60)
    plt.close(fig)
    del X
Ejemplo n.º 3
0
def callback(vae: vi.VariationalAutoencoder, x: np.ndarray, y: np.ndarray):
    trainer = get_current_trainer()
    px, qz = [], []
    X_i = []
    for x_i in tf.data.Dataset.from_tensor_slices(x).batch(64):
        _ = vae(x_i, training=False)
        px.append(_[0])
        qz.append(_[1])
        X_i.append(x_i)
    # llk
    llk_test = tf.reduce_mean(
        tf.concat([p.log_prob(x_i) for p, x_i in zip(px, X_i)], axis=0))
    # latents
    qz_mean = tf.reduce_mean(tf.concat([q.mean() for q in qz], axis=0), axis=0)
    qz_std = tf.reduce_mean(tf.concat([q.stddev() for q in qz], axis=0),
                            axis=0)
    w = tf.reduce_sum(vae.decoder.trainable_variables[0], axis=1)
    # plot the latents and its weights
    fig = plt.figure(figsize=(6, 4), dpi=200)
    ax = plt.gca()
    l1 = ax.plot(qz_mean,
                 label='mean',
                 linewidth=1.0,
                 linestyle='--',
                 marker='o',
                 markersize=4,
                 color='r',
                 alpha=0.5)
    l2 = ax.plot(qz_std,
                 label='std',
                 linewidth=1.0,
                 linestyle='--',
                 marker='o',
                 markersize=4,
                 color='g',
                 alpha=0.5)
    ax1 = ax.twinx()
    l3 = ax1.plot(w,
                  label='weight',
                  linewidth=1.0,
                  linestyle='--',
                  marker='o',
                  markersize=4,
                  color='b',
                  alpha=0.5)
    lines = l1 + l2 + l3
    labs = [l.get_label() for l in lines]
    ax.grid(True)
    ax.legend(lines, labs)
    img_qz = vs.plot_to_image(fig)
    # reconstruction
    img = px[10].mean().numpy()
    if img.shape[-1] == 1:
        img = np.squeeze(img, axis=-1)
    fig = plt.figure(figsize=(8, 8), dpi=120)
    vs.plot_images(img, grids=(8, 8))
    img_reconstructed = vs.plot_to_image(fig)
    # latents traverse
    # TODO
    return dict(llk_test=llk_test,
                qz_mean=qz_mean,
                qz_std=qz_std,
                w_decoder=w,
                reconstructed=img_reconstructed,
                latents=img_qz)
Ejemplo n.º 4
0
def evaluate(vae,
             ds,
             expdir: str,
             title: str,
             batch_size: int = 32,
             seed: int = 1):
    from odin.bay.vi import Correlation
    rand = np.random.RandomState(seed=seed)
    if not os.path.exists(expdir):
        os.makedirs(expdir)
    tanh = True if ds.name.lower() == 'celeba' else False
    ## data for training semi-supervised
    # careful don't allow any data leakage!
    train = ds.create_dataset('train',
                              batch_size=batch_size,
                              label_percent=True,
                              shuffle=False,
                              normalize='tanh' if tanh else 'probs')
    data = [(vae.encode(x, training=False), y) \
      for x, y in tqdm(train, desc=title)]
    x_semi_train = tf.concat(
        [tf.concat([i.mean(), _ymean(j)], axis=1) for (i, j), _ in data],
        axis=0).numpy()
    y_semi_train = tf.concat([i for _, i in data], axis=0).numpy()
    # shuffle
    ids = rand.permutation(x_semi_train.shape[0])
    x_semi_train = x_semi_train[ids]
    y_semi_train = y_semi_train[ids]
    ## data for testing
    test = ds.create_dataset('test',
                             batch_size=batch_size,
                             label_percent=True,
                             shuffle=False,
                             normalize='tanh' if tanh else 'probs')
    prog = tqdm(test, desc=title)
    llk_x = []
    llk_y = []
    z = []
    y_true = []
    y_pred = []
    x_true = []
    x_pred = []
    x_org, x_rec = [], []
    for x, y in prog:
        px, (qz, qy) = vae(x, training=False)
        y_true.append(y)
        y_pred.append(_ymean(qy))
        z.append(qz.mean())
        llk_x.append(px.log_prob(x))
        llk_y.append(qy.log_prob(y))
        if rand.uniform() < 0.005 or len(x_org) < 2:
            x_org.append(x)
            x_rec.append(px.mean())
    ## llk
    llk_x = tf.reduce_mean(tf.concat(llk_x, axis=0)).numpy()
    llk_y = tf.reduce_mean(tf.concat(llk_y, axis=0)).numpy()
    ## the latents
    z = tf.concat(z, axis=0).numpy()
    y_true = tf.concat(y_true, axis=0).numpy()
    y_pred = tf.concat(y_pred, axis=0).numpy()
    x_semi_test = tf.concat([z, y_pred], axis=-1).numpy()
    # shuffle
    ids = rand.permutation(z.shape[0])
    z = z[ids]
    y_true = y_true[ids]
    y_pred = y_pred[ids]
    x_semi_test = x_semi_test[ids]
    ## saving reconstruction images
    x_org = tf.concat(x_org, axis=0).numpy()
    x_rec = tf.concat(x_rec, axis=0).numpy()
    ids = rand.permutation(x_org.shape[0])
    x_org = x_org[ids][:36]
    x_rec = x_rec[ids][:36]
    vmin = x_rec.reshape((36, -1)).min(axis=1).reshape((36, 1, 1, 1))
    vmax = x_rec.reshape((36, -1)).max(axis=1).reshape((36, 1, 1, 1))
    if tanh:
        x_org = (x_org + 1.) / 2.
    x_rec = (x_rec - vmin) / (vmax - vmin)
    if x_org.shape[-1] == 1:  # grayscale image
        x_org = np.squeeze(x_org, -1)
        x_rec = np.squeeze(x_rec, -1)
    else:  # color image
        x_org = np.transpose(x_org, (0, 3, 1, 2))
        x_rec = np.transpose(x_rec, (0, 3, 1, 2))
    plt.figure(figsize=(15, 8))
    ax = plt.subplot(1, 2, 1)
    vs.plot_images(x_org, grids=(6, 6), ax=ax, title='Original')
    ax = plt.subplot(1, 2, 2)
    vs.plot_images(x_rec, grids=(6, 6), ax=ax, title='Reconstructed')
    plt.tight_layout()
    ## prepare the labels
    if ds.name in ('mnist', 'fashionmnist', 'celeba'):
        true = np.argmax(y_true, axis=-1)
        pred = np.argmax(y_pred, axis=-1)
        y_semi_train = np.argmax(y_semi_train, axis=-1)
        y_semi_test = true
        labels_name = ds.labels
    else:  # shapes3d dsprites
        true = y_true[:, 2].astype(np.int32)
        pred = y_pred[:, 2].astype(np.int32)
        y_semi_train = y_semi_train[:, 2].astype(np.int32)
        y_semi_test = true
        if ds.name == 'shapes3d':
            labels_name = ['cube', 'cylinder', 'sphere', 'round']
        elif ds.name == 'dsprites':
            labels_name = ['square', 'ellipse', 'heart']
    plt.figure(figsize=(8, 8))
    vs.plot_confusion_matrix(cm=confusion_matrix(y_true=true, y_pred=pred),
                             labels=labels_name,
                             cbar=True,
                             fontsize=10,
                             title=title)
    labels = np.array([labels_name[i] for i in true])
    labels_pred = np.array([labels_name[i] for i in pred])
    ## save arrays for later inspectation
    np.savez_compressed(f'{expdir}/arrays',
                        x_train=x_semi_train,
                        y_train=y_semi_train,
                        x_test=x_semi_test,
                        y_test=y_semi_test,
                        zdim=z.shape[1],
                        labels=labels_name)
    print(f'Export arrays to "{expdir}/arrays.npz"')
    ## semi-supervised
    with open(f'{expdir}/results.txt', 'w') as f:
        print(f'Export results to "{expdir}/results.txt"')
        f.write(f'Steps: {vae.step.numpy()}\n')
        f.write(f'llk_x: {llk_x}\n')
        f.write(f'llk_y: {llk_y}\n')
        for p in [0.004, 0.06, 0.2, 0.99]:
            x_train, x_test, y_train, y_test = train_test_split(
                x_semi_train,
                y_semi_train,
                train_size=int(np.round(p * x_semi_train.shape[0])),
                random_state=1,
            )
            m = LogisticRegression(max_iter=3000, random_state=1)
            m.fit(x_train, y_train)
            # write the report
            f.write(f'{m.__class__.__name__} Number of labels: '
                    f'{p} {x_train.shape[0]}/{x_test.shape[0]}')
            f.write('\nValidation:\n')
            f.write(
                classification_report(y_true=y_test, y_pred=m.predict(x_test)))
            f.write('\nTest:\n')
            f.write(
                classification_report(y_true=y_semi_test,
                                      y_pred=m.predict(x_semi_test)))
            f.write('------------\n')
    ## scatter plot
    n_points = 4000
    # tsne plot
    tsne = DimReduce.TSNE(z[:n_points], n_components=2)
    kw = dict(x=tsne[:, 0], y=tsne[:, 1], grid=False, size=12.0, alpha=0.6)
    plt.figure(figsize=(8, 8))
    vs.plot_scatter(color=labels[:n_points], title=f'[True-tSNE]{title}', **kw)
    plt.figure(figsize=(8, 8))
    vs.plot_scatter(color=labels_pred[:n_points],
                    title=f'[Pred-tSNE]{title}',
                    **kw)
    # pca plot
    pca = DimReduce.PCA(z, n_components=2)
    kw = dict(x=pca[:, 0], y=pca[:, 1], grid=False, size=12.0, alpha=0.6)
    plt.figure(figsize=(8, 8))
    vs.plot_scatter(color=labels, title=f'[True-PCA]{title}', **kw)
    plt.figure(figsize=(8, 8))
    vs.plot_scatter(color=labels_pred, title=f'[Pred-PCA]{title}', **kw)
    ## factors plot
    corr = (Correlation.Spearman(z, y_true) +
            Correlation.Pearson(z, y_true)) / 2.
    best_z = np.argsort(np.abs(corr), axis=0)[-2:]
    style = dict(size=15.0, alpha=0.6, grid=False)
    for fi, (z1, z2) in enumerate(best_z.T):
        plt.figure(figsize=(8, 4))
        ax = plt.subplot(1, 2, 1)
        vs.plot_scatter(x=z[:n_points, z1],
                        y=z[:n_points, z2],
                        val=y_true[:n_points, fi],
                        ax=ax,
                        title=ds.labels[fi],
                        **style)
        ax = plt.subplot(1, 2, 2)
        vs.plot_scatter(x=z[:n_points, z1],
                        y=z[:n_points, z2],
                        val=y_pred[:n_points, fi],
                        ax=ax,
                        title=ds.labels[fi],
                        **style)
        plt.tight_layout()
    ## save all plot
    vs.plot_save(f'{expdir}/analysis.pdf', dpi=180, verbose=True)
Ejemplo n.º 5
0
def evaluate(vae: VariationalAutoencoder,
             ds: ImageDataset,
             expdir: str,
             title: str,
             batch_size: int = 64,
             take_count: int = -1,
             n_images: int = 36,
             seed: int = 1):
    n_rows = int(np.sqrt(n_images))
    is_semi = vae.is_semi_supervised()
    is_hierarchical = vae.is_hierarchical()
    ds_kw = dict(batch_size=batch_size, label_percent=1.0, shuffle=False)
    ## prepare
    rand = np.random.RandomState(seed=seed)
    if not os.path.exists(expdir):
        os.makedirs(expdir)
    ## data for training semi-supervised
    train = ds.create_dataset('train', **ds_kw)
    (llkx_train, llky_train, x_org_train, x_rec_train, y_true_train,
     y_pred_train, z_train, pz_train) = _call(vae,
                                              ds=train,
                                              rand=rand,
                                              take_count=take_count,
                                              n_images=n_images,
                                              verbose=True)
    ## data for testing
    test = ds.create_dataset('test', **ds_kw)
    (llkx_test, llky_test, x_org_test, x_rec_test, y_true_test, y_pred_test,
     z_test, pz_test) = _call(vae,
                              ds=test,
                              rand=rand,
                              take_count=take_count,
                              n_images=n_images,
                              verbose=True)
    # === 0. plotting latent-factor pairs
    for idx, z in enumerate(z_test):
        z = z.mean()
        f = y_true_test
        corr_mat = Correlation.Spearman(z, f)  # [n_latents, n_factors]
        plot_latents_pairs(z, f, corr_mat, ds.labels)
        vs.plot_save(f'{expdir}/latent{idx}_factor.pdf', dpi=100, verbose=True)
    # === 0. latent traverse plot
    x_travs = x_org_test
    if x_travs.ndim == 3:  # grayscale image
        x_travs = np.expand_dims(x_travs, -1)
    else:  # color image
        x_travs = np.transpose(x_travs, (0, 2, 3, 1))
    x_travs = x_travs[rand.permutation(x_travs.shape[0])]
    n_visual_samples = 5
    n_traverse_points = 21
    n_top_latents = 10
    plt.figure(figsize=(8, 3 * n_visual_samples))
    for i in range(n_visual_samples):
        images = vae.sample_traverse(x_travs[i:i + 1],
                                     min_val=-np.min(z_test[0].mean()),
                                     max_val=np.max(z_test[0].mean()),
                                     n_best_latents=n_top_latents,
                                     n_traverse_points=n_traverse_points,
                                     mode='linear')
        images = as_tuple(images)[0]
        images = _prepare_images(images.mean().numpy(), normalize=True)
        vs.plot_images(images,
                       grids=(n_top_latents, n_traverse_points),
                       ax=(n_visual_samples, 1, i + 1))
        if i == 0:
            plt.title('Latents traverse')
    plt.tight_layout()
    vs.plot_save(f'{expdir}/latents_traverse.pdf', dpi=180, verbose=True)
    # === 0. prior sampling plot
    images = as_tuple(vae.sample_observation(n=n_images, seed=seed))[0]
    images = _prepare_images(images.mean().numpy(), normalize=True)
    plt.figure(figsize=(5, 5))
    vs.plot_images(images, grids=(n_rows, n_rows), title='Sampled')
    # === 1. reconstruction plot
    plt.figure(figsize=(15, 15))
    vs.plot_images(x_org_train,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 1),
                   title='[Train]Original')
    vs.plot_images(x_rec_train,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 2),
                   title='[Train]Reconstructed')
    vs.plot_images(x_org_test,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 3),
                   title='[Test]Original')
    vs.plot_images(x_rec_test,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 4),
                   title='[Test]Reconstructed')
    plt.tight_layout()
    ## prepare the labels
    label_type = ds.label_type
    if label_type == 'categorical':
        labels_name = ds.labels
        true = np.argmax(y_true_test, axis=-1)
        labels_true = np.array([labels_name[i] for i in true])
        labels_pred = labels_true
        if is_semi:
            pred = np.argmax(y_pred_test.mean().numpy(), axis=-1)
            labels_pred = np.array([labels_name[i] for i in pred])
    elif label_type == 'factor':  # dsprites, shapes3d
        labels_name = ['cube', 'cylinder', 'sphere', 'round'] \
          if 'shapes3d' in ds.name else ['square', 'ellipse', 'heart']
        true = y_true_test[:, 2].astype('int32')
        labels_true = np.array([labels_name[i] for i in true])
        labels_pred = labels_true
        if is_semi:
            pred = get_ymean(y_pred_test)[:, 2].astype('int32')
            labels_pred = np.array([labels_name[i] for i in pred])
    else:  # CelebA
        raise NotImplementedError
    ## confusion matrix
    if is_semi:
        plt.figure(figsize=(8, 8))
        acc = accuracy_score(y_true=true, y_pred=pred)
        vs.plot_confusion_matrix(cm=confusion_matrix(y_true=true, y_pred=pred),
                                 labels=labels_name,
                                 cbar=True,
                                 fontsize=10,
                                 title=f'{title} Acc:{acc:.2f}')
    ## save arrays for later inspections
    with open(f'{expdir}/arrays', 'wb') as f:
        pickle.dump(
            dict(z_train=z_train,
                 y_pred_train=y_pred_train,
                 y_true_train=y_true_train,
                 z_test=z_test,
                 y_pred_test=y_pred_test,
                 y_true_test=y_true_test,
                 labels=labels_name,
                 ds=ds.name,
                 label_type=label_type), f)
    print(f'Exported arrays to "{expdir}/arrays"')
    ## semi-supervised
    z_mean_train = np.concatenate(
        [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_train], -1)
    z_mean_test = np.concatenate(
        [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_test], -1)
    # === 2. scatter points latents plot
    n_points = 5000
    ids = rand.permutation(len(labels_true))[:n_points]
    Y_true = labels_true[ids]
    Y_pred = labels_pred[ids]
    # tsne plot
    n_latents = 0 if len(z_train) == 1 else len(z_train)
    for name, X in zip(
        ['all'] + [f'latents{i}'
                   for i in range(n_latents)], [z_mean_test[ids]] +
        [z_test[i].mean().numpy()[ids] for i in range(n_latents)]):
        print(f'Plot scatter points for {name}')
        X = X.reshape(X.shape[0], -1)  # flatten to 2D
        X = Pipeline([('zscore', StandardScaler()),
                      ('pca', PCA(min(X.shape[1], 512),
                                  random_state=seed))]).fit_transform(X)
        tsne = DimReduce.TSNE(X, n_components=2, framework='sklearn')
        kw = dict(x=tsne[:, 0], y=tsne[:, 1], grid=False, size=12.0, alpha=0.8)
        plt.figure(figsize=(12, 6))
        vs.plot_scatter(color=Y_true,
                        title=f'[True]{title}-{name}',
                        ax=(1, 2, 1),
                        **kw)
        vs.plot_scatter(color=Y_pred,
                        title=f'[Pred]{title}-{name}',
                        ax=(1, 2, 2),
                        **kw)
    ## save all plot
    vs.plot_save(f'{expdir}/analysis.pdf', dpi=180, verbose=True)

    # === 3. show the latents statistics
    n_latents = len(z_train)
    colors = sns.color_palette(n_colors=len(labels_true))
    styles = dict(grid=False,
                  ticks_off=False,
                  alpha=0.6,
                  xlabel='mean',
                  ylabel='stddev')

    # scatter between latents and labels (assume categorical distribution)
    def _show_latents_labels(Z, Y, title):
        plt.figure(figsize=(5 * n_latents, 5), dpi=150)
        for idx, z in enumerate(Z):
            if len(z.batch_shape) == 0:
                mean = np.repeat(np.expand_dims(z.mean(), 0), Y.shape[0], 0)
                stddev = z.sample(Y.shape[0]) - mean
            else:
                mean = flatten(z.mean())
                stddev = flatten(z.stddev())
            y = np.argmax(Y, axis=-1)
            data = [[], [], []]
            for y_i, c in zip(np.unique(y), colors):
                mask = (y == y_i)
                data[0].append(np.mean(mean[mask], 0))
                data[1].append(np.mean(stddev[mask], 0))
                data[2].append([labels_true[y_i]] * mean.shape[1])
            vs.plot_scatter(
                x=np.concatenate(data[0], 0),
                y=np.concatenate(data[1], 0),
                color=np.concatenate(data[2], 0),
                ax=(1, n_latents, idx + 1),
                size=15 if mean.shape[1] < 128 else 8,
                title=f'[Test-{title}]#{idx} - {mean.shape[1]} (units)',
                **styles)
        plt.tight_layout()

    # simple scatter mean-stddev each latents
    def _show_latents(Z, title):
        plt.figure(figsize=(3.5 * n_latents, 3.5), dpi=150)
        for idx, z in enumerate(Z):
            mean = flatten(z.mean())
            stddev = flatten(z.stddev())
            if mean.ndim == 2:
                mean = np.mean(mean, 0)
                stddev = np.mean(stddev, 0)
            vs.plot_scatter(
                x=mean,
                y=stddev,
                ax=(1, n_latents, idx + 1),
                size=15 if len(mean) < 128 else 8,
                title=f'[Test-{title}]#{idx} - {len(mean)} (units)',
                **styles)

    _show_latents_labels(z_test, y_true_test, 'post')
    _show_latents_labels(pz_test, y_true_test, 'prior')
    _show_latents(z_test, 'post')
    _show_latents(pz_test, 'prior')

    # KL statistics
    vs.plot_figure()
    for idx, (qz, pz) in enumerate(zip(z_test, pz_test)):
        kl = []
        qz = Normal(loc=qz.mean(), scale=qz.stddev(), name=f'posterior{idx}')
        pz = Normal(loc=pz.mean(), scale=pz.stddev(), name=f'prior{idx}')
        for s, e in minibatch(batch_size=8, n=100):
            z = qz.sample(e - s)
            # don't do this in GPU, it explodes!
            kl.append((qz.log_prob(z) - pz.log_prob(z)).numpy())
        kl = np.concatenate(kl, 0)  # (mcmc, batch, event)
        # per sample
        kl_samples = np.sum(kl, as_tuple(list(range(2, kl.ndim))))
        kl_samples = logsumexp(kl_samples, 0)
        plt.subplot(n_latents, 2, idx * 2 + 1)
        sns.histplot(kl_samples, bins=50)
        plt.title(f'Z#{idx} KL per sample (nats)')
        # per latent
        kl_latents = np.mean(flatten(logsumexp(kl, 0)), 0)
        plt.subplot(n_latents, 2, idx * 2 + 2)
        plt.plot(np.sort(kl_latents))
        plt.title(f'Z#{idx} KL per dim (nats)')
    plt.tight_layout()

    vs.plot_save(f'{expdir}/latents.pdf', dpi=180, verbose=True)
Ejemplo n.º 6
0
 def callback():
     trainer = get_current_trainer()
     x, y = x_test[:1000], y_test[:1000]
     px, qz = vae(x, training=False)
     # latents
     qz_mean = tf.reduce_mean(qz.mean(), axis=0)
     qz_std = tf.reduce_mean(qz.stddev(), axis=0)
     w = tf.reduce_sum(decoder.trainable_variables[0], axis=(0, 1, 2))
     # plot the latents and its weights
     fig = plt.figure(figsize=(6, 4), dpi=200)
     ax = plt.gca()
     l1 = ax.plot(qz_mean,
                  label='mean',
                  linewidth=1.0,
                  linestyle='--',
                  marker='o',
                  markersize=4,
                  color='r',
                  alpha=0.5)
     l2 = ax.plot(qz_std,
                  label='std',
                  linewidth=1.0,
                  linestyle='--',
                  marker='o',
                  markersize=4,
                  color='g',
                  alpha=0.5)
     ax1 = ax.twinx()
     l3 = ax1.plot(w,
                   label='weight',
                   linewidth=1.0,
                   linestyle='--',
                   marker='o',
                   markersize=4,
                   color='b',
                   alpha=0.5)
     lines = l1 + l2 + l3
     labs = [l.get_label() for l in lines]
     ax.grid(True)
     ax.legend(lines, labs)
     img_qz = vs.plot_to_image(fig)
     # reconstruction
     fig = plt.figure(figsize=(5, 5), dpi=120)
     vs.plot_images(np.squeeze(px.mean().numpy()[:25], axis=-1),
                    grids=(5, 5))
     img_res = vs.plot_to_image(fig)
     # latents
     fig = plt.figure(figsize=(5, 5), dpi=200)
     z = fast_umap(qz.mean().numpy())
     vs.plot_scatter(z, color=y, size=12.0, alpha=0.4)
     img_umap = vs.plot_to_image(fig)
     # gradients
     grads = [(k, v) for k, v in trainer.last_train_metrics.items()
              if '_grad/' in k]
     encoder_grad = sum(v for k, v in grads if 'Encoder' in k)
     decoder_grad = sum(v for k, v in grads if 'Decoder' in k)
     return dict(reconstruct=img_res,
                 umap=img_umap,
                 latents=img_qz,
                 qz_mean=qz_mean,
                 qz_std=qz_std,
                 w_decoder=w,
                 llk_test=tf.reduce_mean(px.log_prob(x)),
                 encoder_grad=encoder_grad,
                 decoder_grad=decoder_grad)