Пример #1
0
def plot():
    outdir = 'output/protobowl/'
    pathlib.Path(outdir).mkdir(parents=True, exist_ok=True)

    df = load_protobowl()
    df.result = df.result.apply(lambda x: x is True)
    df['log_n_records'] = df.user_n_records.apply(np.log)

    df_user_grouped = df.groupby('uid')
    user_stat = df_user_grouped.agg(np.mean)
    print('{} users'.format(len(user_stat)))
    print('{} records'.format(len(df)))
    max_color = user_stat.log_n_records.max()
    user_stat['alpha'] = pd.Series(
        user_stat.log_n_records.apply(lambda x: x / max_color), index=user_stat.index)

    # 2D user plot
    p0 = ggplot(user_stat) \
        + geom_point(aes(x='relative_position', y='result',
                     size='user_n_records', color='log_n_records', alpha='alpha'),
                     show_legend={'color': False, 'alpha': False, 'size': False}) \
        + scale_color_gradient(high='#e31a1c', low='#ffffcc') \
        + labs(x='Average buzzing position', y='Accuracy') \
        + theme(aspect_ratio=1)
    p0.save(os.path.join(outdir, 'protobowl_users.pdf'))
    # p0.draw()
    print('p0 done')

    # histogram of number of records
    p1 = ggplot(user_stat, aes(x='log_n_records', y='..density..')) \
        + geom_histogram(color='#e6550d', fill='#fee6ce') \
        + geom_density() \
        + labs(x='Log number of records', y='Density') \
        + theme(aspect_ratio=0.3)
    p1.save(os.path.join(outdir, 'protobowl_hist.pdf'))
    # p1.draw()
    print('p1 done')

    # histogram of accuracy
    p2 = ggplot(user_stat, aes(x='result', y='..density..')) \
        + geom_histogram(color='#31a354', fill='#e5f5e0') \
        + geom_density() \
        + labs(x='Accuracy', y='Density') \
        + theme(aspect_ratio=0.3)
    p2.save(os.path.join(outdir, 'protobowl_acc.pdf'))
    # p2.draw()
    print('p2 done')

    # histogram of buzzing position
    p3 = ggplot(user_stat, aes(x='relative_position', y='..density..')) \
        + geom_histogram(color='#3182bd', fill='#deebf7') \
        + geom_density() \
        + labs(x='Average buzzing position', y='Density') \
        + theme(aspect_ratio=0.3)
    p3.save(os.path.join(outdir, 'protobowl_pos.pdf'))
    # p3.draw()
    print('p3 done')
Пример #2
0
def test_add_element_blank():
    # Adding onto a blanked themeable
    theme1 = theme_gray() + theme(axis_line_x=l1)  # not blank
    theme2 = theme1 + theme(axis_line_x=blank)     # blank
    theme3 = theme2 + theme(axis_line_x=l3)        # not blank
    theme4 = theme_gray() + theme(axis_line_x=l3)  # for comparison
    assert theme3 != theme1
    assert theme3 != theme2
    assert theme3 == theme4  # blanking cleans the slate

    # When a themeable is blanked, the apply method
    # is replaced with the blank method.
    th2 = theme2.themeables['axis_line_x']
    th3 = theme3.themeables['axis_line_x']
    assert th2.apply.__name__ == 'blank'
    assert th3.apply.__name__ == 'apply'
Пример #3
0
def test_aesthetics():
    df = pd.DataFrame({
            'a': range(5),
            'b': 2,
            'c': 3,
            'd': 4,
            'e': 5,
            'f': 6,
            'g': 7,
            'h': 8,
            'i': 9
        })

    p = (ggplot(df, aes(y='a')) +
         geom_point(aes(x='b')) +
         geom_point(aes(x='c', size='a')) +
         geom_point(aes(x='d', alpha='a'),
                    size=10, show_legend=False) +
         geom_point(aes(x='e', shape='factor(a)'),
                    size=10, show_legend=False) +
         geom_point(aes(x='f', color='factor(a)'),
                    size=10, show_legend=False) +
         geom_point(aes(x='g', fill='a'), stroke=0,
                    size=10, show_legend=False) +
         geom_point(aes(x='h', stroke='a'), fill='white',
                    color='green', size=10) +
         geom_point(aes(x='i', shape='factor(a)'),
                    fill='brown', stroke=2, size=10, show_legend=False) +
         theme(subplots_adjust={'right': 0.85}))

    assert p == 'aesthetics'
Пример #4
0
def test_quantiles_width_dodge():
    p = (ggplot(df, aes('x')) +
         geom_violin(aes(y='y'),
                     draw_quantiles=[.25, .75], size=2) +
         geom_violin(aes(y='y+25'), color='green',
                     width=0.5, size=2) +
         geom_violin(aes(y='y+50', fill='factor(y%2)'),
                     size=2) +
         theme(subplots_adjust={'right': 0.85}))
    assert p == 'quantiles_width_dodge'
Пример #5
0
 def __init__(self, base_size=11, base_family='DejaVu Sans'):
     theme_light.__init__(self, base_size, base_family)
     self.add_theme(theme(
         axis_ticks=element_line(color='#DDDDDD', size=0.5),
         panel_border=element_rect(fill='None', color='#838383',
                                   size=1),
         strip_background=element_rect(
             fill='#DDDDDD', color='#838383', size=1),
         strip_text_x=element_text(color='black'),
         strip_text_y=element_text(color='black', angle=-90)
     ), inplace=True)
Пример #6
0
def test_add_complete_partial():
    theme1 = theme_gray()
    theme2 = theme1 + theme(axis_line_x=element_line())
    assert theme2 != theme1
    assert theme2.themeables != theme1.themeables
    assert theme2.rcParams == theme1.rcParams

    # specific difference
    for name in theme2.themeables:
        if name == 'axis_line_x':
            assert theme2.themeables[name] != theme1.themeables[name]
        else:
            assert theme2.themeables[name] == theme1.themeables[name]
Пример #7
0
def test_add_element_heirarchy():
    # parent themeable modifies child themeable
    theme1 = theme_gray() + theme(axis_line_x=l1)  # child
    theme2 = theme1 + theme(axis_line=l2)          # parent
    theme3 = theme1 + theme(axis_line_x=l3)        # child, for comparison
    assert theme2.themeables['axis_line_x'] == \
        theme3.themeables['axis_line_x']

    theme1 = theme_gray() + theme(axis_line_x=l1)  # child
    theme2 = theme1 + theme(line=l2)               # grand-parent
    theme3 = theme1 + theme(axis_line_x=l3)        # child, for comparison
    assert theme2.themeables['axis_line_x'] == \
        theme3.themeables['axis_line_x']

    # child themeable does not affect parent
    theme1 = theme_gray() + theme(axis_line=l1)    # parent
    theme2 = theme1 + theme(axis_line_x=l2)        # child
    theme3 = theme1 + theme(axis_line=l3)          # parent, for comparison
    assert theme3.themeables['axis_line'] != \
        theme2.themeables['axis_line']
Пример #8
0
def test_params():
    p = (ggplot(df, aes('x')) +
         geom_boxplot(df[:m], aes(y='y'), size=2, notch=True) +
         geom_boxplot(df[m:2*m], aes(y='y'), size=2,
                      notch=True, notchwidth=0.8) +
         # outliers
         geom_boxplot(df[2*m:3*m], aes(y='y'), size=2,
                      outlier_size=4, outlier_color='green') +
         geom_boxplot(df[2*m:3*m], aes(y='y+25'), size=2,
                      outlier_size=4, outlier_alpha=0.5) +
         geom_boxplot(df[2*m:3*m], aes(y='y+60'), size=2,
                      outlier_size=4, outlier_shape='D') +
         # position dodge
         geom_boxplot(df[3*m:4*m], aes(y='y', fill='factor(y%2)')) +
         theme(subplots_adjust={'right': 0.85})
         )
    assert p == 'params'
Пример #9
0
def theme_cognoma(fontsize_mult=1):   
    import plotnine as gg
    
    return (gg.theme_bw(base_size = 14 * fontsize_mult) +
        gg.theme(
          line = gg.element_line(color = "#4d4d4d"), 
          rect = gg.element_rect(fill = "white", color = None), 
          text = gg.element_text(color = "black"), 
          axis_ticks = gg.element_line(color = "#4d4d4d"),
          legend_key = gg.element_rect(color = None), 
          panel_border = gg.element_rect(color = "#4d4d4d"),  
          panel_grid = gg.element_line(color = "#b3b3b3"), 
          panel_grid_major_x = gg.element_blank(),
          panel_grid_minor = gg.element_blank(),
          strip_background = gg.element_rect(fill = "#FEF2E2", color = "#4d4d4d"),
          axis_text = gg.element_text(size = 12 * fontsize_mult, color="#4d4d4d"),
          axis_title_x = gg.element_text(size = 13 * fontsize_mult, color="#4d4d4d"),
          axis_title_y = gg.element_text(size = 13 * fontsize_mult, color="#4d4d4d")
    ))
Пример #10
0
def all_stack(fold=BUZZER_DEV_FOLD):
    df_rnn = stack('output/buzzer/RNNBuzzer', 'RNN', fold)
    df_mlp = stack('output/buzzer/MLPBuzzer', 'MLP', fold)
    df_thr = stack('output/buzzer/ThresholdBuzzer', 'Threshold', fold)
    df = df_rnn.append(df_mlp, ignore_index=True)
    df = df.append(df_thr, ignore_index=True)
    model_type = CategoricalDtype(
        categories=['Threshold', 'MLP', 'RNN'])
    df['Model'] = df['Model'].astype(model_type)
    p = (
        ggplot(df)
        + geom_area(aes(x='Position', y='Frequency', fill='Buzzing'))
        + facet_grid('~ Model')
        + theme_fs()
        + theme(
            aspect_ratio=1,
        )
        + scale_fill_brewer(type='div', palette=7)
    )
    p.save('output/buzzer/{}_stack.pdf'.format(fold))
Пример #11
0
def create_confidence_plot(conf_df):
    plt = (
        ggplot(conf_df)
        + aes(x='x', color='Method', fill='Method')
        + geom_density(alpha=.45)
        + facet_wrap('Task', nrow=4)
        + xlab('Confidence')
        + scale_color_manual(values=COLORS)
        + scale_fill_manual(values=COLORS)
        + theme_fs()
        + theme(
            axis_text_y=element_blank(),
            axis_ticks_major_y=element_blank(),
            axis_title_y=element_blank(),
            legend_title=element_blank(),
            legend_position='top',
            legend_box='horizontal',
        )
    )
    return plt
Пример #12
0
def protobowl(fold=BUZZER_DEV_FOLD):
    df_rnn = pickle.load(
        open('output/buzzer/RNNBuzzer/{}_protobowl.pkl'.format(fold), 'rb'))
    df_rnn = df_rnn.groupby(['Possibility', 'Outcome'])
    df_rnn = df_rnn.size().reset_index().rename(columns={0: 'Count'})
    df_rnn['Model'] = pd.Series(['RNN' for _ in range(len(df_rnn))], index=df_rnn.index)

    df_mlp = pickle.load(
        open('output/buzzer/MLPBuzzer/{}_protobowl.pkl'.format(fold), 'rb'))
    df_mlp = df_mlp.groupby(['Possibility', 'Outcome'])
    df_mlp = df_mlp.size().reset_index().rename(columns={0: 'Count'})
    df_mlp['Model'] = pd.Series(['MLP' for _ in range(len(df_mlp))], index=df_mlp.index)

    df_thr = pickle.load(
        open('output/buzzer/ThresholdBuzzer/{}_protobowl.pkl'.format(fold), 'rb'))
    df_thr = df_thr.groupby(['Possibility', 'Outcome'])
    df_thr = df_thr.size().reset_index().rename(columns={0: 'Count'})
    df_thr['Model'] = pd.Series(['Threshold' for _ in range(len(df_thr))], index=df_thr.index)

    df = df_rnn.append(df_mlp, ignore_index=True)
    df = df.append(df_thr, ignore_index=True)

    outcome_type = CategoricalDtype(categories=[15, 10, 5, 0, -5, -10, -15])
    df['Outcome'] = df['Outcome'].astype(outcome_type)
    model_type = CategoricalDtype(
        categories=['Threshold', 'MLP', 'RNN'])
    df['Model'] = df['Model'].astype(model_type)

    p = (
        ggplot(df)
        + geom_col(aes(x='Possibility', y='Count', fill='Outcome'),
                   width=0.7)
        + facet_grid('Model ~')
        + coord_flip()
        + theme_fs()
        + theme(aspect_ratio=0.17)
        + scale_fill_brewer(type='div', palette=7)
    )

    figure_dir = os.path.join('output/buzzer/{}_protobowl.pdf'.format(fold))
    p.save(figure_dir)
