def plot_metrics_comparison_lineplot_grid(dataframe,
                                          models_labels,
                                          metrics_labels,
                                          figure_size=(14, 4)):
    """
    We define a function to plot the grid.
    """

    return (
        # Define the plot.
        p9.ggplot(
            dataframe,
            p9.aes(x='threshold',
                   y='value',
                   group='variable',
                   color='variable',
                   shape='variable'))
        # Add the points and lines.
        + p9.geom_point() + p9.geom_line()
        # Rename the x axis and give some space to left and right.
        + p9.scale_x_discrete(name='Threshold', expand=(0, 0.2))
        # Rename the y axis, give some space on top and bottom, and print the tick labels with 2 decimal digits.
        +
        p9.scale_y_continuous(name='Value',
                              expand=(0, 0.05),
                              labels=lambda l: ['{:.2f}'.format(x) for x in l])
        # Replace the names in the legend.
        + p9.scale_shape_discrete(
            name='Metric', labels=lambda l: [metrics_labels[x] for x in l])
        # Define the colors for the metrics for color-blind people.
        +
        p9.scale_color_brewer(name='Metric',
                              labels=lambda l: [metrics_labels[x] for x in l],
                              type='qual',
                              palette='Set2')
        # Place the plots in a grid, renaming the labels for rows and columns.
        + p9.facet_grid('iterations ~ model',
                        labeller=p9.labeller(
                            rows=lambda x: f'iters = {x}',
                            cols=lambda x: f'{models_labels[x]}'))
        # Define the theme for the plot.
        + p9.theme(
            # Remove the y axis name.
            axis_title_y=p9.element_blank(),
            # Set the size of x and y tick labels font.
            axis_text_x=p9.element_text(size=7),
            axis_text_y=p9.element_text(size=7),
            # Place the legend on top, without title, and reduce the margin.
            legend_title=p9.element_blank(),
            legend_position='top',
            legend_box_margin=2,
            # Set the size for the figure.
            figure_size=figure_size,
        ))
Esempio n. 2
0
def estimate_cutoffs_plot(output_file,
                          df_plt,
                          df_cell_estimate_cutoff,
                          df_fit=None,
                          scale_x_log10=False,
                          save_plot=True):
    """Plot UMI counts by sorted cell barcodes."""
    if min(df_plt['umi_counts']) <= 0:
        fix_log_scale = min(df_plt['umi_counts']) + 1
        df_plt['umi_counts'] = df_plt['umi_counts'] + fix_log_scale
    gplt = plt9.ggplot()
    gplt = gplt + plt9.theme_bw()
    if len(df_plt) <= 50000:
        gplt = gplt + plt9.geom_point(mapping=plt9.aes(x='barcode',
                                                       y='umi_counts'),
                                      data=df_plt,
                                      alpha=0.05,
                                      size=0.1)
    else:
        gplt = gplt + plt9.geom_line(mapping=plt9.aes(x='barcode',
                                                      y='umi_counts'),
                                     data=df_plt,
                                     alpha=0.25,
                                     size=0.75,
                                     color='black')
    gplt = gplt + plt9.geom_vline(mapping=plt9.aes(xintercept='n_cells',
                                                   color='method'),
                                  data=df_cell_estimate_cutoff,
                                  alpha=0.75,
                                  linetype='dashdot')
    gplt = gplt + plt9.scale_color_brewer(palette='Dark2', type='qual')
    if scale_x_log10:
        gplt = gplt + plt9.scale_x_continuous(
            trans='log10', labels=comma_labels, minor_breaks=0)
    else:
        gplt = gplt + plt9.scale_x_continuous(labels=comma_labels,
                                              minor_breaks=0)
    gplt = gplt + plt9.scale_y_continuous(
        trans='log10', labels=comma_labels, minor_breaks=0)
    gplt = gplt + plt9.labs(title='',
                            y='UMI counts',
                            x='Barcode index, sorted by UMI count',
                            color='Cutoff')
    # Add the fit of the droplet utils model
    if df_fit:
        gplt = gplt + plt9.geom_line(mapping=plt9.aes(x='x', y='y'),
                                     data=df_fit,
                                     alpha=1,
                                     color='yellow')
    if save_plot:
        gplt.save('{}.png'.format(output_file), dpi=300, width=5, height=4)
    return gplt
