Пример #1
0
def plot_rate_bw(rate,
                 bw,
                 yinnerdomain,
                 yclientdomain,
                 use_legend,
                 xdomain=None,
                 rate_height=60,
                 bw_height=40):
    rate_chart = alt.Chart(rate).mark_line(color='#5e6472', clip=True).encode(
        x=alt.X('time',
                axis=alt.Axis(title='', labels=False),
                scale=alt.Scale(nice=False, domain=xdomain)),
        y=alt.Y('success_rate:Q',
                axis=alt.Axis(title=['Success', 'rate (s⁻¹)'])),
    ).properties(height=rate_height)

    client_comm = bw.apply(lambda row: 'Client' in row['variable'], axis=1)

    bw_inner_chart = alt.Chart(bw[~client_comm]).mark_line(clip=True).encode(
        x=alt.X("time",
                axis=alt.Axis(title='', labels=False),
                scale=alt.Scale(nice=False,
                                domain=([0, bw['time'].max()]
                                        if xdomain == None else xdomain))),
        y=alt.Y("value",
                axis=alt.Axis(title=''),
                scale=alt.Scale(domain=yinnerdomain)),
        color=alt.Color(
            'variable',
            legend=(alt.Legend(
                title=['Traffic', 'Direction']) if use_legend else None)),
        strokeDash=alt.StrokeDash('variable',
                                  legend=None)).properties(height=bw_height)

    bw_client_chart = alt.Chart(bw[client_comm]).mark_line(clip=True).encode(
        x=alt.X("time",
                axis=alt.Axis(title='Timestamp (s)'),
                scale=alt.Scale(nice=False,
                                domain=([0, bw['time'].max()]
                                        if xdomain == None else xdomain))),
        y=alt.Y("value",
                axis=alt.Axis(title=''),
                scale=alt.Scale(domain=yclientdomain)),
        color=alt.Color('variable',
                        legend=(alt.Legend(title='') if use_legend else None)),
        strokeDash=alt.StrokeDash('variable',
                                  legend=None)).properties(height=bw_height)

    upper = rate_chart
    lower = alt.vconcat(bw_inner_chart,
                        bw_client_chart,
                        title=alt.TitleParams('Bandwidth (MB/s)',
                                              orient='left',
                                              anchor='middle',
                                              dx=15)).resolve_scale(
                                                  color='shared',
                                                  strokeDash='independent')

    return alt.vconcat(upper, lower).resolve_scale(x=alt.ResolveMode('shared'))
def make_plot(inpath):
    flows = infra.pd.read_parquet(inpath)
    flows = flows.reset_index()
    flows["MB"] = flows["bytes_total"] / (1000**2)
    user_total = flows[["user", "MB"]]
    user_total = user_total.groupby(["user"]).sum().reset_index()

    activity = infra.pd.read_parquet("data/clean/user_active_deltas.parquet")

    df = user_total.merge(activity[[
        "user", "days_online", "optimistic_days_online", "days_active"
    ]],
                          on="user")
    df["MB_per_online_day"] = df["MB"] / df["days_online"]
    df["MB_per_active_day"] = df["MB"] / df["days_active"]

    online_cdf_frame = compute_cdf(df,
                                   value_column="MB_per_online_day",
                                   base_column="user")
    online_cdf_frame = online_cdf_frame.rename(
        columns={"MB_per_online_day": "MB"})
    online_cdf_frame = online_cdf_frame.assign(type="Online Ratio")

    print(online_cdf_frame)
    print("Online median MB per Day", online_cdf_frame["MB"].median())

    active_cdf_frame = compute_cdf(df,
                                   value_column="MB_per_active_day",
                                   base_column="user")
    active_cdf_frame = active_cdf_frame.rename(
        columns={"MB_per_active_day": "MB"})
    active_cdf_frame = active_cdf_frame.assign(type="Active Ratio")

    print(active_cdf_frame)
    print("Active median MB per Day", active_cdf_frame["MB"].median())

    plot_frame = online_cdf_frame.append(active_cdf_frame)

    alt.Chart(plot_frame).mark_line(
        interpolate='step-after', clip=True).encode(
            x=alt.X(
                "MB",
                title="Mean MB per Day",
            ),
            y=alt.Y(
                "cdf",
                title="CDF of Users",
                scale=alt.Scale(type="linear", domain=(0, 1.0)),
            ),
            color=alt.Color("type"),
            strokeDash=alt.StrokeDash("type"),
        ).properties(width=500, ).save(
            "renders/bytes_per_online_day_per_user_cdf.png", scale_factor=2.0)
Пример #3
0
def make_gain_chart(value: pd.DataFrame, fiat: str) -> alt.Chart:
    value_long = value.rename(
        {
            "cumsum_fiat": "Invested",
            "value_fiat": "Value"
        }, axis=1).melt(["datetime", "trigger_name"], ["Invested", "Value"])

    chart = (alt.Chart(value_long).mark_line().encode(
        x=alt.X("datetime", title="Time"),
        y=alt.Y("value", title=f"{fiat}"),
        strokeDash=alt.StrokeDash("variable",
                                  title="Variable",
                                  legend=alt.Legend(orient="bottom")),
        color=alt.Color("trigger_name",
                        title="Trigger",
                        legend=alt.Legend(orient="bottom")),
    ).interactive())
    return chart
def make_change_vs_average_plot(inpath):
    grouped_flows = infra.pd.read_parquet(inpath)
    grouped_flows = grouped_flows.reset_index()

    grouped_flows["MB"] = grouped_flows["bytes_total"] / (1000**2)
    working_times = grouped_flows.loc[
        (grouped_flows["day_bin"] < "2019-07-30") |
        (grouped_flows["day_bin"] > "2019-08-31")]

    aggregate = working_times.groupby(["hour", "category"]).agg({"MB": "sum"})
    aggregate = aggregate.reset_index()
    category_total = working_times.groupby(["category"]).sum()
    category_total = category_total.reset_index()[["category", "MB"]]
    category_total = category_total.rename(columns={"MB": "category_total_MB"})

    aggregate = aggregate.merge(category_total, on="category")
    aggregate[
        "byte_density"] = aggregate["MB"] / aggregate["category_total_MB"]

    print(aggregate)
    print(category_total)

    alt.Chart(aggregate).mark_line().encode(
        x=alt.X(
            'hour:O',
            title="Hour of the Day",
        ),
        y=alt.Y(
            'byte_density:Q',
            title="Fraction of Category Bytes Per Hour",
        ),
        color=alt.Color(
            "category:N",
            scale=alt.Scale(scheme="tableau20"),
        ),
        strokeDash=alt.StrokeDash("category:N", ),
    ).save(
        "renders/bytes_per_time_of_day_category_relative_shift.png",
        scale_factor=2,
    )
