Ejemplo n.º 1
0
def plot_distro(data, data_descriptor=None, ax=None, grid_columns=None, x_tics=None, y_tics=None, linewidth=2,\
    box_plot=False, violin_plot=False, bar_plot=False, point_plot=False, count_plot=False, swarm_plot=False, strip_plot=False,\
    y_range=None, x_log=False, y_log=False, sep=None, horizontal=False, order=None, hue_order=None,\
    box_outlier=5, box_whis=1.5, box_notch=False,\
    violin_scale='area', violin_inner='box', violin_cut=2, violin_split=False, violin_scale_hue=False,\
    estimator=plot_utils.ESTIMATORS['mean'], ci=95, capsize=0.2,\
    strip_jitter=True, points_colour=None,\
    point_markers='o', point_marker_size=2,\
    title=None, x_label=None, y_label=None, rotate_x_tics=None, bold_x_tics=False,\
    hide_x_tick_marks=False, hide_y_tick_marks=False,\
    hide_x_ticks=False, hide_y_ticks=False,\
    label=None, show_legend=True, legend_out=False, legend_out_pad=None,\
    despine=True, style='whitegrid_ticks', fontsize=16, colours=None, palette=None, reverse_palette=False, ncolours=None, figsize=None, fig_padding=0.1, dpi=None, output=None, out_format=None,\
    box_kwargs=None, violin_kwargs=None, bar_kwargs=None, point_kwargs=None, count_kwargs=None, swarm_kwargs=None, strip_kwargs=None, legend_kwargs=None, kwargs=None,
    color=None):
    '''
	Parameters
	----------
	data : pandas.DataFrame with indexes 'x', 'y', 'hue'
	or
	data : 3-tuple of lists x, y, and hue ([...], [...], [...]), thus for a single plot data=(None, [...], None)
	'''

    if not (box_plot or violin_plot or bar_plot or point_plot or count_plot
            or swarm_plot or strip_plot):
        raise PlotsError(
            message=
            'Specify a plot to plot: box or violin or bar or count or swarm or strip'
        )
    if count_plot and (box_plot or violin_plot or bar_plot or point_plot
                       or swarm_plot or strip_plot):
        raise PlotsError(
            message='Count plot cannot be combined with any other plot')
    if point_plot and (box_plot or violin_plot or bar_plot or count_plot
                       or swarm_plot or strip_plot):
        raise PlotsError(
            message='Point plot cannot be combined with any other plot')

    # PARSED DATA IS SUPPLIED AS pd.DataFrame({'x':[] , 'y':[] , 'hue':[] }) OR ([...], [...], [...])
    if data is not None:
        if data_descriptor is not None:
            raise PlotsError(
                message=
                'You can specify only one of the mutually exclusive arguments: "data" or "data_descriptor"'
            )
        if grid_columns is not None:
            raise PlotsError(
                message=
                'The grid_columns option is only used when the data is supplied as a filename'
            )
        # pd.DataFrame({'x':[] , 'y':[] , 'hue':[] })
        if isinstance(data, pd.DataFrame) and (
            ('y' not in data and not count_plot) or
            ('x' not in data and count_plot)):
            raise PlotsError(
                message=
                'The dataframe has to have a "y" column, optionally also "x" and "hue" columns'
                if not count_plot else
                'The dataframe has to have a "x" column, optionally also a "hue" column'
            )
        # ([...], [...], [...])
        else:
            if len(data) != 3 or ((data[1] is None and not count_plot) or
                                  (data[0] is None and count_plot)):
                raise PlotsError(
                    message=
                    'The data should be a pandas.DataFrame or a 3-tuple of lists (x, y, hue). '
                    +
                    ('The y list cannot be None, x and hue are optional (can be None).'
                     if not count_plot else
                     'The x list cannot be None, hue is optional (can be None), y is ignored (set it None).'
                     ))
            x, y, hue = data
            data = make_xyhue_dataframe(x, y, hue)

    # DATA IS SUPPLIED AS A FILENAME - THE FILE IS A THREE_COLUMN_FILE OR A GRID_LIKE_FILE
    else:
        if data_descriptor is None:
            raise PlotsError(
                message='You must specify "data" or "data_descriptor"')
        # GRID_LIKE_FILE
        if '?' not in data_descriptor:
            if grid_columns is None:
                raise PlotsError(
                    message=
                    'You must specify columns for y, or x and y, or x, y and hue, or specify grid-columns: filename?y or filename?x?y or filename?x?y?hue or filename --grid_columns x1 x2 x3'
                )
            if x_tics is None:
                raise PlotsError(
                    message=
                    'When specifying --grid_columns, you have to specify also the labels for xtics (--xtics)'
                )
            if len(x_tics) != len(grid_columns):
                raise PlotsError(
                    message=
                    'The number of columns (--grid_columns) differ from the number of xtics (--xtics)'
                )
            x, y, hue = read_grid_as_xyhue(data_descriptor,
                                           grid_columns,
                                           x_tics,
                                           sep=sep,
                                           comment='#')
            data = make_xyhue_dataframe(x, y, hue)
            x_tics = None  #hack; they have been already assigned; this will prevent re-assigning again below
        # THREE_COLUMN_FILE
        else:
            if grid_columns is not None:
                raise PlotsError(
                    message=
                    'When specifying columns using ?, you cannot specify --grid_columns.'
                )
            data_descriptor_split = data_descriptor.split('?')
            filename = data_descriptor_split[0]
            if len(data_descriptor_split) == 2:
                columns_names = {
                    'y': int(data_descriptor_split[1])
                } if not count_plot else {
                    'x': int(data_descriptor_split[1])
                }
            elif len(data_descriptor_split) == 3:
                columns_names = {
                    'x': int(data_descriptor_split[1]),
                    'y': int(data_descriptor_split[2])
                } if not count_plot else {
                    'x': int(data_descriptor_split[1]),
                    'hue': int(data_descriptor_split[2])
                }
            elif len(data_descriptor_split) == 4:
                if count_plot:
                    raise PlotsError(
                        message='For count plot, you can only specify x and hue'
                    )
                columns_names = {
                    'x': int(data_descriptor_split[1]),
                    'y': int(data_descriptor_split[2]),
                    'hue': int(data_descriptor_split[3])
                }
            else:
                raise PlotsError(
                    message=
                    'You can specify only up to 3 columns for x, y and hue: filename or filename?y or filename?x?y or filename?x?y?hue'
                )
            names, columns = zip(
                *sorted(columns_names.items(), key=lambda x: x[1]))
            data = plot_utils.read_table(filename,
                                         usecols=columns,
                                         names=names,
                                         sep=sep)

    if x_tics is not None:
        if 'x' in data:
            raise PlotsError(
                message=
                'You specified the x-categories in your data, thus you cannot use the xtics option'
            )
        elif len(x_tics) != 1:
            raise PlotsError(
                message=
                'You can specify only one x-category using xtics (unless you use a grid-like file input)'
            )
        else:
            data['x'] = np.array([x_tics[0]] * len(data['y']))

    if not show_legend and (legend_out or legend_kwargs != None):
        raise PlotsError(
            message=
            'If you hide the legend (--hide_legend or show_legend=False), you cannot set it outside (legend_out) or set it properties (legend_kwargs)'
        )
    if y_range is not None and (len(y_range) != 2 or y_range[0] >= y_range[1]):
        raise PlotsError(
            message=
            'You need to provide exactly two numbers to set yrange: "min max"')
    for my_order, variable in ((order, 'x'), (hue_order, 'hue')):
        if my_order is not None:
            if variable not in data:
                raise PlotsError(
                    message=
                    'You specified order for %s but your data does not contain %s'
                    % (variable, variable))
            set_variable = set(data[variable])
            if len(my_order) != len(set_variable) or set(
                    my_order) != set_variable:
                raise PlotsError(
                    message='The specified order does not match %s' % variable)
    if ci is not None and ci != 'std' and not callable(ci) and (ci < 0
                                                                or ci > 100):
        raise PlotsError(message='"ci" must be None or within 0 and 100')

    #if (swarm or strip) and 'hue' in data: raise PlotsError(message='Swarmplot is not supported when plotting plots with hue.')
    if colours is not None:
        if palette is not None:
            raise PlotsError(
                message=
                'You can specify only one of the mutually exclusive arguments: "colours" or "palette"'
            )
        if ncolours is not None:
            raise PlotsError(
                message=
                'You cannot specify "ncolours" when you specified "colours"')
    if figsize is not None and len(figsize) != 2:
        raise PlotsError(
            message=
            'You need to provide exactly two numbers to set figure size: "width height"'
        )

    ##

    font = plot_utils.init_plot_style(style,
                                      fontsize,
                                      colours,
                                      palette,
                                      reverse_palette,
                                      ncolours,
                                      hide_x_tick_marks=hide_x_tick_marks,
                                      hide_y_tick_marks=hide_y_tick_marks)
    fig = plt.figure()
    if figsize is not None:
        fig.set_figwidth(figsize[0])
        fig.set_figheight(figsize[1])
    if dpi is not None: fig.set_dpi(dpi)

    default_kwargs = {
        'ax': ax,
        'x': 'x' if 'x' in data else None,
        'y': 'y' if 'y' in data else None,
        'hue': 'hue' if 'hue' in data else None,
        'data': data,
        'orient': 'h' if horizontal else 'v',
        'order': order,
        'hue_order': hue_order,
        'linewidth': linewidth
    }

    if box_plot:
        axs = sb.boxplot(**plot_utils.merged_kwargs(
            default_kwargs,
            dict(fliersize=box_outlier,
                 whis=box_whis,
                 notch=box_notch,
                 flierprops={'marker': 'o'}), kwargs, box_kwargs))
    if violin_plot:
        axs = sb.violinplot(**plot_utils.merged_kwargs(
            default_kwargs,
            dict(scale=violin_scale,
                 inner=violin_inner,
                 cut=violin_cut,
                 split=violin_split,
                 scale_hue=violin_scale_hue), kwargs, violin_kwargs))
    if bar_plot:
        axs = sb.barplot(**plot_utils.merged_kwargs(
            default_kwargs, dict(estimator=estimator, ci=ci, capsize=capsize),
            kwargs, bar_kwargs))
    if point_plot:
        axs = sb.pointplot(**plot_utils.merged_kwargs(
            default_kwargs,
            dict(markers=point_markers,
                 estimator=estimator,
                 ci=ci,
                 capsize=capsize), kwargs, point_kwargs))
    if count_plot:
        axs = sb.countplot(
            **plot_utils.merged_kwargs(default_kwargs, kwargs, count_kwargs))
    if swarm_plot:
        axs = sb.swarmplot(**plot_utils.merged_kwargs(
            default_kwargs, dict(edgecolor='black', linewidth=1),
            dict(facecolor=points_colour) if points_colour is not None else {},
            kwargs, swarm_kwargs))
    if strip_plot:
        axs = sb.stripplot(**plot_utils.merged_kwargs(
            default_kwargs,
            dict(edgecolor='black',
                 linewidth=1,
                 jitter=strip_jitter,
                 split=True),
            dict(facecolor=points_colour) if points_colour is not None else {},
            kwargs, strip_kwargs))

    if point_plot:
        plt.setp(axs.collections, sizes=[point_marker_size])
        plt.setp(axs.lines, linewidth=linewidth)
    if y_range is not None: axs.set_ylim(y_range[0], y_range[1])
    axs.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))
    if y_tics is not None: axs.set_yticks(y_tics)
    if x_log: axs.set_xscale('log')
    if y_log: axs.set_yscale('log')
    if x_label is not None:
        axs.set_xlabel(x_label, labelpad=8, fontproperties=font.get('b'))
    else:
        axs.set_xlabel('', labelpad=8)
    if y_label is not None:
        axs.set_ylabel(y_label, labelpad=10, fontproperties=font.get('b'))
    else:
        axs.set_ylabel('', labelpad=10)
    if title is not None:
        ttl = axs.set_title(title, fontproperties=font.get('b'))
        ttl.set_position([.5, 1.05])
    if label is not None:
        axs.text(label[1],
                 label[2],
                 label[0],
                 horizontalalignment='left',
                 verticalalignment='top',
                 transform=axs.transAxes,
                 fontproperties=font.get('b'))
    plt.setp(axs.get_xticklabels(), rotation=rotate_x_tics)
    plot_utils.set_fontproperties(font.get('b' if bold_x_tics else 'n'),
                                  axs.get_xticklabels())
    plot_utils.set_fontproperties(font.get('n'), axs.get_yticklabels())
    if hide_x_ticks: axs.xaxis.set_major_locator(ticker.NullLocator())
    if hide_y_ticks: axs.yaxis.set_major_locator(ticker.NullLocator())
    if despine or style in ['despine_ticks', 'whitegrid_ticks']:
        sb.despine(top=True, right=True)

    #Legend is a little tricky if the option is to be outside of the plot
    artists = []
    if show_legend:
        legend_handles, legend_labels = axs.get_legend_handles_labels()
        num_plots = sum([
            box_plot, violin_plot, bar_plot, point_plot, count_plot,
            swarm_plot, strip_plot
        ])
        if num_plots != 1:
            legend_handles, legend_labels = legend_handles[:len(
                legend_handles) / num_plots], legend_labels[:len(legend_labels
                                                                 ) / num_plots]
        legend_out_kwargs = dict(
            bbox_to_anchor=(1, 1), loc=2,
            borderpad=legend_out_pad) if legend_out else None
        legend_handles, legend_labels = legend_handles[::-1], legend_labels[::
                                                                            -1]
        legend = axs.legend(
            legend_handles, legend_labels,
            **plot_utils.merged_kwargs(dict(prop=font.get('n')),
                                       legend_out_kwargs, legend_kwargs))
        artists.append(legend)

    plt.tight_layout()

    if output is not None:
        plt.savefig(output,
                    format=out_format,
                    dpi=dpi,
                    additional_artists=artists,
                    bbox_inches='tight',
                    pad_inches=fig_padding)
    else:
        plt.show()
    plt.close()