Пример #13
0
def create_length_plot(len_df, legend_position='right', legend_box='vertical'):
    mean_len_df = len_df.groupby(['Task', 'Method']).mean().reset_index()
    mean_len_df[' '] = 'Mean Length'

    plt = (
        ggplot(len_df)
        + aes(x='x', fill='Method', y='..density..')
        + geom_histogram(binwidth=2, position='identity', alpha=.6)
        + geom_text(
            aes(x='x', y=.22, label='x', color='Method'),
            mean_len_df,
            inherit_aes=False,
            format_string='{:.1f}',
            show_legend=False
        )
        + geom_segment(
            aes(x='x', xend='x', y=0, yend=.205, linetype=' '),
            mean_len_df,
            inherit_aes=False, color='black'
        )
        + scale_linetype_manual(['dashed'])
        + facet_wrap('Task')
        + xlim(0, 20) + ylim(0, .23)
        + xlab('Example Length') + ylab('Frequency')
        + scale_color_manual(values=COLORS)
        + scale_fill_manual(values=COLORS)
        + theme_fs()
        + theme(
            aspect_ratio=1,
            legend_title=element_blank(),
            legend_position=legend_position,
            legend_box=legend_box,
        )
    )

    return plt
Пример #14
0
    def plot_detected_tags(self, data, options):
        tmp = data["tags"][["tag", "matched", "tag_id"]].copy()
        tmp.loc[:, "prop_matched"] = -1
        tmp.loc[:, "n_tags"] = -1

        # Get all proportions bu tags
        tags_summary = (
            tmp.groupby("tag")
            .apply(self.get_proportions, by=["matched"])
            .reset_index(drop=True)
        )
        # convert tags to category
        tags_summary.tag = tags_summary.tag.astype("category")
        # Get list of all tags
        all_tags = tags_summary.tag.unique()
        matched = tags_summary.loc[tags_summary.matched == 1]
        # Get list of all matched tags
        matched_tags = matched.tag.unique()
        unmatched = []
        # Add a row for all unmatched tags
        for tag in all_tags:
            if tag not in matched_tags:
                row = tags_summary.loc[
                    (tags_summary.matched == 0) & (tags_summary.tag == tag)
                ]
                row.prop_matched = 0
                unmatched.append(row)
        unmatched.append(matched)
        # Create final object with unmatched and matched tags
        m_df = pd.concat(unmatched)

        # Sort values
        m_df = m_df.sort_values(["prop_matched", "tag"])
        m_df.tag = m_df.tag.cat.reorder_categories(m_df.tag.to_list())

        plt = ggplot(
            data=m_df,
            mapping=aes(
                x="tag",
                y="prop_matched",
                fill="tag",  # "factor(species, ordered=False)",
            ),
        )
        plot_width = 10 + len(m_df.tag.unique()) * 0.75
        plt = (
            plt
            + geom_bar(stat="identity", show_legend=True, position=position_dodge())
            + xlab("Species")
            + ylab("Proportion of annotation matched")
            + geom_text(
                mapping=aes(label="lbl_matched"), position=position_dodge(width=0.9),
            )
            + theme_classic()
            + theme(
                axis_text_x=element_text(angle=90, vjust=1, hjust=1, margin={"r": -30}),
                plot_title=element_text(
                    weight="bold", size=14, margin={"t": 10, "b": 10}
                ),
                figure_size=(plot_width, 10),
                text=element_text(size=12, weight="bold"),
            )
            + ggtitle(
                (
                    "Proportion of tags detected for model {}, database {}, class {}\n"
                    + "with detector options {}"
                ).format(
                    options["scenario_info"]["model"],
                    options["scenario_info"]["database"],
                    options["scenario_info"]["class"],
                    options,
                )
            )
        )

        return plt
Пример #15
0
def test_add_empty_theme_element():
    # An empty theme element does not alter the theme
    theme1 = theme_gray() + theme(axis_line_x=element_line(color='red'))
    theme2 = theme1 + theme(axis_line_x=element_line())
    assert theme1 == theme2
Пример #16
0
def plt9_tilt_xlab(angle=45):
    from plotnine import theme, element_text
    return theme(axis_text_x=element_text(rotation=angle, hjust=1))
Пример #17
0
import matplotlib.pyplot as plt
from matplotlib.testing.compare import compare_images
from matplotlib import cbook
import six

from plotnine import ggplot, theme


TOLERANCE = 2           # Default tolerance for the tests
DPI = 72                # Default DPI for the tests

# This partial theme modifies all themes that are used in
# the test. It is limited to setting the size of the test
# images Should a test require a larger or smaller figure
# size, the dpi or aspect_ratio should be modified.
test_theme = theme(figure_size=(640/DPI, 480/DPI))

if not os.path.exists(os.path.join(
        os.path.dirname(__file__), 'baseline_images')):
    raise IOError(
        "The baseline image directory does not exist. "
        "This is most likely because the test data is not installed. "
        "You may need to install plotnine from source to get the "
        "test data.")


def raise_no_baseline_image(filename):
    raise Exception("Baseline image {} is missing".format(filename))


def ggplot_equals(gg, right):
Пример #18
0
import pandas as pd

from plotnine import ggplot, aes, geom_col, theme

_theme = theme(subplots_adjust={'right': 0.80})

df = pd.DataFrame({'x': ['b', 'd', 'c', 'a'], 'y': [1, 2, 3, 4]})


def test_reorder():
    p = (ggplot(df, aes('reorder(x, y)', 'y', fill='reorder(x, y, True)')) +
         geom_col())
    assert p + _theme == 'reorder'


def test_reorder_index():
    # The dataframe is created with ordering according to the y
    # variable. So the x index should be ordered acc. to y too
    p = (ggplot(df, aes('reorder(x, x.index)', 'y')) + geom_col())
    assert p + _theme == 'reorder_index'
# Concatenate input and simulated dataframes together
combined_data_df = pd.concat([input_data_UMAPencoded_df, simulated_data_UMAPencoded_df])

# Plot
fig = ggplot(combined_data_df, aes(x='1', y='2'))
fig += geom_point(aes(color='experiment_id'), alpha=0.1)
fig += facet_wrap('~dataset')
fig += labs(x ='UMAP 1',
            y = 'UMAP 2',
            title = 'UMAP of original and simulated data (gene space)')
fig += theme_bw()
fig += theme(
    legend_title_align = "center",
    plot_background=element_rect(fill='white'),
    legend_key=element_rect(fill='white', colour='white'), 
    legend_title=element_text(family='sans-serif', size=15),
    legend_text=element_text(family='sans-serif', size=12),
    plot_title=element_text(family='sans-serif', size=15),
    axis_text=element_text(family='sans-serif', size=12),
    axis_title=element_text(family='sans-serif', size=15)
    )
fig += guides(colour=guide_legend(override_aes={'alpha': 1}))
fig += scale_color_manual(['red', '#bdbdbd'])
fig += geom_point(data=combined_data_df[combined_data_df['experiment_id'] == example_id],
                  alpha=0.1, 
                  color='red')

print(fig)
ggsave(plot=fig, filename=experiment_simulated_file)

Пример #20
0
levels = ["level_3", "level_4a", "level_4b", "pycytominer_select"]
metrics = ["mean", "median", "sum"]

# Set output directory
output_fig_dir = pathlib.Path("figures", batch)
output_fig_dir.mkdir(parents=True, exist_ok=True)

# Set plotting defaults
dpi = 500
height = 3.5
width = 6

# Set common plotnine theme
theme_summary = gg.theme_bw() + gg.theme(
    axis_text_x=gg.element_blank(),
    axis_text_y=gg.element_text(size=6),
    axis_title=gg.element_text(size=8),
    strip_background=gg.element_rect(colour="black", fill="#fdfff4"),
)

# In[4]:

# Load Data
results_files = {}
for level in levels:
    file_names = build_filenames(input_dir, level=level)
    metric_df = {}
    for metric in file_names:
        df = pd.read_csv(file_names[metric], sep="\t", index_col=0)
        metric_df[metric] = df

    results_files[level] = metric_df
Пример #21
0
def plot_portfolio(portfolio_df,
                   figure_size=(12, 4),
                   line_size=1.5,
                   date_text_size=7):
    """
    Given a daily snapshot of virtual purchases plot both overall and per-stock
    performance. Return a tuple of figures representing the performance as inline data.
    """
    assert portfolio_df is not None
    #print(portfolio_df)
    portfolio_df['date'] = pd.to_datetime(portfolio_df['date'])
    avg_profit_over_period = portfolio_df.filter(
        items=['stock', 'stock_profit']).groupby('stock').mean()
    avg_profit_over_period['contribution'] = [
        'positive' if profit >= 0.0 else 'negative'
        for profit in avg_profit_over_period.stock_profit
    ]
    avg_profit_over_period = avg_profit_over_period.drop(
        'stock_profit',
        axis='columns')  # dont want to override actual profit with average
    portfolio_df = portfolio_df.merge(avg_profit_over_period,
                                      left_on='stock',
                                      right_index=True,
                                      how='inner')
    #print(portfolio_df)

    # 1. overall performance
    df = portfolio_df.filter(items=[
        'portfolio_cost', 'portfolio_worth', 'portfolio_profit', 'date'
    ])
    df = df.melt(id_vars=['date'], var_name='field')
    plot = (
        p9.ggplot(df, p9.aes('date', 'value', group='field', color='field')) +
        p9.labs(x='', y='$ AUD') + p9.geom_line(size=1.5) +
        p9.facet_wrap('~ field', nrow=3, ncol=1, scales='free_y') +
        p9.theme(axis_text_x=p9.element_text(angle=30, size=date_text_size),
                 figure_size=figure_size,
                 legend_position='none'))
    overall_figure = plot_as_inline_html_data(plot)

    df = portfolio_df.filter(
        items=['stock', 'date', 'stock_profit', 'stock_worth', 'contribution'])
    melted_df = df.melt(id_vars=['date', 'stock', 'contribution'],
                        var_name='field')
    all_dates = sorted(melted_df['date'].unique())
    df = melted_df[melted_df['date'] == all_dates[-1]]
    df = df[df['field'] == 'stock_profit']  # only latest profit is plotted
    df['contribution'] = [
        'positive' if profit >= 0.0 else 'negative' for profit in df['value']
    ]

    # 2. plot contributors ie. winners and losers
    plot = (p9.ggplot(df, p9.aes('stock', 'value', fill='stock')) +
            p9.geom_bar(stat='identity') + p9.labs(x='', y='$ AUD') +
            p9.facet_grid('contribution ~ field') +
            p9.theme(legend_position='none', figure_size=figure_size))
    profit_contributors = plot_as_inline_html_data(plot)

    # 3. per purchased stock performance
    plot = (
        p9.ggplot(melted_df,
                  p9.aes('date', 'value', group='stock', colour='stock')) +
        p9.xlab('') + p9.geom_line(size=1.0) +
        p9.facet_grid('field ~ contribution', scales="free_y") + p9.theme(
            axis_text_x=p9.element_text(angle=30, size=date_text_size),
            figure_size=figure_size,
            panel_spacing=
            0.5,  # more space between plots to avoid tick mark overlap
            subplots_adjust={'right': 0.8}))
    stock_figure = plot_as_inline_html_data(plot)
    return overall_figure, stock_figure, profit_contributors
Пример #22
0
    def plot(df: 'DataFrame',
             group_colname: str = None,
             time_colname: str = None,
             max_num_groups: int = 1,
             split_dt: Optional[np.datetime64] = None,
             **kwargs) -> 'DataFrame':
        """
        :param df: The output of `.to_dataframe()`.
        :param group_colname: The name of the group-column.
        :param time_colname: The name of the time-column.
        :param max_num_groups: Max. number of groups to plot; if the number of groups in the dataframe is greater than
        this, a random subset will be taken.
        :param split_dt: If supplied, will draw a vertical line at this date (useful for showing pre/post validation).
        :param kwargs: Further keyword arguments to pass to `plotnine.theme` (e.g. `figure_size=(x,y)`)
        :return: A plot of the predicted and actual values.
        """

        from plotnine import (
            ggplot, aes, geom_line, geom_ribbon, facet_grid, facet_wrap, theme_bw, theme, ylab, geom_vline
        )

        is_components = ('process' in df.columns and 'state_element' in df.columns)

        if group_colname is None:
            group_colname = 'group'
            if group_colname not in df.columns:
                raise TypeError("Please specify group_colname")
        if time_colname is None:
            time_colname = 'time'
            if 'time' not in df.columns:
                raise TypeError("Please specify time_colname")

        df = df.copy()
        if df[group_colname].nunique() > max_num_groups:
            subset_groups = df[group_colname].drop_duplicates().sample(max_num_groups).tolist()
            if len(subset_groups) < df[group_colname].nunique():
                print("Subsetting to groups: {}".format(subset_groups))
            df = df.loc[df[group_colname].isin(subset_groups), :]
        num_groups = df[group_colname].nunique()

        aes_kwargs = {'x': time_colname}
        if is_components:
            aes_kwargs['group'] = 'state_element'

        plot = (
                ggplot(df, aes(**aes_kwargs)) +
                geom_line(aes(y='mean'), color='#4C6FE7', size=1.5, alpha=.75) +
                geom_ribbon(aes(ymin='lower', ymax='upper'), color=None, alpha=.25) +
                ylab("")
        )

        if is_components:
            num_processes = df['process'].nunique()
            if num_groups > 1 and num_processes > 1:
                raise ValueError("Cannot plot components for > 1 group and > 1 processes.")
            elif num_groups == 1:
                plot = plot + facet_wrap(f"~ measure + process", scales='free_y', labeller='label_both')
                if 'figure_size' not in kwargs:
                    from plotnine.facets.facet_wrap import n2mfrow
                    nrow, _ = n2mfrow(len(df[['process', 'measure']].drop_duplicates().index))
                    kwargs['figure_size'] = (12, nrow * 2.5)
            else:
                plot = plot + facet_grid(f"{group_colname} ~ measure", scales='free_y', labeller='label_both')
                if 'figure_size' not in kwargs:
                    kwargs['figure_size'] = (12, num_groups * 2.5)

            if (df.groupby('measure')['process'].nunique() <= 1).all():
                plot = plot + geom_line(aes(y='mean', color='state_element'), size=1.5)

        else:
            if 'actual' in df.columns:
                plot = plot + geom_line(aes(y='actual'))
            if num_groups > 1:
                plot = plot + facet_grid(f"{group_colname} ~ measure", scales='free_y', labeller='label_both')
            else:
                plot = plot + facet_wrap("~measure", scales='free_y', labeller='label_both')

            if 'figure_size' not in kwargs:
                kwargs['figure_size'] = (12, 5)

        if split_dt:
            plot = plot + geom_vline(xintercept=np.datetime64(split_dt), linetype='dashed')

        return plot + theme_bw() + theme(**kwargs)