Esempio n. 3
0
def baseline_error():
    al_baselines_len = len(jsonFile['x'])
    df = pd.DataFrame(jsonFile)

    method_type = CategoricalDtype(
        categories=['Baseline', 'PGD', 'Provable', 'Mix'], ordered=True)
    df['method'] = df['method'].astype(method_type)

    p = (
        ggplot(df) + aes(x='x', y='y', shape='method', color='method') +
        geom_point(size=4, stroke=0) + geom_line(size=1) +
        scale_shape_discrete(name='Method') +
        scale_color_brewer(type='qual', palette=2, name='Method') +
        xlab('Training Time (seconds)') + ylab('Adversarial Error') + theme(
            # manually position legend, not used anymore
            # legend_position=(0.70, 0.65),
            # legend_background=element_rect(fill=(0, 0, 0, 0)),
            aspect_ratio=0.8, ) + ggtitle("Baseline Error"))
    fig_dir = '.'
    p.save(os.path.join(fig_dir, 'baselineErr.pdf'))
Esempio n. 4
0
    pd.DataFrame(np.transpose([
        np.repeat("Provable", robust_arr2.shape[0]), robust_arr2[:, 0],
        robust_arr2[:, 1]
    ]),
                 columns=["method", "x", "y"]))
df = df.append(
    pd.DataFrame(np.transpose(
        [np.repeat("Mix", mix_arr2.shape[0]), mix_arr2[:, 0], mix_arr2[:, 1]]),
                 columns=["method", "x", "y"]))
df['x'] = pd.to_numeric((df['x']))
df['y'] = pd.to_numeric((df['y']))
p = (
    ggplot(df) + aes(x='x', y='y', shape='method', color='method') +
    # + geom_point(size=4, stroke=0)
    geom_line(size=1) + scale_shape_discrete(name='Method') +
    scale_color_brewer(type='qual', palette=2, name='Method') +
    xlab('Training Time (seconds)') + ylab('Error') +
    theme(aspect_ratio=0.8, ) + ggtitle("Baseline Error"))
p.save(test + "/baselineErr.png", verbose=False)

df = pd.DataFrame([], columns=["method", "x", "y"])
df = df.append(
    pd.DataFrame(np.transpose([
        np.repeat("Baseline", baseline_arr2.shape[0]), baseline_arr2[:, 0],
        baseline_arr2[:, 2]
    ]),
                 columns=["method", "x", "y"]))
