def make_chart_organisational_diversity(
    org_coeffs,
    num_orgs,
    metric_params,
    org_type_lookup,
    paper_counts,
    save=True,
    fig_num=14,
):
    """Plot comparing the organisational diversity coefficients"""

    # Regression coefficients sorted
    selected = (org_coeffs[metric_params].sort_values("beta").head(
        n=num_orgs).reset_index(drop=False))

    selected["org_type"] = selected["index"].map(org_type_lookup)
    selected["order"] = range(0, len(selected))

    # Paper counts by organisation
    recent_papers_orgs = (paper_counts.loc[selected["index"]].reset_index(
        name="papers").rename(columns={"index": "org"}))
    recent_papers_orgs["order"] = range(0, len(recent_papers_orgs))
    recent_papers_orgs["org_type"] = recent_papers_orgs["org"].map(
        org_type_lookup)

    b_ch = (alt.Chart(selected).mark_bar().encode(
        y=alt.Y("index", sort=alt.EncodingSortField("order"), title=""),
        x=alt.X("beta", title="Coefficient on diversity"),
        color=alt.X("org_type", title="Organisation type"),
    )).properties(width=150, height=600)

    b_err = (alt.Chart(selected).mark_errorbar().encode(
        y=alt.Y(
            "index",
            sort=alt.EncodingSortField("order"),
            title="",
            axis=alt.Axis(ticks=False, labels=False),
        ),
        x=alt.X("lower", title=""),
        x2="upper",
    )).properties(width=150, height=600)

    b_act = (alt.Chart(recent_papers_orgs).mark_bar().encode(
        y=alt.Y(
            "org",
            title=None,
            sort=alt.EncodingSortField("order"),
            axis=alt.Axis(labels=False, ticks=False),
        ),
        x=alt.X("papers"),
        color="org_type",
    )).properties(width=100, height=600)

    out = (b_ch + b_err).resolve_scale(y="independent")
    out_2 = alt.hconcat(out, b_act, spacing=0).resolve_scale(y="shared")

    if save is True:
        save_altair(out_2, f"fig_{fig_num}_comp", driv)

    return out_2
def plot_microtrends(
    trend_table, t0=2010, t1=2021, window=10, save=False, name="microtrends_example"
):
    """Plot microtrends:
    Args:
        trend_table: table with trends in a topic
        t0, t1: first and last year
        window: window for rolling averages
        save: if we want to save
        name: name (only needed if we want to save)
    """

    ex_trends_norm = (
        trend_table.loc[[x.year in np.arange(t0, t1) for x in trend_table.index]]
        .rolling(window=10)
        .mean()
        .dropna()
        .reset_index(drop=False)
        .melt(id_vars="date")
    )

    tr = (
        alt.Chart(ex_trends_norm)
        .mark_line()
        .encode(x="date:T", y="value", color="variable")
    )

    if save is True:
        save_altair(tr, f"fig_{name}", driv)

    return tr
def visualise_tsne(tsne_df, save=True, fig_num=15):
    """Visualise tsne plot"""
    tsne_base = alt.Chart(tsne_df).encode(
        x=alt.X("x:Q", title="", axis=alt.Axis(ticks=False, labels=False)),
        y=alt.Y("y:Q", title="", axis=alt.Axis(ticks=False, labels=False)),
    )

    tsne_points = ((
        tsne_base.mark_point(
            filled=True, opacity=0.5, stroke="black",
            strokeOpacity=0.5).encode(
                color=alt.Color("org_type", title="Organisation type"),
                strokeWidth=alt.Stroke("top",
                                       scale=alt.Scale(range=[0, 1]),
                                       legend=None),
                # stroke = alt.value('blue'),
                size=alt.Size("activity:Q", title="Number of papers"),
                facet=alt.Facet("size",
                                columns=2,
                                title="Number of organisations in plot"),
                tooltip=["index"],
            )).interactive().resolve_scale(
                y="independent", x="independent").properties(width=250,
                                                             height=250))

    if save is True:
        save_altair(tsne_points, "fig_15_tsne", driv)

    return tsne_points