Пример #23
0
def control_list(in_file=None,
                 out_dir=None,
                 reference_gene_file=None,
                 log2=False,
                 page_width=None,
                 page_height=None,
                 user_img_file=None,
                 page_format=None,
                 pseudo_count=1,
                 set_colors=None,
                 dpi=300,
                 rug=False,
                 jitter=False,
                 skip_first=False):
    # -------------------------------------------------------------------------
    #
    # Check in_file content
    #
    # -------------------------------------------------------------------------

    for p, line in enumerate(in_file):

        line = chomp(line)
        line = line.split("\t")

        if len(line) > 2:
            message("Need a two columns file.",
                    type="ERROR")
        if skip_first:
            if p == 0:
                continue
        try:
            fl = float(line[1])
        except ValueError:
            msg = "It seems that column 2 of input file"
            msg += " contains non numeric values. "
            msg += "Check that no header is present and that "
            msg += "columns are ordered properly. "
            msg += "Or use '--skip-first'. "
            message(msg, type="ERROR")

        if log2:
            fl = fl + pseudo_count
            if fl <= 0:
                message("Can not log transform negative/zero values. Add a pseudo-count.",
                        type="ERROR")

    # -------------------------------------------------------------------------
    #
    # Check colors
    #
    # -------------------------------------------------------------------------

    set_colors = set_colors.split(",")

    if len(set_colors) != 2:
        message("Need two colors. Please fix.", type="ERROR")

    mcolors_name = mcolors.cnames

    for i in set_colors:
        if i not in mcolors_name:
            if not is_hex_color(i):
                message(i + " is not a valid color. Please fix.", type="ERROR")

    # -------------------------------------------------------------------------
    #
    # Preparing output files
    #
    # -------------------------------------------------------------------------

    # Preparing pdf file name
    file_out_list = make_outdir_and_file(out_dir, ["control_list.txt",
                                                   "reference_list.txt",
                                                   "diagnostic_diagrams." + page_format],
                                         force=True)

    control_file, reference_file_out, img_file = file_out_list

    if user_img_file is not None:

        os.unlink(img_file.name)
        img_file = user_img_file

        if not img_file.name.endswith(page_format):
            msg = "Image format should be: {f}. Please fix.".format(f=page_format)
            message(msg, type="ERROR")

        test_path = os.path.abspath(img_file.name)
        test_path = os.path.dirname(test_path)

        if not os.path.exists(test_path):
            os.makedirs(test_path)

    # -------------------------------------------------------------------------
    #
    # Read the reference list
    #
    # -------------------------------------------------------------------------

    try:
        reference_genes = pd.read_csv(reference_gene_file.name, sep="\t", header=None)
    except pd.errors.EmptyDataError:
        message("No genes in --reference-gene-file.", type="ERROR")

    reference_genes.rename(columns={reference_genes.columns.values[0]: 'gene'}, inplace=True)

    # -------------------------------------------------------------------------
    #
    # Delete duplicates
    #
    # -------------------------------------------------------------------------

    before = len(reference_genes)
    reference_genes = reference_genes.drop_duplicates(['gene'])
    after = len(reference_genes)

    msg = "%d duplicate lines have been deleted in reference file."
    message(msg % (before - after))

    # -------------------------------------------------------------------------
    #
    # Read expression data and add the pseudo_count
    #
    # -------------------------------------------------------------------------

    if skip_first:
        exp_data = pd.read_csv(in_file.name, sep="\t",
                               header=None, index_col=None,
                               skiprows=[0], names=['exprs'])
    else:

        exp_data = pd.read_csv(in_file.name, sep="\t", names=['exprs'], index_col=0)

    exp_data.exprs = exp_data.exprs.values + pseudo_count

    # -------------------------------------------------------------------------
    #
    # log transformation
    #
    # -------------------------------------------------------------------------

    ylabel = 'Expression'

    if log2:
        if len(exp_data.exprs.values[exp_data.exprs.values == 0]):
            message("Can't use log transformation on zero or negative values. Use -p.",
                    type="ERROR")
        else:
            exp_data.exprs = np.log2(exp_data.exprs.values)
            ylabel = 'log2(Expression)'

    # -------------------------------------------------------------------------
    #
    # Are reference gene found in control list
    #
    # -------------------------------------------------------------------------

    # Sort in increasing order
    exp_data = exp_data.sort_values('exprs')

    #  Vector with positions indicating which in the
    # expression data list are found in reference_gene

    reference_genes_found = [x for x in reference_genes['gene'] if x in exp_data.index]

    msg = "Found %d genes of the reference in the provided signal file" % len(reference_genes_found)
    message(msg)

    not_found = [x for x in reference_genes['gene'] if x not in exp_data.index]

    if len(not_found):
        if len(not_found) == len(reference_genes):
            message("Genes from reference file where not found in signal file (n=%d)." % len(not_found), type="ERROR")
        else:
            message("List of reference genes not found :%s" % not_found)
    else:
        message("All reference genes were found.")

    # -------------------------------------------------------------------------
    #
    # Search for genes with matched signal
    #
    # -------------------------------------------------------------------------

    exp_data_save = exp_data.copy()

    control_list = list()

    nb_candidate_left = exp_data.shape[0] - len(reference_genes_found)

    message("Searching for genes with matched signal.")

    if nb_candidate_left < len(reference_genes_found):
        message("Not enough element to perform selection. Exiting", type="ERROR")

    for i in reference_genes_found:
        not_candidates = reference_genes_found + control_list
        not_candidates = list(set(not_candidates))

        diff = abs(exp_data.loc[i] - exp_data)
        control_list.extend(diff.loc[np.setdiff1d(diff.index, not_candidates)].idxmin(axis=0, skipna=True).tolist())

    # -------------------------------------------------------------------------
    #
    # Prepare a dataframe for plotting
    #
    # -------------------------------------------------------------------------

    message("Preparing a dataframe for plotting.")

    reference = exp_data_save.loc[reference_genes_found].sort_values('exprs')
    reference = reference.assign(genesets=['Reference'] * reference.shape[0])

    control = exp_data_save.loc[control_list].sort_values('exprs')
    control = control.assign(genesets=['Control'] * control.shape[0])

    data = pd.concat([reference, control])
    data['sets'] = pd.Series(['sets' for x in data.index.tolist()], index=data.index)
    data['genesets'] = Categorical(data['genesets'])

    # -------------------------------------------------------------------------
    #
    # Diagnostic plots
    #
    # -------------------------------------------------------------------------

    p = ggplot(data, aes(x='sets', y='exprs', fill='genesets'))

    p += scale_fill_manual(values=dict(zip(['Reference', 'Control'], set_colors)))

    p += geom_violin(color=None)

    p += xlab('Gene sets') + ylab(ylabel)

    p += facet_wrap('~genesets')

    if rug:
        p += geom_rug()

    if jitter:
        p += geom_jitter()

    p += theme_bw()
    p += theme(axis_text_x=element_blank())

    # -------------------------------------------------------------------------
    # Turn warning off. Both pandas and plotnine use warnings for deprecated
    # functions. I need to turn they off although I'm not really satisfied with
    # this solution...
    # -------------------------------------------------------------------------

    def fxn():
        warnings.warn("deprecated", DeprecationWarning)

    # -------------------------------------------------------------------------
    #
    # Saving
    #
    # -------------------------------------------------------------------------

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        fxn()
        message("Saving diagram to file : " + img_file.name)
        message("Be patient. This may be long for large datasets.")

        try:
            p.save(filename=img_file.name, width=page_width, height=page_height, dpi=dpi, limitsize=False)
        except PlotnineError as err:
            message("Plotnine message: " + err.message)
            message("Plotnine encountered an error.", type="ERROR")

    # -------------------------------------------------------------------------
    #
    # write results
    #
    # -------------------------------------------------------------------------

    exp_data_save.loc[reference_genes_found].sort_values('exprs').to_csv(reference_file_out.name, sep="\t")
    exp_data_save.loc[control_list].sort_values('exprs').to_csv(control_file.name, sep="\t")
Пример #24
0
def n_es_genes(df: pd.DataFrame,
               annotation: pd.Series,
               figsize: tuple = None) -> p9.ggplot:
    """Plot distribution of number of ES genes per group
    
    Computes the number of ES genes per column, e.g. cell(-type) 
    and plots the distribution for the groups specified
    by the annotation.
    
    Parameters
    ----------
    df : DataFrame
        Dataframe containing positive ES weights, ideally use only ESmu.
    annotation : Series
        Annotation to group dataframe cell(-types) by in the violin plots.
    figsize : (float, float), optional (default: None)
        Specify width and height of plot.
    
    Returns
    -------
    p : ggplot
        A plotnine ggplot

    """

    ### Count number of non-zero values, i.e. ESw > 0
    df = df.astype(bool).sum(axis=0)

    ### Map column labels to annotation
    if type(annotation) is pd.DataFrame:
        annotation = annotation.iloc[:, 0]

    # remove duplicates
    annotation = annotation.loc[~annotation.index.duplicated(keep='first')]

    df.index = df.index.map(annotation, na_action="ignore").values.astype(str)

    # Constants, height and width of plot.
    if figsize is None:
        W = min((df.index.nunique(), 10))
        H = 6.4  # plotnine default height
    else:
        W, H = figsize

    ### Convert to tidy / long format
    # Org:
    #       ABC  ACBG  ACMB
    # POMC  0.0   0.5   0.9
    # AGRP  0.2   0.0   0.0
    # LEPR  0.1   0.1   0.4

    # Tidy:
    #   gene_name annotation    es_weight
    # 1 POMC      ABC           0.0
    # 2 AGRP      ABC           0.6
    # 3 LEPR      ABC           1.0
    df_tidy = df.copy()
    df_tidy.index.name = None
    df_tidy = pd.melt(df_tidy.reset_index(),
                      id_vars="index",
                      var_name="annotation",
                      value_name="count")

    ### Compute the mean count of ES genes
    mean_count = df_tidy["count"].mean(axis=0)

    ### Plot
    p = (
        ### data
        p9.ggplot(
            data=df_tidy,
            mapping=p9.aes(x="index", y="count", fill="index", label="index"),
        )

        ### theming
        + p9.theme_classic() + p9.theme(
            figure_size=(W, H), axis_text_x=p9.element_text(rotation=75)) +
        p9.labs(
            x="",  # e.g. "Cell-type"
            y="Number of ES genes",  # e.g. "ES weight"
        )

        ### viz
        + p9.geom_violin(scale="width", show_legend=False) +
        p9.geom_jitter(width=0.1, height=0, show_legend=False) +
        p9.geom_hline(yintercept=mean_count,
                      color="blue",
                      linetype="dashed",
                      show_legend=False))

    return p
