def scatter_plot2(df1, df2, xcol, ycol, domain, color1='black', color2='red', xname=None, yname=None, log=False, width=6, height=6, clamp=True, tickCount=5):
    assert len(domain) == 2

    POINT_SIZE = 1.5
    DASH_PATTERN = (0, (6, 2))

    if xname is None:
        xname = xcol
    if yname is None:
        yname = ycol

    # formatter for axes' labels
    ax_formatter = mizani.custom_format('{:n}')

    if clamp:  # clamp overflowing values if required
        df1 = df1.copy(deep=True)
        df1.loc[df1[xcol] > domain[1], xcol] = domain[1]
        df1.loc[df1[ycol] > domain[1], ycol] = domain[1]

        df2 = df2.copy(deep=True)
        df2.loc[df2[xcol] > domain[1], xcol] = domain[1]
        df2.loc[df2[ycol] > domain[1], ycol] = domain[1]

    # generate scatter plot
    scatter = p9.ggplot(df1)
    scatter += p9.aes(x=xcol, y=ycol)
    scatter += p9.geom_point(size=POINT_SIZE, na_rm=True, color=color1, alpha=0.5)
    scatter += p9.geom_point(size=POINT_SIZE, na_rm=True, data=df2, color=color2, alpha=0.5)
    scatter += p9.labs(x=xname, y=yname)

    # rug plots
    scatter += p9.geom_rug(na_rm=True, sides="tr", color=color1, alpha=0.05)
    scatter += p9.geom_rug(na_rm=True, sides="tr", data=df2, color=color2, alpha=0.05)

    if log:  # log scale
        scatter += p9.scale_x_log10(limits=domain, labels=ax_formatter)
        scatter += p9.scale_y_log10(limits=domain, labels=ax_formatter)
    else:
        scatter += p9.scale_x_continuous(limits=domain, labels=ax_formatter)
        scatter += p9.scale_y_continuous(limits=domain, labels=ax_formatter)

    # scatter += p9.theme_xkcd()
    scatter += p9.theme_bw()
    scatter += p9.theme(panel_grid_major=p9.element_line(color='#666666', alpha=0.5))
    scatter += p9.theme(panel_grid_minor=p9.element_blank())
    scatter += p9.theme(figure_size=(width, height))
    scatter += p9.theme(text=p9.element_text(size=24, color="black"))

    # generate additional lines
    scatter += p9.geom_abline(intercept=0, slope=1, linetype=DASH_PATTERN)  # diagonal
    scatter += p9.geom_vline(xintercept=domain[1], linetype=DASH_PATTERN)  # vertical rule
    scatter += p9.geom_hline(yintercept=domain[1], linetype=DASH_PATTERN)  # horizontal rule

    res = scatter

    return res
def test_aesthetics():
    p = (
        ggplot(df) + geom_rug(aes('x', 'y'), size=2) +
        geom_rug(aes('x+2*n', 'y+2*n', alpha='z'), size=2, sides='tr') +
        geom_rug(
            aes('x+4*n', 'y+4*n', linetype='factor(z)'), size=2, sides='t') +
        geom_rug(aes('x+6*n', 'y+6*n', color='factor(z)'), size=2, sides='b') +
        geom_rug(aes('x+8*n', 'y+8*n', size='z'), sides='tblr'))

    assert p + _theme == 'aesthetics'
Example #3
0
def test_aesthetics():
    p = (
        ggplot(df) + geom_rug(aes('x', 'y'), size=2) +
        geom_rug(aes('x+2*n', 'y+2*n', alpha='z'), size=2, sides='tr') +
        geom_rug(
            aes('x+4*n', 'y+4*n', linetype='factor(z)'), size=2, sides='t') +
        geom_rug(aes('x+6*n', 'y+6*n', color='factor(z)'), size=2, sides='b') +
        geom_rug(aes('x+8*n', 'y+8*n', size='z'), sides='tblr'))

    if six.PY2:
        # Small displacement in y-axis text
        assert p + _theme == ('aesthetics', {'tol': 4})
    else:
        assert p + _theme == 'aesthetics'
Example #4
0
def test_coord_flip():
    p = (ggplot(df)
         + geom_rug(aes('x', 'y'), size=2, sides='l')
         + coord_flip()
         )

    assert p + _theme == 'coord_flip'