def make_chart_topic_spec(
    topic_rca,
    topic_mix,
    arxiv_cat_lookup,
    topic_thres=0.05,
    topic_n=150,
    save=False,
    fig_n="extra_1",
):
    """Visualises prevalence of topics in a category
    Args:
        topic_rca: relative specialisation of topics in categories
        arxiv_cat_lookup: lookup between category ids and names
        topic_thres: threshold for topic
        topic_n: number of topics to consider
        save: if we want to save the figure
        fig_n: figure id

    """
    logging.info("Extracting topic counts")
    # Visualise topic distributions
    topic_counts_long = topic_rca.reset_index(drop=False).melt(id_vars="index")

    # Extract top topics
    top_topics = list(
        topic_mix.iloc[:, 1:]
        .applymap(lambda x: x > topic_thres)
        .sum(axis=0)
        .sort_values(ascending=False)[:topic_n]
        .index
    )

    # Focus on those for the long topic
    topic_counts_long_ = topic_counts_long.loc[
        topic_counts_long["variable"].isin(top_topics)
    ]

    # Add nice names for categoru
    topic_counts_long_["arx_cat"] = [
        x.split(" ") for x in topic_counts_long_["index"].map(arxiv_cat_lookup)
    ]

    topic_spec = (
        alt.Chart(topic_counts_long_)
        .mark_bar(color="red")
        .encode(
            y=alt.Y(
                "variable", sort=top_topics, axis=alt.Axis(labels=False, ticks=False)
            ),
            x="value",
            facet=alt.Facet("arx_cat", columns=5),
            tooltip=["variable", "value"],
        )
    ).properties(width=100, height=100)

    if save is True:
        save_altair(topic_spec, f"fig_{fig_n}_topic_specialisations", driv)

    return topic_spec
def make_chart_type_evol(porgs, save=True, fig_number=7):
    """Plots evolution of org types"""

    # Evolution of types
    org_type_evol = (porgs.groupby(["is_ai", "org_type",
                                    "date"]).size().reset_index(name="count"))

    org_type_evol_wide = org_type_evol.pivot_table(index=["date", "is_ai"],
                                                   columns="org_type",
                                                   values="count").fillna(0)

    # Calculate shares
    org_type_evol_sh = org_type_evol_wide.apply(lambda x: x / x.sum(), axis=1)

    # Melt and clean
    org_type_evol_long = org_type_evol_sh.reset_index(drop=False).melt(
        id_vars=["is_ai", "date"])

    org_type_evol_long = org_type_evol_long.loc[[
        x.year > 2000 for x in org_type_evol_long["date"]
    ]]

    org_type_evol_long["category"] = [
        "AI" if x is True else "Not AI" for x in org_type_evol_long["is_ai"]
    ]

    # Visualise
    org_type_evol_ch = ((alt.Chart(org_type_evol_long).transform_window(
        m="mean(value)", frame=[-10, 0],
        groupby=["is_ai", "org_type"]).transform_filter(
            alt.FieldOneOfPredicate(
                "org_type",
                ["Company", "Nonprofit", "Government", "Healthcare"
                 ])).mark_line().encode(
                     x=alt.X("date:T", title=""),
                     y=alt.Y("m:Q", title="Share of activity"),
                     color=alt.Color(
                         "org_type",
                         title="Type of organisation",
                         sort=alt.EncodingSortField("count",
                                                    "sum",
                                                    order="descending"),
                     ),
                     column=alt.Column("category",
                                       title="Research category",
                                       sort=["Not AI", "AI"]),
                 )).resolve_scale(y="independent").properties(width=200,
                                                              height=150))

    if save is True:
        save_altair(org_type_evol_ch, f"fig_{fig_number}_type_evol", driv)

    return org_type_evol_ch
def make_chart_topic_trends(
    topic_trends, arxiv_cat_lookup, year_sort=2020, save=True, fig_n=4
):
    """Topic trend chart"""

    # Sort topics by the year of interest
    topics_sorted = (
        topic_trends.loc[[x.year == year_sort for x in topic_trends["date"]]]
        .groupby("topic_cat")["value"]
        .sum()
        .sort_values(ascending=False)
        .index.tolist()
    )

    topic_trends["order"] = [
        [n for n, k in enumerate(topics_sorted) if x == k][0]
        for x in topic_trends["topic_cat"]
    ]

    # Create clean category names
    topic_trends["topic_cat_clean"] = [
        arxiv_cat_lookup[x][:50] + "..." for x in topic_trends["topic_cat"]
    ]

    # Create clean topic sorted names
    topics_sorted_2 = [arxiv_cat_lookup[x][:50] + "..." for x in topics_sorted]

    evol_sh = (
        alt.Chart(topic_trends)
        .mark_bar(stroke="grey", strokeWidth=0.1)
        .encode(
            x="date:T",
            y=alt.Y("value", scale=alt.Scale(domain=[0, 1])),
            color=alt.Color(
                "topic_cat_clean",
                sort=topics_sorted_2,
                title="Source category",
                scale=alt.Scale(scheme="tableau20"),
                legend=alt.Legend(columns=2),
            ),
            order=alt.Order("order", sort="descending"),
            tooltip=["topic_cat_clean"],
        )
    ).properties(width=400)

    if save is True:
        save_altair(evol_sh, f"fig_{fig_n}_topic_trends", driv)

    return evol_sh
