Example #1
0
 def _show_latents_labels(Z, Y, title):
     plt.figure(figsize=(5 * n_latents, 5), dpi=150)
     for idx, z in enumerate(Z):
         if len(z.batch_shape) == 0:
             mean = np.repeat(np.expand_dims(z.mean(), 0), Y.shape[0], 0)
             stddev = z.sample(Y.shape[0]) - mean
         else:
             mean = flatten(z.mean())
             stddev = flatten(z.stddev())
         y = np.argmax(Y, axis=-1)
         data = [[], [], []]
         for y_i, c in zip(np.unique(y), colors):
             mask = (y == y_i)
             data[0].append(np.mean(mean[mask], 0))
             data[1].append(np.mean(stddev[mask], 0))
             data[2].append([labels_true[y_i]] * mean.shape[1])
         vs.plot_scatter(
             x=np.concatenate(data[0], 0),
             y=np.concatenate(data[1], 0),
             color=np.concatenate(data[2], 0),
             ax=(1, n_latents, idx + 1),
             size=15 if mean.shape[1] < 128 else 8,
             title=f'[Test-{title}]#{idx} - {mean.shape[1]} (units)',
             **styles)
     plt.tight_layout()
Example #2
0
def fast_scatter(x,
                 y,
                 labels,
                 title,
                 azim=None,
                 elev=None,
                 ax=None,
                 enable_legend=False,
                 size=18,
                 fontsize=12):
    y = np.squeeze(y)
    if y.ndim == 1:
        pass
    elif y.ndim == 2:  # provided one-hot vectors
        y = np.argmax(y, axis=-1)
    else:
        raise ValueError("No support for `y` shape: %s" % str(y.shape))
    # ====== get colors and legends ====== #
    if labels is not None:
        y = [labels[int(i)] for i in y]
        num_classes = len(labels)
    else:
        num_classes = len(np.unique(y))
    # int(np.ceil(num_classes / 2)) if num_classes <= 20 else num_classes // 5
    plot_scatter(x=x,
                 color=y,
                 marker=y,
                 size=size,
                 azim=azim,
                 elev=elev,
                 legend_enable=enable_legend,
                 legend_ncol=3,
                 fontsize=fontsize,
                 ax=ax,
                 title=title)
Example #3
0
 def plot_uncertainty_statistics(self, factors=None):
   r"""
   factors : list of Integer or String. The index or name of factors taken
     into account for analyzing.
   """
   factors = self._check_factors(factors)
   zmean = np.concatenate(self.representations_mean, axis=0)
   zstd = np.sqrt(np.concatenate(self.representations_variance, axis=0))
   labels = self.factors_name[factors]
   factors = np.concatenate(self.original_factors, axis=0)[:, factors]
   X = np.arange(zmean.shape[0])
   # create the figure
   nrow = self.n_representations
   ncol = len(labels)
   fig = vs.plot_figure(nrow=nrow * 4, ncol=ncol * 4, dpi=80)
   plot = 1
   for row, (code, mean,
             std) in enumerate(zip(self.codes_name, zmean.T, zstd.T)):
     # prepare the code
     ids = np.argsort(mean)
     mean, std = mean[ids], std[ids]
     # show the factors
     for col, (name, y) in enumerate(zip(labels, factors.T)):
       axes = []
       # the variance
       ax = vs.plot_subplot(nrow, ncol, plot)
       ax.plot(mean, color='g', linestyle='--')
       ax.fill_between(X, mean - 2 * std, mean + 2 * std, alpha=0.2, color='b')
       if col == 0:
         ax.set_ylabel(code)
       if row == 0:
         ax.set_title(name)
       axes.append(ax)
       # factor
       y = y[ids]
       ax = ax.twinx()
       vs.plot_scatter(x=X,
                       y=y,
                       val=y,
                       size=12,
                       color='bwr',
                       alpha=0.5,
                       grid=False)
       axes.append(ax)
       # update plot index
       for ax in axes:
         ax.tick_params(axis='both',
                        which='both',
                        top=False,
                        bottom=False,
                        left=False,
                        right=False,
                        labeltop=False,
                        labelleft=False,
                        labelright=False,
                        labelbottom=False)
       plot += 1
   fig.tight_layout()
   self.add_figure("uncertainty_stats", fig)
   return self
Example #4
0
def plot(train, score, title, applying_pca=False):
    if applying_pca:
        pca = PCA(n_components=NUM_DIM)
        pca.fit(train)
        train = pca.transform(train)
        score = pca.transform(score)
    plot_figure(nrow=6, ncol=12)
    plot_scatter(x=train[:, 0],
                 y=train[:, 1],
                 z=None if NUM_DIM < 3 or train.shape[1] < 3 else train[:, 2],
                 size=POINT_SIZE,
                 color=y_train_color,
                 marker=y_train_marker,
                 fontsize=12,
                 legend=legends,
                 title='[train]' + str(title),
                 ax=(1, 2, 1))
    plot_scatter(x=score[:, 0],
                 y=score[:, 1],
                 z=None if NUM_DIM < 3 or score.shape[1] < 3 else score[:, 2],
                 size=POINT_SIZE,
                 color=y_score_color,
                 marker=y_score_marker,
                 fontsize=12,
                 legend=legends,
                 title='[score]' + str(title),
                 ax=(1, 2, 2))
 def plot_divergence(self,
                     X=OMIC.transcriptomic,
                     omic=OMIC.proteomic,
                     algo='tsne',
                     n_pairs=18,
                     ncol=6):
     r""" Select the most diverged pair within given `omic`, use `X` as
 coordinate and the pair's value as intensity for plotting the scatter
 heatmap. """
     om1 = OMIC.parse(X)
     om2 = OMIC.parse(omic)
     ## prepare the coordinate
     X = self.dimension_reduce(om1, n_components=2, algo=algo)
     n_points = X.shape[0]
     ## prepare the value
     y = self.numpy(om2)
     varnames = self.get_var_names(om2)
     ## check correlation type
     corr_fn = lambda m, n: (spearmanr(m, n, nan_policy='omit').correlation
                             + pearsonr(m, n)[0]) / 2
     ## create the correlation matrix
     corr_ids = []
     corr = []
     for i in range(y.shape[1]):
         for j in range(i + 1, y.shape[1]):
             corr_ids.append((i, j))
             corr.append(corr_fn(y[:, i], y[:, j]))
     ## sorting and select the smallest correlated pairs
     sort_ids = np.argsort(corr)[:int(n_pairs)]
     corr = np.array(corr)[sort_ids]
     corr_ids = np.array(corr_ids)[sort_ids]
     ## plotting
     nrow = int(np.ceil((n_pairs / ncol)))
     fig = plt.figure(figsize=(ncol * 3, nrow * 3))
     for idx, ((i, j), c) in enumerate(zip(corr_ids, corr)):
         name1 = varnames[i]
         name2 = varnames[j]
         y1 = y[:, i]
         y1 = (y1 - np.min(y1)) / (np.max(y1) - np.min(y1))
         y2 = y[:, j]
         y2 = (y2 - np.min(y2)) / (np.max(y2) - np.min(y2))
         val = y1 - y2
         vs.plot_scatter(X,
                         color='bwr',
                         size=20 if n_points < 1000 else
                         (100000 / n_points),
                         val=val,
                         alpha=0.6,
                         cbar=True,
                         cbar_ticks=[name2, 'Others', name1],
                         cbar_horizontal=True,
                         fontsize=8,
                         ax=(nrow, ncol, idx + 1))
     ## adjust and save
     self.add_figure("divergence_%s_%s_%s" % (om1.name, om2.name, algo),
                     fig)
     return self
def evaluate_latent(fn, feeder, title):
    y_true = []
    Z = []
    for outputs in Progbar(feeder.set_batch(batch_mode='file'),
                           name=title,
                           print_report=True,
                           print_summary=False,
                           count_func=lambda x: x[-1].shape[0]):
        name = str(outputs[0])
        idx = int(outputs[1])
        data = outputs[2:]
        assert idx == 0
        y_true.append(name)
        Z.append(fn(*data))
    Z = np.concatenate(Z, axis=0)
    # ====== visualize spectrogram ====== #
    if Z.ndim >= 3:
        sample = np.random.choice(range(len(Z)), size=3, replace=False)
        spec = Z[sample.astype('int32')]
        y = [y_true[int(i)] for i in sample]
        plot_figure(nrow=6, ncol=6)
        for i, (s, tit) in enumerate(zip(spec, y)):
            s = s.reshape(len(s), -1)
            plot_spectrogram(s.T, ax=(1, 3, i + 1), title=tit)
    # ====== visualize each point ====== #
    # flattent to 2D
    Z = np.reshape(Z, newshape=(len(Z), -1))
    # tsne if necessary
    if Z.shape[-1] > 3:
        Z = fast_tsne(Z,
                      n_components=3,
                      n_jobs=8,
                      random_state=K.get_rng().randint(0, 10e8))
    # color and marker
    Z_color = [digit_color_map[i.split('_')[-1]] for i in y_true]
    Z_marker = [gender_marker_map[i.split('_')[1]] for i in y_true]
    plot_figure(nrow=6, ncol=20)
    for i, azim in enumerate((15, 60, 120)):
        plot_scatter(x=Z[:, 0],
                     y=Z[:, 1],
                     z=Z[:, 2],
                     ax=(1, 3, i + 1),
                     size=4,
                     color=Z_color,
                     marker=Z_marker,
                     azim=azim,
                     legend=legends if i == 1 else None,
                     legend_ncol=11,
                     fontsize=10,
                     title=title)
    plot_save(os.path.join(FIG_PATH, '%s.pdf' % title))
Example #7
0
 def _show_latents(Z, title):
     plt.figure(figsize=(3.5 * n_latents, 3.5), dpi=150)
     for idx, z in enumerate(Z):
         mean = flatten(z.mean())
         stddev = flatten(z.stddev())
         if mean.ndim == 2:
             mean = np.mean(mean, 0)
             stddev = np.mean(stddev, 0)
         vs.plot_scatter(
             x=mean,
             y=stddev,
             ax=(1, n_latents, idx + 1),
             size=15 if len(mean) < 128 else 8,
             title=f'[Test-{title}]#{idx} - {len(mean)} (units)',
             **styles)
Example #8
0
def plot(train, score, title, applying_pca=False):
  if applying_pca:
    pca = PCA(n_components=NUM_DIM)
    pca.fit(train)
    train = pca.transform(train)
    score = pca.transform(score)
  plot_figure(nrow=6, ncol=12)
  plot_scatter(x=train[:, 0], y=train[:, 1],
               z=None if NUM_DIM < 3 or train.shape[1] < 3 else train[:, 2],
               size=POINT_SIZE, color=y_train_color, marker=y_train_marker,
               fontsize=12, legend=legends,
               title='[train]' + str(title),
               ax=(1, 2, 1))
  plot_scatter(x=score[:, 0], y=score[:, 1],
               z=None if NUM_DIM < 3 or score.shape[1] < 3 else score[:, 2],
               size=POINT_SIZE, color=y_score_color, marker=y_score_marker,
               fontsize=12, legend=legends,
               title='[score]' + str(title),
               ax=(1, 2, 2))
def plot_latent(data, idx, title):
  mean, samples = data
  # only the mean
  ax = vs.subplot(2, N_MODELS, idx)
  vs.plot_scatter(fn_dim_reduction(mean),
                  ax=ax,
                  title='[Only Mean]' + title,
                  legend_enable=False,
                  **fig_config)
  # mean and sample (single t-SNE)
  ax = vs.subplot(2, N_MODELS, idx + N_MODELS)
  z = np.concatenate([np.expand_dims(mean, axis=0), samples], axis=0)
  z = np.reshape(fn_dim_reduction(z.reshape(-1, z.shape[-1])),
                 z.shape[:-1] + (2,))
  for i in z[1:]:
    vs.plot_scatter(i, ax=ax, legend_enable=False, **fig_config1)
  vs.plot_scatter(z[0],
                  ax=ax,
                  title='[Both Mean and Samples]' + title,
                  legend_enable=True,
                  **fig_config)