Пример #25
0
    def plot_char_percent_vs_accuracy_smooth(self, expo=False):
        if expo:
            p = (ggplot(self.char_plot_df) +
                 facet_wrap('Guessing_Model', nrow=1) +
                 aes(x='char_percent', y='correct', color='Dataset') +
                 stat_smooth(
                     method='mavg', se=False, method_args={'window': 200}) +
                 scale_y_continuous(breaks=np.linspace(0, 1, 11)) +
                 scale_x_continuous(breaks=[0, .5, 1]) +
                 xlab('Percent of Question Revealed') + ylab('Accuracy') +
                 theme(legend_position='top'))
            if os.path.exists('data/external/human_gameplay.json'):
                with open('data/external/human_gameplay.json') as f:
                    gameplay = json.load(f)
                    control_correct_positions = gameplay[
                        'control_correct_positions']
                    control_wrong_positions = gameplay[
                        'control_wrong_positions']
                    control_positions = control_correct_positions + control_wrong_positions
                    control_positions = np.array(control_positions)
                    control_result = np.array(
                        len(control_correct_positions) * [1] +
                        len(control_wrong_positions) * [0])
                    argsort_control = np.argsort(control_positions)
                    control_x = control_positions[argsort_control]
                    control_sorted_result = control_result[argsort_control]
                    control_y = control_sorted_result.cumsum(
                    ) / control_sorted_result.shape[0]
                    control_df = pd.DataFrame({
                        'correct': control_y,
                        'char_percent': control_x
                    })
                    control_df['Dataset'] = 'Test Questions'
                    control_df['Guessing_Model'] = ' Human'

                    adv_correct_positions = gameplay['adv_correct_positions']
                    adv_wrong_positions = gameplay['adv_wrong_positions']
                    adv_positions = adv_correct_positions + adv_wrong_positions
                    adv_positions = np.array(control_positions)
                    adv_result = np.array(
                        len(adv_correct_positions) * [1] +
                        len(adv_wrong_positions) * [0])
                    argsort_adv = np.argsort(adv_positions)
                    adv_x = adv_positions[argsort_adv]
                    adv_sorted_result = adv_result[argsort_adv]
                    adv_y = adv_sorted_result.cumsum(
                    ) / adv_sorted_result.shape[0]
                    adv_df = pd.DataFrame({
                        'correct': adv_y,
                        'char_percent': adv_x
                    })
                    adv_df['Dataset'] = 'Challenge Questions'
                    adv_df['Guessing_Model'] = ' Human'

                    human_df = pd.concat([control_df, adv_df])
                    p = p + (geom_line(data=human_df))

            return p
        else:
            return (
                ggplot(self.char_plot_df) +
                aes(x='char_percent', y='correct', color='Guessing_Model') +
                stat_smooth(
                    method='mavg', se=False, method_args={'window': 500}) +
                scale_y_continuous(breaks=np.linspace(0, 1, 21)))
    int(grouped_candidates_pred_df.hetionet.value_counts()[1]),
    "relation":
    "CtD"
})
datarows.append({
    "edges": (grouped_candidates_pred_df.query(
        "pred_max > @optimal_threshold").hetionet.value_counts()[0]),
    "in_hetionet":
    "Novel",
    "relation":
    "CtD"
})
edges_df = pd.DataFrame.from_records(datarows)
edges_df

# In[26]:

g = (p9.ggplot(edges_df, p9.aes(x="relation", y="edges", fill="in_hetionet")) +
     p9.geom_col(position="dodge") + p9.geom_text(p9.aes(label=(
         edges_df.apply(lambda x: f"{x['edges']} ({x['recall']*100:.0f}%)"
                        if not math.isnan(x['recall']) else f"{x['edges']}",
                        axis=1))),
                                                  position=p9.position_dodge(
                                                      width=1),
                                                  size=9,
                                                  va="bottom") +
     p9.scale_y_log10() + p9.theme(axis_text_y=p9.element_blank(),
                                   axis_ticks_major=p9.element_blank(),
                                   rect=p9.element_blank()))
print(g)
Пример #27
0
    + geom_errorbar(all_svcca[all_svcca['Group'] == 'uncorrected'],
                  aes(x=lst_num_experiments, ymin='ymin', ymax='ymax'),
                   color='darkgrey') \
    + geom_line(threshold,
                aes(x=lst_num_experiments, y='score'),
                linetype='dashed',
                size=1,
                color="darkgrey",
                show_legend=False) \
    + labs(x = "Number of Experiments",
           y = "Similarity score (SVCCA)",
           title = "Similarity across varying numbers of experiments") \
    + theme(plot_title=element_text(weight='bold'),
            plot_background=element_rect(fill="white"),
            panel_background=element_rect(fill="white"),
            panel_grid_major_x=element_line(color="lightgrey"),
            panel_grid_major_y=element_line(color="lightgrey"),
            axis_line=element_line(color="grey"),
            legend_key=element_rect(fill='white', colour='white')
           ) \
    + scale_color_manual(['#b3e5fc']) \

print(g)
ggsave(plot=g, filename=svcca_uncorrected_file, dpi=300)

# In[9]:

# Plot - uncorrected only black
lst_num_experiments = list(all_svcca.index[0:int(len(all_svcca.index) / 2)])

