예제 #1
0
def plot_countsum_series(original,
                         imputed,
                         p=None,
                         reduce_axis=0,
                         title=None,
                         ax=None):
    """
  x: [n_samples, n_genes]
    original count
  w: tuple (expected, stdev_total, stdev_explained) [n_samples, n_genes]
    the prediction
  p: [n_samples, n_genes]
    dropout probability
  """
    if ax is None:
        ax = visual.to_axis(ax)
    reduce_axis = int(reduce_axis)

    if isinstance(imputed, (tuple, list)):  # no statistics provided
        assert len(imputed) == 3
        expected, stdev_total, stdev_explained = imputed
    elif imputed.ndim == 3:
        assert imputed.shape[0] == 3
        expected, stdev_total, stdev_explained = imputed[0], imputed[
            1], imputed[2]
    else:
        raise ValueError()

    count_sum_observed = np.log1p(np.sum(original, axis=reduce_axis))
    count_sum_expected = np.log1p(np.sum(expected, axis=reduce_axis))
    count_sum_stdev_total = np.log1p(np.sum(stdev_total, axis=reduce_axis))
    count_sum_stdev_explained = np.log1p(
        np.sum(stdev_explained, axis=reduce_axis))
    if p is not None:
        p_sum = np.mean(p, axis=reduce_axis)

    ax, handles, indices = plot_series_statistics(
        count_sum_observed,
        count_sum_expected,
        explained_stdev=count_sum_stdev_explained,
        total_stdev=count_sum_stdev_total,
        fontsize=8,
        ax=ax,
        legend_enable=False,
        title=title,
        despine=True if p is None else False,
        return_handles=True,
        return_indices=True,
        xscale='linear',
        yscale='linear',
        sort_by='expected')
    if p is not None:
        _show_zero_inflated_pi(p_sum, ax, handles, indices)
    ax.legend(handles=handles, loc='best', markerscale=4, fontsize=8)
예제 #2
0
def plot_latents_binary(Z,
                        y,
                        labels_name,
                        title=None,
                        elev=None,
                        azim=None,
                        algo='tsne',
                        ax=None,
                        show_legend=True,
                        size=12,
                        fontsize=12,
                        show_scores=True,
                        enable_argmax=True,
                        enable_separated=False):
    from matplotlib import pyplot as plt
    if title is None:
        title = ''
    title = '[%s]%s' % (algo, title)
    ax = to_axis(ax)
    # ====== Downsample if the data is huge ====== #
    Z, y = downsample_data(Z, y)
    # ====== checking inputs ====== #
    assert Z.ndim == 2, Z.shape
    assert Z.shape[0] == y.shape[0]
    num_classes = len(labels_name)
    # ====== preprocessing ====== #
    Z = dimension_reduction(Z, algo=algo)
    # ====== clustering metrics ====== #
    if show_scores:
        scores = clustering_scores(
            latent=Z,
            labels=np.argmax(y, axis=-1) if y.ndim == 2 else y,
            n_labels=num_classes)
        title += '\n'
        for k, v in sorted(scores.items(), key=lambda x: x[0]):
            title += '%s:%.2f ' % (k, v)
    # ====== plotting ====== #
    if enable_argmax:
        y_argmax = np.argmax(y, axis=-1) if y.ndim == 2 else y
        fast_scatter(x=Z,
                     y=y_argmax,
                     labels=labels_name,
                     ax=ax,
                     size=size,
                     title=title,
                     fontsize=fontsize,
                     enable_legend=bool(show_legend))
        ax.grid(False)
    # ====== plot each protein ====== #
    if enable_separated:
        colormap = 'Reds'  # bwr
        ncol = 5 if num_classes <= 20 else 9
        nrow = int(np.ceil(num_classes / ncol))
        fig = plot_figure(nrow=4 * nrow, ncol=20)
        for i, lab in enumerate(labels_name):
            val = K.log_norm(y[:, i], axis=0)
            plot_scatter(x=Z[:, 0],
                         y=Z[:, 1],
                         val=val / np.sum(val),
                         ax=(nrow, ncol, i + 1),
                         color=colormap,
                         size=size,
                         alpha=0.8,
                         fontsize=8,
                         grid=False,
                         title=lab)

        plt.grid(False)
        # big title
        plt.suptitle(title, fontsize=fontsize)
        # show the colorbar
        import matplotlib as mpl
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
        cmap = mpl.cm.get_cmap(name=colormap)
        norm = mpl.colors.Normalize(vmin=0., vmax=1.)
        cb1 = mpl.colorbar.ColorbarBase(cbar_ax,
                                        cmap=cmap,
                                        norm=norm,
                                        orientation='vertical')
        cb1.set_label('Protein markers level')