Example #10
0
  def plot_uncertainty_scatter(self, factors=None, n_samples=2, algo='tsne'):
    r""" Plotting the scatter points of the mean and sampled latent codes,
    colored by the factors.

    Arguments:
      factors : list of Integer or String. The index or name of factors taken
        into account for analyzing.
    """
    factors = self._check_factors(factors)
    # this all include tarin and test data separatedly
    z_mean = np.concatenate(self.representations_mean)
    z_var = np.concatenate(
        [np.mean(var, axis=1) for var in self.representations_variance])
    z_samples = [
        z for z in np.concatenate(self.representations_sample(int(n_samples)),
                                  axis=1)
    ]
    F = np.concatenate(self.original_factors, axis=0)[:, factors]
    labels = self.factors_name[factors]
    # preprocessing
    inputs = tuple([z_mean] + z_samples)
    Z = dimension_reduce(*inputs,
                         algo=algo,
                         n_components=2,
                         return_model=False,
                         combined=True,
                         random_state=self.randint)
    V = utils.discretizing(z_var[:, np.newaxis], n_bins=10).ravel()
    # the figure
    nrow = 3
    ncol = int(np.ceil(len(labels) / nrow))
    fig = vs.plot_figure(nrow=nrow * 4, ncol=ncol * 4, dpi=80)
    for idx, (name, y) in enumerate(zip(labels, F.T)):
      ax = vs.plot_subplot(nrow, ncol, idx + 1)
      for i, x in enumerate(Z):
        kw = dict(val=y,
                  color="coolwarm",
                  ax=ax,
                  x=x,
                  grid=False,
                  legend_enable=False,
                  centroids=True,
                  fontsize=12)
        if i == 0:  # the mean value
          vs.plot_scatter(size=V,
                          size_range=(8, 80),
                          alpha=0.3,
                          linewidths=0,
                          cbar=True,
                          cbar_horizontal=True,
                          title=name,
                          **kw)
        else:  # the samples
          vs.plot_scatter_text(size=8,
                               marker='x',
                               alpha=0.8,
                               weight='light',
                               **kw)
    # fig.tight_layout()
    self.add_figure("uncertainty_scatter_%s" % algo, fig)
    return self
    def plot_scatter(self,
                     X=OMIC.transcriptomic,
                     color_by=OMIC.proteomic,
                     marker_by=None,
                     clustering='kmeans',
                     legend=True,
                     dimension_reduction='tsne',
                     max_scatter_points=5000,
                     ax=None,
                     fig=None,
                     title='',
                     return_figure=False):
        r""" Scatter plot of dimension using binarized protein labels

    Arguments:
      X : instance of OMIC.
        which OMIC data used for coordinates
      color_by : instance of OMIC.
        which OMIC data will be used for coloring the points
      marker_by : instance of OMIC.
        which OMIC data will be used for selecting the marker type
        (e.g. dot, square, triangle ...)
      clustering : {'kmeans', 'knn', 'pca', 'tsne', 'umap', 'louvain'}.
        Clustering algorithm, in case algorithm in ('pca', 'tsne', 'umap'),
        perform dimension reduction before clustering.
        Note: clustering is only applied in case of continuous data.
      dimension_reduction : {'tsne', 'umap', 'pca', None}.
        Dimension reduction algorithm. If None, just take the first 2
        dimension
    """
        ax = vs.to_axis2D(ax, fig=fig)
        omic = OMIC.parse(X)
        omic_name = omic.name
        max_scatter_points = int(max_scatter_points)
        ## prepare data
        X = self.dimension_reduce(omic,
                                  n_components=2,
                                  algo=dimension_reduction)
        color_name, colors = _process_omics(self,
                                            color_by,
                                            clustering=clustering,
                                            allow_none=True)
        marker_name, markers = _process_omics(self,
                                              marker_by,
                                              clustering=clustering,
                                              allow_none=True)
        ## downsampling
        if max_scatter_points > 0:
            ids = np.random.permutation(X.shape[0])[:max_scatter_points]
            X = X[ids]
            if colors is not None:
                colors = colors[ids]
            if markers is not None:
                markers = markers[ids]
        n_points = X.shape[0]
        ## ploting
        kw = dict(color='b')
        if colors is not None:
            if is_categorical_dtype(colors):  # categorical values
                kw['color'] = colors
            else:  # integral values
                kw['val'] = colors
                kw['color'] = 'bwr'
        name = '_'.join(str(i) for i in [omic_name, color_name, marker_name])
        title = f"[{dimension_reduction}-{name}]{title}"
        vs.plot_scatter(X,
                        marker='.' if markers is None else markers,
                        size=88 if n_points < 1000 else (120000 / n_points),
                        alpha=0.8,
                        legend_enable=bool(legend),
                        grid=False,
                        ax=ax,
                        title=title,
                        **kw)
        fig = ax.get_figure()
        if return_figure:
            return fig
        self.add_figure(f"scatter_{name}_{str(dimension_reduction).lower()}",
                        fig)
        return self
Example #12
0
    def plot_imputation_scatter(self,
                                test=True,
                                pca=False,
                                color_by_library=True):
        start_time = time.time()
        n_system = len(self) + 2  # add the original and the corrupted
        data_type = 'test' if test else 'train'

        if n_system <= 5:
            nrow = 1
            ncol = n_system
        else:
            nrow = 2
            ncol = int(np.ceil(n_system / 2))

        X_org = self.posteriors[0].X_test_org if test else self.posteriors[
            0].X_train_org
        X_crr = self.posteriors[0].X_test if test else self.posteriors[
            0].X_train
        y = self.posteriors[0].y_test if test else self.posteriors[0].y_train
        labels = self.posteriors[0].labels
        is_binary_classes = self.posteriors[0].is_binary_classes
        allV = [X_org, X_crr] + [
            pos.V_test if test else pos.V_train for pos in self.posteriors
        ]
        assert X_org.shape == X_crr.shape and all(v.shape == X_org.shape
                                                  for v in allV)
        all_names = ["[%s]Original" % data_type,
                     "[%s]Corrupted" % data_type
                     ] + [i.short_id_lines for i in self.posteriors]

        # log-normalize everything
        if len(X_org) > 5000:
            np.random.seed(5218)
            ids = np.random.permutation(X_org.shape[0])[:5000]
            allV = [v[ids] for v in allV]
            y = y[ids]

        if is_binary_classes:
            y = np.argmax(y, axis=-1)
        else:
            y = ProbabilisticEmbedding().fit_transform(y)
            y = np.argmax(y, axis=-1)

        allV = [log_norm(v) for v in allV]

        fig = plt.figure(figsize=(min(20, 5 * ncol) + 2, nrow * 5))
        for idx, (name, v) in enumerate(zip(all_names, allV)):
            ax = plt.subplot(nrow, ncol, idx + 1)
            n = np.sum(v, axis=-1)
            v = fast_pca(v, n_components=2) if pca else fast_tsne(
                v, n_components=2)
            with catch_warnings_ignore(Warning):
                if color_by_library:
                    plot_scatter(x=v,
                                 val=n,
                                 ax=ax,
                                 size=8,
                                 legend_enable=False,
                                 grid=False,
                                 title=name)
                else:
                    plot_scatter(x=v,
                                 color=[labels[i] for i in y],
                                 marker=[labels[i] for i in y],
                                 ax=ax,
                                 size=8,
                                 legend_enable=True if idx == 0 else False,
                                 grid=False,
                                 title=name)

        with catch_warnings_ignore(Warning):
            plt.tight_layout()
        self.add_figure(
            'imputation_scatter_%s_%s' %
            ('lib' if color_by_library else 'cell', data_type), fig)
        return self._log(
            'plot_imputation_scatter[%s] %s(s)' %
            (data_type, ctext(time.time() - start_time, 'lightyellow')))
Example #13
0
def evaluate(vae,
             ds,
             expdir: str,
             title: str,
             batch_size: int = 32,
             seed: int = 1):
    from odin.bay.vi import Correlation
    rand = np.random.RandomState(seed=seed)
    if not os.path.exists(expdir):
        os.makedirs(expdir)
    tanh = True if ds.name.lower() == 'celeba' else False
    ## data for training semi-supervised
    # careful don't allow any data leakage!
    train = ds.create_dataset('train',
                              batch_size=batch_size,
                              label_percent=True,
                              shuffle=False,
                              normalize='tanh' if tanh else 'probs')
    data = [(vae.encode(x, training=False), y) \
      for x, y in tqdm(train, desc=title)]
    x_semi_train = tf.concat(
        [tf.concat([i.mean(), _ymean(j)], axis=1) for (i, j), _ in data],
        axis=0).numpy()
    y_semi_train = tf.concat([i for _, i in data], axis=0).numpy()
    # shuffle
    ids = rand.permutation(x_semi_train.shape[0])
    x_semi_train = x_semi_train[ids]
    y_semi_train = y_semi_train[ids]
    ## data for testing
    test = ds.create_dataset('test',
                             batch_size=batch_size,
                             label_percent=True,
                             shuffle=False,
                             normalize='tanh' if tanh else 'probs')
    prog = tqdm(test, desc=title)
    llk_x = []
    llk_y = []
    z = []
    y_true = []
    y_pred = []
    x_true = []
    x_pred = []
    x_org, x_rec = [], []
    for x, y in prog:
        px, (qz, qy) = vae(x, training=False)
        y_true.append(y)
        y_pred.append(_ymean(qy))
        z.append(qz.mean())
        llk_x.append(px.log_prob(x))
        llk_y.append(qy.log_prob(y))
        if rand.uniform() < 0.005 or len(x_org) < 2:
            x_org.append(x)
            x_rec.append(px.mean())
    ## llk
    llk_x = tf.reduce_mean(tf.concat(llk_x, axis=0)).numpy()
    llk_y = tf.reduce_mean(tf.concat(llk_y, axis=0)).numpy()
    ## the latents
    z = tf.concat(z, axis=0).numpy()
    y_true = tf.concat(y_true, axis=0).numpy()
    y_pred = tf.concat(y_pred, axis=0).numpy()
    x_semi_test = tf.concat([z, y_pred], axis=-1).numpy()
    # shuffle
    ids = rand.permutation(z.shape[0])
    z = z[ids]
    y_true = y_true[ids]
    y_pred = y_pred[ids]
    x_semi_test = x_semi_test[ids]
    ## saving reconstruction images
    x_org = tf.concat(x_org, axis=0).numpy()
    x_rec = tf.concat(x_rec, axis=0).numpy()
    ids = rand.permutation(x_org.shape[0])
    x_org = x_org[ids][:36]
    x_rec = x_rec[ids][:36]
    vmin = x_rec.reshape((36, -1)).min(axis=1).reshape((36, 1, 1, 1))
    vmax = x_rec.reshape((36, -1)).max(axis=1).reshape((36, 1, 1, 1))
    if tanh:
        x_org = (x_org + 1.) / 2.
    x_rec = (x_rec - vmin) / (vmax - vmin)
    if x_org.shape[-1] == 1:  # grayscale image
        x_org = np.squeeze(x_org, -1)
        x_rec = np.squeeze(x_rec, -1)
    else:  # color image
        x_org = np.transpose(x_org, (0, 3, 1, 2))
        x_rec = np.transpose(x_rec, (0, 3, 1, 2))
    plt.figure(figsize=(15, 8))
    ax = plt.subplot(1, 2, 1)
    vs.plot_images(x_org, grids=(6, 6), ax=ax, title='Original')
    ax = plt.subplot(1, 2, 2)
    vs.plot_images(x_rec, grids=(6, 6), ax=ax, title='Reconstructed')
    plt.tight_layout()
    ## prepare the labels
    if ds.name in ('mnist', 'fashionmnist', 'celeba'):
        true = np.argmax(y_true, axis=-1)
        pred = np.argmax(y_pred, axis=-1)
        y_semi_train = np.argmax(y_semi_train, axis=-1)
        y_semi_test = true
        labels_name = ds.labels
    else:  # shapes3d dsprites
        true = y_true[:, 2].astype(np.int32)
        pred = y_pred[:, 2].astype(np.int32)
        y_semi_train = y_semi_train[:, 2].astype(np.int32)
        y_semi_test = true
        if ds.name == 'shapes3d':
            labels_name = ['cube', 'cylinder', 'sphere', 'round']
        elif ds.name == 'dsprites':
            labels_name = ['square', 'ellipse', 'heart']
    plt.figure(figsize=(8, 8))
    vs.plot_confusion_matrix(cm=confusion_matrix(y_true=true, y_pred=pred),
                             labels=labels_name,
                             cbar=True,
                             fontsize=10,
                             title=title)
    labels = np.array([labels_name[i] for i in true])
    labels_pred = np.array([labels_name[i] for i in pred])
    ## save arrays for later inspectation
    np.savez_compressed(f'{expdir}/arrays',
                        x_train=x_semi_train,
                        y_train=y_semi_train,
                        x_test=x_semi_test,
                        y_test=y_semi_test,
                        zdim=z.shape[1],
                        labels=labels_name)
    print(f'Export arrays to "{expdir}/arrays.npz"')
    ## semi-supervised
    with open(f'{expdir}/results.txt', 'w') as f:
        print(f'Export results to "{expdir}/results.txt"')
        f.write(f'Steps: {vae.step.numpy()}\n')
        f.write(f'llk_x: {llk_x}\n')
        f.write(f'llk_y: {llk_y}\n')
        for p in [0.004, 0.06, 0.2, 0.99]:
            x_train, x_test, y_train, y_test = train_test_split(
                x_semi_train,
                y_semi_train,
                train_size=int(np.round(p * x_semi_train.shape[0])),
                random_state=1,
            )
            m = LogisticRegression(max_iter=3000, random_state=1)
            m.fit(x_train, y_train)
            # write the report
            f.write(f'{m.__class__.__name__} Number of labels: '
                    f'{p} {x_train.shape[0]}/{x_test.shape[0]}')
            f.write('\nValidation:\n')
            f.write(
                classification_report(y_true=y_test, y_pred=m.predict(x_test)))
            f.write('\nTest:\n')
            f.write(
                classification_report(y_true=y_semi_test,
                                      y_pred=m.predict(x_semi_test)))
            f.write('------------\n')
    ## scatter plot
    n_points = 4000
    # tsne plot
    tsne = DimReduce.TSNE(z[:n_points], n_components=2)
    kw = dict(x=tsne[:, 0], y=tsne[:, 1], grid=False, size=12.0, alpha=0.6)
    plt.figure(figsize=(8, 8))
    vs.plot_scatter(color=labels[:n_points], title=f'[True-tSNE]{title}', **kw)
    plt.figure(figsize=(8, 8))
    vs.plot_scatter(color=labels_pred[:n_points],
                    title=f'[Pred-tSNE]{title}',
                    **kw)
    # pca plot
    pca = DimReduce.PCA(z, n_components=2)
    kw = dict(x=pca[:, 0], y=pca[:, 1], grid=False, size=12.0, alpha=0.6)
    plt.figure(figsize=(8, 8))
    vs.plot_scatter(color=labels, title=f'[True-PCA]{title}', **kw)
    plt.figure(figsize=(8, 8))
    vs.plot_scatter(color=labels_pred, title=f'[Pred-PCA]{title}', **kw)
    ## factors plot
    corr = (Correlation.Spearman(z, y_true) +
            Correlation.Pearson(z, y_true)) / 2.
    best_z = np.argsort(np.abs(corr), axis=0)[-2:]
    style = dict(size=15.0, alpha=0.6, grid=False)
    for fi, (z1, z2) in enumerate(best_z.T):
        plt.figure(figsize=(8, 4))
        ax = plt.subplot(1, 2, 1)
        vs.plot_scatter(x=z[:n_points, z1],
                        y=z[:n_points, z2],
                        val=y_true[:n_points, fi],
                        ax=ax,
                        title=ds.labels[fi],
                        **style)
        ax = plt.subplot(1, 2, 2)
        vs.plot_scatter(x=z[:n_points, z1],
                        y=z[:n_points, z2],
                        val=y_pred[:n_points, fi],
                        ax=ax,
                        title=ds.labels[fi],
                        **style)
        plt.tight_layout()
    ## save all plot
    vs.plot_save(f'{expdir}/analysis.pdf', dpi=180, verbose=True)