Example #7
0
def make_agg_trend(arx, save=True):
    """Makes first plot"""
    # First chart: trends
    ai_bool_lookup = {False: "Other categories", True: "AI"}

    # Totals
    ai_trends = (
        arx.groupby(["date", "is_ai"]).size().reset_index(name="Number of papers")
    )
    ai_trends["is_ai"] = ai_trends["is_ai"].map(ai_bool_lookup)

    # Shares
    ai_shares = (
        ai_trends.pivot_table(index="date", columns="is_ai", values="Number of papers")
        .fillna(0)
        .reset_index(drop=False)
    )
    ai_shares["share"] = ai_shares["AI"] / ai_shares.sum(axis=1)

    #  Make chart
    at_ch = (
        alt.Chart(ai_trends)
        .transform_window(
            roll="mean(Number of papers)", frame=[-5, 5], groupby=["is_ai"]
        )
        .mark_line()
        .encode(
            x=alt.X("date:T", title="", axis=alt.Axis(labels=False, ticks=False)),
            y=alt.Y("roll:Q", title=["Number", "of papers"]),
            color=alt.Color("is_ai:N", title="Category"),
        )
        .properties(width=350, height=120)
    )
    as_ch = (
        alt.Chart(ai_shares)
        .transform_window(roll="mean(share)", frame=[-5, 5])
        .mark_line()
        .encode(
            x=alt.X("date:T", title=""),
            y=alt.Y("roll:Q", title=["AI as share", "of all arXiv"]),
        )
    ).properties(width=350, height=120)

    ai_trends_chart = alt.vconcat(at_ch, as_ch, spacing=0)

    if save is True:
        save_altair(ai_trends_chart, "fig_1_ai_trends", driver=driv)

    return ai_trends_chart, ai_trends
Example #8
0
def make_chart_sector_boxplot(df, save=True, fig_num=12):
    """Boxplots comparing diversity of companies and academia"""

    div_comp_sample = (alt.Chart(df).mark_boxplot().encode(
        x=alt.X("category:N", title="", sort=["Company", "Education"]),
        y=alt.Y("score:Q", scale=alt.Scale(zero=False)),
        row=alt.Row("metric",
                    title="Metric",
                    sort=["balance", "weitzman", "rao-stirling"]),
        column=alt.Column("parametre_set", title="Parametre set"),
        color=alt.Color("category", title="Organisation type"),
    ).resolve_scale(y="independent")).properties(height=100, width=100)

    if save is True:
        save_altair(div_comp_sample, f"fig_{fig_num}_div_sect_multiple", driv)

    return div_comp_sample
Example #9
0
def make_chart_distribution_centrality(shares_long,
                                       centrality_ranked_all,
                                       saving=True,
                                       fig_n=11):
    """Plots evolution of shares of topics and centrality averages
    at different positions of the distribution
    Args:
        shares_long: df with shares of topic activity accounted at
            different points of the distribution
        centrality_ranked_all: df with mean centralities for topics
            at different points of the distribution

    """

    shares_evol = (alt.Chart(shares_long).mark_line().encode(
        x=alt.X("variable:N",
                title="",
                axis=alt.Axis(labels=False, ticks=False)),
        y=alt.Y("value", title=["Share of activity", "accounted by rank"]),
        color=alt.X("index:N", title="Position in distribution"),
    )).properties(width=300, height=170)

    line = (alt.Chart(centrality_ranked_all).transform_aggregate(
        m="mean(eigenvector_z)",
        groupby=["year",
                 "rank_segment"]).mark_line().encode(x=alt.X("year:N",
                                                             title=""),
                                                     y="m:Q",
                                                     color="rank_segment:N"))

    band = (alt.Chart(centrality_ranked_all).mark_errorband().encode(
        x="year:N",
        y=alt.Y("eigenvector_z", title=["Mean eigenvector", "centrality"]),
        color=alt.Color("rank_segment:N", title="Position in distribution"),
    ))

    eigen_evol_linech = (band + line).properties(width=300, height=170)

    div_comp = alt.vconcat(shares_evol, eigen_evol_linech,
                           spacing=0).resolve_scale(x="shared",
                                                    color="independent")

    if saving is True:
        save_altair(div_comp, f"fig_{fig_n}_div_comps_evol", driv)

    return div_comp
