def plot_epoch(task):
    if task is None:
        curr_epoch = 0
    else:
        curr_epoch = task.curr_epoch
        if not (curr_epoch < 5 or curr_epoch % 5 == 0):
            return
    rand = np.random.RandomState(seed=1234)

    X, y = X_test, y_test
    n_data = X.shape[0]
    Z = f_z(X)
    W, W_stdev_mcmc, W_stdev_analytic = f_w(X)

    X_pca, W_pca_1 = fast_pca(X,
                              W,
                              n_components=2,
                              random_state=rand.randint(10e8))
    W_pca_2 = fast_pca(W, n_components=2, random_state=rand.randint(10e8))
    X_count_sum = np.sum(X, axis=tuple(range(1, X.ndim)))
    W_count_sum = np.sum(W, axis=-1)

    n_visual_samples = 8
    nrow = 13 + n_visual_samples * 3
    V.plot_figure(nrow=int(nrow * 1.8), ncol=18)
    with V.plot_gridSpec(nrow=nrow + 3, ncol=6, hspace=0.8) as grid:
        # plot the latent space
        for i, (z, name) in enumerate(zip(Z, Z_names)):
            if z.shape[1] > 2:
                z = fast_pca(z,
                             n_components=2,
                             random_state=rand.randint(10e8))
            ax = V.subplot(grid[:3, (i * 2):(i * 2 + 2)])
            V.plot_scatter(x=z[:, 0],
                           y=z[:, 1],
                           color=y,
                           marker=y,
                           n_samples=4000,
                           ax=ax,
                           legend_enable=False,
                           legend_ncol=n_classes)
            ax.set_title(name, fontsize=12)
        # plot the reconstruction
        for i, (x, name) in enumerate(
                zip([X_pca, W_pca_1, W_pca_2], [
                    'Original data', 'Reconstruction',
                    'Reconstruction (separated PCA)'
                ])):
            ax = V.subplot(grid[3:6, (i * 2):(i * 2 + 2)])
            V.plot_scatter(x=x[:, 0],
                           y=x[:, 1],
                           color=y,
                           marker=y,
                           n_samples=4000,
                           ax=ax,
                           legend_enable=i == 1,
                           legend_ncol=n_classes,
                           title=name)
        # plot the reconstruction count sum
        for i, (x, count_sum, name) in enumerate(
                zip([X_pca, W_pca_1], [X_count_sum, W_count_sum], [
                    'Original data (Count-sum)', 'Reconstruction (Count-sum)'
                ])):
            ax = V.subplot(grid[6:10, (i * 3):(i * 3 + 3)])
            V.plot_scatter(x=x[:, 0],
                           y=x[:, 1],
                           val=count_sum,
                           n_samples=2000,
                           marker=y,
                           ax=ax,
                           size=8,
                           legend_enable=i == 0,
                           legend_ncol=n_classes,
                           title=name,
                           colorbar=True,
                           fontsize=10)
        # plot the count-sum series
        count_sum_observed = np.sum(X, axis=0).ravel()
        count_sum_expected = np.sum(W, axis=0)
        count_sum_stdev_explained = np.sum(W_stdev_mcmc, axis=0)
        count_sum_stdev_total = np.sum(W_stdev_analytic, axis=0)
        for i, kws in enumerate([
                dict(xscale='linear', yscale='linear', sort_by=None),
                dict(xscale='linear', yscale='linear', sort_by='expected'),
                dict(xscale='log', yscale='log', sort_by='expected')
        ]):
            ax = V.subplot(grid[10:10 + 3, (i * 2):(i * 2 + 2)])
            V.plot_series_statistics(count_sum_observed,
                                     count_sum_expected,
                                     explained_stdev=count_sum_stdev_explained,
                                     total_stdev=count_sum_stdev_total,
                                     fontsize=8,
                                     title="Count-sum" if i == 0 else None,
                                     **kws)
        # plot the mean and variances
        curr_grid_index = 13
        ids = rand.permutation(n_data)
        ids = ids[:n_visual_samples]
        for i in ids:
            observed, expected, stdev_explained, stdev_total = \
                X[i], W[i], W_stdev_mcmc[i], W_stdev_analytic[i]
            observed = observed.ravel()
            for j, kws in enumerate([
                    dict(xscale='linear', yscale='linear', sort_by=None),
                    dict(xscale='linear', yscale='linear', sort_by='expected'),
                    dict(xscale='log', yscale='log', sort_by='expected')
            ]):
                ax = V.subplot(grid[curr_grid_index:curr_grid_index + 3,
                                    (j * 2):(j * 2 + 2)])
                V.plot_series_statistics(observed,
                                         expected,
                                         explained_stdev=stdev_explained,
                                         total_stdev=stdev_total,
                                         fontsize=8,
                                         title="Test Sample #%d" %
                                         i if j == 0 else None,
                                         **kws)
            curr_grid_index += 3
    V.plot_save(os.path.join(FIGURE_PATH, 'latent_%d.png' % curr_epoch),
                dpi=200,
                log=True)
    exit()