Example #14
0
def plot_latents_binary(Z,
                        y,
                        labels_name,
                        title=None,
                        elev=None,
                        azim=None,
                        algo='tsne',
                        ax=None,
                        show_legend=True,
                        size=12,
                        fontsize=12,
                        show_scores=True,
                        enable_argmax=True,
                        enable_separated=False):
    from matplotlib import pyplot as plt
    if title is None:
        title = ''
    title = '[%s]%s' % (algo, title)
    ax = to_axis(ax)
    # ====== Downsample if the data is huge ====== #
    Z, y = downsample_data(Z, y)
    # ====== checking inputs ====== #
    assert Z.ndim == 2, Z.shape
    assert Z.shape[0] == y.shape[0]
    num_classes = len(labels_name)
    # ====== preprocessing ====== #
    Z = dimension_reduction(Z, algo=algo)
    # ====== clustering metrics ====== #
    if show_scores:
        scores = clustering_scores(
            latent=Z,
            labels=np.argmax(y, axis=-1) if y.ndim == 2 else y,
            n_labels=num_classes)
        title += '\n'
        for k, v in sorted(scores.items(), key=lambda x: x[0]):
            title += '%s:%.2f ' % (k, v)
    # ====== plotting ====== #
    if enable_argmax:
        y_argmax = np.argmax(y, axis=-1) if y.ndim == 2 else y
        fast_scatter(x=Z,
                     y=y_argmax,
                     labels=labels_name,
                     ax=ax,
                     size=size,
                     title=title,
                     fontsize=fontsize,
                     enable_legend=bool(show_legend))
        ax.grid(False)
    # ====== plot each protein ====== #
    if enable_separated:
        colormap = 'Reds'  # bwr
        ncol = 5 if num_classes <= 20 else 9
        nrow = int(np.ceil(num_classes / ncol))
        fig = plot_figure(nrow=4 * nrow, ncol=20)
        for i, lab in enumerate(labels_name):
            val = K.log_norm(y[:, i], axis=0)
            plot_scatter(x=Z[:, 0],
                         y=Z[:, 1],
                         val=val / np.sum(val),
                         ax=(nrow, ncol, i + 1),
                         color=colormap,
                         size=size,
                         alpha=0.8,
                         fontsize=8,
                         grid=False,
                         title=lab)

        plt.grid(False)
        # big title
        plt.suptitle(title, fontsize=fontsize)
        # show the colorbar
        import matplotlib as mpl
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
        cmap = mpl.cm.get_cmap(name=colormap)
        norm = mpl.colors.Normalize(vmin=0., vmax=1.)
        cb1 = mpl.colorbar.ColorbarBase(cbar_ax,
                                        cmap=cmap,
                                        norm=norm,
                                        orientation='vertical')
        cb1.set_label('Protein markers level')
Example #15
0
from odin import visual as vs

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

tf.random.set_seed(8)
np.random.seed(8)

