Exemplo n.º 1
0
 def _update_plot(self, axis, view):
     if self.plot_type == 'regplot':
         sns.regplot(x=view.x,
                     y=view.y,
                     data=view.data,
                     ax=axis,
                     **self.style)
     elif self.plot_type == 'boxplot':
         self.style.pop('return_type', None)
         self.style.pop('figsize', None)
         sns.boxplot(view.data[view.y],
                     view.data[view.x],
                     ax=axis,
                     **self.style)
     elif self.plot_type == 'violinplot':
         sns.violinplot(view.data[view.y],
                        view.data[view.x],
                        ax=axis,
                        **self.style)
     elif self.plot_type == 'interact':
         sns.interactplot(view.x,
                          view.x2,
                          view.y,
                          data=view.data,
                          ax=axis,
                          **self.style)
     elif self.plot_type == 'corrplot':
         sns.corrplot(view.data, ax=axis, **self.style)
     elif self.plot_type == 'lmplot':
         sns.lmplot(x=view.x,
                    y=view.y,
                    data=view.data,
                    ax=axis,
                    **self.style)
     elif self.plot_type in ['pairplot', 'pairgrid', 'facetgrid']:
         map_opts = [(k, self.style.pop(k)) for k in self.style.keys()
                     if 'map' in k]
         if self.plot_type == 'pairplot':
             g = sns.pairplot(view.data, **self.style)
         elif self.plot_type == 'pairgrid':
             g = sns.PairGrid(view.data, **self.style)
         elif self.plot_type == 'facetgrid':
             g = sns.FacetGrid(view.data, **self.style)
         for opt, args in map_opts:
             plot_fn = getattr(sns, args[0]) if hasattr(
                 sns, args[0]) else getattr(plt, args[0])
             getattr(g, opt)(plot_fn, *args[1:])
         plt.close(self.handles['fig'])
         self.handles['fig'] = plt.gcf()
     else:
         super(SNSFramePlot, self)._update_plot(axis, view)
Exemplo n.º 2
0
 def _update_plot(self, axis, view):
     style = self._process_style(self.style[self.cyclic_index])
     if self.plot_type == 'factorplot':
         opts = dict(style, **({'hue': view.x2} if view.x2 else {}))
         sns.factorplot(x=view.x, y=view.y, data=view.data, **opts)
     elif self.plot_type == 'regplot':
         sns.regplot(x=view.x, y=view.y, data=view.data, ax=axis, **style)
     elif self.plot_type == 'boxplot':
         style.pop('return_type', None)
         style.pop('figsize', None)
         sns.boxplot(view.data[view.y], view.data[view.x], ax=axis, **style)
     elif self.plot_type == 'violinplot':
         if view.x:
             sns.violinplot(view.data[view.y],
                            view.data[view.x],
                            ax=axis,
                            **style)
         else:
             sns.violinplot(view.data, ax=axis, **style)
     elif self.plot_type == 'interact':
         sns.interactplot(view.x,
                          view.x2,
                          view.y,
                          data=view.data,
                          ax=axis,
                          **style)
     elif self.plot_type == 'corrplot':
         sns.corrplot(view.data, ax=axis, **style)
     elif self.plot_type == 'lmplot':
         sns.lmplot(x=view.x, y=view.y, data=view.data, ax=axis, **style)
     elif self.plot_type in ['pairplot', 'pairgrid', 'facetgrid']:
         style_keys = list(style.keys())
         map_opts = [(k, style.pop(k)) for k in style_keys if 'map' in k]
         if self.plot_type == 'pairplot':
             g = sns.pairplot(view.data, **style)
         elif self.plot_type == 'pairgrid':
             g = sns.PairGrid(view.data, **style)
         elif self.plot_type == 'facetgrid':
             g = sns.FacetGrid(view.data, **style)
         for opt, args in map_opts:
             plot_fn = getattr(sns, args[0]) if hasattr(
                 sns, args[0]) else getattr(plt, args[0])
             getattr(g, opt)(plot_fn, *args[1:])
         plt.close(self.handles['fig'])
         self.handles['fig'] = plt.gcf()
     else:
         super(SNSFramePlot, self)._update_plot(axis, view)
