def plotter(key, sims, ax, label='', ylabel='', low_q=0.05, high_q=0.95, startday=None): color = cv.get_colors()[key.split('_')[1]] ys = [] for s in sims: ys.append(s.results[key].values) yarr = np.array(ys) best = pl.median(yarr, axis=0) low = pl.quantile(yarr, q=low_q, axis=0) high = pl.quantile(yarr, q=high_q, axis=0) sim = sims[0] tvec = np.arange(len(best)) fill_label = None pl.fill_between(tvec, low, high, facecolor=color, alpha=0.2, label=fill_label) pl.plot(tvec, best, c=color, label=label, lw=4, alpha=1.0) sc.setylim() datemarks = pl.array([sim.day('2020-03-01'),sim.day('2020-05-01'),sim.day('2020-07-01'), sim.day('2020-09-01')]) ax.set_xticks(datemarks) pl.ylabel(ylabel) return
def fixaxis(sim, useSI=True, boxoff=False): ''' Make the plotting more consistent -- add a legend and ensure the axes start at 0 ''' delta = 0.5 pl.legend() # Add legend sc.setylim() # Rescale y to start at 0 pl.xlim((0, sim['n_days'] + delta)) if boxoff: sc.boxoff() # Turn off top and right lines return
def plot(self): pl.figure() pl.plot(self.t, self.S, label='S') pl.plot(self.t, self.E, label='E') pl.plot(self.t, self.I, label='I') pl.plot(self.t, self.R, label='R') pl.legend() pl.xlabel('Day') pl.ylabel('People') sc.setylim() # Reset y-axis to start at 0 sc.commaticks() # Use commas in the y-axis labels return
def format_axs(axs, key=None): ''' Format axes nicely ''' @ticker.FuncFormatter def date_formatter(x, pos): # print(x) return (refsim['start_day'] + dt.timedelta(days=x)).strftime('%b-%d') for i, ax in enumerate(axs): bbox = None if i != 1 else (1.05, 1.05) # Move legend up a bit day_stride = 21 xmin, xmax = ax.get_xlim() ax.set_xticks(np.arange(xmin, xmax, day_stride)) ax.xaxis.set_major_formatter(date_formatter) ax.legend(frameon=False, bbox_to_anchor=bbox) sc.boxoff(ax=ax) sc.setylim(ax=ax) sc.commaticks(ax=ax) return
def plotter(key, sims, ax, label='', ylabel='', low_q=0.05, high_q=0.95, subsample=2): which = key.split('_')[1] try: color = cv.get_colors()[which] except: color = [0.5,0.5,0.5] ys = [] for s in sims: ys.append(s.results[key].values) yarr = np.array(ys) best = pl.median(yarr, axis=0) low = pl.quantile(yarr, q=low_q, axis=0) high = pl.quantile(yarr, q=high_q, axis=0) tvec = np.arange(len(best)) # tempsim = cv.Sim(datafile='../UK_Covid_cases_january03.xlsx') # sim = sims[0] # if key in tempsim.data: # data_t = np.array((tempsim.data.index-sim['start_day'])/np.timedelta64(1,'D')) # inds = np.arange(0, len(data_t), subsample) # data = tempsim.data[key][inds] # pl.plot(data_t[inds], data, 'd', c=color, markersize=10, alpha=0.5, label='Data') fill_label = None end = None start = 2 if key == 'r_eff' else 0 pl.fill_between(tvec[start:end], low[start:end], high[start:end], facecolor=color, alpha=0.2, label=fill_label) pl.plot(tvec[start:end], best[start:end], c=color, label=label, lw=4, alpha=1.0) sc.setylim() datemarks = pl.array([sim.day('2020-03-01'),sim.day('2020-06-01'), sim.day('2020-09-01'),sim.day('2020-09-01'), sim.day('2020-12-01'),sim.day('2021-03-01'), sim.day('2021-05-01')]) ax.set_xticks(datemarks) pl.ylabel(ylabel) return
def plot(self, to_plot=None, do_save=None, fig_path=None, fig_args=None, plot_args=None, scatter_args=None, axis_args=None, legend_args=None, as_dates=True, dateformat=None, interval=None, n_cols=1, font_size=18, font_family=None, use_grid=True, use_commaticks=True, do_show=True, verbose=None): ''' Plot the results -- can supply arguments for both the figure and the plots. Args: to_plot (dict): Nested dict of results to plot; see default_sim_plots for structure do_save (bool or str): Whether or not to save the figure. If a string, save to that filename. fig_path (str): Path to save the figure fig_args (dict): Dictionary of kwargs to be passed to pl.figure() plot_args (dict): Dictionary of kwargs to be passed to pl.plot() scatter_args (dict): Dictionary of kwargs to be passed to pl.scatter() axis_args (dict): Dictionary of kwargs to be passed to pl.subplots_adjust() legend_args (dict): Dictionary of kwargs to be passed to pl.legend() as_dates (bool): Whether to plot the x-axis as dates or time points dateformat (str): Date string format, e.g. '%B %d' interval (int): Interval between tick marks n_cols (int): Number of columns of subpanels to use for subplot font_size (int): Size of the font font_family (str): Font face use_grid (bool): Whether or not to plot gridlines use_commaticks (bool): Plot y-axis with commas rather than scientific notation do_show (bool): Whether or not to show the figure verbose (bool): Display a bit of extra information Returns: fig: Figure handle ''' if verbose is None: verbose = self['verbose'] sc.printv('Plotting...', 1, verbose) if to_plot is None: to_plot = cvd.default_sim_plots to_plot = sc.odict(to_plot) # In case it's supplied as a dict # Handle input arguments -- merge user input with defaults fig_args = sc.mergedicts({'figsize': (16, 14)}, fig_args) plot_args = sc.mergedicts({'lw': 3, 'alpha': 0.7}, plot_args) scatter_args = sc.mergedicts({'s': 70, 'marker': 's'}, scatter_args) axis_args = sc.mergedicts( { 'left': 0.1, 'bottom': 0.05, 'right': 0.9, 'top': 0.97, 'wspace': 0.2, 'hspace': 0.25 }, axis_args) legend_args = sc.mergedicts({'loc': 'best'}, legend_args) fig = pl.figure(**fig_args) pl.subplots_adjust(**axis_args) pl.rcParams['font.size'] = font_size if font_family: pl.rcParams['font.family'] = font_family res = self.results # Shorten since heavily used # Plot everything n_rows = np.ceil(len(to_plot) / n_cols) # Number of subplot rows to have for p, title, keylabels in to_plot.enumitems(): ax = pl.subplot(n_rows, n_cols, p + 1) for key in keylabels: label = res[key].name this_color = res[key].color y = res[key].values pl.plot(res['t'], y, label=label, **plot_args, c=this_color) if self.data is not None and key in self.data: data_t = ( self.data.index - self['start_day'] ) / np.timedelta64( 1, 'D' ) # Convert from data date to model output index based on model start date pl.scatter(data_t, self.data[key], c=[this_color], **scatter_args) if self.data is not None and len(self.data): pl.scatter(pl.nan, pl.nan, c=[(0, 0, 0)], label='Data', **scatter_args) pl.legend(**legend_args) pl.grid(use_grid) sc.setylim() if use_commaticks: sc.commaticks() pl.title(title) # Optionally reset tick marks (useful for e.g. plotting weeks/months) if interval: xmin, xmax = ax.get_xlim() ax.set_xticks(pl.arange(xmin, xmax + 1, interval)) # Set xticks as dates if as_dates: @ticker.FuncFormatter def date_formatter(x, pos): return (self['start_day'] + dt.timedelta(days=x)).strftime('%b-%d') ax.xaxis.set_major_formatter(date_formatter) if not interval: ax.xaxis.set_major_locator( ticker.MaxNLocator(integer=True)) # Plot interventions for intervention in self['interventions']: intervention.plot(self, ax) # Ensure the figure actually renders or saves if do_save: if fig_path is None: # No figpath provided - see whether do_save is a figpath if isinstance(do_save, str): fig_path = do_save # It's a string, assume it's a filename else: fig_path = 'covasim.png' # Just give it a default name fig_path = sc.makefilepath( fig_path) # Ensure it's valid, including creating the folder pl.savefig(fig_path) if do_show: pl.show() else: pl.close(fig) return fig
def plotter(key, sims, ax, ys=None, calib=False, label='', ylabel='', low_q=0.1, high_q=0.9): which = key.split('_')[1] try: color = cv.get_colors()[which] except: color = [0.5,0.5,0.5] if which == 'deaths': color = [0.5,0.0,0.0] if ys is None: ys = [] for s in sims: ys.append(s.results[key].values) yarr = np.array(ys) best = pl.median(yarr, axis=0) # Changed from median to mean for smoother plots low = pl.quantile(yarr, q=low_q, axis=0) high = pl.quantile(yarr, q=high_q, axis=0) sim = sims[0] # For having a sim to refer to # Formatting parameters plot_args = sc.mergedicts({'lw': 3, 'alpha': 0.8}) fill_args = sc.mergedicts({'alpha': 0.2}) tvec = np.arange(len(best)) if calib: if key == 'r_eff': end = -2 else: end = -1 else: end = None pl.fill_between(tvec[:end], low[:end], high[:end], facecolor=color, **fill_args) pl.plot(tvec[:end], best[:end], c=color, label=label, **plot_args) if key in sim.data: data_t = np.array((sim.data.index-sim['start_day'])/np.timedelta64(1,'D')) pl.plot(data_t, sim.data[key], 'o', c=color, markersize=10, label='Data') if calib: xlims = pl.xlim() pl.xlim([13, xlims[1]-1]) else: pl.xlim([0,94]) sc.setylim() xmin,xmax = ax.get_xlim() if calib: ax.set_xticks(pl.arange(xmin+2, xmax, 7)) else: ax.set_xticks(pl.arange(xmin+2, xmax, 7)) pl.ylabel(ylabel) pl.legend(loc='upper left') return
def plotter(key, sims, ax, ys=None, calib=False, label='', ylabel='', low_q=0.025, high_q=0.975, flabel=True, startday=None, subsample=2, chooseseed=None): which = key.split('_')[1] try: color = cv.get_colors()[which] except: color = [0.5, 0.5, 0.5] if which == 'diagnoses': color = [0.03137255, 0.37401, 0.63813918, 1.] elif which == '': color = [0.82400815, 0., 0., 1.] if ys is None: ys = [] for s in sims: ys.append(s.results[key].values) yarr = np.array(ys) if chooseseed is not None: best = sims[chooseseed].results[key].values else: best = pl.median(yarr, axis=0) low = pl.quantile(yarr, q=low_q, axis=0) high = pl.quantile(yarr, q=high_q, axis=0) sim = sims[0] # For having a sim to refer to tvec = np.arange(len(best)) if key in sim.data: data_t = np.array( (sim.data.index - sim['start_day']) / np.timedelta64(1, 'D')) inds = np.arange(0, len(data_t), subsample) pl.plot(data_t[inds], sim.data[key][inds], 'd', c=color, markersize=15, alpha=0.75, label='Data') start = None if startday is not None: start = sim.day(startday) end = sim.day(calibration_end) if flabel: if which == 'infections': fill_label = '95% projected interval' else: fill_label = '95% projected interval' else: fill_label = None pl.fill_between(tvec[startday:end], low[startday:end], high[startday:end], facecolor=color, alpha=0.2, label=fill_label) pl.plot(tvec[startday:end], best[startday:end], c=color, label=label, lw=4, alpha=1.0) # Print some stats if key == 'cum_infections': print( f'Estimated {which} on July 25: {best[sim.day("2020-07-25")]} (95%: {low[sim.day("2020-07-25")]}-{high[sim.day("2020-07-25")]})' ) print( f'Estimated {which} overall: {best[sim.day(calibration_end)]} (95%: {low[sim.day(calibration_end)]}-{high[sim.day(calibration_end)]})' ) elif key == 'n_infectious': peakday = sc.findnearest(best, max(best)) peakval = max(best) print( f'Estimated peak {which} on {sim.date(peakday)}: {peakval} (95%: {low[peakday]}-{high[peakday]})' ) print( f'Estimated {which} on last day: {best[sim.day(calibration_end)]} (95%: {low[sim.day(calibration_end)]}-{high[sim.day(calibration_end)]})' ) elif key == 'cum_diagnoses': print( f'Estimated {which} overall: {best[sim.day(calibration_end)]} (95%: {low[sim.day(calibration_end)]}-{high[sim.day(calibration_end)]})' ) sc.setylim() xmin, xmax = ax.get_xlim() if calib: ax.set_xticks(pl.arange(xmin + 2, xmax, 28)) else: ax.set_xticks(pl.arange(xmin + 2, xmax, 28)) pl.ylabel(ylabel) datemarks = pl.array([ sim.day('2020-07-01'), sim.day('2020-08-01'), sim.day('2020-09-01'), sim.day('2020-10-01') ]) * 1. ax.set_xticks(datemarks) return
'd', c='k', markersize=12, alpha=0.75, label='Data') toplot = plotdict['new_diagnoses'][l][date_inds[0]:date_inds[1]] pl.plot(tvec, toplot, c=colors[i], label=l, lw=4, alpha=1.0) #low = plotdict_l['new_diagnoses'][l][date_inds[0]:date_inds[1]] #high = plotdict_h['new_diagnoses'][l][date_inds[0]:date_inds[1]] #pl.fill_between(tvec, low, high, facecolor=colors[i], alpha=0.2) pl.ylabel('Daily new infections') ax = pl.gca() ax.set_xticks(datemarks) cv.date_formatter(start_day=start_day, ax=ax, dateformat=dateformat) sc.setylim() sc.commaticks() pl.legend(frameon=False) sc.boxoff() # Plot B: R_eff pl.subplot(2, 2, 2) colors = pl.cm.GnBu([0.9, 0.6, 0.3]) for i, l in enumerate(labels): toplot = plotdict['r_eff'][l][date_inds[0]:date_inds[1]] pl.plot(tvec, toplot, c=colors[i], label=l, lw=4, alpha=1.0) low = plotdict_l['r_eff'][l][date_inds[0]:date_inds[1]] high = plotdict_h['r_eff'][l][date_inds[0]:date_inds[1]] pl.fill_between(tvec, low, high, facecolor=colors[i], alpha=0.2) pl.ylabel('R') pl.axhline(1, linestyle=':', c='k', alpha=0.3)
def plot(): fig = pl.figure(num='Fig. 2: Transmission dynamics', figsize=(20,14)) piey, tsy, r3y = 0.68, 0.50, 0.07 piedx, tsdx, r3dx = 0.2, 0.9, 0.25 piedy, tsdy, r3dy = 0.2, 0.47, 0.35 pie1x, pie2x = 0.12, 0.65 tsx = 0.07 dispx, cumx, sympx = tsx, 0.33+tsx, 0.66+tsx ts_ax = pl.axes([tsx, tsy, tsdx, tsdy]) pie_ax1 = pl.axes([pie1x, piey, piedx, piedy]) pie_ax2 = pl.axes([pie2x, piey, piedx, piedy]) symp_ax = pl.axes([sympx, r3y, r3dx, r3dy]) disp_ax = pl.axes([dispx, r3y, r3dx, r3dy]) cum_ax = pl.axes([cumx, r3y, r3dx, r3dy]) off = 0.06 txtdispx, txtcumx, txtsympx = dispx-off, cumx-off, sympx-off+0.02 tsytxt = tsy+tsdy r3ytxt = r3y+r3dy labelsize = 40-wf pl.figtext(txtdispx, tsytxt, 'a', fontsize=labelsize) pl.figtext(txtdispx, r3ytxt, 'b', fontsize=labelsize) pl.figtext(txtcumx, r3ytxt, 'c', fontsize=labelsize) pl.figtext(txtsympx, r3ytxt, 'd', fontsize=labelsize) #%% Fig. 2A -- Time series plot layer_keys = list(sim.layer_keys()) layer_mapping = {k:i for i,k in enumerate(layer_keys)} n_layers = len(layer_keys) colors = sc.gridcolors(n_layers) layer_counts = np.zeros((sim.npts, n_layers)) for source_ind, target_ind in tt.count_transmissions(): dd = tt.detailed[target_ind] date = dd['date'] layer_num = layer_mapping[dd['layer']] layer_counts[date, layer_num] += sim.rescale_vec[date] mar12 = cv.date('2020-03-12') mar23 = cv.date('2020-03-23') mar12d = sim.day(mar12) mar23d = sim.day(mar23) labels = ['Household', 'School', 'Workplace', 'Community', 'LTCF'] for l in range(n_layers): ts_ax.plot(sim.datevec, layer_counts[:,l], c=colors[l], lw=3, label=labels[l]) sc.setylim(ax=ts_ax) sc.boxoff(ax=ts_ax) ts_ax.set_ylabel('Transmissions per day') ts_ax.set_xlim([sc.readdate('2020-01-18'), sc.readdate('2020-06-09')]) ts_ax.xaxis.set_major_formatter(mdates.DateFormatter('%b-%d')) ts_ax.set_xticks([sim.date(d, as_date=True) for d in np.arange(0, sim.day('2020-06-09'), 14)]) ts_ax.legend(frameon=False, bbox_to_anchor=(0.85,0.1)) color = [0.2, 0.2, 0.2] ts_ax.axvline(mar12, c=color, linestyle='--', alpha=0.4, lw=3) ts_ax.axvline(mar23, c=color, linestyle='--', alpha=0.4, lw=3) yl = ts_ax.get_ylim() labely = yl[1]*1.015 ts_ax.text(mar12, labely, 'Schools close ', color=color, alpha=0.9, style='italic', horizontalalignment='center') ts_ax.text(mar23, labely, ' Stay-at-home', color=color, alpha=0.9, style='italic', horizontalalignment='center') #%% Fig. 2A inset -- Pie charts pre_counts = layer_counts[0:mar12d, :].sum(axis=0) post_counts = layer_counts[mar23d:, :].sum(axis=0) pre_counts = pre_counts/pre_counts.sum()*100 post_counts = post_counts/post_counts.sum()*100 lpre = [ f'Household\n{pre_counts[0]:0.1f}%', f'School\n{pre_counts[1]:0.1f}% ', f'Workplace\n{pre_counts[2]:0.1f}% ', f'Community\n{pre_counts[3]:0.1f}%', f'LTCF\n{pre_counts[4]:0.1f}%', ] lpost = [ f'Household\n{post_counts[0]:0.1f}%', f'School\n{post_counts[1]:0.1f}%', f'Workplace\n{post_counts[2]:0.1f}%', f'Community\n{post_counts[3]:0.1f}%', f'LTCF\n{post_counts[4]:0.1f}%', ] pie_ax1.pie(pre_counts, colors=colors, labels=lpre, **pieargs) pie_ax2.pie(post_counts, colors=colors, labels=lpost, **pieargs) pie_ax1.text(0, 1.75, 'Transmissions by layer\nbefore schools closed', style='italic', horizontalalignment='center') pie_ax2.text(0, 1.75, 'Transmissions by layer\nafter stay-at-home', style='italic', horizontalalignment='center') #%% Fig. 2B -- histogram by overdispersion # Process targets n_targets = tt.count_targets(end_day=mar12) # Handle bins max_infections = n_targets.max() edges = np.arange(0, max_infections+2) # Analysis counts = np.histogram(n_targets, edges)[0] bins = edges[:-1] # Remove last bin since it's an edge norm_counts = counts/counts.sum() raw_counts = counts*bins total_counts = raw_counts/raw_counts.sum()*100 n_bins = len(bins) index = np.linspace(0, 100, len(n_targets)) sorted_arr = np.sort(n_targets) sorted_sum = np.cumsum(sorted_arr) sorted_sum = sorted_sum/sorted_sum.max()*100 change_inds = sc.findinds(np.diff(sorted_arr) != 0) pl.set_cmap('Spectral_r') sscolors = sc.vectocolor(n_bins) width = 1.0 for i in range(n_bins): disp_ax.bar(bins[i], total_counts[i], width=width, facecolor=sscolors[i]) disp_ax.set_xlabel('Number of transmissions per case') disp_ax.set_ylabel('Proportion of transmissions (%)') sc.boxoff() disp_ax.set_xlim([0.5, 32.5]) disp_ax.set_xticks(np.arange(0, 32.5, 4)) sc.boxoff(ax=disp_ax) dpie_ax = pl.axes([dispx+0.05, 0.20, 0.2, 0.2]) trans1 = total_counts[1:3].sum() trans2 = total_counts[3:5].sum() trans3 = total_counts[5:8].sum() trans4 = total_counts[8:].sum() labels = [ f'1-2:\n{trans1:0.0f}%', f' 3-4:\n {trans2:0.0f}%', f'5-7: \n{trans3:0.0f}%\n', f'>7: \n{trans4:0.0f}%\n', ] dpie_args = sc.mergedicts(pieargs, dict(labeldistance=1.2)) # Slightly smaller label distance dpie_ax.pie([trans1, trans2, trans3, trans4], labels=labels, colors=sscolors[[0,4,7,12]], **dpie_args) #%% Fig. 2C -- cumulative distribution function rev_ind = 100 - index n_change_inds = len(change_inds) change_bins = bins[counts>0][1:] for i in range(n_change_inds): ib = int(change_bins[i]) ci = change_inds[i] ici = index[ci] sci = sorted_sum[ci] color = sscolors[ib] if i>0: cim1 = change_inds[i-1] icim1 = index[cim1] scim1 = sorted_sum[cim1] cum_ax.plot([icim1, ici], [scim1, sci], lw=4, c=color) cum_ax.scatter([ici], [sci], s=150, zorder=50-i, c=[color], edgecolor='w', linewidth=0.2) if ib<=6 or ib in [8, 10, 25]: xoff = 5 - 2*(ib==1) + 3*(ib>=10) + 1*(ib>=20) yoff = 2*(ib==1) cum_ax.text(ici-xoff, sci+yoff, ib, fontsize=18-wf, color=color) cum_ax.set_xlabel('Proportion of primary infections (%)') cum_ax.set_ylabel('Proportion of transmissions (%)') xmin = -2 ymin = -2 cum_ax.set_xlim([xmin, 102]) cum_ax.set_ylim([ymin, 102]) sc.boxoff(ax=cum_ax) # Draw horizontal lines and annotations ancol1 = [0.2, 0.2, 0.2] ancol2 = sscolors[0] ancol3 = sscolors[6] i01 = sc.findlast(sorted_sum==0) i20 = sc.findlast(sorted_sum<=20) i50 = sc.findlast(sorted_sum<=50) cum_ax.plot([xmin, index[i01]], [0, 0], '--', lw=2, c=ancol1) cum_ax.plot([xmin, index[i20], index[i20]], [20, 20, ymin], '--', lw=2, c=ancol2) cum_ax.plot([xmin, index[i50], index[i50]], [50, 50, ymin], '--', lw=2, c=ancol3) # Compute mean number of transmissions for 80% and 50% thresholds q80 = sc.findfirst(np.cumsum(total_counts)>20) # Count corresponding to 80% of cumulative infections (100-80) q50 = sc.findfirst(np.cumsum(total_counts)>50) # Count corresponding to 50% of cumulative infections n80, n50 = [sum(bins[q:]*norm_counts[q:]/norm_counts[q:].sum()) for q in [q80, q50]] # Plot annotations kw = dict(bbox=dict(facecolor='w', alpha=0.9, lw=0), fontsize=20-wf) cum_ax.text(2, 3, f'{index[i01]:0.0f}% of infections\ndo not transmit', c=ancol1, **kw) cum_ax.text(8, 23, f'{rev_ind[i20]:0.0f}% of infections cause\n80% of transmissions\n(mean: {n80:0.1f} per infection)', c=ancol2, **kw) cum_ax.text(14, 53, f'{rev_ind[i50]:0.0f}% of infections cause\n50% of transmissions\n(mean: {n50:0.1f} per infection)', c=ancol3, **kw) #%% Fig. 2D -- histogram by date of symptom onset # Calculate asymp_count = 0 symp_counts = {} minind = -5 maxind = 15 for _, target_ind in tt.transmissions: dd = tt.detailed[target_ind] date = dd['date'] delta = sim.rescale_vec[date] # Increment counts by this much if dd['s']: if tt.detailed[dd['source']]['date'] <= date: # Skip dynamical scaling reinfections sdate = dd['s']['date_symptomatic'] if np.isnan(sdate): asymp_count += delta else: ind = int(date - sdate) if ind not in symp_counts: symp_counts[ind] = 0 symp_counts[ind] += delta # Convert to an array xax = np.arange(minind-1, maxind+1) sympcounts = np.zeros(len(xax)) for i,val in symp_counts.items(): if i<minind: ind = 0 elif i>maxind: ind = -1 else: ind = sc.findinds(xax==i)[0] sympcounts[ind] += val # Plot total_count = asymp_count + sympcounts.sum() sympcounts = sympcounts/total_count*100 presymp = sc.findinds(xax<=0)[-1] colors = ['#eed15b', '#ee943a', '#c3211a'] asymp_frac = asymp_count/total_count*100 pre_frac = sympcounts[:presymp].sum() symp_frac = sympcounts[presymp:].sum() symp_ax.bar(xax[0]-2, asymp_frac, label='Asymptomatic', color=colors[0]) symp_ax.bar(xax[:presymp], sympcounts[:presymp], label='Presymptomatic', color=colors[1]) symp_ax.bar(xax[presymp:], sympcounts[presymp:], label='Symptomatic', color=colors[2]) symp_ax.set_xlabel('Days since symptom onset') symp_ax.set_ylabel('Proportion of transmissions (%)') symp_ax.set_xticks([minind-3, 0, 5, 10, maxind]) symp_ax.set_xticklabels(['Asymp.', '0', '5', '10', f'>{maxind}']) sc.boxoff(ax=symp_ax) spie_ax = pl.axes([sympx+0.05, 0.20, 0.2, 0.2]) labels = [f'Asymp-\ntomatic\n{asymp_frac:0.0f}%', f' Presymp-\n tomatic\n {pre_frac:0.0f}%', f'Symp-\ntomatic\n{symp_frac:0.0f}%'] spie_ax.pie([asymp_frac, pre_frac, symp_frac], labels=labels, colors=colors, **pieargs) return fig
def plot(self, to_plot=None, do_save=None, fig_path=None, fig_args=None, plot_args=None, axis_args=None, fill_args=None, as_dates=True, interval=None, dateformat=None, font_size=18, font_family=None, grid=True, commaticks=True, do_show=True, sep_figs=False, verbose=None): ''' Plot the results -- can supply arguments for both the figure and the plots. Args: to_plot (dict): Dict of results to plot; see default_scen_plots for structure do_save (bool): Whether or not to save the figure fig_path (str): Path to save the figure fig_args (dict): Dictionary of kwargs to be passed to pl.figure ( ) plot_args (dict): Dictionary of kwargs to be passed to pl.plot ( ) axis_args (dict): Dictionary of kwargs to be passed to pl.subplots_adjust ( ) fill_args (dict): Dictionary of kwargs to be passed to pl.fill_between ( ) as_dates (bool): Whether to plot the x-axis as dates or time points interval (int): Interval between tick marks dateformat (str): Date string format, e.g. '%B %d' font_size (int): Size of the font font_family (str): Font face grid (bool): Whether or not to plot gridlines commaticks (bool): Plot y-axis with commas rather than scientific notation do_show (bool): Whether or not to show the figure sep_figs (bool): Whether to show separate figures for different results instead of subplots verbose (bool): Display a bit of extra information Returns: fig: Figure handle ''' if verbose is None: verbose = self['verbose'] sc.printv('Plotting...', 1, verbose) if to_plot is None: to_plot = default_scen_plots to_plot = sc.odict(sc.dcp(to_plot)) # In case it's supplied as a dict # Handle input arguments -- merge user input with defaults fig_args = sc.mergedicts({'figsize': (16, 12)}, fig_args) plot_args = sc.mergedicts({'lw': 3, 'alpha': 0.7}, plot_args) axis_args = sc.mergedicts( { 'left': 0.10, 'bottom': 0.05, 'right': 0.95, 'top': 0.90, 'wspace': 0.5, 'hspace': 0.25 }, axis_args) fill_args = sc.mergedicts({'alpha': 0.2}, fill_args) if sep_figs: figs = [] else: fig = pl.figure(**fig_args) pl.subplots_adjust(**axis_args) pl.rcParams['font.size'] = font_size if font_family: pl.rcParams['font.family'] = font_family # %% Plotting for rk, reskey, title in to_plot.enumitems(): if sep_figs: figs.append(pl.figure(**fig_args)) ax = pl.subplot(111) else: ax = pl.subplot(len(to_plot), 1, rk + 1) resdata = self.allres[reskey] for scenkey, scendata in resdata.items(): pl.fill_between(self.tvec, scendata.low, scendata.high, **fill_args) pl.plot(self.tvec, scendata.best, label=scendata.name, **plot_args) pl.title(title) if rk == 0: pl.legend(loc='best') sc.setylim() pl.grid(grid) if commaticks: sc.commaticks() # Optionally reset tick marks (useful for e.g. plotting weeks/months) if interval: xmin, xmax = ax.get_xlim() ax.set_xticks(pl.arange(xmin, xmax + 1, interval)) # Set xticks as dates if as_dates: xticks = ax.get_xticks() xticklabels = self.base_sim.inds2dates( xticks, dateformat=dateformat) ax.set_xticklabels(xticklabels) # Ensure the figure actually renders or saves if do_save: if fig_path is None: # No figpath provided - see whether do_save is a figpath fig_path = 'covasim_scenarios.png' # Just give it a default name fig_path = sc.makefilepath( fig_path) # Ensure it's valid, including creating the folder pl.savefig(fig_path) if do_show: pl.show() else: pl.close(fig) return fig
dd = tt.detailed[target_ind] date = dd['date'] layer_num = layer_mapping[dd['layer']] layer_counts[date, layer_num] += sim.rescale_vec[date] lockdown1 = [sc.readdate('2020-03-23'),sc.readdate('2020-05-31')] lockdown2 = [sc.readdate('2020-11-05'),sc.readdate('2020-12-03')] lockdown3 = [sc.readdate('2021-01-04'),sc.readdate('2021-02-08')] labels = ['Household', 'School', 'Workplace', 'Community'] for l in range(n_layers): ax.plot(sim.datevec, layer_counts[:,l], c=colors[l], lw=3, label=labels[l]) ax.axvspan(lockdown1[0], lockdown1[1], color='steelblue', alpha=0.2, lw=0) ax.axvspan(lockdown2[0], lockdown2[1], color='steelblue', alpha=0.2, lw=0) ax.axvspan(lockdown3[0], lockdown3[1], color='lightblue', alpha=0.2, lw=0) sc.setylim(ax=ax) sc.boxoff(ax=ax) ax.set_ylabel('Transmissions per day') ax.set_xlim([sc.readdate('2020-01-21'), sc.readdate('2021-03-01')]) ax.xaxis.set_major_formatter(mdates.DateFormatter('%b\n%y')) datemarks = pl.array([sim.day('2020-02-01'), sim.day('2020-03-01'), sim.day('2020-04-01'), sim.day('2020-05-01'), sim.day('2020-06-01'), sim.day('2020-07-01'), sim.day('2020-08-01'), sim.day('2020-09-01'), sim.day('2020-10-01'), sim.day('2020-11-01'), sim.day('2020-12-01'), sim.day('2021-01-01'), sim.day('2021-02-01'), sim.day('2021-03-01')]) ax.set_xticks([sim.date(d, as_date=True) for d in datemarks]) ax.legend(frameon=False) yl = ax.get_ylim() labely = yl[1]*1.015
def plotter(key, sims, ax, ys=None, calib=False, label='', ylabel='', low_q=0.025, high_q=0.975, flabel=True, subsample=2): ''' Plot a single time series with uncertainty ''' which = key.split('_')[1] try: color = cv.get_colors()[which] except: color = [0.5, 0.5, 0.5] if which == 'deaths': color = [0.5, 0.0, 0.0] if ys is None: ys = [] for i, s in enumerate(sims): if i < sims_cutoff: ys.append(s.results[key].values) yarr = np.array(ys) best = pl.median(yarr, axis=0) # Changed from median to mean for smoother plots low = pl.quantile(yarr, q=low_q, axis=0) high = pl.quantile(yarr, q=high_q, axis=0) sim = sims[0] # For having a sim to refer to tvec = np.arange(len(best)) data, data_t = None, None if key in sim.data: data_t = np.array( (sim.data.index - sim['start_day']) / np.timedelta64(1, 'D')) inds = np.arange(0, len(data_t), subsample) data = sim.data[key][inds] pl.plot(data_t[inds], data, 'd', c=color, markersize=10, alpha=0.5, label='Data') end = None if flabel: if which == 'infections': fill_label = '95% predic-\ntion interval' else: fill_label = '95% prediction\ninterval' else: fill_label = None # Trim the beginning for r_eff and actually plot start = 2 if key == 'r_eff' else 0 pl.fill_between(tvec[start:end], low[start:end], high[start:end], facecolor=color, alpha=0.2, label=fill_label) pl.plot(tvec[start:end], best[start:end], c=color, label=label, lw=4, alpha=1.0) sc.setylim() xmin, xmax = ax.get_xlim() ax.set_xticks(np.arange(xmin, xmax, day_stride)) pl.ylabel(ylabel) plotres[key] = sc.objdict( dict(tvec=tvec, best=best, low=low, high=high, data=data, data_t=data_t)) return
def plot(): # Create the figure fig = pl.figure(num='Fig. 1: Calibration', figsize=(24, 20)) tx1, ty1 = 0.005, 0.97 tx2, ty2 = 0.545, 0.66 ty3 = 0.34 fsize = 40 pl.figtext(tx1, ty1, 'a', fontsize=fsize) pl.figtext(tx1, ty2, 'b', fontsize=fsize) pl.figtext(tx2, ty1, 'c', fontsize=fsize) pl.figtext(tx2, ty2, 'd', fontsize=fsize) pl.figtext(tx1, ty3, 'e', fontsize=fsize) pl.figtext(tx2, ty3, 'f', fontsize=fsize) #%% Fig. 1A: diagnoses x0, y0, dx, dy = 0.055, 0.73, 0.47, 0.24 ax1 = pl.axes([x0, y0, dx, dy]) format_ax(ax1, base_sim) plotter('cum_diagnoses', sims, ax1, calib=True, label='Model', ylabel='Cumulative diagnoses') pl.legend(loc='lower right', frameon=False) #%% Fig. 1B: deaths y0b = 0.42 ax2 = pl.axes([x0, y0b, dx, dy]) format_ax(ax2, base_sim) plotter('cum_deaths', sims, ax2, calib=True, label='Model', ylabel='Cumulative deaths') pl.legend(loc='lower right', frameon=False) #%% Fig. 1A-B inserts (histograms) agehists = [] for s, sim in enumerate(sims): agehist = sim['analyzers'][0] if s == 0: age_data = agehist.data agehists.append(agehist.hists[-1]) # Observed data x = age_data['age'].values pos = age_data['cum_diagnoses'].values death = age_data['cum_deaths'].values # Model outputs mposlist = [] mdeathlist = [] for hists in agehists: mposlist.append(hists['diagnosed']) mdeathlist.append(hists['dead']) mposarr = np.array(mposlist) mdeatharr = np.array(mdeathlist) low_q = 0.025 high_q = 0.975 mpbest = pl.median(mposarr, axis=0) mplow = pl.quantile(mposarr, q=low_q, axis=0) mphigh = pl.quantile(mposarr, q=high_q, axis=0) mdbest = pl.median(mdeatharr, axis=0) mdlow = pl.quantile(mdeatharr, q=low_q, axis=0) mdhigh = pl.quantile(mdeatharr, q=high_q, axis=0) w = 4 off = 2 # Insets x0s, y0s, dxs, dys = 0.11, 0.84, 0.17, 0.13 ax1s = pl.axes([x0s, y0s, dxs, dys]) c1 = [0.3, 0.3, 0.6] c2 = [0.6, 0.7, 0.9] xx = x + w - off pl.bar(x - off, pos, width=w, label='Data', facecolor=c1) pl.bar(xx, mpbest, width=w, label='Model', facecolor=c2) for i, ix in enumerate(xx): pl.plot([ix, ix], [mplow[i], mphigh[i]], c='k') ax1s.set_xticks(np.arange(0, 81, 20)) pl.xlabel('Age') pl.ylabel('Cases') sc.boxoff(ax1s) pl.legend(frameon=False, bbox_to_anchor=(0.7, 1.1)) y0sb = 0.53 ax2s = pl.axes([x0s, y0sb, dxs, dys]) c1 = [0.5, 0.0, 0.0] c2 = [0.9, 0.4, 0.3] pl.bar(x - off, death, width=w, label='Data', facecolor=c1) pl.bar(x + w - off, mdbest, width=w, label='Model', facecolor=c2) for i, ix in enumerate(xx): pl.plot([ix, ix], [mdlow[i], mdhigh[i]], c='k') ax2s.set_xticks(np.arange(0, 81, 20)) pl.xlabel('Age') pl.ylabel('Deaths') sc.boxoff(ax2s) pl.legend(frameon=False) sc.boxoff(ax1s) #%% Fig. 1C: infections x0, dx = 0.60, 0.38 ax3 = pl.axes([x0, y0, dx, dy]) format_ax(ax3, sim) # Plot SCAN data pop_size = 2.25e6 scan = pd.read_csv(scan_file) for i, r in scan.iterrows(): label = "Data" if i == 0 else None ts = np.mean(sim.day(r['since'], r['to'])) low = r['lower'] * pop_size high = r['upper'] * pop_size mean = r['mean'] * pop_size ax3.plot([ts] * 2, [low, high], alpha=1.0, color='k', zorder=1000) ax3.plot(ts, mean, 'o', markersize=7, color='k', alpha=0.5, label=label, zorder=1000) # Plot simulation plotter('cum_infections', sims, ax3, calib=True, label='Cumulative\ninfections\n(modeled)', ylabel='Infections') plotter('n_infectious', sims, ax3, calib=True, label='Active\ninfections\n(modeled)', ylabel='Infections', flabel=False) pl.legend(loc='upper left', frameon=False) pl.ylim([0, 130e3]) plot_intervs(sim) #%% Fig. 1C: R_eff ax4 = pl.axes([x0, y0b, dx, dy]) format_ax(ax4, sim, key='r_eff') plotter('r_eff', sims, ax4, calib=True, label='$R_{eff}$ (modeled)', ylabel=r'Effective reproduction number') pl.axhline(1, linestyle='--', lw=3, c='k', alpha=0.5) pl.legend(loc='upper right', frameon=False) plot_intervs(sim) #%% Fig. 1E # Do the plotting pl.subplots_adjust(left=0.04, right=0.52, bottom=0.03, top=0.35, wspace=0.12, hspace=0.50) for i, k in enumerate(keys): eax = pl.subplot(2, 2, i + 1) c1 = [0.2, 0.5, 0.8] c2 = [1.0, 0.5, 0.0] c3 = [0.1, 0.6, 0.1] sns.kdeplot(df1[k], shade=1, linewidth=3, label='', color=c1, alpha=0.5) sns.kdeplot(df2[k], shade=0, linewidth=3, label='', color=c2, alpha=0.5) pl.title(mapping[k]) pl.xlabel('') pl.yticks([]) if not i % 4: pl.ylabel('Density') yfactor = 1.3 yl = pl.ylim() pl.ylim([yl[0], yl[1] * yfactor]) m1 = np.median(df1[k]) m2 = np.median(df2[k]) m1std = df1[k].std() m2std = df2[k].std() pl.axvline(m1, c=c1, ymax=0.9, lw=3, linestyle='--') pl.axvline(m2, c=c2, ymax=0.9, lw=3, linestyle='--') def fmt(lab, val, std=-1): if val < 0.1: valstr = f'{lab} = {val:0.4f}' elif val < 1.0: valstr = f'{lab} = {val:0.2f}±{std:0.2f}' else: valstr = f'{lab} = {val:0.1f}±{std:0.1f}' if std < 0: valstr = valstr.split('±')[0] # Discard STD if not used return valstr if k.startswith('bc'): pl.xlim([0, 100]) elif k == 'beta': pl.xlim([3, 5]) elif k.startswith('tn'): pl.xlim([0, 50]) else: raise Exception(f'Please assign key {k}') xl = pl.xlim() xfmap = dict( beta=0.15, bc_wc1=0.30, bc_lf=0.35, tn=0.55, ) xf = xfmap[k] x0 = xl[0] + xf * (xl[1] - xl[0]) ypos1 = yl[1] * 0.97 ypos2 = yl[1] * 0.77 ypos3 = yl[1] * 0.57 if k == 'beta': # Use 2 s.f. instead of 1 pl.text(x0, ypos1, f'M: {m1:0.2f} ± {m1std:0.2f}', c=c1) pl.text(x0, ypos2, f'N: {m2:0.2f} ± {m2std:0.2f}', c=c2) pl.text(x0, ypos3, rf'$\Delta$: {(m2-m1):0.2f} ± {(m1std+m2std):0.2f}', c=c3) else: pl.text(x0, ypos1, f'M: {m1:0.1f} ± {m1std:0.1f}', c=c1) pl.text(x0, ypos2, f'N: {m2:0.1f} ± {m2std:0.1f}', c=c2) pl.text(x0, ypos3, rf'$\Delta$: {(m2-m1):0.1f} ± {(m1std+m2std):0.1f}', c=c3) sc.boxoff(ax=eax) #%% Fig. 1F: SafeGraph x0, y0c, dyc = 0.60, 0.03, 0.30 ax5 = pl.axes([x0, y0c, dx, dyc]) format_ax(ax5, sim, key='r_eff') fn = safegraph_file df = pd.read_csv(fn) week = df['week'] days = sim.day(week.values.tolist()) s = df['p.tot.schools'].values * 100 w = df['p.tot.no.schools'].values * 100 # From Fig. 2 colors = sc.gridcolors(5) wcolor = colors[3] # Work color/community scolor = colors[1] # School color pl.plot(days, w, 'd-', c=wcolor, markersize=15, lw=3, alpha=0.9, label='Workplace and\ncommunity mobility data') pl.plot(days, s, 'd-', c=scolor, markersize=15, lw=3, alpha=0.9, label='School mobility data') sc.setylim() xmin, xmax = ax5.get_xlim() ax5.set_xticks(np.arange(xmin, xmax, day_stride)) pl.ylabel('Relative mobility (%)') pl.legend(loc='upper right', frameon=False) plot_intervs(sim) return fig
def plotter(key, sims, ax, ys=None, calib=False, label='', ylabel='', low_q=0.025, high_q=0.975, flabel=True, startday=None, subsample=2, chooseseed=None): which = key.split('_')[1] try: color = cv.get_colors()[which] except: color = [0.5, 0.5, 0.5] if which == 'diagnoses': color = [0.03137255, 0.37401, 0.63813918, 1.] elif which == '': color = [0.82400815, 0., 0., 1.] if ys is None: ys = [] for s in sims: ys.append(s.results[key].values) yarr = np.array(ys) if chooseseed is not None: best = sims[chooseseed].results[key].values else: best = pl.mean(yarr, axis=0) low = pl.quantile(yarr, q=low_q, axis=0) high = pl.quantile(yarr, q=high_q, axis=0) sim = sims[0] # For having a sim to refer to tvec = np.arange(len(best)) if key in sim.data: data_t = np.array( (sim.data.index - sim['start_day']) / np.timedelta64(1, 'D')) inds = np.arange(0, len(data_t), subsample) pl.plot(data_t[inds], sim.data[key][inds], 'd', c=color, markersize=10, alpha=0.5, label='Data') start = None if startday is not None: start = sim.day(startday) end = sim.day(today) if flabel: if which == 'infections': fill_label = '95% projected interval' else: fill_label = '95% projected interval' else: fill_label = None pl.fill_between(tvec[startday:end], low[startday:end], high[startday:end], facecolor=color, alpha=0.2, label=fill_label) pl.plot(tvec[startday:end], best[startday:end], c=color, label=label, lw=4, alpha=1.0) sc.setylim() xmin, xmax = ax.get_xlim() if calib: ax.set_xticks(pl.arange(xmin + 2, xmax, 28)) else: ax.set_xticks(pl.arange(xmin + 2, xmax, 28)) pl.ylabel(ylabel) return