def make_plot(infile):
    purchases = infra.pd.read_parquet(infile)

    # Drop nulls from the first purchase
    clean_purchases = purchases.dropna()
    # Convert timedelta to seconds for altair compatibility
    clean_purchases["time_since_last_purchase"] = clean_purchases[
        "time_since_last_purchase"].transform(pd.Timedelta.total_seconds)
    clean_purchases = clean_purchases[[
        "user", "time_since_last_purchase", "amount_bytes"
    ]]

    aggregate = clean_purchases.groupby(["user"]).agg({
        "time_since_last_purchase":
        ["mean", lambda x: x.quantile(0.90), lambda x: x.quantile(0.99)]
    })
    # Flatten column names
    aggregate = aggregate.reset_index()
    aggregate.columns = [
        ' '.join(col).strip() for col in aggregate.columns.values
    ]
    aggregate = aggregate.rename(
        columns={
            "time_since_last_purchase mean": "mean",
            "time_since_last_purchase <lambda_0>": "q90",
            "time_since_last_purchase <lambda_1>": "q99",
        })

    # Compute a CDF since the specific user does not matter
    stats_mean = compute_cdf(aggregate, "mean", "user")
    stats_mean = stats_mean.rename(columns={"mean": "value"})
    stats_mean["type"] = "User's Mean"

    stats_q90 = compute_cdf(aggregate, "q90", "user")
    stats_q90 = stats_q90.rename(columns={"q90": "value"})
    stats_q90["type"] = "User's 90% Quantile"

    stats_q99 = compute_cdf(aggregate, "q99", "user")
    stats_q99 = stats_q99.rename(columns={"q99": "value"})
    stats_q99["type"] = "User's 99% Quantile"

    stats_frame = stats_mean.append(stats_q90).append(stats_q99)

    # Convert to Days
    stats_frame["value"] = stats_frame["value"] / 86400
    print(stats_frame)

    alt.Chart(stats_frame).mark_line(clip=True).encode(
        x=alt.X('value:Q',
                scale=alt.Scale(type="log", domain=(0.1, 80)),
                title="Time Between Purchases (Hours) (Log Scale)"),
        y=alt.Y(
            'cdf',
            title="Fraction of Users (CDF)",
            scale=alt.Scale(type="linear", domain=(0, 1.0)),
        ),
        color=alt.Color(
            "type",
            sort=None,
            legend=alt.Legend(
                title="",
                orient="bottom-right",
                fillColor="white",
                labelLimit=500,
                padding=5,
                strokeColor="black",
                columns=1,
            ),
        ),
        strokeDash=alt.StrokeDash(
            "type",
            sort=None,
        )).properties(
            width=500,
            height=200,
        ).save("renders/purchase_timing_per_user_cdf.png", scale_factor=2.0)
def make_plot():
    transactions = infra.pd.read_parquet(
        "data/clean/transactions_DIV_none_INDEX_timestamp.parquet")
    purchases = transactions.loc[(transactions["kind"] == "purchase") |
                                 (transactions["kind"] == "admin_topup")]
    purchases = purchases[["timestamp", "amount_idr", "kind", "user"]]

    purchases[
        "amount_usd"] = purchases["amount_idr"] * infra.constants.IDR_TO_USD
    purchases = purchases.loc[purchases["kind"] == "purchase"]

    # Bin by days to limit the number of tuples
    purchases["day"] = purchases["timestamp"].dt.floor("d")
    purchases = purchases.drop(
        "timestamp", axis="columns").rename(columns={"day": "timestamp"})
    purchases = purchases.groupby(["timestamp", "user"]).sum().reset_index()
    purchases = purchases.assign(kind="Total Revenue")

    user_ranks = purchases.groupby("user").sum().reset_index()
    user_ranks["rank"] = user_ranks["amount_usd"].rank(method="min",
                                                       ascending=False)

    purchases = purchases.merge(user_ranks[["user", "rank"]],
                                on="user",
                                how="inner")

    purchases_no_top_5 = purchases.loc[purchases["rank"] > 5].copy()
    purchases_no_top_5["kind"] = "Revenue Sans Top 5"

    purchases_no_top_10 = purchases.loc[purchases["rank"] > 10].copy()
    purchases_no_top_10["kind"] = "Revenue Sans Top 10"

    purchases_no_top_15 = purchases.loc[purchases["rank"] > 15].copy()
    purchases_no_top_15["kind"] = "Revenue Sans Top 15"

    purchases_no_top_20 = purchases.loc[purchases["rank"] > 20].copy()
    purchases_no_top_20["kind"] = "Revenue Sans Top 20"

    finances = purchases.append(make_expenses()).append(
        purchases_no_top_5).append(purchases_no_top_10).append(
            purchases_no_top_15).append(purchases_no_top_20)

    label_order = {
        "Costs": 1,
        "Total Revenue": 2,
        "Revenue Sans Top 5": 3,
        "Revenue Sans Top 10": 4,
        "Revenue Sans Top 15": 5,
        "Revenue Sans Top 20": 6,
    }

    finances = finances.sort_values(["timestamp", "kind"])
    finances = finances.groupby(["timestamp", "kind"]).sum().sort_index()
    finances = finances.reset_index()
    finances = finances.sort_values(
        ["kind"], key=lambda col: col.map(lambda x: label_order[x]))
    finances = finances.sort_values(
        ["timestamp"],
        kind="mergesort")  # Mergesort is stablely implemented : )
    finances = finances.reset_index()

    finances["amount_cum"] = finances.groupby("kind").cumsum()["amount_usd"]

    altair.Chart(finances).mark_line(interpolate="step-after").encode(
        x=altair.X(
            "timestamp:T",
            title="Time",
        ),
        y=altair.Y(
            "amount_cum",
            title="Amount (USD)",
        ),
        color=altair.Color(
            "kind",
            title="Cumulative:",
            sort=None,
            legend=altair.Legend(
                orient="top-left",
                fillColor="white",
                labelLimit=500,
                padding=5,
                strokeColor="black",
            ),
        ),
        strokeDash=altair.StrokeDash(
            "kind",
            sort=None,
        )).properties(width=500).save(
            "renders/revenue_over_time.png",
            scale_factor=2,
        )