X, y = load_digits(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

X_umap = ml.fast_umap(X_train, X_test)
X_tsne = ml.fast_tsne(X_train, X_test)
X_pca = ml.fast_pca(X_train, X_test, n_components=2)

styles = dict(size=12, alpha=0.6, centroids=True)

vs.plot_figure(6, 12)
vs.plot_scatter(x=X_pca[0], color=y_train, ax=(1, 2, 1), **styles)
vs.plot_scatter(x=X_pca[1], color=y_test, ax=(1, 2, 2), **styles)

vs.plot_figure(6, 12)
vs.plot_scatter(x=X_tsne[0], color=y_train, ax=(1, 2, 1), **styles)
vs.plot_scatter(x=X_tsne[1], color=y_test, ax=(1, 2, 2), **styles)

vs.plot_figure(6, 12)
vs.plot_scatter(x=X_umap[0], color=y_train, ax=(1, 2, 1), **styles)
vs.plot_scatter(x=X_umap[1], color=y_test, ax=(1, 2, 2), **styles)

vs.plot_save()
Example #16
0
    def plot_disentanglement_scatter(self,
                                     factor_omic='proteomic',
                                     pairs=PROTEIN_PAIR_NEGATIVE,
                                     corr_matrix=None,
                                     n_pairs=10,
                                     latents_per_pair=5,
                                     magnify=2):
        r""" Select the most differentiated pairs of `factor_omic`, then,
    select the most correlated latent to each factor within the pair,
    use the latents' coordination for plotting the scatter points, and,
    use the `factor_omic` values for coloring the heatmap.

    Arguments:
      factor_omic : `OMIC`, the OMIC used as groundtruth factors
      pairs : list of `(factor_name1, factor_name2)` (optional)
        This determines which pairs will be plotted.
      corr_matrix : correlation matrix between `factor_omic` and the latents,
        dimension must be `[n_factor_omic, n_latents]`.
        This determines which latents dimension will be selected for plotting
        the pairs.
      magnify : a Scalar (default: 1)
        a constant for magnifying the color divergence of
        the `factor_omic`, the higher, small differences lead to stronger
        color divergence.

    Example:
    ```
    pairs = post.get_marker_pairs()
    corr = post.get_correlation_matrix('proteomic', 'latent')
    post.plot_disentanglement_scatter('proteomic',
                                      pairs=pairs,
                                      corr_matrix=corr,
                                      magnify=2)
    post.plot_disentanglement_scatter('iproteomic',
                                      pairs=pairs,
                                      corr_matrix=corr,
                                      magnify=2)
    ```
    """
        factor_omic = OMIC.parse(factor_omic)
        var_ids = self.dataset.get_var_indices(factor_omic)
        ### marker pairs
        if pairs is None:
            pairs = self.get_marker_pairs(omic1=factor_omic,
                                          omic2=None,
                                          most_correlated=False,
                                          remove_duplicated=True,
                                          n=int(n_pairs))
        else:
            pairs = [(name1, name2) for name1, name2 in pairs
                     if name1 in var_ids and name2 in var_ids]
            assert len(pairs) > 0
        ### correlation matrix
        if corr_matrix is None:
            corr_matrix = self.get_correlation_matrix(factor_omic, 'latent',
                                                      'average')
        shape = (self.dataset.get_dim(factor_omic),
                 self.dataset.get_dim(OMIC.latent))
        assert corr_matrix.shape == shape, \
          (f"Correlation matrix must has shape {shape} but given matrix "
           f"with shape {corr_matrix.shape}")
        ### getting all latents for each pair
        latents_per_pair = int(latents_per_pair)
        omic2latent = {}
        for name1, name2 in pairs:
            latents = []
            seen = set()
            # sort in descending order
            for i1, i2 in _iter_2list(
                    np.argsort(corr_matrix[var_ids[name1]])[::-1],
                    np.argsort(corr_matrix[var_ids[name2]])[::-1]):
                if i1 != i2 and i1 not in seen and i2 not in seen:
                    seen.add(i1)
                    seen.add(i2)
                    latents.append([i1, i2])
                if len(latents) == latents_per_pair:
                    break
            omic2latent[(name1, name2)] = latents
        ### plotting
        ncol = 5
        X = self.dataset.get_omic(omic=factor_omic)
        Z = self.latents.mean().numpy()
        latent_names = self.dataset.get_var_names(OMIC.latent)
        norm = lambda x: (x - np.min(x)) / (np.max(x) - np.min(x))
        for (name1, name2), pairs in omic2latent.items():
            nrow = int(np.ceil(len(pairs) / ncol))
            fig = vs.plot_figure(nrow=nrow * 3.3, ncol=ncol * 4, dpi=80)
            # normalize the factor OMIC for color values
            x1 = X[:, var_ids[name1]]
            x2 = X[:, var_ids[name2]]
            x = norm(x1) - norm(x2)
            x = np.clip(x * magnify, -1., 1.)
            # get latents' coordination
            for idx, (i1, i2) in enumerate(pairs):
                z1 = Z[:, i1]
                z2 = Z[:, i2]
                ax = vs.plot_scatter(x=z1,
                                     y=z2,
                                     val=x,
                                     ax=(nrow, ncol, idx + 1),
                                     cbar=True,
                                     cbar_ticks=[name1, 'others', name2],
                                     cbar_labrotation=-60,
                                     ticks_off=True,
                                     fontsize=8,
                                     max_n_points=2000,
                                     size=16)
                # xticks
                v = max(np.abs(np.min(z1)), np.abs(np.max(z1)))
                ticks = np.linspace(-v, v, num=5)
                ax.set_xticks(ticks)
                ax.set_xticklabels([f"{i:.2g}" for i in ticks], fontsize=8)
                ax.set_xlabel(latent_names[i1], fontsize=10)
                # yticks
                v = max(np.abs(np.min(z2)), np.abs(np.max(z2)))
                ticks = np.linspace(-v, v, num=5)
                ax.set_yticks(ticks)
                ax.set_yticklabels([f"{i:.2g}" for i in ticks], fontsize=8)
                ax.set_ylabel(latent_names[i2], fontsize=10)
            # final title
            fig.suptitle(f"[{factor_omic.name}] {name1}-{name2}", fontsize=10)
            fig.tight_layout(rect=[0.0, 0.02, 1.0, 0.98])
            self.add_figure(name=f"scatter_{factor_omic.name}_{name1}_{name2}",
                            fig=fig)
        return self
from sklearn.manifold import TSNE
from odin.utils import UnitTimer, TemporaryDirectory

iris = F.load_iris()
print(iris)
pca = MiniBatchPCA()

X = iris['X'][:]

i = 0
while i < X.shape[0]:
    x = X[i:i + 20]
    i += 20
    pca.partial_fit(x)
    print("Fitting PCA ...")

with UnitTimer():
    for i in range(8):
        x = pca.transform(X)

with UnitTimer():
    for i in range(8):
        x = pca.transform_mpi(X, keep_order=True, ncpu=1, n_components=2)
print("Output shape:", x.shape)

colors = ['r' if i == 0 else ('b' if i == 1 else 'g')
          for i in iris['y'][:]]
visual.plot_scatter(x[:, 0], x[:, 1], color=colors, size=8)
visual.plot_save('/tmp/tmp.pdf')
# bananab
Example #18
0
def plot_monitoring_epoch(X, X_drop, y, Z, Z_drop, W_outputs, W_drop_outputs,
                          pi, pi_drop, row_name, dropout_percentage,
                          curr_epoch, ds_name, labels, save_dir):
    # Order of W_outputs: [W, W_stdev_total, W_stdev_explained]
    from matplotlib import pyplot as plt
    if y.ndim == 2:
        y = np.argmax(y, axis=-1)
    y = np.array([labels[i] for i in y])
    dropout_percentage_text = '%g%%' % (dropout_percentage * 100)

    Z_pca = fast_pca(Z, n_components=2, random_state=5218)
    Z_pca_drop = fast_pca(Z_drop, n_components=2, random_state=5218)
    if W_outputs is not None:
        X_pca, X_pca_drop, W_pca, W_pca_drop = fast_pca(X,
                                                        X_drop,
                                                        W_outputs[0],
                                                        W_drop_outputs[0],
                                                        n_components=2,
                                                        random_state=5218)
    # ====== downsampling ====== #
    rand = np.random.RandomState(seed=5218)
    n_test_samples = len(y)
    ids = np.arange(n_test_samples, dtype='int32')
    if n_test_samples > 8000:
        ids = rand.choice(ids, size=8000, replace=False)
    # ====== scatter configuration ====== #
    config = dict(size=6, labels=None)
    y = y[ids]

    X = X[ids]
    X_drop = X_drop[ids]
    Z_pca = Z_pca[ids]
    X_pca = X_pca[ids]
    W_pca = W_pca[ids]

    W_outputs = [w[ids] for w in W_outputs]
    W_drop_outputs = [w[ids] for w in W_drop_outputs]
    Z_pca_drop = Z_pca_drop[ids]
    X_pca_drop = X_pca_drop[ids]
    W_pca_drop = W_pca_drop[ids]

    if pi is not None:
        pi = pi[ids]
        pi_drop = pi_drop[ids]
    # ====== plotting NO reconstruction ====== #
    if W_outputs is None:
        plot_figure(nrow=8, ncol=20)
        fast_scatter(x=Z_pca,
                     y=y,
                     title="[PCA] Test data latent space",
                     enable_legend=True,
                     ax=(1, 2, 1),
                     **config)
        fast_scatter(x=Z_pca_drop,
                     y=y,
                     title="[PCA][Dropped:%s] Test data latent space" %
                     dropout_percentage_text,
                     ax=(1, 2, 2),
                     **config)
    # ====== plotting WITH reconstruction ====== #
    else:
        plot_figure(nrow=16, ncol=20)
        # original test data WITHOUT dropout
        fast_scatter(x=X_pca,
                     y=y,
                     title="[PCA][Test Data] Original",
                     ax=(2, 3, 1),
                     **config)
        fast_scatter(x=W_pca,
                     y=y,
                     title="Reconstructed",
                     ax=(2, 3, 2),
                     **config)
        fast_scatter(x=Z_pca,
                     y=y,
                     title="Latent space",
                     ax=(2, 3, 3),
                     **config)
        # original test data WITH dropout
        fast_scatter(x=X_pca_drop,
                     y=y,
                     title="[PCA][Dropped:%s][Test Data] Original" %
                     dropout_percentage_text,
                     ax=(2, 3, 4),
                     **config)
        fast_scatter(x=W_pca_drop,
                     y=y,
                     title="Reconstructed",
                     ax=(2, 3, 5),
                     enable_legend=True,
                     **config)
        fast_scatter(x=Z_pca_drop,
                     y=y,
                     title="Latent space",
                     ax=(2, 3, 6),
                     **config)
    plot_save(os.path.join(save_dir, 'latent_epoch%d.png') % curr_epoch,
              dpi=180,
              clear_all=True,
              log=True)
    # ====== plot count-sum ====== #
    if W_outputs is not None:
        X_countsum = _clip_count_sum(np.sum(X, axis=-1))
        W_countsum = _clip_count_sum(np.sum(W_outputs[0], axis=-1))
        X_drop_countsum = _clip_count_sum(np.sum(X_drop, axis=-1))
        W_drop_countsum = _clip_count_sum(np.sum(W_drop_outputs[0], axis=-1))
        series_config = [
            dict(xscale='linear', yscale='linear', sort_by=None),
            dict(xscale='linear', yscale='linear', sort_by='expected')
        ]

        if pi is not None:
            pi_sum = np.mean(pi, axis=-1)
            pi_drop_sum = np.mean(pi_drop, axis=-1)
        # plot the reconstruction count sum
        plot_figure(nrow=3 * 5 + 8, ncol=18)
        with plot_gridSpec(nrow=3 * (2 if pi is None else 3) + 4 * 3 + 1,
                           ncol=6,
                           wspace=1.0,
                           hspace=0.8) as grid:
            kws = dict(colorbar=True,
                       fontsize=10,
                       size=10,
                       marker=y,
                       n_samples=1200)
            # without dropout
            ax = subplot(grid[:3, 0:3])
            plot_scatter(x=X_pca,
                         val=X_countsum,
                         ax=ax,
                         legend_enable=False,
                         title='Original data (Count-sum)',
                         **kws)
            ax = subplot(grid[:3, 3:6])
            plot_scatter(x=W_pca,
                         val=W_countsum,
                         ax=ax,
                         legend_enable=False,
                         title='Reconstruction (Count-sum)',
                         **kws)
            # with dropout
            ax = subplot(grid[3:6, 0:3])
            plot_scatter(x=X_pca_drop,
                         val=X_drop_countsum,
                         ax=ax,
                         legend_enable=True if pi is None else False,
                         legend_ncol=len(labels),
                         title='[Dropped:%s]Original data (Count-sum)' %
                         dropout_percentage_text,
                         **kws)
            ax = subplot(grid[3:6, 3:6])
            plot_scatter(x=W_pca_drop,
                         val=W_drop_countsum,
                         ax=ax,
                         legend_enable=False,
                         title='[Dropped:%s]Reconstruction (Count-sum)' %
                         dropout_percentage_text,
                         **kws)
            row_start = 6
            # zero-inflated pi
            if pi is not None:
                ax = subplot(grid[6:9, 0:3])
                plot_scatter(x=X_pca,
                             val=pi_sum,
                             ax=ax,
                             legend_enable=True,
                             legend_ncol=len(labels),
                             title='Zero-inflated probabilities',
                             **kws)
                ax = subplot(grid[6:9, 3:6])
                plot_scatter(x=X_pca,
                             val=pi_drop_sum,
                             ax=ax,
                             legend_enable=False,
                             title='[Dropped:%s]Zero-inflated probabilities' %
                             dropout_percentage_text,
                             **kws)
                row_start += 3

            # plot the count-sum series
            def plot_count_sum_series(x, w, p, row_start, tit):
                if len(w) != 3:  # no statistics provided
                    return
                expected, stdev_total, stdev_explained = w
                count_sum_observed = np.sum(x, axis=0)
                count_sum_expected = np.sum(expected, axis=0)
                count_sum_stdev_total = np.sum(stdev_total, axis=0)
                count_sum_stdev_explained = np.sum(stdev_explained, axis=0)
                if p is not None:
                    p_sum = np.mean(p, axis=0)
                for i, kws in enumerate(series_config):
                    ax = subplot(grid[row_start:row_start + 3,
                                      (i * 3):(i * 3 + 3)])
                    ax, handles, indices = plot_series_statistics(
                        count_sum_observed,
                        count_sum_expected,
                        explained_stdev=count_sum_stdev_explained,
                        total_stdev=count_sum_stdev_total,
                        fontsize=8,
                        ax=ax,
                        legend_enable=False,
                        title=tit if i == 0 else None,
                        despine=True if p is None else False,
                        return_handles=True,
                        return_indices=True,
                        **kws)
                    if p is not None:
                        _show_zero_inflated_pi(p_sum, ax, handles, indices)
                    plt.legend(handles=handles,
                               loc='best',
                               markerscale=4,
                               fontsize=8)

            # add one row extra padding
            row_start += 1
            plot_count_sum_series(x=X,
                                  w=W_outputs,
                                  p=pi,
                                  row_start=row_start,
                                  tit="Count-sum X_original - W_original")
            row_start += 1
            plot_count_sum_series(
                x=X_drop,
                w=W_drop_outputs,
                p=pi_drop,
                row_start=row_start + 3,
                tit="[Dropped:%s]Count-sum X_drop - W_dropout" %
                dropout_percentage_text)
            row_start += 1
            plot_count_sum_series(
                x=X,
                w=W_drop_outputs,
                p=pi_drop,
                row_start=row_start + 6,
                tit="[Dropped:%s]Count-sum X_original - W_dropout" %
                dropout_percentage_text)
        plot_save(os.path.join(save_dir, 'countsum_epoch%d.png') % curr_epoch,
                  dpi=180,
                  clear_all=True,
                  log=True)
    # ====== plot series of samples ====== #
    if W_outputs is not None and len(W_outputs) == 3:
        # NOTe: turn off pi here
        pi = None

        n_visual_samples = 8
        plot_figure(nrow=3 * n_visual_samples + 8, ncol=25)
        col_width = 5
        with plot_gridSpec(nrow=3 * n_visual_samples,
                           ncol=4 * col_width,
                           wspace=5.0,
                           hspace=1.0) as grid:
            curr_grid_index = 0
            for i in rand.permutation(len(X))[:n_visual_samples]:
                observed = X[i]
                expected, stdev_explained, stdev_total = [
                    w[i] for w in W_outputs
                ]
                expected_drop, stdev_explained_drop, stdev_total_drop = [
                    w[i] for w in W_drop_outputs
                ]
                if pi is not None:
                    p_zi = pi[i]
                    p_zi_drop = pi_drop[i]
                # compare to W_original
                for j, kws in enumerate(series_config):
                    ax = subplot(grid[curr_grid_index:curr_grid_index + 3,
                                      (j * col_width):(j * col_width +
                                                       col_width)])
                    ax, handles, indices = plot_series_statistics(
                        observed,
                        expected,
                        explained_stdev=stdev_explained,
                        total_stdev=stdev_total,
                        fontsize=8,
                        legend_enable=False,
                        despine=True if pi is None else False,
                        title=("'%s' X_original - W_original" %
                               row_name[i]) if j == 0 else None,
                        return_handles=True,
                        return_indices=True,
                        **kws)
                    if pi is not None:
                        _show_zero_inflated_pi(p_zi, ax, handles, indices)
                    plt.legend(handles=handles,
                               loc='best',
                               markerscale=4,
                               fontsize=8)
                # compare to W_dropout
                for j, kws in enumerate(series_config):
                    col_start = col_width * 2
                    ax = subplot(
                        grid[curr_grid_index:curr_grid_index + 3,
                             (col_start +
                              j * col_width):(col_start + j * col_width +
                                              col_width)])
                    ax, handles, indices = plot_series_statistics(
                        observed,
                        expected_drop,
                        explained_stdev=stdev_explained_drop,
                        total_stdev=stdev_total_drop,
                        fontsize=8,
                        legend_enable=False,
                        despine=True if pi is None else False,
                        title=("[Dropped:%s]'%s' X_original - W_dropout" %
                               (dropout_percentage_text, row_name[i]))
                        if j == 0 else None,
                        return_handles=True,
                        return_indices=True,
                        **kws)
                    if pi is not None:
                        _show_zero_inflated_pi(p_zi_drop, ax, handles, indices)
                    plt.legend(handles=handles,
                               loc='best',
                               markerscale=4,
                               fontsize=8)
                curr_grid_index += 3
        plot_save(os.path.join(save_dir, 'samples_epoch%d.png') % curr_epoch,
                  dpi=180,
                  clear_all=True,
                  log=True)
    # ====== special case for mnist ====== #
    if 'mnist' in ds_name and W_outputs is not None:
        plot_figure(nrow=3, ncol=18)
        n_images = 32
        ids = rand.choice(np.arange(X.shape[0], dtype='int32'),
                          size=n_images,
                          replace=False)
        meta_data = [("Org", X[ids]), ("Rec", W_outputs[0][ids]),
                     ("OrgDropout", X_drop[ids]),
                     ("RecDropout", W_drop_outputs[0][ids])]
        count = 1
        for name, data in meta_data:
            for i in range(n_images):
                x = data[i].reshape(28, 28)
                plt.subplot(4, n_images, count)
                show_image(x)
                if i == 0:
                    plt.ylabel(name, fontsize=8)
                count += 1
        plt.subplots_adjust(wspace=0.05, hspace=0.05)
        plot_save(os.path.join(save_dir, 'image_epoch%d.png') % curr_epoch,
                  dpi=180,
                  clear_all=True,
                  log=True)
Example #19
0
def plot_evaluate_reconstruction(X,
                                 W,
                                 y_raw,
                                 y_prob,
                                 X_row,
                                 X_col,
                                 labels,
                                 title,
                                 pi=None,
                                 enable_image=True,
                                 enable_tsne=True,
                                 enable_sparsity=True):
    """
  pi : zero-inflated rate (imputation rate), or dropout probabilities
  """
    print("Evaluate: [Reconstruction]", ctext(title, 'lightyellow'))
    from matplotlib import pyplot as plt
    fontsize = 12

    W_stdev_total, W_stdev_explained = None, None
    if isinstance(W, (tuple, list)):
        if len(W) == 1:
            W = W[0]
        elif len(W) == 3:
            W, W_stdev_total, W_stdev_explained = W
        else:
            raise RuntimeError()
    elif W.ndim == 3:
        W, W_stdev_total, W_stdev_explained = W[0], W[1], W[2]
    # convert the prediction to integer
    # W = W.astype('int32')

    assert (X.shape[0] == W.shape[0] == y_raw.shape[0]) and \
    (X.shape == W.shape) and \
    (y_raw.shape == y_prob.shape)

    X, X_row, W, y_raw, y_prob, pi = downsample_data(X, X_row, W, y_raw,
                                                     y_prob, pi)
    y_argmax = np.argmax(y_prob, axis=-1)

    # ====== prepare count-sum ====== #
    X_log = K.log_norm(X, axis=1)
    W_log = K.log_norm(W, axis=1)

    X_gene_countsum = np.sum(X, axis=0)
    X_cell_countsum = np.sum(X, axis=1)
    X_gene_nzeros = np.sum(X == 0, axis=0)
    X_cell_nzeros = np.sum(X == 0, axis=1)

    gene_sort = np.argsort(X_gene_countsum)
    cell_sort = np.argsort(X_cell_countsum)

    W_gene_countsum = np.sum(W, axis=0)
    W_cell_countsum = np.sum(W, axis=1)
    W_gene_nzeros = np.sum(W == 0, axis=0)
    W_cell_nzeros = np.sum(W == 0, axis=1)

    X_col_sorted = X_col[gene_sort] if X_col is not None else None
    X_row_sorted = X_row[cell_sort] if X_row is not None else None

    if pi is not None:
        pi_cell_countsum = np.mean(pi, axis=1)
        pi_gene_countsum = np.mean(pi, axis=0)

    # ====== Compare image ====== #
    if enable_image:
        _RAND = np.random.RandomState(seed=87654321)
        n_img = 12
        n_img_row = min(3, X.shape[0] // n_img)
        n_row_per_row = 2 if pi is None else 3
        plot_figure(nrow=n_img_row * 4, ncol=18)
        count = 1
        all_ids = _RAND.choice(np.arange(0, X.shape[0]),
                               size=n_img * n_img_row,
                               replace=False)

        for img_row in range(n_img_row):
            ids = all_ids[img_row * n_img:(img_row + 1) * n_img]

            # plot original images
            for _, i in enumerate(ids):
                ax = plt.subplot(n_row_per_row * n_img_row, n_img, count)
                show_image(X[i])
                if _ == 0:
                    plt.ylabel("Original")
                if X_row is not None:
                    ax.set_title(X_row[i], fontsize=8)
                count += 1
            # plot reconstructed images
            for _, i in enumerate(ids):
                plt.subplot(n_row_per_row * n_img_row, n_img, count)
                show_image(W[i])
                if _ == 0:
                    plt.ylabel("Reconstructed")
                count += 1
            # plot zero-inflated rate
            if pi is not None:
                for _, i in enumerate(ids):
                    plt.subplot(n_row_per_row * n_img_row, n_img, count)
                    show_image(pi[i], is_probability=True)
                    if _ == 0:
                        plt.ylabel("$p_{zero-inflated}$")
                    count += 1
        plt.tight_layout()
    # ====== compare the T-SNE plot ====== #
    if enable_tsne:

        def pca_and_tsne(x, w):
            x_pca, w_pca = fast_pca(x,
                                    w,
                                    n_components=512,
                                    random_state=87654321)
            x_tsne = fast_tsne(x_pca, n_components=2, random_state=87654321)
            w_tsne = fast_tsne(w_pca, n_components=2, random_state=87654321)
            return x_pca[:, :2], x_tsne, w_pca[:, :2], w_tsne

        # transforming the data
        (X_cell_pca, X_cell_tsne, W_cell_pca,
         W_cell_tsne) = pca_and_tsne(X_log, W_log)
        (X_gene_pca, X_gene_tsne, W_gene_pca,
         W_gene_tsne) = pca_and_tsne(X_log.T, W_log.T)
        # prepare the figure
        n_plot = 3 + 2  # 3 for cells, 2 for genes
        if pi is not None:
            n_plot += 2  # 2 more row for pi
        plot_figure(nrow=n_plot * 5, ncol=18)
        # Cells
        fast_scatter(x=X_cell_pca,
                     y=y_argmax,
                     labels=labels,
                     title="[PCA]Original Cell Data",
                     ax=(n_plot, 2, 1),
                     enable_legend=False)
        fast_scatter(x=W_cell_pca,
                     y=y_argmax,
                     labels=labels,
                     title="[PCA]Reconstructed Cell Data",
                     ax=(n_plot, 2, 2),
                     enable_legend=False)

        fast_scatter(x=X_cell_tsne,
                     y=y_argmax,
                     labels=labels,
                     title="[t-SNE]Original Cell Data",
                     ax=(n_plot, 2, 3),
                     enable_legend=True)
        fast_scatter(x=W_cell_tsne,
                     y=y_argmax,
                     labels=labels,
                     title="[t-SNE]Reconstructed Cell Data",
                     ax=(n_plot, 2, 4),
                     enable_legend=False)

        fast_log = lambda x: K.log_norm(x, axis=0)

        plot_scatter(
            x=X_cell_tsne,
            val=fast_log(X_cell_countsum),
            title="[t-SNE]Original Cell Data + Original Cell Countsum",
            ax=(n_plot, 2, 5),
            colorbar=True)
        plot_scatter(
            x=X_cell_tsne,
            val=fast_log(W_cell_countsum),
            title="[t-SNE]Original Cell Data + Reconstructed Cell Countsum",
            ax=(n_plot, 2, 6),
            colorbar=True)
        # Genes
        plot_scatter(x=X_gene_pca,
                     val=fast_log(X_gene_countsum),
                     title="[PCA]Original Gene Data + Original Gene Countsum",
                     ax=(n_plot, 2, 7),
                     colorbar=True)
        plot_scatter(
            x=W_gene_pca,
            val=fast_log(X_gene_countsum),
            title="[PCA]Reconstructed Gene Data + Original Gene Countsum",
            ax=(n_plot, 2, 8),
            colorbar=True)

        plot_scatter(
            x=X_gene_tsne,
            val=fast_log(X_gene_countsum),
            title="[t-SNE]Original Gene Data + Original Gene Countsum",
            ax=(n_plot, 2, 9),
            colorbar=True)
        plot_scatter(
            x=X_gene_tsne,
            val=fast_log(W_gene_countsum),
            title="[t-SNE]Original Gene Data + Reconstructed Gene Countsum",
            ax=(n_plot, 2, 10),
            colorbar=True)
        # zero-inflation rate
        if pi is not None:
            plot_scatter(
                x=X_cell_tsne,
                val=X_cell_countsum,
                title="[t-SNE]Original Cell Data + Original Cell Countsum",
                ax=(n_plot, 2, 11),
                colorbar=True)
            plot_scatter(
                x=X_cell_tsne,
                val=pi_cell_countsum,
                title="[t-SNE]Original Cell Data + Zero-inflated rate",
                ax=(n_plot, 2, 12),
                colorbar=True)

            plot_scatter(
                x=X_gene_tsne,
                val=X_gene_countsum,
                title="[t-SNE]Original Gene Data + Original Gene Countsum",
                ax=(n_plot, 2, 13),
                colorbar=True)
            plot_scatter(
                x=X_gene_tsne,
                val=pi_gene_countsum,
                title="[t-SNE]Original Gene Data + Zero-inflated rate",
                ax=(n_plot, 2, 14),
                colorbar=True)
        plt.tight_layout()
    # ******************** sparsity ******************** #
    if enable_sparsity:
        plot_figure(nrow=8, ncol=8)
        # ====== sparsity ====== #
        z = (X.ravel() == 0).astype('int32')
        z_res = (W.ravel() == 0).astype('int32')
        plot_confusion_matrix(ax=None,
                              cm=confusion_matrix(y_true=z,
                                                  y_pred=z_res,
                                                  labels=(0, 1)),
                              labels=('Not Zero', 'Zero'),
                              colorbar=True,
                              fontsize=fontsize + 4,
                              title="Sparsity")
Example #20
0
U = []
Z_hat = []
Y = []
for x, y in tqdm(valid):
    qz_x, qu_z, qz_u = vae.encode_two_stages(x)
    Z.append(qz_x.mean())
    U.append(qu_z.mean())
    Z_hat.append(qz_u.mean())
    Y.append(np.argmax(y, axis=-1))
Z = np.concatenate(Z, 0)[:5000]
U = np.concatenate(U, 0)[:5000]
Z_hat = np.concatenate(Z_hat, 0)[:5000]
Y = np.concatenate(Y, 0)[:5000]

plt.figure(figsize=(15, 5), dpi=150)
vs.plot_scatter(fast_tsne(Z), color=Y, grid=False, ax=(1, 3, 1))
vs.plot_scatter(fast_tsne(U), color=Y, grid=False, ax=(1, 3, 2))
vs.plot_scatter(fast_tsne(Z_hat), color=Y, grid=False, ax=(1, 3, 3))
plt.tight_layout()

ids = np.argsort(np.mean(qz_x.stddev(), 0))
ids_u = np.argsort(np.mean(qu_z.stddev(), 0))

plt.figure(figsize=(10, 10), dpi=200)
plot_latent_stats(mean=np.mean(qz_x.mean(), 0)[ids],
                  stddev=np.mean(qz_x.stddev(), 0)[ids],
                  ax=(3, 1, 1),
                  name='q(z|x)')
plot_latent_stats(mean=np.mean(qu_z.mean(), 0)[ids_u],
                  stddev=np.mean(qu_z.stddev(), 0)[ids_u],
                  ax=(3, 1, 2),
Example #21
0
    def plot_scatter(
        self,
        factor_index: Union[int, str],
        classifier: Optional[Literal['svm', 'tree', 'logistic', 'knn', 'lda',
                                     'gbt']] = None,
        classifier_kw: Dict[str, Any] = {},
        dimension_reduction: Literal['pca', 'umap', 'tsne', 'knn',
                                     'kmean'] = 'tsne',
        max_samples: Optional[int] = 2000,
        return_figure: bool = False,
        ax: Optional['Axes'] = None,
        seed: int = 1,
    ) -> Union['Figure', Posterior]:
        """Plot dimension reduced scatter points of the sample set.

    Parameters
    ----------
    classifier : {'svm', 'tree', 'logistic', 'knn', 'lda', 'gbt'}, optional
        classifier for ploting decision contour of each factor, by default None
    classifier_kw : Dict[str, Any], optional
        keyword arguments for the classifier, by default {}
    dimension_reduction : {'pca', 'umap', 'tsne', 'knn', 'kmean'}, optional
        method for dimension reduction, by default 'tsne'
    factor_indices : Optional[Union[int, str, List[Union[int, str]]]], optional
        indicator of which factor will be plotted, by default None
    max_samples : Optional[int], optional
        maximum number of samples to be plotted, by default 2000
    return_figure : bool, optional
        return the figure or add it to the Visualizer for later processing,
        by default False
    seed : int, optional
        seed for random state, by default 1

    Returns
    -------
    Figure or Posterior
        return a `matplotlib.pyplot.Figure` if `return_figure=True` else return
        self for method chaining.
    """
        ## get all relevant factors
        if isinstance(factor_index, string_types):
            factor_index = self.factor_names.index(factor_index)
        factor_indices = int(factor_index)
        f = self.factors[:, factor_index]
        name = self.factor_names[factor_index]
        categorical = self.is_categorical(factor_index)
        f_norm = (f - np.mean(f, axis=0)) / np.std(f, axis=0)
        ## reduce latents dimension
        z = self.dimension_reduce(algorithm=dimension_reduction, seed=seed)
        x_min, x_max = np.min(z[:, 0]), np.max(z[:, 0])
        y_min, y_max = np.min(z[:, 1]), np.max(z[:, 1])
        ## downsample
        if isinstance(max_samples, Number):
            max_samples = int(max_samples)
            if max_samples < z.shape[0]:
                rand = np.random.RandomState(seed=seed)
                ids = rand.choice(np.arange(z.shape[0], dtype=np.int32),
                                  size=max_samples,
                                  replace=False)
                z = z[ids]
                f = f[ids]
        ## train classifier if provided
        n_samples = z.shape[0]
        if classifier is not None:
            xx, yy = np.meshgrid(np.linspace(x_min, x_max, n_samples),
                                 np.linspace(y_min, y_max, n_samples))
            xy = np.c_[xx.ravel(), yy.ravel()]
        ## plotting
        ax = vs.to_axis(ax, is_3D=False)
        cmap = 'bwr'
        # scatter plot
        vs.plot_scatter(x=z, val=f, color=cmap, ax=ax, size=10., alpha=0.5)
        ax.grid(False)
        ax.tick_params(axis='both',
                       bottom=False,
                       top=False,
                       left=False,
                       right=False)
        ax.set_title(name)
        # classifier boundary
        if classifier is not None:
            model = linear_classifier(z,
                                      f,
                                      algo=classifier,
                                      seed=seed,
                                      **classifier_kw)
            ax.contourf(xx,
                        yy,
                        model.predict(xy).reshape(xx.shape),
                        cmap=cmap,
                        alpha=0.4)
        if return_figure:
            return plt.gcf()
        return self.add_figure(
            name=f'scatter_{dimension_reduction}_{str(classifier).lower()}',
            fig=plt.gcf())
Example #22
0
def plot_latents_protein_pairs(Z,
                               y,
                               labels_name,
                               all_pairs=False,
                               title=None,
                               elev=None,
                               azim=None,
                               algo='tsne',
                               show_colorbar=False):
    r""" Label `y` is multi-classes
  i.e. each samples could belong to multiple classes at once

  Returns:
    fig : matplotlib.Figure or None
        if no pair found, return None, otherwise, the
        figure used to plot all protein pairs
    figsize : (`float`, `float`), optional (default=`None`)
      width, height in inches

  """
    labels_name = [standardize_protein_name(i) for i in labels_name]

    if title is None:
        title = ''
    title = '[%s]%s' % (algo, title)
    # ====== Downsample if the data is huge ====== #
    Z, y = downsample_data(Z, y)
    # ====== checking inputs ====== #
    assert Z.ndim == 2, Z.shape
    assert Z.shape[0] == y.shape[0]
    # ====== preprocessing ====== #
    Z = dimension_reduction(Z, algo=algo)

    # ====== select proteins ====== #
    def logit(p):
        eps = np.finfo('float32').eps
        p = np.copy(p)
        p[p == 0] = eps
        p[p == 1] = 1 - eps
        return np.log(p / (1 - p))

    # normalize to 0, 1
    y_min = np.min(y, axis=0, keepdims=True)
    y_max = np.max(y, axis=0, keepdims=True)
    y = (y - y_min) / (y_max - y_min)

    # select most 2 different proteins to create pairs
    labels_index = {name: i for i, name in enumerate(labels_name)}
    pairs = []
    if all_pairs:
        pairs = set([
            '*'.join(sorted(i)) for i in itertools.product(
                labels_index.keys(), labels_index.keys()) if i[0] != i[1]
        ])
        pairs = [i.split('*') for i in pairs]
    else:
        for i, j in PROTEIN_PAIR_NEGATIVE:
            if i in labels_index and j in labels_index:
                pairs.append((i, j))
    n_pairs = len(pairs)
    if n_pairs == 0:
        return None
    # we could handle 5 pairs in 1 row, no problem
    ncol = min(5, n_pairs)
    nrow = int(np.ceil(n_pairs / ncol))
    fig = plt.figure(figsize=(ncol * 4, int(nrow * 3.6)))

    for idx, labels_name in enumerate(pairs):
        ax = plt.subplot(nrow, ncol, idx + 1)
        # polarize y level
        val = np.hstack((y[:, labels_index[labels_name[0]]][:, np.newaxis],
                         y[:, labels_index[labels_name[1]]][:, np.newaxis]))
        # red mean closer to 1, i.e. protein labels_name[1]
        # blue mean closer to -1, i.e. protein labels_name[0]
        val = logit(val[:, 1]) - logit(val[:, 0])
        # normalize again to [-1, 1]
        val = 2 * (val - np.min(val)) / (np.max(val) - np.min(val)) - 1
        # ====== let plotting ====== #
        plot_scatter(x=Z[:, 0],
                     y=Z[:, 1],
                     val=val,
                     legend_enable=False,
                     color='bwr',
                     size=8,
                     elev=elev,
                     azim=azim,
                     alpha=1.,
                     fontsize=8,
                     grid=False,
                     ax=ax,
                     colorbar=True,
                     colorbar_horizontal=True,
                     colorbar_ticks=[labels_name[0], 'Others', labels_name[1]],
                     title='%s' % ('/'.join(labels_name)))
    plt.suptitle(title)
    return fig
from odin.ml import MiniBatchPCA
from sklearn.manifold import TSNE
from odin.utils import UnitTimer, TemporaryDirectory

iris = F.load_iris()
print(iris)
pca = MiniBatchPCA()

X = iris['X'][:]

i = 0
while i < X.shape[0]:
    x = X[i:i + 20]
    i += 20
    pca.partial_fit(x)
    print("Fitting PCA ...")

with UnitTimer():
    for i in range(8):
        x = pca.transform(X)

with UnitTimer():
    for i in range(8):
        x = pca.transform_mpi(X, keep_order=True, ncpu=1, n_components=2)
print("Output shape:", x.shape)

colors = ['r' if i == 0 else ('b' if i == 1 else 'g') for i in iris['y'][:]]
visual.plot_scatter(x[:, 0], x[:, 1], color=colors, size=8)
visual.plot_save('/tmp/tmp.pdf')
# bananab
def plot_epoch(task):
    if task is None:
        curr_epoch = 0
    else:
        curr_epoch = task.curr_epoch
        if not (curr_epoch < 5 or curr_epoch % 5 == 0):
            return
    rand = np.random.RandomState(seed=1234)

    X, y = X_test, y_test
    n_data = X.shape[0]
    Z = f_z(X)
    W, W_stdev_mcmc, W_stdev_analytic = f_w(X)

    X_pca, W_pca_1 = fast_pca(X,
                              W,
                              n_components=2,
                              random_state=rand.randint(10e8))
    W_pca_2 = fast_pca(W, n_components=2, random_state=rand.randint(10e8))
    X_count_sum = np.sum(X, axis=tuple(range(1, X.ndim)))
    W_count_sum = np.sum(W, axis=-1)

    n_visual_samples = 8
    nrow = 13 + n_visual_samples * 3
    V.plot_figure(nrow=int(nrow * 1.8), ncol=18)
    with V.plot_gridSpec(nrow=nrow + 3, ncol=6, hspace=0.8) as grid:
        # plot the latent space
        for i, (z, name) in enumerate(zip(Z, Z_names)):
            if z.shape[1] > 2:
                z = fast_pca(z,
                             n_components=2,
                             random_state=rand.randint(10e8))
            ax = V.subplot(grid[:3, (i * 2):(i * 2 + 2)])
            V.plot_scatter(x=z[:, 0],
                           y=z[:, 1],
                           color=y,
                           marker=y,
                           n_samples=4000,
                           ax=ax,
                           legend_enable=False,
                           legend_ncol=n_classes)
            ax.set_title(name, fontsize=12)
        # plot the reconstruction
        for i, (x, name) in enumerate(
                zip([X_pca, W_pca_1, W_pca_2], [
                    'Original data', 'Reconstruction',
                    'Reconstruction (separated PCA)'
                ])):
            ax = V.subplot(grid[3:6, (i * 2):(i * 2 + 2)])
            V.plot_scatter(x=x[:, 0],
                           y=x[:, 1],
                           color=y,
                           marker=y,
                           n_samples=4000,
                           ax=ax,
                           legend_enable=i == 1,
                           legend_ncol=n_classes,
                           title=name)
        # plot the reconstruction count sum
        for i, (x, count_sum, name) in enumerate(
                zip([X_pca, W_pca_1], [X_count_sum, W_count_sum], [
                    'Original data (Count-sum)', 'Reconstruction (Count-sum)'
                ])):
            ax = V.subplot(grid[6:10, (i * 3):(i * 3 + 3)])
            V.plot_scatter(x=x[:, 0],
                           y=x[:, 1],
                           val=count_sum,
                           n_samples=2000,
                           marker=y,
                           ax=ax,
                           size=8,
                           legend_enable=i == 0,
                           legend_ncol=n_classes,
                           title=name,
                           colorbar=True,
                           fontsize=10)
        # plot the count-sum series
        count_sum_observed = np.sum(X, axis=0).ravel()
        count_sum_expected = np.sum(W, axis=0)
        count_sum_stdev_explained = np.sum(W_stdev_mcmc, axis=0)
        count_sum_stdev_total = np.sum(W_stdev_analytic, axis=0)
        for i, kws in enumerate([
                dict(xscale='linear', yscale='linear', sort_by=None),
                dict(xscale='linear', yscale='linear', sort_by='expected'),
                dict(xscale='log', yscale='log', sort_by='expected')
        ]):
            ax = V.subplot(grid[10:10 + 3, (i * 2):(i * 2 + 2)])
            V.plot_series_statistics(count_sum_observed,
                                     count_sum_expected,
                                     explained_stdev=count_sum_stdev_explained,
                                     total_stdev=count_sum_stdev_total,
                                     fontsize=8,
                                     title="Count-sum" if i == 0 else None,
                                     **kws)
        # plot the mean and variances
        curr_grid_index = 13
        ids = rand.permutation(n_data)
        ids = ids[:n_visual_samples]
        for i in ids:
            observed, expected, stdev_explained, stdev_total = \
                X[i], W[i], W_stdev_mcmc[i], W_stdev_analytic[i]
            observed = observed.ravel()
            for j, kws in enumerate([
                    dict(xscale='linear', yscale='linear', sort_by=None),
                    dict(xscale='linear', yscale='linear', sort_by='expected'),
                    dict(xscale='log', yscale='log', sort_by='expected')
            ]):
                ax = V.subplot(grid[curr_grid_index:curr_grid_index + 3,
                                    (j * 2):(j * 2 + 2)])
                V.plot_series_statistics(observed,
                                         expected,
                                         explained_stdev=stdev_explained,
                                         total_stdev=stdev_total,
                                         fontsize=8,
                                         title="Test Sample #%d" %
                                         i if j == 0 else None,
                                         **kws)
            curr_grid_index += 3
    V.plot_save(os.path.join(FIGURE_PATH, 'latent_%d.png' % curr_epoch),
                dpi=200,
                log=True)
    exit()
Example #25
0
def evaluate(vae: VariationalAutoencoder,
             ds: ImageDataset,
             expdir: str,
             title: str,
             batch_size: int = 64,
             take_count: int = -1,
             n_images: int = 36,
             seed: int = 1):
    n_rows = int(np.sqrt(n_images))
    is_semi = vae.is_semi_supervised()
    is_hierarchical = vae.is_hierarchical()
    ds_kw = dict(batch_size=batch_size, label_percent=1.0, shuffle=False)
    ## prepare
    rand = np.random.RandomState(seed=seed)
    if not os.path.exists(expdir):
        os.makedirs(expdir)
    ## data for training semi-supervised
    train = ds.create_dataset('train', **ds_kw)
    (llkx_train, llky_train, x_org_train, x_rec_train, y_true_train,
     y_pred_train, z_train, pz_train) = _call(vae,
                                              ds=train,
                                              rand=rand,
                                              take_count=take_count,
                                              n_images=n_images,
                                              verbose=True)
    ## data for testing
    test = ds.create_dataset('test', **ds_kw)
    (llkx_test, llky_test, x_org_test, x_rec_test, y_true_test, y_pred_test,
     z_test, pz_test) = _call(vae,
                              ds=test,
                              rand=rand,
                              take_count=take_count,
                              n_images=n_images,
                              verbose=True)
    # === 0. plotting latent-factor pairs
    for idx, z in enumerate(z_test):
        z = z.mean()
        f = y_true_test
        corr_mat = Correlation.Spearman(z, f)  # [n_latents, n_factors]
        plot_latents_pairs(z, f, corr_mat, ds.labels)
        vs.plot_save(f'{expdir}/latent{idx}_factor.pdf', dpi=100, verbose=True)
    # === 0. latent traverse plot
    x_travs = x_org_test
    if x_travs.ndim == 3:  # grayscale image
        x_travs = np.expand_dims(x_travs, -1)
    else:  # color image
        x_travs = np.transpose(x_travs, (0, 2, 3, 1))
    x_travs = x_travs[rand.permutation(x_travs.shape[0])]
    n_visual_samples = 5
    n_traverse_points = 21
    n_top_latents = 10
    plt.figure(figsize=(8, 3 * n_visual_samples))
    for i in range(n_visual_samples):
        images = vae.sample_traverse(x_travs[i:i + 1],
                                     min_val=-np.min(z_test[0].mean()),
                                     max_val=np.max(z_test[0].mean()),
                                     n_best_latents=n_top_latents,
                                     n_traverse_points=n_traverse_points,
                                     mode='linear')
        images = as_tuple(images)[0]
        images = _prepare_images(images.mean().numpy(), normalize=True)
        vs.plot_images(images,
                       grids=(n_top_latents, n_traverse_points),
                       ax=(n_visual_samples, 1, i + 1))
        if i == 0:
            plt.title('Latents traverse')
    plt.tight_layout()
    vs.plot_save(f'{expdir}/latents_traverse.pdf', dpi=180, verbose=True)
    # === 0. prior sampling plot
    images = as_tuple(vae.sample_observation(n=n_images, seed=seed))[0]
    images = _prepare_images(images.mean().numpy(), normalize=True)
    plt.figure(figsize=(5, 5))
    vs.plot_images(images, grids=(n_rows, n_rows), title='Sampled')
    # === 1. reconstruction plot
    plt.figure(figsize=(15, 15))
    vs.plot_images(x_org_train,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 1),
                   title='[Train]Original')
    vs.plot_images(x_rec_train,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 2),
                   title='[Train]Reconstructed')
    vs.plot_images(x_org_test,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 3),
                   title='[Test]Original')
    vs.plot_images(x_rec_test,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 4),
                   title='[Test]Reconstructed')
    plt.tight_layout()
    ## prepare the labels
    label_type = ds.label_type
    if label_type == 'categorical':
        labels_name = ds.labels
        true = np.argmax(y_true_test, axis=-1)
        labels_true = np.array([labels_name[i] for i in true])
        labels_pred = labels_true
        if is_semi:
            pred = np.argmax(y_pred_test.mean().numpy(), axis=-1)
            labels_pred = np.array([labels_name[i] for i in pred])
    elif label_type == 'factor':  # dsprites, shapes3d
        labels_name = ['cube', 'cylinder', 'sphere', 'round'] \
          if 'shapes3d' in ds.name else ['square', 'ellipse', 'heart']
        true = y_true_test[:, 2].astype('int32')
        labels_true = np.array([labels_name[i] for i in true])
        labels_pred = labels_true
        if is_semi:
            pred = get_ymean(y_pred_test)[:, 2].astype('int32')
            labels_pred = np.array([labels_name[i] for i in pred])
    else:  # CelebA
        raise NotImplementedError
    ## confusion matrix
    if is_semi:
        plt.figure(figsize=(8, 8))
        acc = accuracy_score(y_true=true, y_pred=pred)
        vs.plot_confusion_matrix(cm=confusion_matrix(y_true=true, y_pred=pred),
                                 labels=labels_name,
                                 cbar=True,
                                 fontsize=10,
                                 title=f'{title} Acc:{acc:.2f}')
    ## save arrays for later inspections
    with open(f'{expdir}/arrays', 'wb') as f:
        pickle.dump(
            dict(z_train=z_train,
                 y_pred_train=y_pred_train,
                 y_true_train=y_true_train,
                 z_test=z_test,
                 y_pred_test=y_pred_test,
                 y_true_test=y_true_test,
                 labels=labels_name,
                 ds=ds.name,
                 label_type=label_type), f)
    print(f'Exported arrays to "{expdir}/arrays"')
    ## semi-supervised
    z_mean_train = np.concatenate(
        [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_train], -1)
    z_mean_test = np.concatenate(
        [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_test], -1)
    # === 2. scatter points latents plot
    n_points = 5000
    ids = rand.permutation(len(labels_true))[:n_points]
    Y_true = labels_true[ids]
    Y_pred = labels_pred[ids]
    # tsne plot
    n_latents = 0 if len(z_train) == 1 else len(z_train)
    for name, X in zip(
        ['all'] + [f'latents{i}'
                   for i in range(n_latents)], [z_mean_test[ids]] +
        [z_test[i].mean().numpy()[ids] for i in range(n_latents)]):
        print(f'Plot scatter points for {name}')
        X = X.reshape(X.shape[0], -1)  # flatten to 2D
        X = Pipeline([('zscore', StandardScaler()),
                      ('pca', PCA(min(X.shape[1], 512),
                                  random_state=seed))]).fit_transform(X)
        tsne = DimReduce.TSNE(X, n_components=2, framework='sklearn')
        kw = dict(x=tsne[:, 0], y=tsne[:, 1], grid=False, size=12.0, alpha=0.8)
        plt.figure(figsize=(12, 6))
        vs.plot_scatter(color=Y_true,
                        title=f'[True]{title}-{name}',
                        ax=(1, 2, 1),
                        **kw)
        vs.plot_scatter(color=Y_pred,
                        title=f'[Pred]{title}-{name}',
                        ax=(1, 2, 2),
                        **kw)
    ## save all plot
    vs.plot_save(f'{expdir}/analysis.pdf', dpi=180, verbose=True)

    # === 3. show the latents statistics
    n_latents = len(z_train)
    colors = sns.color_palette(n_colors=len(labels_true))
    styles = dict(grid=False,
                  ticks_off=False,
                  alpha=0.6,
                  xlabel='mean',
                  ylabel='stddev')

    # scatter between latents and labels (assume categorical distribution)
    def _show_latents_labels(Z, Y, title):
        plt.figure(figsize=(5 * n_latents, 5), dpi=150)
        for idx, z in enumerate(Z):
            if len(z.batch_shape) == 0:
                mean = np.repeat(np.expand_dims(z.mean(), 0), Y.shape[0], 0)
                stddev = z.sample(Y.shape[0]) - mean
            else:
                mean = flatten(z.mean())
                stddev = flatten(z.stddev())
            y = np.argmax(Y, axis=-1)
            data = [[], [], []]
            for y_i, c in zip(np.unique(y), colors):
                mask = (y == y_i)
                data[0].append(np.mean(mean[mask], 0))
                data[1].append(np.mean(stddev[mask], 0))
                data[2].append([labels_true[y_i]] * mean.shape[1])
            vs.plot_scatter(
                x=np.concatenate(data[0], 0),
                y=np.concatenate(data[1], 0),
                color=np.concatenate(data[2], 0),
                ax=(1, n_latents, idx + 1),
                size=15 if mean.shape[1] < 128 else 8,
                title=f'[Test-{title}]#{idx} - {mean.shape[1]} (units)',
                **styles)
        plt.tight_layout()

    # simple scatter mean-stddev each latents
    def _show_latents(Z, title):
        plt.figure(figsize=(3.5 * n_latents, 3.5), dpi=150)
        for idx, z in enumerate(Z):
            mean = flatten(z.mean())
            stddev = flatten(z.stddev())
            if mean.ndim == 2:
                mean = np.mean(mean, 0)
                stddev = np.mean(stddev, 0)
            vs.plot_scatter(
                x=mean,
                y=stddev,
                ax=(1, n_latents, idx + 1),
                size=15 if len(mean) < 128 else 8,
                title=f'[Test-{title}]#{idx} - {len(mean)} (units)',
                **styles)

    _show_latents_labels(z_test, y_true_test, 'post')
    _show_latents_labels(pz_test, y_true_test, 'prior')
    _show_latents(z_test, 'post')
    _show_latents(pz_test, 'prior')

    # KL statistics
    vs.plot_figure()
    for idx, (qz, pz) in enumerate(zip(z_test, pz_test)):
        kl = []
        qz = Normal(loc=qz.mean(), scale=qz.stddev(), name=f'posterior{idx}')
        pz = Normal(loc=pz.mean(), scale=pz.stddev(), name=f'prior{idx}')
        for s, e in minibatch(batch_size=8, n=100):
            z = qz.sample(e - s)
            # don't do this in GPU, it explodes!
            kl.append((qz.log_prob(z) - pz.log_prob(z)).numpy())
        kl = np.concatenate(kl, 0)  # (mcmc, batch, event)
        # per sample
        kl_samples = np.sum(kl, as_tuple(list(range(2, kl.ndim))))
        kl_samples = logsumexp(kl_samples, 0)
        plt.subplot(n_latents, 2, idx * 2 + 1)
        sns.histplot(kl_samples, bins=50)
        plt.title(f'Z#{idx} KL per sample (nats)')
        # per latent
        kl_latents = np.mean(flatten(logsumexp(kl, 0)), 0)
        plt.subplot(n_latents, 2, idx * 2 + 2)
        plt.plot(np.sort(kl_latents))
        plt.title(f'Z#{idx} KL per dim (nats)')
    plt.tight_layout()

    vs.plot_save(f'{expdir}/latents.pdf', dpi=180, verbose=True)
    def plot_correlation_scatter(self,
                                 omic1=OMIC.transcriptomic,
                                 omic2=OMIC.proteomic,
                                 var_names1='auto',
                                 var_names2='auto',
                                 is_marker_pairs=True,
                                 log1=True,
                                 log2=True,
                                 max_scatter_points=200,
                                 top=3,
                                 bottom=3,
                                 title='',
                                 return_figure=False):
        r""" Mapping from omic1 to omic2

    Arguments:
      omic1, omic2 : instance of OMIC.
        With `omic1` represent the x-axis, and `omic2` represent the y-axis.
      var_names1 : list of all variable name for `omic1`
    """
        omic1 = OMIC.parse(omic1)
        omic2 = OMIC.parse(omic2)
        if isinstance(var_names1, string_types) and var_names1 == 'auto':
            var_names1 = omic1.markers
        if isinstance(var_names2, string_types) and var_names2 == 'auto':
            var_names2 = omic2.markers
        if var_names1 is None or var_names2 is None:
            is_marker_pairs = False
        max_scatter_points = int(max_scatter_points)
        # get all correlations
        corr = self.get_correlation(omic1, omic2)
        corr_map = {(x[0], x[1]): (0 if np.isnan(x[2]) else x[2],
                                   0 if np.isnan(x[3]) else x[3])
                    for x in corr}
        om1_names = self.get_var_names(omic1)
        om2_names = self.get_var_names(omic2)
        om1_idx = {j: i for i, j in enumerate(om1_names)}
        om2_idx = {j: i for i, j in enumerate(om2_names)}
        # extract the data and normalization
        X1 = self.numpy(omic1)
        library = np.sum(X1, axis=1, keepdims=True)
        library = discretizing(library, n_bins=10, strategy='quantile').ravel()
        if log1:
            s = np.sum(X1, axis=1, keepdims=True)
            X1 = np.log1p(X1 / s * np.median(s))
        X2 = self.numpy(omic2)
        if log2:
            s = np.sum(X2, axis=1, keepdims=True)
            X2 = np.log1p(X2 / s * np.median(s))
        ### getting the marker pairs
        all_pairs = []
        # coordinate marker pairs
        if is_marker_pairs:
            pairs = [(i1, i2) for i1, i2 in zip(var_names1, var_names2)
                     if i1 in om1_idx and i2 in om2_idx]
            var_names1 = [i for i, _ in pairs]
            var_names2 = [i for _, i in pairs]
        # filter omic2
        if var_names2 is not None:
            var_names2 = [i for i in var_names2 if i in om2_names]
        else:
            var_names2 = om2_names
        assert len(var_names2) > 0, \
          (f"None of the variables {var_names2} is contained in variable list "
           f"of OMIC {omic2.name}")
        nrow = len(var_names2)
        # filter omic1
        if var_names1 is not None:
            var_names1 = [i for i in var_names1 if i in om1_names]
            ncol = len(var_names1)
            assert len(var_names1) > 0, \
              (f"None of the variables {var_names1} is contained in variable list "
               f"of OMIC {omic1.name}")
            for name2 in var_names2:
                for name1 in var_names1:
                    all_pairs.append((om1_idx[name1], om2_idx[name2]))
        else:
            # top and bottom correlation pairs
            top = int(top)
            bottom = int(bottom)
            ncol = top + bottom
            # pick all top and bottom of omic1 coordinated to omic2
            for name in var_names2:
                i2 = om2_idx[name]
                pairs = sorted([[sum(corr_map[(i1, i2)]), i1]
                                for i1 in range(len(om1_names))])
                for _, i1 in pairs[-top:][::-1] + pairs[:bottom][::-1]:
                    all_pairs.append((i1, i2))
        ### downsampling scatter points
        if max_scatter_points > 0:
            ids = np.random.permutation(len(X1))[:max_scatter_points]
        else:
            ids = np.arange(len(X1), dtype=np.int32)
        ### plotting
        fig = plt.figure(figsize=(ncol * 2, nrow * 2 + 2), dpi=80)
        for i, pair in enumerate(all_pairs):
            ax = plt.subplot(nrow, ncol, i + 1)
            p, s = corr_map[pair]
            idx1, idx2 = pair
            x1 = X1[:, idx1]
            x2 = X2[:, idx2]
            crow = i // ncol
            ccol = i % ncol
            if is_marker_pairs:
                color = 'salmon' if crow == ccol else 'blue'
            else:
                color = 'salmon' if ccol < top else 'blue'
            vs.plot_scatter(x=x1[ids],
                            y=x2[ids],
                            color=color,
                            ax=ax,
                            size=library[ids],
                            size_range=(6, 30),
                            legend_enable=False,
                            linewidths=0.,
                            cbar=False,
                            alpha=0.3)
            # additional title for first column
            ax.set_title(f"{om1_names[idx1]}\n$p={p:.2g}$ $s={s:.2g}$",
                         fontsize=8)
            # beginning of every column
            if i % ncol == 0:
                ax.set_ylabel(f"{om2_names[idx2]}", fontsize=8, weight='bold')
        ## big title
        plt.suptitle(f"[x:{omic1.name}_y:{omic2.name}]{title}", fontsize=10)
        fig.tight_layout(rect=[0.0, 0.02, 1.0, 0.98])
        ### store and return
        if return_figure:
            return fig
        self.add_figure(
            f"corr_{omic1.name}{'log' if log1 else 'raw'}_"
            f"{omic2.name}{'log' if log2 else 'raw'}", fig)
        return self