Example #10
0
def make_chart_diversity_evol(results, save=True, fig_n=10):
    """Plots evolution of diversity"""
    div_evol_ch = ((alt.Chart(results).mark_line(opacity=0.9).encode(
        x=alt.X("year:O", title=""),
        y=alt.Y("score:Q", scale=alt.Scale(zero=False), title="z-score"),
        row=alt.Row(
            "diversity_metric",
            title="Diversity metric",
            sort=["balance", "weitzman", "rao_stirling"],
        ),
        color=alt.Color("parametre_set:N", title="Parameter set"),
    )).resolve_scale(y="independent").properties(width=250, height=80))

    if save is True:
        save_altair(div_evol_ch, f"fig_{fig_n}_div_evol", driv)

    return div_evol_ch
Example #11
0
def make_chart_sector_comparison(results_df, save=True, fig_num=12):
    """Barcharts comparing diversity of companies and academia"""

    div_comp_chart = (alt.Chart(results_df).mark_bar(
        color="red", opacity=0.7, stroke="grey").encode(
            x=alt.X("category",
                    title="Category",
                    sort=["Company", "Education"]),
            y=alt.Y("score", scale=alt.Scale(zero=False), title="Score"),
            column=alt.Column("metric",
                              title="Metric",
                              sort=["balance", "weitzman", "rao_stirling"]),
            row=alt.Row("parametre_set", title="Parametre set"),
        ).resolve_scale(y="independent")).properties(height=75, width=100)

    if save is True:
        save_altair(div_comp_chart, f"fig_{fig_num}_div_sect", driv)

    return div_comp_chart
Example #12
0
def make_cat_trend(linech, save=True, fig_n=2):
    """Makes chart 2"""

    ai_subtrends_chart = (
        alt.Chart(linech)
        .transform_window(
            roll="mean(n)", frame=[-10, 0], groupby=["category_clean", "type"]
        )
        .mark_line()
        .encode(
            x=alt.X("index:T", title=""),
            y=alt.X("roll:Q", title="Number of papers"),
            color=alt.Color("type", title="Source"),
        )
        .properties(width=200, height=100)
    ).facet(alt.Facet("category_clean", title="Category"), columns=2)

    if save is True:
        save_altair(ai_subtrends_chart, f"fig_{fig_n}_ai_subtrends", driver=driv)

    return ai_subtrends_chart
def make_chart_type_comparison(porgs, save=True, fig_number=6):
    """Plots evolution of activity by organisation type"""
    # Counts activity by type
    org_type_count = (porgs.groupby(["is_ai", "org_type"
                                     ]).size().reset_index(name="count"))

    # Melts and normalises
    org_type_long = (org_type_count.pivot_table(
        index="org_type", columns="is_ai",
        values="count").apply(lambda x: 100 * x / x.sum()).reset_index(
            drop=False).melt(id_vars="org_type"))

    # Add clean AI name
    org_type_long["category"] = [
        "AI" if x is True else "Not AI" for x in org_type_long["is_ai"]
    ]

    # Create altair base
    base = alt.Chart(org_type_long).encode(
        y=alt.Y(
            "org_type",
            title=["Type of", "organisation"],
            sort=alt.EncodingSortField("value", "sum", order="descending"),
        ),
        x=alt.X("value", title="% activity"),
    )

    type_comp_point = base.mark_point(filled=True).encode(color="category",
                                                          shape="category")

    type_comp_l = base.mark_line(stroke="grey",
                                 strokeWidth=1,
                                 strokeDash=[1, 1]).encode(detail="org_type")

    type_comp_chart = (type_comp_point + type_comp_l).properties(height=100)

    if save is True:
        save_altair(type_comp_chart, f"fig_{fig_number}_type_comp", driv)

    return type_comp_chart