Esempio n. 2
0
def plot_monitoring_epoch(X, X_drop, y, Z, Z_drop, W_outputs, W_drop_outputs,
                          pi, pi_drop, row_name, dropout_percentage,
                          curr_epoch, ds_name, labels, save_dir):
    # Order of W_outputs: [W, W_stdev_total, W_stdev_explained]
    from matplotlib import pyplot as plt
    if y.ndim == 2:
        y = np.argmax(y, axis=-1)
    y = np.array([labels[i] for i in y])
    dropout_percentage_text = '%g%%' % (dropout_percentage * 100)

    Z_pca = fast_pca(Z, n_components=2, random_state=5218)
    Z_pca_drop = fast_pca(Z_drop, n_components=2, random_state=5218)
    if W_outputs is not None:
        X_pca, X_pca_drop, W_pca, W_pca_drop = fast_pca(X,
                                                        X_drop,
                                                        W_outputs[0],
                                                        W_drop_outputs[0],
                                                        n_components=2,
                                                        random_state=5218)
    # ====== downsampling ====== #
    rand = np.random.RandomState(seed=5218)
    n_test_samples = len(y)
    ids = np.arange(n_test_samples, dtype='int32')
    if n_test_samples > 8000:
        ids = rand.choice(ids, size=8000, replace=False)
    # ====== scatter configuration ====== #
    config = dict(size=6, labels=None)
    y = y[ids]

    X = X[ids]
    X_drop = X_drop[ids]
    Z_pca = Z_pca[ids]
    X_pca = X_pca[ids]
    W_pca = W_pca[ids]

    W_outputs = [w[ids] for w in W_outputs]
    W_drop_outputs = [w[ids] for w in W_drop_outputs]
    Z_pca_drop = Z_pca_drop[ids]
    X_pca_drop = X_pca_drop[ids]
    W_pca_drop = W_pca_drop[ids]

    if pi is not None:
        pi = pi[ids]
        pi_drop = pi_drop[ids]
    # ====== plotting NO reconstruction ====== #
    if W_outputs is None:
        plot_figure(nrow=8, ncol=20)
        fast_scatter(x=Z_pca,
                     y=y,
                     title="[PCA] Test data latent space",
                     enable_legend=True,
                     ax=(1, 2, 1),
                     **config)
        fast_scatter(x=Z_pca_drop,
                     y=y,
                     title="[PCA][Dropped:%s] Test data latent space" %
                     dropout_percentage_text,
                     ax=(1, 2, 2),
                     **config)
    # ====== plotting WITH reconstruction ====== #
    else:
        plot_figure(nrow=16, ncol=20)
        # original test data WITHOUT dropout
        fast_scatter(x=X_pca,
                     y=y,
                     title="[PCA][Test Data] Original",
                     ax=(2, 3, 1),
                     **config)
        fast_scatter(x=W_pca,
                     y=y,
                     title="Reconstructed",
                     ax=(2, 3, 2),
                     **config)
        fast_scatter(x=Z_pca,
                     y=y,
                     title="Latent space",
                     ax=(2, 3, 3),
                     **config)
        # original test data WITH dropout
        fast_scatter(x=X_pca_drop,
                     y=y,
                     title="[PCA][Dropped:%s][Test Data] Original" %
                     dropout_percentage_text,
                     ax=(2, 3, 4),
                     **config)
        fast_scatter(x=W_pca_drop,
                     y=y,
                     title="Reconstructed",
                     ax=(2, 3, 5),
                     enable_legend=True,
                     **config)
        fast_scatter(x=Z_pca_drop,
                     y=y,
                     title="Latent space",
                     ax=(2, 3, 6),
                     **config)
    plot_save(os.path.join(save_dir, 'latent_epoch%d.png') % curr_epoch,
              dpi=180,
              clear_all=True,
              log=True)
    # ====== plot count-sum ====== #
    if W_outputs is not None:
        X_countsum = _clip_count_sum(np.sum(X, axis=-1))
        W_countsum = _clip_count_sum(np.sum(W_outputs[0], axis=-1))
        X_drop_countsum = _clip_count_sum(np.sum(X_drop, axis=-1))
        W_drop_countsum = _clip_count_sum(np.sum(W_drop_outputs[0], axis=-1))
        series_config = [
            dict(xscale='linear', yscale='linear', sort_by=None),
            dict(xscale='linear', yscale='linear', sort_by='expected')
        ]

        if pi is not None:
            pi_sum = np.mean(pi, axis=-1)
            pi_drop_sum = np.mean(pi_drop, axis=-1)
        # plot the reconstruction count sum
        plot_figure(nrow=3 * 5 + 8, ncol=18)
        with plot_gridSpec(nrow=3 * (2 if pi is None else 3) + 4 * 3 + 1,
                           ncol=6,
                           wspace=1.0,
                           hspace=0.8) as grid:
            kws = dict(colorbar=True,
                       fontsize=10,
                       size=10,
                       marker=y,
                       n_samples=1200)
            # without dropout
            ax = subplot(grid[:3, 0:3])
            plot_scatter(x=X_pca,
                         val=X_countsum,
                         ax=ax,
                         legend_enable=False,
                         title='Original data (Count-sum)',
                         **kws)
            ax = subplot(grid[:3, 3:6])
            plot_scatter(x=W_pca,
                         val=W_countsum,
                         ax=ax,
                         legend_enable=False,
                         title='Reconstruction (Count-sum)',
                         **kws)
            # with dropout
            ax = subplot(grid[3:6, 0:3])
            plot_scatter(x=X_pca_drop,
                         val=X_drop_countsum,
                         ax=ax,
                         legend_enable=True if pi is None else False,
                         legend_ncol=len(labels),
                         title='[Dropped:%s]Original data (Count-sum)' %
                         dropout_percentage_text,
                         **kws)
            ax = subplot(grid[3:6, 3:6])
            plot_scatter(x=W_pca_drop,
                         val=W_drop_countsum,
                         ax=ax,
                         legend_enable=False,
                         title='[Dropped:%s]Reconstruction (Count-sum)' %
                         dropout_percentage_text,
                         **kws)
            row_start = 6
            # zero-inflated pi
            if pi is not None:
                ax = subplot(grid[6:9, 0:3])
                plot_scatter(x=X_pca,
                             val=pi_sum,
                             ax=ax,
                             legend_enable=True,
                             legend_ncol=len(labels),
                             title='Zero-inflated probabilities',
                             **kws)
                ax = subplot(grid[6:9, 3:6])
                plot_scatter(x=X_pca,
                             val=pi_drop_sum,
                             ax=ax,
                             legend_enable=False,
                             title='[Dropped:%s]Zero-inflated probabilities' %
                             dropout_percentage_text,
                             **kws)
                row_start += 3

            # plot the count-sum series
            def plot_count_sum_series(x, w, p, row_start, tit):
                if len(w) != 3:  # no statistics provided
                    return
                expected, stdev_total, stdev_explained = w
                count_sum_observed = np.sum(x, axis=0)
                count_sum_expected = np.sum(expected, axis=0)
                count_sum_stdev_total = np.sum(stdev_total, axis=0)
                count_sum_stdev_explained = np.sum(stdev_explained, axis=0)
                if p is not None:
                    p_sum = np.mean(p, axis=0)
                for i, kws in enumerate(series_config):
                    ax = subplot(grid[row_start:row_start + 3,
                                      (i * 3):(i * 3 + 3)])
                    ax, handles, indices = plot_series_statistics(
                        count_sum_observed,
                        count_sum_expected,
                        explained_stdev=count_sum_stdev_explained,
                        total_stdev=count_sum_stdev_total,
                        fontsize=8,
                        ax=ax,
                        legend_enable=False,
                        title=tit if i == 0 else None,
                        despine=True if p is None else False,
                        return_handles=True,
                        return_indices=True,
                        **kws)
                    if p is not None:
                        _show_zero_inflated_pi(p_sum, ax, handles, indices)
                    plt.legend(handles=handles,
                               loc='best',
                               markerscale=4,
                               fontsize=8)

            # add one row extra padding
            row_start += 1
            plot_count_sum_series(x=X,
                                  w=W_outputs,
                                  p=pi,
                                  row_start=row_start,
                                  tit="Count-sum X_original - W_original")
            row_start += 1
            plot_count_sum_series(
                x=X_drop,
                w=W_drop_outputs,
                p=pi_drop,
                row_start=row_start + 3,
                tit="[Dropped:%s]Count-sum X_drop - W_dropout" %
                dropout_percentage_text)
            row_start += 1
            plot_count_sum_series(
                x=X,
                w=W_drop_outputs,
                p=pi_drop,
                row_start=row_start + 6,
                tit="[Dropped:%s]Count-sum X_original - W_dropout" %
                dropout_percentage_text)
        plot_save(os.path.join(save_dir, 'countsum_epoch%d.png') % curr_epoch,
                  dpi=180,
                  clear_all=True,
                  log=True)
    # ====== plot series of samples ====== #
    if W_outputs is not None and len(W_outputs) == 3:
        # NOTe: turn off pi here
        pi = None

        n_visual_samples = 8
        plot_figure(nrow=3 * n_visual_samples + 8, ncol=25)
        col_width = 5
        with plot_gridSpec(nrow=3 * n_visual_samples,
                           ncol=4 * col_width,
                           wspace=5.0,
                           hspace=1.0) as grid:
            curr_grid_index = 0
            for i in rand.permutation(len(X))[:n_visual_samples]:
                observed = X[i]
                expected, stdev_explained, stdev_total = [
                    w[i] for w in W_outputs
                ]
                expected_drop, stdev_explained_drop, stdev_total_drop = [
                    w[i] for w in W_drop_outputs
                ]
                if pi is not None:
                    p_zi = pi[i]
                    p_zi_drop = pi_drop[i]
                # compare to W_original
                for j, kws in enumerate(series_config):
                    ax = subplot(grid[curr_grid_index:curr_grid_index + 3,
                                      (j * col_width):(j * col_width +
                                                       col_width)])
                    ax, handles, indices = plot_series_statistics(
                        observed,
                        expected,
                        explained_stdev=stdev_explained,
                        total_stdev=stdev_total,
                        fontsize=8,
                        legend_enable=False,
                        despine=True if pi is None else False,
                        title=("'%s' X_original - W_original" %
                               row_name[i]) if j == 0 else None,
                        return_handles=True,
                        return_indices=True,
                        **kws)
                    if pi is not None:
                        _show_zero_inflated_pi(p_zi, ax, handles, indices)
                    plt.legend(handles=handles,
                               loc='best',
                               markerscale=4,
                               fontsize=8)
                # compare to W_dropout
                for j, kws in enumerate(series_config):
                    col_start = col_width * 2
                    ax = subplot(
                        grid[curr_grid_index:curr_grid_index + 3,
                             (col_start +
                              j * col_width):(col_start + j * col_width +
                                              col_width)])
                    ax, handles, indices = plot_series_statistics(
                        observed,
                        expected_drop,
                        explained_stdev=stdev_explained_drop,
                        total_stdev=stdev_total_drop,
                        fontsize=8,
                        legend_enable=False,
                        despine=True if pi is None else False,
                        title=("[Dropped:%s]'%s' X_original - W_dropout" %
                               (dropout_percentage_text, row_name[i]))
                        if j == 0 else None,
                        return_handles=True,
                        return_indices=True,
                        **kws)
                    if pi is not None:
                        _show_zero_inflated_pi(p_zi_drop, ax, handles, indices)
                    plt.legend(handles=handles,
                               loc='best',
                               markerscale=4,
                               fontsize=8)
                curr_grid_index += 3
        plot_save(os.path.join(save_dir, 'samples_epoch%d.png') % curr_epoch,
                  dpi=180,
                  clear_all=True,
                  log=True)
    # ====== special case for mnist ====== #
    if 'mnist' in ds_name and W_outputs is not None:
        plot_figure(nrow=3, ncol=18)
        n_images = 32
        ids = rand.choice(np.arange(X.shape[0], dtype='int32'),
                          size=n_images,
                          replace=False)
        meta_data = [("Org", X[ids]), ("Rec", W_outputs[0][ids]),
                     ("OrgDropout", X_drop[ids]),
                     ("RecDropout", W_drop_outputs[0][ids])]
        count = 1
        for name, data in meta_data:
            for i in range(n_images):
                x = data[i].reshape(28, 28)
                plt.subplot(4, n_images, count)
                show_image(x)
                if i == 0:
                    plt.ylabel(name, fontsize=8)
                count += 1
        plt.subplots_adjust(wspace=0.05, hspace=0.05)
        plot_save(os.path.join(save_dir, 'image_epoch%d.png') % curr_epoch,
                  dpi=180,
                  clear_all=True,
                  log=True)