df = df.append(
    pd.DataFrame(np.transpose([
        np.repeat("Madry", madry_arr2.shape[0]), madry_arr2[:, 0],
        madry_arr2[:, 2]
Esempio n. 5
0
def main():
    """Run CLI."""
    parser = argparse.ArgumentParser(description="""
            Calcualte and compare LISI across a series of reduced dims and
            categorical variables.
            """)

    parser.add_argument(
        '-v',
        '--version',
        action='version',
        version='%(prog)s {version}'.format(version=__version__))

    # parser.add_argument(
    #     '-h5', '--h5_anndata',
    #     action='store',
    #     dest='h5',
    #     required=True,
    #     help='H5 AnnData file.'
    # )

    parser.add_argument(
        '-rf',
        '--reduced_dims_tsv',
        action='store',
        dest='reduced_dims',
        required=True,
        help='List of tab-delimited files of reduced dimensions (e.g., PCs)\
            for each cell. First column is cell_barcode. List should be\
            split by "::" (e.g. file1.tsv.gz::file2.tsv.gz).')

    parser.add_argument(
        '-lbl',
        '--reduced_dims_tsv_labels',
        action='store',
        dest='reduced_dims_labels',
        required=True,
        help='String of labels for each reduced_dims_tsv file. List should be\
            split by "::".')

    parser.add_argument(
        '-mf',
        '--metadata_tsv',
        action='store',
        dest='metadata_tsv',
        required=True,
        help='Tab-delimited file of metadata for each cell. First column\
            is cell_barcode.')

    parser.add_argument(
        '-mv',
        '--metadata_columns',
        action='store',
        dest='metadata_columns',
        default='experiment_id',
        help='Comma separated string of categorical variables to calculate\
            LISI with.\
            (default: %(default)s)')

    parser.add_argument('-p',
                        '--perplexity',
                        action='store',
                        dest='perplexity',
                        default=30.0,
                        type=float,
                        help='Perplexity.\
            (default: %(default)s)')

    parser.add_argument(
        '-of',
        '--output_file',
        action='store',
        dest='of',
        default='',
        help='Basename of output files, assuming output in current working \
            directory.\
            (default: <metadata_tsv>-lisi)')

    options = parser.parse_args()

    # Fixed settings.
    # verbose = True

    # Get the out file base.
    out_file_base = options.of
    if out_file_base == '':
        out_file_base = '{}-lisi'.format(
            os.path.basename(
                options.metadata_tsv.rstrip('tsv.gz').rstrip('.')))

    # Get the columns to use
    lisi_columns = options.metadata_columns.split(',')
    # lisi_columns = ['experiment_id', 'batch']
    lisi_columns_dtype = dict(
        zip(lisi_columns, ['category'] * len(lisi_columns)))

    # Load the metadata file
    file_meta = options.metadata_tsv
    df_meta = pd.read_csv(file_meta,
                          sep='\t',
                          index_col='cell_barcode',
                          dtype=lisi_columns_dtype)

    # Load the reduced dims.
    files = options.reduced_dims.split('::')
    labels = options.reduced_dims_labels.split('::')
    assert len(files) == len(labels), 'ERROR: check files and labels input'

    # Make a dict of theoretical maximum LISI value for each label.
    lisi_limit = {}
    for col in lisi_columns:
        n_cat = len(df_meta[col].cat.categories)
        lisi_limit[col] = n_cat

    list_lisi = []
    for i in range(len(files)):
        df_reduced_dims = pd.read_csv(files[i],
                                      sep='\t',
                                      index_col='cell_barcode')

        # Run lisi and save results to dataframe
        _df_lisi = pd.DataFrame(hm.compute_lisi(
            df_reduced_dims.loc[df_meta.index, :], df_meta[lisi_columns],
            lisi_columns),
                                columns=lisi_columns)
        _df_lisi['file'] = files[i]
        _df_lisi['label'] = labels[i]
        _df_lisi['cell_barcode'] = df_meta.index
        list_lisi.append(_df_lisi)

    # Make one long dataframe.
    df_lisi = pd.concat(list_lisi)
    # Make cell_barcode the first column.
    cols = list(df_lisi.columns)
    cols = [cols[-1]] + cols[:-1]

    # Save the results
    df_lisi[cols].to_csv('{}.tsv.gz'.format(out_file_base),
                         sep='\t',
                         index=False,
                         quoting=csv.QUOTE_NONNUMERIC,
                         na_rep='',
                         compression='gzip')

    # Compare the lisi distributions
    n_labels = len(labels)
    for lisi_column in lisi_columns:
        # Make density plot.
        gplt = plt9.ggplot(df_lisi,
                           plt9.aes(
                               fill='label',
                               x='label',
                               y=lisi_column,
                           ))
        gplt = gplt + plt9.theme_bw(base_size=12)
        gplt = gplt + plt9.geom_violin(alpha=0.9)
        gplt = gplt + plt9.geom_boxplot(
            group='label',
            position=plt9.position_dodge(width=.9),
            width=.1,
            fill='white',
            outlier_alpha=0  # Do not know how to totally remove outliers.
        )
        # Add a line at the theoretical maximum
        gplt = gplt + plt9.geom_hline(
            plt9.aes(yintercept=lisi_limit[lisi_column]))
        # gplt = gplt + plt9.facet_grid('{} ~ .'.format(label))
        gplt = gplt + plt9.labs(x='Reduced dimensions', y='LISI', title='')
        gplt = gplt + plt9.theme(
            axis_text_x=plt9.element_text(angle=-45, hjust=0))
        gplt = gplt + plt9.theme(legend_position='none')
        if n_labels != 0 and n_labels < 9:
            gplt = gplt + plt9.scale_fill_brewer(palette='Dark2', type='qual')
        gplt.save(
            '{}-{}-violin.png'.format(out_file_base, lisi_column),
            dpi=300,
            width=4 * (n_labels / 4),
            height=10,
            # height=4*(n_samples/4),
            limitsize=False)

        # Make ecdf.
        gplt = plt9.ggplot(df_lisi, plt9.aes(
            x=lisi_column,
            color='label',
        ))
        gplt = gplt + plt9.theme_bw(base_size=12)
        gplt = gplt + plt9.stat_ecdf(alpha=0.8)
        gplt = gplt + plt9.labs(
            x='LISI',
            y='Cumulative density',
            # color='Reduction',
            title='')
        if n_labels != 0 and n_labels < 9:
            gplt = gplt + plt9.scale_color_brewer(palette='Dark2', type='qual')
        gplt.save('{}-{}-ecdf.pdf'.format(out_file_base, lisi_column),
                  dpi=300,
                  width=10,
                  height=4,
                  limitsize=False)
Esempio n. 6
0
 from a given set of parameters, and for making comaprions plots.
 
 The add_agent method supports both representative agents and mixes of heterogeneous
 agents. For representative agents, a name, an employment history, the point in 
 the employment history to begin computing the consumption history, the point to
 begin computing the search history, and the parameter dictionary must be specified.
 For heteoregenous agent, a list of (weight,dictionary) tuples must be specified
 iunstead, with weights summing to 1.
"""

import plotnine as p9
param_path = "../Parameters/params_ui.json"
execfile("prelim.py")

###Plot aesthetics###
aes_color = p9.scale_color_brewer(type='qual', palette=2)
aes_color_alpha = 0.7
aes_glyphs = p9.scale_shape_manual(
    values=['o', '^', 's', 'D', 'v', 'x', 'P', '+'])
aes_fte_theme = p9.theme(
    axis_ticks=p9.element_blank(),
    panel_background=p9.element_rect(fill='white', color='white'),
    panel_border=p9.element_rect(color='white'),
    panel_grid_minor=p9.element_blank(),
    panel_grid_major=p9.element_blank(),
    legend_background=p9.element_rect(fill='white'),
    legend_text=p9.element_text(size=8),
    legend_key=p9.element_blank(),
    legend_title=p9.element_blank(),
    plot_title=p9.element_text(size=12, vjust=1.25, ha='center'),
    axis_text_x=p9.element_text(size=10),
Esempio n. 7
0
def make_single_likert_chart(survey_data, column, facet, labels, five_is_high=False):
    """Make an offset stacked barchart showing the number of respondents at each rank 
        or value for a single columns in the original data. Each facet is shown as
        a tick on the x-axis

    Args:
        survey_data (pandas.DataFrame): Raw data read in from Kubernetes Survey   
        topic (str): String that all questions of interest start with
        labels (list): List of strings to use as labels, corresponding
             to the numerical values given by the respondents.
        facet (str): Column used for grouping 
        five_is_high (bool, optionalc): Defaults to False. If True,
            5 is considered the highest value in a ranking, otherwise 
            it is taken as the lowest value.

    Returns:
        (plotnine.ggplot): Offset stacked barchart plot object which 
            can be displayed in a notebook or saved out to a file
    """
    mid_point = 3
    cols = [column, facet]
    show_legend = True
    topic_data = survey_data[cols]

    topic_data_long = make_long(topic_data, facet)

    if not five_is_high:
        topic_data_long = topic_data_long.assign(rating=topic_data_long.rating * -1.0)
    x = topic_data_long.columns.tolist()
    x.remove("level_1")
    x.remove("level_0")

    if not five_is_high:
        mid_point *= -1

    top_cutoff = topic_data_long["rating"] >= mid_point
    bottom_cutoff = topic_data_long["rating"] <= mid_point

    top_scores = (
        topic_data_long[top_cutoff]
        .groupby(x)
        .count()
        .reset_index()
        .sort_index(ascending=False)
    )

    top_scores.loc[top_scores["rating"] == mid_point, "level_1"] = (
        top_scores[top_scores["rating"] == mid_point]["level_1"] / 2.0
    )
    top_scores = top_scores.merge(
        topic_data_long.groupby(facet).count().reset_index(), on=facet
    )
    top_scores = top_scores.assign(level_1=top_scores.level_1_x / top_scores.level_1_y)

    bottom_scores = topic_data_long[bottom_cutoff].groupby(x).count().reset_index()
    bottom_scores.loc[bottom_scores["rating"] == mid_point, "level_1"] = (
        bottom_scores[bottom_scores["rating"] == mid_point]["level_1"] / 2.0
    )
    bottom_scores = bottom_scores.merge(
        topic_data_long.groupby(facet).count().reset_index(), on=facet
    )
    bottom_scores = bottom_scores.assign(
        level_1=bottom_scores.level_1_x * -1 / bottom_scores.level_1_y
    )

    vp = (
        p9.ggplot(
            topic_data_long,
            p9.aes(x=facet, fill="factor(rating_x)", color="factor(rating_x)"),
        )
        + p9.geom_col(
            data=top_scores,
            mapping=p9.aes(y="level_1"),
            show_legend=show_legend,
            size=0.25,
            position=p9.position_stack(reverse=True),
        )
        + p9.geom_col(
            data=bottom_scores,
            mapping=p9.aes(y="level_1"),
            show_legend=show_legend,
            size=0.25,
        )
        + p9.geom_hline(yintercept=0, color="white")
        + p9.theme(
            axis_text_x=p9.element_text(angle=45, ha="right"),
            strip_text_y=p9.element_text(angle=0, ha="left"),
        )
        + p9.scale_x_discrete(
            limits=topic_data_long[facet].unique().tolist(),
            labels=[
                x.replace("_", " ") for x in topic_data_long[facet].unique().tolist()
            ],
        )
    )

    if five_is_high:
        vp = (
            vp
            + p9.scale_color_brewer(
                "div",
                "RdBu",
                limits=[1, 2, 3, 4, 5],
                labels=["\n".join(wrap(x, 15)) for x in labels],
            )
            + p9.scale_fill_brewer(
                "div",
                "RdBu",
                limits=[1, 2, 3, 4, 5],
                labels=["\n".join(wrap(x, 15)) for x in labels],
            )
        )
    else:
        vp = (
            vp
            + reverse_scale_fill_brewer(
                "div",
                "RdBu",
                limits=[-1, -2, -3, -4, -5],
                labels=["\n".join(wrap(x, 15)) for x in labels],
            )
            + reverse_scale_color_brewer(
                "div",
                "RdBu",
                limits=[-1, -2, -3, -4, -5],
                labels=["\n".join(wrap(x, 15)) for x in labels],
            )
        )

    return vp
Esempio n. 8
0
def make_likert_chart(
    survey_data,
    topic,
    labels,
    facet_by=[],
    max_value=5,
    max_is_high=False,
    wrap_facets=True,
    sort_x=False,
):
    """Make an offset stacked barchart showing the number of respondents at each rank or value for 
        all columns in the topic. Each column in the original data is a tick on the x-axis

    Args:
        survey_data (pandas.DataFrame): Raw data read in from Kubernetes Survey   
        topic (str): String that all questions of interest start with
        labels (list): List of strings to use as labels, corresponding
             to the numerical values given by the respondents.
        facet_by (list,optional): List of columns use for grouping 
        max_value (int, optional):  Defaults to 5. The maximuum value a respondent can assign.
        max_is_high (bool, optiona ): Defaults to False. If True,
            the max_value is considered the highest value in a ranking, otherwise 
            it is taken as the lowest value.
        wrap_facets (bool, optional): Defaults to True. If True, the facet labels are 
            wrapped
        sort_x  (bool, optional): Defaults to False. If True, the x-axis is sorted by the 
            mean value for each column in the original data 

    Returns:
        (plotnine.ggplot): Offset stacked barchart plot object which 
            can be displayed in a notebook or saved out to a file
    """

    mid_point = math.ceil(max_value / 2)

    og_cols = [x for x in survey_data.columns if x.startswith(topic)]
    show_legend = True

    topic_data_long = get_single_year_data_subset(survey_data, topic, facet_by)

    if not max_is_high:
        topic_data_long = topic_data_long.assign(rating=topic_data_long.rating * -1.0)

        mid_point = -1 * mid_point

    top_scores, bottom_scores = split_for_likert(topic_data_long, mid_point)

    if facet_by:
        fix = False
        if "." in facet_by:
            facet_by.remove(".")
            fix = True

        top_scores = top_scores.merge(
            topic_data_long.groupby(facet_by).count().reset_index(), on=facet_by
        ).rename(columns={"rating_x": "rating", "level_0_x": "level_0"})
        top_scores = top_scores.assign(
            level_1=top_scores.level_1_x / (top_scores.level_1_y / len(og_cols))
        )

        bottom_scores = bottom_scores.merge(
            topic_data_long.groupby(facet_by).count().reset_index(), on=facet_by
        ).rename(columns={"rating_x": "rating", "level_0_x": "level_0"})
        bottom_scores = bottom_scores.assign(
            level_1=bottom_scores.level_1_x
            * -1
            / (bottom_scores.level_1_y / len(og_cols))
        )

        if fix:
            facet_by.append(".")

    else:
        bottom_scores = bottom_scores.assign(level_1=bottom_scores.level_1 * -1)

    if sort_x:
        x_sort_order = (
            topic_data_long.groupby("level_0")
            .mean()
            .sort_values("rating")
            .reset_index()["level_0"]
            .values.tolist()
        )
        x_sort_order.reverse()
    else:
        x_sort_order = topic_data_long["level_0"].unique().tolist()

    vp = (
        p9.ggplot(
            topic_data_long,
            p9.aes(x="level_0", fill="factor(rating)", color="factor(rating)"),
        )
        + p9.geom_col(
            data=top_scores,
            mapping=p9.aes(y="level_1"),
            show_legend=show_legend,
            size=0.25,
            position=p9.position_stack(reverse=True),
        )
        + p9.geom_col(
            data=bottom_scores,
            mapping=p9.aes(y="level_1"),
            show_legend=show_legend,
            size=0.25,
            position=p9.position_stack(),
        )
        + p9.geom_hline(yintercept=0, color="white")
        + p9.theme(
            axis_text_x=p9.element_text(angle=45, ha="right"),
            strip_text_y=p9.element_text(angle=0, ha="left"),
        )
        + p9.scale_x_discrete(
            limits=x_sort_order,
            labels=[
                "\n".join(
                    textwrap.wrap(x.replace(topic, "").replace("_", " "), width=35)[0:2]
                )
                for x in x_sort_order
            ],
        )
    )

    if max_is_high:
        vp = (
            vp
            + p9.scale_color_brewer(
                "div", "RdBu", limits=list(range(1, max_value + 1)), labels=labels
            )
            + p9.scale_fill_brewer(
                "div", "RdBu", limits=list(range(1, max_value + 1)), labels=labels
            )
        )

    else:
        vp = (
            vp
            + reverse_scale_fill_brewer(
                "div",
                "RdBu",
                limits=list(reversed(range(-max_value, 0))),
                labels=labels,
            )
            + reverse_scale_color_brewer(
                "div",
                "RdBu",
                limits=list(reversed(range(-max_value, 0))),
                labels=labels,
            )
        )

    if facet_by:
        if wrap_facets:
            vp = (
                vp
                + p9.facet_grid(facet_by, labeller=lambda x: "\n".join(wrap(x, 15)))
                + p9.theme(
                    strip_text_x=p9.element_text(
                        wrap=True, va="bottom", margin={"b": -0.5}
                    )
                )
            )
        else:
            vp = vp + p9.facet_grid(facet_by, space="free", labeller=lambda x: x)
    return vp
Esempio n. 9
0
def make_likert_chart_multi_year(
    survey_data,
    topic,
    labels,
    facet_by=[],
    five_is_high=False,
    exclude_new_contributors=False,
):
    """Make an offset stacked barchart showing the number of respondents at each rank or value for 
        all columns in the topic. Each column in the topic is a facet, with the years displayed
        along the x-axis.

    Args:
        survey_data (pandas.DataFrame): Raw data read in from Kubernetes Survey   
        topic (str): String that all questions of interest start with
        labels (list): List of strings to use as labels, corresponding
             to the numerical values given by the respondents.
        facet_by (list,optional): List of columns use for grouping
        five_is_high (bool, optiona ): Defaults to False. If True,
            five is considered the highest value in a ranking, otherwise 
            it is taken as the lowest value.
        exclude_new_contributors (bool, optional): Defaults to False. If True,
            do not include any responses from contributors with less than 
            one year of experience        

    Returns:
        (plotnine.ggplot): Offset stacked barchart plot object which 
            can be displayed in a notebook or saved out to a file
    """

    facet_by = copy(facet_by)
    og_cols = [x for x in survey_data.columns if x.startswith(topic)]
    show_legend = True

    topic_data_long = get_multi_year_data_subset(
        survey_data, topic, facet_by, exclude_new_contributors
    )

    if not five_is_high:
        topic_data_long = topic_data_long.assign(rating=topic_data_long.rating * -1.0)

    mid_point = 3 if five_is_high else -3
    top_scores, bottom_scores = split_for_likert(topic_data_long, mid_point)

    if facet_by:
        fix = False
        if "." in facet_by:
            facet_by.remove(".")
            fix = True

        # Calculate proportion for each rank
        top_scores = top_scores.merge(
            topic_data_long.groupby(facet_by + ["year"]).count().reset_index(),
            on=facet_by + ["year"],
        ).rename(columns={"rating_x": "rating", "level_0_x": "level_0"})
        top_scores = top_scores.assign(
            level_1=top_scores.level_1_x / (top_scores.level_1_y / len(og_cols))
        )

        bottom_scores = bottom_scores.merge(
            topic_data_long.groupby(facet_by + ["year"]).count().reset_index(),
            on=facet_by + ["year"],
        ).rename(columns={"rating_x": "rating", "level_0_x": "level_0"})
        bottom_scores = bottom_scores.assign(
            level_1=bottom_scores.level_1_x
            * -1
            / (bottom_scores.level_1_y / len(og_cols))
        )

        if fix:
            facet_by.append(".")
    else:
        # Calculate proportion for each rank
        top_scores = top_scores.merge(
            topic_data_long.groupby(["year"]).count().reset_index(), on=["year"]
        ).rename(columns={"rating_x": "rating", "level_0_x": "level_0"})
        top_scores = top_scores.assign(
            level_1=top_scores.level_1_x / (top_scores.level_1_y / len(og_cols))
        )

        bottom_scores = bottom_scores.merge(
            topic_data_long.groupby(["year"]).count().reset_index(), on=["year"]
        ).rename(columns={"rating_x": "rating", "level_0_x": "level_0"})
        bottom_scores = bottom_scores.assign(
            level_1=bottom_scores.level_1_x
            * -1
            / (bottom_scores.level_1_y / len(og_cols))
        )

    vp = (
        p9.ggplot(
            topic_data_long,
            p9.aes(x="factor(year)", fill="factor(rating)", color="factor(rating)"),
        )
        + p9.geom_col(
            data=top_scores,
            mapping=p9.aes(y="level_1"),
            show_legend=show_legend,
            size=0.25,
            position=p9.position_stack(reverse=True),
        )
        + p9.geom_col(
            data=bottom_scores,
            mapping=p9.aes(y="level_1"),
            show_legend=show_legend,
            size=0.25,
            position=p9.position_stack(),
        )
        + p9.geom_hline(yintercept=0, color="white")
    )

    if five_is_high:
        vp = (
            vp
            + p9.scale_color_brewer(
                "div", "RdBu", limits=[1, 2, 3, 4, 5], labels=labels
            )
            + p9.scale_fill_brewer("div", "RdBu", limits=[1, 2, 3, 4, 5], labels=labels)
            + p9.theme(
                axis_text_x=p9.element_text(angle=45, ha="right"),
                strip_text_y=p9.element_text(angle=0, ha="left"),
            )
        )
    else:
        vp = (
            vp
            + p9.scale_color_brewer(
                "div", "RdBu", limits=[-5, -4, -3, -2, -1], labels=labels
            )
            + p9.scale_fill_brewer(
                "div", "RdBu", limits=[-5, -4, -3, -2, -1], labels=labels
            )
            + p9.theme(strip_text_y=p9.element_text(angle=0, ha="left"))
        )

    if facet_by:
        facet_by.remove(".")

    else:
        facet_by.append(".")

    vp = (
        vp
        + p9.facet_grid(
            facet_by + ["level_0"],
            labeller=lambda x: "\n".join(
                wrap(
                    x.replace(topic, "").replace("_", " ").replace("/", "/ ").strip(),
                    15,
                )
            ),
        )
        + p9.theme(
            strip_text_x=p9.element_text(wrap=True, ma="left"), panel_spacing_x=0.1
        )
    )

    return vp
print([_.shape for _ in tss])
print(ts.shape)

# convert data into dataframe
df = arr2df(arrays, ts, titles)  # , n=50)

df2 = df.loc[:, ['tag', 'hr', 'v']].groupby(
    ['tag', 'hr']).quantile(q=[.9, .99, 1]).unstack()
df2.columns = ['q090', 'q099', 'q100']

df2 = df2.reset_index().melt(id_vars=['tag', 'hr'],
                             var_name='q',
                             value_name='v')

df2['g'] = df2['tag'] + df2['q'].astype(str)

hrmax = df['hr'].max()
p = (ggplot(df2) +
     geom_line(aes('hr', 'v', color='tag', alpha='q', size='q', group='g')) +
     scale_x_continuous(
         breaks=np.arange(0, hrmax, 24),
         minor_breaks=np.arange(0, hrmax, 6),
     ) + scale_color_brewer(type='qual', palette='Set1') +
     scale_alpha_manual(np.linspace(1, 0.2, num=3)) +
     scale_size_manual([2, 1, .5]) + labs(title=title, y='conc (ppb)'))
p.save(oname)

print(df2['v'].max())
pp = p + scale_y_log10(limits=[1.0, df2['v'].max()])
pp.save(oname[:-4] + '_log.png')
g = (p9.ggplot(
    kl_divergence_df.replace({
        "biorxiv_vs_pmc": "bioRxiv-PMC",
        "biorxiv_vs_nytac": "bioRxiv-NYTAC",
        "pmc_vs_nytac": "PMC-NYTAC",
    }).rename(index=str, columns={"comparison": "Comparison"})) + p9.aes(
        x="factor(num_terms)",
        y="KL_divergence",
        fill="Comparison",
        color="Comparison",
        group="Comparison",
    ) + p9.geom_point(size=2) + p9.geom_line(linetype="dashed") +
     p9.scale_fill_brewer(type="qual", palette="Paired", direction=-1) +
     p9.scale_color_brewer(
         type="qual",
         palette="Paired",
         direction=-1,
     ) + p9.labs(
         x="Number of terms evaluated",
         y="Kullback–Leibler Divergence",
     ) + p9.theme_seaborn(
         context="paper",
         style="ticks",
         font_scale=1.8,
     ) + p9.theme(figure_size=(11, 8.5), text=p9.element_text(family="Arial")))
g.save("output/svg_files/corpora_kl_divergence.svg")
g.save("output/figures/corpora_kl_divergence.png", dpi=500)
print(g)

kl_divergence_special_char_df = pd.read_csv(
    "output/comparison_stats/corpora_kl_divergence_special_chars_removed.tsv",