예제 #1
0
def test_tile_aesthetics():
    p = (ggplot(df, aes('x', 'y', width=1, height=1)) + geom_tile() +
         geom_tile(aes(y='y+2', alpha='z'), show_legend=False) +
         geom_tile(aes(y='y+4', fill='factor(z)')) +
         geom_tile(aes(y='y+6', color='factor(z+1)'), size=2) + geom_tile(
             aes(y='y+8', linetype='factor(z+2)'), color='yellow', size=2) +
         _theme)

    assert p == 'tile-aesthetics'
예제 #2
0
def test_tile_aesthetics():
    p = (ggplot(df, aes('x', 'y', width=1, height=1)) +
         geom_tile() +
         geom_tile(aes(y='y+2', alpha='z'),
                   show_legend=False) +
         geom_tile(aes(y='y+4', fill='factor(z)')) +
         geom_tile(aes(y='y+6', color='factor(z+1)'), size=2) +
         geom_tile(aes(y='y+8', linetype='factor(z+2)'),
                   color='yellow', size=2) +
         _theme)

    assert p == 'tile-aesthetics'
예제 #3
0
def plot_two_way_sdc(sdc_df: pd.DataFrame, alpha: float = .05, **kwargs):
    """
    Plots the results of a SDC analysis for a fixed window size in a 2D figure.

    In a similar fashion to a recurrence plot, x and y axes represent the start index of the x and y sequences. Only
    results with a p_value < alpha are shown, while controlling the alpha as a function of the intensity of the score
    and the color as a function of the sign of the established relationship.

    Parameters
    ----------
    sdc_df
        Data frame as outputted by `compute_sdc` which will be used to plot the results.
    alpha
        Significance threshold. Only values with a score < alpha will be plotted
    kwargs
        Keyword arguments to pass to `plotnine.theme` to customize the plot.
    Returns
    -------
    p9.ggplot.ggplot
        Plot
    """
    fragment_size = int(sdc_df.iloc[0]['stop_1'] - sdc_df.iloc[0]['start_1'])
    f = (sdc_df.loc[lambda dd: dd.p_value < alpha].assign(r_str=lambda dd: dd[
        'r'].apply(lambda x: '$r > 0$' if x > 0 else '$r < 0$')).pipe(
            lambda dd: p9.ggplot(dd) + p9.aes(
                'start_1', 'start_2', fill='r_str', alpha='abs(r)'
            ) + p9.geom_tile() + p9.scale_fill_manual(['#da2421', 'black']) +
            p9.scale_y_reverse() + p9.theme(**kwargs) + p9.guides(alpha=False)
            + p9.labs(x='$X_i$',
                      y='$Y_j$',
                      fill='$r$',
                      title=f'Two-Way SDC plot for $S = {fragment_size}$' +
                      r' and $\alpha =$' + f'{alpha}')))

    return f
예제 #4
0
파일: plots.py 프로젝트: mappin/asxtrade
def make_sentiment_plot(sentiment_df, exclude_zero_bin=True, plot_text_labels=True):
    rows = []
    print(
        "Sentiment plot: exclude zero bins? {} show text? {}".format(
            exclude_zero_bin, plot_text_labels
        )
    )

    for column in filter(lambda c: c.startswith("bin_"), sentiment_df.columns):
        c = Counter(sentiment_df[column])
        date = column[4:]
        for bin_name, val in c.items():
            if exclude_zero_bin and (bin_name == "0.0" or not isinstance(bin_name, str)):
                continue
            bin_name = str(bin_name)
            assert isinstance(bin_name, str)
            val = int(val)
            rows.append(
                {
                    "date": datetime.strptime(date, "%Y-%m-%d"),
                    "bin": bin_name,
                    "value": val,
                }
            )

    df = pd.DataFrame.from_records(rows)
    # print(df['bin'].unique())
    # HACK TODO FIXME: should get from price_change_bins()...
    order = [
        "-1000.0",
        "-100.0",
        "-10.0",
        "-5.0",
        "-3.0",
        "-2.0",
        "-1.0",
        "-1e-06",
        "1e-06",
        "1.0",
        "2.0",
        "3.0",
        "5.0",
        "10.0",
        "25.0",
        "100.0",
        "1000.0",
    ]
    df["bin_ordered"] = pd.Categorical(df["bin"], categories=order)

    plot = (
        p9.ggplot(df, p9.aes("date", "bin_ordered", fill="value"))
        + p9.geom_tile(show_legend=False)
        + p9.theme_bw()
        + p9.xlab("")
        + p9.ylab("Percentage daily change")
        + p9.theme(axis_text_x=p9.element_text(angle=30, size=7), figure_size=(10, 5))
    )
    if plot_text_labels:
        plot = plot + p9.geom_text(p9.aes(label="value"), size=8, color="white")
    return plot_as_inline_html_data(plot)