def make_chart_company_activity(porgs,
                                papers,
                                top_c=15,
                                t0=2005,
                                roll_w=-8,
                                save=True,
                                fig_num=8):
    """Chart evolution of company activity"""
    logging.info("Preparing data")
    # Create a paper - date lookup
    paper_year_date = create_paper_dates_dict(papers)["date"]

    # Find paper IDs for top 15 companies
    top_comps = (porgs.loc[(porgs["is_ai"] == True)
                           & (porgs["org_type"] == "Company")]
                 ["org_name"].value_counts().head(n=top_c).index)

    comp_papers = {
        org: set(porgs.loc[porgs["org_name"] == org]["article_id"])
        for org in top_comps
    }

    # Concatenate them in a dataframe
    comp_trends = (pd.DataFrame([
        pd.Series(
            [paper_year_date[x] for x in ser if x in paper_year_date.keys()],
            name=n,
        ).value_counts() for n, ser in comp_papers.items()
    ]).fillna(0).T)

    comp_trends_long = (comp_trends.reset_index(drop=False).melt(
        id_vars="index").assign(indicator="Total AI papers"))

    # Extract top 20 comps for 2020 (we use this to order the chart later)
    comps_2020 = (comp_trends.loc[[x.year == 2020 for x in comp_trends.index
                                   ]].sum().sort_values(ascending=False))
    comps_order = {n: name for n, name in enumerate(comps_2020.index)}

    # Normalise with paper counts for all AI
    logging.info("Normalising")
    all_ai_counts = (porgs.drop_duplicates("article_id").query("is_ai == True")
                     ["date"].value_counts())

    comp_trends_norm = comp_trends.apply(lambda x: x / all_ai_counts).fillna(0)

    comp_trends_share_long = (comp_trends_norm.reset_index(drop=False).melt(
        id_vars="index").assign(indicator="Share of all AI"))

    # Some tidying up before plotting
    comp_trends = pd.concat([comp_trends_long, comp_trends_share_long])
    comp_trends_recent = comp_trends.loc[[(x.year > t0)
                                          for x in comp_trends["index"]]]
    comp_trends_recent["order"] = comp_trends_recent["variable"].map(
        comps_order)

    date_domain = list(pd.to_datetime(["2007-01-01", "2020-07-01"]))

    logging.info("Plotting")
    # Create chart
    comp_evol_chart = ((alt.Chart(comp_trends_recent).mark_area(
        stroke="black", strokeWidth=0.1, clip=True).transform_window(
            roll="mean(value)", frame=[roll_w, 0],
            groupby=["indicator"]).encode(
                x=alt.X("index:T",
                        title="",
                        scale=alt.Scale(domain=date_domain)),
                y=alt.Y("roll:Q", title=""),
                color=alt.Color(
                    "variable",
                    title="Company",
                    scale=alt.Scale(scheme="tableau20"),
                    sort=comps_2020.index.tolist(),
                ),
                order=alt.Order("order"),
                row=alt.Row("indicator", sort=["total", "share"]),
            )).properties(height=200).resolve_scale(y="independent"))

    if save is True:
        save_altair(comp_evol_chart, f"fig_{fig_num}_company_evol", driv)

    return comp_evol_chart
