Ejemplo n.º 1
0
def plot_district_age_distribution(percentiles,
                                   ylabel,
                                   fmt,
                                   phi=50,
                                   vax_policy="random",
                                   N_jk=None,
                                   n=5,
                                   district_spacing=1.5,
                                   age_spacing=0.1,
                                   rotation=0):
    fig = plt.figure()
    district_ordering = list(districts_to_run.index)[:n]
    for (i, district) in enumerate(district_ordering):
        ylls = percentiles[district, phi, vax_policy]
        for j in range(7):
            plt.errorbar(x=[district_spacing * i + age_spacing * (j - 3)],
                         y=ylls[1, 6 - j] * USD /
                         (N_jk[f"N_{6-j}"][district] if N_jk else 1),
                         yerr=[[(ylls[1, 6 - j] - ylls[0, 6 - j]) * USD /
                                (N_jk[f"N_{6-j}"][district] if N_jk else 1)],
                               [(ylls[2, 6 - j] - ylls[1, 6 - j]) * USD /
                                (N_jk[f"N_{6-j}"][district] if N_jk else 1)]],
                         fmt=fmt,
                         color=age_group_colors[6 - j],
                         figure=fig,
                         label=None if i > 0 else age_bin_labels[6 - j],
                         ms=12,
                         elinewidth=5)
    plt.xticks([1.5 * _ for _ in range(n)],
               district_ordering,
               rotation=rotation,
               fontsize="20")
    plt.yticks(fontsize="20")
    plt.legend(title="age bin",
               title_fontsize="20",
               fontsize="20",
               ncol=7,
               loc="lower center",
               bbox_to_anchor=(0.5, 1))
    ymin, ymax = plt.ylim()
    plt.vlines(x=[0.75 + 1.5 * _ for _ in range(n - 1)],
               ymin=ymin,
               ymax=ymax,
               color="gray",
               alpha=0.5,
               linewidths=2)
    plt.ylim(ymin, ymax)
    plt.gca().grid(False, axis="x")
    plt.PlotDevice().title(f"\n{vax_policy} demand curves").ylabel(
        f"{ylabel}\n")
Ejemplo n.º 2
0
def plot_mobility(series, label, stringency = None, until = None, annotation = "Google Mobility Data; baseline mobility measured from Jan 3 - Feb 6"):
    plt.plot(series.date, smoothed(series.retail_and_recreation_percent_change_from_baseline), label = "Retail/Recreation")
    plt.plot(series.date, smoothed(series.grocery_and_pharmacy_percent_change_from_baseline),  label = "Grocery/Pharmacy")
    plt.plot(series.date, smoothed(series.parks_percent_change_from_baseline),                 label = "Parks")
    plt.plot(series.date, smoothed(series.transit_stations_percent_change_from_baseline),      label = "Transit Stations")
    plt.plot(series.date, smoothed(series.workplaces_percent_change_from_baseline),            label = "Workplaces")
    plt.plot(series.date, smoothed(series.residential_percent_change_from_baseline),           label = "Residential")
    if until:
        right = pd.Timestamp(until)
    elif stringency is not None:
        right = stringency.Date.max()
    else:
        right = series.date.iloc[-1]
    lax = plt.gca()
    if stringency is not None: 
        plt.sca(lax.twinx())
        stringency_IN = stringency.query("CountryName == 'India'")
        stringency_US = stringency.query("(CountryName == 'United States') & (RegionName.isnull())", engine = "python")
        plt.plot(stringency_IN.Date, stringency_IN.StringencyIndex, 'k--', alpha = 0.6, label = "IN Measure Stringency")
        plt.plot(stringency_US.Date, stringency_US.StringencyIndex, 'k.' , alpha = 0.6, label = "US Measure Stringency")
        plt.PlotDevice().ylabel("lockdown stringency index", rotation = -90, labelpad = 50)
        plt.legend()
        plt.sca(lax)
    plt.legend(loc = "lower right")
    plt.fill_betweenx((-100, 60), pd.to_datetime("March 24, 2020"), pd.to_datetime("June 1, 2020"), color = "black", alpha = 0.05, zorder = -1)
    plt.text(s = "national lockdown", x = pd.to_datetime("April 27, 2020"), y = -90, fontdict = plt.theme.note, ha = "center", va = "top")
    plt.PlotDevice()\
        .xlabel("\ndate")\
        .ylabel("% change in mobility\n")
        # .title(f"\n{label}: Mobility & Lockdown Trends")\
        # .annotate(annotation)\
    plt.ylim(-100, 60)

    plt.xlim(left = series.date.iloc[0], right = right)