Exemplo n.º 3
0
def plot_stackdist(data, size=.5, aspect=12, x_labels=None,
                   y_labels=None, palette=None, g=None):
    """
    Plot stacked distributions (a.k.a Joy plot) of data.

    Parameters
    ----------

    data : list of ndarrays
        A list of 2D numpy arrays with dimensions (n_observations, n_features)
    size : scalar, optional (default: .5)
        Height (in inches) of each facet. See also: ``aspect``.
    aspect : scalar, optional (default: 12)
        Aspect ratio of each facet, so that ``aspect * size`` gives the width
        of each facet in inches.
    x_labels : list, optional (default: None)
        A list of str labels for feature-axis
    y_labels : list, optional (default: None)
        A list of str labels for y-axis
    palette : list of colors, optional (default: None)
        List of colors for plots. If ``None``, the default MSMExplorer colors
        are used.
    g : Seaborn.FacetGrid, optional (default: None)
        Pre-initialized FacetGrid to use for plotting.

    Returns
    -------
    g : Seaborn.FacetGrid
        Seaborn FacetGrid of the stacked distributions.

    """

    n_feat = data[0].shape[1]
    x = np.concatenate([d.ravel() for d in data], axis=0)
    f = np.concatenate([(np.ones_like(d) * np.arange(d.shape[1])).ravel()
                        for d in data], axis=0)
    g = np.concatenate([i * np.ones_like(d.ravel()).astype(int)
                        for i, d in enumerate(data)], axis=0)
    df = pd.DataFrame(dict(x=x, f=f, g=g))

    if not palette:
        palette = list(palettes.msme_rgb.values())[::-1]

    # Initialize the FacetGrid object
    g = sns.FacetGrid(df, row="g", col="f", hue="f",
                      aspect=aspect, size=size, palette=palette)

    # Draw the densities in a few steps
    global row_count, col_count
    col_count = 0
    row_count = 0

    def kdeplot(x, color='w', **kwargs):
        global row_count, col_count

        if color != 'w':
            color = sns.light_palette(
                color, n_colors=len(data) + 1)[row_count + 1]
        sns.kdeplot(x, color=color, **kwargs)

        col_count = (col_count + 1) % n_feat

        if col_count == 0:
            row_count = (row_count + 1) % len(data)

    g.map(kdeplot, "x", clip_on=False, shade=True, alpha=1., bw=.2)
    g.map(kdeplot, "x", clip_on=False, color='w', lw=2, bw=.2)

    # Add y labels
    g.set_titles("")
    g.set_xlabels("")
    for i, ax in enumerate(g.axes):
        if y_labels is not None:
            ax[0].text(0, .2, y_labels[i], fontweight="bold", color='k',
                       ha="left", va="center", transform=ax[0].transAxes)
        for j, a in enumerate(ax):
            a.set_facecolor((0, 0, 0, 0))
            if i == 0 and x_labels is not None:
                a.set_title(x_labels[j])

    # Set the subplots to overlap
    g.fig.subplots_adjust(hspace=-.25, wspace=0.1)

    # Remove axes details that don't play will with overlap
    g.set(yticks=[])
    g.set(xticks=[])
    g.despine(bottom=False, left=True)
    return g