def make_ul_dl_scatter_plot(infile):
    user_cat = infra.pd.read_parquet(infile)
    user_cat = user_cat.reset_index()

    # Filter users to only users who made purchases in the network with registered ips
    users = infra.pd.read_parquet("data/clean/user_active_deltas.parquet")[[
        "user"
    ]]
    user_cat = users.merge(user_cat, on="user", how="left")

    # Compute total bytes for each user across categories
    user_totals = user_cat.groupby(["user"]).sum().reset_index()
    user_totals[
        "bytes_total"] = user_totals["bytes_up"] + user_totals["bytes_down"]

    user_cat = _find_user_top_category(user_cat)
    print(user_cat)

    user_totals = user_totals.merge(user_cat, on="user")
    print(user_totals)

    # Filter users by time in network to eliminate early incomplete samples
    user_active_ranges = infra.pd.read_parquet(
        "data/clean/user_active_deltas.parquet")[[
            "user", "days_since_first_active", "days_active", "days_online"
        ]]
    # Drop users that joined less than a week ago.
    users_to_analyze = user_active_ranges.loc[
        user_active_ranges["days_since_first_active"] > 7]
    # Drop users active for less than one day
    users_to_analyze = users_to_analyze.loc[
        users_to_analyze["days_active"] > 1, ]

    user_totals = user_totals.merge(users_to_analyze, on="user", how="inner")

    # Rank users by their online daily use.
    user_totals["bytes_avg_per_online_day"] = user_totals[
        "bytes_total"] / user_totals["days_online"]
    user_totals["rank_total"] = user_totals["bytes_total"].rank(method="min",
                                                                pct=False)
    user_totals["rank_daily"] = user_totals["bytes_avg_per_online_day"].rank(
        method="min", pct=False)

    # Normalize ul and dl by days online
    user_totals["bytes_up_avg_per_online_day"] = user_totals[
        "bytes_up"] / user_totals["days_online"]
    user_totals["bytes_down_avg_per_online_day"] = user_totals[
        "bytes_down"] / user_totals["days_online"]

    # take the minimum of days online and days active, since active is
    # partial-day aware, but online rounds up to whole days. Can be up to 2-e
    # days off if the user joined late in the day and was last active early.
    user_totals["normalized_days_online"] = np.minimum(
        user_totals["days_online"],
        user_totals["days_active"]) / user_totals["days_active"]

    user_totals["MB_avg_per_online_day"] = user_totals[
        "bytes_avg_per_online_day"] / (1000**2)
    user_totals[
        "ul ratio"] = user_totals["bytes_up"] / user_totals["bytes_total"]
    user_totals[
        "dl ratio"] = user_totals["bytes_down"] / user_totals["bytes_total"]

    # Perform Regressions and Stats Analysis
    # Log-transform to analyze exponential relationships with linear regression
    user_totals["log_ul_ratio"] = user_totals["ul ratio"].map(np.log)
    user_totals["log_mb_per_day"] = user_totals["MB_avg_per_online_day"].map(
        np.log)

    # Print log stats info
    x_log = user_totals["log_mb_per_day"]
    y_log = user_totals["log_ul_ratio"]
    x_log_with_const = sm.add_constant(x_log)
    estimate = sm.OLS(y_log, x_log_with_const)
    estimate_fit = estimate.fit()
    print("Stats info for log-transformded OLS linear fit")
    print("P value", estimate_fit.pvalues[1])
    print("R squared", estimate_fit.rsquared)
    print(estimate_fit.summary())

    # Print direct linear regression stats info
    x = user_totals["MB_avg_per_online_day"]
    y = user_totals["ul ratio"]
    x_with_const = sm.add_constant(x)
    estimate = sm.OLS(y, x_with_const)
    estimate_fit = estimate.fit()
    print("Stats info for direct OLS linear fit")
    print("P value", estimate_fit.pvalues[1])
    print("R squared", estimate_fit.rsquared)
    print(estimate_fit.summary())

    # Reshape to generate column matrixes expected by sklearn
    mb_array = user_totals["MB_avg_per_online_day"].values.reshape((-1, 1))
    ul_ratio_array = user_totals["ul ratio"].values.reshape((-1, 1))
    log_mb_array = user_totals["log_mb_per_day"].values.reshape((-1, 1))
    log_ul_array = user_totals["log_ul_ratio"].values.reshape((-1, 1))
    lin_regressor = LinearRegression()
    lin_regressor.fit(mb_array, ul_ratio_array)
    logt_regressor = LinearRegression()
    logt_regressor.fit(log_mb_array, log_ul_array)

    # Generate a regression plot
    uniform_x = np.linspace(start=mb_array.min(),
                            stop=mb_array.max(),
                            num=1000,
                            endpoint=True).reshape((-1, 1))
    predictions = lin_regressor.predict(uniform_x)
    log_x = np.log(uniform_x)
    logt_predictions = logt_regressor.predict(log_x)
    logt_predictions = np.exp(logt_predictions)

    regression_frame = pd.DataFrame({
        "regressionX": uniform_x.flatten(),
        "predictions": predictions.flatten()
    })
    regression_frame = regression_frame.assign(
        type="Linear(P<0.0001, R²=0.09)")

    logt_frame = pd.DataFrame({
        "regressionX": uniform_x.flatten(),
        "predictions": logt_predictions.flatten()
    })
    logt_frame = logt_frame.assign(
        type="Log Transformed Linear(P<0.0001, R²=0.19)")

    user_totals = user_totals.groupby(["user"]).first()

    scatter = alt.Chart(user_totals).mark_point(
        opacity=0.9, strokeWidth=1.5
    ).encode(
        x=alt.X(
            "MB_avg_per_online_day:Q",
            title="User's Average MB Per Day Online",
            # scale=alt.Scale(
            #     type="log",
            # ),
        ),
        y=alt.Y(
            "ul ratio:Q",
            title="Uplink/Total Bytes Ratio",
            # scale=alt.Scale(
            #     type="log",
            # ),
        ),
        # color=alt.Color(
        #     "top_category",
        #     scale=alt.Scale(scheme="category20"),
        #     sort="descending",
        # ),
        # shape=alt.Shape(
        #     "top_category",
        #     sort="descending",
        # )
    )

    regression = alt.Chart(logt_frame).mark_line(
        color="black", opacity=1).encode(
            x=alt.X("regressionX",
                    # scale=alt.Scale(
                    #     type="log",
                    # ),
                    ),
            y=alt.Y("predictions",
                    # scale=alt.Scale(
                    #     type="log",
                    # ),
                    ),
            strokeDash=alt.StrokeDash("type",
                                      title=None,
                                      legend=alt.Legend(
                                          orient="top-right",
                                          fillColor="white",
                                          labelLimit=500,
                                          padding=10,
                                          strokeColor="black",
                                      )))

    (regression + scatter).properties(width=500, ).save(
        "renders/dl_ul_ratio_per_user_scatter.png",
        scale_factor=2,
    )
Пример #8
0
# %%
corr_between_tests["compare_type"] = np.where(
    corr_between_tests["match"].isin(["ctrl_ctrl", "rep1_rep1", "spont_spont"]),
    "same",
    "different",
)

selection = alt.selection_multi(fields=["match"], bind="legend")

base = alt.Chart(corr_between_tests[corr_between_tests.n > 2000]).encode(
    x="dimension", y="corr", color="match",
)

base.mark_line().encode(
    strokeDash=alt.StrokeDash("compare_type", sort="descending"),
    size=alt.condition(~selection, alt.value(1), alt.value(2)),
    opacity=alt.condition(~selection, alt.value(0.4), alt.value(1)),
    column="regions",
    row="n:N",
).properties(width=200, height=250).add_selection(selection)

# %% [markdown]
# The same analysis, with spontaneous activities subtracted.

# %%
spks_nospont = (
    SubtractSpontAnalyzer(128).fit(loader.spks, loader.idx_spont).transform(loader.spks)
)
with cr.set_spks_source(spks_nospont[loader.istim.index, :]):
    df_un = cr.calc_cr([5000, 10000], no_rep)