Example #27
0
 def callback():
     trainer = get_current_trainer()
     x, y = x_test[:1000], y_test[:1000]
     px, qz = vae(x, training=False)
     # latents
     qz_mean = tf.reduce_mean(qz.mean(), axis=0)
     qz_std = tf.reduce_mean(qz.stddev(), axis=0)
     w = tf.reduce_sum(decoder.trainable_variables[0], axis=(0, 1, 2))
     # plot the latents and its weights
     fig = plt.figure(figsize=(6, 4), dpi=200)
     ax = plt.gca()
     l1 = ax.plot(qz_mean,
                  label='mean',
                  linewidth=1.0,
                  linestyle='--',
                  marker='o',
                  markersize=4,
                  color='r',
                  alpha=0.5)
     l2 = ax.plot(qz_std,
                  label='std',
                  linewidth=1.0,
                  linestyle='--',
                  marker='o',
                  markersize=4,
                  color='g',
                  alpha=0.5)
     ax1 = ax.twinx()
     l3 = ax1.plot(w,
                   label='weight',
                   linewidth=1.0,
                   linestyle='--',
                   marker='o',
                   markersize=4,
                   color='b',
                   alpha=0.5)
     lines = l1 + l2 + l3
     labs = [l.get_label() for l in lines]
     ax.grid(True)
     ax.legend(lines, labs)
     img_qz = vs.plot_to_image(fig)
     # reconstruction
     fig = plt.figure(figsize=(5, 5), dpi=120)
     vs.plot_images(np.squeeze(px.mean().numpy()[:25], axis=-1),
                    grids=(5, 5))
     img_res = vs.plot_to_image(fig)
     # latents
     fig = plt.figure(figsize=(5, 5), dpi=200)
     z = fast_umap(qz.mean().numpy())
     vs.plot_scatter(z, color=y, size=12.0, alpha=0.4)
     img_umap = vs.plot_to_image(fig)
     # gradients
     grads = [(k, v) for k, v in trainer.last_train_metrics.items()
              if '_grad/' in k]
     encoder_grad = sum(v for k, v in grads if 'Encoder' in k)
     decoder_grad = sum(v for k, v in grads if 'Decoder' in k)
     return dict(reconstruct=img_res,
                 umap=img_umap,
                 latents=img_qz,
                 qz_mean=qz_mean,
                 qz_std=qz_std,
                 w_decoder=w,
                 llk_test=tf.reduce_mean(px.log_prob(x)),
                 encoder_grad=encoder_grad,
                 decoder_grad=decoder_grad)
