Exemplo n.º 1
0
def scatter(x, y, names, path, plots, color="#4CB391", figformat="png",
            stat=None, log=False, minvalx=0, minvaly=0, title=None,
            plot_settings={}, xmax=None, ymax=None):
    """Create bivariate plots.

    Create four types of bivariate plots of x vs y, containing marginal summaries
    -A scatter plot with histograms on axes
    -A hexagonal binned plot with histograms on axes
    -A kernel density plot with density curves on axes
    -A pauvre-style plot using code from https://github.com/conchoecia/pauvre
    """
    logging.info("Nanoplotter: Creating {} vs {} plots using statistics from {} reads.".format(
        names[0], names[1], x.size))
    if not contains_variance([x, y], names):
        return []
    sns.set(style="ticks", **plot_settings)
    maxvalx = xmax or np.amax(x)
    maxvaly = ymax or np.amax(y)

    plots_made = []

    if plots["hex"]:
        hex_plot = Plot(
            path=path + "_hex." + figformat,
            title="{} vs {} plot using hexagonal bins".format(names[0], names[1]))
        plot = sns.jointplot(
            x=x,
            y=y,
            kind="hex",
            color=color,
            stat_func=stat,
            space=0,
            xlim=(minvalx, maxvalx),
            ylim=(minvaly, maxvaly),
            height=10)
        plot.set_axis_labels(names[0], names[1])
        if log:
            hex_plot.title = hex_plot.title + " after log transformation of read lengths"
            ticks = [10**i for i in range(10) if not 10**i > 10 * (10**maxvalx)]
            plot.ax_joint.set_xticks(np.log10(ticks))
            plot.ax_marg_x.set_xticks(np.log10(ticks))
            plot.ax_joint.set_xticklabels(ticks)
        plt.subplots_adjust(top=0.90)
        plot.fig.suptitle(title or "{} vs {} plot".format(names[0], names[1]), fontsize=25)
        hex_plot.fig = plot
        hex_plot.save(format=figformat)
        plots_made.append(hex_plot)

    sns.set(style="darkgrid", **plot_settings)
    if plots["dot"]:
        dot_plot = Plot(
            path=path + "_dot." + figformat,
            title="{} vs {} plot using dots".format(names[0], names[1]))
        plot = sns.jointplot(
            x=x,
            y=y,
            kind="scatter",
            color=color,
            stat_func=stat,
            xlim=(minvalx, maxvalx),
            ylim=(minvaly, maxvaly),
            space=0,
            height=10,
            joint_kws={"s": 1})
        plot.set_axis_labels(names[0], names[1])
        if log:
            dot_plot.title = dot_plot.title + " after log transformation of read lengths"
            ticks = [10**i for i in range(10) if not 10**i > 10 * (10**maxvalx)]
            plot.ax_joint.set_xticks(np.log10(ticks))
            plot.ax_marg_x.set_xticks(np.log10(ticks))
            plot.ax_joint.set_xticklabels(ticks)
        plt.subplots_adjust(top=0.90)
        plot.fig.suptitle(title or "{} vs {} plot".format(names[0], names[1]), fontsize=25)
        dot_plot.fig = plot
        dot_plot.save(format=figformat)
        plots_made.append(dot_plot)

    if plots["kde"]:
        idx = np.random.choice(x.index, min(2000, len(x)), replace=False)
        kde_plot = Plot(
            path=path + "_kde." + figformat,
            title="{} vs {} plot using a kernel density estimation".format(names[0], names[1]))
        plot = sns.jointplot(
            x=x[idx],
            y=y[idx],
            kind="kde",
            clip=((0, np.Inf), (0, np.Inf)),
            xlim=(minvalx, maxvalx),
            ylim=(minvaly, maxvaly),
            space=0,
            color=color,
            stat_func=stat,
            shade_lowest=False,
            height=10)
        plot.set_axis_labels(names[0], names[1])
        if log:
            kde_plot.title = kde_plot.title + " after log transformation of read lengths"
            ticks = [10**i for i in range(10) if not 10**i > 10 * (10**maxvalx)]
            plot.ax_joint.set_xticks(np.log10(ticks))
            plot.ax_marg_x.set_xticks(np.log10(ticks))
            plot.ax_joint.set_xticklabels(ticks)
        plt.subplots_adjust(top=0.90)
        plot.fig.suptitle(title or "{} vs {} plot".format(names[0], names[1]), fontsize=25)
        kde_plot.fig = plot
        kde_plot.save(format=figformat)
        plots_made.append(kde_plot)

    if plots["pauvre"] and names == ['Read lengths', 'Average read quality'] and log is False:
        pauvre_plot = Plot(
            path=path + "_pauvre." + figformat,
            title="{} vs {} plot using pauvre-style @conchoecia".format(names[0], names[1]))
        sns.set(style="white", **plot_settings)
        margin_plot(df=pd.DataFrame({"length": x, "meanQual": y}),
                    Y_AXES=False,
                    title=title or "Length vs Quality in Pauvre-style",
                    plot_maxlen=None,
                    plot_minlen=0,
                    plot_maxqual=None,
                    plot_minqual=0,
                    lengthbin=None,
                    qualbin=None,
                    BASENAME="whatever",
                    path=pauvre_plot.path,
                    fileform=[figformat],
                    dpi=600,
                    TRANSPARENT=True,
                    QUIET=True)
        plots_made.append(pauvre_plot)
    plt.close("all")
    return plots_made