예제 #5
0
def test_fill_gradient(tmp_path):
    x = sample_data()
    p = (p9.ggplot(x, p9.aes(x='x', y='y', fill='z')) +
         p9.geom_tile())

    p1 = p + endktheme.plotnine.scale_fill_gradient_energinet()
    render(tmp_path, p1)

    p2 = p + endktheme.plotnine.scale_fill_gradient2_energinet()
    render(tmp_path, p2)
예제 #6
0
def plot_corrtile(df, var=None, out=None, corr=None):
    r"""
    """
    return (
        df
        >> ggplot(aes(var, out))
        + geom_tile(aes(fill=corr))
        + scale_fill_gradient2(name="Corr", midpoint=0)
        + theme(axis_text_x=element_text(angle=270))
    )
예제 #7
0
def plot_counts(counts: pd.DataFrame):
	counts.bc_l = pd.Categorical(counts.bc_l, bc_range(counts.bc_l))
	counts.bc_r = pd.Categorical(counts.bc_r, bc_range(counts.bc_r))
	log = np.log
	return (
		ggplot(counts, aes('bc_r', 'bc_l', fill='log(Count)'))
		+ geom_tile()
		# + scale_y_reverse()
		+ coord_fixed()
		+ scale_fill_cmap('inferno')
		+ theme(axis_text_x=element_text(angle=90, vjust=.5))
		+ labs(x='Right Barcode', y='Left Barcode')
	)
예제 #8
0
def setup_heatmap0(df: pd.DataFrame, format_string, axis_text):
    # https://stackoverflow.com/a/62161556/819272
    # Plotnine does not support changing the position of any axis.
    return (p9.ggplot(df, p9.aes(y='row', x='col')) + p9.coord_equal() +
            p9.geom_tile(p9.aes(fill='scale')) + p9.geom_text(
                p9.aes(label='value'), format_string=format_string, size=7) +
            p9.scale_y_discrete(drop=False) + p9.scale_x_discrete(drop=False) +
            p9.scale_fill_gradientn(colors=['#63BE7B', '#FFEB84', '#F8696B'],
                                    na_value='#CCCCCC',
                                    guide=False) +
            p9.theme(axis_text=p9.element_blank()
                     if not axis_text else p9.element_text(face='bold'),
                     axis_ticks=p9.element_blank(),
                     axis_title=p9.element_blank(),
                     panel_grid=p9.element_blank()))
예제 #9
0
def plot_correlations(df, cols):
    '''Takes a data frame, a list of columns, and a pair of names for plot axes
        Returns a plot of pairwise correlations between all variables
    '''

    axisnames = ["Variable 1", "Variable 2", "Correlation"]

    corr_df = pd.DataFrame(df[cols].corr().stack()).reset_index()
    dict(zip(list(corr_df.columns), axisnames))
    corr_df.rename(columns=dict(zip(list(corr_df.columns), axisnames)),
                   inplace=True)

    return (p9.ggplot(corr_df,
                      p9.aes(axisnames[0], axisnames[1], fill=axisnames[2])) +
            p9.geom_tile(p9.aes(width=.95, height=.95)) +
            p9.theme(axis_text_x=p9.element_text(rotation=90)))