Example #28
0
def plot_latents_pairs(z: np.ndarray,
                       f: np.ndarray,
                       correlation: np.ndarray,
                       labels: List[str],
                       n_points: int = 1000,
                       seed: int = 1):
    n_latents, n_factors = correlation.shape
    assert z.shape[1] == n_latents
    assert f.shape[1] == n_factors
    assert z.shape[0] == f.shape[0]
    rand = np.random.RandomState(seed=seed)
    ids = rand.permutation(z.shape[0])
    z = np.asarray(z)[ids][:n_points]
    f = np.asarray(f)[ids][:n_points]
    ## find the best latents for each labels
    f2z = {
        f_idx: z_idx
        for f_idx, z_idx in enumerate(np.argmax(correlation, axis=0))
    }
    ## special cases
    selected_labels = set(labels)
    n_pairs = len(selected_labels) * (len(selected_labels) - 1) // 2
    ## plotting each pairs
    ncol = 2
    nrow = n_pairs
    fig = plt.figure(figsize=(ncol * 3.5, nrow * 3))
    c = 1
    styles = dict(size=10,
                  alpha=0.8,
                  color='bwr',
                  cbar=True,
                  cbar_nticks=5,
                  cbar_ticks_rotation=0,
                  cbar_fontsize=8,
                  fontsize=10,
                  grid=False)
    for f1 in range(n_factors):
        for f2 in range(f1 + 1, n_factors):
            if (labels[f1] not in selected_labels
                    or labels[f2] not in selected_labels):
                continue
            z1 = f2z[f1]
            z2 = f2z[f2]
            vs.plot_scatter(x=z[:, z1],
                            y=z[:, z2],
                            val=f[:, f1].astype(np.float32),
                            xlabel=f'Z{z1}',
                            ylabel=f'Z{z2}',
                            cbar_title=labels[f1],
                            ax=(nrow, ncol, c),
                            **styles)
            vs.plot_scatter(x=z[:, z1],
                            y=z[:, z2],
                            val=f[:, f2].astype(np.float32),
                            xlabel=f'Z{z1}',
                            ylabel=f'Z{z2}',
                            cbar_title=labels[f2],
                            ax=(nrow, ncol, c + 1),
                            **styles)
            c += 2
    plt.tight_layout()
    return fig
Example #29
0
else:
    y_pred = lda.predict_topics(X_test, hard_topics=True, verbose=False)
counts = Counter(y_pred)
y_pred_labels = np.array([f"#{i}({counts[i]})" for i in y_pred])

y_true = np.argmax(y_test, axis=-1)
counts = Counter(y_true)
y_true_labels = np.array([f"{sc.labels[i]}({counts[i]})" for i in y_true])

scores_text = ", ".join([
    f"{key}:{val:.2f}"
    for key, val in unsupervised_clustering_scores(factors=y_true,
                                                   predictions=y_pred).items()
])

fig = plt.figure(figsize=(14, 8))
vs.plot_scatter(x=x_,
                color=y_pred_labels,
                size=12.0,
                ax=(1, 2, 1),
                title="Topics")
vs.plot_scatter(x=x_,
                color=y_true_labels,
                size=12.0,
                ax=(1, 2, 2),
                title="Celltype")
plt.suptitle(f"[{algo}]{os.path.basename(LOGDIR)}\n{scores_text}")
plt.tight_layout(rect=[0.05, 0.0, 1.0, 0.95])
fig.savefig(f"{LOGDIR}.png", dpi=200)
print("Saved image:", f"{LOGDIR}.png")