Ejemplo n.º 2
0
def plot_curves(data, data_descriptors=None, labels=None, linewidth=2, y_tics=None, \
   x_range=None, y_range=None, x_log=False, y_log=False, sep='\t',\
   title=None, x_label=None, y_label=None, rotate_x_tics=None, bold_x_tics=False,\
   hide_x_tick_marks=False, hide_y_tick_marks=False,\
   hide_x_ticks=False, hide_y_ticks=False,\
   label=None, show_legend=True, legend_out=False, legend_out_pad=None, show_auc=True, despine=False,\
   style='despine_ticks', font=None, fontsize=16, colours=None, palette=None, reverse_palette=False, ncolours=None,\
   figsize=None, fig_padding=0.1, dpi=None, output=None, out_format=None, legend_kwargs=None, kwargs=None):
    '''
	Parameters
	----------
	data : a list of tuples of lists aka a list of curves
	[([x11, x12, ...], [y11, y12, ...]), ([x21, x22, ...], [y21, y22, ...]), ...] 
	'''
    if data is not None:
        for x, y in data:
            if len(x) != len(y):
                PlotsError(message='unpaired samples detected: x size: ' +
                           str(len(x)) + ' y size: ' + str(len(y)))
        if data_descriptors is not None:
            raise PlotsError(
                message=
                'You can specify only one of the mutually exclusive arguments: "x" and "y" or "data_descriptors"'
            )
    else:
        if data_descriptors is None:
            raise PlotsError(
                message='You must specify "data" or "data_descriptors"')
        total_lists = 0
        for data_descriptor in data_descriptors:
            total_lists += data_descriptor.count('?')
        if total_lists % 2 != 0:
            raise PlotsError(
                message='you need to specify an even number of columns ' +
                str(total_lists))

    if not show_legend and (legend_out or legend_kwargs != None):
        raise PlotsError(
            message=
            'If you hide the legend (--hide_legend or show_legend=False), you cannot set it outside (legend_out) or set it properties (legend_kwargs)'
        )
    for x_y_range in [(x_range, 'xrange'), (y_range, 'yrange')]:
        if x_y_range[0] is not None and (len(x_y_range[0]) != 2 or
                                         x_y_range[0][0] >= x_y_range[0][1]):
            raise PlotsError(
                message=
                'You need to provide exactly two numbers to set %s: "min max"'
                % x_y_range[1])
    if colours is not None:
        if palette is not None:
            raise PlotsError(
                message=
                'You can specify only one of the mutually exclusive arguments: "colours" or "palette"'
            )
    if figsize is not None and len(figsize) != 2:
        raise PlotsError(
            message=
            'You need to provide exactly two numbers to set figure size: "width height"'
        )

    if data is None:
        data = parse_data(data_descriptors, sep=sep)
    if labels is not None:
        if len(data) != len(labels):
            PlotsError(message='wrong number of labels: ' + str(len(data)) +
                       ', ' + str(len(labels)))

    font = plot_utils.init_plot_style(style,
                                      fontsize,
                                      colours,
                                      palette,
                                      reverse_palette,
                                      ncolours,
                                      hide_x_tick_marks=hide_x_tick_marks,
                                      hide_y_tick_marks=hide_y_tick_marks)
    fig = plt.figure()
    if figsize is not None:
        fig.set_figwidth(figsize[0])
        fig.set_figheight(figsize[1])
    if dpi is not None: fig.set_dpi(dpi)
    axs = plt.gca()

    for i, curve in enumerate(data):
        legend_label = labels[i] if labels is not None else ('curve ' +
                                                             str(i + 1))
        if show_auc:
            auc = ml_utils.calc_AUC(zip(curve[0], curve[1]))
            legend_label += ' (AUC = ' + ('%.3f' % auc) + ')'
        plt.plot(
            curve[0], curve[1],
            **plot_utils.merged_kwargs(
                {
                    'label': legend_label,
                    'linewidth': linewidth
                }, kwargs))
    if x_range is not None: axs.set_xlim(x_range[0], x_range[1])
    if y_range is not None: axs.set_ylim(y_range[0], y_range[1])
    if y_tics is not None: axs.set_yticks(y_tics)
    if x_log: axs.set_xscale('log')
    if y_log: axs.set_yscale('log')
    if x_label is not None:
        axs.set_xlabel(x_label, labelpad=15, fontproperties=font.get('b'))
    if y_label is not None:
        axs.set_ylabel(y_label, labelpad=15, fontproperties=font.get('b'))
    if title is not None: axs.set_title(title, fontproperties=font.get('b'))
    if label is not None:
        axs.text(0.05,
                 0.95,
                 label,
                 horizontalalignment='left',
                 verticalalignment='top',
                 transform=axs.transAxes,
                 fontproperties=font.get('n'))
    if show_legend: axs.legend(loc='lower right', prop=font.get('n'))
    plt.setp(axs.get_xticklabels(), rotation=rotate_x_tics)
    plot_utils.set_fontproperties(font.get('b' if bold_x_tics else 'n'),
                                  axs.get_xticklabels())
    plot_utils.set_fontproperties(font.get('n'), axs.get_yticklabels())
    if hide_x_ticks: axs.xaxis.set_major_locator(ticker.NullLocator())
    if hide_y_ticks: axs.yaxis.set_major_locator(ticker.NullLocator())
    if despine or style in ['despine_ticks', 'whitegrid_ticks']:
        sb.despine(top=True, right=True)

    #Legend is a little tricky if the option is to be outside of the plot
    artists = []
    if show_legend:
        legend_out_kwargs = dict(
            bbox_to_anchor=(1, 1), loc=2,
            borderpad=legend_out_pad) if legend_out else None
        legend = axs.legend(**plot_utils.merged_kwargs(
            dict(prop=font.get('n'), loc='lower right'), legend_out_kwargs,
            legend_kwargs))
        artists.append(legend)

    plt.tight_layout()
    if output is not None:
        plt.savefig(output,
                    format=out_format,
                    dpi=dpi,
                    additional_artists=artists,
                    bbox_inches='tight',
                    pad_inches=fig_padding)
    else:
        plt.show()
    plt.close()