threshold = pd.DataFrame(pd.np.tile(permuted_score,
                                    (len(lst_num_experiments), 1)),
Пример #28
0
def plot_cellranger_vs_cellbender(samplename, raw_cellranger_mtx,
                                  filtered_cellranger_mtx,
                                  cellbender_unfiltered_h5, fpr,
                                  n_expected_cells, n_total_droplets_included,
                                  out_dir):
    """compare cellranger raw vs cellranger filtered vs cellbender outputs"""
    logging.info('samplename ' + str(samplename))
    logging.info('raw_cellranger_mtx ' + str(raw_cellranger_mtx))
    logging.info('filtered_cellranger_mtx ' + str(filtered_cellranger_mtx))
    logging.info('cellbender_unfiltered_h5 ' + str(cellbender_unfiltered_h5))
    logging.info('fpr ' + str(fpr))
    logging.info('n_expected_cells ' + str(n_expected_cells))
    logging.info('n_total_droplets_included ' + str(n_total_droplets_included))
    logging.info('out_dir ' + str(out_dir))

    # Make the output directory if it does not exist.
    if out_dir == '':
        out_dir = os.getcwd()
    else:
        os.makedirs(out_dir, exist_ok=True)
        out_dir = out_dir + '/fpr_' + fpr
        os.makedirs(out_dir, exist_ok=True)
        os.makedirs(out_dir + '/' + samplename, exist_ok=True)
        logging.info(out_dir)
    # logging.info(df.head())

    # Get compression opts for pandas
    compression_opts = 'gzip'
    if LooseVersion(pd.__version__) > '1.0.0':
        compression_opts = dict(method='gzip', compresslevel=9)

    # read cellranger raw
    adata_cellranger_raw = sc.read_10x_mtx(raw_cellranger_mtx,
                                           var_names='gene_symbols',
                                           make_unique=True,
                                           cache=False,
                                           cache_compression=compression_opts)

    # First filter out any cells that have 0 total counts
    zero_count_cells_cellranger_raw = adata_cellranger_raw.obs_names[np.where(
        adata_cellranger_raw.X.sum(axis=1) == 0)[0]]
    # sc.pp.filter_cells(adata, min_counts=1, inplace=True) # Minimum number of counts required for a cell to pass filtering.
    logging.info(
        "_cellranger_raw: Filtering {}/{} cells with 0 counts.".format(
            len(zero_count_cells_cellranger_raw), adata_cellranger_raw.n_obs))
    adata_cellranger_raw = adata_cellranger_raw[
        adata_cellranger_raw.obs_names.difference(
            zero_count_cells_cellranger_raw, sort=False)]

    sc.pp.calculate_qc_metrics(adata_cellranger_raw, inplace=True)

    logging.info('cellranger raw n barcodes(.obs) x cells(.var) .X.shape:')
    logging.info(adata_cellranger_raw.X.shape)
    logging.info('cellranger raw .obs:')
    logging.info(adata_cellranger_raw.obs)
    logging.info('cellranger raw .var:')
    logging.info(adata_cellranger_raw.var)

    df_total_counts = pd.DataFrame(data=adata_cellranger_raw.obs.sort_values(
        by=['total_counts'], ascending=False).total_counts)
    df_total_counts['barcode_row_number'] = df_total_counts.reset_index(
    ).index + 1
    df_total_counts['barcodes'] = df_total_counts.index
    df_total_counts_cellranger_raw = df_total_counts
    df_total_counts_cellranger_raw['dataset'] = 'Cellranger Raw'

    logging.info(df_total_counts)
    # read cellranger filtered
    adata_cellranger_filtered = sc.read_10x_mtx(
        filtered_cellranger_mtx,
        var_names='gene_symbols',
        make_unique=True,
        cache=False,
        cache_compression=compression_opts)

    # First filter out any cells that have 0 total counts
    zero_count_cells_cellranger_filtered = adata_cellranger_filtered.obs_names[
        np.where(adata_cellranger_filtered.X.sum(axis=1) == 0)[0]]
    # sc.pp.filter_cells(adata, min_counts=1, inplace=True) # Minimum number of counts required for a cell to pass filtering.
    logging.info(
        "_cellranger_filtered: Filtering {}/{} cells with 0 counts.".format(
            len(zero_count_cells_cellranger_filtered),
            adata_cellranger_filtered.n_obs))
    adata_cellranger_filtered = adata_cellranger_filtered[
        adata_cellranger_filtered.obs_names.difference(
            zero_count_cells_cellranger_filtered, sort=False)]

    sc.pp.calculate_qc_metrics(adata_cellranger_filtered, inplace=True)

    logging.info(
        'cellranger filtered n barcodes(.obs) x cells(.var) .X.shape:')
    logging.info(adata_cellranger_filtered.X.shape)
    logging.info('cellranger filtered .obs:')
    logging.info(adata_cellranger_filtered.obs.columns)
    logging.info(adata_cellranger_filtered.obs)
    logging.info('cellranger filtered .var:')
    logging.info(adata_cellranger_filtered.var)

    df_total_counts = pd.DataFrame(
        data=adata_cellranger_filtered.obs.sort_values(
            by=['total_counts'], ascending=False).total_counts)
    df_total_counts['barcodes'] = df_total_counts.index
    df_total_counts['barcode_row_number'] = df_total_counts.reset_index(
    ).index + 1
    df_total_counts_cellranger_filtered = df_total_counts
    df_total_counts_cellranger_filtered['dataset'] = 'Cellranger Filtered'

    logging.info(df_total_counts)
    # read cellbender output
    adata_cellbender = anndata_from_h5(cellbender_unfiltered_h5,
                                       analyzed_barcodes_only=True)

    # First filter out any cells that have 0 total counts
    zero_count_cells_cellbender_filtered = adata_cellbender.obs_names[np.where(
        adata_cellbender.X.sum(axis=1) == 0)[0]]
    # sc.pp.filter_cells(adata, min_counts=1, inplace=True) # Minimum number of counts required for a cell to pass filtering.
    logging.info(
        "_cellbender_filtered: Filtering {}/{} cells with 0 counts.".format(
            len(zero_count_cells_cellbender_filtered), adata_cellbender.n_obs))
    adata_cellbender = adata_cellbender[adata_cellbender.obs_names.difference(
        zero_count_cells_cellbender_filtered, sort=False)]

    sc.pp.calculate_qc_metrics(adata_cellbender, inplace=True)

    logging.info(
        'cellbender cellbender.n barcodes(.obs) x cells(.var) .X.shape:')
    logging.info(adata_cellbender.X.shape)
    logging.info('cellbender cellbender.obs:')
    logging.info(adata_cellbender.obs)
    logging.info('cellbender cellbender.var:')
    logging.info(adata_cellbender.var)

    df_total_counts = pd.DataFrame(data=adata_cellbender.obs.sort_values(
        by=['total_counts'], ascending=False).total_counts)
    df_total_counts['barcodes'] = df_total_counts.index
    df_total_counts['barcode_row_number'] = df_total_counts.reset_index(
    ).index + 1
    df_total_counts_cellbender = df_total_counts
    df_total_counts_cellbender['dataset'] = 'Cellbender'

    logging.info(df_total_counts)

    # df_total_counts_cellranger_filtered.rename(columns={"total_counts": "cellranger_filtered_total_counts"})
    df_cellranger_cellbender = pd.merge(
        df_total_counts_cellranger_filtered,
        df_total_counts_cellbender,
        how='outer',
        left_index=True,
        right_index=True,
        suffixes=('_cellranger',
                  '_cellbender')).sort_values(by=['total_counts_cellbender'],
                                              ascending=False)
    logging.info(df_cellranger_cellbender)
    df_cellranger_cellbender[['cellranger', 'cellbender']] = np.where(
        df_cellranger_cellbender[[
            'total_counts_cellranger', 'total_counts_cellbender'
        ]].isnull(), 0, 1)

    #df_cellranger_cellbender.to_csv('df_cellranger_cellbender.csv', index=True, index_label='barcode')

    grouped = df_cellranger_cellbender[['cellranger', 'cellbender']].groupby(
        ["cellranger", "cellbender"]).size().reset_index(name='counts')
    logging.info(grouped.columns)
    #grouped.to_csv('cellranger_cellbender.csv', index=False)

    df_cellranger_cellbender[
        'barcode_row_number'] = df_cellranger_cellbender.reset_index(
        ).index + 1

    ### plot UMI counts descending order
    df_merged = pd.concat([
        df_total_counts_cellranger_raw, df_total_counts_cellranger_filtered,
        df_total_counts_cellbender
    ])
    #df_merged.to_csv('df_merged.csv', index=True, index_label='barcode')

    df_vline = pd.DataFrame(
        data={
            'x': [int(n_expected_cells),
                  int(n_total_droplets_included)],
            'color': ['expected-cells', 'total-droplets-included']
        })

    gplt = ggplot(df_merged, aes(x='barcode_row_number', y='total_counts')) \
        + geom_point() \
        + geom_vline(df_vline, aes(xintercept='x', color='color')) \
        + theme_bw() + facet_wrap('dataset') \
        + labs(x='Barcodes (ordered by descending cell total couts)',color='Cellbender input',
               y='Cell total counts', title='Cells filtered out by Cellranger or Cellbender') \
        + scale_y_continuous(trans='log10',minor_breaks=0) + scale_x_continuous(trans='log10',minor_breaks=0)
    gplt.save(out_dir + '/' + samplename + '/barcode_vs_total_counts.png',
              width=12,
              height=5,
              dpi=300)  # dpi=300,

    df_cellranger_cellbender_count = grouped  # pd.read_csv('cellranger_cellbender.csv')

    df = pd.merge(df_merged,
                  df_cellranger_cellbender[['cellranger', 'cellbender']],
                  how='left',
                  left_index=True,
                  right_index=True)
    df = pd.merge(df,
                  df_cellranger_cellbender_count,
                  how='left',
                  left_on=['cellranger', 'cellbender'],
                  right_on=['cellranger', 'cellbender'])
    df["counts"].fillna(df['counts'].isnull().sum(), inplace=True)
    df["counts"] = df["counts"].astype(int)
    # df.replace({"counts": {""}  }, inplace=True)

    df["filtered"] = df["cellranger"].astype(
        str) + '-' + df["cellbender"].astype(str)
    df.replace(
        {
            "filtered": {
                "nan-nan": 'Cellranger Raw only',
                "1.0-1.0": "Cellranger Filtered + Cellbender",
                "1.0-0.0": "Cellranger Filtered only",
                "0.0-1.0": "Cellbender only",
                "0.0-0.0": "0.0-0.0"
            }
        },
        inplace=True)
    df["filtered"] = df["filtered"] + ', n=' + df["counts"].astype(str)
    df['filtered'].value_counts()
    df.replace(
        {
            "dataset": {
                "cellbender": "Cellbender output",
                "cellranger_raw": "Cellranger Raw output",
                "cellranger_filtered": "Cellranger Filtered output"
            }
        },
        inplace=True)


    gplt = ggplot(df, aes(x='filtered', y='total_counts', color='filtered')) \
        + geom_boxplot() \
        + theme_bw() \
        + facet_wrap('dataset') \
        + theme(axis_text_x=element_blank()) \
        + scale_y_continuous(trans='log10',minor_breaks=0) \
        + labs(color='n cells in intersection of datasets', x='', y='Cell total counts', title='Total cell counts compared across datasets (facets)')
    gplt.save(out_dir + '/' + samplename +
              '/boxplots_cellranger_vs_cellbender.png',
              width=12,
              height=5,
              dpi=300)  # dpi=300,

    # plot difference cellbender filtered vs cellranger filtered for common cells between the 2 datasets
    df_cellranger_cellbender = df_cellranger_cellbender[
        df_cellranger_cellbender['cellranger'] == 1]
    df_cellranger_cellbender = df_cellranger_cellbender[
        df_cellranger_cellbender['cellbender'] == 1]

    # Subset the datasets to the relevant barcodes.
    adata_cellbender_common = adata_cellbender[
        df_cellranger_cellbender.index.values]
    adata_cellranger_filtered_common = adata_cellranger_filtered[
        df_cellranger_cellbender.index.values]
    # Put count matrices into 'layers' in anndata for clarity.
    adata = adata_cellbender_common
    adata.layers['counts_cellbender'] = adata_cellbender_common.X.copy()
    adata.layers['counts_raw'] = adata_cellranger_filtered_common.X.copy()
    # Get the differences in counts per cell
    X_raw_minus_cb = adata.layers['counts_raw'] - adata.layers[
        'counts_cellbender']
    X_raw_minus_cb = abs(X_raw_minus_cb)
    # Get the top most different genes
    df_diff_genes = pd.DataFrame(data=adata.var.gene_symbols.values)
    df_diff_genes['ensembl_id'] = adata.var.index
    df_diff_genes['gene_symbol'] = adata.var.gene_symbols.values
    df_diff_genes['dif_across_cells'] = np.asarray(
        X_raw_minus_cb.sum(axis=0)).reshape(-1)
    df_diff_genes = df_diff_genes.sort_values('dif_across_cells',
                                              ascending=False).head(n=100)
    #df_diff_genes.to_csv('df_diff_genes.csv', index=True)
    top_genes = df_diff_genes['ensembl_id']
    top_genes_symbols = df_diff_genes['gene_symbol']
    logging.info('top_genes:')
    logging.info(top_genes)

    logging.info(adata_cellbender_common.var.index)
    adata_cellbender_common = adata_cellbender[
        df_cellranger_cellbender.index.values, top_genes].to_df()
    adata_cellbender_common['barcode'] = adata_cellbender_common.index
    adata_cellbender_common = pd.melt(adata_cellbender_common,
                                      ignore_index=True,
                                      id_vars=['barcode'],
                                      var_name='ensembl_id',
                                      value_name='count')
    adata_cellbender_common = pd.merge(
        adata_cellbender_common,
        df_diff_genes[['ensembl_id', 'gene_symbol']],
        how='left',
        left_on='ensembl_id',
        right_on='ensembl_id')
    adata_cellbender_common = adata_cellbender_common.sort_values(
        by=['barcode', 'ensembl_id'], ascending=False)
    adata_cellbender_common['dataset'] = 'Cellbender'
    #adata_cellbender_common.to_csv('adata_cellbender_common.csv', index=True)

    logging.info(adata_cellranger_filtered.var.index)
    adata_cellranger_filtered_common = adata_cellranger_filtered[
        df_cellranger_cellbender.index.values, top_genes_symbols].to_df()
    adata_cellranger_filtered_common[
        'barcode'] = adata_cellranger_filtered_common.index
    adata_cellranger_filtered_common = pd.melt(
        adata_cellranger_filtered_common,
        ignore_index=True,
        id_vars=['barcode'],
        var_name='gene_symbol',
        value_name='count')
    adata_cellranger_filtered_common = pd.merge(
        adata_cellranger_filtered_common,
        df_diff_genes[['ensembl_id', 'gene_symbol']],
        how='left',
        left_on='gene_symbol',
        right_on='gene_symbol')
    adata_cellranger_filtered_common['dataset'] = 'Cellranger Filtered'
    adata_cellranger_filtered_common = adata_cellranger_filtered_common.sort_values(
        by=['barcode', 'ensembl_id'], ascending=False)
    adata_cellranger_filtered_common = adata_cellranger_filtered_common[
        adata_cellbender_common.columns]
    #adata_cellranger_filtered_common.to_csv('adata_cellranger_filtered_common.csv', index=True)

    logging.info(adata_cellranger_raw.var.index)
    adata_cellranger_raw_common = adata_cellranger_raw[
        df_cellranger_cellbender.index.values, top_genes_symbols].to_df()
    adata_cellranger_raw_common['barcode'] = adata_cellranger_raw_common.index
    adata_cellranger_raw_common = pd.melt(adata_cellranger_raw_common,
                                          ignore_index=True,
                                          id_vars=['barcode'],
                                          var_name='gene_symbol',
                                          value_name='count')
    adata_cellranger_raw_common = pd.merge(
        adata_cellranger_raw_common,
        df_diff_genes[['ensembl_id', 'gene_symbol']],
        how='left',
        left_on='gene_symbol',
        right_on='gene_symbol')
    adata_cellranger_raw_common['dataset'] = 'Cellranger Raw'
    adata_cellranger_raw_common = adata_cellranger_raw_common.sort_values(
        by=['barcode', 'ensembl_id'], ascending=False)
    adata_cellranger_raw_common = adata_cellranger_raw_common[
        adata_cellbender_common.columns]
    #adata_cellranger_raw_common.to_csv('adata_cellranger_raw_common.csv', index=True)

    logging.info(adata_cellranger_raw_common['gene_symbol'] ==
                 adata_cellbender_common['gene_symbol'])
    logging.info(adata_cellranger_raw_common['ensembl_id'] ==
                 adata_cellbender_common['ensembl_id'])

    adata_filtered_cellbender_diff = adata_cellbender_common.copy()
    adata_filtered_cellbender_diff['count'] = adata_cellranger_filtered_common[
        'count'] - adata_cellbender_common['count']
    adata_filtered_cellbender_diff[
        'dataset'] = 'Cellranger Filtered - Cellbender'

    adata_raw_cellbender_diff = adata_cellbender_common.copy()
    adata_raw_cellbender_diff['count'] = adata_cellranger_raw_common[
        'count'] - adata_cellbender_common['count']
    adata_raw_cellbender_diff['dataset'] = 'Cellranger Raw - Cellbender'

    df_merged = pd.concat([
        adata_cellbender_common, adata_cellranger_filtered_common,
        adata_cellranger_raw_common, adata_filtered_cellbender_diff,
        adata_raw_cellbender_diff
    ],
                          ignore_index=True)

    gplt = ggplot(df_merged, aes(x='gene_symbol',y='count')) \
        + geom_boxplot() \
        + theme_bw() \
        + theme(axis_text_x = element_text(angle = 90, hjust = 1, size= 6)) \
        + facet_wrap('dataset', scales = 'free', ncol = 1) \
        + labs(x='Genes (top 100 Genes most different between Cellranger Filtered counts and Cellbender filtered counts)', y='Cell total counts', title='Total cell counts compared across most different genes (x-axis) and datasets (facets)')
    gplt.save(out_dir + '/' + samplename +
              '/boxplot_topgenes_cellranger_vs_cellbender.png',
              width=10,
              height=20,
              dpi=300)  # dpi=300,
    logging.info('script done.')
Пример #29
0
    def plot_tag_repartition(self, data, options):
        tag_df = data["tags"]
        if not "background" in tag_df.columns:
            tag_df["background"] = False
        test = tag_df[["tag", "matched", "background", "id"]].copy()
        test.loc[:, "prop_matched"] = -1
        test.loc[:, "prop_background"] = -1
        test.loc[:, "lbl_matched"] = ""
        test.loc[:, "lbl_background"] = ""
        test.loc[:, "n_tags"] = -1

        n_total = test.shape[0]
        n_matched = test.matched.value_counts()

        tags_summary = (
            test.groupby("tag")
            .apply(self.get_proportions, by=["matched", "background"])
            .reset_index(drop=True)
        )
        tags_summary = tags_summary.sort_values(["tag", "matched", "background"])

        plt = ggplot(
            data=tags_summary,
            mapping=aes(
                x="tag",  # "factor(species, ordered=False)",
                y="n_tags",
                fill="background",
                ymax=max(tags_summary.n_tags) + 35,  # "factor(species, ordered=False)",
            ),
        )
        plot_width = 10 + len(tags_summary.tag.unique()) * 0.75
        plt = (
            plt
            + geom_bar(stat="identity", show_legend=True, position=position_dodge())
            + facet_wrap(
                "matched",
                nrow=1,
                ncol=2,
                scales="fixed",
                labeller=(lambda x: self.get_matched_label(x, n_total, n_matched)),
            )
            + xlab("Species")
            + ylab("Number of annotations")
            + geom_text(
                mapping=aes(label="lbl_background"), position=position_dodge(width=0.9),
            )
            + geom_text(
                mapping=aes(y=max(tags_summary.n_tags) + 30, label="lbl_matched",)
            )
            + theme_classic()
            + theme(
                axis_text_x=element_text(angle=90, vjust=1, hjust=1, margin={"r": -30}),
                plot_title=element_text(
                    weight="bold", size=14, margin={"t": 10, "b": 10}
                ),
                figure_size=(plot_width, 10),
                text=element_text(size=12, weight="bold"),
            )
            + ggtitle(
                (
                    "Tag repartition for model {}, database {}, class {}\n"
                    + "with detector options {}"
                ).format(
                    options["scenario_info"]["model"],
                    options["scenario_info"]["database"],
                    options["scenario_info"]["class"],
                    options,
                )
            )
        )

        return plt
Пример #30
0
    def __plot(
        self,
        plot_data,
        x,
        y,
        colour,
        lbl_x,
        lbl_y,
        facet,
        facet_scales,
        facet_by,
        smoothed,
        points,
        error_bars,
        save,
    ):
        cbbPalette = [
            "#000000",
            "#E69F00",
            "#56B4E9",
            "#009E73",
            "#0072B2",
            "#D55E00",
            "#CC79A7",
        ]
        plt = ggplot(data=plot_data, mapping=aes(x=x, y=y, colour=colour))
        plt += xlab(lbl_x)
        plt += ylab(lbl_y)
        # + facet_grid("site~", scales="free")
        # + geom_line()
        if facet:
            # TODO: use facet as save
            nrow, ncol = self.get_facet_rows(plot_data, facet_by)
            plt += facet_wrap(facet_by,
                              nrow=nrow,
                              ncol=ncol,
                              scales=facet_scales)
        if points:
            plt += geom_point()
        if error_bars:
            # TODO use generic way to compute them
            pass
            # self.plt += geom_errorbar(aes(ymin="ACI_mean - ACI_std", ymax="ACI_mean + ACI_std"))
        # TODO: use smooth as save
        if smoothed:
            plt += geom_smooth(
                method="mavg",
                se=False,
                method_args={
                    "window": 4,
                    "center": True,
                    "min_periods": 1
                },
            )
        else:
            plt += geom_line()
        plt += scale_colour_manual(values=cbbPalette, guide=False)
        plt += scale_x_continuous(labels=label_x)

        plt += theme(figure_size=(15, 18), dpi=150)

        if save:
            plt.save(**save)
        return plt
    # plot the x axis titles
    + p9.geom_vline(xintercept=[2.5, 14.5, 26.5, 38.5, 50.5, 62.5, 74.5]) +
    p9.geom_text(label="2014", x=8.5, y=0, color="black") +
    p9.geom_text(label="2015", x=20.5, y=0, color="black") +
    p9.geom_text(label="2016", x=32.5, y=0, color="black") +
    p9.geom_text(label="2017", x=44.5, y=0, color="black") +
    p9.geom_text(label="2018", x=56.5, y=0, color="black") +
    p9.geom_text(label="2019", x=68.5, y=0, color="black")

    # Plot the overall proportion published
    + p9.geom_hline(
        yintercept=0.4196, linetype='solid', color=color_mapper['2018']) +
    p9.geom_hline(yintercept=published / posted,
                  linetype="solid",
                  color=color_mapper['2020ML']) +
    p9.annotate("text", x=8.5, y=0.395, label="overall: 0.4196", size=8) +
    p9.annotate("text",
                x=8.5,
                y=0.48,
                label=f"overall: {published/posted:.4f}",
                size=8) +
    p9.theme_seaborn(style='ticks', context='paper', font_scale=1.5) +
    p9.theme(figure_size=(10, 4.5),
             axis_text_x=p9.element_blank(),
             axis_title_x=p9.element_text(margin={"t": 15})) +
    p9.labs(y="Proportion Published", x="Month"))
