Ejemplo n.º 1
0
 def latent_units(model: VariationalModel, valid_ds: tf.data.Dataset):
     weights = model.weights
     Qz = []
     Pz = []
     for x, y in valid_ds.take(10):
         _call(model, x, y, decode=True)
         qz, pz = model.get_latents(return_prior=True)
         Qz.append(as_tuple(qz))
         Pz.append(as_tuple(pz))
     n_latents = len(Qz[0])
     for i in range(n_latents):
         qz: Sequence[Distribution] = [q[i] for q in Qz]
         pz: Sequence[Distribution] = [p[i] for p in Pz]
         # tracking kl
         kld = []
         for q, p in zip(qz, pz):
             q = Normal(loc=q.mean(), scale=q.stddev())
             z = q.sample()
             if isinstance(p, Vamprior):
                 C = p.C
                 p = p.distribution  # [n_components, zdim]
                 p = Normal(loc=p.mean(), scale=p.stddev())
                 kld.append(
                     q.log_prob(z) - (tf.reduce_logsumexp(
                         p.log_prob(tf.expand_dims(z, 1)), 1) - C))
             else:
                 p = Normal(loc=p.mean(), scale=p.stddev())
                 kld.append(q.log_prob(z) - p.log_prob(z))
         kld = tf.reshape(tf.reduce_mean(tf.concat(kld, 0), axis=0),
                          -1).numpy()
         # mean and stddev
         mean = tf.reduce_mean(tf.concat([d.mean() for d in qz], axis=0), 0)
         stddev = tf.reduce_mean(
             tf.concat([d.stddev() for d in qz], axis=0), 0)
         mean = tf.reshape(mean, -1).numpy()
         stddev = tf.reshape(stddev, -1).numpy()
         zdim = mean.shape[0]
         # the figure
         plt.figure(figsize=(12, 5), dpi=50)
         lines = []
         ids = np.argsort(stddev)
         styles = dict(marker='o', markersize=2, linewidth=0, alpha=0.8)
         lines += plt.plot(mean[ids], label='mean', color='r', **styles)
         lines += plt.plot(stddev[ids], label='stddev', color='b', **styles)
         plt.grid(False)
         # show weights if exists
         plt.twinx()
         lines += plt.plot(kld[ids],
                           label='KL(q|p)',
                           linestyle='--',
                           linewidth=1.0,
                           alpha=0.6)
         for w in weights:
             name = w.name
             if w.shape.rank > 0 and w.shape[
                     0] == zdim and '/kernel' in name:
                 w = tf.linalg.norm(tf.reshape(w, (w.shape[0], -1)),
                                    axis=1).numpy()
                 lines += plt.plot(w[ids],
                                   label=name.split(':')[0],
                                   linestyle='--',
                                   alpha=0.6)
         plt.grid(False)
         plt.legend(lines, [ln.get_label() for ln in lines], fontsize=6)
         # save summary
         tf.summary.image(f'z{i}',
                          vs.plot_to_image(plt.gcf()),
                          step=model.step)
         tf.summary.histogram(f'z{i}/mean', mean, step=model.step)
         tf.summary.histogram(f'z{i}/stddev', stddev, step=model.step)
         tf.summary.histogram(f'z{i}/kld', kld, step=model.step)