예제 #3
0
def plot_distance_heatmap(X,
                          labels,
                          labels_name=None,
                          lognorm=True,
                          colormap='hot',
                          ax=None,
                          legend_enable=True,
                          legend_loc='upper center',
                          legend_ncol=3,
                          legend_colspace=0.2,
                          fontsize=10,
                          show_colorbar=True,
                          title=None):
    r"""

  Parameters
  ----------
  X : (n_samples, n_features)
    coordination for scatter points

  labels : (n_samples, n_classes) or (n_samples, 1) or (n_samples,)
    list of classes index, in case of binary classification,
    the list can be float value represent confidence value for
    positive class.

  labels_name : (n_classes,)
    list of classes' name, this will be used to determine
    number of classes

  # visualize_distance(latent_scVI, labels, "scVI")
  """
    from matplotlib.lines import Line2D
    X = K.length_norm(X, axis=-1, epsilon=np.finfo(X.dtype).eps)

    ax = to_axis(ax)
    n_samples, n_dim = X.shape

    # processing labels
    labels = np.array(labels)
    if labels.ndim == 2:
        if labels.shape[1] == 1:
            labels = labels.ravel()
        else:
            labels = np.argmax(labels, axis=-1)
    elif labels.ndim > 2:
        raise ValueError("Only support 1-D or 2-D labels")

    labels_int = labels.astype('int32')
    # float values label (normalize -1 to 1) or binary classification
    if not np.all(labels_int == labels) or \
    (labels_name is not None and len(labels_name) == 2) or \
    (len(np.unique(labels)) == 2):
        min_val = np.min(labels)
        max_val = np.max(labels)
        labels = 2 * (labels - min_val) / (max_val - min_val) - 1
        label_colormap = 'bwr'
    # integer values label and multiple classes classification
    else:
        labels = labels_int
        label_colormap = 'Dark2'

    # ====== sorting label and X ====== #
    order_X = np.vstack(
        [x for _, x in sorted(zip(labels, X), key=lambda pair: pair[0])])
    order_label = np.vstack(
        [y for y, x in sorted(zip(labels, X), key=lambda pair: pair[0])])
    distance = sp.spatial.distance_matrix(order_X, order_X)
    if bool(lognorm):
        distance = np.log1p(distance)
    min_non_zero = np.min(distance[np.nonzero(distance)])
    distance = np.clip(distance, a_min=min_non_zero, a_max=np.max(distance))

    # ====== convert data to image ====== #
    cm = plt.get_cmap(colormap)
    distance_img = cm(distance)
    # diagonal black line (i.e. zero distance)
    for i in range(n_samples):
        distance_img[i, i] = (0, 0, 0, 1)

    cm = plt.get_cmap(label_colormap)
    width = max(int(0.032 * n_samples), 8)
    horz_bar = np.repeat(cm(order_label.T), repeats=width, axis=0)
    vert_bar = np.repeat(cm(order_label), repeats=width, axis=1)

    final_img = np.zeros(shape=(n_samples + width, n_samples + width,
                                distance_img.shape[2]),
                         dtype=distance_img.dtype)
    final_img[width:, width:] = distance_img
    final_img[:width, width:] = horz_bar
    final_img[width:, :width] = vert_bar
    assert np.sum(final_img[:width, :width]) == 0, \
    "Something wrong with my spacial coordination when writing this code!"
    # ====== plotting ====== #
    ax.imshow(final_img)
    ax.axis('off')
    # ====== legend ====== #
    if labels_name is not None and bool(legend_enable):
        cm = plt.get_cmap(label_colormap)
        labels_name = np.asarray(labels_name)
        if len(labels_name) == 2:  # binary (easy peasy)
            all_colors = np.array((cm(np.min(labels)), cm(np.max(labels))))
        else:  # multiple classes
            all_colors = cm(list(range(len(labels_name))))
        legend_elements = [
            Line2D([0], [0],
                   marker='o',
                   color=color,
                   label=name,
                   linewidth=0,
                   linestyle=None,
                   lw=0,
                   markerfacecolor=color,
                   markersize=fontsize // 2)
            for color, name in zip(all_colors, labels_name)
        ]
        ax.legend(handles=legend_elements,
                  markerscale=1.,
                  scatterpoints=1,
                  scatteryoffsets=[0.375, 0.5, 0.3125],
                  loc=legend_loc,
                  bbox_to_anchor=(0.5, -0.01),
                  ncol=int(legend_ncol),
                  columnspacing=float(legend_colspace),
                  labelspacing=0.,
                  fontsize=fontsize - 1,
                  handletextpad=0.1)
    # ====== final configurations ====== #
    if title is not None:
        ax.set_title(str(title), fontsize=fontsize)
    if show_colorbar:
        plot_colorbar(colormap,
                      vmin=np.min(distance),
                      vmax=np.max(distance),
                      ax=ax,
                      orientation='vertical')
    return ax
예제 #4
0
    def plot_scatter(
        self,
        factor_index: Union[int, str],
        classifier: Optional[Literal['svm', 'tree', 'logistic', 'knn', 'lda',
                                     'gbt']] = None,
        classifier_kw: Dict[str, Any] = {},
        dimension_reduction: Literal['pca', 'umap', 'tsne', 'knn',
                                     'kmean'] = 'tsne',
        max_samples: Optional[int] = 2000,
        return_figure: bool = False,
        ax: Optional['Axes'] = None,
        seed: int = 1,
    ) -> Union['Figure', Posterior]:
        """Plot dimension reduced scatter points of the sample set.

    Parameters
    ----------
    classifier : {'svm', 'tree', 'logistic', 'knn', 'lda', 'gbt'}, optional
        classifier for ploting decision contour of each factor, by default None
    classifier_kw : Dict[str, Any], optional
        keyword arguments for the classifier, by default {}
    dimension_reduction : {'pca', 'umap', 'tsne', 'knn', 'kmean'}, optional
        method for dimension reduction, by default 'tsne'
    factor_indices : Optional[Union[int, str, List[Union[int, str]]]], optional
        indicator of which factor will be plotted, by default None
    max_samples : Optional[int], optional
        maximum number of samples to be plotted, by default 2000
    return_figure : bool, optional
        return the figure or add it to the Visualizer for later processing,
        by default False
    seed : int, optional
        seed for random state, by default 1

    Returns
    -------
    Figure or Posterior
        return a `matplotlib.pyplot.Figure` if `return_figure=True` else return
        self for method chaining.
    """
        ## get all relevant factors
        if isinstance(factor_index, string_types):
            factor_index = self.factor_names.index(factor_index)
        factor_indices = int(factor_index)
        f = self.factors[:, factor_index]
        name = self.factor_names[factor_index]
        categorical = self.is_categorical(factor_index)
        f_norm = (f - np.mean(f, axis=0)) / np.std(f, axis=0)
        ## reduce latents dimension
        z = self.dimension_reduce(algorithm=dimension_reduction, seed=seed)
        x_min, x_max = np.min(z[:, 0]), np.max(z[:, 0])
        y_min, y_max = np.min(z[:, 1]), np.max(z[:, 1])
        ## downsample
        if isinstance(max_samples, Number):
            max_samples = int(max_samples)
            if max_samples < z.shape[0]:
                rand = np.random.RandomState(seed=seed)
                ids = rand.choice(np.arange(z.shape[0], dtype=np.int32),
                                  size=max_samples,
                                  replace=False)
                z = z[ids]
                f = f[ids]
        ## train classifier if provided
        n_samples = z.shape[0]
        if classifier is not None:
            xx, yy = np.meshgrid(np.linspace(x_min, x_max, n_samples),
                                 np.linspace(y_min, y_max, n_samples))
            xy = np.c_[xx.ravel(), yy.ravel()]
        ## plotting
        ax = vs.to_axis(ax, is_3D=False)
        cmap = 'bwr'
        # scatter plot
        vs.plot_scatter(x=z, val=f, color=cmap, ax=ax, size=10., alpha=0.5)
        ax.grid(False)
        ax.tick_params(axis='both',
                       bottom=False,
                       top=False,
                       left=False,
                       right=False)
        ax.set_title(name)
        # classifier boundary
        if classifier is not None:
            model = linear_classifier(z,
                                      f,
                                      algo=classifier,
                                      seed=seed,
                                      **classifier_kw)
            ax.contourf(xx,
                        yy,
                        model.predict(xy).reshape(xx.shape),
                        cmap=cmap,
                        alpha=0.4)
        if return_figure:
            return plt.gcf()
        return self.add_figure(
            name=f'scatter_{dimension_reduction}_{str(classifier).lower()}',
            fig=plt.gcf())
예제 #5
0
def plot_countsum_comparison(original,
                             reconstructed,
                             imputed,
                             title,
                             comparing_axis=0,
                             ax=None):
    """
  original : [n_samples, n_genes]
  reconstructed : [n_samples, n_genes]
  imputed : [n_samples, n_genes]
  """
    from matplotlib import pyplot as plt
    ax = visual.to_axis(ax)

    original = original.sum(axis=comparing_axis)
    reconstructed = _mean(reconstructed).sum(axis=comparing_axis)
    imputed = _mean(imputed).sum(axis=comparing_axis)
    assert original.shape == reconstructed.shape == imputed.shape

    sorted_indices = np.argsort(original)

    original = np.log1p(original[sorted_indices])
    reconstructed = np.log1p(reconstructed[sorted_indices])
    imputed = np.log1p(imputed[sorted_indices])

    # ====== plotting the figures ====== #
    colors = seaborn.color_palette(palette='Set2', n_colors=3)

    ax.scatter(original, imputed, c=colors[1], s=3, alpha=0.3)
    ax.scatter(original, reconstructed, c=colors[2], s=3, alpha=0.3)
    # ====== plotting the median line ====== #
    xmin, xmax = ax.get_xlim()
    ymin, ymax = ax.get_ylim()
    max_val = max(xmax, ymax)

    ax.axhline(xmin=0,
               xmax=max_val,
               y=np.median(original),
               color=colors[0],
               linestyle='--',
               linewidth=1.5,
               label="Corrupted Median")
    ax.axhline(xmin=0,
               xmax=max_val,
               y=np.median(imputed),
               color=colors[1],
               linestyle='--',
               linewidth=1.5,
               label="Imputed Median")
    ax.axhline(xmin=0,
               xmax=max_val,
               y=np.median(reconstructed),
               color=colors[2],
               linestyle='--',
               linewidth=1.5,
               label="Reconstructed Median")
    # ====== adjust the aspect ====== #
    visual.plot_aspect(aspect='equal', adjustable='box', ax=ax)
    ax.set_xlim((0, max_val))
    ax.set_ylim((0, max_val))

    ax.plot((0, max_val), (0, max_val),
            color='black',
            linestyle=':',
            linewidth=1)
    plt.legend(fontsize=8)
    ax.set_xlabel("Log-Count of the corrupted data")
    ax.set_ylabel("Log-Count of the reconstructed and imputed data")
    ax.set_title(title)

    return ax