g.save("output/figures/publication_rate.svg", dpi=250)
g.save("output/figures/publication_rate.png", dpi=250)
print(g)
    + geom_errorbar(all_svcca,
                  aes(x=lst_num_experiments, ymin='ymin', ymax='ymax'),
                   color='darkgrey') \
    + geom_line(threshold, 
                aes(x=lst_num_experiments, y='score'), 
                linetype='dashed',
                size=1.5,
                color="darkgrey",
                show_legend=False) \
    + labs(x = "Number of Partitions", 
           y = "Similarity score (SVCCA)", 
           title = "Similarity across varying numbers of partitions") \
    + theme(plot_title=element_text(weight='bold'),
            plot_background=element_rect(fill="white"),
            panel_background=element_rect(fill="white"),
            panel_grid_major_x=element_line(color="lightgrey"),
            panel_grid_major_y=element_line(color="lightgrey"),
            axis_line=element_line(color="grey"),
            legend_key=element_rect(fill='white', colour='white')
           ) \
    + scale_color_manual(['#1976d2', '#b3e5fc']) \
    

print(panel_A)
ggsave(plot=panel_A, filename=svcca_file, device="svg", dpi=300)
ggsave(plot=panel_A, filename=svcca_png_file, device="svg", dpi=300)


# ## Uncorrected PCA panel

# In[11]:
Пример #33
0
p = (
    gg.ggplot(filter_melt_df,
              gg.aes(x='lane', y='filtration', fill='num_variants_cat')) +
    gg.geom_bar(stat='identity', position='dodge') +
    gg.facet_wrap('~ final_id') +
    gg.scale_fill_manual(name='Filtration Step',
                         values=['#1b9e77', '#d95f02', '#7570b3', '#e7298a'],
                         labels=['All Variants',
                                 'Common Variants',
                                 'Depth (< {} reads)'.format(replicate_filter_min_depth_count),
                                 'Depth (> {} reads)'.format(replicate_filter_max_depth_count)]) + 
    gg.xlab('Sample') +
    gg.ylab('Final Number of Variants') +
    gg.theme_bw() +
    gg.theme(axis_text_x=gg.element_text(angle='90'),
             axis_text=gg.element_text(size=8),
             axis_title=gg.element_text(size=14))
    )
p


# In[13]:


figure_file = os.path.join('figures', 'replicates_filtration_results.pdf')
gg.ggsave(p, figure_file, height=5.5, width=6.5, dpi=500)


# In[14]:

Пример #34
0
def search_room(dataframe: pd.DataFrame) -> bool:

    # Search top 100
    top100 = st.sidebar.checkbox(
        "Filter top 100 apartments",
        help="filter only the top 100 apartments by price",
    )

    # Search by price
    min_price, max_price = st.sidebar.slider(
        "Search apartments by price",
        min(dataframe.price),
        max(dataframe.price),
        (min(dataframe.price), max(dataframe.price)),
        help="Insert the min and max price",
    )

    # Search by review_scores_rating

    # Search by room type

    # Search by Beds

    # Search by Beds

    # Search by Bathrooms

    # Search by Accomodates

    # Select columns for plot
    to_select = st.sidebar.multiselect(
        "Seleziona le colonne che vuoi visualizzare",
        list(dataframe.columns),
        [i for i in list(dataframe.columns)],
        help="Seleziona le colonne che vuoi considerare",
    )

    if top100:
        dataframe = dataframe.groupby("price").head(100)

    dataframe_filtered = dataframe[to_select]

    dataframe_filtered = dataframe_filtered.loc[
        dataframe.price.between(min_price, max_price)
    ]
    # Launch the data visualization
    main_room_type(dataframe_filtered)

    st.sidebar.markdown("Select plot axis")
    axis1 = st.sidebar.selectbox(
        "Select first axis", list(dataframe_filtered.columns)
    )
    axis2 = st.sidebar.selectbox(
        "Select second axis", list(dataframe_filtered.columns)
    )

    scatterplot = st.sidebar.button(
        "Scatterplot", key="bscatterplot", help="Launch the scatterplot"
    )
    if scatterplot:
        fig = px.scatter(dataframe_filtered, x=axis1, y=axis2)
        st.markdown(f"Plot with: {axis1}, {axis2}")
        st.plotly_chart(fig)
        st.markdown("Raw data used")

        st.dataframe(
            dataframe_filtered.style.highlight_max(axis=0)
            .format({axis2: "{:.2%}"})
            .highlight_null(null_color="red")
            .set_caption("Result table with all the data filtered")
        )
        return True

    barplot = st.sidebar.button(
        "Barplot", key="bggplot", help="Launch the ggplot"
    )
    if barplot:

        st.markdown(
            "To launch this plot please remember to select all the columns in the data"
        )
        # plot_folder_path = os.path.join(get_folder_path("."), "plots")

        fig = (
            pn.ggplot(dataframe_filtered)
            + pn.aes(x=axis1, fill=axis2)
            + pn.geom_bar()
            + pn.theme(axis_text_x=pn.element_text(angle=45, hjust=1))
        )

        st.markdown("### Barplot")
        st.markdown(f"Displaying: {axis1} over {axis2}")
        st.pyplot(
            pn.ggplot.draw(fig),
            clear_figure=True,
            width=100,
            height=200,
            dpi=600,
        )
        # st.image(fig_path)
        # st.write(fig)

        # st.pyplot(fig)

    histogram = st.sidebar.button(
        "Histogram", key="bp9histogram", help="Launch the ggplot histogram"
    )
    if histogram:
        fig = (
            pn.ggplot(dataframe_filtered)
            + pn.aes(x="price")
            + pn.geom_histogram(fill="blue", colour="black", bins=30)
            + pn.xlim(0, 200)
        )

        st.markdown("### Histogram")
        st.markdown(f"Displaying: {axis1} over {axis2}")
        st.pyplot(
            pn.ggplot.draw(fig),
            clear_figure=True,
            width=100,
            height=200,
            dpi=600,
        )

    density = st.sidebar.button(
        "Density", key="bp9density", help="Launch the ggplot density"
    )
    if density:

        fig = (
            pn.ggplot(dataframe_filtered.head(1000))
            + pn.aes(x="price")
            + pn.geom_density(fill="blue", colour="black", alpha=0.5)
            + pn.xlim(0, 200)
        )

        st.markdown("### Density Plot")
        st.pyplot(
            pn.ggplot.draw(fig),
            clear_figure=True,
            width=100,
            height=200,
            dpi=600,
        )

    latlong = st.sidebar.button(
        "Latitude-Longitude",
        key="bp9latlon",
        help="Launch the ggplot latitude and longitude categorical comparison",
    )
    if latlong:
        # color categorical variable
        fig = (
            pn.ggplot(
                dataframe_filtered,
                pn.aes(x="latitude", y="longitude", colour="room_type"),
            )
            + pn.geom_point(alpha=0.5)
        )

        st.markdown("### Color categorical variable")
        st.pyplot(
            pn.ggplot.draw(fig),
            clear_figure=True,
            width=100,
            height=200,
            dpi=600,
        )

        return True

    return False
Пример #35
0
gradient = (
    (0.99, 0.88, 0.87),
    (0.98, 0.62, 0.71),
    (0.86, 0.20, 0.59),
    bcolor, bcolor,
    bcolor_darker, bcolor_darker)

