コード例 #1
0
ファイル: main_figures.py プロジェクト: COVID-IWG/epimargin
def plot_component_breakdowns(color,
                              white,
                              colorlabel,
                              whitelabel,
                              semilogy=False,
                              ylabel="WTP (USD)"):
    fig, ax = plt.subplots()
    ax.bar(range(7),
           white * USD,
           bottom=color * USD,
           color="white",
           edgecolor=age_group_colors,
           linewidth=2,
           figure=fig)
    ax.bar(range(7),
           color * USD,
           color=age_group_colors,
           edgecolor=age_group_colors,
           linewidth=2,
           figure=fig)
    ax.bar(range(7), [0],
           label=whitelabel,
           color="white",
           edgecolor="black",
           linewidth=2)
    ax.bar(range(7), [0],
           label=colorlabel,
           color="black",
           edgecolor="black",
           linewidth=2)

    plt.xticks(range(7), age_bin_labels, fontsize="20")
    plt.yticks(fontsize="20")
    plt.legend(ncol=4,
               fontsize="20",
               loc="lower center",
               bbox_to_anchor=(0.5, 1))
    plt.PlotDevice().ylabel(f"{ylabel}\n")
    if semilogy: plt.semilogy()
コード例 #2
0
            state = state_name_lookup[state]
            try:
                return np.median(
                    np.load(
                        src /
                        f"YLL_{state}_{district}_phi{phi}_{vax_policy}.npz")
                    ['arr_0'])
            except FileNotFoundError:
                # return np.nan
                return 0

        districts = districts_to_run.copy()\
            .assign(YLL = districts_to_run.index.map(load_median_YLL))\
            .assign(YLL_per_mn = lambda df: df["YLL"]/(df["N_tot"]/1e6))

        fig, ax = plt.subplots(1, 1)
        scheme = mapclassify.UserDefined(
            districts.YLL_per_mn,
            [0, 125, 250, 400, 600, 900, 1200, 1600, 2500, 5000, 7500
             ])  # ~deciles
        scheme = mapclassify.UserDefined(
            districts.YLL_per_mn,
            [0, 600, 1200, 2000, 2400, 3000, 3600, 4000, 4500, 5550, 9000
             ])  # ~deciles
        districts["category"] = scheme.yb
        india.join(districts["category"].astype(int))\
            .drop(labels = "Andaman And Nicobar Islands")\
            .plot(
            column = "category",
            linewidth = 0.1,
            edgecolor = "k",
コード例 #3
0
        data={
            "dates": dth_dates,
            "Rt_pred": dth_Rt_pred,
            "Rt_CI_upper": dth_Rt_CI_upper,
            "Rt_CI_lower": dth_Rt_CI_lower,
            "T_pred": dth_T_pred,
            "T_CI_upper": dth_T_CI_upper,
            "T_CI_lower": dth_T_CI_lower,
            "total_cases": dth_total_cases[2:],
            "new_cases_ts": dth_new_cases_ts,
        })
    dth_estimates["anomaly"] = dth_estimates["dates"].isin(
        set(dth_anomaly_dates))
    print("  + Rt (dth) today:", inf_Rt_pred[-1])

    fig, axs = plt.subplots(1, 2, sharey=True)
    plt.sca(axs[0])
    plt.Rt(inf_dates, inf_Rt_pred, inf_Rt_CI_lower, inf_Rt_CI_upper, CI)\
        .axis_labels("date", "$R_t$")
    plt.title("estimated from infections",
              loc="left",
              fontdict=plt.theme.label)

    # fig, axs = plt.subplots(3, 1, sharex = True)
    # plt.sca(axs[0])
    # plt.plot(dth_dates, delhi_dD_smoothed[2:], color = "orange")
    # plt.title("d$D$/d$t$", loc = "left", fontdict = plt.theme.label)

    # plt.sca(axs[1])
    # plt.plot(dth_dates, np.diff(delhi_dD_smoothed)[1:], color = "red")
    # plt.title("d$^2D$/d$t^2$", loc = "left", fontdict = plt.theme.label)