Ejemplo n.º 2
0
def evaluate(vae: VariationalAutoencoder,
             ds: ImageDataset,
             expdir: str,
             title: str,
             batch_size: int = 64,
             take_count: int = -1,
             n_images: int = 36,
             seed: int = 1):
    n_rows = int(np.sqrt(n_images))
    is_semi = vae.is_semi_supervised()
    is_hierarchical = vae.is_hierarchical()
    ds_kw = dict(batch_size=batch_size, label_percent=1.0, shuffle=False)
    ## prepare
    rand = np.random.RandomState(seed=seed)
    if not os.path.exists(expdir):
        os.makedirs(expdir)
    ## data for training semi-supervised
    train = ds.create_dataset('train', **ds_kw)
    (llkx_train, llky_train, x_org_train, x_rec_train, y_true_train,
     y_pred_train, z_train, pz_train) = _call(vae,
                                              ds=train,
                                              rand=rand,
                                              take_count=take_count,
                                              n_images=n_images,
                                              verbose=True)
    ## data for testing
    test = ds.create_dataset('test', **ds_kw)
    (llkx_test, llky_test, x_org_test, x_rec_test, y_true_test, y_pred_test,
     z_test, pz_test) = _call(vae,
                              ds=test,
                              rand=rand,
                              take_count=take_count,
                              n_images=n_images,
                              verbose=True)
    # === 0. plotting latent-factor pairs
    for idx, z in enumerate(z_test):
        z = z.mean()
        f = y_true_test
        corr_mat = Correlation.Spearman(z, f)  # [n_latents, n_factors]
        plot_latents_pairs(z, f, corr_mat, ds.labels)
        vs.plot_save(f'{expdir}/latent{idx}_factor.pdf', dpi=100, verbose=True)
    # === 0. latent traverse plot
    x_travs = x_org_test
    if x_travs.ndim == 3:  # grayscale image
        x_travs = np.expand_dims(x_travs, -1)
    else:  # color image
        x_travs = np.transpose(x_travs, (0, 2, 3, 1))
    x_travs = x_travs[rand.permutation(x_travs.shape[0])]
    n_visual_samples = 5
    n_traverse_points = 21
    n_top_latents = 10
    plt.figure(figsize=(8, 3 * n_visual_samples))
    for i in range(n_visual_samples):
        images = vae.sample_traverse(x_travs[i:i + 1],
                                     min_val=-np.min(z_test[0].mean()),
                                     max_val=np.max(z_test[0].mean()),
                                     n_best_latents=n_top_latents,
                                     n_traverse_points=n_traverse_points,
                                     mode='linear')
        images = as_tuple(images)[0]
        images = _prepare_images(images.mean().numpy(), normalize=True)
        vs.plot_images(images,
                       grids=(n_top_latents, n_traverse_points),
                       ax=(n_visual_samples, 1, i + 1))
        if i == 0:
            plt.title('Latents traverse')
    plt.tight_layout()
    vs.plot_save(f'{expdir}/latents_traverse.pdf', dpi=180, verbose=True)
    # === 0. prior sampling plot
    images = as_tuple(vae.sample_observation(n=n_images, seed=seed))[0]
    images = _prepare_images(images.mean().numpy(), normalize=True)
    plt.figure(figsize=(5, 5))
    vs.plot_images(images, grids=(n_rows, n_rows), title='Sampled')
    # === 1. reconstruction plot
    plt.figure(figsize=(15, 15))
    vs.plot_images(x_org_train,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 1),
                   title='[Train]Original')
    vs.plot_images(x_rec_train,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 2),
                   title='[Train]Reconstructed')
    vs.plot_images(x_org_test,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 3),
                   title='[Test]Original')
    vs.plot_images(x_rec_test,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 4),
                   title='[Test]Reconstructed')
    plt.tight_layout()
    ## prepare the labels
    label_type = ds.label_type
    if label_type == 'categorical':
        labels_name = ds.labels
        true = np.argmax(y_true_test, axis=-1)
        labels_true = np.array([labels_name[i] for i in true])
        labels_pred = labels_true
        if is_semi:
            pred = np.argmax(y_pred_test.mean().numpy(), axis=-1)
            labels_pred = np.array([labels_name[i] for i in pred])
    elif label_type == 'factor':  # dsprites, shapes3d
        labels_name = ['cube', 'cylinder', 'sphere', 'round'] \
          if 'shapes3d' in ds.name else ['square', 'ellipse', 'heart']
        true = y_true_test[:, 2].astype('int32')
        labels_true = np.array([labels_name[i] for i in true])
        labels_pred = labels_true
        if is_semi:
            pred = get_ymean(y_pred_test)[:, 2].astype('int32')
            labels_pred = np.array([labels_name[i] for i in pred])
    else:  # CelebA
        raise NotImplementedError
    ## confusion matrix
    if is_semi:
        plt.figure(figsize=(8, 8))
        acc = accuracy_score(y_true=true, y_pred=pred)
        vs.plot_confusion_matrix(cm=confusion_matrix(y_true=true, y_pred=pred),
                                 labels=labels_name,
                                 cbar=True,
                                 fontsize=10,
                                 title=f'{title} Acc:{acc:.2f}')
    ## save arrays for later inspections
    with open(f'{expdir}/arrays', 'wb') as f:
        pickle.dump(
            dict(z_train=z_train,
                 y_pred_train=y_pred_train,
                 y_true_train=y_true_train,
                 z_test=z_test,
                 y_pred_test=y_pred_test,
                 y_true_test=y_true_test,
                 labels=labels_name,
                 ds=ds.name,
                 label_type=label_type), f)
    print(f'Exported arrays to "{expdir}/arrays"')
    ## semi-supervised
    z_mean_train = np.concatenate(
        [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_train], -1)
    z_mean_test = np.concatenate(
        [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_test], -1)
    # === 2. scatter points latents plot
    n_points = 5000
    ids = rand.permutation(len(labels_true))[:n_points]
    Y_true = labels_true[ids]
    Y_pred = labels_pred[ids]
    # tsne plot
    n_latents = 0 if len(z_train) == 1 else len(z_train)
    for name, X in zip(
        ['all'] + [f'latents{i}'
                   for i in range(n_latents)], [z_mean_test[ids]] +
        [z_test[i].mean().numpy()[ids] for i in range(n_latents)]):
        print(f'Plot scatter points for {name}')
        X = X.reshape(X.shape[0], -1)  # flatten to 2D
        X = Pipeline([('zscore', StandardScaler()),
                      ('pca', PCA(min(X.shape[1], 512),
                                  random_state=seed))]).fit_transform(X)
        tsne = DimReduce.TSNE(X, n_components=2, framework='sklearn')
        kw = dict(x=tsne[:, 0], y=tsne[:, 1], grid=False, size=12.0, alpha=0.8)
        plt.figure(figsize=(12, 6))
        vs.plot_scatter(color=Y_true,
                        title=f'[True]{title}-{name}',
                        ax=(1, 2, 1),
                        **kw)
        vs.plot_scatter(color=Y_pred,
                        title=f'[Pred]{title}-{name}',
                        ax=(1, 2, 2),
                        **kw)
    ## save all plot
    vs.plot_save(f'{expdir}/analysis.pdf', dpi=180, verbose=True)

    # === 3. show the latents statistics
    n_latents = len(z_train)
    colors = sns.color_palette(n_colors=len(labels_true))
    styles = dict(grid=False,
                  ticks_off=False,
                  alpha=0.6,
                  xlabel='mean',
                  ylabel='stddev')

    # scatter between latents and labels (assume categorical distribution)
    def _show_latents_labels(Z, Y, title):
        plt.figure(figsize=(5 * n_latents, 5), dpi=150)
        for idx, z in enumerate(Z):
            if len(z.batch_shape) == 0:
                mean = np.repeat(np.expand_dims(z.mean(), 0), Y.shape[0], 0)
                stddev = z.sample(Y.shape[0]) - mean
            else:
                mean = flatten(z.mean())
                stddev = flatten(z.stddev())
            y = np.argmax(Y, axis=-1)
            data = [[], [], []]
            for y_i, c in zip(np.unique(y), colors):
                mask = (y == y_i)
                data[0].append(np.mean(mean[mask], 0))
                data[1].append(np.mean(stddev[mask], 0))
                data[2].append([labels_true[y_i]] * mean.shape[1])
            vs.plot_scatter(
                x=np.concatenate(data[0], 0),
                y=np.concatenate(data[1], 0),
                color=np.concatenate(data[2], 0),
                ax=(1, n_latents, idx + 1),
                size=15 if mean.shape[1] < 128 else 8,
                title=f'[Test-{title}]#{idx} - {mean.shape[1]} (units)',
                **styles)
        plt.tight_layout()

    # simple scatter mean-stddev each latents
    def _show_latents(Z, title):
        plt.figure(figsize=(3.5 * n_latents, 3.5), dpi=150)
        for idx, z in enumerate(Z):
            mean = flatten(z.mean())
            stddev = flatten(z.stddev())
            if mean.ndim == 2:
                mean = np.mean(mean, 0)
                stddev = np.mean(stddev, 0)
            vs.plot_scatter(
                x=mean,
                y=stddev,
                ax=(1, n_latents, idx + 1),
                size=15 if len(mean) < 128 else 8,
                title=f'[Test-{title}]#{idx} - {len(mean)} (units)',
                **styles)

    _show_latents_labels(z_test, y_true_test, 'post')
    _show_latents_labels(pz_test, y_true_test, 'prior')
    _show_latents(z_test, 'post')
    _show_latents(pz_test, 'prior')

    # KL statistics
    vs.plot_figure()
    for idx, (qz, pz) in enumerate(zip(z_test, pz_test)):
        kl = []
        qz = Normal(loc=qz.mean(), scale=qz.stddev(), name=f'posterior{idx}')
        pz = Normal(loc=pz.mean(), scale=pz.stddev(), name=f'prior{idx}')
        for s, e in minibatch(batch_size=8, n=100):
            z = qz.sample(e - s)
            # don't do this in GPU, it explodes!
            kl.append((qz.log_prob(z) - pz.log_prob(z)).numpy())
        kl = np.concatenate(kl, 0)  # (mcmc, batch, event)
        # per sample
        kl_samples = np.sum(kl, as_tuple(list(range(2, kl.ndim))))
        kl_samples = logsumexp(kl_samples, 0)
        plt.subplot(n_latents, 2, idx * 2 + 1)
        sns.histplot(kl_samples, bins=50)
        plt.title(f'Z#{idx} KL per sample (nats)')
        # per latent
        kl_latents = np.mean(flatten(logsumexp(kl, 0)), 0)
        plt.subplot(n_latents, 2, idx * 2 + 2)
        plt.plot(np.sort(kl_latents))
        plt.title(f'Z#{idx} KL per dim (nats)')
    plt.tight_layout()

    vs.plot_save(f'{expdir}/latents.pdf', dpi=180, verbose=True)