Ejemplo n.º 3
0
def plot_hist(data, data_descriptors=None, labels=None, auto_labels=False, y_tics=None, linewidth=2, nbins=None, binw=None, groups=None, factors=None, ncols=None,\
			x_range=None, y_range=None, x_log=False, y_log=False, sep='\t', hist=True, kde=True, rug=True, normed=True, title=None, x_label=None, y_label=None, rotate_x_tics=None, bold_x_tics=False,\
			hide_x_tick_marks=False, hide_y_tick_marks=False,\
			hide_x_ticks=False, hide_y_ticks=False,\
			label=None, show_binw=True, show_legend=True, legend_out=False, legend_out_pad=None, despine=True, style='despine_ticks',\
			fontsize=16, colours=None, palette=None, reverse_palette=False, ncolours=None, figsize=None, fig_padding=0.1, dpi=None, output=None, out_format=None, legend_kwargs=None, kwargs=None):
	'''
	Parameters
	----------
	data : list of lists (subplots) of lists (histograms)
	       [ [ [1, 1, 2, 2] ], [ [0.1, 0.2, 0.3], [0.3, 0.4, 0.5] ] ]
	       --------------------------data----------------------------
	         ---subplot_1----  --------------subplot_2-------------
	           -histogram-       ---histogram---  ---histogram--- 
	'''
	if data is not None:
		if data_descriptors is not None: raise PlotsError(message='You can specify only one of the mutually exclusive arguments: "data" or "data_descriptors"')
		if groups is not None: raise PlotsError(message='You cannot specify "groups" if "data" is specified"')
		if factors is not None: raise PlotsError(message='You cannot specify "factors" if "data" is specified')
		data_items_count = sum([len(x) for x in data])
	else:
		if data_descriptors is None: raise PlotsError(message='You must specify "data" or "data_descriptors"')
		if len(data_descriptors) == 1:
			data_items_count = data_descriptors[0].count('?')
		else:
			data_items_count = len(data_descriptors)
			for d in data_descriptors:
				if d.count('?') != 1: raise PlotsError(message='data filenames and 0-based columns must be in this format: filenameA?columnX')
	
	if groups is not None and factors is not None: raise PlotsError(message='You can specify only one of the mutually exclusive arguments: "groups" or "factors"')
	if groups is not None and sum(groups) != data_items_count: raise PlotsError(message='Number of data items specified as "data_descriptors" must be equal to the sum of counts specified by "groups"')	
	if auto_labels and labels is not None: raise PlotsError(message='You cannot specify "labels" if "auto_labels" is set')
	
	if factors is not None: 
		if len(factors) == 1:
			factors_count = factors[0].count('?')
		else:
			factors_count = len(factors)
			for f in factors:
				if f.count('?') != 1: raise PlotsError(message='factor filenames and 0-based columns must be in this format: filenameA?columnX')
		if factors_count != 1 and factors_count != data_items_count: raise PlotsError(message='Number of data items specified as "data_descriptors" must be equal to the number of factors specified as "factors"')
	
	if labels is not None and factors is None and len(labels) != data_items_count: raise PlotsError(message='Number of data items specified as "data_descriptors" must be equal to the number of labels specified as "label"')
	
	if not show_legend and (legend_out or legend_kwargs != None): raise PlotsError(message='If you hide the legend (--hide_legend or show_legend=False), you cannot set it outside (legend_out) or set it properties (legend_kwargs)')
	for x_y_range in [(x_range, 'xrange'), (y_range, 'yrange')]:
		if x_y_range[0] is not None and (len(x_y_range[0]) != 2 or x_y_range[0][0] >= x_y_range[0][1]): raise PlotsError(message='You need to provide exactly two numbers to set %s: "min max"' % x_y_range[1])
	if nbins is not None and binw is not None: raise PlotsError(message='You can specify only one of the mutually exclusive arguments: "bins" or "binw"')
	if colours is not None:
		if palette is not None: raise PlotsError(message='You can specify only one of the mutually exclusive arguments: "colours" or "palette"')
		if ncolours is not None: raise PlotsError(message='You cannot specify "ncolours" when you specified "colours"')
	if figsize is not None and len(figsize) != 2: raise PlotsError(message='You need to provide exactly two numbers to set figure size: "width height"')
	if not hist and not kde and not rug: raise PlotsError(message='You need to plot at least one of a) histogram or b) KDE density or c) a rug plot')
		
	#if nbins is None and binw is None: print 'INFO: calculating the number of histogram bins using the Freedman-Diaconis rule'
	if y_label is None: y_label = 'Density' if kde or normed else 'Frequency'
	if ncols is None:
		if groups is not None: 
			ncols = len(groups) if len(groups) < 4 else 2
		elif factors is not None:
			ncols = data_items_count if data_items_count < 4 else 2
		else:
			ncols = 1
	
	if data is None:
		if groups is not None:
			force_groups = groups
		elif factors is not None:
			force_groups = [1] * data_items_count
		else:
			force_groups = [data_items_count]
		data, parsed_labels = parse_data(data_descriptors, groups=force_groups, factors=factors, auto_labels=auto_labels, sep=sep)
		if labels is None: labels = parsed_labels
		if len(labels) != len(parsed_labels): raise ValueError

	font = plot_utils.init_plot_style(style, fontsize, colours, palette, reverse_palette, ncolours, hide_x_tick_marks=hide_x_tick_marks, hide_y_tick_marks=hide_y_tick_marks)
	nrows = int(len(data) / ncols) + (len(data) % ncols != 0)
	fig, axs = plt.subplots(nrows, ncols, sharex=False, sharey=False, squeeze=False)
	if figsize is not None: fig.set_figwidth(figsize[0]); fig.set_figheight(figsize[1])
	if dpi is not None: fig.set_dpi(dpi)
	
	artists, l = [], 0
	if label is None: label = ''
	original_label = label
	for i, group_data in enumerate(data):
		label = original_label
		r = int(i / ncols)
		c = i % ncols
		bins, binw = get_bins(group_data, nbins=nbins, binw=binw if nbins is None else None, minimum=x_range[0] if x_range is not None else None, maximum=x_range[1] if x_range is not None else None)
		last_l = l
		for data_list in group_data:
			default_kwargs = dict(a=data_list, ax=axs[r][c], hist=hist, kde=kde, rug=rug, norm_hist=normed, label=labels[l])
			try:
				sb.distplot(kde_kws={'lw': linewidth}, hist_kws={'linewidth': linewidth}, **plot_utils.merged_kwargs(default_kwargs, dict(bins=bins), kwargs))
				l += 1
			except ValueError:
				l = last_l
				label += '\nWARNING: number of bins is 10'
				for data_list in group_data:
					sb.distplot(kde_kws={'lw': linewidth}, hist_kws={'linewidth': linewidth}, **plot_utils.merged_kwargs(default_kwargs, dict(bins=10), kwargs))
					l += 1
				break
			
		if x_range is not None: axs[r][c].set_xlim(x_range[0], x_range[1])
		if y_range is not None: axs[r][c].set_ylim(y_range[0], y_range[1])
		if y_tics is not None: axs.set_yticks(y_tics)
		if x_log: axs[r][c].set_xscale('log')
		if y_log: axs[r][c].set_yscale('log')
		if x_label is not None: axs[r][c].set_xlabel(x_label, labelpad=15, fontproperties=font.get('b'))
		if y_label is not None: axs[r][c].set_ylabel(y_label, labelpad=15, fontproperties=font.get('b'))
		if title is not None: axs[r][c].set_title(title, fontproperties=font.get('b'))
		#if show_binw and binw is not None: axs[r][c].text(0.05, 0.95 if label is None else 0.9, 'bin width ' + str(binw if nbins is None else round(binw, 4)), fontsize=fontsize, horizontalalignment='left', verticalalignment='top', transform=axs[r][c].transAxes)
		if show_binw and binw is not None: label += ('\n\n' if label is not None else '') + 'bin width ' + str(binw if nbins is None else round(binw, 4))
		if label is not None: axs[r][c].text(0.05, 0.95, label, horizontalalignment='left', verticalalignment='top', transform=axs[r][c].transAxes, fontproperties=font.get('n'))
		if show_legend: axs[r][c].legend(prop=font.get('n'))
		#Legend is a little tricky if the option is to be outside of the plot
		if show_legend:
			legend_out_kwargs = dict(bbox_to_anchor=(1, 1), loc=2, borderpad=legend_out_pad) if legend_out else None
			legend = axs[r][c].legend(**plot_utils.merged_kwargs(dict(prop=font.get('n')), legend_out_kwargs, legend_kwargs))
			artists.append(legend)
		plt.setp(axs[r][c].get_xticklabels(), rotation=rotate_x_tics)
		plot_utils.set_fontproperties(font.get('b' if bold_x_tics else 'n'), axs[r][c].get_xticklabels())
		plot_utils.set_fontproperties(font.get('n'), axs[r][c].get_yticklabels())
		if hide_x_ticks: axs[r][c].xaxis.set_major_locator(ticker.NullLocator())
		if hide_y_ticks: axs[r][c].yaxis.set_major_locator(ticker.NullLocator())

	for empty in range(c + 1, ncols):
		axs[r][empty].axis('off')

	if despine or style in ['despine_ticks', 'whitegrid_ticks']: sb.despine(top=True, right=True)
	
	plt.tight_layout()
	if output is not None:
		plt.savefig(output, format=out_format, dpi=dpi, additional_artists=artists, bbox_inches='tight', pad_inches=fig_padding)
	else:
		plt.show()
	plt.close()