Пример #9
0
st.markdown(
    "Here rides throughtout the year are visible, binned by maximum distance and grouped by morning(9am-noon) \
    or afternoon(noon-3pm). It's interesting to see the dips around Christmas, February and March, with a \
    large spike in September.")
st.markdown('\n\n')
st.markdown('\n\n')
df_rides = pd.read_csv(
    "streamlit/data/afternoon_treatment_dist_bins_2010_dow.csv")
df_rides['pickup_date'] = pd.to_datetime(df_rides['pickup_date'])

# https://altair-viz.github.io/user_guide/times_and_dates.html
rides = alt.Chart(df_rides, title='Rides all year').mark_line().encode(
    x=alt.X('pickup_date:T', title='Date/time', axis=alt.Axis(tickCount=18, )),
    y=alt.Y('RidesCount:Q', title='Rides per time block'),
    color=alt.Color('comp_dist_bins:N', title='Distance bins (mi)'),
    strokeDash=alt.StrokeDash('post_treatment_time_dummy:O',
                              title='Afternoon Dummy')).properties(
                                  width=875, height=475).interactive()

st.altair_chart(rides)
st.info(
    'This plot is to help illustrate inherent trends for the mornings and afternoons throughout the year. Rides were also binned based on distance.'
)

st.markdown('\n\n')
st.markdown('\n\n')
st.header('Taxi rides and precipitation inquiry')
st.markdown('\n\n')
st.markdown(
    "Our goal was to show a causal relationship between rain and taxi ridership by counting the number of pickups before and after a rain event began. After we \
started we found that this had actually been studied already, by [Kamga et al](https://www.researchgate.net/publication/255982467_Hailing_in_the_Rain_Temporal_and_Weather-Related_Variations_in_Taxi_Ridership_and_Taxi_Demand-Supply_Equilibrium),\
[Sun et al](https://www.hindawi.com/journals/jat/2020/7081628/), and [Chen et al](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0183574). \
Пример #10
0
def make_plot(inpath):
    flows = infra.pd.read_parquet(inpath)
    flows = flows.reset_index()

    activity = infra.pd.read_parquet("data/clean/user_active_deltas.parquet")

    # Drop users new to the network first active less than a week ago.
    activity = activity.loc[activity["days_since_first_active"] >= 7, ]
    # Drop users active for less than 1 day
    activity = activity.loc[activity["days_active"] >= 1, ]

    # take the minimum of days online and days active, since active is
    # partial-day aware, but online rounds up to whole days. Can be up to 2-e
    # days off if the user joined late in the day and was last active early.
    activity["online_ratio"] = np.minimum(
        activity["days_online"],
        (activity["days_active"] - activity["outage_impact_days"])) / (
            activity["days_active"] - activity["outage_impact_days"])

    flows["MB"] = flows["bytes_total"] / (1000**2)
    user_total = flows[["user", "MB"]]
    user_total = user_total.groupby(["user"]).sum().reset_index()
    df = user_total.merge(
        activity[["user", "online_ratio", "days_online"]],
        on="user",
    )
    df["MB_per_online_day"] = df["MB"] / df["days_online"]

    # Log transform for analysis
    df["log_MB_per_online_day"] = df["MB_per_online_day"].map(np.log)
    df["log_online_ratio"] = df["online_ratio"].map(np.log)

    # Print log stats info
    x_log = df["log_MB_per_online_day"]
    y_log = df["log_online_ratio"]
    x_log_with_const = sm.add_constant(x_log)
    estimate = sm.OLS(y_log, x_log_with_const)
    estimate_fit = estimate.fit()
    print("Stats info for log-transformded OLS linear fit")
    print("P value", estimate_fit.pvalues[1])
    print("R squared", estimate_fit.rsquared)
    print(estimate_fit.summary())

    # Print direct linear regression stats info
    x = df["MB_per_online_day"]
    y = df["online_ratio"]
    x_with_const = sm.add_constant(x)
    estimate = sm.OLS(y, x_with_const)
    estimate_fit = estimate.fit()
    print("Stats info for direct OLS linear fit")
    print("P value", estimate_fit.pvalues[1])
    print("R squared", estimate_fit.rsquared)
    print(estimate_fit.summary())

    # Reshape to generate column matrixes expected by sklearn
    mb_array = df["MB_per_online_day"].values.reshape((-1, 1))
    online_ratio_array = df["online_ratio"].values.reshape((-1, 1))
    log_mb_array = df["log_MB_per_online_day"].values.reshape((-1, 1))
    log_online_array = df["log_online_ratio"].values.reshape((-1, 1))
    lin_regressor = LinearRegression()
    lin_regressor.fit(mb_array, online_ratio_array)
    logt_regressor = LinearRegression()
    logt_regressor.fit(log_mb_array, log_online_array)

    # Generate a regression plot
    uniform_x = np.linspace(start=mb_array.min(),
                            stop=mb_array.max(),
                            num=1000,
                            endpoint=True).reshape((-1, 1))
    predictions = lin_regressor.predict(uniform_x)
    log_x = np.log(uniform_x)
    logt_predictions = logt_regressor.predict(log_x)
    logt_predictions = np.exp(logt_predictions)

    regression_frame = pd.DataFrame({
        "regressionX": uniform_x.flatten(),
        "predictions": predictions.flatten()
    })
    regression_frame = regression_frame.assign(type="Linear(P<0.01, R²=0.05)")

    logt_frame = pd.DataFrame({
        "regressionX": uniform_x.flatten(),
        "predictions": logt_predictions.flatten()
    })
    logt_frame = logt_frame.assign(
        type="Log Transformed Linear(P<0.005, R²=0.05)")
    regression_frame = regression_frame.append(logt_frame)

    scatter = alt.Chart(df).mark_point(opacity=0.9, strokeWidth=1.5).encode(
        x=alt.X(
            "MB_per_online_day",
            title="Mean MB per Day Online",
        ),
        y=alt.Y(
            "online_ratio",
            title="Online Days / Active Days",
            # scale=alt.Scale(type="linear", domain=(0, 1.0)),
        ),
    )

    regression = alt.Chart(logt_frame).mark_line(
        color="black", opacity=1).encode(
            x=alt.X("regressionX",
                    # scale=alt.Scale(
                    #     type="log",
                    # ),
                    ),
            y=alt.Y("predictions",
                    # scale=alt.Scale(
                    #     type="log",
                    # ),
                    ),
            strokeDash=alt.StrokeDash("type",
                                      title=None,
                                      legend=alt.Legend(
                                          orient='none',
                                          fillColor="white",
                                          labelLimit=500,
                                          padding=5,
                                          strokeColor="black",
                                          legendX=300,
                                          legendY=255,
                                      )),
        )

    (scatter + regression).properties(
        width=500,
        height=250,
    ).save("renders/rate_active_per_user.png", scale_factor=2.0)
Пример #11
0
def main(time_period):
    ###### CUSTOMIZE COLOR THEME ######
    alt.themes.register("finastra", finastra_theme)
    alt.themes.enable("finastra")
    violet, fuchsia = ["#694ED6", "#C137A2"]

    ###### SET UP PAGE ######
    icon_path = "esg_ai_logo.png"
    st.set_page_config(page_title="ESG AI",
                       page_icon=icon_path,
                       layout='centered',
                       initial_sidebar_state="collapsed")
    _, logo, _ = st.beta_columns(3)
    logo.image(icon_path, width=200)
    style = ("text-align:center; padding: 0px; font-family: arial black;, "
             "font-size: 400%")
    title = f"<h1 style='{style}'>ESG<sup>AI</sup></h1><br><br>"
    st.write(title, unsafe_allow_html=True)

    ###### LOAD DATA ######
    with st.spinner(text="Fetching Data..."):
        data, companies = load_data(time_period)
    df_conn = data["conn"]
    df_data = data["data"]
    embeddings = data["embed"]

    ####### CREATE SIDEBAR CATEGORY FILTER######
    st.sidebar.title("Filter Options")
    date_place = st.sidebar.empty()
    esg_categories = st.sidebar.multiselect("Select News Categories",
                                            ["E", "S", "G"], ["E", "S", "G"])
    pub = st.sidebar.empty()
    num_neighbors = st.sidebar.slider("Number of Connections", 1, 20, value=8)

    ###### RUN COMPUTATIONS WHEN A COMPANY IS SELECTED ######
    company = st.selectbox("Select a Company to Analyze", companies)
    if company and company != "Select a Company":
        ###### FILTER ######
        df_company = df_data[df_data.Organization == company]
        diff_col = f"{company.replace(' ', '_')}_diff"
        esg_keys = ["E_score", "S_score", "G_score"]
        esg_df = get_melted_frame(data, esg_keys, keepcol=diff_col)
        ind_esg_df = get_melted_frame(data, esg_keys, dropcol="industry_tone")
        tone_df = get_melted_frame(data, ["overall_score"], keepcol=diff_col)
        ind_tone_df = get_melted_frame(data, ["overall_score"],
                                       dropcol="industry_tone")

        ###### DATE WIDGET ######
        start = df_company.DATE.min()
        end = df_company.DATE.max()
        selected_dates = date_place.date_input("Select a Date Range",
                                               value=[start, end],
                                               min_value=start,
                                               max_value=end,
                                               key=None)
        time.sleep(
            0.8)  #Allow user some time to select the two dates -- hacky :D
        start, end = selected_dates

        ###### FILTER DATA ######
        df_company = filter_company_data(df_company, esg_categories, start,
                                         end)
        esg_df = filter_on_date(esg_df, start, end)
        ind_esg_df = filter_on_date(ind_esg_df, start, end)
        tone_df = filter_on_date(tone_df, start, end)
        ind_tone_df = filter_on_date(ind_tone_df, start, end)
        date_filtered = filter_on_date(df_data, start, end)

        ###### PUBLISHER SELECT BOX ######
        publishers = df_company.SourceCommonName.sort_values().unique().tolist(
        )
        publishers.insert(0, "all")
        publisher = pub.selectbox("Select Publisher", publishers)
        df_company = filter_publisher(df_company, publisher)

        ###### DISPLAY DATA ######
        URL_Expander = st.beta_expander(f"View {company.title()} Data:", True)
        URL_Expander.write(f"### {len(df_company):,d} Matching Articles for " +
                           company.title())
        display_cols = [
            "DATE", "SourceCommonName", "URL", "Tone", "Polarity",
            "ActivityDensity", "SelfDensity"
        ]  #  "WordCount"
        URL_Expander.write(df_company[display_cols])

        ###### CHART: METRIC OVER TIME ######
        st.markdown("---")
        col1, col2 = st.beta_columns((1, 3))

        metric_options = [
            "Tone", "NegativeTone", "PositiveTone", "Polarity",
            "ActivityDensity", "WordCount", "Overall Score", "ESG Scores"
        ]
        line_metric = col1.radio("Choose Metric", options=metric_options)

        if line_metric == "ESG Scores":
            # Get ESG scores
            esg_df["WHO"] = company.title()
            ind_esg_df["WHO"] = "Industry Average"
            esg_plot_df = pd.concat([esg_df,
                                     ind_esg_df]).reset_index(drop=True)
            esg_plot_df.replace(
                {
                    "E_score": "Environment",
                    "S_score": "Social",
                    "G_score": "Governance"
                },
                inplace=True)

            metric_chart = alt.Chart(
                esg_plot_df, title="Trends Over Time").mark_line().encode(
                    x=alt.X("yearmonthdate(DATE):O", title="DATE"),
                    y=alt.Y("Score:Q"),
                    color=alt.Color("ESG",
                                    sort=None,
                                    legend=alt.Legend(title=None,
                                                      orient="top")),
                    strokeDash=alt.StrokeDash("WHO",
                                              sort=None,
                                              legend=alt.Legend(
                                                  title=None,
                                                  symbolType="stroke",
                                                  symbolFillColor="gray",
                                                  symbolStrokeWidth=4,
                                                  orient="top")),
                    tooltip=[
                        "DATE", "ESG",
                        alt.Tooltip("Score", format=".5f")
                    ])

        else:
            if line_metric == "Overall Score":
                line_metric = "Score"
                tone_df["WHO"] = company.title()
                ind_tone_df["WHO"] = "Industry Average"
                plot_df = pd.concat([tone_df,
                                     ind_tone_df]).reset_index(drop=True)
            else:
                df1 = df_company.groupby(
                    "DATE")[line_metric].mean().reset_index()
                df2 = filter_on_date(
                    df_data.groupby("DATE")[line_metric].mean().reset_index(),
                    start, end)
                df1["WHO"] = company.title()
                df2["WHO"] = "Industry Average"
                plot_df = pd.concat([df1, df2]).reset_index(drop=True)
            metric_chart = alt.Chart(
                plot_df, title="Trends Over Time").mark_line().encode(
                    x=alt.X("yearmonthdate(DATE):O", title="DATE"),
                    y=alt.Y(f"{line_metric}:Q",
                            scale=alt.Scale(type="linear")),
                    color=alt.Color("WHO", legend=None),
                    strokeDash=alt.StrokeDash(
                        "WHO",
                        sort=None,
                        legend=alt.Legend(
                            title=None,
                            symbolType="stroke",
                            symbolFillColor="gray",
                            symbolStrokeWidth=4,
                            orient="top",
                        ),
                    ),
                    tooltip=["DATE",
                             alt.Tooltip(line_metric, format=".3f")])
        metric_chart = metric_chart.properties(height=340,
                                               width=200).interactive()
        col2.altair_chart(metric_chart, use_container_width=True)

        ###### CHART: ESG RADAR ######
        col1, col2 = st.beta_columns((1, 2))
        avg_esg = data["ESG"]
        avg_esg.rename(columns={"Unnamed: 0": "Type"}, inplace=True)
        avg_esg.replace(
            {
                "T": "Overall",
                "E": "Environment",
                "S": "Social",
                "G": "Governance"
            },
            inplace=True)
        avg_esg["Industry Average"] = avg_esg.mean(axis=1)

        radar_df = avg_esg[["Type", company,
                            "Industry Average"]].melt("Type",
                                                      value_name="score",
                                                      var_name="entity")

        radar = px.line_polar(radar_df,
                              r="score",
                              theta="Type",
                              color="entity",
                              line_close=True,
                              hover_name="Type",
                              hover_data={
                                  "Type": True,
                                  "entity": True,
                                  "score": ":.2f"
                              },
                              color_discrete_map={
                                  "Industry Average": fuchsia,
                                  company: violet
                              })
        radar.update_layout(
            template=None,
            polar={
                "radialaxis": {
                    "showticklabels": False,
                    "ticks": ""
                },
                "angularaxis": {
                    "showticklabels": False,
                    "ticks": ""
                },
            },
            legend={
                "title": None,
                "yanchor": "middle",
                "orientation": "h"
            },
            title={
                "text": "<b>ESG Scores</b>",
                "x": 0.5,
                "y": 0.8875,
                "xanchor": "center",
                "yanchor": "top",
                "font": {
                    "family": "Futura",
                    "size": 23
                }
            },
            margin={
                "l": 5,
                "r": 5,
                "t": 0,
                "b": 0
            },
        )
        radar.update_layout(showlegend=False)
        col1.plotly_chart(radar, use_container_width=True)

        ###### CHART: DOCUMENT TONE DISTRIBUTION #####
        # add overall average
        dist_chart = alt.Chart(
            df_company, title="Document Tone "
            "Distribution").transform_density(
                density='Tone',
                as_=["Tone",
                     "density"]).mark_area(opacity=0.5, color="purple").encode(
                         x=alt.X('Tone:Q', scale=alt.Scale(domain=(-10, 10))),
                         y='density:Q',
                         tooltip=[
                             alt.Tooltip("Tone", format=".3f"),
                             alt.Tooltip("density:Q", format=".4f")
                         ]).properties(height=325, ).configure_title(
                             dy=-20).interactive()
        col2.markdown("### <br>", unsafe_allow_html=True)
        col2.altair_chart(dist_chart, use_container_width=True)

        ###### CHART: SCATTER OF ARTICLES OVER TIME #####
        # st.markdown("---")
        scatter = alt.Chart(df_company,
                            title="Article Tone").mark_circle().encode(
                                x="NegativeTone:Q",
                                y="PositiveTone:Q",
                                size="WordCount:Q",
                                color=alt.Color("Polarity:Q",
                                                scale=alt.Scale()),
                                tooltip=[
                                    alt.Tooltip("Polarity", format=".3f"),
                                    alt.Tooltip("NegativeTone", format=".3f"),
                                    alt.Tooltip("PositiveTone", format=".3f"),
                                    alt.Tooltip("DATE"),
                                    alt.Tooltip("WordCount", format=",d"),
                                    alt.Tooltip("SourceCommonName",
                                                title="Site")
                                ]).properties(height=450).interactive()
        st.altair_chart(scatter, use_container_width=True)

        ###### NUMBER OF NEIGHBORS TO FIND #####
        neighbor_cols = [f"n{i}_rec" for i in range(num_neighbors)]
        company_df = df_conn[df_conn.company == company]
        neighbors = company_df[neighbor_cols].iloc[0]

        ###### CHART: 3D EMBEDDING WITH NEIGHBORS ######
        st.markdown("---")
        color_f = lambda f: f"Company: {company.title()}" if f == company else (
            "Connected Company" if f in neighbors.values else "Other Company")
        embeddings["colorCode"] = embeddings.company.apply(color_f)
        point_colors = {
            company: violet,
            "Connected Company": fuchsia,
            "Other Company": "lightgrey"
        }
        fig_3d = px.scatter_3d(
            embeddings,
            x="0",
            y="1",
            z="2",
            color='colorCode',
            color_discrete_map=point_colors,
            opacity=0.4,
            hover_name="company",
            hover_data={c: False
                        for c in embeddings.columns},
        )
        fig_3d.update_layout(
            legend={
                "orientation": "h",
                "yanchor": "bottom",
                "title": None
            },
            title={
                "text": "<b>Company Connections</b>",
                "x": 0.5,
                "y": 0.9,
                "xanchor": "center",
                "yanchor": "top",
                "font": {
                    "family": "Futura",
                    "size": 23
                }
            },
            scene={
                "xaxis": {
                    "visible": False
                },
                "yaxis": {
                    "visible": False
                },
                "zaxis": {
                    "visible": False
                }
            },
            margin={
                "l": 0,
                "r": 0,
                "t": 0,
                "b": 0
            },
        )
        st.plotly_chart(fig_3d, use_container_width=True)

        ###### CHART: NEIGHBOR SIMILIARITY ######
        st.markdown("---")
        neighbor_conf = pd.DataFrame({
            "Neighbor":
            neighbors,
            "Confidence":
            company_df[[f"n{i}_conf" for i in range(num_neighbors)]].values[0]
        })
        conf_plot = alt.Chart(
            neighbor_conf, title="Connected Companies").mark_bar().encode(
                x="Confidence:Q",
                y=alt.Y("Neighbor:N", sort="-x"),
                tooltip=["Neighbor",
                         alt.Tooltip("Confidence", format=".3f")],
                color=alt.Color(
                    "Confidence:Q", scale=alt.Scale(),
                    legend=None)).properties(height=25 * num_neighbors +
                                             100).configure_axis(grid=False)
        st.altair_chart(conf_plot, use_container_width=True)
Пример #12
0
def create_exploratory_visualisation(trial_id,
                                     directory,
                                     vis_data_file,
                                     match_data_file,
                                     violations_data_file,
                                     props_data_file,
                                     baseline_label='baseline',
                                     verde_label='verde'):
    """
    Uses altair to generate the exploratory visualisation.
    :param trial_id:
    :param directory:
    :param vis_data_file:
    :param match_data_file:
    :param violations_data_file:
    :param props_data_file:
    :param baseline_label:
    :param verde_label:
    :return:
    """

    vl_viewer = f'{trial_id}_view_one_vl.html?vl_json='

    # common data and transforms for first layer with marks for each vis model
    base = alt.Chart(os.path.basename(vis_data_file)).transform_calculate(
        rank="format(datum.rank,'03')",
        link=f"'{vl_viewer}' + datum.vl_spec_file").properties(
            width=250, height=alt.Step(30), title='visualisation rankings')

    # add a selectable square for each vis model
    select_models = alt.selection_multi(fields=['set', 'rank'])
    select_brush = alt.selection_interval()
    squares = base.mark_square(size=150).encode(
        alt.X('set:O',
              axis=alt.Axis(labelAngle=0,
                            title=None,
                            orient='top',
                            labelPadding=5)),
        alt.Y('rank:O', axis=alt.Axis(title=None)),
        tooltip=['set:N', 'rank:N', 'cost:Q'],
        opacity=alt.Opacity('has_match:O', legend=None),
        color=alt.condition(
            select_models | select_brush, alt.value('steelblue'),
            alt.value('lightgray'))).add_selection(select_models,
                                                   select_brush).interactive()

    # add a small circle with the hyperlink to the actual vis.
    # Shame that xoffset is not an encoding channel, so we have to do in two steps...
    def make_circles(vis_set, offset):
        return base.transform_filter(datum.set == vis_set).mark_circle(
            size=25,
            xOffset=offset,
        ).encode(alt.X('set:O',
                       axis=alt.Axis(labelAngle=0,
                                     title=None,
                                     orient='top',
                                     labelPadding=5)),
                 alt.Y('rank:O'),
                 tooltip=['link:N'],
                 href='link:N',
                 color=alt.condition(select_models | select_brush,
                                     alt.value('steelblue'),
                                     alt.value('lightgray'))).interactive()

    baseline_circles = make_circles(baseline_label, -15)
    verde_circles = make_circles(verde_label, 15)

    # next layer is match lines, handle case of no matches
    if match_data_file:
        col_domain = ['not', 'with_equal_cost', 'with_different_cost']
        col_range_ = ['steelblue', 'green', 'red']
        match_lines = alt.Chart(
            os.path.basename(match_data_file)).mark_line().transform_calculate(
                rank="format(datum.rank,'03')").encode(
                    alt.X('set:O',
                          axis=alt.Axis(labelAngle=0,
                                        title=None,
                                        orient='top',
                                        labelPadding=5)),
                    alt.Y('rank:O'),
                    detail=['match:N', 'match_type:N'],
                    strokeDash=alt.StrokeDash(
                        'match_type:N',
                        scale=alt.Scale(domain=['verde_addition', 'exact'],
                                        range=[[5, 4], [1, 0]]),
                        legend=alt.Legend(orient='bottom')),
                    color=alt.condition(
                        select_models | select_brush,
                        alt.Color('crossed:N',
                                  scale=alt.Scale(domain=col_domain,
                                                  range=col_range_),
                                  legend=alt.Legend(orient='bottom')),
                        alt.value('lightgray')))
    else:
        match_lines = None

    # rules to connect models with the same cost
    cost_rules = base.mark_rule(strokeWidth=2).transform_aggregate(
        min_rank='min(rank)', max_rank='max(rank)',
        groupby=['set', 'cost'
                 ]).encode(alt.X('set:O',
                                 axis=alt.Axis(labelAngle=0,
                                               title=None,
                                               orient='top',
                                               labelPadding=5)),
                           alt.Y('min_rank:O'),
                           alt.Y2('max_rank:O'),
                           color=alt.condition(select_models | select_brush,
                                               alt.value('steelblue'),
                                               alt.value('lightgray')),
                           tooltip=['cost:Q', 'min_rank:O',
                                    'max_rank:O']).interactive()

    rank_chart = baseline_circles + verde_circles

    if match_lines:
        rank_chart = rank_chart + match_lines

    rank_chart = rank_chart + cost_rules + squares

    # chart to show violation occurrences and weights for selected vis models across sets
    def make_violation_chart(dimension, width_step):
        return alt.Chart(os.path.basename(violations_data_file)).mark_circle(
            color='red', ).transform_calculate(
                rank="format(datum.rank,'03')", ).transform_filter(
                    select_models).transform_filter(select_brush).encode(
                        x=alt.X(f'{dimension}:N',
                                axis=alt.Axis(labelAngle=0,
                                              title=None,
                                              orient='top',
                                              labelPadding=5)),
                        y=alt.Y('violation:N', axis=alt.Axis(title=None)),
                        size=alt.Size('num:Q', legend=None),
                        opacity=alt.Opacity('weight:Q',
                                            scale=alt.Scale(range=[0, 1]),
                                            legend=None),
                        tooltip=[
                            'set:N', 'rank:Q', 'violation:N', 'num:Q',
                            'weight:Q', 'cost_contrib:Q'
                        ]).properties(
                            width=alt.Step(width_step),
                            title=f'soft rule violations (x-{dimension})'
                        ).interactive()

    violation_set_chart = make_violation_chart('set', 40)
    violation_rank_chart = make_violation_chart('rank', 30)

    # chart to show prop occurrences for selected vis models across sets
    def make_prop_chart(dimension, width_step):
        return alt.Chart(os.path.basename(props_data_file)).mark_circle(
            size=50, color='green').transform_calculate(
                rank="format(datum.rank,'03')").transform_filter(
                    select_models).transform_filter(select_brush).encode(
                        x=alt.X(f'{dimension}:N',
                                axis=alt.Axis(labelAngle=0,
                                              title=None,
                                              orient='top',
                                              labelPadding=5)),
                        y=alt.Y('prop:N', axis=alt.Axis(title=None)),
                        tooltip=['prop:N']).properties(
                            width=alt.Step(width_step),
                            title=f'specification terms (x-{dimension})'
                        ).interactive()

    prop_set_chart = make_prop_chart('set', 40)
    prop_rank_chart = make_prop_chart('rank', 30)

    # glue them all together
    top_chart = rank_chart | violation_set_chart | prop_set_chart
    bottom_chart = violation_rank_chart | prop_rank_chart
    chart = top_chart & bottom_chart
    # put a timestamp
    ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    chart = chart.properties(title=f'{trial_id} {ts}')

    file_name = os.path.join(directory, 'vegalite',
                             f'{trial_id}_view_compare.html')
    logging.info(f'writing comparison visualisation to {file_name}')
    chart.save(file_name)
Пример #13
0
def make_plot(inpath):
    activity = infra.pd.read_parquet(inpath)

    # Drop users that have been active less than a week.
    activity = activity.loc[activity["days_since_first_active"] >= 7, ]

    # Drop users active for less than one week
    activity = activity.loc[activity["days_active"] >= 1, ]

    # take the minimum of days online and days active, since active is
    # partial-day aware, but online rounds up to whole days. Can be up to 2-e
    # days off if the user joined late in the day and was last active early.
    activity["optimistic_online_ratio"] = (
        np.minimum(activity["days_online"], activity["days_active"]) /
        (activity["days_active"] - activity["outage_impact_days"]))
    print(activity)

    # Compute a CDF since the specific user does not matter
    optimistic_cdf_frame = compute_cdf(activity,
                                       value_column="optimistic_online_ratio",
                                       base_column="user")
    optimistic_cdf_frame = optimistic_cdf_frame.rename(
        columns={"optimistic_online_ratio": "value"})
    optimistic_cdf_frame["type"] = "Optimistic (ignore outages)"

    # take the minimum of days online and days active, since active is
    # partial-day aware, but online rounds up to whole days. Can be up to 2-e
    # days off if the user joined late in the day and was last active early.
    activity["observed_online_ratio"] = np.minimum(
        activity["days_online"],
        activity["days_active"]) / activity["days_active"]
    print(activity)

    # Compute a CDF since the specific user does not matter
    observed_cdf_frame = compute_cdf(activity,
                                     value_column="observed_online_ratio",
                                     base_column="user")
    observed_cdf_frame = observed_cdf_frame.rename(
        columns={"observed_online_ratio": "value"})
    observed_cdf_frame["type"] = "Observed"

    df = optimistic_cdf_frame.append(observed_cdf_frame)

    alt.Chart(df).mark_line(interpolate='step-after', clip=True).encode(
        x=alt.X('value:Q',
                scale=alt.Scale(type="linear", domain=(0, 1.00)),
                title="Online Days / Active Days"),
        y=alt.Y(
            'cdf',
            title="CDF of Users N={}".format(len(activity)),
            scale=alt.Scale(type="linear", domain=(0, 1.0)),
        ),
        color=alt.Color(
            "type",
            legend=alt.Legend(title=None),
        ),
        strokeDash=alt.StrokeDash(
            "type",
            legend=alt.Legend(
                orient="top-left",
                fillColor="white",
                labelLimit=500,
                padding=5,
                strokeColor="black",
            ),
        ),
    ).properties(
        width=500,
        height=200,
    ).save("renders/rate_online_when_active_cdf.png", scale_factor=2.0)
Пример #14
0
def make_plot(infile):
    early_users = infra.pd.read_parquet(
        "data/clean/initial_user_balances_INDEX_none.parquet")
    registered_users = early_users.assign(timestamp=infra.constants.MIN_DATE)

    transactions = infra.pd.read_parquet(
        "data/clean/transactions_DIV_none_INDEX_timestamp.parquet"
    ).reset_index()

    registered_users = registered_users.append(transactions).sort_values(
        "timestamp").groupby("user").first()
    registered_users = registered_users.reset_index().sort_values(
        "timestamp").reset_index()
    registered_users = registered_users.assign(temp=1)
    registered_users["count"] = registered_users["temp"].cumsum()
    registered_users = registered_users.drop(
        ["temp", "user"], axis="columns").rename(columns={"count": "user"})
    registered_users["day"] = registered_users["timestamp"].dt.floor("d")

    # Generate a dense dataframe with all days
    date_range = pd.DataFrame({
        "day":
        pd.date_range(infra.constants.MIN_DATE,
                      infra.constants.MAX_DATE,
                      freq="1D")
    })
    registered_users = date_range.merge(
        registered_users,
        how="left",
        left_on="day",
        right_on="day",
    ).fillna(method="ffill").dropna()

    user_days = infra.pd.read_parquet(infile)

    active_users = user_days.groupby("day")["user"].nunique()
    active_users = active_users.to_frame().reset_index()

    # Group weekly to capture the total number of unique users across the entire week and account for intermittent use.
    weekly_users = user_days.groupby(pd.Grouper(
        key="day", freq="W-MON"))["user"].nunique()
    weekly_users = weekly_users.to_frame().reset_index().rename(
        columns={"user": "******"})
    week_range = pd.DataFrame({
        "day":
        pd.date_range(infra.constants.MIN_DATE,
                      infra.constants.MAX_DATE,
                      freq="W-MON")
    })
    weekly_users = weekly_users.merge(week_range, on="day", how="outer")
    weekly_users.fillna(0)

    monthly_users = user_days.groupby(pd.Grouper(key="day",
                                                 freq="M"))["user"].nunique()
    monthly_users = monthly_users.to_frame().reset_index().rename(
        columns={"user": "******"})
    month_range = pd.DataFrame({
        "day":
        pd.date_range(infra.constants.MIN_DATE,
                      infra.constants.MAX_DATE,
                      freq="M")
    })
    monthly_users = monthly_users.merge(month_range, on="day", how="outer")
    monthly_users = monthly_users.fillna(0)

    # Join the active and registered users together
    users = active_users.merge(registered_users,
                               how="right",
                               left_on="day",
                               right_on="day",
                               suffixes=('_active', '_registered'))
    users = users.merge(weekly_users, how="outer", on="day")
    users = users.merge(monthly_users, how="outer", on="day")

    # For cohorts with no active users, fill zero.
    users["user_active"] = users["user_active"].fillna(value=0)

    users = users.rename(
        columns={
            "day": "date",
            "user_active": "Unique Daily Online",
            "user_registered": "Registered",
            "week_unique_users": "Unique Weekly Online",
            "month_unique_users": "Unique Monthly Online"
        })
    users = users.set_index("date").sort_index()
    users["Registered"] = users["Registered"].fillna(method="ffill")
    users["Unique Weekly Online"] = users["Unique Weekly Online"].fillna(
        method="bfill")
    users["Unique Monthly Online"] = users["Unique Monthly Online"].fillna(
        method="bfill")
    users = users.reset_index()

    # Limit graphs to the study period
    users = users.loc[users["date"] < infra.constants.MAX_DATE]

    # Compute a rolling average
    users["Active 7-Day Average"] = users["Unique Daily Online"].rolling(
        window=7, ).mean()

    # Get the data in a form that is easily plottable
    users = users.melt(id_vars=["date"],
                       value_vars=[
                           "Registered", "Unique Monthly Online",
                           "Unique Weekly Online", "Unique Daily Online"
                       ],
                       var_name="user_type",
                       value_name="num_users")
    # Drop the rolling average... it wasn't useful
    # users = users.melt(id_vars=["date"], value_vars=["Active", "Registered", "Active 7-Day Average", "Unique Weekly Active"], var_name="user_type", value_name="num_users")
    # Reset the types of the dataframe
    types = {"date": "datetime64", "num_users": "int64"}
    # Required since some rolling average entries are NaN before the average window is filled.
    users = users.dropna()
    users = users.astype(types)

    users = users.sort_values(["date", "num_users"])
    label_order = {
        "Registered": 1,
        "Unique Monthly Online": 2,
        "Unique Weekly Online": 3,
        "Unique Daily Online": 4,
    }
    # Mergesort is stablely implemented : )
    users = users.sort_values(
        ["user_type"],
        key=lambda col: col.map(lambda x: label_order[x]),
        kind="mergesort",
    )
    users = users.reset_index()

    altair.Chart(users).mark_line(interpolate='step-after').encode(
        x=altair.X(
            "date:T",
            title="Time",
            axis=altair.Axis(
                labelSeparation=5,
                labelOverlap="parity",
            ),
        ),
        y=altair.Y(
            "num_users",
            title="User Count",
        ),
        color=altair.Color(
            "user_type",
            title="",
            sort=None,
            legend=altair.Legend(
                orient="top-left",
                fillColor="white",
                labelLimit=500,
                padding=10,
                strokeColor="black",
            ),
        ),
        strokeDash=altair.StrokeDash(
            "user_type",
            sort=None,
        ),
    ).properties(width=500).save("renders/users_per_week.png", scale_factor=2)