Example #5
0
def test_aesthetics():
    p = (ggplot(df) +
         geom_rug(aes('x', 'y'), size=2) +
         geom_rug(aes('x+2*n', 'y+2*n', alpha='z'),
                  size=2, sides='tr') +
         geom_rug(aes('x+4*n', 'y+4*n', linetype='factor(z)'),
                  size=2, sides='t') +
         geom_rug(aes('x+6*n', 'y+6*n', color='factor(z)'),
                  size=2, sides='b') +
         geom_rug(aes('x+8*n', 'y+8*n', size='z'),
                  sides='tblr'))

    if six.PY2:
        # Small displacement in y-axis text
        assert p + _theme == ('aesthetics', {'tol': 4})
    else:
        assert p + _theme == 'aesthetics'
Example #6
0
    def _generate_plot(self, x, y, xlabel=None, ylabel=None, title=None):
        df = pd.DataFrame({"x": x, "y": y, "fit": self.predict(x)})

        p = gg.ggplot(df, gg.aes("x", "y"))

        # Add points to the continuous plot
        if self.outcome_type == "continuous":
            p += gg.geom_point(color="steelblue", alpha=1 / 4)

        # When the outcome is binary, use log odds
        #
        # There appears to be an ongoing bug in plotnine that is
        # Making the below not work
        # else:
        #     p += gg.stat_summary_bin(geom="point", fun_y=np.mean,
        #                              color="steelblue")

        p += gg.geom_rug(sides='b')
        plot_data = pd.DataFrame({
            "x_axis": utils.bin_array(x),
            "y_axis": self.predict(utils.bin_array(x))
        })
        p += gg.geom_line(data=plot_data,
                          mapping=gg.aes("x_axis", "y_axis"),
                          size=1,
                          color="black")
        for knot in self.knots:
            p += gg.geom_point(gg.aes(x=knot, y=self.predict(knot)),
                               shape="x",
                               size=4,
                               color="darkblue")

        if xlabel is not None:
            p += gg.xlab(xlabel)
        if ylabel is not None:
            p += gg.ylab(ylabel)
        if title is not None:
            p += gg.ggtitle(title)
        return p
Example #7
0
tsne_results_df[clusters_colname] = clusters

plot_title = 'SHAP-Based Clusters in T-SNE SHAP Space'
x_axis_label = 'T-SNE Component 1'
y_axis_label = 'T-SNE Component 2'
xlim = [tsne_results_df.iloc[:, 0].min(), tsne_results_df.iloc[:, 0].max()]
ylim = [tsne_results_df.iloc[:, 1].min(), tsne_results_df.iloc[:, 1].max()]

plot = (p9.ggplot(tsne_results_df,
                    p9.aes(y=tsne_results_df.columns[1], 
                           x=tsne_results_df.columns[0],
                           group=clusters_colname,
                           color=clusters_colname
                           ))
        + p9.geom_point(size=2)
        + p9.geom_rug()
        + p9.stat_ellipse()
        + p9.xlim(xlim[0], xlim[1])
        + p9.ylim(ylim[0], ylim[1])
        #+ p9.scale_color_gradient(low='blue', high='yellow')
        #+ p9.scale_color_manual(values=colors)
        + p9.theme_light(base_size=18)
        + p9.ggtitle(plot_title)
        + p9.labs(y=y_axis_label,
                  x=x_axis_label)
        )

plot_filename = 'shap_clusters.png'
plot.save(plot_filename, width=10, height=10)
from IPython.display import Image
Image(filename=plot_filename)
# In[18]:

gg.options.figure_size = (6.4, 4.8)

# Make sure to drop duplicates of redundant gene, perturbation, and cell line columns
# Not removing replicates will put more weight on genes with more measurements

