Beispiel #1
0
def test_calib():

    entry = sc.loadjson(calibfile)[0]
    params = sc.dcp(entry['pars'])
    params['rand_seed'] = int(entry['index'])

    scen = generate_scenarios()['all_remote']
    testing = generate_testing()['None']
    #testing[0]['delay'] = 0
    for stype, spec in scen.items():
        if spec is not None:
            spec['testing'] = testing
    scen['testing'] = testing
    scen['es']['verbose'] = scen['ms']['verbose'] = scen['hs'][
        'verbose'] = debug

    sim = cs.create_sim(params, pop_size=pop_size, folder=folder)

    sm = cvsch.schools_manager(scen)
    sim['interventions'] += [sm]

    sim.run(keep_people=debug)

    stats = evaluate_sim(sim)
    print(stats)

    if debug:
        sim.plot(to_plot='overview')
        #t = sim.make_transtree()
    else:
        sim.plot()

    cv.savefig('sim.png')

    return sim
Beispiel #2
0
def make():

    print('Making figure')
    cv.Sim().run(do_plot=True)
    filename = cv.savefig(comments='good figure')

    return filename
        interval=90,  # Number of days between tick marks
        dateformat='%m/%Y',  # Date format for ticks
        fig_args={'figsize': (14, 8)},  # Size of the figure (x and y)
        axis_args={'left': 0.15},  # Space on left side of plot
    )

    msim.plot_result('r_eff', **plot_customizations)
    #sim.plot_result('r_eff')
    pl.axhline(1.0, linestyle='--', c=[0.8, 0.4, 0.4], alpha=0.8,
               lw=4)  # Add a line for the R_eff = 1 cutoff
    pl.title('')
    pl.savefig('R.pdf')

    msim.plot_result('cum_deaths', **plot_customizations)
    pl.title('')
    cv.savefig('Deaths.pdf')

    msim.plot_result('new_infections', **plot_customizations)
    pl.title('')
    cv.savefig('Infections.pdf')

    msim.plot_result('cum_diagnoses', **plot_customizations)
    pl.title('')
    cv.savefig('Diagnoses.pdf')

##for calibration figures
#msim.plot_result('cum_deaths', interval=20, fig_args={'figsize':(12,7)}, axis_args={'left':0.15})
#pl.title('')
#cv.savefig('Deaths.png')

# msim.plot_result('cum_diagnoses', interval=20, fig_args={'figsize':(12,7)}, axis_args={'left':0.15})
                    scatter=False,
                    color=c)
    ax.set_xlabel('School size (students)')
    ax.set_ylabel('Infection on First Day (%)')
    yt = [0.0, 0.25, 0.50, 0.75]
    ax.set_yticks(yt)
    ax.set_yticklabels([int(100 * t) for t in yt])
    ax.grid(True)

    lmap = {
        'None': 'No PCR testing',
        '2020-10-26': 'One week prior',
        '2020-10-27': 'Six days prior',
        '2020-10-28': 'Five days prior',
        '2020-10-29': 'Four days prior',
        '2020-10-30': 'Three days prior',
        '2020-10-31': 'Two days prior',
        '2020-11-01': 'One day prior',
        '2020-11-02': 'On the first day of school'
    }

    handles, labels = ax.get_legend_handles_labels()
    handles = [handles[-1]] + handles[:-1]
    labels = [lmap[labels[-1]]] + [lmap[l] for l in labels[:-1]]

    ax.legend(handles, labels)
    plt.tight_layout()
    cv.savefig(os.path.join(imgdir, 'PCR_Days_Sweep.png'), dpi=300)

    fig.tight_layout()
Beispiel #5
0
        sims,
        ax[2],
        calib=True,
        label='Cumulative infections\n(modeled)',
        ylabel='',
        flabel=False)
plotter('n_infectious',
        sims,
        ax[2],
        calib=True,
        label='Active infections\n(modeled)',
        ylabel='Estimated infections',
        flabel=False)
pl.legend(loc='upper left', frameon=False)

# d. cumulative deaths
ax[3] = pl.axes([xgapl + xgapm + dx1, ygapb, dx2, dy])
format_ax(ax[3], sim)
plotter('cum_deaths',
        sims,
        ax[3],
        calib=True,
        label='Deaths\n(modeled)',
        ylabel='Cumulative deaths',
        flabel=False)
pl.legend(loc='upper left', frameon=False)
#pl.ylim([0, 10e3])

cv.savefig(f'{figsfolder}/fig2_calibration.pdf')

sc.toc(T)
def plot_reff_combined(num_param_set, date_of_file):

    colors = ['forestgreen', 'mediumpurple', 'tab:orange', 'maroon']

    case = '50_cases_re_0.9'
    rel_trans = 'under10_0.5trans'
    beta_layer = '3xschool_beta_layer'
    rel_sus = 'rel_sus'
    sens = [rel_trans, None, rel_sus]#, beta_layer]

    name = 'Infectivity/Susceptibility Assumptions'
    sens_labels = {
        None: 'Baseline: children under 10 as infectious as adults; '
              '\nper-contact infectivity in school 20% relative to households;'
              '\n0-20 year olds 33-66% less susceptible than 20+',
        rel_trans: '50% reduced transmission in children under 10',
        beta_layer: 'Same infectivity per contact in schools as households',
        rel_sus: '0-20 year olds as susceptible as 20+'
    }

    df_by_prev = []
    df_by_prev_std = []

    for sen in sens:
        df_mean = []
        df_std = []
        for strat in strategies:
            df_mean.append(combine_results_dfs(strat, num_param_set, 'r_eff', True, date_of_file, case, sen))
            df_std.append(combine_results_dfs(strat, num_param_set, 'r_eff', False, date_of_file, case, sen))

        scenario_strategies = strats

        df_comb = pd.DataFrame(df_mean).transpose()
        df_comb.columns = scenario_strategies

        df_comb_std = pd.DataFrame(df_std).transpose()
        df_comb_std.columns = scenario_strategies

        df_by_prev.append(df_comb)
        df_by_prev_std.append(df_comb_std)

    base = dt.datetime(2020, 7, 1)
    date_list = [base + dt.timedelta(days=x) for x in range(len(df_comb))]
    x = []
    for date in date_list:
        x.append(date.strftime('%b %d'))

    date_to_x = {d: i for i, d in enumerate(x)}

    min_x = date_to_x['Aug 30']
    max_x = date_to_x['Dec 01']

    for i, _ in enumerate(sens):
        df_by_prev[i] = df_by_prev[i].iloc[min_x:max_x,]
        df_by_prev[i] = df_by_prev[i].mean(axis=0)
        df_by_prev_std[i] = df_by_prev_std[i].iloc[min_x:max_x, ]
        df_by_prev_std[i] = df_by_prev_std[i].mean(axis=0)

    x = np.arange(len(strategy_labels))

    width = [-.2, 0, .2] #[-.3, -.1, .1, .3]

    left = 0.07
    right = 0.99
    # right = 0.63
    bottom = 0.15
    top = 0.96
    fs = 24

    fig, ax = plt.subplots(figsize=(13,9))
    fig.subplots_adjust(left=left, right=right, top=top, bottom=bottom)
    for i, sen in enumerate(sens):
        ax.bar(x + width[i], df_by_prev[i].values, yerr = df_by_prev_std[i].values/2, width=0.2,
               label=sens_labels[sen], color=colors[i], alpha = 0.87)
    # ax.bar(x, df_comb.values, yerr=df_comb_std.values / 2, width=0.4, alpha=0.87)
    ax.axhline(y=1, xmin=0, xmax=1, color='black', ls='--')

    ax.set_ylabel('Effective Reproductive Number', size=fs)
    #ax.set_title(f'Effective Reproductive Number by School Reopening Strategy', size=18, horizontalalignment='center')
    ax.set_ylim(0.7, 1.4)
    leg_i = ax.legend(fontsize=fs, title=name)
    leg_i.set_title(name, prop={'size': fs})
    ax.set_xticks(x)
    ax.set_xlim(-0.45, 6.45)
    ax.tick_params(labelsize=21)
    ax.set_xticklabels(strategy_labels_brief.values(), fontsize=20)

    cv.savefig(f'r_eff_{date_of_file}.png')
    plt.savefig(f'r_eff_{date_of_file}.pdf')