コード例 #4
0
ファイル: main.py プロジェクト: COVID-IWG/covid-metrics-infra
def generate_report(state_code: str):
    print(f"Received request for {state_code}.")
    state = state_code_lookup[state_code]
    normalized_state = state.replace(" and ", " And ").replace(" & ", " And ")
    blobs = {
        f"pipeline/est/{state_code}_state_Rt.csv":
        f"/tmp/state_Rt_{state_code}.csv",
        f"pipeline/est/{state_code}_district_Rt.csv":
        f"/tmp/district_Rt_{state_code}.csv",
        f"pipeline/commons/maps/{state_code}.json":
        f"/tmp/state_{state_code}.geojson"
    } if normalized_state not in dissolved_states else {
        f"pipeline/est/{state_code}_state_Rt.csv":
        f"/tmp/state_Rt_{state_code}.csv",
    }
    for (blob_name, filename) in blobs.items():
        bucket.blob(blob_name).download_to_filename(filename)
    print(f"Downloaded estimates for {state_code}.")

    state_Rt = pd.read_csv(f"/tmp/state_Rt_{state_code}.csv",
                           parse_dates=["dates"],
                           index_col=0)

    plt.close("all")
    dates = [pd.Timestamp(date).to_pydatetime() for date in state_Rt.dates]
    plt.Rt(dates, state_Rt.Rt_pred, state_Rt.Rt_CI_lower, state_Rt.Rt_CI_upper, CI)\
        .axis_labels("date", "$R_t$")\
        .title(f"{state}: $R_t$ over time", ha = "center", x = 0.5)\
        .adjust(left = 0.11, bottom = 0.16)
    plt.gcf().set_size_inches(3840 / 300, 1986 / 300)
    plt.savefig(f"/tmp/{state_code}_Rt_timeseries.png")
    plt.close()
    print(f"Generated timeseries plot for {state_code}.")

    # check output is at least 50 KB
    timeseries_size_kb = os.stat(
        f"/tmp/{state_code}_Rt_timeseries.png").st_size / 1000
    print(f"Timeseries artifact size: {timeseries_size_kb} kb")
    assert timeseries_size_kb > 50
    bucket.blob(
        f"pipeline/rpt/{state_code}_Rt_timeseries.png").upload_from_filename(
            f"/tmp/{state_code}_Rt_timeseries.png", content_type="image/png")

    if normalized_state not in (island_states + dissolved_states):
        district_Rt = pd.read_csv(f"/tmp/district_Rt_{state_code}.csv",
                                  parse_dates=["dates"],
                                  index_col=0)
        latest_Rt = district_Rt[district_Rt.dates == district_Rt.dates.max(
        )].set_index("district")["Rt_pred"].to_dict()
        top10 = [(k, "> 3.0" if v > 3 else f"{v:.2f}") for (k, v) in sorted(
            latest_Rt.items(), key=lambda t: t[1], reverse=True)[:10]]

        gdf = gpd.read_file(f"/tmp/state_{state_code}.geojson")
        gdf["Rt"] = gdf.district.map(latest_Rt)
        fig, ax = plt.subplots()
        fig.set_size_inches(3840 / 300, 1986 / 300)
        plt.choropleth(gdf, title = None, mappable = plt.get_cmap(0.75, 2.5), fig = fig, ax = ax)\
            .adjust(left = 0)
        plt.sca(fig.get_axes()[0])
        plt.PlotDevice(fig).title(f"{state}: $R_t$ by district",
                                  ha="center",
                                  x=0.5)
        plt.axis('off')
        plt.savefig(f"/tmp/{state_code}_Rt_choropleth.png", dpi=300)
        plt.close()
        print(f"Generated choropleth for {state_code}.")

        # check output is at least 100 KB
        choropleth_size_kb = os.stat(
            f"/tmp/{state_code}_Rt_choropleth.png").st_size / 1000
        print(f"Choropleth artifact size: {choropleth_size_kb} kb")
        assert choropleth_size_kb > 100
        bucket.blob(f"pipeline/rpt/{state_code}_Rt_choropleth.png"
                    ).upload_from_filename(
                        f"/tmp/{state_code}_Rt_choropleth.png",
                        content_type="image/png")
    else:
        print(f"Skipped choropleth for {state_code}.")

    if normalized_state not in dissolved_states:
        fig, ax = plt.subplots(1, 1)
        ax.axis('tight')
        ax.axis('off')
        table = ax.table(cellText=top10,
                         colLabels=["district", "$R_t$"],
                         loc='center',
                         cellLoc="center")
        table.scale(1, 2)
        for (row, col), cell in table.get_celld().items():
            if (row == 0):
                cell.set_text_props(fontfamily=plt.theme.label["family"],
                                    fontsize=plt.theme.label["size"],
                                    fontweight="semibold")
            else:
                cell.set_text_props(fontfamily=plt.theme.label["family"],
                                    fontsize=plt.theme.label["size"],
                                    fontweight="light")
        plt.PlotDevice().title(f"{state}: top districts by $R_t$",
                               ha="center",
                               x=0.5)
        plt.savefig(f"/tmp/{state_code}_Rt_top10.png", dpi=600)
        plt.close()
        print(f"Generated top 10 district listing for {state_code}.")

        # check output is at least 50 KB
        top10_size_kb = os.stat(
            f"/tmp/{state_code}_Rt_top10.png").st_size / 1000
        print(f"Top 10 listing artifact size: {top10_size_kb} kb")
        assert top10_size_kb > 50
        bucket.blob(
            f"pipeline/rpt/{state_code}_Rt_top10.png").upload_from_filename(
                f"/tmp/{state_code}_Rt_top10.png", content_type="image/png")
    else:
        print(f"Skipped top 10 district listing for {state_code}.")

    # sleep for 15 seconds to ensure the images finish saving
    time.sleep(15)

    print(f"Uploaded artifacts for {state_code}.")
    return "OK!"