cor_density_gg = (
    gg.ggplot(
        summary_corr_df.drop_duplicates(
            ["Metadata_cell_line", "Metadata_gene_name", "replicate_type"]
        ),
        gg.aes(x="correlation_guide")) + \
        gg.geom_density(gg.aes(fill="Metadata_cell_line"),
                        alpha=0.4) + \
        gg.geom_rug(gg.aes(color="Metadata_cell_line"),
                    show_legend={'color': False}) + \
        gg.theme_bw() + \
    gg.theme(
            subplots_adjust={"wspace": 0.2},
            axis_text=gg.element_text(size=7),
            axis_title=gg.element_text(size=9),
            strip_text=gg.element_text(size=6, color="black"),
            strip_background=gg.element_rect(colour="black", fill="#fdfff4"),
        ) + \
        gg.xlim([-0.5, 1]) + \
        gg.xlab("Median Correlation of All Guides Across Genes") + \
        gg.ylab("Density") + \
        gg.facet_wrap("~replicate_type", nrow=2, scales="free") + \
        gg.scale_fill_manual(name="Cell Line",
                             values=["#1b9e77", "#d95f02", "#7570b3"]) + \
        gg.scale_color_manual(name="Cell Line",
Example #9
0
def control_list(in_file=None,
                 out_dir=None,
                 reference_gene_file=None,
                 log2=False,
                 page_width=None,
                 page_height=None,
                 user_img_file=None,
                 page_format=None,
                 pseudo_count=1,
                 set_colors=None,
                 dpi=300,
                 rug=False,
                 jitter=False,
                 skip_first=False):
    # -------------------------------------------------------------------------
    #
    # Check in_file content
    #
    # -------------------------------------------------------------------------

    for p, line in enumerate(in_file):

        line = chomp(line)
        line = line.split("\t")

        if len(line) > 2:
            message("Need a two columns file.",
                    type="ERROR")
        if skip_first:
            if p == 0:
                continue
        try:
            fl = float(line[1])
        except ValueError:
            msg = "It seems that column 2 of input file"
            msg += " contains non numeric values. "
            msg += "Check that no header is present and that "
            msg += "columns are ordered properly. "
            msg += "Or use '--skip-first'. "
            message(msg, type="ERROR")

        if log2:
            fl = fl + pseudo_count
            if fl <= 0:
                message("Can not log transform negative/zero values. Add a pseudo-count.",
                        type="ERROR")

    # -------------------------------------------------------------------------
    #
    # Check colors
    #
    # -------------------------------------------------------------------------

    set_colors = set_colors.split(",")

    if len(set_colors) != 2:
        message("Need two colors. Please fix.", type="ERROR")

    mcolors_name = mcolors.cnames

    for i in set_colors:
        if i not in mcolors_name:
            if not is_hex_color(i):
                message(i + " is not a valid color. Please fix.", type="ERROR")

    # -------------------------------------------------------------------------
    #
    # Preparing output files
    #
    # -------------------------------------------------------------------------

    # Preparing pdf file name
    file_out_list = make_outdir_and_file(out_dir, ["control_list.txt",
                                                   "reference_list.txt",
                                                   "diagnostic_diagrams." + page_format],
                                         force=True)

    control_file, reference_file_out, img_file = file_out_list

    if user_img_file is not None:

        os.unlink(img_file.name)
        img_file = user_img_file

        if not img_file.name.endswith(page_format):
            msg = "Image format should be: {f}. Please fix.".format(f=page_format)
            message(msg, type="ERROR")

        test_path = os.path.abspath(img_file.name)
        test_path = os.path.dirname(test_path)

        if not os.path.exists(test_path):
            os.makedirs(test_path)

    # -------------------------------------------------------------------------
    #
    # Read the reference list
    #
    # -------------------------------------------------------------------------

    try:
        reference_genes = pd.read_csv(reference_gene_file.name, sep="\t", header=None)
    except pd.errors.EmptyDataError:
        message("No genes in --reference-gene-file.", type="ERROR")

    reference_genes.rename(columns={reference_genes.columns.values[0]: 'gene'}, inplace=True)

    # -------------------------------------------------------------------------
    #
    # Delete duplicates
    #
    # -------------------------------------------------------------------------

    before = len(reference_genes)
    reference_genes = reference_genes.drop_duplicates(['gene'])
    after = len(reference_genes)

    msg = "%d duplicate lines have been deleted in reference file."
    message(msg % (before - after))

    # -------------------------------------------------------------------------
    #
    # Read expression data and add the pseudo_count
    #
    # -------------------------------------------------------------------------

    if skip_first:
        exp_data = pd.read_csv(in_file.name, sep="\t",
                               header=None, index_col=None,
                               skiprows=[0], names=['exprs'])
    else:

        exp_data = pd.read_csv(in_file.name, sep="\t", names=['exprs'], index_col=0)

    exp_data.exprs = exp_data.exprs.values + pseudo_count

    # -------------------------------------------------------------------------
    #
    # log transformation
    #
    # -------------------------------------------------------------------------

    ylabel = 'Expression'

    if log2:
        if len(exp_data.exprs.values[exp_data.exprs.values == 0]):
            message("Can't use log transformation on zero or negative values. Use -p.",
                    type="ERROR")
        else:
            exp_data.exprs = np.log2(exp_data.exprs.values)
            ylabel = 'log2(Expression)'

    # -------------------------------------------------------------------------
    #
    # Are reference gene found in control list
    #
    # -------------------------------------------------------------------------

    # Sort in increasing order
    exp_data = exp_data.sort_values('exprs')

    #  Vector with positions indicating which in the
    # expression data list are found in reference_gene

    reference_genes_found = [x for x in reference_genes['gene'] if x in exp_data.index]

    msg = "Found %d genes of the reference in the provided signal file" % len(reference_genes_found)
    message(msg)

    not_found = [x for x in reference_genes['gene'] if x not in exp_data.index]

    if len(not_found):
        if len(not_found) == len(reference_genes):
            message("Genes from reference file where not found in signal file (n=%d)." % len(not_found), type="ERROR")
        else:
            message("List of reference genes not found :%s" % not_found)
    else:
        message("All reference genes were found.")

    # -------------------------------------------------------------------------
    #
    # Search for genes with matched signal
    #
    # -------------------------------------------------------------------------

    exp_data_save = exp_data.copy()

    control_list = list()

    nb_candidate_left = exp_data.shape[0] - len(reference_genes_found)

    message("Searching for genes with matched signal.")

    if nb_candidate_left < len(reference_genes_found):
        message("Not enough element to perform selection. Exiting", type="ERROR")

    for i in reference_genes_found:
        not_candidates = reference_genes_found + control_list
        not_candidates = list(set(not_candidates))

        diff = abs(exp_data.loc[i] - exp_data)
        control_list.extend(diff.loc[np.setdiff1d(diff.index, not_candidates)].idxmin(axis=0, skipna=True).tolist())

    # -------------------------------------------------------------------------
    #
    # Prepare a dataframe for plotting
    #
    # -------------------------------------------------------------------------

    message("Preparing a dataframe for plotting.")

    reference = exp_data_save.loc[reference_genes_found].sort_values('exprs')
    reference = reference.assign(genesets=['Reference'] * reference.shape[0])

    control = exp_data_save.loc[control_list].sort_values('exprs')
    control = control.assign(genesets=['Control'] * control.shape[0])

    data = pd.concat([reference, control])
    data['sets'] = pd.Series(['sets' for x in data.index.tolist()], index=data.index)
    data['genesets'] = Categorical(data['genesets'])

    # -------------------------------------------------------------------------
    #
    # Diagnostic plots
    #
    # -------------------------------------------------------------------------

    p = ggplot(data, aes(x='sets', y='exprs', fill='genesets'))

    p += scale_fill_manual(values=dict(zip(['Reference', 'Control'], set_colors)))

    p += geom_violin(color=None)

    p += xlab('Gene sets') + ylab(ylabel)

    p += facet_wrap('~genesets')

    if rug:
        p += geom_rug()

    if jitter:
        p += geom_jitter()

    p += theme_bw()
    p += theme(axis_text_x=element_blank())

    # -------------------------------------------------------------------------
    # Turn warning off. Both pandas and plotnine use warnings for deprecated
    # functions. I need to turn they off although I'm not really satisfied with
    # this solution...
    # -------------------------------------------------------------------------

    def fxn():
        warnings.warn("deprecated", DeprecationWarning)

    # -------------------------------------------------------------------------
    #
    # Saving
    #
    # -------------------------------------------------------------------------

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        fxn()
        message("Saving diagram to file : " + img_file.name)
        message("Be patient. This may be long for large datasets.")

        try:
            p.save(filename=img_file.name, width=page_width, height=page_height, dpi=dpi, limitsize=False)
        except PlotnineError as err:
            message("Plotnine message: " + err.message)
            message("Plotnine encountered an error.", type="ERROR")

    # -------------------------------------------------------------------------
    #
    # write results
    #
    # -------------------------------------------------------------------------

    exp_data_save.loc[reference_genes_found].sort_values('exprs').to_csv(reference_file_out.name, sep="\t")
    exp_data_save.loc[control_list].sort_values('exprs').to_csv(control_file.name, sep="\t")