Beispiel #7
0
    c1 = [0.3, 0.3, 0.6]  # diags
    c2 = [0.6, 0.7, 0.9]  #diags
    pl.bar(X, pos, width=w, label='Data', facecolor=c1)
    pl.bar(XX, mpbest, width=w, label='Model', facecolor=c2)
    for i, ix in enumerate(XX):
        pl.plot([ix, ix], [mplow[i], mphigh[i]], c='k')
    ax[0].set_xticks((X + XX) / 2)
    ax[0].set_xticklabels(x)
    pl.xlabel('Age')
    pl.ylabel('Diagnoses')
    sc.boxoff(ax[0])
    pl.legend(frameon=False, bbox_to_anchor=(0.3, 0.7))

    # Deaths
    ax[1] = pl.axes([xl + dx + xm, yb, dx, dy])
    c1 = [0.5, 0.0, 0.0]  # deaths
    c2 = [0.9, 0.4, 0.3]  # deaths
    pl.bar(X, deaths, width=w, label='Data', facecolor=c1)
    pl.bar(XX, mdbest, width=w, label='Model', facecolor=c2)
    for i, ix in enumerate(XX):
        pl.plot([ix, ix], [mdlow[i], mdhigh[i]], c='k')
    ax[1].set_xticks((X + XX) / 2)
    ax[1].set_xticklabels(x)
    pl.xlabel('Age')
    pl.ylabel('Deaths')
    sc.boxoff(ax[1])
    pl.legend(frameon=False, bbox_to_anchor=(0.3, 0.7))

    plotname = 'uk_stats_by_age_agg.png' if aggregate else 'uk_stats_by_age.png'
    cv.savefig(plotname, dpi=100)
            ax[pn].set_ylabel('R')
    elif pn in range(ncols, ncols * 2):
        plotter('cum_deaths', sims[pn % ncols], ax[pn])
        ax[pn].set_ylim(0, 150_000)
        if (pn % ncols) == 0:
            ax[pn].set_ylabel('Total deaths')
    else:
        plotter('new_infections', sims[pn % ncols], ax[pn])
        ax[pn].set_ylim(0, 250_000)
        if (pn % ncols) == 0:
            ax[pn].set_ylabel('New infections')

    if pn not in range(ncols):
        ax[pn].set_xticklabels([])

cv.savefig(f'{figsfolder}/fig_UK_school_scens.png', dpi=100)

################################################################################
# ## Fig 3
################################################################################
pl.figure(figsize=(24, 12))
#font_size = 24
#pl.rcParams['font.size'] = font_size

# Subplot sizes
xgapl = 0.06
xgapm = 0.1
xgapr = 0.01
ygapb = 0.11
ygapm = 0.1
ygapt = 0.02
Beispiel #9
0
def plot_calibration(sims, date, do_save=0):

    sim = sims[0] # For having a sim to refer to

    # Draw plots
    fig1_path = f'calibration_{date}_fig1.png'
    fig2_path = f'calibration_{date}_fig2.png'
    fig_args    = sc.mergedicts({'figsize': (16, 14)})
    axis_args   = sc.mergedicts({'left': 0.10, 'bottom': 0.05, 'right': 0.95, 'top': 0.93, 'wspace': 0.25, 'hspace': 0.40})

    # Handle input arguments -- merge user input with defaults
    low_q = 0.1
    high_q = 0.9

    # Figure 1: Calibration
    pl.figure(**fig_args)
    pl.subplots_adjust(**axis_args)
    pl.figtext(0.42, 0.95, 'Model calibration', fontsize=30)


    #%% Figure 1, panel 1
    ax = pl.subplot(4,1,1)
    format_ax(ax, sim)
    plotter('new_tests', sims, ax, calib=True, label='Number of tests per day', ylabel='Tests')
    plotter('new_diagnoses', sims, ax, calib=True, label='Number of diagnoses per day', ylabel='Tests')


    #%% Figure 1, panel 2
    ax = pl.subplot(4,1,2)
    format_ax(ax, sim)
    plotter('cum_diagnoses', sims, ax, calib=True, label='Cumulative diagnoses', ylabel='People')


    #%% Figure 1, panel 3
    ax = pl.subplot(4,1,3)
    format_ax(ax, sim)
    plotter('cum_deaths', sims, ax, calib=True, label='Cumulative deaths', ylabel='Deaths')


    #%% Figure 1, panels 4A and 4B

    agehists = []

    for s,sim in enumerate(sims):
        agehist = sim['analyzers'][0]
        if s == 0:
            age_data = agehist.data
        agehists.append(agehist.hists[-1])

    x = age_data['age'].values
    pos = age_data['cum_diagnoses'].values
    death = age_data['cum_deaths'].values

    # From the model
    mposlist = []
    mdeathlist = []
    for hists in agehists:
        mposlist.append(hists['diagnosed'])
        mdeathlist.append(hists['dead'])
    mposarr = np.array(mposlist)
    mdeatharr = np.array(mdeathlist)

    mpbest = pl.median(mposarr, axis=0)
    mplow  = pl.quantile(mposarr, q=low_q, axis=0)
    mphigh = pl.quantile(mposarr, q=high_q, axis=0)
    mdbest = pl.median(mdeatharr, axis=0)
    mdlow  = pl.quantile(mdeatharr, q=low_q, axis=0)
    mdhigh = pl.quantile(mdeatharr, q=high_q, axis=0)

    # Plotting
    w = 4
    off = 2
    bins = x.tolist() + [100]

    ax = pl.subplot(4,2,7)
    c1 = [0.3,0.3,0.6]
    c2 = [0.6,0.7,0.9]
    xx = x+w-off
    pl.bar(x-off,pos, width=w, label='Data', facecolor=c1)
    pl.bar(xx, mpbest, width=w, label='Model', facecolor=c2)
    for i,ix in enumerate(xx):
        pl.plot([ix,ix], [mplow[i], mphigh[i]], c='k')
    ax.set_xticks(bins[:-1])
    pl.title('Diagnosed cases by age')
    pl.xlabel('Age')
    pl.ylabel('Cases')
    pl.legend()

    ax = pl.subplot(4,2,8)
    c1 = [0.5,0.0,0.0]
    c2 = [0.9,0.4,0.3]
    pl.bar(x-off,death, width=w, label='Data', facecolor=c1)
    pl.bar(x+w-off, mdbest, width=w, label='Model', facecolor=c2)
    for i,ix in enumerate(xx):
        pl.plot([ix,ix], [mdlow[i], mdhigh[i]], c='k')
    ax.set_xticks(bins[:-1])
    pl.title('Deaths by age')
    pl.xlabel('Age')
    pl.ylabel('Deaths')
    pl.legend()

    # Tidy up
    if do_save:
        cv.savefig(fig1_path)


    # Figure 2: Projections
    pl.figure(**fig_args)
    pl.subplots_adjust(**axis_args)
    pl.figtext(0.42, 0.95, 'Model estimates', fontsize=30)

    #%% Figure 2, panel 1
    ax = pl.subplot(4,1,1)
    format_ax(ax, sim)
    plotter('cum_infections', sims, ax,calib=True, label='Cumulative infections', ylabel='People')
    plotter('cum_recoveries', sims, ax,calib=True, label='Cumulative recoveries', ylabel='People')

    #%% Figure 2, panel 2
    ax = pl.subplot(4,1,2)
    format_ax(ax, sim)
    plotter('n_infectious', sims, ax,calib=True, label='Number of active infections', ylabel='People')
    plot_intervs(sim, labels=True)

    #%% Figure 2, panel 3
    ax = pl.subplot(4,1,3)
    format_ax(ax, sim)
    plotter('new_infections', sims, ax,calib=True, label='Infections per day', ylabel='People')
    plotter('new_recoveries', sims, ax,calib=True, label='Recoveries per day', ylabel='People')
    plot_intervs(sim)

    #%% Figure 2, panels 4
    ax = pl.subplot(4,1,4)
    format_ax(ax, sim)
    plotter('r_eff', sims, ax, calib=True, label='Effective reproductive number', ylabel=r'$R_{eff}$')

    ylims = [0,4]
    pl.ylim(ylims)
    xlims = pl.xlim()
    pl.plot(xlims, [1, 1], 'k')
    plot_intervs(sim)

    # Tidy up
    if do_save:
        cv.savefig(fig2_path)

    return