df1 = df[:n//3:9]
df2 = df[n//3:2*n//3]
df3 = df[2*n//3::12]

p = (ggplot(aes('x', 'y', color='y', fill='y'))
     + annotate(geom='label', x=0.295, y=0.495, label='pl  tnine',
                label_size=1.5, label_padding=.1, size=24,
                fill=bcolor_lighter, color=bcolor)
     + geom_point(df1, size=8, stroke=0, show_legend=False)
     + geom_line(df2, size=2, color=bcolor_darker, show_legend=False)
     + geom_bar(df3, aes('x+.06'), stat='identity', size=0, show_legend=False)

     + scale_color_gradientn(colors=gradient)
     + scale_fill_gradientn(colors=gradient)
     + theme_void()
     + theme(figure_size=(3.6, 3.6)))

p.save('logo.pdf', pad_inches=-0.04)

# Remove the project name
p.layers = p.layers.__class__(p.layers[1:])
p.save('logo-small.pdf', pad_inches=-0.04)
# Plot total number of cells per well
cell_count_totalcells_df = (cell_count_df.groupby(
    ["x_loc", "y_loc", "well", "site_location",
     "site"])["total_cell_count"].mean().reset_index())

plate = cell_count_df["plate"].unique()[0]

os.makedirs(output_figuresdir, exist_ok=True)
by_well_gg = (
    gg.ggplot(cell_count_totalcells_df, gg.aes(x="x_loc", y="y_loc")) +
    gg.geom_point(gg.aes(fill="total_cell_count"), size=10) +
    gg.geom_text(gg.aes(label="site_location"), color="lightgrey") +
    gg.facet_wrap("~well") + gg.coord_fixed() + gg.theme_bw() +
    gg.ggtitle(f"Total Cells/Well\n{plate}") + gg.theme(
        axis_text=gg.element_blank(),
        axis_title=gg.element_blank(),
        strip_background=gg.element_rect(colour="black", fill="#fdfff4"),
    ) + gg.labs(fill="Cells") + gg.scale_fill_cmap(name="magma"))

output_file = pathlib.Path(output_figuresdir,
                           "plate_layout_cells_count_per_well.png")
if check_if_write(output_file, force, throw_warning=True):
    by_well_gg.save(output_file, dpi=300, verbose=False)

# Plot cell category ratios per well
ratio_df = pd.pivot_table(
    cell_count_df,
    values="cell_count",
    index=["site", "plate", "well", "site_location", "x_loc", "y_loc"],
    columns=["Cell_Quality"],
)
Пример #37
0
    def plot_char_percent_vs_accuracy_smooth(self, expo=False, no_models=False, columns=False):
        if self.y_max is not None:
            limits = [0, float(self.y_max)]
            eprint(f'Setting limits to: {limits}')
        else:
            limits = [0, 1]
        if expo:
            if os.path.exists('data/external/all_human_gameplay.json') and not self.no_humans:
                with open('data/external/all_human_gameplay.json') as f:
                    all_gameplay = json.load(f)
                    frames = []
                    for event, name in [('parents', 'Intermediate'), ('maryland', 'Expert'), ('live', 'National')]:
                        if self.merge_humans:
                            name = 'Human'
                        gameplay = all_gameplay[event]
                        if event != 'live':
                            control_correct_positions = gameplay['control_correct_positions']
                            control_wrong_positions = gameplay['control_wrong_positions']
                            control_positions = control_correct_positions + control_wrong_positions
                            control_positions = np.array(control_positions)
                            control_result = np.array(len(control_correct_positions) * [1] + len(control_wrong_positions) * [0])
                            argsort_control = np.argsort(control_positions)
                            control_x = control_positions[argsort_control]
                            control_sorted_result = control_result[argsort_control]
                            control_y = control_sorted_result.cumsum() / control_sorted_result.shape[0]
                            control_df = pd.DataFrame({'correct': control_y, 'char_percent': control_x})
                            control_df['Dataset'] = 'Regular Test'
                            control_df['Guessing_Model'] = f' {name}'
                            frames.append(control_df)

                        adv_correct_positions = gameplay['adv_correct_positions']
                        adv_wrong_positions = gameplay['adv_wrong_positions']
                        adv_positions = adv_correct_positions + adv_wrong_positions
                        adv_positions = np.array(adv_positions)
                        adv_result = np.array(len(adv_correct_positions) * [1] + len(adv_wrong_positions) * [0])
                        argsort_adv = np.argsort(adv_positions)
                        adv_x = adv_positions[argsort_adv]
                        adv_sorted_result = adv_result[argsort_adv]
                        adv_y = adv_sorted_result.cumsum() / adv_sorted_result.shape[0]
                        adv_df = pd.DataFrame({'correct': adv_y, 'char_percent': adv_x})
                        adv_df['Dataset'] = 'IR Adversarial'
                        adv_df['Guessing_Model'] = f' {name}'
                        frames.append(adv_df)

                        if len(gameplay['advneural_correct_positions']) > 0:
                            adv_correct_positions = gameplay['advneural_correct_positions']
                            adv_wrong_positions = gameplay['advneural_wrong_positions']
                            adv_positions = adv_correct_positions + adv_wrong_positions
                            adv_positions = np.array(adv_positions)
                            adv_result = np.array(len(adv_correct_positions) * [1] + len(adv_wrong_positions) * [0])
                            argsort_adv = np.argsort(adv_positions)
                            adv_x = adv_positions[argsort_adv]
                            adv_sorted_result = adv_result[argsort_adv]
                            adv_y = adv_sorted_result.cumsum() / adv_sorted_result.shape[0]
                            adv_df = pd.DataFrame({'correct': adv_y, 'char_percent': adv_x})
                            adv_df['Dataset'] = 'RNN Adversarial'
                            adv_df['Guessing_Model'] = f' {name}'
                            frames.append(adv_df)

                    human_df = pd.concat(frames)
                    human_vals = sort_humans(list(human_df['Guessing_Model'].unique()))
                    human_dtype = CategoricalDtype(human_vals, ordered=True)
                    human_df['Guessing_Model'] = human_df['Guessing_Model'].astype(human_dtype)
                    dataset_dtype = CategoricalDtype(['Regular Test', 'IR Adversarial', 'RNN Adversarial'], ordered=True)
                    human_df['Dataset'] = human_df['Dataset'].astype(dataset_dtype)

            if no_models:
                p = ggplot(human_df) + geom_point(shape='.')
            else:
                df = self.char_plot_df
                if 1 not in self.rounds:
                    df = df[df['Dataset'] != 'Round 1 - IR Adversarial']
                if 2 not in self.rounds:
                    df = df[df['Dataset'] != 'Round 2 - IR Adversarial']
                    df = df[df['Dataset'] != 'Round 2 - RNN Adversarial']
                p = ggplot(df)
                if self.save_df is not None:
                    eprint(f'Saving df to: {self.save_df}')
                    df.to_json(self.save_df)

                if os.path.exists('data/external/all_human_gameplay.json') and not self.no_humans:
                    eprint('Loading human data')
                    p = p + geom_line(data=human_df)

            if columns:
                facet_conf = facet_wrap('Guessing_Model', ncol=1)
            else:
                facet_conf = facet_wrap('Guessing_Model', nrow=1)

            if not no_models:
                if self.mvg_avg_char:
                    chart = stat_smooth(method='mavg', se=False, method_args={'window': 400})
                else:
                    chart = stat_summary_bin(fun_data=mean_no_se, bins=20, shape='.', linetype='None', size=0.5)
            else:
                chart = None

            p = (
                p + facet_conf
                + aes(x='char_percent', y='correct', color='Dataset')
            )
            if chart is not None:
                p += chart
            p = (
                p
                + scale_y_continuous(breaks=np.linspace(0, 1, 6))
                + scale_x_continuous(breaks=[0, .5, 1])
                + coord_cartesian(ylim=limits)
                + xlab('Percent of Question Revealed')
                + ylab('Accuracy')
                + theme(
                    #legend_position='top', legend_box_margin=0, legend_title=element_blank(),
                    strip_text_x=element_text(margin={'t': 6, 'b': 6, 'l': 1, 'r': 5})
                )
                + scale_color_manual(values=['#FF3333', '#66CC00', '#3333FF', '#FFFF33'], name='Questions')
            )
            if self.title != '':
                p += ggtitle(self.title)

            return p
        else:
            if self.save_df is not None:
                eprint(f'Saving df to: {self.save_df}')
                df.to_json(self.save_df)
            return (
                ggplot(self.char_plot_df)
                + aes(x='char_percent', y='correct', color='Guessing_Model')
                + stat_smooth(method='mavg', se=False, method_args={'window': 500})
                + scale_y_continuous(breaks=np.linspace(0, 1, 6))
                + coord_cartesian(ylim=limits)
            )
Пример #38
0
        alpha=0.8,
        size=0.6
    )
    + gg.theme_bw()
    + gg.xlab("UMAP (X)")
    + gg.ylab("UMAP (Y)")
    + gg.ggtitle("Four Clone Dataset - Merged")
    + gg.facet_wrap("~Metadata_plate_ID")
    + gg.scale_fill_manual(
        name="Batch",
        values=["#1b9e77", "#d95f02", "#7570b3"],
        labels=['Batch 5', "Batch 6", "Batch 7"]
    )
    + gg.scale_shape_manual(name="Treatment", values=[".", "+"])
    + gg.theme(
        strip_text=gg.element_text(size=6, color="black"),
        strip_background=gg.element_rect(colour="black", fill="#fdfff4"),
    )
)
    
file = os.path.join("figures", "umap", "four_clone_umap_plate_facet")
for extension in save_file_extensions:
    umap_batch_facet_gg.save(filename='{}{}'.format(file, extension), height=3, width=3.5, dpi=400)

umap_batch_facet_gg


# In[11]:


# Visualize UMAP results
clone_facet_gg = (
Пример #39
0
def test_add_partial_complete():
    theme1 = theme(axis_line_x=element_line())
    theme2 = theme_gray()
    theme3 = theme1 + theme2
    assert theme3 == theme2
Пример #40
0
# # BioRxiv Research Article Categories

# Categories assigned to each research article. Neuroscience dominates majority of the articles as expected.

# +
category_list = metadata_df.category.value_counts().index.tolist()[::-1]

# plot nine doesn't implement reverse keyword for scale x discrete
# ugh...
g = (p9.ggplot(metadata_df, p9.aes(x="category")) + p9.geom_bar(
    size=10, fill="#253494", position=p9.position_dodge(width=3)) +
     p9.scale_x_discrete(limits=category_list) + p9.coord_flip() +
     p9.theme_seaborn(
         context="paper", style="ticks", font="Arial", font_scale=1) +
     p9.theme(text=p9.element_text(size=12)))
g.save("output/figures/preprint_category.svg")
g.save("output/figures/preprint_category.png", dpi=300)
print(g)
# -

metadata_df["category"].value_counts()

# # New, Confirmatory, Contradictory Results?

# +
heading_list = metadata_df.heading.value_counts().index.tolist()[::-1]

g = (p9.ggplot(metadata_df, p9.aes(x="heading")) +
     p9.geom_bar(size=10, fill="#253494") +
     p9.scale_x_discrete(limits=heading_list) + p9.coord_flip() +
Пример #41
0
        robust_arr2[:, 1]
    ]),
                 columns=["method", "x", "y"]))
df = df.append(
    pd.DataFrame(np.transpose(
        [np.repeat("Mix", mix_arr2.shape[0]), mix_arr2[:, 0], mix_arr2[:, 1]]),
                 columns=["method", "x", "y"]))
df['x'] = pd.to_numeric((df['x']))
df['y'] = pd.to_numeric((df['y']))
p = (
    ggplot(df) + aes(x='x', y='y', shape='method', color='method')
    # + geom_point(size=4, stroke=0)
    + geom_line(size=1) + scale_shape_discrete(name='Method') +
    scale_color_brewer(type='qual', palette=2, name='Method') +
    xlab('Training Time (seconds)') + ylab('Error') +
    theme(aspect_ratio=0.8, ) + ggtitle("Baseline Error"))
p.save(test + "/baselineErr.png", verbose=False)

df = pd.DataFrame([], columns=["method", "x", "y"])
df = df.append(
    pd.DataFrame(np.transpose([
        np.repeat("Provable", robust_arr2.shape[0]), robust_arr2[:, 0],
        robust_arr2[:, 2]
    ]),
                 columns=["method", "x", "y"]))
df = df.append(
    pd.DataFrame(np.transpose(
        [np.repeat("Mix", mix_arr2.shape[0]), mix_arr2[:, 0], mix_arr2[:, 2]]),
                 columns=["method", "x", "y"]))
df['x'] = pd.to_numeric((df['x']))
df['y'] = pd.to_numeric((df['y']))
Пример #42
0
def test_facet_grid_scales_free_y():
    p = (g
         + facet_grid(['.', 'var1>2'], scales='free_y')
         + theme(panel_spacing_x=0.3))
    assert p == 'facet_grid_scales_free_y'
Пример #43
0
def test_facet_grid_scales_free_x():
    p = (g
         + facet_grid(['var1>2', '.'], scales='free_x')
         + theme(panel_spacing_y=0.3))
    assert p == 'facet_grid_scales_free_x'
Пример #44
0
def make_plots(leak_df, time_df, site_df, sim_n, spin_up, output_directory):
    """
    This function makes a set of standard plots to output at end of simulation.
    """
    # Temporarily mute warnings
    warnings.filterwarnings('ignore')
    pn.theme_set(pn.theme_linedraw())

    # Chop off spin-up year (only for plots, still exists in raw output)
    time_df_adj = time_df.iloc[spin_up:, ]

    # Timeseries plots
    plot_time_1 = (
        pn.ggplot(time_df_adj, pn.aes('datetime', 'daily_emissions_kg')) +
        pn.geom_line(size=2) +
        pn.ggtitle('Daily emissions from all sites (kg)') + pn.ylab('') +
        pn.xlab('') + pn.scale_x_datetime(labels=date_format('%Y')) + pn.theme(
            panel_border=pn.element_rect(colour="black", fill=None, size=2),
            panel_grid_minor_x=pn.element_blank(),
            panel_grid_major_x=pn.element_blank(),
            panel_grid_minor_y=pn.element_line(
                colour='black', linewidth=0.5, alpha=0.3),
            panel_grid_major_y=pn.element_line(
                colour='black', linewidth=1, alpha=0.5)))

    plot_time_1.save(output_directory + '/plot_time_emissions_' + sim_n +
                     '.png',
                     width=10,
                     height=3,
                     dpi=300)

    plot_time_2 = (pn.ggplot(time_df_adj, pn.aes('datetime', 'active_leaks')) +
                   pn.geom_line(size=2) +
                   pn.ggtitle('Number of active leaks at all sites') +
                   pn.ylab('') + pn.xlab('') +
                   pn.scale_x_datetime(labels=date_format('%Y')) +
                   pn.theme(panel_border=pn.element_rect(
                       colour="black", fill=None, size=2),
                            panel_grid_minor_x=pn.element_blank(),
                            panel_grid_major_x=pn.element_blank(),
                            panel_grid_minor_y=pn.element_line(
                                colour='black', linewidth=0.5, alpha=0.3),
                            panel_grid_major_y=pn.element_line(
                                colour='black', linewidth=1, alpha=0.5)))

    plot_time_2.save(output_directory + '/plot_time_active_' + sim_n + '.png',
                     width=10,
                     height=3,
                     dpi=300)

    # Site-level plots
    plot_site_1 = (
        pn.ggplot(site_df, pn.aes('cum_frac_sites', 'cum_frac_emissions')) +
        pn.geom_line(size=2) + pn.theme(
            panel_border=pn.element_rect(colour="black", fill=None, size=2),
            panel_grid_minor_x=pn.element_blank(),
            panel_grid_major_x=pn.element_blank(),
            panel_grid_minor_y=pn.element_line(
                colour='black', linewidth=0.5, alpha=0.3),
            panel_grid_major_y=pn.element_line(
                colour='black', linewidth=1, alpha=0.5)) +
        pn.xlab('Cumulative fraction of sites') +
        pn.ylab('Cumulative fraction of emissions') +
        pn.ggtitle('Empirical cumulative distribution of site-level emissions')
    )

    plot_site_1.save(output_directory + '/site_cum_dist_' + sim_n + '.png',
                     width=5,
                     height=4,
                     dpi=300)

    # Leak plots
    plot_leak_1 = (pn.ggplot(leak_df, pn.aes('days_active')) +
                   pn.geom_histogram(colour='gray') +
                   pn.theme(panel_border=pn.element_rect(
                       colour="black", fill=None, size=2),
                            panel_grid_minor_x=pn.element_blank(),
                            panel_grid_major_x=pn.element_blank(),
                            panel_grid_minor_y=pn.element_line(
                                colour='black', linewidth=0.5, alpha=0.3),
                            panel_grid_major_y=pn.element_line(
                                colour='black', linewidth=1, alpha=0.5)) +
                   pn.ggtitle('Distribution of leak duration') +
                   pn.xlab('Number of days the leak was active') +
                   pn.ylab('Count'))
    plot_leak_1.save(output_directory + '/leak_active_hist' + sim_n + '.png',
                     width=5,
                     height=4,
                     dpi=300)

    plot_leak_2 = (pn.ggplot(
        leak_df, pn.aes('cum_frac_leaks', 'cum_frac_rate', colour='status')) +
                   pn.geom_line(size=2) +
                   pn.scale_colour_hue(h=0.15, l=0.25, s=0.9) +
                   pn.theme(panel_border=pn.element_rect(
                       colour="black", fill=None, size=2),
                            panel_grid_minor_x=pn.element_blank(),
                            panel_grid_major_x=pn.element_blank(),
                            panel_grid_minor_y=pn.element_line(
                                colour='black', linewidth=0.5, alpha=0.3),
                            panel_grid_major_y=pn.element_line(
                                colour='black', linewidth=1, alpha=0.5)) +
                   pn.xlab('Cumulative fraction of leak sources') +
                   pn.ylab('Cumulative leak rate fraction') +
                   pn.ggtitle('Fractional cumulative distribution'))

    plot_leak_2.save(output_directory + '/leak_cum_dist1_' + sim_n + '.png',
                     width=4,
                     height=4,
                     dpi=300)

    plot_leak_3 = (pn.ggplot(
        leak_df, pn.aes('cum_frac_leaks', 'cum_rate', colour='status')) +
                   pn.geom_line(size=2) +
                   pn.scale_colour_hue(h=0.15, l=0.25, s=0.9) +
                   pn.theme(panel_border=pn.element_rect(
                       colour="black", fill=None, size=2),
                            panel_grid_minor_x=pn.element_blank(),
                            panel_grid_major_x=pn.element_blank(),
                            panel_grid_minor_y=pn.element_line(
                                colour='black', linewidth=0.5, alpha=0.3),
                            panel_grid_major_y=pn.element_line(
                                colour='black', linewidth=1, alpha=0.5)) +
                   pn.scale_y_continuous(trans='log10') +
                   pn.xlab('Cumulative fraction of leak sources') +
                   pn.ylab('Cumulative emissions (kg/day)') +
                   pn.ggtitle('Absolute cumulative distribution'))

    plot_leak_3.save(output_directory + '/leak_cum_dist2_' + sim_n + '.png',
                     width=4,
                     height=4,
                     dpi=300)

    return