Example #15
0
def make_cat_distr_chart(
    cat_sets, ai_joint, arxiv_cat_lookup, cats_to_plot=20, save=True, fig_n=3
):
    """Makes chart 3
    Args:
        cat_sets: df grouping ids by category
        ai_joint: ai ids
        arxiv_cat_lookup: lookup between arxiv cat ids and names
        cats_to_plot: Number of top categories to visualise
        save: Whether to save the plot

    """

    logging.info("Getting category frequencies")
    # Create a df with dummies for all categories
    cat_dums = []

    for c, v in cat_sets.items():
        d = pd.DataFrame(index=[x for x in v if x in ai_joint])
        d[c] = True
        cat_dums.append(d)

    cat_bin_df = pd.concat(cat_dums, axis=1, sort=True).fillna(0)

    # Category frequencies for all papers
    ai_freqs = (
        cat_bin_df.sum()
        .reset_index(name="paper_counts")
        .query("paper_counts > 0")
        .sort_values("paper_counts", ascending=False)
        .rename(columns={"index": "category_id"})
    )

    ai_freqs["category_name"] = [
        arxiv_cat_lookup[x][:40] + "..." for x in ai_freqs["category_id"]
    ]
    order_lookup = {cat: n for n, cat in enumerate(ai_freqs["category_id"])}
    ai_freqs["order"] = ai_freqs["category_id"].map(order_lookup)

    logging.info("Getting category overlaps")
    # Category frequencies for each category
    res = []
    for x in ai_freqs["category_id"]:

        ai_cat = cat_bin_df.loc[cat_bin_df[x] == True].sum().drop(x, axis=0)
        ai_cat.name = x
        res.append(ai_cat)

    hm_long = (
        pd.concat(res, axis=1)
        .apply(lambda x: x / x.sum(), axis=1)
        .loc[ai_freqs["category_id"]]
        .fillna(0)
        .reset_index(drop=False)
        .melt(id_vars="index")
    )

    hm_long["category_name_1"] = [
        arxiv_cat_lookup[x][:40] + "..." for x in hm_long["index"]
    ]
    hm_long["category_name_2"] = [
        arxiv_cat_lookup[x][:40] + "..." for x in hm_long["variable"]
    ]
    hm_long["order_1"] = hm_long["index"].map(order_lookup)
    hm_long["order_2"] = hm_long["variable"].map(order_lookup)
    hm_long["value"] = [100 * np.round(x, 4) for x in hm_long["value"]]

    # And plot
    logging.info("Plotting")
    # Barchart
    # We focus on the top 20 categories with AI papers
    ai_freq_bar = (
        alt.Chart(ai_freqs.loc[ai_freqs["order"] < cats_to_plot])
        .mark_bar(color="red", opacity=0.6, stroke="grey", strokeWidth=0.5)
        .encode(
            y=alt.Y("paper_counts", title="Number of papers"),
            x=alt.X(
                "category_name",
                title="",
                sort=alt.EncodingSortField("order"),
                axis=alt.Axis(labels=False, ticks=False),
            ),
        )
    ).properties(width=350, height=200)

    # HM
    ai_hm = (
        alt.Chart(hm_long.query("order_1 < 20").query("order_2 < 20"))
        .mark_rect()
        .encode(
            x=alt.X(
                "category_name_1",
                sort=alt.EncodingSortField("order_1"),
                title="arXiv category",
            ),
            y=alt.Y(
                "category_name_2",
                sort=alt.EncodingSortField("order_2"),
                title="arXiv category",
            ),
            # order=alt.Order('value',sort='ascending'),
            color=alt.Color(
                "value", title=["% of articles in x-category", "with y-category"]
            ),
            tooltip=["category_name_2", "value"],
        )
    ).properties(width=350)

    cat_freqs_hm = (
        alt.vconcat(ai_freq_bar, ai_hm)
        .configure_concat(spacing=0)
        .resolve_scale(color="independent")
    )

    if save is True:
        save_altair(cat_freqs_hm, f"fig_{fig_n}_arxiv_categories", driv)

    return cat_freqs_hm