Ejemplo n.º 4
0
def plot_scatter(x, y, data_descriptors=None, linewidth=2, y_tics=None, x_range=None, y_range=None, x_log=False, y_log=False, sep='\t', title=None, x_label=None, y_label=None, rotate_x_tics=None, bold_x_tics=False,\
    hide_x_tick_marks=False, hide_y_tick_marks=False,\
    hide_x_ticks=False, hide_y_ticks=False,\
    auto_labels=False, label=None, show_corr=True, show_legend=False, legend_out=False, legend_out_pad=None, fit_reg=False, ci=None, despine=True, style='despine_ticks',\
    fontsize=16, colours=None, palette=None, reverse_palette=False, ncolours=None, figsize=None, fig_padding=0.1, dpi=None, output=None, out_format=None, legend_kwargs=None, kwargs=None):

    if (x is None and y is not None) or (x is not None and y is None):
        raise PlotsError(message='one sample is None: ' +
                         ('x is None' if x is None else 'y is None'))
    elif x is not None and y is not None:
        if len(x) != len(y):
            raise PlotsError(message='Non-matching number of plots')
        for xx, yy in zip(x, y):
            if len(xx) != len(yy):
                raise PlotsError(
                    message='unpaired samples detected: x size: ' +
                    str(len(xx)) + ' y size: ' + str(len(yy)))
        if data_descriptors is not None:
            raise PlotsError(
                message=
                'You can specify only one of the mutually exclusive arguments: "x" and "y" or "data_descriptors"'
            )
    else:
        if data_descriptors is None:
            raise PlotsError(
                message='You must specify "data" or "data_descriptors"')
        if len(data_descriptors) == 1:
            if data_descriptors[0].count('?') != 2:
                raise PlotsError(
                    message=
                    'When using a single data_descriptor, you need to specify exactly two columns with ?'
                )
        elif len(data_descriptors) == 2:
            if data_descriptors[0].count(
                    '?') != 1 or data_descriptors[1].count('?') != 1:
                raise PlotsError(
                    message=
                    'When using two data_descriptors, you need to specify exactly one column for each with ?'
                )
        else:
            raise PlotsError(
                message=
                'Either specify one data_descriptor with two columns (2 x ?) or two data_descriptors with one column each (1 x ?)'
            )
    if auto_labels and (x_label is not None or y_label is not None):
        raise PlotsError(
            message=
            'You cannot specify "xlabel" or "ylabel" if "auto_labels" is set')
    if ci is not None and (ci < 0 or ci > 100):
        raise PlotsError(message='"ci" must be None or within 0 and 100')

    if not show_legend and (legend_out or legend_kwargs != None):
        raise PlotsError(
            message=
            'If you hide the legend (--hide_legend or show_legend=False), you cannot set it outside (legend_out) or set it properties (legend_kwargs)'
        )
    for x_y_range in [(x_range, 'xrange'), (y_range, 'yrange')]:
        if x_y_range[0] is not None and (len(x_y_range[0]) != 2 or
                                         x_y_range[0][0] >= x_y_range[0][1]):
            raise PlotsError(
                message=
                'You need to provide exactly two numbers to set %s: "min max"'
                % x_y_range[1])
    if colours is not None:
        if palette is not None:
            raise PlotsError(
                message=
                'You can specify only one of the mutually exclusive arguments: "colours" or "palette"'
            )
    if figsize is not None and len(figsize) != 2:
        raise PlotsError(
            message=
            'You need to provide exactly two numbers to set figure size: "width height"'
        )

    if x is None and y is None:
        x, y, parsed_label_x, parsed_label_y = parse_data(
            data_descriptors, auto_labels=auto_labels, sep=sep)
        if x_label is None: x_label = parsed_label_x
        if y_label is None: y_label = parsed_label_y
        x, y = [x], [y]

    font = plot_utils.init_plot_style(style,
                                      fontsize,
                                      colours,
                                      palette,
                                      reverse_palette,
                                      ncolours,
                                      hide_x_tick_marks=hide_x_tick_marks,
                                      hide_y_tick_marks=hide_y_tick_marks)
    fig = plt.figure()
    if figsize is not None:
        fig.set_figwidth(figsize[0])
        fig.set_figheight(figsize[1])
    if dpi is not None: fig.set_dpi(dpi)
    for xx, yy in zip(x, y):
        default_kwargs = dict(x=np.array(xx),
                              y=np.array(yy),
                              data=None,
                              x_estimator=None,
                              x_bins=None,
                              x_ci='ci',
                              scatter=True,
                              fit_reg=fit_reg,
                              ci=ci,
                              n_boot=1000,
                              units=None,
                              order=1,
                              logistic=False,
                              lowess=False,
                              robust=False,
                              logx=False,
                              x_partial=None,
                              y_partial=None,
                              truncate=False,
                              dropna=False,
                              x_jitter=None,
                              y_jitter=None,
                              label=None,
                              color=None,
                              marker='o',
                              ax=None)
        axs = sb.regplot(line_kws={'linewidth': linewidth},
                         **plot_utils.merged_kwargs(default_kwargs, kwargs))
    if x_range is not None: axs.set_xlim(x_range[0], x_range[1])
    if y_range is not None: axs.set_ylim(y_range[0], y_range[1])
    if y_tics is not None: axs.set_yticks(y_tics)
    if x_log: axs.set_xscale('log')
    if y_log: axs.set_yscale('log')
    if x_label is not None:
        axs.set_xlabel(x_label, labelpad=15, fontproperties=font.get('b'))
    if y_label is not None:
        axs.set_ylabel(y_label, labelpad=15, fontproperties=font.get('b'))
    if title is not None: axs.set_title(title, fontproperties=font.get('b'))
    if label is not None:
        axs.text(0.05,
                 0.95,
                 label,
                 horizontalalignment='left',
                 verticalalignment='top',
                 transform=axs.transAxes,
                 fontproperties=font.get('n'))
    if show_corr and len(x) == 1:
        r, _ = ml_utils.calc_r(x[0], y[0])
        axs.text(0.85,
                 0.95,
                 'r',
                 horizontalalignment='left',
                 verticalalignment='top',
                 transform=axs.transAxes,
                 fontproperties=font.get('i'))
        axs.text(0.87,
                 0.95,
                 '= ' + ('%.3f' % r),
                 horizontalalignment='left',
                 verticalalignment='top',
                 transform=axs.transAxes,
                 fontproperties=font.get('n'))
    plt.setp(axs.get_xticklabels(), rotation=rotate_x_tics)
    plot_utils.set_fontproperties(font.get('b' if bold_x_tics else 'n'),
                                  axs.get_xticklabels())
    plot_utils.set_fontproperties(font.get('n'), axs.get_yticklabels())
    if hide_x_ticks: axs.xaxis.set_major_locator(ticker.NullLocator())
    if hide_y_ticks: axs.yaxis.set_major_locator(ticker.NullLocator())
    if despine or style in ['despine_ticks', 'whitegrid_ticks']:
        sb.despine(top=True, right=True)

    #Legend is a little tricky if the option is to be outside of the plot
    artists = []
    if show_legend:
        legend_out_kwargs = dict(
            bbox_to_anchor=(1, 1), loc=2,
            borderpad=legend_out_pad) if legend_out else None
        legend = axs.legend(
            **plot_utils.merged_kwargs(dict(
                prop=font.get('n')), legend_out_kwargs, legend_kwargs))
        artists.append(legend)

    plt.tight_layout()
    if output is not None:
        plt.savefig(output,
                    format=out_format,
                    dpi=dpi,
                    additional_artists=artists,
                    bbox_inches='tight',
                    pad_inches=fig_padding)
    else:
        plt.show()
    plt.close()