コード例 #5
0
ファイル: amravati_facet.py プロジェクト: COVID-IWG/epimargin
    .rename(columns = lambda _:_.replace("_api", ""))\
    .sort_values(["state", "district"])\
    .set_index(["state", "district"])

yticks = {
    "Surat", "Dhule", "Nashik", "Mumbai", "Pune", "Delhi", "Kolkata", "Chennai"
}

xticks = {
    "Surat", "Narmada", "Mumbai", "Thane", "Pune", "Aurangabad", "Parbhani",
    "Nanded", "Yavatmal", "Chennai"
}

pop_density = pd.read_csv(data / "popdensity.csv").set_index(
    ["state", "district"])
fig, ax_nest = plt.subplots(ncols=ncols, nrows=nrows)
for (j, i) in product(range(nrows), range(ncols)):
    if (i + 1, j + 1) in coords.values():
        continue
    ax_nest[j, i].axis("off")

for ((state, district), (x, y)) in coords.items():
    plt.sca(ax_nest[y - 1, x - 1])
    urban_share = int(
        (1 - serodist.loc[state, ("New " if district == "Delhi" else "") +
                          district]["rural_share"].mean()) * 100)
    density = pop_density.loc[state, district].density
    rt_data = district_estimates.loc[state, district].set_index(
        "dates")["Feb 1, 2021":]
    plt.Rt(rt_data.index,
           rt_data.Rt_pred,
コード例 #6
0
ファイル: max_rt_choro.py プロジェクト: COVID-IWG/epimargin
import epimargin.plots as plt
import pandas as pd
import geopandas as gpd

rt = pd.read_csv("data/india_states_max_Rt.csv")
rt.state = rt.state.str.replace("&", "and")
gdf = gpd.read_file("data/india.json").dissolve("st_nm")

mappable = plt.get_cmap(1, 4, "viridis")
fig, ax = plt.subplots()
gdf["pt"] = gdf["geometry"].centroid
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
# ax.set_title(title, loc="left", fontdict=label_font)
# gdf = gdf.merge(rt, left_on = "st_nm", right_on = "state")
gdf.plot(color=[mappable.to_rgba(_) for _ in gdf["max_Rt"]],
         ax=ax,
         edgecolors="black",
         linewidth=0.5,
         missing_kwds={
             "color": "dimgray",
             "edgecolor": "white"
         })

for (_, row) in gdf.iterrows():
    label = label_fn(row)
    a1 = ax.annotate(s=f"{label}{Rt_c}",
                     xy=list(row["pt"].coords)[0],
                     ha="center",
                     fontfamily=note_font["family"],
コード例 #7
0
plt.legend(prop=plt.theme.label, handlelength=1, framealpha=0)
plt.PlotDevice()\
    .axis_labels(x = "age group", y = "CFR (log-scaled)")\
    .l_title("CFR in India (adjusted for reporting)")\
    .r_title("source:\nICMR")\
    .adjust(left = 0.11, bottom = 0.15, right = 0.95)
plt.semilogy()
plt.show()

# fig 3
india_data = pd.read_csv(results / "india_data.csv", parse_dates = ["dt"])\
    .query("State == 'TT'")\
    .set_index("dt")\
    .sort_index()

fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)

plt.sca(axs[0, 0])
plt.scatter(india_data.index, india_data["cfr_2week"], color="black", s=2)
plt.title("2-week lag", loc="left", fontdict=plt.theme.label)

plt.sca(axs[0, 1])
plt.scatter(india_data.index, india_data["cfr_maxcor"], color="black", s=2)
plt.title("10-day lag", loc="left", fontdict=plt.theme.label)

plt.sca(axs[1, 0])
plt.scatter(india_data.index, india_data["cfr_1week"], color="black", s=2)
plt.title("1-week lag", loc="left", fontdict=plt.theme.label)

plt.sca(axs[1, 1])
plt.scatter(india_data.index, india_data["cfr_same"], color="black", s=2)
コード例 #8
0
        .rename(columns = schema)\
        .dropna(how = 'all')\
        .query("age.str.strip() != ''", engine = "python")
parse_datetimes(cases.loc[:, "confirmed"])
cases.regency = cases.regency.str.title().map(
    lambda s: regency_names.get(s, s))
cases.age = cases.age.apply(parse_age)
cases = cases.dropna(subset=["age"])
cases["age_bin"] = pd.cut(cases.age,
                          bins=[0] + list(range(20, 80, 10)) + [100])
age_ts = cases[["age_bin",
                "confirmed"]].groupby(["age_bin",
                                       "confirmed"]).size().sort_index()
ss_max_rts = {}

fig, axs = plt.subplots(4, 2, True, True)
(dates, Rt_pred, Rt_CI_upper, Rt_CI_lower, T_pred, T_CI_upper, T_CI_lower, total_cases, new_cases_ts, anomalies, anomaly_dates)\
    = analytical_MPVS(age_ts.sum(level = 1), CI = CI, smoothing = notched_smoothing(window = 5), totals = False)
plt.sca(axs.flat[0])
plt.Rt(dates, Rt_pred, Rt_CI_upper, Rt_CI_lower,
       CI).annotate(f"all ages").adjust(left=0.04,
                                        right=0.96,
                                        top=0.95,
                                        bottom=0.05,
                                        hspace=0.3,
                                        wspace=0.15)
r = pd.Series(Rt_pred, index=dates)
ss_max_rts["all"] = r[r.index.month_name() == "April"].max()

for (age_bin,
     ax) in zip(age_ts.index.get_level_values(0).categories, axs.flat[1:]):