Пример #45
0
def test_facet_grid_scales_free_y_formula_dot_notation():
    p = (g+facet_grid('. ~ var1>2', scales='free_y')
         + theme(panel_spacing_x=0.3))
    assert p == 'facet_grid_scales_free_y'
Пример #46
0
print(len(df))
print(len(set(df.qid)))

user_stat['log_n_records'] = pd.Series(user_stat.n_records.apply(np.log),
                                       index=user_stat.index)
max_color = user_stat.log_n_records.max()
user_stat['alpha'] = pd.Series(
    user_stat.log_n_records.apply(lambda x: x / max_color), index=user_stat.index)


p0 = ggplot(user_stat) \
        + geom_point(aes(x='ratio', y='accuracy',
                     size='n_records', color='log_n_records', alpha='alpha'),
                     show_legend={'color': False, 'alpha': False, 'size': False}) \
        + scale_color_gradient(high='#e31a1c', low='#ffffcc') \
        + theme(aspect_ratio=1)
p0.save('protobowl_users.pdf')
# p0.draw()
print('p0 done')


p1 = ggplot(user_stat, aes(x='log_n_records', y='..density..')) \
        + geom_histogram(color='#e6550d', fill='#fee6ce') \
        + geom_density() \
        + theme(aspect_ratio=0.3)
p1.save('protobowl_hist.pdf')
# p1.draw()
print('p1 done')


p2 = ggplot(user_stat, aes(x='accuracy', y='..density..')) \
    by_well_gg = (
        gg.ggplot(
            cell_count_totalcells_df.loc[
                cell_count_totalcells_df["site"].str.contains(plate)
            ],
            gg.aes(x="x_loc", y="y_loc"),
        )
        + gg.geom_point(gg.aes(fill="total_cell_count"), shape="s", size=6)
        + gg.geom_text(gg.aes(label="site_location"), color="lightgrey", size=6)
        + gg.facet_wrap("~well")
        + gg.coord_fixed()
        + gg.theme_bw()
        + gg.ggtitle(f"Total Cells/Well\n{plate}")
        + gg.theme(
            axis_text=gg.element_blank(),
            axis_title=gg.element_blank(),
            strip_background=gg.element_rect(colour="black", fill="#fdfff4"),
        )
        + gg.labs(fill="Cells")
        + gg.scale_fill_cmap(name="Number of Cells")
    )

    output_file = pathlib.Path(
        output_figuresdir, f"plate_layout_cells_count_per_well_{plate}.png"
    )
    if check_if_write(output_file, force, throw_warning=True):
        by_well_gg.save(output_file, dpi=300, verbose=False)

# Plot cell category ratios per well and empty cells per well
ratio_df = pd.pivot_table(
    cell_count_df,
Пример #48
0
from __future__ import absolute_import, division, print_function

import pandas as pd

from plotnine import ggplot, aes, geom_crossbar, theme

n = 4
df = pd.DataFrame({
    'x': [1] * n,
    'ymin': range(1, 2 * n + 1, 2),
    'y': [i + 0.1 + i / 10 for i in range(1, 2 * n + 1, 2)],
    'ymax': range(2, 2 * n + 2, 2),
    'z': range(n)
})

_theme = theme(facet_spacing={'right': 0.85})


def test_aesthetics():
    p = (ggplot(df, aes(y='y', ymin='ymin', ymax='ymax')) +
         geom_crossbar(aes('x'), size=2) + geom_crossbar(
             aes('x+1', alpha='z'), fill='green', width=0.2, size=2) +
         geom_crossbar(aes('x+2', linetype='factor(z)'), size=2) +
         geom_crossbar(aes('x+3', color='factor(z)'), size=2) +
         geom_crossbar(aes('x+4', size='z')))

    assert p + _theme == 'aesthetics'
Пример #49
0
def test_add_partial_complete():
    theme1 = theme(axis_line_x=element_line())
    theme2 = theme_gray()
    theme3 = theme1 + theme2
    assert theme3 == theme2
Пример #50
0
import types

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.testing.compare import compare_images

from plotnine import ggplot, theme

TOLERANCE = 2  # Default tolerance for the tests
DPI = 72  # Default DPI for the tests

# This partial theme modifies all themes that are used in
# the test. It is limited to setting the size of the test
# images Should a test require a larger or smaller figure
# size, the dpi or aspect_ratio should be modified.
test_theme = theme(figure_size=(640 / DPI, 480 / DPI))

if not os.path.exists(
        os.path.join(os.path.dirname(__file__), 'baseline_images')):
    raise IOError(
        "The baseline image directory does not exist. "
        "This is most likely because the test data is not installed. "
        "You may need to install plotnine from source to get the "
        "test data.")


def raise_no_baseline_image(filename):
    raise Exception("Baseline image {} is missing".format(filename))


def ggplot_equals(gg, right):
Пример #51
0
def error_comparison():
    char_frames = {}
    first_frames = {}
    full_frames = {}
    train_times = {}
    use_wiki = {}
    best_accuracies = {}
    for p in glob.glob(f'output/guesser/best/qanta.guesser*/guesser_report_guesstest.pickle', recursive=True):
        with open(p, 'rb') as f:
            report = pickle.load(f)
            name = report['guesser_name']
            params = report['guesser_params']
            train_times[name] = params['training_time']
            use_wiki[name] = params['use_wiki'] if 'use_wiki' in params else False
            char_frames[name] = report['char_df']
            first_frames[name] = report['first_df']
            full_frames[name] = report['full_df']
            best_accuracies[name] = (report['first_accuracy'], report['full_accuracy'])
    first_df = pd.concat([f for f in first_frames.values()]).sort_values('score', ascending=False).groupby(['guesser', 'qanta_id']).first().reset_index()
    first_df['position'] = ' Start'
    full_df = pd.concat([f for f in full_frames.values()]).sort_values('score', ascending=False).groupby(['guesser', 'qanta_id']).first().reset_index()
    full_df['position'] = 'End'
    compare_df = pd.concat([first_df, full_df])
    compare_df = compare_df[compare_df.guesser != 'qanta.guesser.vw.VWGuesser']
    compare_results = {}
    comparisons = ['qanta.guesser.dan.DanGuesser', 'qanta.guesser.rnn.RnnGuesser', 'qanta.guesser.elasticsearch.ElasticSearchGuesser']
    cr_rows = []
    for (qnum, position), group in compare_df.groupby(['qanta_id', 'position']):
        group = group.set_index('guesser')
        correct_guessers = []
        wrong_guessers = []
        for name in comparisons:
            if group.loc[name].correct == 1:
                correct_guessers.append(name)
            else:
                wrong_guessers.append(name)
        if len(correct_guessers) > 3:
            raise ValueError('this should be unreachable')
        elif len(correct_guessers) == 3:
            cr_rows.append({'qnum': qnum, 'Position': position, 'model': 'All', 'Result': 'Correct'})
        elif len(correct_guessers) == 0:
            cr_rows.append({'qnum': qnum, 'Position': position, 'model': 'All', 'Result': 'Wrong'})
        elif len(correct_guessers) == 1:
            cr_rows.append({
                'qnum': qnum, 'Position': position,
                'model': to_shortname(correct_guessers[0]),
                'Result': 'Correct'
            })
        else:
            cr_rows.append({
                'qnum': qnum, 'Position': position,
                'model': to_shortname(wrong_guessers[0]),
                'Result': 'Wrong'
            })
    cr_df = pd.DataFrame(cr_rows)
    # samples = cr_df[(cr_df.Position == ' Start') & (cr_df.Result == 'Correct') & (cr_df.model == 'RNN')].qnum.values
    # for qid in samples:
    #     q = lookup[qid]
    #     print(q['first_sentence'])
    #     print(q['page'])
    #     print()
    p = (
        ggplot(cr_df)
        + aes(x='model', fill='Result') + facet_grid(['Result', 'Position']) #+ facet_wrap('Position', labeller='label_both')
        + geom_bar(aes(y='(..count..) / sum(..count..)'), position='dodge')
        + labs(x='Models', y='Fraction with Corresponding Result') + coord_flip()
        + theme_fs() + theme(aspect_ratio=.6)
    )
    p.save('output/plots/guesser_error_comparison.pdf')
Пример #52
0
def test_add_empty_theme_element():
    # An empty theme element does not alter the theme
    theme1 = theme_gray() + theme(axis_line_x=element_line(color='red'))
    theme2 = theme1 + theme(axis_line_x=element_line())
    assert theme1 == theme2
Пример #53
0
from __future__ import absolute_import, division, print_function

import pandas as pd

from plotnine import ggplot, aes, geom_count, theme

_theme = theme(subplots_adjust={'right': 0.85})

df = pd.DataFrame({
    'x': list('aaaaaaaaaabbbbbbbbbbcccccccccc'),
    'y': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
          1, 1, 1, 1, 1, 6, 6, 8, 10, 10,
          1, 1, 2, 4, 4, 4, 4, 9, 9, 9]})


def test_discrete_x():
    p = ggplot(df, aes('x', 'y')) + geom_count()

    assert p + _theme == 'discrete_x'


def test_discrete_y():
    p = ggplot(df, aes('y', 'x')) + geom_count()

    assert p + _theme == 'discrete_y'


def test_continuous_x_y():
    p = ggplot(df, aes('y', 'y')) + geom_count()

    assert p + _theme == 'continuous_x_y'
    + geom_line(threshold,
                aes(x=lst_num_partitions, y='score'),
                linetype='dashed',
                size=1,
                color="darkgrey",
                show_legend=False) \
    + labs(x = "Number of Partitions",
           y = "Similarity score (SVCCA)",
           title = "Similarity across varying numbers of partitions") \
    + theme(
            plot_background=element_rect(fill="white"),
            panel_background=element_rect(fill="white"),
            panel_grid_major_x=element_line(color="lightgrey"),
            panel_grid_major_y=element_line(color="lightgrey"),
            axis_line=element_line(color="grey"),
            legend_key=element_rect(fill='white', colour='white'),
            legend_title=element_text(family='sans-serif', size=15),
            legend_text=element_text(family='sans-serif', size=12),
            plot_title=element_text(family='sans-serif', size=15),
            axis_text=element_text(family='sans-serif', size=12),
            axis_title=element_text(family='sans-serif', size=15)
           ) \
    + scale_color_manual(['#1976d2', '#b3e5fc']) \

print(panel_A)
ggsave(plot=panel_A, filename=svcca_file, device="svg", dpi=300)
ggsave(plot=panel_A, filename=svcca_png_file, device="svg", dpi=300)

# ### Uncorrected PCA

# In[44]: