예제 #1
0
def plot_pointplot(plot_df, y_axis_label="", use_log10=False, limits=[0, 3.2]):
    """
    Plots the pointplot
    Arguments:
        plot_df - the dataframe that contains the odds ratio and lemmas
        y_axis_label - the label for the y axis
        use_log10 - use log10 for the y axis?
    """
    graph = (
        p9.ggplot(plot_df, p9.aes(x="lemma", y="odds_ratio")) +
        p9.geom_pointrange(p9.aes(ymin="lower_odds", ymax="upper_odds"),
                           position=p9.position_dodge(width=1),
                           size=0.3,
                           color="#253494") +
        p9.scale_x_discrete(limits=(plot_df.sort_values(
            "odds_ratio", ascending=True).lemma.tolist())) +
        (p9.scale_y_log10() if use_log10 else p9.scale_y_continuous(
            limits=limits)) +
        p9.geom_hline(p9.aes(yintercept=1), linetype='--', color='grey') +
        p9.coord_flip() + p9.theme_seaborn(
            context='paper', style="ticks", font_scale=1, font='Arial') +
        p9.theme(
            # 640 x 480
            figure_size=(6.66, 5),
            panel_grid_minor=p9.element_blank(),
            axis_title=p9.element_text(size=12),
            axis_text_x=p9.element_text(size=10)) +
        p9.labs(x=None, y=y_axis_label))
    return graph
예제 #2
0
def plot_bargraph(count_plot_df, plot_df):
    """
    Plots the bargraph 
    Arguments:
        count_plot_df - The dataframe that contains lemma counts
        plot_df - the dataframe that contains the odds ratio and lemmas
    """

    graph = (
        p9.ggplot(count_plot_df.astype({"count": int}),
                  p9.aes(x="lemma", y="count")) +
        p9.geom_col(position=p9.position_dodge(width=0.5), fill="#253494") +
        p9.coord_flip() + p9.facet_wrap("repository", scales='free_x') +
        p9.scale_x_discrete(limits=(plot_df.sort_values(
            "odds_ratio", ascending=True).lemma.tolist())) +
        p9.scale_y_continuous(labels=custom_format('{:,.0g}')) +
        p9.labs(x=None) + p9.theme_seaborn(
            context='paper', style="ticks", font="Arial", font_scale=0.95) +
        p9.theme(
            # 640 x 480
            figure_size=(6.66, 5),
            strip_background=p9.element_rect(fill="white"),
            strip_text=p9.element_text(size=12),
            axis_title=p9.element_text(size=12),
            axis_text_x=p9.element_text(size=10),
        ))
    return graph
예제 #3
0
def make_single_bar_chart_multi_year(survey_data, column, facet, proportionally=False):
    """Make a barchart showing the number of respondents responding to a single column.
        Bars are colored by which year of the survey they correspond to. If facet
        is not empty, the resulting plot will be faceted into subplots by the variables
        given. 

    Args:
        survey_data (pandas.DataFrame): Raw data read in from Kubernetes Survey   
        column (str): Column to plot responses to
        facet (list,optional): List of columns use for grouping
        proportionally (bool, optiona ): Defaults to False. If True,
            the bars heights are determined proportionally to the 
            total number of responses in that facet. 

    Returns:
        (plotnine.ggplot): Plot object which can be displayed in a notebook or saved out to a file

    """
    cols = [column, facet]
    show_legend = False
    topic_data = survey_data[cols + ["year"]]

    topic_data_long = make_long(topic_data, facet, multi_year=True)

    if proportionally:
        proportions = (
            topic_data_long[topic_data_long.rating == 1].groupby(facet + ["year"]).sum()
            / topic_data_long.groupby(facet + ["year"]).sum()
        ).reset_index()
    else:
        proportions = (
            topic_data_long[topic_data_long.rating == 1]
            .groupby(facet + ["year"])
            .count()
            .reset_index()
        )

    x = topic_data_long.columns.tolist()
    x.remove("level_1")

    ## Uncomment to return dataframe instead of plot
    # return proportions

    return (
        p9.ggplot(proportions, p9.aes(x=facet, fill="year", y="level_1"))
        + p9.geom_bar(show_legend=show_legend, stat="identity")
        + p9.theme(
            axis_text_x=p9.element_text(angle=45, ha="right"),
            strip_text_y=p9.element_text(angle=0, ha="left"),
        )
        + p9.scale_x_discrete(
            limits=topic_data_long[facet].unique().tolist(),
            labels=[
                x.replace("_", " ") for x in topic_data_long[facet].unique().tolist()
            ],
        )
    )
 def create(self, file_path: str) -> None:
     (ggplot(self._data, aes(x="pattern", y="count", label="fraction")) +
      geom_bar(stat="identity", fill="#1e4f79") +
      geom_text(va='bottom', size=24, format_string='{:.1%}') +
      scale_x_discrete(limits=self._data["pattern"]) +
      scale_y_continuous(labels=comma_format(), expand=[0.1, 0]) +
      ggtitle("Design Pattern Counts") + xlab("Design Pattern") +
      ylab("Count") + theme_classic(base_size=32, base_family="Helvetica") +
      theme(text=element_text(size=32),
            axis_text_x=element_text(rotation=45, ha="right"))).save(
                file_path, width=24, height=8)
 def create(self, file_path: str) -> None:
     (ggplot(self._data, aes(x="count", label="..count..")) +
      geom_bar(fill="#1e4f79") +
      geom_text(stat="count", va='bottom', size=24) +
      scale_x_discrete(limits=[
          "1", "2", "3", "5", "26", "52", "97", "100", "300", "537"
      ]) + scale_y_continuous(breaks=[0, 5, 10], limits=[0, 10]) +
      ggtitle("Case Study Sizes") + xlab("Number of Projects") +
      ylab("Number of Case Studies") +
      theme_classic(base_size=28, base_family="Helvetica") +
      theme(text=element_text(size=28))).save(file_path, width=14, height=7)
def plot_metrics_comparison_lineplot_grid(dataframe,
                                          models_labels,
                                          metrics_labels,
                                          figure_size=(14, 4)):
    """
    We define a function to plot the grid.
    """

    return (
        # Define the plot.
        p9.ggplot(
            dataframe,
            p9.aes(x='threshold',
                   y='value',
                   group='variable',
                   color='variable',
                   shape='variable'))
        # Add the points and lines.
        + p9.geom_point() + p9.geom_line()
        # Rename the x axis and give some space to left and right.
        + p9.scale_x_discrete(name='Threshold', expand=(0, 0.2))
        # Rename the y axis, give some space on top and bottom, and print the tick labels with 2 decimal digits.
        +
        p9.scale_y_continuous(name='Value',
                              expand=(0, 0.05),
                              labels=lambda l: ['{:.2f}'.format(x) for x in l])
        # Replace the names in the legend.
        + p9.scale_shape_discrete(
            name='Metric', labels=lambda l: [metrics_labels[x] for x in l])
        # Define the colors for the metrics for color-blind people.
        +
        p9.scale_color_brewer(name='Metric',
                              labels=lambda l: [metrics_labels[x] for x in l],
                              type='qual',
                              palette='Set2')
        # Place the plots in a grid, renaming the labels for rows and columns.
        + p9.facet_grid('iterations ~ model',
                        labeller=p9.labeller(
                            rows=lambda x: f'iters = {x}',
                            cols=lambda x: f'{models_labels[x]}'))
        # Define the theme for the plot.
        + p9.theme(
            # Remove the y axis name.
            axis_title_y=p9.element_blank(),
            # Set the size of x and y tick labels font.
            axis_text_x=p9.element_text(size=7),
            axis_text_y=p9.element_text(size=7),
            # Place the legend on top, without title, and reduce the margin.
            legend_title=p9.element_blank(),
            legend_position='top',
            legend_box_margin=2,
            # Set the size for the figure.
            figure_size=figure_size,
        ))
 def create(self, file_path: str) -> None:
     (ggplot(self._data, aes(x="category", y="count", label="percent")) +
      geom_bar(stat="identity", fill="#1e4f79") +
      geom_text(va='bottom', size=24) +
      scale_x_discrete(limits=self._data["category"]) +
      scale_y_continuous(labels=comma_format(), expand=[0.1, 0]) +
      ggtitle("Classes per Category") + xlab("Category") +
      ylab("Number of Classes") +
      theme_classic(base_size=32, base_family="Helvetica") +
      theme(text=element_text(size=32),
            axis_text_x=element_text(rotation=45, ha="right"))).save(
                file_path, width=7, height=7)
예제 #8
0
def setup_heatmap0(df: pd.DataFrame, format_string, axis_text):
    # https://stackoverflow.com/a/62161556/819272
    # Plotnine does not support changing the position of any axis.
    return (p9.ggplot(df, p9.aes(y='row', x='col')) + p9.coord_equal() +
            p9.geom_tile(p9.aes(fill='scale')) + p9.geom_text(
                p9.aes(label='value'), format_string=format_string, size=7) +
            p9.scale_y_discrete(drop=False) + p9.scale_x_discrete(drop=False) +
            p9.scale_fill_gradientn(colors=['#63BE7B', '#FFEB84', '#F8696B'],
                                    na_value='#CCCCCC',
                                    guide=False) +
            p9.theme(axis_text=p9.element_blank()
                     if not axis_text else p9.element_text(face='bold'),
                     axis_ticks=p9.element_blank(),
                     axis_title=p9.element_blank(),
                     panel_grid=p9.element_blank()))
def plot_preprocessing_boxplot_bymodel(dataframe,
                                       models_labels,
                                       metrics_labels,
                                       groups_labels,
                                       figure_size=(14, 4)):
    """
    We define a function to plot the grid.
    """

    return (
        # Define the plot.
        p9.ggplot(dataframe, p9.aes(x='variable', y='value', fill='group'))
        # Add the boxplots.
        + p9.geom_boxplot(position='dodge')
        # Rename the x axis.
        + p9.scale_x_discrete(name='Metric',
                              labels=lambda l: [metrics_labels[x] for x in l])
        # Rename the y axis.
        + p9.scale_y_continuous(
            name='Value',
            expand=(0, 0.05),
            # breaks=[-0.25, 0, 0.25, 0.5, 0.75, 1], limits=[-0.25, 1],
            labels=lambda l: ['{:.2f}'.format(x) for x in l])
        # Define the colors for the metrics for color-blind people.
        + p9.scale_fill_brewer(name='Group',
                               labels=lambda l: [groups_labels[x] for x in l],
                               type='qual',
                               palette='Set2')
        # Place the plots in a grid, renaming the labels.
        + p9.facet_grid(
            'model ~ .',
            scales='free_y',
            labeller=p9.labeller(rows=lambda x: f'{models_labels[x]}'))
        # Define the theme for the plot.
        + p9.theme(
            # Remove the x and y axis names.
            axis_title_x=p9.element_blank(),
            axis_title_y=p9.element_blank(),
            # Set the size of x and y tick labels font.
            axis_text_x=p9.element_text(size=7),
            axis_text_y=p9.element_text(size=7),
            # Place the legend on top, without title, and reduce the margin.
            legend_title=p9.element_blank(),
            legend_position='top',
            legend_box_margin=2,
            # Set the size for the figure.
            figure_size=figure_size,
        ))
def plot_distributions_bar_plot_grid(dataframe, figure_size=(14, 4)):
    """
    We create a function to plot the bar plot.
    """

    return (
        # Define the plot.
        p9.ggplot(dataframe, p9.aes(x='threshold', fill='value'))
        # Add the bars.
        + p9.geom_bar(position='dodge') +
        p9.geom_text(p9.aes(label='stat(count)'),
                     stat='count',
                     position=p9.position_dodge(0.9),
                     size=7,
                     va='bottom')
        # Rename the x axis.
        + p9.scale_x_discrete(name='Threshold')
        # Rename the y axis, give some space on top and bottom (mul_bottom, add_bottom, mul_top, add_top).
        + p9.scale_y_continuous(name='Count', expand=(0, 0, 0, 500))
        # Replace the names in the legend and set the colors of the bars.
        + p9.scale_fill_manual(values={
            0: '#009e73',
            1: '#d55e00'
        },
                               labels=lambda l: [{
                                   0: 'Stable',
                                   1: 'Unstable'
                               }[x] for x in l])
        # Place the plots in a grid, renaming the labels.
        + p9.facet_grid('. ~ iterations',
                        labeller=p9.labeller(cols=lambda x: f'iters = {x}'))
        # Define the theme for the plot.
        + p9.theme(
            # Remove the y axis name.
            axis_title_y=p9.element_blank(),
            # Set the size of x and y tick labels font.
            axis_text_x=p9.element_text(size=7),
            axis_text_y=p9.element_text(size=7),
            # Place the legend on top, without title, and reduce the margin.
            legend_title=p9.element_blank(),
            legend_position='top',
            legend_box_margin=2,
            # Set the size for the figure.
            figure_size=figure_size,
        ))
예제 #11
0
def plot_scale(df: pd.DataFrame,
               sweep_vars: Sequence[str] = None) -> gg.ggplot:
    """Plots the best episode observed by height_threshold."""
    df = cp_swingup_preprocess(df_in=df)

    group_vars = ['height_threshold']
    if sweep_vars:
        group_vars += sweep_vars
    plt_df = df.groupby(group_vars)['best_episode'].max().reset_index()

    p = (
        gg.ggplot(plt_df) +
        gg.aes(x='factor(height_threshold)',
               y='best_episode',
               colour='best_episode > {}'.format(GOOD_EPISODE)) +
        gg.geom_point(size=5, alpha=0.8) +
        gg.scale_colour_manual(values=['#d73027', '#313695']) +
        gg.geom_hline(gg.aes(yintercept=0.0), alpha=0)  # axis hack
        + gg.scale_x_discrete(breaks=[0, 0.25, 0.5, 0.75, 1.0]) +
        gg.ylab('best return in first {} episodes'.format(NUM_EPISODES)) +
        gg.xlab('height threshold'))
    return plotting.facet_sweep_plot(p, sweep_vars)
예제 #12
0
def day_night_attacks(Data, Data_m):
    print('======= Creating day_night_attacks =======')
    #Filter montlhy and ever Symptomes
    freq_all = Data[(Data.Group == 'sy')]
    freq_m = Data_m[(Data_m.Group == 'sy')]
    
    test = freq_all[(pd.isna(freq_all.year) == 0) & (pd.isna(freq_all.month) == 0)]
    Test_3 = pd.DataFrame(test.groupby("hour", as_index = False).count())
    Test_3 = Test_3.iloc[:, 0:2]
    Test_3 = Test_3.rename(columns = {"Unnamed: 0": "n"})

    test_m = freq_m[(pd.isna(freq_m.year) == 0) & (pd.isna(freq_m.month) == 0)]
    Test_3_m = pd.DataFrame(test_m.groupby("hour", as_index = False).count())
    Test_3_m = Test_3_m.iloc[:, 0:2]
    Test_3_m = Test_3_m.rename(columns = {"Unnamed: 0": "n"})
    
    
    plot =(p9.ggplot(data=Test_3,
                     mapping=p9.aes(x='hour', y = 'n'))
        + p9.geom_point(color = 'red', size = 10)
        + p9.geom_line(color = 'red', size = 1)        
        #+ p9.geom_point(color = 'red', size = 10)
        #+ p9.geom_line(color = 'red', size = 1)
        + p9.theme_classic()
        + p9.theme(axis_text = p9.element_text(size=40),
                   axis_title = p9.element_text(size = 40,face = 'bold'))
        + p9.coord_cartesian(xlim = (1,25))
        + p9.labs(x='Hours',y='No. of attacks')
        + p9.scale_x_discrete(limits = (range(1,25)))
        )
    plot_month =(p9.ggplot(data=Test_3_m,
                     mapping=p9.aes(x='hour', y = 'n'))
        #+ p9.geom_line(color = 'red', size = 5)
        + p9.geom_point(color = 'red', size = 10)
        + p9.theme_classic()
        + p9.theme(axis_text = p9.element_text(size=40),
                   axis_title = p9.element_text(size = 40,face = 'bold'))
        + p9.coord_cartesian(xlim = (1,25))
        + p9.labs(x='Hours',y='No. of attacks')
        + p9.scale_x_discrete(limits = (range(1,25)))
        )

    #Creating and saving MONTHLY Grap_3
    if (len(Test_3_m) > 0):
        #G3 = graph_3(freq_m)
        plot_month.save(filename = 'Graph_3.jpeg',
                 plot = plot_month,
                 path = "pdf/iteration/",
                 width = 25, height = 5,
                 dpi = 320)
    else: 
        print('Plot not created; no data found.')

    #Creating and saving EVER Grap_3
    if (len(freq_all) > 0):
        #G3 = graph_3(freq_all)
        plot.save(filename = 'Graph_ALL_3.jpeg',
                 plot = plot,
                 path = "pdf/iteration/",
                 width = 25, height = 5,
                 dpi = 320)
    else: 
        print('Plot not created; no data found.')

    return(print('=================================day_night_attacks DONE ============================='))    
category_sim_df.head()

category_sim_df.to_csv("output/category_cossim_95_ci.tsv", sep="\t", index=False)

g = (
    p9.ggplot(category_sim_df)
    + p9.aes(
        x="category",
        y="pca1_cossim",
        ymin="pca1_cossim_lower",
        ymax="pca1_cossim_upper",
    )
    + p9.geom_pointrange()
    + p9.coord_flip()
    + p9.theme_bw()
    + p9.scale_x_discrete(limits=category_sim_df.category.tolist()[::-1])
    + p9.theme(
        figure_size=(11, 8.5),
        text=p9.element_text(size=12),
        panel_grid_major_y=p9.element_blank(),
    )
    + p9.labs(y="PC1 Cosine Similarity")
)
g.save("output/pca_plots/figures/category_pca1_95_ci.png", dpi=500)
print(g)

g = (
    p9.ggplot(category_sim_df)
    + p9.aes(
        x="category",
        y="pca2_cossim",
예제 #14
0
                              na_rm=True)

            g += p9.scale_fill_manual(values=ez_colors(g.n_groups('group')))
            g += p9.scale_color_manual(values=ez_colors(g.n_groups('group')))

    # set facets
    if facet_x is not None and facet_y is None:
        g += p9.facet_wrap('~facet_x')
    if facet_x is not None and facet_y is not None:
        g += p9.facet_grid('facet_y~facet_x')

    # set x scale
    if g.column_is_timestamp('x'):
        g += p9.scale_x_datetime()
    elif g.column_is_categorical('x'):
        g += p9.scale_x_discrete()
    else:
        g += p9.scale_x_continuous(labels=ez_labels)

    # set y scale
    g += p9.scale_y_continuous(labels=ez_labels)

    # set axis labels
    g += \
        p9.xlab(names['x']) + \
        p9.ylab(names['y'])

    # set theme
    g += theme_ez(figure_size=figure_size,
                  base_size=base_size,
                  legend_title=p9.element_text(text=names['group'],
예제 #15
0
def line_plot(df,
              x,
              y,
              group=None,
              facet_x=None,
              facet_y=None,
              aggfun='sum',
              err=None,
              show_points=False,
              base_size=10,
              figure_size=(6, 3)):
    '''
  Aggregates data in df and plots multiple columns as a line chart.

  Parameters
  ----------
  df : pd.DataFrame
    input dataframe
  x : str
    quoted expression to be plotted on the x axis
  y : str or list of str
    quoted expression(s) to be plotted on the y axis
  group : str
    quoted expression to be used as group (ie color)
  facet_x : str
    quoted expression to be used as facet
  facet_y : str
    quoted expression to be used as facet
  aggfun : str or fun
    function to be used for aggregating (eg sum, mean, median ...)
  err : str
     quoted expression to be used as error shaded area
  show_points : bool
    show/hide markers
  base_size : int
    base size for theme_ez
  figure_size :tuple of int
    figure size

  Returns
  -------
  g : EZPlot
    EZplot object

  '''

    if group is not None and isinstance(y, list) and len(y) > 1:
        log.error(
            "groups can be specified only when a single y column is present")
        raise ValueError(
            "groups can be specified only when a single y column is present")

    if err is not None and isinstance(y, list) and len(y) > 1:
        log.error(
            "err can be specified only when a single y column is present")
        raise ValueError(
            "err can be specified only when a single y column is present")

    if isinstance(y, list) and len(y) == 1:
        y = y[0]

    # create a copy of the data
    dataframe = df.copy()

    # define groups and variables; remove and store (eventual) names
    names = {}
    groups = {}
    variables = {}

    for label, var in zip(['x', 'group', 'facet_x', 'facet_y'],
                          [x, group, facet_x, facet_y]):
        names[label], groups[label] = unname(var)

    # fix special cases
    if x == '.index':
        groups['x'] = '.index'
        names[
            'x'] = dataframe.index.name if dataframe.index.name is not None else ''

    if isinstance(y, list):

        ys = []
        for i, var in enumerate(y):
            ys.append('y_{}'.format(i))
            names['y_{}'.format(i)], variables['y_{}'.format(i)] = unname(var)

        # aggregate data
        tmp_gdata = agg_data(dataframe,
                             variables,
                             groups,
                             aggfun,
                             fill_groups=True)
        groups_present = [
            c for c in ['x', 'facet_x', 'facet_y'] if c in tmp_gdata.columns
        ]
        gdata = pd.melt(tmp_gdata,
                        groups_present,
                        var_name='group',
                        value_name='y')
        gdata['group'] = gdata['group'].replace(
            {var: names[var]
             for var in ys})

        # update values for plotting
        names['y'] = 'Value'
        names['group'] = 'Variable'
        group = 'Variable'

    else:

        names['y'], variables['y'] = unname(y)
        if err is not None:
            names['err'], variables['err'] = unname(err)

        # aggregate data
        gdata = agg_data(dataframe,
                         variables,
                         groups,
                         aggfun,
                         fill_groups=True)

    # reorder columns
    gdata = gdata[[
        c for c in ['x', 'y', 'err', 'group', 'facet_x', 'facet_y']
        if c in gdata.columns
    ]]
    if err is not None:
        gdata['ymax'] = gdata['y'] + gdata['err']
        gdata['ymin'] = gdata['y'] - gdata['err']

    # init plot obj
    g = EZPlot(gdata)

    # set groups
    if group is None:
        g += p9.geom_line(p9.aes(x="x", y="y"),
                          group=1,
                          colour=ez_colors(1)[0])
        if show_points:
            g += p9.geom_point(p9.aes(x="x", y="y"),
                               group=1,
                               colour=ez_colors(1)[0])
        if err is not None:
            g += p9.geom_ribbon(p9.aes(x="x", ymax="ymax", ymin="ymin"),
                                group=1,
                                fill=ez_colors(1)[0],
                                alpha=0.2)
    else:
        g += p9.geom_line(
            p9.aes(x="x", y="y", group="factor(group)",
                   colour="factor(group)"))
        if show_points:
            g += p9.geom_point(p9.aes(x="x", y="y", colour="factor(group)"))
        if err is not None:
            g += p9.geom_ribbon(p9.aes(x="x",
                                       ymax="ymax",
                                       ymin="ymin",
                                       fill="factor(group)"),
                                alpha=0.2)
        g += p9.scale_color_manual(values=ez_colors(g.n_groups('group')))
        g += p9.scale_fill_manual(values=ez_colors(g.n_groups('group')))

    # set facets
    if facet_x is not None and facet_y is None:
        g += p9.facet_wrap('~facet_x')
    if facet_x is not None and facet_y is not None:
        g += p9.facet_grid('facet_y~facet_x')

    # set x scale
    if g.column_is_timestamp('x'):
        g += p9.scale_x_datetime()
    elif g.column_is_categorical('x'):
        g += p9.scale_x_discrete()
    else:
        g += p9.scale_x_continuous(labels=ez_labels)

    # set y scale
    g += p9.scale_y_continuous(labels=ez_labels)

    # set axis labels
    g += \
      p9.xlab(names['x']) + \
      p9.ylab(names['y'])

    # set theme
    g += theme_ez(figure_size=figure_size,
                  base_size=base_size,
                  legend_title=p9.element_text(text=names['group'],
                                               size=base_size))

    return g
예제 #16
0
    def barchart_make(roi, df, list_rois, config, ylimit, save_function,
                      find_ylim_function):
        thisroi = list_rois[roi]

        current_df = df.loc[df['index'] == thisroi]

        current_df = current_df.sort_values([config.single_roi_fig_x_axis])
        current_df = current_df.reset_index(
            drop=True)  # Reset index to remove grouping
        current_df[config.single_roi_fig_x_axis] = pd.Categorical(
            current_df[config.single_roi_fig_x_axis],
            categories=current_df[config.single_roi_fig_x_axis].unique())

        figure = (
            pltn.ggplot(
                current_df,
                pltn.aes(x=config.single_roi_fig_x_axis,
                         y='Mean',
                         ymin="Mean-Conf_Int_95",
                         ymax="Mean+Conf_Int_95",
                         fill='factor({colour})'.format(
                             colour=config.single_roi_fig_colour))) +
            pltn.theme_538() + pltn.geom_col(position=pltn.position_dodge(
                preserve='single', width=0.8),
                                             width=0.8,
                                             na_rm=True) +
            pltn.geom_errorbar(size=1,
                               position=pltn.position_dodge(
                                   preserve='single', width=0.8)) +
            pltn.labs(x=config.single_roi_fig_label_x,
                      y=config.single_roi_fig_label_y,
                      fill=config.single_roi_fig_label_fill) +
            pltn.scale_x_discrete(labels=[]) +
            pltn.theme(panel_grid_major_x=pltn.element_line(alpha=0),
                       axis_title_x=pltn.element_text(
                           weight='bold', color='black', size=20),
                       axis_title_y=pltn.element_text(
                           weight='bold', color='black', size=20),
                       axis_text_y=pltn.element_text(size=20, color='black'),
                       legend_title=pltn.element_text(size=20, color='black'),
                       legend_text=pltn.element_text(size=18, color='black'),
                       subplots_adjust={'right': 0.85},
                       legend_position=(0.9, 0.8),
                       dpi=config.plot_dpi) +
            pltn.geom_text(pltn.aes(y=-.7, label=config.single_roi_fig_x_axis),
                           color='black',
                           size=20,
                           va='top') + pltn.scale_fill_manual(
                               values=config.colorblind_friendly_plot_colours))

        if ylimit:
            # Set y limit of figure (used to make it the same for every barchart)
            figure += pltn.ylim(None, ylimit)
            thisroi += '_same_ylim'

        returned_ylim = 0
        if config.use_same_axis_limits in ('Same limits',
                                           'Create both') and ylimit == 0:
            returned_ylim = find_ylim_function(thisroi, figure, 'yaxis')

        if config.use_same_axis_limits == 'Same limits' and ylimit == 0:
            return returned_ylim
        elif ylimit != 0:
            folder = 'Same_yaxis'
        else:
            folder = 'Different_yaxis'

        save_function(figure, thisroi, config, folder, 'barchart')

        return returned_ylim
        'aupr_upper': lambda x: x.aupr_mean + (critical_val * x.aupr_std)/pd.np.sqrt(x.lf_num_len),
        'aupr_lower': lambda x: x.aupr_mean - (critical_val * x.aupr_std)/pd.np.sqrt(x.lf_num_len)
    })
)
dev_disc_df.head(2)


# In[7]:


g = ( 
    p9.ggplot(dev_disc_df, p9.aes(x="factor(lf_num)", y="auroc_mean", linetype="model", color="relation"))
    + p9.geom_point()
    + p9.geom_errorbar(p9.aes(ymin="auroc_lower", ymax="auroc_upper"))
    + p9.geom_line(p9.aes(group="model"))
    + p9.scale_x_discrete(limits=[0, 1, 6, 11, 16, 'All'])
    + p9.scale_color_manual(values={
        "DaG": mcolors.to_hex(color_map["DaG"]),
        'CtD': mcolors.to_hex(color_map["CtD"]),
        "CbG": mcolors.to_hex(color_map["CbG"]),
        "GiG": mcolors.to_hex(color_map["GiG"]),
        }, guide=False)
    + p9.facet_wrap("relation")
    + p9.labs(
        title="Disc Model Performance (Tune Set)",
    )
    + p9.xlab("Number of Label Functions")
    + p9.ylab("AUROC")
    + p9.theme_bw()
)
print(g)
예제 #18
0
    })

# In[6]:

journal_paper_df = pd.DataFrame.from_records(journal_type_records)
journal_paper_df.to_csv("output/pubmed_central_journal_paper_map.tsv.xz",
                        sep="\t",
                        index=False,
                        compression="xz")
journal_paper_df.head()

# In[7]:

journal_paper_df.journal.unique().shape

# # Types of Articles Contained in PMC

# In[3]:

journal_article_type_list = journal_paper_df['article_type'].value_counts(
).index.tolist()[::-1]
journal_article_type_list = journal_article_type_list[-15:]

g = (p9.ggplot(
    journal_paper_df.query(f"article_type in {journal_article_type_list}")) +
     p9.aes(x="article_type") + p9.geom_bar(position="dodge") +
     p9.scale_x_discrete(limits=journal_article_type_list) + p9.coord_flip() +
     p9.theme_bw())
g.save("output/figures/article_type.png", dpi=500)
print(g)
예제 #19
0
def MDplot(Data,
           Names=None,
           Ordering='Default',
           Scaling=None,
           Fill='darkblue',
           RobustGaussian=True,
           GaussianColor='magenta',
           Gaussian_lwd=1.5,
           BoxPlot=False,
           BoxColor='darkred',
           MDscaling='width',
           LineColor='black',
           LineSize=0.01,
           QuantityThreshold=40,
           UniqueValuesThreshold=12,
           SampleSize=500000,
           SizeOfJitteredPoints=1,
           OnlyPlotOutput=True,
           ValueColumn=None,
           ClassColumn=None):
    """
    Plots a mirrored density plot for each numeric column
    
    Args:
        Data (dataframe): dataframe containing data. Each column is one 
                          variable (wide table format, for long table format 
                          see ValueColumn and ClassColumn)
        Names (list): list of column names (will be used if data is not a 
                      dataframe)
        Ordering (str): 'Default', 'Columnwise', 'Alphabetical' or 'Statistics'
        Scaling (str): scaling method, one of: Percentalize, CompleteRobust, 
                                               Robust, Log
        Fill (str): color of MD-Plot
        RobustGaussian (bool): draw a gaussian distribution if column is 
                               gaussian
        GaussianColor (str): color for gaussian distribution
        Gaussian_lwd (float): line width of gaussian distribution
        BoxPlot (bool): draw box-plot
        BoxColor (str): color for box-plots
        MDscaling (str): scale of ggplot violin
        LineSize (float): line width of ggplot violin
        QuantityThreshold (int): minimal number of rows
        UniqueValuesThreshold (int): minimal number of unique values per 
                                         column
        SampleSize (int): number of samples used if number of rows is larger 
                          than SampleSize
        OnlyPlotOutput (bool): if True than returning only ggplot object,
                               if False than returning dictionary containing 
                               ggplot object and additional infos
        ValueColumn (str): name of the column of values to be plotted
                           (data in long table format)
        ClassColumn (str): name of the column with class identifiers for the 
                           value column (data in long table format)
        
    Returns:
        ggplot object or dictionary containing ggplot object and additional 
        infos
    """

    if not isinstance(Data, pd.DataFrame):
        try:
            if Names is not None:
                Data = pd.DataFrame(Data, columns=Names)
            else:
                Data = pd.DataFrame(Data)
                lstCols = list(Data.columns)
                dctCols = {}
                for strCol in lstCols:
                    dctCols[strCol] = "C_" + str(strCol)
                Data = Data.rename(columns=dctCols)
        except:
            raise Exception("Data cannot be converted into pandas dataframe")
    else:
        Data = Data.reset_index(drop=True)

    if ValueColumn is not None and ClassColumn is not None:
        lstCols = list(Data.columns)
        if ValueColumn not in lstCols:
            raise Exception("ValueColumn not contained in dataframe")
        if ClassColumn not in lstCols:
            raise Exception("ClassColumn not contained in dataframe")

        lstClasses = list(Data[ClassColumn].unique())
        DataWide = pd.DataFrame()
        for strClass in lstClasses:
            if len(DataWide) == 0:
                DataWide = Data[Data[ClassColumn] == strClass].copy()\
                .reset_index(drop=True)
                DataWide = DataWide.rename(columns={ValueColumn: strClass})
                DataWide = DataWide[[strClass]]
            else:
                dfTemp = Data[Data[ClassColumn] == strClass].copy()\
                .reset_index(drop=True)
                dfTemp = dfTemp.rename(columns={ValueColumn: strClass})
                dfTemp = dfTemp[[strClass]]
                DataWide = DataWide.join(dfTemp, how='outer')
        Data = DataWide.copy()

    lstCols = list(Data.columns)
    for strCol in lstCols:
        if not is_numeric_dtype(Data[strCol]):
            print("Deleting non numeric column: " + strCol)
            Data = Data.drop([strCol], axis=1)
        else:
            if abs(Data[strCol].sum()) == np.inf:
                print("Deleting infinite column: " + strCol)
                Data = Data.drop([strCol], axis=1)

    Data = Data.rename_axis("index", axis="index")\
    .rename_axis("variable", axis="columns")
    dvariables = Data.shape[1]
    nCases = Data.shape[0]

    if nCases > SampleSize:
        print('Data has more cases than "SampleSize". Drawing a sample for '
              'faster computation. You can omit this by setting '
              '"SampleSize=len(data)".')
        sampledIndex = np.sort(
            np.random.choice(list(Data.index), size=SampleSize, replace=False))
        Data = Data.loc[sampledIndex]

    nPerVar = Data.apply(lambda x: len(x.dropna()))
    nUniquePerVar = Data.apply(lambda x: len(list(x.dropna().unique())))

    # renaming columns to nonumeric names
    lstCols = list(Data.columns)
    dctCols = {}
    for strCol in lstCols:
        try:
            a = float(strCol)
            dctCols[strCol] = "C_" + str(strCol)
        except:
            dctCols[strCol] = str(strCol)
    Data = Data.rename(columns=dctCols)

    if Scaling == "Percentalize":
        Data = Data.apply(lambda x: 100 * (x - x.min()) / (x.max() - x.min()))
    if Scaling == "CompleteRobust":
        Data = robust_normalization(Data, centered=True, capped=True)
    if Scaling == "Robust":
        Data = robust_normalization(Data, centered=False, capped=False)
    if Scaling == "Log":
        Data = signed_log(Data, base="Ten")
        if RobustGaussian == True:
            RobustGaussian = False
            print("log with robust gaussian does not work, because mean and "
                  "variance is not valid description for log normal data")

#_______________________________________________Roboust Gaussian and Statistics
    if RobustGaussian == True or Ordering == "Statistics":
        Data = Data.applymap(lambda x: np.nan if abs(x) == np.inf else x)

        if nCases < 50:
            warnings.warn("Sample is maybe too small for statistical testing")

        factor = pd.Series([0.25, 0.75]).apply(lambda x: abs(norm.ppf(x)))\
        .sum()
        std = Data.std()

        dfQuartile = Data.apply(
            lambda x: mquantiles(x, [0.25, 0.75], alphap=0.5, betap=0.5))
        dfQuartile = dfQuartile.append(dfQuartile.loc[1] - dfQuartile.loc[0],
                                       ignore_index=True)
        dfQuartile.index = ["low", "hi", "iqr"]
        dfMinMax = Data.apply(
            lambda x: mquantiles(x, [0.001, 0.999], alphap=0.5, betap=0.5))
        dfMinMax.index = ["min", "max"]

        shat = pd.Series()
        mhat = pd.Series()
        nonunimodal = pd.Series()
        skewed = pd.Series()
        bimodalprob = pd.Series()
        isuniformdist = pd.Series()
        nSample = max([10000, nCases])
        normaldist = np.empty((nSample, dvariables))
        normaldist[:] = np.nan
        normaldist = pd.DataFrame(normaldist, columns=lstCols)

        for strCol in lstCols:
            shat[strCol] = min(
                [std[strCol], dfQuartile[strCol].loc["iqr"] / factor])
            mhat[strCol] = trim_mean(Data[strCol].dropna(), 0.1)

            if nCases > 45000 and nPerVar[strCol] > 8:
                # statistical testing does not work with to many cases
                sampledIndex = np.sort(
                    np.random.choice(list(Data.index),
                                     size=45000,
                                     replace=False))
                vec = Data[strCol].loc[sampledIndex]
                if nUniquePerVar[strCol] > UniqueValuesThreshold:
                    nonunimodal[strCol] = dip.diptst(vec.dropna(), numt=100)[1]
                    skewed[strCol] = skewtest(vec)[1]
                    args = (dfMinMax[strCol].loc["min"],
                            dfMinMax[strCol].loc["max"] \
                            - dfMinMax[strCol].loc["min"])
                    isuniformdist[strCol] = kstest(vec, "uniform", args)[1]
                    bimodalprob[strCol] = bimodal(vec)["Bimodal"]
                else:
                    print("Not enough unique values for statistical testing, "
                          "thus output of testing is ignored.")
                    nonunimodal[strCol] = 1
                    skewed[strCol] = 1
                    isuniformdist[strCol] = 0
                    bimodalprob[strCol] = 0
            elif nPerVar[strCol] < 8:
                warnings.warn("Sample of finite values to small to calculate "
                              "agostino.test or dip.test for " + strCol)
                nonunimodal[strCol] = 1
                skewed[strCol] = 1
                isuniformdist[strCol] = 0
                bimodalprob[strCol] = 0
            else:
                if nUniquePerVar[strCol] > UniqueValuesThreshold:
                    nonunimodal[strCol] = dip.diptst(Data[strCol].dropna(),
                                                     numt=100)[1]
                    skewed[strCol] = skewtest(Data[strCol])[1]
                    args = (dfMinMax[strCol].loc["min"],
                            dfMinMax[strCol].loc["max"] \
                            - dfMinMax[strCol].loc["min"])
                    isuniformdist[strCol] = kstest(Data[strCol], "uniform",
                                                   args)[1]
                    bimodalprob[strCol] = bimodal(Data[strCol])["Bimodal"]
                else:
                    print("Not enough unique values for statistical testing, "
                          "thus output of testing is ignored.")
                    nonunimodal[strCol] = 1
                    skewed[strCol] = 1
                    isuniformdist[strCol] = 0
                    bimodalprob[strCol] = 0

            if isuniformdist[strCol] < 0.05 and nonunimodal[strCol] > 0.05 \
            and skewed[strCol] > 0.05 and bimodalprob[strCol] < 0.05 \
            and nPerVar[strCol] > QuantityThreshold \
            and nUniquePerVar[strCol] > UniqueValuesThreshold:
                normaldist[strCol] = np.random.normal(mhat[strCol],
                                                      shat[strCol], nSample)
                normaldist[strCol] = normaldist[strCol]\
                .apply(lambda x: np.nan if x < Data[strCol].min() \
                                 or x > Data[strCol].max() else x)
        nonunimodal[nonunimodal == 0] = 0.0000000001
        skewed[skewed == 0] = 0.0000000001
        effectStrength = (-10 * np.log(skewed) - 10 * np.log(nonunimodal)) / 2

#______________________________________________________________________Ordering
    if Ordering == "Default":
        bimodalprob = pd.Series()
        for strCol in lstCols:
            if nCases > 45000 and nPerVar[strCol] > 8:
                sampledIndex = np.sort(
                    np.random.choice(list(Data.index),
                                     size=45000,
                                     replace=False))
                vec = Data[strCol].loc[sampledIndex]
                bimodalprob[strCol] = bimodal(vec)["Bimodal"]
            elif nPerVar[strCol] < 8:
                bimodalprob[strCol] = 0
            else:
                bimodalprob[strCol] = bimodal(Data[strCol])["Bimodal"]
        if len(list(bimodalprob.unique())) < 2 and dvariables > 1 \
        and RobustGaussian == True:
            rangfolge = list(effectStrength.sort_values(ascending=False).index)
            print("Using statistics for ordering instead of default")
        else:
            rangfolge = list(bimodalprob.sort_values(ascending=False).index)

    if Ordering == "Columnwise":
        rangfolge = lstCols

    if Ordering == "Alphabetical":
        rangfolge = lstCols.copy()
        rangfolge.sort()

    if Ordering == "Statistics":
        rangfolge = list(effectStrength.sort_values(ascending=False).index)

#________________________________________________________________Data Reshaping
    if nPerVar.min() < QuantityThreshold \
    or nUniquePerVar.min() < UniqueValuesThreshold:
        warnings.warn("Some columns have less than " + str(QuantityThreshold) +
                      " data points or less than " +
                      str(UniqueValuesThreshold) +
                      " unique values. Changing from MD-plot to Jitter-Plot "
                      "for these columns.")
        dataDensity = Data.copy()
        mm = Data.median()
        for strCol in lstCols:
            if nPerVar[strCol] < QuantityThreshold \
            or nUniquePerVar[strCol] < UniqueValuesThreshold:
                if mm[strCol] != 0:
                    dataDensity[strCol] = mm[strCol] \
                    * np.random.uniform(-0.001, 0.001, nCases) + mm[strCol]
                else:
                    dataDensity[strCol] = np.random.uniform(
                        -0.001, 0.001, nCases)
        # Generates in the cases where pdf cannot be estimated a scatter plot
        dataJitter = dataDensity.copy()
        # Delete all scatters for features where distributions can be estimated
        for strCol in lstCols:
            if nPerVar[strCol] >= QuantityThreshold \
            and nUniquePerVar[strCol] >= UniqueValuesThreshold:
                dataJitter[strCol] = np.nan
        #apply ordering
        dataframe = dataDensity[rangfolge].reset_index()\
        .melt(id_vars=["index"])
    else:
        dataframe = Data[rangfolge].reset_index().melt(id_vars=["index"])

    dctCols = {"index": "ID", "variable": "Variables", "value": "Values"}
    dataframe = dataframe.rename(columns=dctCols)

    #______________________________________________________________________Plotting
    plot = p9.ggplot(dataframe, p9.aes(x="Variables", group="Variables",
                                        y="Values")) \
                     + p9.scale_x_discrete(limits=rangfolge)

    plot = plot + p9.geom_violin(stat = stat_pde_density(scale=MDscaling),
                                 fill=Fill, colour=LineColor,
                                 size=LineSize, trim=True) \
                           + p9.theme(axis_text_x=p9.element_text(rotation=90))

    if nPerVar.min() < QuantityThreshold \
    or nUniquePerVar.min() < UniqueValuesThreshold:
        dataframejitter = dataJitter[rangfolge].reset_index()\
        .melt(id_vars=["index"])
        dataframejitter = dataframejitter.rename(columns=dctCols)
        plot = plot + p9.geom_jitter(
            size=SizeOfJitteredPoints,
            data=dataframejitter,
            colour=LineColor,
            mapping=p9.aes(x="Variables", group="Variables", y="Values"),
            position=p9.position_jitter(0.15))

    if RobustGaussian == True:
        dfTemp = normaldist[rangfolge].reset_index().melt(id_vars=["index"])
        dfTemp = dfTemp.rename(columns=dctCols)
        if dfTemp["Values"].isnull().all() == False:
            plot = plot + p9.geom_violin(
                data=dfTemp,
                mapping=p9.aes(x="Variables", group="Variables", y="Values"),
                colour=GaussianColor,
                alpha=0,
                scale=MDscaling,
                size=Gaussian_lwd,
                na_rm=True,
                trim=True,
                fill=None,
                position="identity",
                width=1)

    if BoxPlot == True:
        plot = plot + p9.stat_boxplot(geom = "errorbar", width = 0.5,
                                      color=BoxColor) \
                    + p9.geom_boxplot(width=1, outlier_colour = None, alpha=0,
                                      fill='#ffffff', color=BoxColor,
                                      position="identity")

    if OnlyPlotOutput == True:
        return plot
    else:
        print(plot)
        return {
            "Ordering": rangfolge,
            "DataOrdered": Data[rangfolge],
            "ggplotObj": plot
        }
예제 #20
0
def main():

    args = UserInput()

    if args.y_lim:
        y_lim = np.array(args.y_lim, dtype=np.float32)
    else:
        y_lim = None
    if args.size:
        size = np.array(args.size, dtype=np.float32)
    else:
        size = args.size

###################################

    df_list = [
        pd.read_csv(f, sep=args.sep, skipinitialspace=True)
        for f in args.infile
    ]

    ## only take input with 1 or 2 columns; for 2 columns, 1st is always removed
    lg_list = []
    for idx, df in enumerate(df_list):
        xdf = pd.DataFrame(df.iloc[:, int(args.col) - 1])

        if args.col_names:
            xdf.columns = [args.col_names[idx]]

        lg_list.append(pd.melt(xdf))

    lg_df = pd.concat(lg_list)
    lg_df.columns = [args.x_name, args.y_name]
    print(lg_df)

    ## plotnine method
    if args.use_p9:
        import plotnine as p9
        Quant = [.25, .5, .75]

        if y_lim is not None:
            set_ylim = p9.ylim(y_lim)
        else:
            set_ylim = p9.ylim(
                [lg_df[args.y_name].min(), lg_df[args.y_name].max()])

        df_plot = (p9.ggplot(
            lg_df, p9.aes(x=args.x_name, y=args.y_name, fill=args.x_name)) +
                   p9.geom_violin(
                       width=.75, draw_quantiles=Quant, show_legend=False) +
                   p9.ggtitle(args.title) + p9.theme_classic() + set_ylim +
                   p9.scale_x_discrete(limits=args.col_names) +
                   p9.theme(text=p9.element_text(size=12, color='black'),
                            axis_text_x=p9.element_text(angle=33),
                            panel_grid_major_y=p9.element_line(color='gray',
                                                               alpha=.5)))

        p9.ggsave(filename='{0}.violin.{1}'.format(args.outpref, args.img),
                  plot=df_plot,
                  dpi=int(args.dpi),
                  format=args.img,
                  width=size[0],
                  height=size[1],
                  units='in',
                  verbose=False)

    else:
        ## Seaborn method
        import seaborn as sns
        sns.set(style='whitegrid')

        ax = sns.violinplot(x=args.x_name,
                            y=args.y_name,
                            data=lg_df,
                            linewidth=1,
                            inner='box')
        if args.title:
            ax.set_title(args.title)
        if y_lim is not None:
            ax.set(ylim=y_lim)

        plt.savefig('{0}.violin.{1}'.format(args.outpref, args.img),
                    figsize=tuple(size),
                    format=args.img,
                    dpi=int(args.dpi))
        plt.clf()
예제 #21
0
def scatter_plot(df,
                 x,
                 y,
                 group=None,
                 facet_x=None,
                 facet_y=None,
                 base_size=10,
                 figure_size=(6, 3),
                 **kwargs):
    '''
    Aggregates data in df and plots as a scatter plot chart.

    Parameters
    ----------
    df : pd.DataFrame
      input dataframe
    x : str
      quoted expression to be plotted on the x axis
    y : str
      quoted expression to be plotted on the y axis
    group : str
      quoted expression to be used as group (ie color)
    facet_x : str
      quoted expression to be used as facet
    facet_y : str
      quoted expression to be used as facet
    base_size : int
      base size for theme_ez
    figure_size :tuple of int
      figure size
    **kwargs:
      additional kwargs passed to geom_point

    Returns
    -------
    g : EZPlot
      EZplot object

    '''

    # create a copy of the data
    dataframe = df.copy()

    # define groups and variables; remove and store (eventual) names
    names = {}
    groups = {}
    variables = {}

    for label, var in zip(['x', 'group', 'facet_x', 'facet_y'],
                          [x, group, facet_x, facet_y]):
        names[label], groups[label] = unname(var)
    names['y'], variables['y'] = unname(y)

    # fix special cases
    if x == '.index':
        groups['x'] = '.index'
        names[
            'x'] = dataframe.index.name if dataframe.index.name is not None else ''

    # aggregate data and reorder columns
    gdata = agg_data(dataframe, variables, groups, None, fill_groups=True)
    gdata = gdata[[
        c for c in ['x', 'y', 'group', 'facet_x', 'facet_y']
        if c in gdata.columns
    ]]

    # add group_x column
    if group is not None:
        gdata['group_x'] = gdata['group'].astype(
            'str') + '_' + gdata['x'].astype(str)

    g = EZPlot(gdata)

    # set groups
    if group is None:
        g += p9.geom_point(p9.aes(x="x", y="y"),
                           colour=ez_colors(1)[0],
                           **kwargs)
    else:
        g += p9.geom_point(
            p9.aes(x="x", y="y", group="factor(group)", color="factor(group)"),
            **kwargs)
        g += p9.scale_color_manual(values=ez_colors(g.n_groups('group')))

    # set facets
    if facet_x is not None and facet_y is None:
        g += p9.facet_wrap('~facet_x')
    if facet_x is not None and facet_y is not None:
        g += p9.facet_grid('facet_y~facet_x')

    # set x scale
    if g.column_is_timestamp('x'):
        g += p9.scale_x_datetime()
    elif g.column_is_categorical('x'):
        g += p9.scale_x_discrete()
    else:
        g += p9.scale_x_continuous(labels=ez_labels)

    # set y scale
    if g.column_is_timestamp('y'):
        g += p9.scale_y_datetime()
    elif g.column_is_categorical('y'):
        g += p9.scale_y_discrete()
    else:
        g += p9.scale_y_continuous(labels=ez_labels)

    # set axis labels
    g += \
        p9.xlab(names['x']) + \
        p9.ylab(names['y'])

    # set theme
    g += theme_ez(figure_size=figure_size,
                  base_size=base_size,
                  legend_title=p9.element_text(text=names['group'],
                                               size=base_size))

    return g
예제 #22
0
tags_summary.rename(columns={"species": "count"}, inplace=True)

tags_summary["tag_duration"] = tags_summary.tag_duration.astype(int)
tags_summary["duration"] = tags_summary.tag_duration.astype(str) + "s"
tags_summary = tags_summary.reindex(list(SPECIES_LABELS.keys()))
# tags_summary["species"] = tags_summary.index
tags_summary.reset_index(inplace=True)
tags_summary
(ggplot(data=tags_summary,
        mapping=aes(x="factor(species, ordered=False)",
                    y="tag_duration",
                    fill="factor(species, ordered=False)")) +
 geom_bar(stat="identity", show_legend=False) + xlab("Species") +
 ylab("Duration of annotations (s)") +
 geom_text(mapping=aes(label="count"), nudge_y=15) + theme_classic() +
 scale_x_discrete(limits=SPECIES_LIST, labels=xlabels)).save(
     "species_repartition_duration_mini.png", width=10, height=8)

(ggplot(data=tags_summary,
        mapping=aes(x="factor(species, ordered=False)",
                    y="count",
                    fill="factor(species, ordered=False)")) +
 geom_bar(stat="identity", show_legend=False) + xlab("Species") +
 ylab("Number of annotations") +
 geom_text(mapping=aes(label="duration"), nudge_y=15) + theme_classic() +
 scale_x_discrete(limits=SPECIES_LIST, labels=xlabels)).save(
     "species_repartition_count_mini.png", width=10, height=8)
print(tags_summary)

xlabels = [lab.replace(" ", "\n") for lab in SPECIES_LABELS.values()]
xlabels
예제 #23
0
def area_plot(df,
              x,
              y,
              group=None,
              facet_x=None,
              facet_y=None,
              aggfun='sum',
              fill=False,
              sort_groups=True,
              base_size=10,
              figure_size=(6, 3)):
    '''
    Aggregates data in df and plots as a stacked area chart.

    Parameters
    ----------
    df : pd.DataFrame
      input dataframe
    x : str
      quoted expression to be plotted on the x axis
    y : str
      quoted expression to be plotted on the y axis
    group : str
      quoted expression to be used as group (ie color)
    facet_x : str
      quoted expression to be used as facet
    facet_y : str
      quoted expression to be used as facet
    aggfun : str or fun
      function to be used for aggregating (eg sum, mean, median ...)
    fill : bool
      plot shares for each group instead of absolute values
    sort_groups : bool
      sort groups by the sum of their value (otherwise alphabetical order is used)
    base_size : int
      base size for theme_ez
    figure_size :tuple of int
      figure size

    Returns
    -------
    g : EZPlot
      EZplot object

    '''

    # create a copy of the data
    dataframe = df.copy()

    # define groups and variables; remove and store (eventual) names
    names = {}
    groups = {}
    variables = {}

    for label, var in zip(['x', 'group', 'facet_x', 'facet_y'],
                          [x, group, facet_x, facet_y]):
        names[label], groups[label] = unname(var)
    names['y'], variables['y'] = unname(y)

    # fix special cases
    if x == '.index':
        groups['x'] = '.index'
        names[
            'x'] = dataframe.index.name if dataframe.index.name is not None else ''

    # aggregate data and reorder columns
    gdata = agg_data(dataframe, variables, groups, aggfun, fill_groups=True)
    gdata['y'].fillna(0, inplace=True)
    gdata = gdata[[
        c for c in ['x', 'y', 'group', 'facet_x', 'facet_y']
        if c in gdata.columns
    ]]

    if fill:
        groups_to_normalize = [
            c for c in ['x', 'facet_x', 'facet_y'] if c in gdata.columns
        ]
        total_values = gdata \
            .groupby(groups_to_normalize)['y'] \
            .sum() \
            .reset_index() \
            .rename(columns = {'y':'tot_y'})
        gdata = pd.merge(gdata, total_values, on=groups_to_normalize)
        gdata['y'] = gdata['y'] / (gdata['tot_y'] + EPSILON)
        gdata.drop('tot_y', axis=1, inplace=True)
        ylabeller = percent_labels
    else:
        ylabeller = ez_labels

    # get plot object
    g = EZPlot(gdata)

    # determine order and create a categorical type
    if sort_groups:
        sort_data_groups(g)

    # get colors
    colors = np.flip(ez_colors(g.n_groups('group')))

    # set groups
    if group is None:
        g += p9.geom_area(p9.aes(x="x", y="y"),
                          colour=None,
                          fill=ez_colors(1)[0],
                          na_rm=True)
    else:
        g += p9.geom_area(p9.aes(x="x",
                                 y="y",
                                 group="factor(group)",
                                 fill="factor(group)"),
                          colour=None,
                          na_rm=True)
        g += p9.scale_fill_manual(values=colors)

    # set facets
    if facet_x is not None and facet_y is None:
        g += p9.facet_wrap('~facet_x')
    if facet_x is not None and facet_y is not None:
        g += p9.facet_grid('facet_y~facet_x')

    # set x scale
    if g.column_is_timestamp('x'):
        g += p9.scale_x_datetime()
    elif g.column_is_categorical('x'):
        g += p9.scale_x_discrete()
    else:
        g += p9.scale_x_continuous(labels=ez_labels)

    # set y scale
    g += p9.scale_y_continuous(labels=ylabeller,
                               expand=[0, 0, 0.1 * (not fill) + 0.03, 0])

    # set axis labels
    g += \
        p9.xlab(names['x']) + \
        p9.ylab(names['y'])

    # set theme
    g += theme_ez(figure_size=figure_size,
                  base_size=base_size,
                  legend_title=p9.element_text(text=names['group'],
                                               size=base_size))

    if sort_groups:
        g += p9.guides(fill=p9.guide_legend(reverse=True),
                       color=p9.guide_legend(reverse=True))

    return g
예제 #24
0
def hist_plot(df,
              x,
              y=None,
              group = None,
              facet_x = None,
              facet_y = None,
              w='1',
              bins=21,
              bin_width = None,
              position = 'stack',
              normalize = False,
              sort_groups=True,
              base_size=10,
              figure_size=(6, 3)):

    '''
    Plot a 1-d or 2-d histogram

    Parameters
    ----------
    df : pd.DataFrame
      input dataframe
    x : str
      quoted expression to be plotted on the x axis
    y : str
      quoted expression to be plotted on the y axis. If this is specified the histogram will be 2-d.
    group : str
      quoted expression to be used as group (ie color)
    facet_x : str
      quoted expression to be used as facet
    facet_y : str
      quoted expression to be used as facet
    w : str
      quoted expression representing histogram weights (default is 1)
    bins : int or tuple
      number of bins to be used
    bin_width : float or tuple
      bin width to be used
    position : str
      if groups are present, choose between `stack`, `overlay` or `dodge`
    normalize : bool
      normalize histogram counts
    sort_groups : bool
      sort groups by the sum of their value (otherwise alphabetical order is used)
    base_size : int
      base size for theme_ez
    figure_size :tuple of int
      figure size

    Returns
    -------
    g : EZPlot
      EZplot object

    '''

    if position not in ['overlay', 'stack', 'dodge']:
        log.error("position not recognized")
        raise NotImplementedError("position not recognized")

    if (bins is None) and (bin_width is None):
        log.error("Either bins or bin_with should be defined")
        raise ValueError("Either bins or bin_with should be defined")

    if (bins is not None) and (bin_width is not None):
        log.error("Only one between bins or bin_with should be defined")
        raise ValueError("Only one between  bins or bin_with should be defined")

    if (y is not None) and (group is not None):
        log.error("y and group cannot be requested at the same time")
        raise ValueError("y and group cannot be requested at the same time")

    if y is None:
        bins = (bins, bins)
        bin_width = (bin_width, bin_width)
    else:
        if type(bins) not in [tuple, list]:
            bins = (bins, bins)
        if type(bin_width) not in [tuple, list]:
            bin_width = (bin_width, bin_width)

    # create a copy of the data
    dataframe = df.copy()

    # define groups and variables; remove and store (eventual) names
    names = {}
    groups = {}
    variables = {}

    for label, var in zip(['x', 'y', 'group', 'facet_x', 'facet_y'], [x, y, group, facet_x, facet_y]):
        names[label], groups[label] = unname(var)
    names['w'], variables['w'] = unname(w)

    # set column names and evaluate expressions
    tmp_df = agg_data(dataframe, variables, groups, None, fill_groups=False)

    # redefine groups and variables; remove and store (eventual) names
    new_groups = {c:c for c in tmp_df.columns if c in ['x', 'y', 'group', 'facet_x', 'facet_y']}
    non_xy_groups = [g for g  in new_groups.keys() if g not in ['x', 'y']]
    new_variables = {'w':'w'}

    # bin data (if necessary)
    if tmp_df['x'].dtypes != np.dtype('O'):
        tmp_df['x'], bins_x, bin_width_x= bin_data(tmp_df['x'], bins[0], bin_width[0])
    else:
        bin_width_x=1
    if y is not None:
        if tmp_df['y'].dtypes != np.dtype('O'):
            tmp_df['y'], bins_y, bin_width_y = bin_data(tmp_df['y'], bins[1], bin_width[1])
        else:
            bin_width_y=1
    else:
        bin_width_y=1

    # aggregate data and reorder columns
    gdata = agg_data(tmp_df, new_variables, new_groups, 'sum', fill_groups=True)
    gdata.fillna(0, inplace=True)
    gdata = gdata[[c for c in ['x', 'y', 'w', 'group', 'facet_x', 'facet_y'] if c in gdata.columns]]

    # normalize
    if normalize:
        if len(non_xy_groups)==0:
            gdata['w'] = gdata['w']/(gdata['w'].sum()*bin_width_x*bin_width_y)
        else:
            gdata['w'] = gdata.groupby(non_xy_groups)['w'].apply(lambda x: x/(x.sum()*bin_width_x*bin_width_y))

    # start plotting
    g = EZPlot(gdata)
    # determine order and create a categorical type
    if (group is not None) and sort_groups:
        if g.column_is_categorical('x'):
            g.sort_group('x', 'w', ascending=False)
        g.sort_group('group', 'w')
        g.sort_group('facet_x', 'w', ascending=False)
        g.sort_group('facet_y', 'w', ascending=False)
        if groups:
            colors = np.flip(ez_colors(g.n_groups('group')))
    elif (group is not None):
        colors = ez_colors(g.n_groups('group'))

    if y is None:
        # set groups
        if group is None:
            g += p9.geom_bar(p9.aes(x="x", y="w"),
                             stat = 'identity',
                             colour = None,
                             fill = ez_colors(1)[0])
        else:
            g += p9.geom_bar(p9.aes(x="x", y="w",
                                    group="factor(group)",
                                    fill="factor(group)"),
                             colour=None,
                             stat = 'identity',
                             **POSITION_KWARGS[position])
            g += p9.scale_fill_manual(values=colors)

        # set facets
        if facet_x is not None and facet_y is None:
            g += p9.facet_wrap('~facet_x')
        if facet_x is not None and facet_y is not None:
            g += p9.facet_grid('facet_y~facet_x')

        # set x scale
        if g.column_is_categorical('x'):
            g += p9.scale_x_discrete()
        else:
            g += p9.scale_x_continuous(labels=ez_labels)

        # set y scale
        g += p9.scale_y_continuous(labels=ez_labels)

        # set axis labels
        g += \
            p9.xlab(names['x']) + \
            p9.ylab('Counts')

        # set theme
        g += theme_ez(figure_size=figure_size,
                      base_size=base_size,
                      legend_title=p9.element_text(text=names['group'], size=base_size))

        if sort_groups:
            g += p9.guides(fill=p9.guide_legend(reverse=True))

    else:
        g += p9.geom_tile(p9.aes(x="x", y="y", fill='w'),
                          stat = 'identity',
                          colour = None)

        # set facets
        if facet_x is not None and facet_y is None:
            g += p9.facet_wrap('~facet_x')
        if facet_x is not None and facet_y is not None:
            g += p9.facet_grid('facet_y~facet_x')

        # set x scale
        if g.column_is_categorical('x'):
            g += p9.scale_x_discrete()
        else:
            g += p9.scale_x_continuous(labels=ez_labels)

        # set y scale
        if g.column_is_categorical('y'):
            g += p9.scale_y_discrete()
        else:
            g += p9.scale_y_continuous(labels=ez_labels)

        # set axis labels
        g += \
            p9.xlab(names['x']) + \
            p9.ylab(names['y'])

        # set theme
        g += theme_ez(figure_size=figure_size,
                      base_size=base_size,
                      legend_title=p9.element_text(text='Counts', size=base_size))

    return g
예제 #25
0
def density_plot(df,
                 x,
                 group=None,
                 facet_x=None,
                 facet_y=None,
                 position='overlay',
                 sort_groups=True,
                 base_size=10,
                 figure_size=(6, 3),
                 **stat_kwargs):
    '''
    Plot a 1-d density plot

    Parameters
    ----------
    df : pd.DataFrame
      input dataframe
    x : str
      quoted expression to be plotted on the x axis
    group : str
      quoted expression to be used as group (ie color)
    facet_x : str
      quoted expression to be used as facet
    facet_y : str
      quoted expression to be used as facet
    position : str
      if groups are present, choose between `stack` or `overlay`
    base_size : int
      base size for theme_ez
    figure_size :tuple of int
      figure size
    stat_kwargs : kwargs
      kwargs for the density stat

    Returns
    -------
    g : EZPlot
      EZplot object

    '''

    if position not in ['overlay', 'stack']:
        log.error("position not recognized")
        raise NotImplementedError("position not recognized")

    # create a copy of the data
    dataframe = df.copy()

    # define groups and variables; remove and store (eventual) names
    names = {}
    groups = {}
    variables = {}

    for label, var in zip(['x', 'group', 'facet_x', 'facet_y'],
                          [x, group, facet_x, facet_y]):
        names[label], groups[label] = unname(var)

    # fix special cases
    if x == '.index':
        groups['x'] = '.index'
        names[
            'x'] = dataframe.index.name if dataframe.index.name is not None else ''

    # aggregate data and reorder columns
    gdata = agg_data(dataframe, variables, groups, None, fill_groups=False)
    gdata = gdata[[
        c for c in ['x', 'group', 'facet_x', 'facet_y'] if c in gdata.columns
    ]]

    # start plotting
    g = EZPlot(gdata)

    # determine order and create a categorical type
    colors = ez_colors(g.n_groups('group'))

    # set groups
    if group is None:
        g += p9.geom_density(p9.aes(x="x"),
                             stat=p9.stats.stat_density(**stat_kwargs),
                             colour=ez_colors(1)[0],
                             fill=ez_colors(1)[0],
                             **POSITION_KWARGS[position])
    else:
        g += p9.geom_density(p9.aes(x="x",
                                    group="factor(group)",
                                    colour="factor(group)",
                                    fill="factor(group)"),
                             stat=p9.stats.stat_density(**stat_kwargs),
                             **POSITION_KWARGS[position])
        g += p9.scale_fill_manual(values=colors, reverse=False)
        g += p9.scale_color_manual(values=colors, reverse=False)

    # set facets
    if facet_x is not None and facet_y is None:
        g += p9.facet_wrap('~facet_x')
    if facet_x is not None and facet_y is not None:
        g += p9.facet_grid('facet_y~facet_x')

    # set x scale
    if g.column_is_categorical('x'):
        g += p9.scale_x_discrete()
    else:
        g += p9.scale_x_continuous(labels=ez_labels)

    # set y scale
    g += p9.scale_y_continuous(labels=ez_labels)

    # set axis labels
    g += \
        p9.xlab(names['x']) + \
        p9.ylab('Density')

    # set theme
    g += theme_ez(figure_size=figure_size,
                  base_size=base_size,
                  legend_title=p9.element_text(text=names['group'],
                                               size=base_size))

    if sort_groups:
        g += p9.guides(fill=p9.guide_legend(reverse=True))

    return g
예제 #26
0
            half_life_ci_l=lambda x: pd.to_timedelta(x.half_life_ci_l, "D"),
            half_life_ci_u=lambda x: pd.to_timedelta(x.half_life_ci_u, "D"),
        ),
        p9.aes(
            x="category",
            y="half_life_time",
            ymin="half_life_ci_l",
            ymax="half_life_ci_u",
        ),
    )
    + p9.geom_col(fill="#1f78b4")
    + p9.geom_errorbar()
    + p9.scale_x_discrete(
        limits=(
            category_half_life.query("category!='none'")
            .sort_values("half_life_time")
            .category.tolist()[::-1]
        ),
    )
    + p9.scale_y_timedelta(labels=timedelta_format("d"))
    + p9.coord_flip()
    + p9.labs(
        x="Preprint Categories",
        y="Time Until 50% of Preprints are Published",
        title="Preprint Category Half-Life",
    )
    + p9.theme_seaborn(context="paper", style="white", font_scale=1, font="Arial")
    + p9.theme(axis_ticks_minor_x=p9.element_blank(), text=p9.element_text(size=12))
)
g.save("output/preprint_category_halflife.svg")
g.save("output/preprint_category_halflife.png", dpi=600)
예제 #27
0
metadata_df["author_type"].value_counts()

# # BioRxiv Research Article Categories

# Categories assigned to each research article. Neuroscience dominates majority of the articles as expected.

# In[9]:

category_list = metadata_df.category.value_counts().index.tolist()[::-1]

# plot nine doesn't implement reverse keyword for scale x discrete
# ugh...
g = (
    p9.ggplot(metadata_df, p9.aes(x="category")) +
    p9.geom_bar(size=10, fill="#253494", position=p9.position_dodge(width=3)) +
    p9.scale_x_discrete(limits=category_list) + p9.coord_flip() +
    p9.theme_seaborn(
        context="paper", style="ticks", font="Arial", font_scale=1))
g.save("output/figures/preprint_category.png", dpi=500)
print(g)

# In[10]:

metadata_df["category"].value_counts()

# # New, Confirmatory, Contradictory Results?

# In[11]:

heading_list = metadata_df.heading.value_counts().index.tolist()[::-1]
예제 #28
0
def marginal_plot(df,
                  x,
                  y,
                  group = None,
                  facet_x = None,
                  facet_y = None,
                  aggfun = 'sum',
                  bins=21,
                  use_quantiles = False,
                  label_pos='auto',
                  label_function=ez_labels,
                  sort_groups=True,
                  base_size=10,
                  figure_size=(6, 3)):

    '''
    Bin the data in a df and plot it using lines.

    Parameters
    ----------
    df : pd.DataFrame
      input dataframe
    x : str
      quoted expression to be plotted on the x axis
    y : str
      quoted expression to be plotted on the y axis
    group : str
      quoted expression to be used as group (ie color)
    facet_x : str
      quoted expression to be used as facet
    facet_y : str
      quoted expression to be used as facet
    aggfun : str or fun
      function to be used for aggregating (eg sum, mean, median ...)
    bins : int or tuple
      number of bins to be used
    use_quantiles : bool
      bin data using quantiles
    label_pos : str
      Use count label on each point. Choose between None, 'auto' or 'force'
    label_function : callable
      labelling function
    sort_groups : bool
      sort groups by the sum of their value (otherwise alphabetical order is used)
    base_size : int
      base size for theme_ez
    figure_size :tuple of int
      figure size

    Returns
    -------
    g : EZPlot
      EZplot object
    '''

    if label_pos not in [None, 'auto', 'force']:
        log.error("label_pos not recognized")
        raise NotImplementedError("label_pos not recognized")
    elif label_pos == 'auto':
        if bins<=21 and group is None:
            show_labels=True
        else:
            show_labels=False
    else:
        show_labels = True if label_pos=='force' else False

    # create a copy of the data
    dataframe = df.copy()

    # define groups and variables; remove and store (eventual) names
    names = {}
    groups = {}
    variables = {}

    for label, var in zip(['x', 'group', 'facet_x', 'facet_y'], [x,  group, facet_x, facet_y]):
        names[label], groups[label] = unname(var)
    names['y'], variables['y'] = unname(y)

    # set column names and evaluate expressions
    tmp_df = agg_data(dataframe, variables, groups, None, fill_groups=False)

    # redefine groups and variables; remove and store (eventual) names
    new_groups = {c:c for c in tmp_df.columns if c in ['x', 'group', 'facet_x', 'facet_y']}
    new_variables = {'y': 'y'}

    # bin data
    if use_quantiles:
        quantile_groups = [c for c in tmp_df.columns if c in ['group', 'facet_x', 'facet_y']]
        if len(quantile_groups)>0:
            tmp_df['x'] = tmp_df.groupby(quantile_groups)['x'].apply(lambda x: qbin_data(x, bins))
        else:
            tmp_df['x'] = qbin_data(tmp_df['x'], bins)
    else:
        tmp_df['x'], _, _ = bin_data(tmp_df['x'], bins, None)

    # aggregate data and reorder columns
    gdata = agg_data(tmp_df, new_variables, new_groups, aggfun, fill_groups=False)

    # reorder columns
    gdata = gdata[[c for c in ['x', 'y', 'group', 'facet_x', 'facet_y'] if c in gdata.columns]]

    # init plot obj
    g = EZPlot(gdata)

    # determine order and create a categorical type
    if sort_groups:
        sort_data_groups(g)

    # get colors
    colors = np.flip(ez_colors(g.n_groups('group')))

    # set groups
    if group is None:
        g += p9.geom_line(p9.aes(x="x", y="y"), group=1, colour=colors[0])
        if show_labels:
            g += p9.geom_point(p9.aes(x="x", y="y"), group=1, colour=colors[0])
    else:
        g += p9.geom_line(p9.aes(x="x", y="y", group="factor(group)", colour="factor(group)"))
        if show_labels:
            g += p9.geom_point(p9.aes(x="x", y="y", colour="factor(group)"))
        g += p9.scale_color_manual(values=colors)

    # set labels
    if show_labels:
        groups_to_count = [c for c in tmp_df.columns if c in ['x', 'group', 'facet_x', 'facet_y']]
        tmp_df['counts']=1
        top_labels = tmp_df \
            .groupby(groups_to_count)['counts'] \
            .sum()\
            .reset_index()
        top_labels['label'] = label_function(top_labels['counts'])
        
        # make sure labels and  data can be joined
        for c in ['group', 'facet_x', 'facet_y']:
            if c in tmp_df.columns:
                try:
                    top_labels[c] = pd.Categorical(top_labels[c].astype(str),
                                                   categories = g.data[c].cat.categories,
                                                   ordered = g.data[c].cat.ordered)
                except:
                    pass
        #return g.data, top_labels
        g.data = pd.merge(g.data, top_labels, on=groups_to_count, how='left')
        g.data['label_pos'] = g.data['y'] + \
                    np.sign(g.data['y'])*g.data['y'].abs().max()*0.02

        g += p9.geom_text(p9.aes(x='x', y='label_pos', label='label'),
                          color="#000000",
                          size=base_size * 0.7,
                          ha='center',
                          va='bottom')
    # set facets
    if facet_x is not None and facet_y is None:
        g += p9.facet_wrap('~facet_x')
    if facet_x is not None and facet_y is not None:
        g += p9.facet_grid('facet_y~facet_x')
        
    # set x scale
    if g.column_is_timestamp('x'):
        g += p9.scale_x_datetime()
    elif g.column_is_categorical('x'):
        g += p9.scale_x_discrete()
    else:
        g += p9.scale_x_continuous(labels=ez_labels)

    # set y scale
    g += p9.scale_y_continuous(labels=ez_labels)

    # set axis labels
    g += \
        p9.xlab(names['x']) + \
        p9.ylab(names['y'])

    # set theme
    g += theme_ez(figure_size=figure_size,
                  base_size=base_size,
                  legend_title=p9.element_text(text=names['group'], size=base_size))
    return g
예제 #29
0
print("Best CV Fold")
print(model.scores_["polka"][:, best_result[0]])
model.scores_["polka"][:, best_result[0]].mean()

model_weights_df = pd.DataFrame.from_dict({
    "weight": model.coef_[0],
    "pc": list(range(1, 51)),
})
model_weights_df["pc"] = pd.Categorical(model_weights_df["pc"])
model_weights_df.head()

g = (p9.ggplot(model_weights_df, p9.aes(x="pc", y="weight")) +
     p9.geom_col(position=p9.position_dodge(width=5), fill="#253494") +
     p9.coord_flip() +
     p9.scale_x_discrete(limits=list(sorted(range(1, 51), reverse=True))) +
     p9.theme_seaborn(
         context="paper", style="ticks", font_scale=1.1, font="Arial") +
     p9.theme(figure_size=(10, 8)) + p9.labs(title="Regression Model Weights",
                                             x="Princpial Component",
                                             y="Model Weight"))
# g.save("output/figures/pca_log_regression_weights.svg")
# g.save("output/figures/pca_log_regression_weights.png", dpi=250)
print(g)

fold_features = model.coefs_paths_["polka"].transpose(1, 0, 2)
model_performance_df = pd.DataFrame.from_dict({
    "feat_num": ((fold_features.astype(bool).sum(axis=1)) > 0).sum(axis=1),
    "C":
    model.Cs_,
    "score":
category_half_life

# In[14]:

g = (p9.ggplot(
    category_half_life.query("category!='none'").assign(
        half_life_time=lambda x: pd.to_timedelta(x.half_life_time, "D"),
        half_life_ci_l=lambda x: pd.to_timedelta(x.half_life_ci_l, "D"),
        half_life_ci_u=lambda x: pd.to_timedelta(x.half_life_ci_u, "D"),
    ),
    p9.aes(x="category",
           y="half_life_time",
           ymin="half_life_ci_l",
           ymax="half_life_ci_u"),
) + p9.geom_col(fill="#1f78b4") + p9.geom_errorbar() + p9.scale_x_discrete(
    limits=(category_half_life.query("category!='none'").sort_values(
        "half_life_time").category.tolist()[::-1]), ) +
     p9.scale_y_timedelta(labels=timedelta_format("d")) + p9.coord_flip() +
     p9.labs(
         x="Preprint Categories",
         y="Time Until 50% of Preprints are Published",
         title="Preprint Category Half-Life",
     ) + p9.theme_seaborn(context="paper", style="white", font_scale=1.2) +
     p9.theme(axis_ticks_minor_x=p9.element_blank(), ))
g.save("output/preprint_category_halflife.svg", dpi=250)
g.save("output/preprint_category_halflife.png", dpi=250)
print(g)

# Take home Results:
#     1. The average amount of time for half of all preprints to be published is 348 days (~1 year)
#     2. Biophysics and biochemistry are two categories that take the least time to have half their preprints published