# Plot D: severe cases
pl.subplot(2, 2, 4)
colors = pl.cm.YlOrBr([0.9, 0.6, 0.3])
for i, l in enumerate(labels):
    if i == 0:
        ds = np.arange(0, len(tvec_d), 1)  # Downsample
        thissim = msims[l].sims[0]
        datatoplot = thissim.data['new_severe'][date_inds_d[0]:date_inds_d[1]]
        pl.plot(tvec_d[ds],
                datatoplot[ds],
                'd',
                c='k',
                markersize=12,
                alpha=0.75,
                label='Data')
    toplot = plotdict['new_severe'][l][date_inds[0]:date_inds[1]]
    pl.plot(tvec, toplot, c=colors[i], label=l, lw=4, alpha=1.0)
    low = plotdict_l['new_severe'][l][date_inds[0]:date_inds[1]]
    high = plotdict_h['new_severe'][l][date_inds[0]:date_inds[1]]
    pl.fill_between(tvec, low, high, facecolor=colors[i], alpha=0.2)
pl.ylabel('Daily Hospitalisations')
ax = pl.gca()
ax.set_xticks(datemarks)
cv.date_formatter(start_day=start_day, ax=ax, dateformat=dateformat)
sc.setylim()
sc.commaticks()
pl.legend(frameon=False)
sc.boxoff()
cv.savefig('figs_Vac/uk_scens_plots.png')
# Frac in-person days lost
d = pd.melt(df, id_vars=['key1', 'key2', 'key3'], value_vars=[f'perc_inperson_days_lost_{gkey}' for gkey in grp_dict.keys()], var_name='Group', value_name='Days lost (%)')
d.replace( {'Group': {f'perc_inperson_days_lost_{gkey}':gkey for gkey in grp_dict.keys()}}, inplace=True)
d = d.loc[d['key1']==school_scenario] # K-5 only
g = sns.FacetGrid(data=d, row='Group', height=4, aspect=3, row_order=['Teachers & Staff', 'Students'], legend_out=False)
g.map_dataframe( sns.barplot, x='key2', y='Days lost (%)', hue='key3', order=test_order, hue_order=sens_order, palette='Set2')
g.set_titles(row_template="{row_name}", fontsize=24)
g.set_axis_labels(y_var="Days lost (%)")
plt.tight_layout()

for axi, ax in enumerate(g.axes.flat):
    box = ax.get_position()
    ax.set_position([box.x0, box.y0 + (axi+1)*box.height * 0.1, box.width, box.height * 0.9])
g.axes.flat[1].legend(loc='upper center',bbox_to_anchor=(0.48,-0.16), ncol=4, fontsize=14)

cv.savefig(os.path.join(imgdir, '3mInPersonDaysLost_countermeasures.png'), dpi=300)


# Attack rate
for aspect, fontsize in zip([2.5, 3], [12,14]):
    d = pd.melt(df, id_vars=['key1', 'key2', 'key3'], value_vars=[f'attackrate_{gkey}' for gkey in grp_dict.keys()], var_name='Group', value_name='Cum Inc (%)')
    d.replace( {'Group': {f'attackrate_{gkey}':gkey for gkey in grp_dict.keys()}}, inplace=True)
    d = d.loc[d['key1']==school_scenario] # K-5 only
    g = sns.FacetGrid(data=d, row='Group', height=4, aspect=aspect, row_order=['Teachers & Staff', 'Students'], legend_out=False) # col='key1', 
    g.map_dataframe( sns.barplot, x='key2', y='Cum Inc (%)', hue='key3', order=test_order, hue_order=sens_order, palette='Set2')
    g.set_titles(row_template="{row_name}")
    g.set_axis_labels(y_var="3-Month Attack Rate (%)")
    plt.tight_layout()

    for axi, ax in enumerate(g.axes.flat):
        box = ax.get_position()
Beispiel #12
0
        interval=90,  # Number of days between tick marks
        dateformat='%m/%Y',  # Date format for ticks
        fig_args={'figsize': (14, 8)},  # Size of the figure (x and y)
        axis_args={'left': 0.15},  # Space on left side of plot
    )

    msim.plot_result('r_eff', **plot_customizations)
    #sim.plot_result('r_eff')
    pl.axhline(1.0, linestyle='--', c=[0.8, 0.4, 0.4], alpha=0.8,
               lw=4)  # Add a line for the R_eff = 1 cutoff
    pl.title('')
    pl.savefig('R.pdf')

    msim.plot_result('cum_deaths', **plot_customizations)
    pl.title('')
    cv.savefig('Deaths.pdf')

    msim.plot_result('new_infections', **plot_customizations)
    pl.title('')
    cv.savefig('Casespdf')

    #msim.plot_result('cum_diagnoses', **plot_customizations)
    #pl.title('')
    #cv.savefig('Diagnoses18_68.pdf')

##for calibration figures
#msim.plot_result('cum_deaths', interval=20, fig_args={'figsize':(12,7)}, axis_args={'left':0.15})
#pl.title('')
#cv.savefig('Deaths.png')

# msim.plot_result('cum_diagnoses', interval=20, fig_args={'figsize':(12,7)}, axis_args={'left':0.15})
    symp_ax.bar(xax[0]-2, asymp_frac, label='Asymptomatic', color=colors[0])
    symp_ax.bar(xax[:presymp], sympcounts[:presymp], label='Presymptomatic', color=colors[1])
    symp_ax.bar(xax[presymp:], sympcounts[presymp:], label='Symptomatic', color=colors[2])
    symp_ax.set_xlabel('Days since symptom onset')
    symp_ax.set_ylabel('Proportion of transmissions (%)')
    symp_ax.set_xticks([minind-3, 0, 5, 10, maxind])
    symp_ax.set_xticklabels(['Asymp.', '0', '5', '10', f'>{maxind}'])
    sc.boxoff(ax=symp_ax)

    spie_ax = pl.axes([sympx+0.05, 0.20, 0.2, 0.2])
    labels = [f'Asymp-\ntomatic\n{asymp_frac:0.0f}%', f' Presymp-\n tomatic\n {pre_frac:0.0f}%', f'Symp-\ntomatic\n{symp_frac:0.0f}%']
    spie_ax.pie([asymp_frac, pre_frac, symp_frac], labels=labels, colors=colors, **pieargs)

    return fig

# Actually plot
fig = plot()


#%% Tidy up

if do_save:
    cv.savefig(fig_path, dpi=150)

if do_show:
    pl.show()

sc.toc(T)

print('Done.')
    if do_plot:

        # Make individual plots
        plot_customizations = dict(
            interval   = 90, # Number of days between tick marks
            dateformat = '%m/%Y', # Date format for ticks
            fig_args   = {'figsize':(14,8)}, # Size of the figure (x and y)
            axis_args  = {'left':0.15}, # Space on left side of plot
            )

        for mkey, msim in final_msims.items():

            msim.plot_result('r_eff', do_show=do_show, **plot_customizations)
            pl.axhline(1.0, linestyle='--', c=[0.8,0.4,0.4], alpha=0.8, lw=4) # Add a line for the R_eff = 1 cutoff
            pl.title('')
            if do_save: cv.savefig(f'R_eff_{mkey}.png')

            msim.plot_result('cum_deaths', do_show=do_show, **plot_customizations)
            pl.title('')
            if do_save: cv.savefig(f'Deaths_{mkey}.png')

            msim.plot_result('new_infections', do_show=do_show, **plot_customizations)
            pl.title('')
            if do_save: cv.savefig(f'Infections_{mkey}.png')

            msim.plot_result('cum_diagnoses', do_show=do_show, **plot_customizations)
            pl.title('')
            if do_save: cv.savefig(f'Diagnoses_{mkey}.png')


Beispiel #15
0
        ax[pn].set_xticks(datemarks)

    if pn == 1:
        ax[pn] = sns.swarmplot(x="variable",
                               y="value",
                               data=df2,
                               color="grey",
                               alpha=0.5)
        ax[pn] = sns.violinplot(x="variable",
                                y="value",
                                data=df2,
                                color="lightblue",
                                alpha=0.5,
                                inner=None)
        ax[pn] = sns.pointplot(x="variable",
                               y="value",
                               data=df2,
                               ci=None,
                               color="steelblue",
                               markers='D',
                               scale=1.2)

        ax[pn].set_ylabel('Cumulative infections, 1 Dec 2020 - 1 Mar 2021')
        ax[pn].set_xlabel('Symptomatic testing rate')

cv.savefig(f'{figsfolder}/fig4_multiscens.pdf')

print([np.median(cuminf[tn]) for tn in range(len(thresholds))])
print([np.quantile(cuminf[tn], q=0.025) for tn in range(len(thresholds))])
print([np.quantile(cuminf[tn], q=0.975) for tn in range(len(thresholds))])
sc.toc(T)
Beispiel #16
0
        if pn==4: pl.legend(loc='upper right', frameon=False, fontsize=20)
        if pn not in [0,5,10,15]:
            ax[pn].set_yticklabels([])
        else:
            ax[pn].set_ylabel('New infections')
        if pn not in range(nv):
            ax[pn].set_xticklabels([])
        else:
            xmin, xmax = ax[pn].get_xlim()
            ax[pn].set_xticks(pl.arange(xmin+5, xmax, 40))

    if thisfig==resultsfolder: figname = figsfolder+'/fig2_grid.png'
    elif thisfig==sensfolder: figname = figsfolder+'/figS2_grid.png'

    cv.savefig(figname, dpi=100)


#d = {'testing': [0.067]*nv*nm+[0.1]*nv*nm+[0.15]*nv*nm+[0.19]*nv*nm, 'tracing': [0.0]*nm+[0.25]*nm+[0.5]*nm+[0.75]*nm+[1.0]*nm+[0.0]*nm+[0.25]*nm+[0.5]*nm+[0.75]*nm+[1.0]*nm+[0.0]*nm+[0.25]*nm+[0.5]*nm+[0.75]*nm+[1.0]*nm+[0.0]*nm+[0.25]*nm+[0.5]*nm+[0.75]*nm+[1.0]*nm, 'masks': [0.0,0.25,0.5,0.75]*nt*nv}
#d['val'] = []
#for t in tlevels:
#    for v in vlevels:
#        d['val'].extend(sc.sigfig(results['cum_infections'][t][v],3))
#import pandas as pd
#df = pd.DataFrame(d)
#df.to_excel('sweepresults.xlsx')


################################################################################################################
# Figure 3: bar plot of cumulative infections
################################################################################################################
Beispiel #17
0
import covasim as cv

fn = 'inputs/kc_synthpops_clustered_withstaff_seed0.ppl'

sim = cv.Sim(pop_size=225e3, pop_type='synthpops', popfile=fn, load_pop=True)
sim.initialize()
sim.people.plot()

cv.savefig('people_plot.png')
Beispiel #18
0
    # Save the key figures
    plot_customizations = dict(
        interval=90,  # Number of days between tick marks
        dateformat='%m/%Y',  # Date format for ticks
        fig_args={'figsize': (14, 8)},  # Size of the figure (x and y)
        axis_args={'left': 0.15},  # Space on left side of plot
    )

    msim.plot_result('r_eff', **plot_customizations)
    #sim.plot_result('r_eff')
    pl.axhline(1.0, linestyle='--', c=[0.8, 0.4, 0.4], alpha=0.8,
               lw=4)  # Add a line for the R_eff = 1 cutoff
    pl.title('')
    pl.savefig('%s/test%s-trace%s-R.pdf' % (scenario, args.test, args.trace))

    msim.plot_result('cum_deaths', **plot_customizations)
    pl.title('')
    cv.savefig('%s/test%s-trace%s-Deaths.pdf' %
               (scenario, args.test, args.trace))

    msim.plot_result('new_infections', **plot_customizations)
    pl.title('')
    cv.savefig('%s/test%s-trace%s-Infections.pdf' %
               (scenario, args.test, args.trace))

    msim.plot_result('cum_diagnoses', **plot_customizations)
    pl.title('')
    cv.savefig('%s/test%s-trace%s-Diagnoses.pdf' %
               (scenario, args.test, args.trace))
            ci = 1

        rvec = []
        for sim in msim.sims:
            r = sim.results[ch].values
            ps = 100000 / sim.pars['pop_size'] * sim.pars['pop_scale']
            if ch in ['new_diagnoses']:
                r *= 14 * ps
            if ch in ['new_infections', 'new_tests']:
                r *= ps
            rvec.append(r)

        ax = axv[ri, ci]
        med = np.median(rvec, axis=0)
        sd = np.std(rvec, axis=0)
        ax.fill_between(sim.results['date'],
                        med + 2 * sd,
                        med - 2 * sd,
                        color='lightgray')
        ax.plot(sim.results['date'], med + 2 * sd, color='gray', lw=1)
        ax.plot(sim.results['date'], med - 2 * sd, color='gray', lw=1)
        ax.plot(sim.results['date'], med, color='k', lw=2)
        if 'ref' in info and info['ref'] is not None:
            ax.axhline(y=info['ref'], color='r', ls='--', lw=2)
        ax.set_title(info['title'])

        ri += 1

    f.tight_layout()
    cv.savefig(os.path.join(imgdir, f'calib.png'), dpi=300)
        time_num_new = np.append(
            time_num_new,
            sim.people.date_diagnosed[idx] - sim.people.date_symptomatic[idx])
        yield_num_new = yield_num_new + sim.results['test_yield'].values
        idx_dia = sim.people.diagnosed
        idx = ~np.isnan(sim.people.date_symptomatic)
        stage_new['mild'] += sum(idx[idx_dia]) / sum(idx)
        idx = ~np.isnan(sim.people.date_severe)
        stage_new['sev'] += sum(idx[idx_dia]) / sum(idx)
        idx = ~np.isnan(sim.people.date_critical)
        stage_new['crit'] += sum(idx[idx_dia]) / sum(idx)
        sc.toc(t)
        sim.plot(to_plot={'test': ['new_tests']})
        pl.show()
        if do_save:
            cv.savefig('testScalingNum.png')
            pl.close()

        sim = single_sim_new(rand_seed=i, dist=None)
        t = sc.tic()
        sim.run()
        idx = sim.people.diagnosed
        time_num_flat = np.append(
            time_num_flat,
            sim.people.date_diagnosed[idx] - sim.people.date_symptomatic[idx])
        yield_num_flat = yield_num_flat + sim.results['test_yield'].values
        idx_dia = sim.people.diagnosed
        idx = ~np.isnan(sim.people.date_symptomatic)
        stage_flat['mild'] += sum(idx[idx_dia]) / sum(idx)
        idx = ~np.isnan(sim.people.date_severe)
        stage_flat['sev'] += sum(idx[idx_dia]) / sum(idx)
    print('Deaths: ', msim.results['cum_deaths'][-1])
    print('Infections: ', msim.results['cum_infectious'][-1])

    # Save the key figures
    plot_customizations = dict(
        interval=90,  # Number of days between tick marks
        dateformat='%Y/%m',  # Date format for ticks
        fig_args={'figsize': (14, 8)},  # Size of the figure (x and y)
        axis_args={'left': 0.15},  # Space on left side of plot
    )

    msim.plot_result('r_eff', **plot_customizations)
    pl.axhline(1.0, linestyle='--', c=[0.8, 0.4, 0.4], alpha=0.8,
               lw=4)  # Add a line for the R_eff = 1 cutoff
    pl.title('')
    cv.savefig('R_eff.png')

    msim.plot_result('cum_deaths', **plot_customizations)
    pl.title('')
    cv.savefig('Deaths.png')

    msim.plot_result('new_infections', **plot_customizations)
    pl.title('')
    cv.savefig('Infections.png')

    msim.plot_result('cum_diagnoses', **plot_customizations)
    pl.title('')
    cv.savefig('Diagnoses.png')

    msim.plot_result('new_tests', **plot_customizations)
    cv.savefig('Test.png')
Beispiel #22
0
               fontsize=36, fontweight='bold', bbox={'edgecolor': 'none', 'facecolor': 'white', 'alpha': 0.5, 'pad': 4})

    pl.figtext(xgapl + dx * ncols + xgapm + rlpad, ygapb + (ygapm + dy) * 0 + epsy,'            Masks: 30% EC        ',
               rotation=90, fontweight='bold',
               bbox={'edgecolor': 'none', 'facecolor': 'white', 'alpha': 0.5, 'pad': 4})
    pl.figtext(xgapl + dx * ncols + xgapm + rlpad, ygapb + (ygapm + dy) * 1 + epsy,'            Masks: 15% EC        ',
               rotation=90, fontweight='bold',
               bbox={'edgecolor': 'none', 'facecolor': 'white', 'alpha': 0.5, 'pad': 4})

    for pn in range(nplots):
        ax[pn] = pl.axes([xgapl + (dx + xgapm) * (pn % ncols), ygapb + (ygapm + dy) * (pn // ncols), dx, dy])
        print([xgapl + (dx + xgapm) * (pn % ncols), ygapb + (ygapm + dy) * (pn // ncols)])
        print(list(sims.keys())[pn])
        format_ax(ax[pn], sim)
        ax[pn].axvline(masks_begin, c=[0, 0, 0], linestyle='--', alpha=0.4, lw=3)
#        pl.figtext(xgapl + (dx + xgapm) * (pn % ncols) + dx, ygapb + (ygapm + dy) * (pn // ncols) + dy, '            24%         ')
        ax[pn].annotate(labels[tti_scen][pn], xy=(700, 340_000), xycoords='data', ha='right', va='top')
        plotter('new_infections', sims[pn], ax[pn])
        ax[pn].set_ylim(0, 150_000)
        ax[pn].set_yticks(np.arange(0, 150_000, 25_000))

        if (pn%ncols) != 0:
            ax[pn].set_yticklabels([])
        else:
            ax[pn].set_ylabel('New infections')
        if pn not in range(ncols):
            ax[pn].set_xticklabels([])

    cv.savefig(f'{figsfolder}/fig_scens_{tti_scen}TTI.png', dpi=100)

sc.toc(T)
import pandas as pd

data = pd.read_csv(swabfile)
data = data.loc[data['Test Delay'] != '#NAME?', ]
pdf = cvu.get_pdf('lognormal', 10, 170)

# Check that not using a distribution gives the same answer as before
pl.hist(time_prob_old, np.arange(-2.5, 25.5), density=True)
pl.hist(time_prob_flat, np.arange(-2.5, 25.5), density=True, alpha=.25)
pl.xlim([-2, 20])
pl.xlabel('Symptom onset to swab')
pl.ylabel('Percent of tests')
pl.show()
if do_save:
    cv.savefig('testprobOld.png')
    pl.close()

# See how close the default distribution is the the WA data and what the model
# produces
pl.hist(time_prob_new, np.arange(-2.5, 25.5), density=True)
pl.plot(data['Test Delay'], data['Percent'] / 100)
pl.plot(np.arange(100), pdf.pdf(np.arange(100)))
pl.xlim([-2, 20])
pl.xlabel('Symptom onset to swab')
pl.ylabel('Percent of tests')
pl.legend(['Data', 'Distribution', 'Sim Histogram'])
pl.show()
if do_save:
    cv.savefig('testprobEmperical.png')
    pl.close()
    layer_counts[date, layer_num] += sim.rescale_vec[date]

lockdown1 = [sc.readdate('2020-03-23'),sc.readdate('2020-05-31')]
lockdown2 = [sc.readdate('2020-11-05'),sc.readdate('2020-12-03')]
lockdown3 = [sc.readdate('2021-01-04'),sc.readdate('2021-02-08')]

labels = ['Household', 'School', 'Workplace', 'Community']
for l in range(n_layers):
    ax.plot(sim.datevec, layer_counts[:,l], c=colors[l], lw=3, label=labels[l])
ax.axvspan(lockdown1[0], lockdown1[1], color='steelblue', alpha=0.2, lw=0)
ax.axvspan(lockdown2[0], lockdown2[1], color='steelblue', alpha=0.2, lw=0)
ax.axvspan(lockdown3[0], lockdown3[1], color='lightblue', alpha=0.2, lw=0)
sc.setylim(ax=ax)
sc.boxoff(ax=ax)
ax.set_ylabel('Transmissions per day')
ax.set_xlim([sc.readdate('2020-01-21'), sc.readdate('2021-03-01')])
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b\n%y'))

datemarks = pl.array([sim.day('2020-02-01'), sim.day('2020-03-01'), sim.day('2020-04-01'), sim.day('2020-05-01'), sim.day('2020-06-01'),
                      sim.day('2020-07-01'), sim.day('2020-08-01'), sim.day('2020-09-01'), sim.day('2020-10-01'),
                      sim.day('2020-11-01'), sim.day('2020-12-01'), sim.day('2021-01-01'), sim.day('2021-02-01'), sim.day('2021-03-01')])
ax.set_xticks([sim.date(d, as_date=True) for d in datemarks])
ax.legend(frameon=False)


yl = ax.get_ylim()
labely = yl[1]*1.015

cv.savefig(f'{figsfolder}/fig_trans.png', dpi=100)

sc.toc(T)
Beispiel #25
0
def test_misc():
    sc.heading('Testing miscellaneous functions')

    sim_path = 'test_misc.sim'
    json_path = 'test_misc.json'
    gitinfo_path = 'test_misc.gitinfo'
    fig_path = 'test_misc.png'
    fig_comments = 'Test comment'

    # Data loading
    cv.load_data(csv_file)
    cv.load_data(xlsx_file)

    with pytest.raises(NotImplementedError):
        cv.load_data('example_data.unsupported_extension')

    with pytest.raises(ValueError):
        cv.load_data(xlsx_file, columns=['missing_column'])

    # Dates
    d1 = cv.date('2020-04-04')
    d2 = cv.date(sc.readdate('2020-04-04'))
    ds = cv.date('2020-04-04', d2)
    assert d1 == d2
    assert d2 == ds[0]

    with pytest.raises(ValueError):
        cv.date([(2020, 4, 4)])  # Raises a TypeError which raises a ValueError

    with pytest.raises(ValueError):
        cv.date('Not a date')

    cv.daydiff('2020-04-04')

    # Run sim for more investigations
    sim = cv.Sim(pop_size=500, verbose=0)
    sim.run()
    sim.plot(do_show=False)

    # Saving and loading
    cv.savefig(fig_path, comments=fig_comments)
    cv.save(filename=sim_path, obj=sim)
    cv.load(filename=sim_path)

    # Version checks
    cv.check_version('0.0.0')  # Nonsense version
    print('↑ Should complain about version')
    with pytest.raises(ValueError):
        cv.check_version('0.0.0', die=True)

    # Git checks
    cv.git_info(json_path)
    cv.git_info(json_path, check=True)

    # Poisson tests
    c1 = 5
    c2 = 8
    for alternative in ['two-sided', 'larger', 'smaller']:
        cv.poisson_test(c1, c2, alternative=alternative)
    for method in ['score', 'wald', 'sqrt', 'exact-cond']:
        cv.poisson_test(c1, c2, method=method)

    with pytest.raises(ValueError):
        cv.poisson_test(c1, c2, method='not a method')

    # Test locations
    for location in [None, 'viet-nam']:
        cv.data.show_locations(location)

    # Test versions
    with pytest.raises(ValueError):
        cv.check_save_version('1.3.2', die=True)
    cv.check_save_version(cv.__version__,
                          filename=gitinfo_path,
                          comments='Test')

    # Test PNG
    try:
        metadata = cv.get_png_metadata(fig_path, output=True)
        assert metadata['Covasim version'] == cv.__version__
        assert metadata['Covasim comments'] == fig_comments
    except ImportError as E:
        print(
            f'Cannot test PNG function since pillow not installed ({str(E)}), skipping'
        )

    # Tidy up
    remove_files(sim_path, json_path, fig_path, gitinfo_path)

    return
Beispiel #26
0
def plot_attack_rate(date_of_file, cases, sens):
    colors = ['lightseagreen', 'lightsteelblue', 'lightcoral']

    name = 'Cases per 100k people during the\ntwo weeks prior to school reopening'
    if '1.1' in cases[0]:
        prev = 'by_cases_rising'
        subtitle = '(Re > 1)'
    else:
        prev = 'by_cases_falling'
        subtitle = '(Re < 1)'
    label = inc_labels

    if sens is not None:
        prev = prev + '_' + sens
        subtitle += f'\n ({sens_label[sens]})'

    staff_by_case = []
    students_by_case = []
    for _, case in enumerate(cases):
        df = outputs_df(date_of_file, case, sens)
        scenario_strategies = df.columns[1:]
        scenario_strategies = scenario_strategies.tolist()
        scenario_strategies.remove('all_remote')

        staff = []
        students = []
        num_staff = df[df['Unnamed: 0'] == 'num_staff']['as_normal'].values
        num_staff += df[df['Unnamed: 0'] == 'num_teachers']['as_normal'].values
        num_students = df[df['Unnamed: 0'] ==
                          'num_students']['as_normal'].values
        for n, strategy in enumerate(scenario_strategies):
            #num_staff = df[df['Unnamed: 0'] == 'num_staff'][strategy].values
            #num_staff += df[df['Unnamed: 0'] == 'num_teachers'][strategy].values

            num_staff_cases = df[df['Unnamed: 0'] ==
                                 'num_staff_cases'][strategy].values
            num_staff_cases += df[df['Unnamed: 0'] ==
                                  'num_teacher_cases'][strategy].values

            staff.append(100 * num_staff_cases[0] / num_staff[0])

            #num_students = df[df['Unnamed: 0'] == 'num_students'][strategy].values
            num_student_cases = df[df['Unnamed: 0'] ==
                                   'num_student_cases'][strategy].values
            students.append(100 * num_student_cases[0] / num_students[0])

        staff = pd.DataFrame(staff).transpose()
        staff_by_case.append(staff)
        students = pd.DataFrame(students).transpose()
        students_by_case.append(students)

    x = np.arange(len(scenario_strategies))

    width = [-.2, 0, .2]
    width_text = [-.28, -.09, .12]

    fig, axs = plt.subplots(nrows=2,
                            sharex=True,
                            sharey=False,
                            figsize=(13, 10))  # 13 x 9
    fig.subplots_adjust(hspace=0.3, right=0.9, bottom=0.15)
    #fig.suptitle(f'Predicted cumulative COVID-19 infection rate from Sep. 1 to Dec. 1 for people in schools', size=24, horizontalalignment='center')

    for i, ax in enumerate(axs):
        for j, case in enumerate(cases):
            if i == 0:
                ax.bar(x + width[j],
                       staff_by_case[j].values[0],
                       width=0.2,
                       label=label[case],
                       color=colors[j])
                for h in range(len(x)):
                    ax.text(h + width_text[j],
                            0.5 + staff_by_case[j][h].values,
                            round(staff_by_case[j][h].values[0], 1),
                            fontsize=13)
                ax.set_title('Teachers and staff',
                             size=20,
                             horizontalalignment='center')
            else:
                ax.bar(x + width[j],
                       students_by_case[j].values[0],
                       width=0.2,
                       label=label[case],
                       color=colors[j])
                for h in range(len(x)):
                    ax.text(h + width_text[j],
                            0.5 + students_by_case[j][h].values,
                            round(students_by_case[j][h].values[0], 1),
                            fontsize=13)
                ax.set_title('Students', size=20, horizontalalignment='center')
        ax.set_ylabel('COVID-19 infection rate (%)', size=20)
        ax.set_ylim([0, 27])
        ax.set_xticks(x)
        ax.set_xticklabels(strategy_labels_brief.values(), fontsize=17)  # 14
        if i == 0:
            leg_i = ax.legend(fontsize=20, title=name)
            leg_i.set_title(name, prop={'size': 20})
        #ax.legend(fontsize=16, title=name)

    cv.savefig(f'attack_rate_{prev}_{date_of_file}.png')
    plt.savefig(f'attack_rate_{prev}_{date_of_file}.pdf')
Beispiel #27
0
    msim.run(n_runs=20, par_args={'ncpus': 5})
    msim.reduce()
    msim.save(msimfile)
else:
    msim = cv.load(msimfile)

#%% Plotting
for interv in msim.base_sim['interventions']:
    interv.do_plot = False

to_plot = ['cum_diagnoses', 'new_diagnoses', 'cum_deaths', 'new_deaths']
fig_args = dict(figsize=(18, 18))
scatter_args = dict(alpha=0.3, marker='o')
dateformat = '%d %b'

fig = msim.plot(to_plot=to_plot,
                n_cols=1,
                fig_args=fig_args,
                scatter_args=scatter_args,
                dateformat=dateformat)
fit = msim.base_sim.compute_fit(weights={'cum_diagnoses': 1, 'cum_deaths': 1})
print('Average daily mismatch: ',
      fit.mismatch / msim.base_sim['n_days'] / 2 * 100)

for ax in fig.axes:
    ax.legend(['Model', '80% modeled interval', 'Data'], loc='upper left')

if do_save:
    cv.savefig(figfile)

print('Done.')
Beispiel #28
0
def plot_dimensions(date_of_file, cases, sens):

    name = 'Cases per 100k in last 14 days'
    if '1.1' in cases[0]:
        prev = 'by_cases_rising'
        subtitle = '(Re > 1)'
    else:
        prev = 'by_cases_falling'
        subtitle = '(Re < 1)'
    label = inc_labels

    if sens is not None:
        prev = prev + '_' + sens
        subtitle += f'\n ({sens_label[sens]})'

    attack_rate_by_case = []
    perc_school_days_lost_by_case = []
    efficient_y_by_case = []
    efficient_x_by_case = []

    for _, case in enumerate(cases):
        df = outputs_df(date_of_file, case, sens)
        scenario_strategies = df.columns[1:]
        scenario_strategies = scenario_strategies.tolist()
        scenario_strategies.remove('all_remote')
        perc_school_days_lost = []
        # perc_school_days_lost = [0,	0,	60.76923077,	34.72494893,	59.94845764,	84.83944482]
        # total = df[df['Unnamed: 0'] == 'num_staff']['as_normal'].values
        # total += df[df['Unnamed: 0'] == 'num_teachers']['as_normal'].values
        # total = df[df['Unnamed: 0'] == 'num_students']['as_normal'].values
        attack_rate = []
        efficient_y = []
        efficient_x = []
        weekend_days = df[df['Unnamed: 0'] ==
                          'school_days_lost']['with_screening'].values
        total_school_days = df[df['Unnamed: 0'] ==
                               'student_school_days']['with_screening'].values
        for n, strategy in enumerate(scenario_strategies):
            total = df[df['Unnamed: 0'] == 'num_staff'][strategy].values
            total += df[df['Unnamed: 0'] == 'num_teachers'][strategy].values
            total += df[df['Unnamed: 0'] == 'num_students'][strategy].values

            num_cases = df[df['Unnamed: 0'] ==
                           'num_staff_cases'][strategy].values
            num_cases += df[df['Unnamed: 0'] ==
                            'num_teacher_cases'][strategy].values
            num_cases += df[df['Unnamed: 0'] ==
                            'num_student_cases'][strategy].values

            attack_rate.append(100 * num_cases[0] / total[0])

            school_days_lost = df[df['Unnamed: 0'] ==
                                  'school_days_lost'][strategy].values
            if strategy != 'as_normal':
                school_days_lost = school_days_lost[0] - weekend_days[0]
                perc_school_days_lost.append(100 * school_days_lost /
                                             total_school_days[0])
            else:
                perc_school_days_lost.append(100 * school_days_lost[0] /
                                             total_school_days[0])

            if strategy != 'with_hybrid_scheduling':
                efficient_y.append(attack_rate[n])
                efficient_x.append(perc_school_days_lost[n])

        attack_rate = pd.DataFrame(attack_rate).transpose()
        attack_rate_by_case.append(attack_rate)

        efficient_y = pd.DataFrame(efficient_y).transpose()
        efficient_y_by_case.append(efficient_y)
        efficient_x = pd.DataFrame(efficient_x).transpose()
        efficient_x_by_case.append(efficient_x)

        perc_school_days_lost = pd.DataFrame(perc_school_days_lost).transpose()
        perc_school_days_lost_by_case.append(perc_school_days_lost)

    n_strategies = len(scenario_strategies)

    size_min = 8
    size_max = 50
    num_sizes = len(cases)
    intervals = (size_max - size_min) / num_sizes
    sizes = np.arange(size_min, size_max, step=intervals).tolist()

    left = 0.09
    right = 0.62
    bottom = 0.10
    top = 0.95
    fig = plt.figure(figsize=(13, 9))
    fig.subplots_adjust(left=left, right=right, top=top, bottom=bottom)
    ax = fig.add_subplot(111)
    alpha = 0.67
    fs = 24

    for j, rate in enumerate(cases):
        ax.plot(efficient_x_by_case[j].iloc[0, :].values,
                efficient_y_by_case[j].iloc[0, :].values,
                linewidth=3,
                alpha=0.33,
                color='grey',
                linestyle='-')
        for i in range(n_strategies):
            ax.plot(
                perc_school_days_lost_by_case[j][i],
                attack_rate_by_case[j][i],
                marker='o',
                markersize=sizes[j],
                alpha=alpha,
                markerfacecolor=colors[i],
                markeredgewidth=0,
            )

        ax.set_xlabel('Days of Distance Learning (% of Total School Days)',
                      fontsize=fs)
        ax.set_ylabel('COVID-19 infection rate (%)',
                      fontsize=fs)  # 'Within-School Attack Rate (%)'
        # ax.set_ylim(0, 15)
        # ax.set_xlim(0, 10)
    ax.tick_params(labelsize=fs)

    # Strategies Legend
    ax_left = right + 0.04
    ax_bottom = bottom + 0.02
    ax_right = 0.95
    ax_width = ax_right - ax_left
    ax_height = (top - bottom) / 2.
    ax_leg = fig.add_axes([ax_left + 0.01, ax_bottom, ax_width, ax_height])
    for s, strat in enumerate(scenario_strategies):
        ax_leg.plot(-5,
                    -5,
                    color=colors[s],
                    label=strategy_labels_brief2[strat])

    leg = ax_leg.legend(loc='center', fontsize=fs)  # loc=10,
    leg.draw_frame(False)
    ax_leg.axis('off')

    # Mobility size legend
    ax_bottom_2 = ax_bottom + ax_height + 0.0
    ax_leg_2 = fig.add_axes([ax_left, ax_bottom_2, ax_width, ax_height])

    ybase = 0.8
    ytop = 1
    yinterval = 0.7 * (ytop - ybase) / len(sizes)

    for i in range(len(sizes)):
        xi = 1
        yi = ybase + yinterval * i
        si = sizes[i]

        ax_leg_2.plot(xi * 1.5,
                      yi,
                      linestyle=None,
                      marker='o',
                      markersize=si,
                      markerfacecolor='white',
                      markeredgecolor='black')
        ax_leg_2.text(xi * 4,
                      yi, (label[cases[i]]),
                      verticalalignment='center',
                      fontsize=fs)

    ax_leg_2.text(xi * 7 + 1.5,
                  0.95,
                  name,
                  horizontalalignment='center',
                  fontsize=fs)

    ax_leg_2.text(xi * 7 + 0.3,
                  0.75,
                  'School reopening scenario',
                  horizontalalignment='center',
                  fontsize=fs)

    ax_leg_2.axis('off')
    ax_leg_2.set_xlim(left=0, right=20)
    ax_leg_2.set_ylim(bottom=ybase * 0.9, top=ytop * 1.0)

    #ax.set_title(f'Trade-Offs with School Reopening', fontsize=20)
    cv.savefig(f'tradeoffs_{prev}_{date_of_file}.png')
    plt.savefig(f'tradeoffs_{prev}_{date_of_file}.pdf')
                palette='tab10')
###g.add_legend(fontsize=14)
g.set_titles(row_template="{row_name}", fontsize=24)
#xtl = g.axes[1,0].get_xticklabels()
#xtl = [l.get_text() for l in xtl]
#g.set(xticklabels=[scen_names[k] if k in scen_names else k for k in xtl])
g.set_axis_labels(y_var="Days lost (%)")
plt.tight_layout()

for ax in g.axes.flat:
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.7, box.height])

g.axes.flat[0].legend(loc='upper left', bbox_to_anchor=(1, 0.3))

cv.savefig(os.path.join(imgdir, '3mInPersonDaysLost_sensitivity.png'), dpi=300)

# Attack rate
d = pd.melt(df,
            id_vars=['key1', 'key2', 'key3'],
            value_vars=[f'attackrate_{gkey}' for gkey in grp_dict.keys()],
            var_name='Group',
            value_name='Cum Inc (%)')
d.replace({'Group': {f'attackrate_{gkey}': gkey
                     for gkey in grp_dict.keys()}},
          inplace=True)
d = d.loc[d['key1'] == 'k5']  # K-5 only
g = sns.FacetGrid(data=d,
                  row='Group',
                  height=4,
                  aspect=3,
        ax[pn] = pl.axes([
            xgapl + (dx + xgapm) * (pn % ncols),
            ygapb + (ygapm + dy) * (pn // ncols), dx, dy
        ])
        ax[pn] = sns.heatmap(dfs[res][scen],
                             xticklabels=4,
                             yticklabels=4,
                             cmap=sns.cm.rocket_r,
                             vmin=0,
                             vmax=cbar_lims[res],
                             cbar=pn == 0,
                             cbar_ax=None if pn else cbar_ax,
                             cbar_kws={'label': label})
        ax[pn].set_ylim(ax[pn].get_ylim()[::-1])

        if (pn % ncols) != 0:
            ax[pn].set_yticklabels([])
        else:
            ax[pn].set_ylabel('Symptomatic testing')
            ax[pn].set_yticklabels(
                [f'{int(i*100)}%' for i in np.linspace(0, 1, 6)], rotation=0)
        if pn not in range(ncols):
            ax[pn].set_xticklabels([])
        else:
            ax[pn].set_xlabel('% of contacts traced')
            ax[pn].set_xticklabels(
                [f'{int(i * 100)}%' for i in np.linspace(0, 1, 6)])

    cv.savefig(f'{figsfolder}/fig_sweeps_{res}.png', dpi=100)

sc.toc(T)