예제 #10
0
def plot_patch(patch, vmin=None, vmax=None, round_digits=1, threshold=None):
    if len(patch.shape) == 1:
        patch = patch.reshape(1, -1)
    if threshold is None:
        threshold = patch.mean()
    vmin = vmin or patch.min()
    vmax = vmax or patch.max()

    tile_height = tile_width = 0.95

    hshift = 0
    vshift = 0.5 * tile_height

    plotr = pd.DataFrame({
        'x': (
            np.tile(np.arange(patch.shape[1]), patch.shape[0]).flatten()
            + hshift
        ),
        'y': (
            - np.repeat(np.arange(patch.shape[0]), patch.shape[1]).flatten()
            + vshift
        ),
        'value': np.round(patch.flatten(), round_digits),
        'color_text': patch.flatten() < threshold
    })

    return (
        p9.ggplot(p9.aes('x', 'y'))
        + p9.geom_tile(plotr, p9.aes(width=tile_width, height=tile_height))
        + p9.geom_text(plotr, p9.aes(label='value', color='color_text'))
        + p9.aes(fill='value')
        + p9.coord_equal(expand=False)
        + p9.theme_void()
        + p9.scales.scale_fill_gradient(
            high='#f0f0f0', low='#252525', guide=False)
        + p9.scales.scale_color_gray(breaks=[False, True], guide=False)
    )
예제 #11
0
def make_sentiment_plot(sentiment_df,
                        exclude_zero_bin=True,
                        plot_text_labels=True):
    rows = []
    print("Sentiment plot: exclude zero bins? {} show text? {}".format(
        exclude_zero_bin, plot_text_labels))

    for column in filter(lambda c: c.startswith("bin_"), sentiment_df.columns):
        c = Counter(sentiment_df[column])
        date = column[4:]
        for bin_name, val in c.items():
            if exclude_zero_bin and (bin_name == "0.0"
                                     or not isinstance(bin_name, str)):
                continue
            bin_name = str(bin_name)
            assert isinstance(bin_name, str)
            val = int(val)
            rows.append({
                "date": datetime.strptime(date, "%Y-%m-%d"),
                "bin": bin_name,
                "value": val,
            })

    df = pd.DataFrame.from_records(rows)
    # print(df['bin'].unique())
    bins, labels = price_change_bins()  # pylint: disable=unused-variable
    order = filter(
        lambda s: s != "0.0", labels
    )  # dont show the no change bin since it dominates the activity heatmap
    df["bin_ordered"] = pd.Categorical(df["bin"], categories=order)

    plot = p9.ggplot(df, p9.aes(
        "date", "bin_ordered", fill="value")) + p9.geom_tile(show_legend=False)
    if plot_text_labels:
        plot = plot + p9.geom_text(
            p9.aes(label="value"), size=8, color="white")
    return user_theme(plot, y_axis_label="Daily change (%)")
예제 #12
0
def ggfuntile(f,
              d,
              xrng=(0, 1),
              yrng=(0, 1),
              limits=(0, 1),
              density=51,
              xlab="x",
              ylab="y",
              zlab="f",
              breaks=None,
              **kwargs):
    od = OrderedDict()
    od[xlab] = np.arange(xrng[0], xrng[1],
                         (xrng[1] - xrng[0]) / (density - 1.0))
    od[ylab] = np.arange(yrng[0], yrng[1],
                         (yrng[1] - yrng[0]) / (density - 1.0))
    ggdata = expandGrid(od)
    ggdata["z"] = [
        f(ggdata.iloc[i, 0], ggdata.iloc[i, 1]) for i in range(ggdata.shape[0])
    ]
    gg = ggplot(ggdata, aes(x=xlab, y=ylab))
    gg += geom_tile(aes(fill="z"))
    gg += scale_fill_gradientn(colors=[
        "black", "#202020", "#404040", "#808080", "white", "dodgerblue",
        "blue", "darkblue", "midnightblue"
    ],
                               name=zlab,
                               limits=limits)
    gg += theme_classic()
    gg += geom_point(data=d,
                     mapping=aes(shape="class"),
                     color="red",
                     size=2,
                     alpha=0.8)
    gg += scale_shape_manual(values=["x", "^"])
    return gg
예제 #13
0
    def pictures(self, mode='bw', subset=None, n_random=10):
        """Returns a picture of the selected images.

        Creates either a colored or a black-white picture of the selected
        images.

        Args:
            mode: Should the picture be black-white ('bw') or in color
                ('color')?
            subset: Optional list of picture indices that should be included in
                the dataframe. If specified, n_random will be ignored.
            n_random: Optional number of randomly selected images. If neither
                subset nor n_random are specified, all images will be included.

        Returns:
            A plotnine object including all pictures with their label.

        Raises:
            NotImplementedError: mode must be either 'bw' or 'color'."""
        dataframe = self.rgb_dataframe(subset=subset, n_random=n_random)
        if mode == 'bw':
            fill_key = 'rgb_bw'
        elif mode == 'color':
            fill_key = 'rgb'
        else:
            raise NotImplementedError("Pictures are either in black-white"
                                      "('bw') or in color ('color').")
        picture = (
            gg.ggplot(dataframe, gg.aes(x='x', y='y', fill=fill_key)) +
            gg.geom_tile() + gg.theme_void() +
            gg.theme(legend_position='none') + gg.scale_fill_manual(
                values={key: key
                        for key in dataframe[fill_key].unique()}) +
            gg.facet_wrap('image_id', labeller=self.labeller) +
            gg.scale_y_reverse() + gg.coord_fixed())
        return picture
예제 #14
0
def hist_plot(df,
              x,
              y=None,
              group = None,
              facet_x = None,
              facet_y = None,
              w='1',
              bins=21,
              bin_width = None,
              position = 'stack',
              normalize = False,
              sort_groups=True,
              base_size=10,
              figure_size=(6, 3)):

    '''
    Plot a 1-d or 2-d histogram

    Parameters
    ----------
    df : pd.DataFrame
      input dataframe
    x : str
      quoted expression to be plotted on the x axis
    y : str
      quoted expression to be plotted on the y axis. If this is specified the histogram will be 2-d.
    group : str
      quoted expression to be used as group (ie color)
    facet_x : str
      quoted expression to be used as facet
    facet_y : str
      quoted expression to be used as facet
    w : str
      quoted expression representing histogram weights (default is 1)
    bins : int or tuple
      number of bins to be used
    bin_width : float or tuple
      bin width to be used
    position : str
      if groups are present, choose between `stack`, `overlay` or `dodge`
    normalize : bool
      normalize histogram counts
    sort_groups : bool
      sort groups by the sum of their value (otherwise alphabetical order is used)
    base_size : int
      base size for theme_ez
    figure_size :tuple of int
      figure size

    Returns
    -------
    g : EZPlot
      EZplot object

    '''

    if position not in ['overlay', 'stack', 'dodge']:
        log.error("position not recognized")
        raise NotImplementedError("position not recognized")

    if (bins is None) and (bin_width is None):
        log.error("Either bins or bin_with should be defined")
        raise ValueError("Either bins or bin_with should be defined")

    if (bins is not None) and (bin_width is not None):
        log.error("Only one between bins or bin_with should be defined")
        raise ValueError("Only one between  bins or bin_with should be defined")

    if (y is not None) and (group is not None):
        log.error("y and group cannot be requested at the same time")
        raise ValueError("y and group cannot be requested at the same time")

    if y is None:
        bins = (bins, bins)
        bin_width = (bin_width, bin_width)
    else:
        if type(bins) not in [tuple, list]:
            bins = (bins, bins)
        if type(bin_width) not in [tuple, list]:
            bin_width = (bin_width, bin_width)

    # create a copy of the data
    dataframe = df.copy()

    # define groups and variables; remove and store (eventual) names
    names = {}
    groups = {}
    variables = {}

    for label, var in zip(['x', 'y', 'group', 'facet_x', 'facet_y'], [x, y, group, facet_x, facet_y]):
        names[label], groups[label] = unname(var)
    names['w'], variables['w'] = unname(w)

    # set column names and evaluate expressions
    tmp_df = agg_data(dataframe, variables, groups, None, fill_groups=False)

    # redefine groups and variables; remove and store (eventual) names
    new_groups = {c:c for c in tmp_df.columns if c in ['x', 'y', 'group', 'facet_x', 'facet_y']}
    non_xy_groups = [g for g  in new_groups.keys() if g not in ['x', 'y']]
    new_variables = {'w':'w'}

    # bin data (if necessary)
    if tmp_df['x'].dtypes != np.dtype('O'):
        tmp_df['x'], bins_x, bin_width_x= bin_data(tmp_df['x'], bins[0], bin_width[0])
    else:
        bin_width_x=1
    if y is not None:
        if tmp_df['y'].dtypes != np.dtype('O'):
            tmp_df['y'], bins_y, bin_width_y = bin_data(tmp_df['y'], bins[1], bin_width[1])
        else:
            bin_width_y=1
    else:
        bin_width_y=1

    # aggregate data and reorder columns
    gdata = agg_data(tmp_df, new_variables, new_groups, 'sum', fill_groups=True)
    gdata.fillna(0, inplace=True)
    gdata = gdata[[c for c in ['x', 'y', 'w', 'group', 'facet_x', 'facet_y'] if c in gdata.columns]]

    # normalize
    if normalize:
        if len(non_xy_groups)==0:
            gdata['w'] = gdata['w']/(gdata['w'].sum()*bin_width_x*bin_width_y)
        else:
            gdata['w'] = gdata.groupby(non_xy_groups)['w'].apply(lambda x: x/(x.sum()*bin_width_x*bin_width_y))

    # start plotting
    g = EZPlot(gdata)
    # determine order and create a categorical type
    if (group is not None) and sort_groups:
        if g.column_is_categorical('x'):
            g.sort_group('x', 'w', ascending=False)
        g.sort_group('group', 'w')
        g.sort_group('facet_x', 'w', ascending=False)
        g.sort_group('facet_y', 'w', ascending=False)
        if groups:
            colors = np.flip(ez_colors(g.n_groups('group')))
    elif (group is not None):
        colors = ez_colors(g.n_groups('group'))

    if y is None:
        # set groups
        if group is None:
            g += p9.geom_bar(p9.aes(x="x", y="w"),
                             stat = 'identity',
                             colour = None,
                             fill = ez_colors(1)[0])
        else:
            g += p9.geom_bar(p9.aes(x="x", y="w",
                                    group="factor(group)",
                                    fill="factor(group)"),
                             colour=None,
                             stat = 'identity',
                             **POSITION_KWARGS[position])
            g += p9.scale_fill_manual(values=colors)

        # set facets
        if facet_x is not None and facet_y is None:
            g += p9.facet_wrap('~facet_x')
        if facet_x is not None and facet_y is not None:
            g += p9.facet_grid('facet_y~facet_x')

        # set x scale
        if g.column_is_categorical('x'):
            g += p9.scale_x_discrete()
        else:
            g += p9.scale_x_continuous(labels=ez_labels)

        # set y scale
        g += p9.scale_y_continuous(labels=ez_labels)

        # set axis labels
        g += \
            p9.xlab(names['x']) + \
            p9.ylab('Counts')

        # set theme
        g += theme_ez(figure_size=figure_size,
                      base_size=base_size,
                      legend_title=p9.element_text(text=names['group'], size=base_size))

        if sort_groups:
            g += p9.guides(fill=p9.guide_legend(reverse=True))

    else:
        g += p9.geom_tile(p9.aes(x="x", y="y", fill='w'),
                          stat = 'identity',
                          colour = None)

        # set facets
        if facet_x is not None and facet_y is None:
            g += p9.facet_wrap('~facet_x')
        if facet_x is not None and facet_y is not None:
            g += p9.facet_grid('facet_y~facet_x')

        # set x scale
        if g.column_is_categorical('x'):
            g += p9.scale_x_discrete()
        else:
            g += p9.scale_x_continuous(labels=ez_labels)

        # set y scale
        if g.column_is_categorical('y'):
            g += p9.scale_y_discrete()
        else:
            g += p9.scale_y_continuous(labels=ez_labels)

        # set axis labels
        g += \
            p9.xlab(names['x']) + \
            p9.ylab(names['y'])

        # set theme
        g += theme_ez(figure_size=figure_size,
                      base_size=base_size,
                      legend_title=p9.element_text(text='Counts', size=base_size))

    return g
예제 #15
0
        B_list = filler_r_maxes[test_filename_to_filler_name(
            test_filenames[j])]
        print(FILLER_NAMES[i], test_filenames[j])
        A_proportions = compute_proportions(A_list)
        B_proportions = compute_proportions(B_list)
        read_words.append(test_filename_to_filler_name(test_filenames[j]))
        write_words.append(FILLER_NAMES_TOYLABELS[FILLER_NAMES[i]])
        correlations.append(np.corrcoef(A_proportions, B_proportions)[0, 1])

# TODO: Change this to ggplot.
data_df = pd.DataFrame({
    'read_words': read_words,
    'write_words': write_words,
    'correlations': correlations
})
plot = p9.ggplot(data_df, p9.aes(x='read_words', y='write_words', fill='correlations')) +\
        p9.geom_tile() +\
        p9.labels.labs(x='Read', y='Write') +\
        p9.theme(axis_text_x=p9.element_text(rotation=90))
plot.save(os.path.join(SAVE_DIR, "readwrite_correlations_" + trial_num),
          dpi=900)

# Make legend.
for role, color in query_colors.items():
    plt.plot([0, 0], [1, 1], color=color, label=role)
legend = plt.legend(ncol=1, bbox_to_anchor=(1.1, -0.1))
plt.savefig(os.path.join(SAVE_DIR, "decoding_legend_2col"),
            bbox_extra_artists=(legend, ),
            bbox_inches="tight",
            dpi=900)
                                                 "data_length": pan.data_length,
                                                 "min": round(pan["min"] / dt["min"], 1),
                                                 "max": round(pan["max"] / dt["max"], 1),
                                                 "avg": round(pan["avg"] / dt["avg"], 1),
                                                 "q50": round(pan["q50"] / dt["q50"], 1)}))

comparison[["sel_col", "pos_col"]] = comparison["scenario"].apply(clean_scenarios)
comparison[["no_column", "data_length"]] = comparison[["no_column", "data_length"]].apply(pd.to_numeric, downcast="integer")

### Visual Exploration
saveformat = "png"
## Select

# pandas
plot = (gg.ggplot(pandas_sel, gg.aes("factor(no_column)", "factor(data_length)")) +
        gg.geom_tile(gg.aes(fill="q50")) +
        gg.geom_text(gg.aes(label="q50"), color="white", size=9) +
        gg.labs(y="# Rows", x="# Columns", title="Pandas median selection time") +
        gg.facet_grid("pos_col ~ sel_col") +
        gg.theme_bw() +
        gg.theme(legend_position=None))

gg.ggsave(plot, filename=os.path.join(path_n, "output", f"select_results_pandas.{saveformat}"), width=15, height=10)

# data.table
plot = (gg.ggplot(datatable_sel, gg.aes("factor(no_column)", "factor(data_length)")) +
        gg.geom_tile(gg.aes(fill="q50")) +
        gg.geom_text(gg.aes(label="q50"), color="white", size=9) +
        gg.labs(y="# Rows", x="# Columns", title="data.table median selection time") +
        gg.facet_grid("pos_col ~ sel_col") +
        gg.theme_bw() +
예제 #17
0
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import settings
import plotnine as p9
import warnings
warnings.filterwarnings('ignore')

device_id = '000D6F000C1382DE1'

sql = f"""
SELECT *
FROM AH_USE_LOG_BYMINUTE
WHERE 1=1
AND DEVICE_ID = '{device_id}'
AND COLLECT_DATE >= '20191101'"""

df = pd.read_sql(sql, con=settings.conn)

(p9.ggplot(data=df, mapping=p9.aes(x='COLLECT_TIME', y='COLLECT_DATE')) +
 p9.geom_tile(p9.aes(fill='APPLIANCE_STATUS')))
예제 #18
0
select_df = (pd.read_csv(
    select_file, sep="\t",
    index_col=0).reset_index().rename({
        "index": "feature"
    }, axis="columns").melt(
        id_vars="feature", var_name="plate",
        value_name="status").query("plate not in @nonuniform_plates"))

# Reorder data
select_df.plate = pd.Categorical(select_df.plate,
                                 categories=plate_order,
                                 ordered=True)
select_df.feature = pd.Categorical(select_df.feature,
                                   categories=feature_order,
                                   ordered=True)

print(select_df.shape)
select_df.head()

# In[15]:

feature_select_gg = (
    gg.ggplot(select_df, gg.aes(x="feature", y="plate", fill="status")) +
    gg.geom_tile(size=0.5) + gg.ggtitle("Feature Select Summary") +
    theme_summary + gg.theme(axis_text_y=gg.element_blank()))

output_file = pathlib.Path(f"{output_fig_dir}/feature_select_summary.png")
feature_select_gg.save(output_file, dpi=dpi, height=4, width=6)

feature_select_gg
예제 #19
0
def ologram_merge_stats(inputfiles=None,
                        pdf_width=None,
                        pdf_height=None,
                        output=None,
                        labels=None):
    # -------------------------------------------------------------------------
    # Check user provided labels
    # -------------------------------------------------------------------------

    if labels is not None:

        labels = labels.split(",")

        for elmt in labels:
            if not re.search("^[A-Za-z0-9_]+$", elmt):
                message(
                    "Only alphanumeric characters and '_' allowed for --more-bed-labels",
                    type="ERROR")
        if len(labels) != len(inputfiles):
            message("--labels: the number of labels should be"
                    " the same as the number of input files ", type="ERROR")

        if len(labels) != len(set(labels)):
            message("Redundant labels not allowed.", type="ERROR")

    # -------------------------------------------------------------------------
    # Loop over input files
    # -------------------------------------------------------------------------

    df_list = list()
    df_label = list()

    for pos, infile in enumerate(inputfiles):
        message("Reading file : " + infile.name)
        # Read the dataset into a temporay dataframe
        df_tmp = pd.read_csv(infile, sep='\t', header=0, index_col=None)
        # Change name of 'feature_type' column.
        df_tmp = df_tmp.rename(index=str, columns={"feature_type": "Feature"})
        # Assign the name of the dataset to a new column

        if labels is None:
            file_short_name = os.path.basename(os.path.normpath(os.path.dirname(infile.name)))
            df_label += [file_short_name]
        else:
            file_short_name = labels[pos]
            df_label += [labels[pos]]

        df_tmp = df_tmp.assign(**{"dataset": [file_short_name] * df_tmp.shape[0]})
        # Pval set to 0 or -1 are changed to 1e-320 and NaN respectively
        df_tmp.loc[df_tmp['summed_bp_overlaps_pvalue'] == 0, 'summed_bp_overlaps_pvalue'] = 1e-320
        df_tmp.loc[df_tmp['summed_bp_overlaps_pvalue'] == -1, 'summed_bp_overlaps_pvalue'] = np.nan
        # Compute -log10(pval)
        df_tmp = df_tmp.assign(**{"-log_10(pval)": -np.log10(df_tmp.summed_bp_overlaps_pvalue)})

        # Which p-values are signifcant ?
        # TODO: For now, draws all p-values. Add Benjamini-Hochberg correction, and distinguish between NaN and 0.
        df_tmp = df_tmp.assign(**{"pval_signif": df_tmp.summed_bp_overlaps_pvalue > 0})

        # Add the df to the list to be subsequently merged
        df_list += [df_tmp]



    if len(set(df_label)) < len(df_label):
        message('Enclosing directories are ambiguous and cannot be used as labels. You may use "--labels".',
                type="ERROR")

    # -------------------------------------------------------------------------
    # Concatenate dataframes (row bind)
    # -------------------------------------------------------------------------

    message("Merging dataframes.")
    df_merged = pd.concat(df_list, axis=0)

    # -------------------------------------------------------------------------
    # Plotting
    # -------------------------------------------------------------------------

    message("Plotting")
    my_plot = ggplot(data=df_merged,
                     mapping=aes(y='Feature', x='dataset'))
    my_plot += geom_tile(aes(fill = 'summed_bp_overlaps_log2_fold_change'))
    my_plot += scale_fill_gradient2()
    my_plot += labs(fill = "log2(fold change) for summed bp overlaps")

    # Points for p-val. Must be after geom_tile()
    my_plot += geom_point(data = df_merged.loc[df_merged['pval_signif']],
        mapping = aes(x='dataset',y='Feature',color = '-log_10(pval)'), size=4, shape ='D', inherit_aes = False)
    my_plot += scale_color_gradientn(colors = ["#160E00","#FFB025","#FFE7BD"])
    my_plot += labs(color = "-log10(p-value)")

    # Theming
    my_plot += theme_bw()
    my_plot += theme(panel_grid_major=element_blank(),
                     axis_text_x=element_text(rotation=90),
                     panel_border=element_blank(),
                     axis_ticks=element_blank())

    # -------------------------------------------------------------------------
    # Saving
    # -------------------------------------------------------------------------

    message("Saving")
    nb_ft = len(list(df_merged['Feature'].unique()))
    nb_datasets = len(list(df_merged['dataset'].unique()))

    if pdf_width is None:
        panel_width = 0.6
        pdf_width = panel_width * nb_datasets

        if pdf_width > 100:
            pdf_width = 100
            message("Setting --pdf-width to 100 (limit)")

    if pdf_height is None:
        panel_height = 0.6
        pdf_height = panel_height * nb_ft

        if pdf_height > 500:
            pdf_height = 500
            message("Setting --pdf-height to 500 (limit)")

    message("Page width set to " + str(pdf_width))
    message("Page height set to " + str(pdf_height))
    figsize = (pdf_width, pdf_height)

    # -------------------------------------------------------------------------
    # Turn warning off. Both pandas and plotnine use warnings for deprecated
    # functions. I need to turn they off although I'm not really satisfied with
    # this solution...
    # -------------------------------------------------------------------------

    def fxn():
        warnings.warn("deprecated", DeprecationWarning)

    # -------------------------------------------------------------------------
    # Saving
    # -------------------------------------------------------------------------

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        fxn()

        message("Saving diagram to file : " + output.name)
        message("Be patient. This may be long for large datasets.")

        # NOTE : We must manually specify figure size with save_as_pdf_pages
        save_as_pdf_pages(filename=output.name,
                          plots=[my_plot + theme(figure_size=figsize)],
                          width=pdf_width,
                          height=pdf_height)
예제 #20
0
파일: heatmap.py 프로젝트: pvtodorov/CELLEX
def heatmap(esw: pd.DataFrame,
            genes: list = None,
            annotations: list = None,
            figsize: tuple = None) -> p9.ggplot:
    """
    
    Args:
    esw             : DataFrame of ES weights
    genes          : a list of genes to include in the heatmap
    annotations    : a list of annotations to include in the heatmap
    figsize : (float, float), optional (default: None)
        Specify width and height of plot.
    Returns:
        g    : ggplot

    """

    df_tidy = esw

    ### Reduce dataframe to genes and annotations of interest
    if genes is not None:
        genes = [str.upper(s) for s in genes]
        idx = np.char.upper(df_tidy.index.values.astype(str))
        mask = np.isin(idx, genes)
        df_tidy = esw[mask]

    if annotations is not None:
        annotations = [str.upper(s) for s in annotations]
        cols = np.char.upper(df_tidy.columns.values.astype(str))
        mask = np.isin(cols, annotations)
        df_tidy = df_tidy.iloc[:, mask]

    # Constants, height and width of plot.
    if figsize is None:
        W = min((df_tidy.shape[0], df_tidy.shape[1], 20))
        H = min((df_tidy.shape[0], df_tidy.shape[1], 20))
    else:
        W, H = figsize

    ### Convert to tidy / long format if necessary
    # Org:
    #       ABC  ACBG  ACMB
    # POMC  0.0   0.5   0.9
    # AGRP  0.2   0.0   0.0
    # LEPR  0.1   0.1   0.4

    # Tidy:
    #   gene_name annotation    es_weight
    # 1 POMC      ABC           0.0
    # 2 AGRP      ABC           0.6
    # 3 LEPR      ABC           1.0

    df_tidy.index.name = None  # ensure that index name is none, so "index" is used for id_vars
    df_tidy = pd.melt(df_tidy.reset_index(),
                      id_vars="index",
                      var_name="annotation",
                      value_name="weight")

    ### Plot
    p = (
        ### data
        p9.ggplot(
            data=df_tidy,
            mapping=p9.aes(
                x="index", y="annotation", fill="weight", label="annotation"))

        ### theming
        + p9.theme_classic() + p9.theme(
            figure_size=(W, H),
            axis_text_x=p9.element_text(rotation=75),
        ) + p9.labs(
            x="",  # e.g. "Cell-type"
            y="",  # e.g. "ES weight"
        )

        ### viz
        + p9.geom_tile()
        # + p9.scale_fill_gradientn(colors=['#9ebcda','#8c6bb1','#88419d','#6e016b']) # light blue to purple
        + p9.scale_fill_gradientn(colors=['#ffffff', '#1E90FF'],
                                  limits=[0, 1])  # white to dodgerblue
    )

    return p