Exemplo n.º 4
0
def roi_distributions(
    df_path,
    ascending=False,
    cmap='viridis',
    exclude_tissue_type=[],
    max_rois=7,
    save_as=False,
    small_roi_cutoff=8,
    start=0.0,
    stop=1.0,
    text_side='left',
    xlim=None,
    ylim=None,
):
    """Plot the distributions of values inside 3D image regions of interest.

	Parameters
	----------

	df_path : str
		Path to a `pandas.DataFrame` object containing a 'value' a 'Structure', and a 'tissue type' column.
	ascending : boolean, optional
		Whether to plot the ROI distributions from lowest to highest mean
		(if `False` the ROI distributions are plotted from highest to lowest mean).
	cmap : string, optional
		Name of matplotlib colormap which to color the plot array with.
	exclude_tissue_type : list, optional
		What tissue types to discount from plotting.
		Values in this list will be ckecked on the 'tissue type' column of `df`.
		This is commonly used to exclude cerebrospinal fluid ROIs from plotting.
	max_rois : int, optional
		How many ROIs to limit the plot to.
	save_as : str, optional
		Path to save the figure to.
	small_roi_cutoff : int, optional
		Minimum number of rows per 'Structure' value required to add the respective 'Structure' value to the plot
		(this corresponds to the minimum number of voxels which a ROI needs to have in order to be included in the plot).
	start : float, optional
		At which fraction of the colormap to start.
	stop : float, optional
		At which fraction of the colormap to stop.
	text_side : {'left', 'right'}, optional
		Which side of the plot to set the `df` 'Structure'-column values on.
	xlim : list, optional
		X-axis limits, passed to `seaborn.FacetGrid()`
	ylim : list, optional
		Y-axis limits, passed to `seaborn.FacetGrid()`
	"""

    mpl.rcParams["xtick.major.size"] = 0.0
    mpl.rcParams["ytick.major.size"] = 0.0
    mpl.rcParams["axes.facecolor"] = (0, 0, 0, 0)

    df_path = path.abspath(path.expanduser(df_path))

    df = pd.read_csv(df_path)
    if small_roi_cutoff:
        for i in list(df['Structure'].unique()):
            if len(df[df['Structure'] == i]) < small_roi_cutoff:
                df = df[df['Structure'] != i]
    df['mean'] = df.groupby('Structure')['value'].transform('mean')
    df = df.sort_values(['mean'], ascending=ascending)
    if exclude_tissue_type:
        df = df[~df['tissue type'].isin(exclude_tissue_type)]
    if max_rois:
        uniques = list(df['Structure'].unique())
        keep = uniques[:max_rois]
        df = df[df['Structure'].isin(keep)]
    structures = list(df['Structure'].unique())

    # Define colors
    cm_subsection = np.linspace(start, stop, len(structures))
    cmap = plt.get_cmap(cmap)
    pal = [cmap(x) for x in cm_subsection]

    # Initialize the FacetGrid object
    aspect = mpl.rcParams['figure.figsize']
    ratio = aspect[0] / float(aspect[1])
    g = sns.FacetGrid(
        df,
        row='Structure',
        hue='Structure',
        aspect=max_rois * ratio,
        size=aspect[1] / max_rois,
        palette=pal,
        xlim=xlim,
        ylim=ylim,
    )

    # Draw the densities in a few steps
    lw = mpl.rcParams['lines.linewidth']
    g.map(sns.kdeplot,
          'value',
          clip_on=False,
          gridsize=500,
          shade=True,
          alpha=1,
          lw=lw / 4. * 3,
          bw=.2)
    g.map(sns.kdeplot,
          'value',
          clip_on=False,
          gridsize=500,
          color="w",
          lw=lw,
          bw=.2)
    g.map(plt.axhline, y=0, lw=lw, clip_on=False)

    # Define and use a simple function to label the plot in axes coordinates
    def label(x, color, label):
        ax = plt.gca()
        if text_side == 'left':
            text = ax.text(
                0,
                .04,
                label,
                fontweight="bold",
                color=color,
                ha="left",
                va="bottom",
                transform=ax.transAxes,
            )
        if text_side == 'right':
            text = ax.text(
                1,
                .04,
                label,
                fontweight="bold",
                color=color,
                ha="right",
                va="bottom",
                transform=ax.transAxes,
            )
        text.set_path_effects([
            path_effects.Stroke(linewidth=lw, foreground='w'),
            path_effects.Normal()
        ])

    g.map(label, 'value')

    # Set the subplots to overlap
    g.fig.subplots_adjust(hspace=-.25)

    # Remove axes details that don't play will with overlap
    g.set_titles("")
    g.set(yticks=[])
    g.despine(bottom=True, left=True)

    if save_as:
        save_as = path.abspath(path.expanduser(save_as))
        plt.savefig(save_as)