Exemplo n.º 2
0
def violin_or_box_plot(df,
                       y,
                       figformat,
                       path,
                       y_name,
                       title=None,
                       plot="violin",
                       log=False,
                       palette=None):
    """Create a violin or boxplot from the received DataFrame.

    The x-axis should be divided based on the 'dataset' column,
    the y-axis is specified in the arguments
    """
    comp = Plot(path=path + "NanoComp_" + y.replace(' ', '_') + '.' +
                figformat,
                title="Comparing {}".format(y))
    if y == "quals":
        comp.title = "Comparing base call quality scores"

    if plot == 'violin':
        logging.info("Nanoplotter: Creating violin plot for {}.".format(y))
        process_violin_and_box(ax=sns.violinplot(x="dataset",
                                                 y=y,
                                                 data=df,
                                                 inner=None,
                                                 cut=0,
                                                 palette=palette,
                                                 linewidth=0),
                               log=log,
                               plot_obj=comp,
                               title=title,
                               y_name=y_name,
                               figformat=figformat,
                               ymax=np.amax(df[y]))
    elif plot == 'box':
        logging.info("Nanoplotter: Creating box plot for {}.".format(y))
        process_violin_and_box(ax=sns.boxplot(x="dataset",
                                              y=y,
                                              data=df,
                                              palette=palette),
                               log=log,
                               plot_obj=comp,
                               title=title,
                               y_name=y_name,
                               figformat=figformat,
                               ymax=np.amax(df[y]))
    elif plot == 'ridge':
        logging.info("Nanoplotter: Creating ridges plot for {}.".format(y))
        comp.fig, axes = joypy.joyplot(df,
                                       by="dataset",
                                       column=y,
                                       title=title or comp.title,
                                       x_range=[-0.05, np.amax(df[y])])
        if log:
            xticks = [float(i.get_text()) for i in axes[-1].get_xticklabels()]
            axes[-1].set_xticklabels([10**i for i in xticks])
        axes[-1].set_xticklabels(axes[-1].get_xticklabels(),
                                 rotation=30,
                                 ha='center')
        comp.save(format=figformat)
    else:
        logging.error("Unknown comp plot type {}".format(plot))
        sys.exit("Unknown comp plot type {}".format(plot))
    plt.close("all")
    return [comp]