Ejemplo n.º 3
0
def plot_state_age_distribution(percentiles,
                                ylabel,
                                fmt,
                                district_spacing=1.5,
                                n=5,
                                age_spacing=0.1,
                                rotation=0,
                                ymin=0,
                                ymax=1000):
    fig = plt.figure()
    state_ordering = list(
        sorted(percentiles.keys(),
               key=lambda k: percentiles[k][0].max(),
               reverse=True))
    for (i, state) in enumerate(state_ordering[:n]):
        ylls = percentiles[state]
        for j in range(7):
            plt.errorbar(x=[district_spacing * i + age_spacing * (j - 3)],
                         y=ylls[0, 6 - j],
                         yerr=[[(ylls[0, 6 - j] - ylls[1, 6 - j])],
                               [(ylls[2, 6 - j] - ylls[0, 6 - j])]],
                         fmt=fmt,
                         color=agebin_colors[6 - j],
                         figure=fig,
                         label=None if i > 0 else agebin_labels[6 - j],
                         ms=12,
                         elinewidth=5)
    plt.xticks([1.5 * _ for _ in range(n)],
               state_ordering,
               rotation=rotation,
               fontsize="20")
    plt.yticks(fontsize="20")
    # plt.legend(title = "age bin", title_fontsize = "20", fontsize = "20", ncol = 7,
    plt.legend(fontsize="20",
               ncol=7,
               loc="lower center",
               bbox_to_anchor=(0.5, 1))
    plt.vlines(x=[0.75 + 1.5 * _ for _ in range(n - 1)],
               ymin=ymin,
               ymax=ymax,
               color="gray",
               alpha=0.5,
               linewidths=4)
    plt.ylim(ymin, ymax)
    plt.gca().grid(False, axis="x")
    plt.PlotDevice().ylabel(f"{ylabel}\n")
Ejemplo n.º 4
0
def outcomes_per_policy(percentiles,
                        metric_label,
                        fmt,
                        phis=[25, 50, 100, 200],
                        reference=(25, "no_vax"),
                        reference_color=no_vax_color,
                        vax_policies=["contact", "random", "mortality"],
                        policy_colors=[
                            contactrate_vax_color, random_vax_color,
                            mortality_vax_color
                        ],
                        policy_labels=[
                            "contact rate priority", "random assignment",
                            "mortality priority"
                        ],
                        spacing=0.2):
    fig = plt.figure()

    md, lo, hi = percentiles[reference]
    *_, bars = plt.errorbar(x=[0],
                            y=[md],
                            yerr=[[md - lo], [hi - md]],
                            figure=fig,
                            fmt=fmt,
                            color=reference_color,
                            label="no vaccination",
                            ms=12,
                            elinewidth=5)
    [_.set_alpha(0.5) for _ in bars]
    plt.hlines(md,
               xmin=-1,
               xmax=5,
               linestyles="dotted",
               colors=reference_color)

    for (i, phi) in enumerate(phis, start=1):
        for (j, (vax_policy, color, label)) in enumerate(
                zip(vax_policies, policy_colors, policy_labels)):
            md, lo, hi = death_percentiles[phi, vax_policy]
            *_, bars = plt.errorbar(x=[i + spacing * (j - 1)],
                                    y=[md],
                                    yerr=[[md - lo], [hi - md]],
                                    figure=fig,
                                    fmt=fmt,
                                    color=color,
                                    label=label if i == 0 else None,
                                    ms=12,
                                    elinewidth=5)
            [_.set_alpha(0.5) for _ in bars]

    plt.legend(ncol=4,
               fontsize="20",
               loc="lower center",
               bbox_to_anchor=(0.5, 1))
    plt.xticks(range(len(phis) + 1),
               [f"$\phi = {phi}$%" for phi in ([0] + phis)],
               fontsize="20")
    plt.yticks(fontsize="20")
    plt.PlotDevice().ylabel(f"{metric_label}\n")
    plt.gca().grid(False, axis="x")
    ymin, ymax = plt.ylim()
    plt.vlines(x=[0.5 + _ for _ in range(len(phis))],
               ymin=ymin,
               ymax=ymax,
               color="gray",
               alpha=0.5,
               linewidths=2)
    plt.ylim(ymin, ymax)
    plt.xlim(-0.5, len(phis) + 1.5)
Ejemplo n.º 5
0
                    figure=fig,
                    zorder=10,
                    linewidth=2)
plt.xticks(fontsize="20", rotation=0)
plt.yticks(fontsize="20")
plt.legend([scatter_TN, plot_TN, scatter_TT, plot_TT], [
    "Tamil Nadu (raw)", "Tamil Nadu (smoothed)", "India (raw)",
    "India (smoothed)"
],
           fontsize="20",
           ncol=4,
           framealpha=1,
           handlelength=0.75,
           loc="lower center",
           bbox_to_anchor=(0.5, 1))
plt.gca().xaxis.set_major_formatter(plt.bY_FMT)
plt.gca().xaxis.set_minor_formatter(plt.bY_FMT)
plt.xlim(left=pd.Timestamp("March 1, 2020"),
         right=pd.Timestamp("April 15, 2021"))
plt.ylim(bottom=0)
plt.PlotDevice().ylabel("per-capita infection rate\n").xlabel("\ndate")
plt.show()

# 1B: per capita vaccination rates
vax = load_vax_data()\
    .reindex(pd.date_range(start = pd.Timestamp("Jan 1, 2021"), end = simulation_start, freq = "D"), fill_value = 0)\
    [pd.Timestamp("Jan 1, 2021"):simulation_start]\
    .drop(labels = [pd.Timestamp("2021-03-15")]) # handle NAN

plt.plot(vax.index,
         vax["Tamil Nadu"] / N_TN,
Ejemplo n.º 6
0
# data prep
with (data/'timeseries.json').open("rb") as fp:
    df = flat_table.normalize(pd.read_json(fp)).fillna(0)
df.columns = df.columns.str.split('.', expand = True)
dates = np.squeeze(df["index"][None].values)
df = df.drop(columns = "index").set_index(dates).stack([1, 2]).drop("UN", axis = 1)

series = mobility[mobility.sub_region_1.isna()]
plt.plot(series.date, smoothed(series.retail_and_recreation_percent_change_from_baseline), label = "Retail/Recreation")
plt.fill_betweenx((-100, 60), pd.to_datetime("March 24, 2020"), pd.to_datetime("June 1, 2020"), color = "black", alpha = 0.05, zorder = -1)
plt.text(s = "national lockdown", x = pd.to_datetime("April 27, 2020"), y = -20, fontdict = plt.note_font, ha = "center", va = "top")
plt.ylim(-100, 10)
plt.xlim(series.date.min(), series.date.max())
plt.legend(loc = 'upper right')
lax = plt.gca()
plt.sca(lax.twinx())
plt.plot(df["TT"][:, "delta", "confirmed"].index, smoothed(df["TT"][:, "delta", "confirmed"].values), label = "Daily Cases", color = plt.PRED_PURPLE)
plt.legend(loc = 'lower right')
plt.PlotDevice().ylabel("new cases", rotation = -90, labelpad = 50)
plt.sca(lax)
plt.PlotDevice().title("\nIndia Mobility and Case Count Trends")\
    .annotate("Google Mobility Data + Covid19India.org")\
    .xlabel("\ndate")\
    .ylabel("% change in mobility\n")
plt.show()

plt.plot(series.date, smoothed(series.retail_and_recreation_percent_change_from_baseline), label = "Retail/Recreation")
plt.fill_betweenx((-100, 60), pd.to_datetime("March 24, 2020"), pd.to_datetime("June 1, 2020"), color = "black", alpha = 0.05, zorder = -1)
plt.text(s = "national lockdown", x = pd.to_datetime("April 27, 2020"), y = -20, fontdict = plt.note_font, ha = "center", va = "top")
plt.ylim(-100, 10)
Ejemplo n.º 7
0
            for p in tqdm(params)
        }

        outcomes_per_policy(
            {k: v * USD / (1e9)
             for (k, v) in TEV_percentiles.items()},
            "TEV (USD, billions)",
            "D",
            reference=(25, "novax"),
            phis=[25, 50, 100, 200],
            vax_policies=["contact", "random", "mortality"],
            policy_colors=[
                contactrate_vax_color, random_vax_color, mortality_vax_color
            ],
            policy_labels=["contact rate", "random", "mortality"])
        plt.gca().ticklabel_format(axis="y", useOffset=False)
        plt.show()

    ## 2D: state x age
    if "2D" in figs_to_run or "TEV_state_age" in figs_to_run or run_all:
        # focus_state_TEV = {
        #     state: aggregate_dynamic_percentiles_by_age(src, f"total_TEV_{state}*phi50_random.npz", sum_axis = 0, pct_axis = 0)
        #     for state in tqdm([state_name_lookup[_] for _ in  focus_states])
        # }
        focus_state_TEV = {state: np.array(0) for state in focus_states}
        focus_state_agepop = districts_to_run.loc[focus_states].filter(
            regex="N_[0-6]", axis=1).sum(level=0)
        for (state, district) in districts_to_run.loc[focus_states].index:
            state_code = state_name_lookup[state]
            state_age_weight = districts_to_run.loc[state, district].filter(
                regex="N_[0-6]", axis=0) / focus_state_agepop.loc[state]
Ejemplo n.º 8
0
    urban_share = int(
        (1 - serodist.loc[state, ("New " if district == "Delhi" else "") +
                          district]["rural_share"].mean()) * 100)
    density = pop_density.loc[state, district].density
    rt_data = district_estimates.loc[state, district].set_index(
        "dates")["Feb 1, 2021":]
    plt.Rt(rt_data.index,
           rt_data.Rt_pred,
           rt_data.RR_CI_upper,
           rt_data.RR_CI_lower,
           0.95,
           yaxis_colors=False,
           ymin=0.5,
           ymax=2.0)
    if (x, y) != (4, 1):
        plt.gca().get_legend().remove()
    plt.gca().set_xticks([
        pd.Timestamp("February 1, 2021"),
        pd.Timestamp("March 1, 2021"),
        pd.Timestamp("April 1, 2021")
    ])

    plt.PlotDevice()\
        .l_title(district, fontsize = 12)\
        .r_title(f"{urban_share}% urban, {density}/km$^2$", fontsize = 10)

    if district not in xticks:
        plt.gca().set_xticklabels([])
    if district not in yticks:
        plt.gca().set_yticklabels([])
Ejemplo n.º 9
0
                sns.color_palette("YlOrRd_r", n_colors=1 +
                                  len(variants) // 2)))))

# plt.legend(
#     handlelength = 0.6, framealpha = 0,
#     prop = {'size': 10}, loc = "lower left",
#     bbox_to_anchor=(0, 0.925), ncol = len(variants),
#     columnspacing = 1.5, labelspacing = 0.1
# )

plt.stackplot(mutations.index,
              *[mutations[v] for v in variants],
              labels=variants,
              alpha=0.75)

handles, labels = plt.gca().get_legend_handles_labels()
lgnd = plt.gca().legend(handles[::-1],
                        labels[::-1],
                        handlelength=0.6,
                        framealpha=0,
                        prop={'size': 12},
                        ncol=1,
                        loc="center left",
                        bbox_to_anchor=(1, 0.5),
                        handletextpad=1,
                        labelspacing=1)
lgnd.set_title("variant",
               prop={
                   "size": 14,
                   "family": plt.theme.label["family"]
               })
Ejemplo n.º 10
0
    plt.figure()
    plt.Rt(dates, Rt_pred, RR_CI_lower, RR_CI_upper, CI)\
        .ylabel("Estimated $R_t$")\
        .xlabel("Date")\
        .title(district)\
        .size(11, 8)\
        .save(figs/f"Rt_est_MP{district}.png", dpi=600, bbox_inches="tight")#\
    #.show()
    plt.close()

from matplotlib.dates import DateFormatter
formatter = DateFormatter("%b\n%Y")

f = notched_smoothing(window=smoothing)
plt.plot(ts.loc["Maharashtra"].index,
         ts.loc["Maharashtra"].Hospitalized,
         color="black",
         label="raw case counts from API")
plt.plot(ts.loc["Maharashtra"].index,
         f(ts.loc["Maharashtra"].Hospitalized),
         color="black",
         linestyle="dashed",
         alpha=0.5,
         label="smoothed, seasonality-adjusted case counts")
plt.PlotDevice()\
    .l_title("daily case counts in Maharashtra")\
    .axis_labels(x = "date", y = "daily cases")
plt.gca().xaxis.set_major_formatter(formatter)
plt.legend(prop=plt.theme.note, handlelength=1, framealpha=0)
plt.show()
Ejemplo n.º 11
0
            }

            outcomes_per_policy(
                {k: v * USD / (1e9)
                 for (k, v) in TEV_percentiles.items()},
                "TEV (USD, billions)",
                "D",
                reference=(25, "novax"),
                phis=[25, 50, 100, 200],
                vax_policies=["contact", "random", "mortality"],
                policy_colors=[
                    contactrate_vax_color, random_vax_color,
                    mortality_vax_color
                ],
                policy_labels=["contact rate", "random", "mortality"])
            plt.gca().ticklabel_format(axis="y", useOffset=False)
            plt.gcf().set_size_inches((16.8, 9.92))
            plt.PlotDevice().l_title(f"{state_code}: tev")
            plt.savefig(dst / f"{state_code}_{district}_tev.png")
            plt.close("all")

    for (state, code) in [("Bihar", "BR")]:
        dst = dst0 / code
        dst.mkdir(exist_ok=True)
        # for district in simulation_initial_conditions.query(f"state == '{state}'").index.get_level_values(1).unique():
        for district in simulation_initial_conditions.loc[state].index[:16]:
            cf_consumption = np.load(
                src / f"c_p0v0{code}_{district}_phi25_novax.npz")['arr_0']
            cons_mean = np.mean(cf_consumption, axis=1)
            plt.plot(cons_mean)
            plt.PlotDevice().l_title(f"{code} {district}: mean consumption")
Ejemplo n.º 12
0
PrD = pd.DataFrame(prob_deathx).T\
    .rename(columns = dict(enumerate(IN_age_structure.keys())))\
    .assign(t = dt)\
    .set_index("t")
PrD.plot()
plt.legend(title="Age category",
           title_fontsize=18,
           fontsize=16,
           framealpha=1,
           handlelength=1)
plt.xlim(right=pd.Timestamp("Jan 01, 2022"))
plt.PlotDevice()\
    .xlabel("\nDate")\
    .ylabel("Probability\n")
plt.subplots_adjust(left=0.12, bottom=0.12, right=0.94, top=0.96)
plt.gca().xaxis.set_minor_locator(mpl.ticker.NullLocator())
plt.gca().xaxis.set_minor_formatter(mpl.ticker.NullFormatter())
plt.gca().xaxis.set_major_locator(mdates.AutoDateLocator())
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%b %d'))
plt.xticks(fontsize="16")
plt.yticks(fontsize="16")
plt.gca().xaxis.grid(True, which="major")
plt.semilogy()
plt.ylim(bottom=1e-7)
plt.show()

PrD = pd.DataFrame(prob_death).set_index(
    pd.date_range(start=simulation_start, freq="D", periods=len(prob_death)))
plt.plot(PrD, color=TN_color, linewidth=2, label="probability of death")
plt.xlim(left=pd.Timestamp("Jan 01, 2021"), right=pd.Timestamp("Jan 01, 2022"))
plt.PlotDevice().ylabel("log-probability\n")