def make_chart_topic_comparison(
    topic_mix,
    arxiv_cat_lookup,
    comparison_ids,
    selected_categories,
    comparison_names,
    topic_list,
    topic_category_map,
    highlights=False,
    highlight_topics=None,
    highlight_class_table="Company",
    save=True,
    fig_num=15,
):
    """Creates a chart that compares the topic specialisations
    of different groups of organisations
    Args:
        topic_mix: topic mix
        arxiv_cat_lookup: lookup between category ids and names
        comparison_ids: ids we want to compare
        selected_categories: arXiv categories to focus on
        comparison_names: names for the categories we are comparing
        highlights: if we want to highlight particular topics
        highlight_topics: which ones
        highlight_class_table: topics to highlight in the table
    """

    # Extract the representations of categories
    comp_topic_rel = pd.DataFrame([
        topic_rep(
            ids,
            topic_mix,
            selected_categories,
            topic_list=topic_list,
            topic_category_map=topic_category_map,
        )[1].loc[True] for ids in comparison_ids
    ])
    comparison_df = comp_topic_rel.T
    comparison_df.columns = comparison_names

    comparison_df_long = comparison_df.reset_index(drop=False).melt(
        id_vars="index")
    comparison_df_long["cat"] = comparison_df_long["index"].map(
        topic_category_map)

    order = (comparison_df_long.groupby(
        ["index", "cat"])["value"].sum().reset_index(drop=False).sort_values(
            by=["cat", "value"], ascending=[True, False])["index"].tolist())

    comparison_df_filter = comparison_df_long.loc[
        comparison_df_long["cat"].isin(selected_categories)]

    comparison_df_filter["cat_clean"] = [
        arxiv_cat_lookup[x][:35] + "..." for x in comparison_df_filter["cat"]
    ]

    # Sort categories by biggest differences?
    diff_comp = (comparison_df_filter.pivot_table(
        index=["index", "cat_clean"], columns="variable",
        values="value").assign(
            diff=lambda x: x["company"] - x["academia"]).reset_index(
                drop=False).groupby("cat_clean")["diff"].mean().sort_values(
                    ascending=False).index.tolist())

    # Plot
    comp_ch = (alt.Chart(comparison_df_filter).mark_point(
        filled=True, opacity=0.5, stroke="black", strokeWidth=0.5).encode(
            x=alt.X("index",
                    title="",
                    sort=order,
                    axis=alt.Axis(labels=False, ticks=False)),
            y=alt.Y("value", title=["Share of papers", "with topic"]),
            color=alt.Color("variable", title="Organisation type"),
            tooltip=["index"],
        ))

    comp_lines = (alt.Chart(comparison_df_filter).mark_line(
        strokeWidth=1, strokeDash=[1, 1], stroke="grey").encode(
            x=alt.X("index",
                    sort=order,
                    axis=alt.Axis(labels=False, ticks=False)),
            y="value",
            detail="index",
        ))

    topic_comp_type = ((comp_ch + comp_lines).properties(
        width=200, height=150).facet(alt.Facet("cat_clean",
                                               sort=diff_comp,
                                               title="arXiv category"),
                                     columns=3).resolve_scale(x="independent"))

    if highlights is False:

        topic_comp_type = ((comp_ch + comp_lines).properties(
            width=200, height=150).facet(
                alt.Facet("cat_clean", sort=diff_comp, title="arXiv category"),
                columns=3,
            ).resolve_scale(x="independent"))

        if save is True:
            save_altair(topic_comp_type, f"fig_{fig_num}_topic_comp", driv)

        return topic_comp_type
    else:

        # Lookup for the selected categories
        code_topic_lookup = {
            v: str(n + 1)
            for n, v in enumerate(highlight_topics)
        }

        # Add a label per topic for the selected topics
        comparison_df_filter["code"] = [
            code_topic_lookup[x]
            if x in code_topic_lookup.keys() else "no_label"
            for x in comparison_df_filter["index"]
        ]

        # Need to find a way to remove the bottom one
        max_val = comparison_df_filter.groupby(
            "index")["value"].max().to_dict()
        comparison_df_filter["max"] = comparison_df_filter["index"].map(
            max_val)

        comp_text = (alt.Chart(comparison_df_filter).transform_filter(
            alt.datum.code != "no_label").mark_text(
                yOffset=-10, color="red").encode(
                    text=alt.Text("code"),
                    x=alt.X("index",
                            sort=order,
                            axis=alt.Axis(labels=False, ticks=False)),
                    y=alt.Y("max", title=""),
                    detail="index",
                ))

        topic_comp_type = ((comp_ch + comp_lines + comp_text).properties(
            width=200, height=150).facet(
                alt.Facet("cat_clean", sort=diff_comp, title="arXiv category"),
                columns=3,
            ).resolve_scale(x="independent"))

        if save is True:
            save_altair(topic_comp_type, "fig_9_topic_comp", driv)
            save_highlights_table(
                comparison_df_filter,
                highlight_topics,
                highlight_class_table,
                topic_category_map,
            )

        return topic_comp_type, comparison_df_filter
Example #17
0
cit_evol_df = (papers_an.groupby(
    ['is_comp',
     'year'])['citation_count'].mean().unstack(level=0).dropna().loc[range(
         2012, 2020)].stack().reset_index(name='mean_citations').assign(
             is_comp=lambda df: df['is_comp'].replace({
                 True: 'Company',
                 False: 'Not company'
             })))

cit_evol_chart = (alt.Chart(cit_evol_df).mark_line(point=True).encode(
    x=alt.X('year:O', title=None),
    y=alt.Y('mean_citations', title='Mean citations'),
    color=alt.Color('is_comp', title='Article type'))).properties(width=400,
                                                                  height=200)

save_altair(cit_evol_chart, "fig_influence", driver=webd)

# -

# ### Regression

# +
# Steps: train model, get model results


# +
def get_model_results(model, name):
    pass


def fit_model(papers_an, tm, comps_n):