def curves(start, county, n_weeks=3, model_i=35, save_plot=False): with open('../data/counties/counties.pkl', "rb") as f: counties = pkl.load(f) start = int(start) n_weeks = int(n_weeks) model_i = int(model_i) # with open('../data/comparison.pkl', "rb") as f: # best_model = pkl.load(f) # update to day and new limits! xlim = (5.5, 15.5) ylim = (47, 56) # <- 10 weeks #countyByName = OrderedDict( # [('Düsseldorf', '05111'), ('Leipzig', '14713'), ('Nürnberg', '09564'), ('München', '09162')]) countyByName = make_county_dict() # Hier dann das reinspeisen plot_county_names = {"covid19": [county]} start_day = pd.Timestamp('2020-01-28') + pd.Timedelta(days=start) year = str(start_day)[:4] month = str(start_day)[5:7] day = str(start_day)[8:10] # if os.path.exists("../figures/{}_{}_{}/curve_trend_{}.png".format(year, month, day,countyByName[county])): # return day_folder_path = "../figures/{}_{}_{}".format(year, month, day) Path(day_folder_path).mkdir(parents=True, exist_ok=True) # check for metadata file: if not os.path.isfile("../figures/{}_{}_{}/metadata.csv".format( year, month, day)): ids = [] for key in counties: ids.append(int(key)) df = pd.DataFrame(data=ids, columns=["countyID"]) df["probText"] = "" df.to_csv("../figures/{}_{}_{}/metadata.csv".format(year, month, day)) # colors for curves #red C1 = "#D55E00" C2 = "#E69F00" #C3 = "#0073CF" #green C4 = "#188500" C5 = "#29c706" #C6 = "#0073CF" # quantiles we want to plot qs = [0.25, 0.50, 0.75] fig = plt.figure(figsize=(12, 6)) grid = plt.GridSpec( 1, 1, top=0.9, bottom=0.2, left=0.07, right=0.97, hspace=0.25, wspace=0.15, ) # for i, disease in enumerate(diseases): i = 0 disease = "covid19" prediction_region = "germany" data = load_daily_data_n_weeks(start, n_weeks, disease, prediction_region, counties) start_day = pd.Timestamp('2020-01-28') + pd.Timedelta(days=start) i_start_day = 0 day_0 = start_day + pd.Timedelta(days=n_weeks * 7 + 5) day_m5 = day_0 - pd.Timedelta(days=5) day_p5 = day_0 + pd.Timedelta(days=5) _, target, _, _ = split_data(data, train_start=start_day, test_start=day_0, post_test=day_p5) county_ids = target.columns county_id = countyByName[county] ### SELECTION CRITERION ### #if np.count_non_zero(target[county_id]) < 7: #??? # stdd = 10 # gaussian = lambda x: np.exp( (-(x)**2) / (2* stdd**2) ) # Load our prediction samples res = load_pred_model_window(model_i, start, n_weeks) res_trend = load_pred_model_window(model_i, start, n_weeks, trend=True) n_days = (day_p5 - start_day).days prediction_samples = np.reshape(res['y'], (res['y'].shape[0], -1, 412)) prediction_samples_trend = np.reshape(res_trend['μ'], (res_trend['μ'].shape[0], -1, 412)) prediction_samples = prediction_samples[:, i_start_day:i_start_day + n_days, :] prediction_samples_trend = prediction_samples_trend[:, i_start_day: i_start_day + n_days, :] ext_index = pd.DatetimeIndex([d for d in target.index] + \ [d for d in pd.date_range(target.index[-1]+timedelta(1),day_p5-timedelta(1))]) # TODO: figure out where quantiles comes from and if its pymc3, how to replace it prediction_quantiles = quantiles(prediction_samples, (5, 25, 75, 95)) prediction_mean = pd.DataFrame(data=np.mean(prediction_samples, axis=0), index=ext_index, columns=target.columns) prediction_q25 = pd.DataFrame(data=prediction_quantiles[25], index=ext_index, columns=target.columns) prediction_q75 = pd.DataFrame(data=prediction_quantiles[75], index=ext_index, columns=target.columns) prediction_q5 = pd.DataFrame(data=prediction_quantiles[5], index=ext_index, columns=target.columns) prediction_q95 = pd.DataFrame(data=prediction_quantiles[95], index=ext_index, columns=target.columns) prediction_mean_trend = pd.DataFrame(data=np.mean(prediction_samples_trend, axis=0), index=ext_index, columns=target.columns) # Unnecessary for-loop for j, name in enumerate(plot_county_names[disease]): ax = fig.add_subplot(grid[j, i]) county_id = countyByName[name] dates = [pd.Timestamp(day) for day in ext_index] days = [(day - min(dates)).days for day in dates] # plot our predictions w/ quartiles p_pred = ax.plot_date(dates, prediction_mean[county_id], "-", color=C1, linewidth=2.0, zorder=4) # plot our predictions w/ quartiles p_quant = ax.fill_between(dates, prediction_q25[county_id], prediction_q75[county_id], facecolor=C2, alpha=0.5, zorder=1) ax.plot_date(dates, prediction_q25[county_id], ":", color=C2, linewidth=2.0, zorder=3) ax.plot_date(dates, prediction_q75[county_id], ":", color=C2, linewidth=2.0, zorder=3) # plot ground truth p_real = ax.plot_date(dates[:-5], target[county_id], "k.") print(dates[-5] - pd.Timedelta(12, unit='h')) # plot 30week marker ax.axvline(dates[-5] - pd.Timedelta(12, unit='h'), ls='-', lw=2, c='dodgerblue') ax.axvline(dates[-10] - pd.Timedelta(12, unit='h'), ls='--', lw=2, c='lightskyblue') ax.set_ylabel("Fallzahlen/Tag nach Meldedatum", fontsize=16) ax.tick_params(axis="both", direction='out', size=6, labelsize=16, length=6) ticks = [ start_day + pd.Timedelta(days=i) for i in [0, 5, 10, 15, 20, 25, 30, 35, 40] ] labels = [ "{}.{}.{}".format(str(d)[8:10], str(d)[5:7], str(d)[:4]) for d in ticks ] plt.xticks(ticks, labels) #new_ticks = plt.get_xtickslabels() plt.setp(ax.get_xticklabels()[-4], color="red") plt.setp(ax.get_xticklabels(), rotation=45) ax.autoscale(True) p_quant2 = ax.fill_between(dates, prediction_q5[county_id], prediction_q95[county_id], facecolor=C2, alpha=0.25, zorder=0) ax.plot_date(dates, prediction_q5[county_id], ":", color=C2, alpha=0.5, linewidth=2.0, zorder=1) ax.plot_date(dates, prediction_q95[county_id], ":", color=C2, alpha=0.5, linewidth=2.0, zorder=1) # Plot the trend. ''' p_pred_trend = ax.plot_date( dates, prediction_mean_trend[county_id], "-", color="green", linewidth=2.0, zorder=4) ''' # Compute probability of increase/decreas i_county = county_ids.get_loc(county_id) trace = load_trace_window(disease, model_i, start, n_weeks) trend_params = pm.trace_to_dataframe(trace, varnames=["W_t_t"]).values trend_w2 = np.reshape(trend_params, newshape=(1000, 412, 2))[:, i_county, 1] prob2 = np.mean(trend_w2 > 0) # Set axis limits. ylimmax = max(3 * (target[county_id]).max(), 10) ax.set_ylim([-(1 / 30) * ylimmax, ylimmax]) ax.set_xlim([start_day, day_p5 - pd.Timedelta(days=1)]) if (i == 0) & (j == 0): ax.legend([p_real[0], p_pred[0], p_quant, p_quant2], [ "Daten RKI", "Modell", "25\%-75\%-Quantil", "5\%-95\%-Quantil" ], fontsize=16, loc="upper left") # Not perfectly positioned. print("uheufbhwio") print(ax.get_xticks()[-5]) print(ax.get_ylim()[1]) pos1 = tuple( ax.transData.transform((ax.get_xticks()[-3], ax.get_ylim()[1]))) pos1 = (ax.get_xticks()[-5], ax.get_ylim()[1]) print(pos1) fontsize_bluebox = 18 fig.text(ax.get_xticks()[-5] + 0.65, ax.get_ylim()[1], "Nowcast", ha="left", va="top", fontsize=fontsize_bluebox, bbox=dict(facecolor='lightskyblue', boxstyle='rarrow'), transform=ax.transData) # fig.text(pos1[0]/1200, pos1[1]/600,"Nowcast",fontsize=fontsize_bluebox,bbox=dict(facecolor='cornflowerblue')) fig.text(ax.get_xticks()[-4] + 0.65, ax.get_ylim()[1], "Forecast", ha="left", va="top", fontsize=fontsize_bluebox, bbox=dict(facecolor='dodgerblue', boxstyle='rarrow'), transform=ax.transData) ''' fig.text(0, 1 + 0.025, r"$\textbf{" + plot_county_names["covid19"][j]+ r"}$", fontsize=22, transform=ax.transAxes) ''' #plt.yticks(ax.get_yticks()[:-1], ax.get_yticklabels()[:-1]) # Store text in csv. #fontsize_probtext = 14 if prob2 >= 0.5: #fig.text(0.865, 0.685, "Die Fallzahlen \n werden mit einer \n Wahrscheinlichkeit \n von {:2.1f}\% steigen.".format(prob2*100), fontsize=fontsize_probtext,bbox=dict(facecolor='white')) probText = "Die Fallzahlen werden mit einer Wahrscheinlichkeit von {:2.1f}\% steigen.".format( prob2 * 100) else: probText = "Die Fallzahlen werden mit einer Wahrscheinlichkeit von {:2.1f}\% fallen.".format( 100 - prob2 * 100) #fig.text(0.865, 0.685, "Die Fallzahlen \n werden mit einer \n Wahrscheinlichkeit \n von {:2.1f}\% fallen.".format(100-prob2*100), fontsize=fontsize_probtext ,bbox=dict(facecolor='white')) print(county_id) df = pd.read_csv("../figures/{}_{}_{}/metadata.csv".format( year, month, day), index_col=0) county_ix = df["countyID"][df["countyID"] == int(county_id)].index[0] if prob2 >= 0.5: probVal = prob2 * 100 else: probVal = -(100 - prob2 * 100) df.iloc[county_ix, 1] = probVal #= probText df.to_csv("../figures/{}_{}_{}/metadata.csv".format(year, month, day)) print(probVal) plt.tight_layout() if save_plot: year = str(start_day)[:4] month = str(start_day)[5:7] day = str(start_day)[8:10] day_folder_path = "../figures/{}_{}_{}".format(year, month, day) Path(day_folder_path).mkdir(parents=True, exist_ok=True) plt.savefig("../figures/{}_{}_{}/curve_{}.png".format( year, month, day, countyByName[county]), dpi=200) plt.close() return fig
def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.05, quartiles=True, rhat=True, main=None, xtitle=None, xlim=None, ylabels=None, chain_spacing=0.05, vline=0, gs=None, plot_transformed=False, **plot_kwargs): """ Forest plot (model summary plot). Generates a "forest plot" of 100*(1-alpha)% credible intervals for either the set of variables in a given model, or a specified set of nodes. Parameters ---------- trace_obj: NpTrace or MultiTrace object Trace(s) from an MCMC sample. varnames: list List of variables to plot (defaults to None, which results in all variables plotted). transform : callable Function to transform data (defaults to identity) alpha (optional): float Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05). quartiles (optional): bool Flag for plotting the interquartile range, in addition to the (1-alpha)*100% intervals (defaults to True). rhat (optional): bool Flag for plotting Gelman-Rubin statistics. Requires 2 or more chains (defaults to True). main (optional): string Title for main plot. Passing False results in titles being suppressed; passing None (default) results in default titles. xtitle (optional): string Label for x-axis. Defaults to no label xlim (optional): list or tuple Range for x-axis. Defaults to matplotlib's best guess. ylabels (optional): list or array User-defined labels for each variable. If not provided, the node __name__ attributes are used. chain_spacing (optional): float Plot spacing between chains (defaults to 0.05). vline (optional): numeric Location of vertical reference line (defaults to 0). gs : GridSpec Matplotlib GridSpec object. Defaults to None. plot_transformed : bool Flag for plotting automatically transformed variables in addition to original variables (defaults to False). plot_kwargs : dict Optional arguments for plot elements. Currently accepts 'fontsize', 'linewidth', 'color', 'marker', and 'markersize'. Returns ------- gs : matplotlib GridSpec """ # Quantiles to be calculated if quartiles: qlist = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)] else: qlist = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)] # Range for x-axis plotrange = None # Subplots interval_plot = None nchains = trace_obj.nchains if varnames is None: varnames = get_default_varnames(trace_obj.varnames, plot_transformed) plot_rhat = (rhat and nchains > 1) # Empty list for y-axis labels if gs is None: # Initialize plot if plot_rhat: gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1]) else: gs = gridspec.GridSpec(1, 1) # Subplot for confidence intervals interval_plot = plt.subplot(gs[0]) trace_quantiles = quantiles(trace_obj, qlist, transform=transform, squeeze=False) hpd_intervals = hpd(trace_obj, alpha, transform=transform, squeeze=False) labels = [] for j, chain in enumerate(trace_obj.chains): # Counter for current variable var = 0 for varname in varnames: var_quantiles = trace_quantiles[chain][varname] quants = [var_quantiles[v] for v in qlist] var_hpd = hpd_intervals[chain][varname].T # Substitute HPD interval for quantile quants[0] = var_hpd[0].T quants[-1] = var_hpd[1].T # Ensure x-axis contains range of current interval if plotrange: plotrange = [ min(plotrange[0], np.min(quants)), max(plotrange[1], np.max(quants)) ] else: plotrange = [np.min(quants), np.max(quants)] # Number of elements in current variable value = trace_obj.get_values(varname, chains=[chain])[0] k = np.size(value) # Append variable name(s) to list if j == 0: if k > 1: names = _var_str(varname, np.shape(value)) labels += names else: labels.append(varname) # Add spacing for each chain, if more than one offset = [0] + [(chain_spacing * ((i + 2) / 2)) * (-1)**i for i in range(nchains - 1)] # Y coordinate with offset y = -var + offset[j] # Deal with multivariate nodes if k > 1: for q in np.transpose(quants).squeeze(): # Multiple y values interval_plot = _plot_tree(interval_plot, y, q, quartiles, **plot_kwargs) y -= 1 else: interval_plot = _plot_tree(interval_plot, y, quants, quartiles, **plot_kwargs) # Increment index var += k labels = ylabels if ylabels is not None else labels # Update margins left_margin = np.max([len(x) for x in labels]) * 0.015 gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05) # Define range of y-axis interval_plot.set_ylim(-var + 0.5, 0.5) datarange = plotrange[1] - plotrange[0] interval_plot.set_xlim(plotrange[0] - 0.05 * datarange, plotrange[1] + 0.05 * datarange) # Add variable labels interval_plot.set_yticks([-l for l in range(len(labels))]) interval_plot.set_yticklabels(labels, fontsize=plot_kwargs.get('fontsize', None)) # Add title plot_title = "" if main is None: plot_title = "{:.0f}% Credible Intervals".format((1 - alpha) * 100) elif main: plot_title = main if plot_title: interval_plot.set_title(plot_title, fontsize=plot_kwargs.get('fontsize', None)) # Add x-axis label if xtitle is not None: interval_plot.set_xlabel(xtitle) # Constrain to specified range if xlim is not None: interval_plot.set_xlim(*xlim) # Remove ticklines on y-axes for ticks in interval_plot.yaxis.get_major_ticks(): ticks.tick1On = False ticks.tick2On = False for loc, spine in interval_plot.spines.items(): if loc in ['left', 'right']: spine.set_color('none') # don't draw spine # Reference line interval_plot.axvline(vline, color='k', linestyle=':') # Genenerate Gelman-Rubin plot if plot_rhat: _make_rhat_plot(trace_obj, plt.subplot(gs[1]), "R-hat", labels, varnames, plot_transformed) return gs
def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.05, quartiles=True, rhat=True, main=None, xtitle=None, xlim=None, ylabels=None, chain_spacing=0.05, vline=0, gs=None, plot_transformed=False): """ Forest plot (model summary plot). Generates a "forest plot" of 100*(1-alpha)% credible intervals for either the set of variables in a given model, or a specified set of nodes. Parameters ---------- trace_obj: NpTrace or MultiTrace object Trace(s) from an MCMC sample. varnames: list List of variables to plot (defaults to None, which results in all variables plotted). transform : callable Function to transform data (defaults to identity) alpha (optional): float Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05). quartiles (optional): bool Flag for plotting the interquartile range, in addition to the (1-alpha)*100% intervals (defaults to True). rhat (optional): bool Flag for plotting Gelman-Rubin statistics. Requires 2 or more chains (defaults to True). main (optional): string Title for main plot. Passing False results in titles being suppressed; passing None (default) results in default titles. xtitle (optional): string Label for x-axis. Defaults to no label xlim (optional): list or tuple Range for x-axis. Defaults to matplotlib's best guess. ylabels (optional): list or array User-defined labels for each variable. If not provided, the node __name__ attributes are used. chain_spacing (optional): float Plot spacing between chains (defaults to 0.05). vline (optional): numeric Location of vertical reference line (defaults to 0). gs : GridSpec Matplotlib GridSpec object. Defaults to None. plot_transformed : bool Flag for plotting automatically transformed variables in addition to original variables (defaults to False). Returns ------- gs : matplotlib GridSpec """ # Quantiles to be calculated if quartiles: qlist = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)] else: qlist = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)] # Range for x-axis plotrange = None # Subplots interval_plot = None nchains = trace_obj.nchains if varnames is None: varnames = get_default_varnames(trace_obj, plot_transformed) plot_rhat = (rhat and nchains > 1) # Empty list for y-axis labels if gs is None: # Initialize plot if plot_rhat: gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1]) else: gs = gridspec.GridSpec(1, 1) # Subplot for confidence intervals interval_plot = plt.subplot(gs[0]) trace_quantiles = quantiles(trace_obj, qlist, transform=transform, squeeze=False) hpd_intervals = hpd(trace_obj, alpha, transform=transform, squeeze=False) labels = [] for j, chain in enumerate(trace_obj.chains): # Counter for current variable var = 0 for varname in varnames: var_quantiles = trace_quantiles[chain][varname] quants = [var_quantiles[v] for v in qlist] var_hpd = hpd_intervals[chain][varname].T # Substitute HPD interval for quantile quants[0] = var_hpd[0].T quants[-1] = var_hpd[1].T # Ensure x-axis contains range of current interval if plotrange: plotrange = [min(plotrange[0], np.min(quants)), max(plotrange[1], np.max(quants))] else: plotrange = [np.min(quants), np.max(quants)] # Number of elements in current variable value = trace_obj.get_values(varname, chains=[chain])[0] k = np.size(value) # Append variable name(s) to list if j == 0: if k > 1: names = _var_str(varname, np.shape(value)) labels += names else: labels.append(varname) # Add spacing for each chain, if more than one offset = [0] + [(chain_spacing * ((i + 2) / 2)) * (-1) ** i for i in range(nchains - 1)] # Y coordinate with offset y = -var + offset[j] # Deal with multivariate nodes if k > 1: for q in np.transpose(quants).squeeze(): # Multiple y values interval_plot = _plot_tree(interval_plot, y, q, quartiles) y -= 1 else: interval_plot = _plot_tree(interval_plot, y, quants, quartiles) # Increment index var += k labels = ylabels if ylabels is not None else labels # Update margins left_margin = np.max([len(x) for x in labels]) * 0.015 gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05) # Define range of y-axis interval_plot.set_ylim(-var + 0.5, 0.5) datarange = plotrange[1] - plotrange[0] interval_plot.set_xlim(plotrange[0] - 0.05 * datarange, plotrange[1] + 0.05 * datarange) # Add variable labels interval_plot.set_yticks([-l for l in range(len(labels))]) interval_plot.set_yticklabels(labels) # Add title plot_title = "" if main is None: plot_title = "{:.0f}% Credible Intervals".format((1 - alpha) * 100) elif main: plot_title = main if plot_title: interval_plot.set_title(plot_title) # Add x-axis label if xtitle is not None: interval_plot.set_xlabel(xtitle) # Constrain to specified range if xlim is not None: interval_plot.set_xlim(*xlim) # Remove ticklines on y-axes for ticks in interval_plot.yaxis.get_major_ticks(): ticks.tick1On = False ticks.tick2On = False for loc, spine in interval_plot.spines.items(): if loc in ['left', 'right']: spine.set_color('none') # don't draw spine # Reference line interval_plot.axvline(vline, color='k', linestyle='--') # Genenerate Gelman-Rubin plot if plot_rhat: _make_rhat_plot(trace_obj, plt.subplot(gs[1]), "R-hat", labels, varnames, plot_transformed) return gs
def forestplot(trace, models=None, varnames=None, transform=identity_transform, alpha=0.05, quartiles=True, rhat=True, main=None, xtitle=None, xlim=None, ylabels=None, colors='C0', chain_spacing=0.1, vline=0, gs=None, plot_transformed=False, plot_kwargs=None): """ Forest plot (model summary plot). Generates a "forest plot" of 100*(1-alpha)% credible intervals from a trace or list of traces. Parameters ---------- trace : trace or list of traces Trace(s) from an MCMC sample. models : list (optional) List with names for the models in the list of traces. Useful when plotting more that one trace. varnames: list List of variables to plot (defaults to None, which results in all variables plotted). transform : callable Function to transform data (defaults to identity) alpha : float, optional Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05). quartiles : bool, optional Flag for plotting the interquartile range, in addition to the (1-alpha)*100% intervals (defaults to True). rhat : bool, optional Flag for plotting Gelman-Rubin statistics. Requires 2 or more chains (defaults to True). main : string, optional Title for main plot. Passing False results in titles being suppressed; passing None (default) results in default titles. xtitle : string, optional Label for x-axis. Defaults to no label xlim : list or tuple, optional Range for x-axis. Defaults to matplotlib's best guess. ylabels : list or array, optional User-defined labels for each variable. If not provided, the node __name__ attributes are used. colors : list or string, optional list with valid matplotlib colors, one color per model. Alternative a string can be passed. If the string is `cycle `, it will automatically chose a color per model from the matyplolib's cycle. If a single color is passed, eg 'k', 'C2', 'red' this color will be used for all models. Defauls to 'C0' (blueish in most matplotlib styles) chain_spacing : float, optional Plot spacing between chains (defaults to 0.1). vline : numeric, optional Location of vertical reference line (defaults to 0). gs : GridSpec Matplotlib GridSpec object. Defaults to None. plot_transformed : bool Flag for plotting automatically transformed variables in addition to original variables (defaults to False). plot_kwargs : dict Optional arguments for plot elements. Currently accepts 'fontsize', 'linewidth', 'marker', and 'markersize'. Returns ------- gs : matplotlib GridSpec """ if plot_kwargs is None: plot_kwargs = {} if not isinstance(trace, (list, tuple)): trace = [trace] if models is None: if len(trace) > 1: models = ['m_{}'.format(i) for i in range(len(trace))] else: models = [''] elif len(models) != len(trace): raise ValueError("The number of names for the models does not match " "the number of models") if colors == 'cycle': colors = ['C{}'.format(i % 10) for i in range(len(models))] elif isinstance(colors, str): colors = [colors for i in range(len(models))] # Quantiles to be calculated if quartiles: qlist = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)] else: qlist = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)] nchains = [tr.nchains for tr in trace] if varnames is None: varnames = [] for idx, tr in enumerate(trace): varnames_tmp = get_default_varnames(tr.varnames, plot_transformed) for v in varnames_tmp: if v not in varnames: varnames.append(v) plot_rhat = [rhat and nch > 1 for nch in nchains] # Empty list for y-axis labels if gs is None: # Initialize plot if np.any(plot_rhat): gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1]) gr_plot = plt.subplot(gs[1]) gr_plot.set_xticks((1.0, 1.5, 2.0), ("1", "1.5", "2+")) gr_plot.set_xlim(0.9, 2.1) gr_plot.set_yticks([]) gr_plot.set_title('R-hat') else: gs = gridspec.GridSpec(1, 1) # Subplot for confidence intervals interval_plot = plt.subplot(gs[0]) trace_quantiles = [] hpd_intervals = [] for tr in trace: trace_quantiles.append( quantiles(tr, qlist, transform=transform, squeeze=False)) hpd_intervals.append(hpd(tr, alpha, transform=transform, squeeze=False)) labels = [] var = 0 all_quants = [] bands = [0.05, 0] * len(varnames) var_old = 0.5 for v_idx, v in enumerate(varnames): for h, tr in enumerate(trace): if v not in tr.varnames: labels.append(models[h] + ' ' + v) var += 1 else: for j, chain in enumerate(tr.chains): var_quantiles = trace_quantiles[h][chain][v] quants = [var_quantiles[vq] for vq in qlist] var_hpd = hpd_intervals[h][chain][v].T # Substitute HPD interval for quantile quants[0] = var_hpd[0].T quants[-1] = var_hpd[1].T # Ensure x-axis contains range of current interval all_quants.extend(np.ravel(quants)) # Number of elements in current variable value = tr.get_values(v, chains=[chain])[0] k = np.size(value) # Append variable name(s) to list if j == 0: if k > 1: names = _var_str(v, np.shape(value)) names[0] = models[h] + ' ' + names[0] labels += names else: labels.append(models[h] + ' ' + v) # Add spacing for each chain, if more than one offset = [0] + [(chain_spacing * ((i + 2) / 2)) * (-1)**i for i in range(nchains[h] - 1)] # Y coordinate with offset y = -var + offset[j] # Deal with multivariate nodes if k > 1: qs = np.moveaxis(np.array(quants), 0, -1).squeeze() for q in qs.reshape(-1, len(quants)): # Multiple y values interval_plot = _plot_tree(interval_plot, y, q, quartiles, colors[h], plot_kwargs) y -= 1 else: interval_plot = _plot_tree(interval_plot, y, quants, quartiles, colors[h], plot_kwargs) # Genenerate Gelman-Rubin plot if plot_rhat[h] and v in tr.varnames: R = gelman_rubin(tr, [v]) if k > 1: Rval = dict2pd(R, 'rhat').values gr_plot.plot([min(r, 2) for r in Rval], [-(j + var) for j in range(k)], 'o', color=colors[h], markersize=4) else: gr_plot.plot(min(R[v], 2), -var, 'o', color=colors[h], markersize=4) var += k if len(trace) > 1: interval_plot.axhspan(var_old, y - chain_spacing - 0.5, facecolor='k', alpha=bands[v_idx]) gr_plot.axhspan(var_old, y - chain_spacing - 0.5, facecolor='k', alpha=bands[v_idx]) var_old = y - chain_spacing - 0.5 if ylabels is not None: labels = ylabels # Update margins left_margin = np.max([len(x) for x in labels]) * 0.015 gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05) # Define range of y-axis for forestplot and R-hat interval_plot.set_ylim(-var + 0.5, 0.5) if np.any(plot_rhat): gr_plot.set_ylim(-var + 0.5, 0.5) plotrange = [np.min(all_quants), np.max(all_quants)] datarange = plotrange[1] - plotrange[0] interval_plot.set_xlim(plotrange[0] - 0.05 * datarange, plotrange[1] + 0.05 * datarange) # Add variable labels interval_plot.set_yticks([-l for l in range(len(labels))]) interval_plot.set_yticklabels(labels, fontsize=plot_kwargs.get('fontsize', None)) # Add title if main is None: plot_title = "{:.0f}% Credible Intervals".format((1 - alpha) * 100) elif main: plot_title = main else: plot_title = "" interval_plot.set_title(plot_title, fontsize=plot_kwargs.get('fontsize', None)) # Add x-axis label if xtitle is not None: interval_plot.set_xlabel(xtitle) # Constrain to specified range if xlim is not None: interval_plot.set_xlim(*xlim) # Remove ticklines on y-axes for ticks in interval_plot.yaxis.get_major_ticks(): ticks.tick1On = False ticks.tick2On = False for loc, spine in interval_plot.spines.items(): if loc in ['left', 'right']: spine.set_color('none') # don't draw spine # Reference line interval_plot.axvline(vline, color='k', linestyle=':') return gs
features = model.evaluate_features(target_train.index, target_train.columns) trend_features = features["temporal_trend"].swaplevel(0, 1).loc["09162"] periodic_features = features["temporal_seasonal"].swaplevel(0, 1).loc["09162"] #t_all = t_all_b if disease == "borreliosis" else t_all_cr trace = load_trace(disease, use_age, use_eastwest) trend_params = pm.trace_to_dataframe(trace, varnames=["W_t_t"]) periodic_params = pm.trace_to_dataframe(trace, varnames=["W_t_s"]) TT = trend_params.values.dot(trend_features.values.T) TP = periodic_params.values.dot(periodic_features.values.T) TTP = TT + TP TT_quantiles = quantiles(TT, (25, 75)) TP_quantiles = quantiles(TP, (25, 75)) TTP_quantiles = quantiles(TTP, (25, 75)) dates = [n.wednesday() for n in target_train.index.values] # Temporal periodic effect ax_p = fig.add_subplot(grid[0, i]) ax_p.fill_between(dates, np.exp(TP_quantiles[25]), np.exp(TP_quantiles[75]), alpha=0.5, zorder=1, facecolor=C1) ax_p.plot_date(dates,
def curves_no_periodic(model_i=15, prediction_day=30, save_plot=False): with open('../data/counties/counties.pkl', "rb") as f: counties = pkl.load(f) # with open('../data/comparison.pkl', "rb") as f: # best_model = pkl.load(f) # update to day and new limits! xlim = (5.5, 15.5) ylim = (47, 56) # <- 10 weeks countyByName = OrderedDict([('Düsseldorf', '05111'), ('Leipzig', '14713'), ('Nürnberg', '09564'), ('München', '09162')]) plot_county_names = {"covid19": ["Düsseldorf", "Leipzig"]} # colors for curves C1 = "#D55E00" C2 = "#E69F00" C3 = "#0073CF" # quantiles we want to plot qs = [0.25, 0.50, 0.75] fig = plt.figure(figsize=(12, 14)) grid = plt.GridSpec(3, 1, top=0.9, bottom=0.1, left=0.07, right=0.97, hspace=0.25, wspace=0.15, height_ratios=[1, 1, 1.75]) # for i, disease in enumerate(diseases): i = 0 disease = "covid19" prediction_region = "germany" data = load_daily_data(disease, prediction_region, counties) start_day = pd.Timestamp('2020-04-10') i_start_day = (start_day - data.index.min()).days day_0 = pd.Timestamp('2020-05-21') day_m5 = day_0 - pd.Timedelta(days=5) day_p5 = day_0 + pd.Timedelta(days=5) _, target, _, _ = split_data(data, train_start=start_day, test_start=day_0, post_test=day_p5) county_ids = target.columns # Load our prediction samples res = load_final_nowcast_pred() #res_test = load_pred_by_i(disease, model_i) # print(res_train['y'].shape) #print(res_test['y'].shape) n_days = (day_p5 - start_day).days #print(res['y'].shape) prediction_samples = np.reshape(res['y'], (res['y'].shape[0], -1, 412)) #print(prediction_samples.shape) #print(target.index) prediction_samples = prediction_samples[:, i_start_day:i_start_day + n_days, :] ext_index = pd.DatetimeIndex([d for d in target.index] + \ [d for d in pd.date_range(target.index[-1]+timedelta(1),day_p5-timedelta(1))]) # TODO: figure out where quantiles comes from and if its pymc3, how to replace it prediction_quantiles = quantiles(prediction_samples, (5, 25, 75, 95)) prediction_mean = pd.DataFrame(data=np.mean(prediction_samples, axis=0), index=ext_index, columns=target.columns) prediction_q25 = pd.DataFrame(data=prediction_quantiles[25], index=ext_index, columns=target.columns) prediction_q75 = pd.DataFrame(data=prediction_quantiles[75], index=ext_index, columns=target.columns) prediction_q5 = pd.DataFrame(data=prediction_quantiles[5], index=ext_index, columns=target.columns) prediction_q95 = pd.DataFrame(data=prediction_quantiles[95], index=ext_index, columns=target.columns) map_ax = fig.add_subplot(grid[2, i]) map_ax.set_position(grid[2, i].get_position(fig).translated(0, -0.05)) map_ax.set_xlabel("{}.{}.{}".format(prediction_mean.index[-5].day, prediction_mean.index[-5].month, prediction_mean.index[-5].year), fontsize=22) # plot the chloropleth map plot_counties(map_ax, counties, prediction_mean.iloc[-10].to_dict(), edgecolors=dict( zip(map(countyByName.get, plot_county_names[disease]), ["red"] * len(plot_county_names[disease]))), xlim=xlim, ylim=ylim, contourcolor="black", background=False, xticks=False, yticks=False, grid=False, frame=True, ylabel=False, xlabel=False, lw=2) map_ax.set_rasterized(True) for j, name in enumerate(plot_county_names[disease]): ax = fig.add_subplot(grid[j, i]) county_id = countyByName[name] # dates = [n.wednesday() for n in target.index.values] dates = [pd.Timestamp(day) for day in ext_index] days = [(day - min(dates)).days for day in dates] # plot our predictions w/ quartiles p_pred = ax.plot_date(dates, prediction_mean[county_id], "-", color=C1, linewidth=2.0, zorder=4) p_quant = ax.fill_between(dates, prediction_q25[county_id], prediction_q75[county_id], facecolor=C2, alpha=0.5, zorder=1) ax.plot_date(dates, prediction_q25[county_id], ":", color=C2, linewidth=2.0, zorder=3) ax.plot_date(dates, prediction_q75[county_id], ":", color=C2, linewidth=2.0, zorder=3) # plot ground truth p_real = ax.plot_date(dates[:-5], target[county_id], "k.") # plot 30week marker ax.axvline(dates[-5], ls='-', lw=2, c='cornflowerblue') ax.axvline(dates[-10], ls='--', lw=2, c='cornflowerblue') #ax.set_title(["campylobacteriosis" if disease == "campylobacter" else disease] # [0] + "\n" + name if j == 0 else name, fontsize=22) if j == 1: ax.set_xlabel("Time", fontsize=20) ax.tick_params(axis="both", direction='out', size=6, labelsize=16, length=6) ticks = [ '2020-03-02', '2020-03-12', '2020-03-22', '2020-04-01', '2020-04-11', '2020-04-21', '2020-05-1', '2020-05-11', '2020-05-21' ] labels = [ '02.03.2020', '12.03.2020', '22.03.2020', '01.04.2020', '11.04.2020', '21.04.2020', '01.05.2020', '11.05.2020', '21.05.2020' ] plt.xticks(ticks, labels) #plt.xlabel(ticks) plt.setp(ax.get_xticklabels(), visible=j > 0, rotation=45) cent = np.array(counties[county_id]["shape"].centroid.coords[0]) txt = map_ax.annotate(name, cent, cent + 0.5, color="white", arrowprops=dict(facecolor='white', shrink=0.001, headwidth=3), fontsize=26, fontweight='bold', fontname="Arial") txt.set_path_effects( [PathEffects.withStroke(linewidth=2, foreground='black')]) ax.set_xlim([start_day, day_p5 - pd.Timedelta(1)]) ax.autoscale(False) p_quant2 = ax.fill_between(dates, prediction_q5[county_id], prediction_q95[county_id], facecolor=C2, alpha=0.25, zorder=0) ax.plot_date(dates, prediction_q5[county_id], ":", color=C2, alpha=0.5, linewidth=2.0, zorder=1) ax.plot_date(dates, prediction_q95[county_id], ":", color=C2, alpha=0.5, linewidth=2.0, zorder=1) if (i == 0) & (j == 0): ax.legend([p_real[0], p_pred[0], p_quant, p_quant2], [ "reported", "predicted", "25\%-75\% quantile", "5\%-95\% quantile" ], fontsize=12, loc="upper right") fig.text(0, 1 + 0.025, r"$\textbf{" + str(i + 2) + "ABC"[j] + " " + plot_county_names["covid19"][j] + r"}$", fontsize=22, transform=ax.transAxes) fig.text(0, 0.95, r"$\textbf{" + str(i + 2) + r"C}$", fontsize=22, transform=map_ax.transAxes) fig.text(0.01, 0.66, "Reported/predicted infections", va='center', rotation='vertical', fontsize=22) if save_plot: plt.savefig("../figures/curves_{}.pdf".format(model_i))
# if disease == "borreliosis": # data = data[data.index >= parse_yearweek("2013-KW1")] _, target, _, _ = split_data( data, train_start=pd.Timestamp( 2020, 1, 28), test_start=pd.Timestamp( 2020, 3, 30), post_test=pd.Timestamp( 2020, 3, 31)) # plots for the training period! # _, _, _, target = split_data(data) county_ids = target.columns res = load_pred(disease, use_age, use_eastwest) n_days = 62 # for now; get from timestamps up top! prediction_samples = np.reshape(res['y'], (res['y'].shape[0], n_days, -1)) # TODO: figure out where quantiles comes from and if its pymc3, how to replace it prediction_quantiles = quantiles(prediction_samples, (5, 25, 75, 95)) prediction_mean = pd.DataFrame( data=np.mean( prediction_samples, axis=0), index=target.index, columns=target.columns) prediction_q25 = pd.DataFrame( data=prediction_quantiles[25], index=target.index, columns=target.columns) prediction_q75 = pd.DataFrame( data=prediction_quantiles[75], index=target.index, columns=target.columns)
def curves(start, n_weeks=3, model_i=35,save_plot=False): with open('../data/counties/counties.pkl', "rb") as f: counties = pkl.load(f) start = int(start) n_weeks = int(n_weeks) model_i = int(model_i) # with open('../data/comparison.pkl', "rb") as f: # best_model = pkl.load(f) # update to day and new limits! xlim = (5.5, 15.5) ylim = (47, 56) # <- 10 weeks #countyByName = OrderedDict( # [('Düsseldorf', '05111'), ('Leipzig', '14713'), ('Nürnberg', '09564'), ('München', '09162')]) # Hier dann das reinspeisen # Fake for loop to work #plot_county_names = {"covid19": ["Leipzig"]} # colors for curves C1 = "#D55E00" C2 = "#E69F00" C3 = "#0073CF" # quantiles we want to plot qs = [0.25, 0.50, 0.75] fig = plt.figure(figsize=(6, 8)) grid = plt.GridSpec( 1, 1, #top=0.99, #bottom=0.01, #left=0.1, #right=0.9, #hspace=0.02, #wspace=0.15, #height_ratios=[ # 1, # 1, # 1.75] ) # for i, disease in enumerate(diseases): i = 0 disease = "covid19" prediction_region = "germany" data = load_daily_data_n_weeks(start, n_weeks, disease, prediction_region, counties) start_day = pd.Timestamp('2020-01-28') + pd.Timedelta(days=start) i_start_day = 0 day_0 = start_day + pd.Timedelta(days=n_weeks*7+5) day_m5 = day_0 - pd.Timedelta(days=5) day_p5 = day_0 + pd.Timedelta(days=5) _, target, _, _ = split_data( data, train_start=start_day, test_start=day_0, post_test=day_p5) county_ids = target.columns # Load our prediction samples res = load_pred_model_window(model_i, start, n_weeks) n_days = (day_p5 - start_day).days prediction_samples = np.reshape(res['y'], (res['y'].shape[0], -1, 412)) prediction_samples = prediction_samples[:,i_start_day:i_start_day+n_days,:] ext_index = pd.DatetimeIndex([d for d in target.index] + \ [d for d in pd.date_range(target.index[-1]+timedelta(1),day_p5-timedelta(1))]) # TODO: figure out where quantiles comes from and if its pymc3, how to replace it prediction_quantiles = quantiles(prediction_samples, (5, 25, 75, 95)) prediction_mean = pd.DataFrame( data=np.mean( prediction_samples, axis=0), index=ext_index, columns=target.columns) prediction_q25 = pd.DataFrame( data=prediction_quantiles[25], index=ext_index, columns=target.columns) prediction_q75 = pd.DataFrame( data=prediction_quantiles[75], index=ext_index, columns=target.columns) prediction_q5 = pd.DataFrame( data=prediction_quantiles[5], index=ext_index, columns=target.columns) prediction_q95 = pd.DataFrame( data=prediction_quantiles[95], index=ext_index, columns=target.columns) map_ax = fig.add_subplot(grid[0, i]) #map_ax.set_position(grid[0, i].get_position(fig).translated(0, -0.05)) # relativize prediction mean. map_vals = prediction_mean.iloc[-10] map_rki = data.iloc[-1].values.astype('float64') map_keys = [] # ik= 0 for (ik, (key, _)) in enumerate(counties.items()): n_people = counties[key]['demographics'][('total',2018)] map_vals[ik] = (map_vals[ik] / n_people) * 100000 map_rki[ik] = (map_rki[ik] / n_people) * 100000 # ik = ik+1 map_keys.append(key) map_df = pd.DataFrame(index=None) map_df["countyID"] = map_keys map_df["newInf100k"] = list(map_vals) map_df["newInf100k_RKI"] = list(map_rki) # plot the chloropleth map plot_counties(map_ax, counties, map_vals.to_dict(), #prediction_mean.iloc[-10].to_dict(), #edgecolors=dict(zip(map(countyByName.get, # plot_county_names[disease]), # ["red"] * len(plot_county_names[disease]))), edgecolors=None, xlim=xlim, ylim=ylim, contourcolor="black", background=False, xticks=False, yticks=False, grid=False, frame=True, ylabel=False, xlabel=False, lw=2) #plt.colorbar() map_ax.set_rasterized(True) ''' for j, name in enumerate(plot_county_names[disease]): ax = fig.add_subplot(grid[j, i]) county_id = countyByName[name] # dates = [n.wednesday() for n in target.index.values] dates = [pd.Timestamp(day) for day in ext_index] days = [ (day - min(dates)).days for day in dates] ''' fig.text(0.71,0.17,"Neuinfektionen \n pro 100.000 \n Einwohner", fontsize=14, color=[0.3,0.3,0.3]) if save_plot: year = str(start_day)[:4] month = str(start_day)[5:7] day = str(start_day)[8:10] day_folder_path = "../figures/{}_{}_{}".format(year, month, day) Path(day_folder_path).mkdir(parents=True, exist_ok=True) plt.savefig("../figures/{}_{}_{}/map.png".format(year, month, day), dpi=300) map_df.to_csv("../figures/{}_{}_{}/map.csv".format(year, month, day)) plt.close() return fig
def plotdata_csv(start, n_weeks, csv_path, counties, output_dir): countyByName = make_county_dict() data = load_data_n_weeks(start, n_weeks, csv_path) start_day = pd.Timestamp("2020-01-28") + pd.Timedelta(days=start) day_0 = start_day + pd.Timedelta(days=n_weeks * 7 + 5) day_m5 = day_0 - pd.Timedelta(days=5) day_p5 = day_0 + pd.Timedelta(days=5) _, target, _, _ = split_data(data, train_start=start_day, test_start=day_0, post_test=day_p5) # Load our prediction samples res = load_predictions(start, n_weeks) res_trend = load_trend_predictions(start, n_weeks) prediction_samples = np.reshape(res["y"], (res["y"].shape[0], -1, 412)) prediction_samples_mu = np.reshape(res["μ"], (res["μ"].shape[0], -1, 412)) prediction_samples_trend = np.reshape(res_trend["y"], (res_trend["y"].shape[0], -1, 412)) prediction_samples_trend_mu = np.reshape( res_trend["μ"], (res_trend["μ"].shape[0], -1, 412)) predictions_7day_inc = sample_x_days_incidence_by_county( prediction_samples_trend, 7) predictions_7day_inc_mu = sample_x_days_incidence_by_county( prediction_samples_trend_mu, 7) ext_index = pd.DatetimeIndex([d for d in target.index] + [ d for d in pd.date_range(target.index[-1] + timedelta(1), day_p5 - timedelta(1)) ]) # TODO: figure out if we want to replace quantiles function (newer pymc3 versions don't support it) prediction_quantiles = quantiles(prediction_samples, (5, 25, 75, 95)) prediction_quantiles_trend = quantiles(prediction_samples_trend, (5, 25, 75, 95)) prediction_quantiles_7day_inc = quantiles(predictions_7day_inc, (5, 25, 75, 95)) prediction_mean = pd.DataFrame( data=np.mean(prediction_samples_mu, axis=0), index=ext_index, columns=target.columns, ) prediction_q25 = pd.DataFrame(data=prediction_quantiles[25], index=ext_index, columns=target.columns) prediction_q75 = pd.DataFrame(data=prediction_quantiles[75], index=ext_index, columns=target.columns) prediction_q5 = pd.DataFrame(data=prediction_quantiles[5], index=ext_index, columns=target.columns) prediction_q95 = pd.DataFrame(data=prediction_quantiles[95], index=ext_index, columns=target.columns) prediction_mean_trend = pd.DataFrame( data=np.mean(prediction_samples_trend_mu, axis=0), index=ext_index, columns=target.columns, ) prediction_q25_trend = pd.DataFrame(data=prediction_quantiles_trend[25], index=ext_index, columns=target.columns) prediction_q75_trend = pd.DataFrame(data=prediction_quantiles_trend[75], index=ext_index, columns=target.columns) prediction_q5_trend = pd.DataFrame(data=prediction_quantiles_trend[5], index=ext_index, columns=target.columns) prediction_q95_trend = pd.DataFrame(data=prediction_quantiles_trend[95], index=ext_index, columns=target.columns) prediction_mean_7day = pd.DataFrame( data=np.pad( np.mean(predictions_7day_inc_mu, axis=0), ((6, 0), (0, 0)), "constant", constant_values=np.nan, ), index=ext_index, columns=target.columns, ) prediction_q25_7day = pd.DataFrame( data=np.pad( prediction_quantiles_7day_inc[25].astype(float), ((6, 0), (0, 0)), "constant", constant_values=np.nan, ), index=ext_index, columns=target.columns, ) prediction_q75_7day = pd.DataFrame( data=np.pad( prediction_quantiles_7day_inc[75].astype(float), ((6, 0), (0, 0)), "constant", constant_values=np.nan, ), index=ext_index, columns=target.columns, ) prediction_q5_7day = pd.DataFrame( data=np.pad( prediction_quantiles_7day_inc[5].astype(float), ((6, 0), (0, 0)), "constant", constant_values=np.nan, ), index=ext_index, columns=target.columns, ) prediction_q95_7day = pd.DataFrame( data=np.pad( prediction_quantiles_7day_inc[95].astype(float), ((6, 0), (0, 0)), "constant", constant_values=np.nan, ), index=ext_index, columns=target.columns, ) rki_7day = target.rolling(7).sum() ref_date = target.iloc[-1].name nowcast_vals = prediction_mean.loc[prediction_mean.index == ref_date] nowcast7day_vals = prediction_mean_7day.loc[prediction_mean.index == ref_date] rki_vals = target.iloc[-1] rki_7day_vals = rki_7day.iloc[-1] map_nowcast = [] map_nowcast100k = [] map_nowcast_7day = [] map_nowcast_7day100k = [] map_rki = [] map_rki100k = [] map_rki_7day = [] map_rki_7day100k = [] map_keys = [] for (county, county_id) in countyByName.items(): rki_data = np.append(target.loc[:, county_id].values, np.repeat(np.nan, 5)) rki_data7day = np.append(rki_7day.loc[:, county_id].values, np.repeat(np.nan, 5)) n_people = counties[county_id]["demographics"][("total", 2018)] map_nowcast.append(nowcast_vals[county_id].item()) map_nowcast100k.append(nowcast_vals[county_id].item() / n_people * 100000) map_nowcast_7day.append(nowcast7day_vals[county_id].item()) map_nowcast_7day100k.append(nowcast7day_vals[county_id].item() / n_people * 100000) map_rki.append(rki_vals[county_id].item()) map_rki100k.append(rki_vals[county_id].item() / n_people * 100000) map_rki_7day.append(rki_7day_vals[county_id].item()) map_rki_7day100k.append(rki_7day_vals[county_id].item() / n_people * 100000) map_keys.append(county_id) county_data = pd.DataFrame( { "Raw Prediction Mean": prediction_mean.loc[:, county_id].values, "Raw Prediction Mean 100k": np.multiply( np.divide(prediction_mean.loc[:, county_id].values, n_people), 100000, ), "Raw Prediction Q25": prediction_q25.loc[:, county_id].values, "Raw Prediction Q25 100k": np.multiply( np.divide(prediction_q25.loc[:, county_id].values, n_people), 100000, ), "Raw Prediction Q75": prediction_q75.loc[:, county_id].values, "Raw Prediction Q75 100k": np.multiply( np.divide(prediction_q75.loc[:, county_id].values, n_people), 100000, ), "Raw Prediction Q5": prediction_q5.loc[:, county_id].values, "Raw Prediction Q5 100k": np.multiply( np.divide(prediction_q5.loc[:, county_id].values, n_people), 100000, ), "Raw Prediction Q95": prediction_q95.loc[:, county_id].values, "Raw Prediction Q95 100k": np.multiply( np.divide(prediction_q95.loc[:, county_id].values, n_people), 100000, ), "Trend Prediction Mean": prediction_mean_trend.loc[:, county_id].values, "Trend Prediction Mean 100k": np.multiply( np.divide(prediction_mean_trend.loc[:, county_id].values, n_people), 100000, ), "Trend Prediction Q25": prediction_q25_trend.loc[:, county_id].values, "Trend Prediction Q25 100k": np.multiply( np.divide(prediction_q25_trend.loc[:, county_id].values, n_people), 100000, ), "Trend Prediction Q75": prediction_q75_trend.loc[:, county_id].values, "Trend Prediction Q75 100k": np.multiply( np.divide(prediction_q75_trend.loc[:, county_id].values, n_people), 100000, ), "Trend Prediction Q5": prediction_q5_trend.loc[:, county_id].values, "Trend Prediction Q5 100k": np.multiply( np.divide(prediction_q5_trend.loc[:, county_id].values, n_people), 100000, ), "Trend Prediction Q95": prediction_q95_trend.loc[:, county_id].values, "Trend Prediction Q95 100k": np.multiply( np.divide(prediction_q95_trend.loc[:, county_id].values, n_people), 100000, ), "Trend 7Day Prediction Mean": prediction_mean_7day.loc[:, county_id].values, "Trend 7Day Prediction Mean 100k": np.multiply( np.divide(prediction_mean_7day.loc[:, county_id].values, n_people), 100000, ), "Trend 7Day Prediction Q25": prediction_q25_7day.loc[:, county_id].values, "Trend 7Day Prediction Q25 100k": np.multiply( np.divide(prediction_q25_7day.loc[:, county_id].values, n_people), 100000, ), "Trend 7Day Prediction Q75": prediction_q75_7day.loc[:, county_id].values, "Trend 7Day Prediction Q75 100k": np.multiply( np.divide(prediction_q75_7day.loc[:, county_id].values, n_people), 100000, ), "Trend 7Day Prediction Q5": prediction_q5_7day.loc[:, county_id].values, "Trend 7Day Prediction Q5 100k": np.multiply( np.divide(prediction_q5_7day.loc[:, county_id].values, n_people), 100000, ), "Trend 7Day Prediction Q95": prediction_q95_7day.loc[:, county_id].values, "Trend 7Day Prediction Q95 100k": np.multiply( np.divide(prediction_q95_7day.loc[:, county_id].values, n_people), 100000, ), "RKI Meldedaten": rki_data, "RKI 7Day Incidence": rki_data7day, "is_nowcast": (day_m5 <= ext_index) & (ext_index < day_0), "is_high": np.less(prediction_q95_trend.loc[:, county_id].values, rki_data), "is_prediction": (day_0 <= ext_index), }, index=ext_index, ) fpath = os.path.join(output_dir, "{}.csv".format(countyByName[county])) county_data.to_csv(fpath) map_df = pd.DataFrame(index=None) map_df["countyID"] = map_keys map_df["newInf100k"] = map_nowcast100k map_df["7DayInf100k"] = map_nowcast_7day100k map_df["newInf100k_RKI"] = map_rki100k map_df["7DayInf100k_RKI"] = map_rki_7day100k map_df["newInfRaw"] = map_nowcast map_df["7DayInfRaw"] = map_nowcast_7day map_df["newInfRaw_RKI"] = map_rki map_df["7DayInfRaw_RKI"] = map_rki_7day map_df.to_csv(os.path.join(output_dir, "map.csv"))
def forestplot(trace, models=None, varnames=None, transform=identity_transform, alpha=0.05, quartiles=True, rhat=True, main=None, xtitle=None, xlim=None, ylabels=None, colors='C0', chain_spacing=0.1, vline=0, gs=None, plot_transformed=False, plot_kwargs=None): """ Forest plot (model summary plot). Generates a "forest plot" of 100*(1-alpha)% credible intervals from a trace or list of traces. Parameters ---------- trace : trace or list of traces Trace(s) from an MCMC sample. models : list (optional) List with names for the models in the list of traces. Useful when plotting more that one trace. varnames: list List of variables to plot (defaults to None, which results in all variables plotted). transform : callable Function to transform data (defaults to identity) alpha : float, optional Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05). quartiles : bool, optional Flag for plotting the interquartile range, in addition to the (1-alpha)*100% intervals (defaults to True). rhat : bool, optional Flag for plotting Gelman-Rubin statistics. Requires 2 or more chains (defaults to True). main : string, optional Title for main plot. Passing False results in titles being suppressed; passing None (default) results in default titles. xtitle : string, optional Label for x-axis. Defaults to no label xlim : list or tuple, optional Range for x-axis. Defaults to matplotlib's best guess. ylabels : list or array, optional User-defined labels for each variable. If not provided, the node __name__ attributes are used. colors : list or string, optional list with valid matplotlib colors, one color per model. Alternative a string can be passed. If the string is `cycle `, it will automatically chose a color per model from the matyplolib's cycle. If a single color is passed, eg 'k', 'C2', 'red' this color will be used for all models. Defauls to 'C0' (blueish in most matplotlib styles) chain_spacing : float, optional Plot spacing between chains (defaults to 0.1). vline : numeric, optional Location of vertical reference line (defaults to 0). gs : GridSpec Matplotlib GridSpec object. Defaults to None. plot_transformed : bool Flag for plotting automatically transformed variables in addition to original variables (defaults to False). plot_kwargs : dict Optional arguments for plot elements. Currently accepts 'fontsize', 'linewidth', 'marker', and 'markersize'. Returns ------- gs : matplotlib GridSpec """ if plot_kwargs is None: plot_kwargs = {} if not isinstance(trace, (list, tuple)): trace = [trace] if models is None: if len(trace) > 1: models = ['m_{}'.format(i) for i in range(len(trace))] else: models = [''] elif len(models) != len(trace): raise ValueError("The number of names for the models does not match " "the number of models") if colors == 'cycle': colors = ['C{}'.format(i % 10) for i in range(len(models))] elif isinstance(colors, str): colors = [colors for i in range(len(models))] # Quantiles to be calculated if quartiles: qlist = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)] else: qlist = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)] nchains = [tr.nchains for tr in trace] if varnames is None: varnames = [] for idx, tr in enumerate(trace): varnames_tmp = get_default_varnames(tr.varnames, plot_transformed) for v in varnames_tmp: if v not in varnames: varnames.append(v) plot_rhat = [rhat and nch > 1 for nch in nchains] # Empty list for y-axis labels if gs is None: # Initialize plot if np.any(plot_rhat): gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1]) gr_plot = plt.subplot(gs[1]) gr_plot.set_xticks((1.0, 1.5, 2.0), ("1", "1.5", "2+")) gr_plot.set_xlim(0.9, 2.1) gr_plot.set_yticks([]) gr_plot.set_title('R-hat') else: gs = gridspec.GridSpec(1, 1) # Subplot for confidence intervals interval_plot = plt.subplot(gs[0]) trace_quantiles = [] hpd_intervals = [] for tr in trace: trace_quantiles.append(quantiles(tr, qlist, transform=transform, squeeze=False)) hpd_intervals.append(hpd(tr, alpha, transform=transform, squeeze=False)) labels = [] var = 0 all_quants = [] bands = [0.05, 0] * len(varnames) var_old = 0.5 for v_idx, v in enumerate(varnames): for h, tr in enumerate(trace): if v not in tr.varnames: labels.append(models[h] + ' ' + v) var += 1 else: for j, chain in enumerate(tr.chains): var_quantiles = trace_quantiles[h][chain][v] quants = [var_quantiles[vq] for vq in qlist] var_hpd = hpd_intervals[h][chain][v].T # Substitute HPD interval for quantile quants[0] = var_hpd[0].T quants[-1] = var_hpd[1].T # Ensure x-axis contains range of current interval all_quants.extend(np.ravel(quants)) # Number of elements in current variable value = tr.get_values(v, chains=[chain])[0] k = np.size(value) # Append variable name(s) to list if j == 0: if k > 1: names = _var_str(v, np.shape(value)) names[0] = models[h] + ' ' + names[0] labels += names else: labels.append(models[h] + ' ' + v) # Add spacing for each chain, if more than one offset = [0] + [(chain_spacing * ((i + 2) / 2)) * (-1) ** i for i in range(nchains[h] - 1)] # Y coordinate with offset y = - var + offset[j] # Deal with multivariate nodes if k > 1: qs = np.moveaxis(np.array(quants), 0, -1).squeeze() for q in qs.reshape(-1, len(quants)): # Multiple y values interval_plot = _plot_tree(interval_plot, y, q, quartiles, colors[h], plot_kwargs) y -= 1 else: interval_plot = _plot_tree(interval_plot, y, quants, quartiles, colors[h], plot_kwargs) # Genenerate Gelman-Rubin plot if plot_rhat[h] and v in tr.varnames: R = gelman_rubin(tr, [v]) if k > 1: Rval = dict2pd(R, 'rhat').values gr_plot.plot([min(r, 2) for r in Rval], [-(j + var) for j in range(k)], 'o', color=colors[h], markersize=4) else: gr_plot.plot(min(R[v], 2), -var, 'o', color=colors[h], markersize=4) var += k if len(trace) > 1: interval_plot.axhspan(var_old, y - chain_spacing - 0.5, facecolor='k', alpha=bands[v_idx]) if np.any(plot_rhat): gr_plot.axhspan(var_old, y - chain_spacing - 0.5, facecolor='k', alpha=bands[v_idx]) var_old = y - chain_spacing - 0.5 if ylabels is not None: labels = ylabels # Update margins left_margin = np.max([len(x) for x in labels]) * 0.015 gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05) # Define range of y-axis for forestplot and R-hat interval_plot.set_ylim(- var + 0.5, 0.5) if np.any(plot_rhat): gr_plot.set_ylim(- var + 0.5, 0.5) plotrange = [np.min(all_quants), np.max(all_quants)] datarange = plotrange[1] - plotrange[0] interval_plot.set_xlim(plotrange[0] - 0.05 * datarange, plotrange[1] + 0.05 * datarange) # Add variable labels interval_plot.set_yticks([- l for l in range(len(labels))]) interval_plot.set_yticklabels(labels, fontsize=plot_kwargs.get('fontsize', None)) # Add title if main is None: plot_title = "{:.0f}% Credible Intervals".format((1 - alpha) * 100) elif main: plot_title = main else: plot_title = "" interval_plot.set_title(plot_title, fontsize=plot_kwargs.get('fontsize', None)) # Add x-axis label if xtitle is not None: interval_plot.set_xlabel(xtitle) # Constrain to specified range if xlim is not None: interval_plot.set_xlim(*xlim) # Remove ticklines on y-axes for ticks in interval_plot.yaxis.get_major_ticks(): ticks.tick1On = False ticks.tick2On = False for loc, spine in interval_plot.spines.items(): if loc in ['left', 'right']: spine.set_color('none') # don't draw spine # Reference line interval_plot.axvline(vline, color='k', linestyle=':') return gs
def curves_appendix(model_i=15, save_plot=False): with open('../data/counties/counties.pkl', "rb") as f: counties = pkl.load(f) countyByName = OrderedDict([('Düsseldorf', '05111'), ('Recklinghausen', '05562'), ("Hannover", "03241"), ("Hamburg", "02000"), ("Berlin-Mitte", "11001"), ("Osnabrück", "03404"), ("Frankfurt (Main)", "06412"), ("Görlitz", "14626"), ("Stuttgart", "08111"), ("Potsdam", "12054"), ("Köln", "05315"), ("Aachen", "05334"), ("Rostock", "13003"), ("Flensburg", "01001"), ("Frankfurt (Oder)", "12053"), ("Lübeck", "01003"), ("Münster", "05515"), ("Berlin Neukölln", "11008"), ('Göttingen', "03159"), ("Cottbus", "12052"), ("Erlangen", "09562"), ("Regensburg", "09362"), ("Bayreuth", "09472"), ("Bautzen", "14625"), ('Nürnberg', '09564'), ('München', '09162'), ("Würzburg", "09679"), ("Deggendorf", "09271"), ("Ansbach", "09571"), ("Rottal-Inn", "09277"), ("Passau", "09275"), ("Schwabach", "09565"), ("Memmingen", "09764"), ("Erlangen-Höchstadt", "09572"), ("Nürnberger Land", "09574"), ('Roth', "09576"), ('Starnberg', "09188"), ('Berchtesgadener Land', "09172"), ('Schweinfurt', "09678"), ("Augsburg", "09772"), ('Neustadt a.d.Waldnaab', "09374"), ("Fürstenfeldbruck", "09179"), ('Rosenheim', "09187"), ("Straubing", "09263"), ("Erding", "09177"), ("Tirschenreuth", "09377"), ('Miltenberg', "09676"), ('Neumarkt i.d.OPf.', "09373"), ('Heinsberg', "05370"), ('Landsberg am Lech', "09181"), ('Rottal-Inn', "09277"), ("Tübingen", "08416"), ("Augsburg", "09772"), ("Bielefeld", "05711")]) plot_county_names = { "covid19": [ "Düsseldorf", "Heinsberg", "Hannover", "München", "Hamburg", "Berlin-Mitte", "Osnabrück", "Frankfurt (Main)", "Görlitz", "Stuttgart", "Landsberg am Lech", "Köln", "Rottal-Inn", "Rostock", "Flensburg", "Frankfurt (Oder)", "Lübeck", "Münster", "Berlin Neukölln", "Göttingen", "Bielefeld", "Tübingen", "Augsburg", "Bayreuth", "Nürnberg" ] } # colors for curves C1 = "#D55E00" C2 = "#E69F00" C3 = "#0073CF" i = 0 disease = "covid19" prediction_region = "germany" data = load_daily_data(disease, prediction_region, counties) start_day = pd.Timestamp('2020-03-01') i_start_day = (start_day - data.index.min()).days day_0 = pd.Timestamp('2020-05-21') day_m5 = day_0 - pd.Timedelta(days=5) day_p5 = day_0 + pd.Timedelta(days=5) _, target, _, _ = split_data(data, train_start=start_day, test_start=day_0, post_test=day_p5) county_ids = target.columns # Load our prediction samples res = load_final_pred() n_days = (day_p5 - start_day).days prediction_samples = np.reshape(res['y'], (res['y'].shape[0], -1, 412)) prediction_samples = prediction_samples[:, i_start_day:i_start_day + n_days, :] prediction_quantiles = quantiles(prediction_samples, (5, 25, 75, 95)) ext_index = pd.DatetimeIndex([d for d in target.index] + \ [d for d in pd.date_range(target.index[-1]+timedelta(1),day_p5-timedelta(1))]) #print(ext_index) # print(prediction_samples.shape) prediction_mean = pd.DataFrame( data=np.mean(prediction_samples, axis=0), index=ext_index, #index=target.index, columns=target.columns) prediction_q25 = pd.DataFrame( data=prediction_quantiles[25], index=ext_index, #index=target.index, columns=target.columns) prediction_q75 = pd.DataFrame( data=prediction_quantiles[75], index=ext_index, #index=target.index, columns=target.columns) prediction_q5 = pd.DataFrame( data=prediction_quantiles[5], index=ext_index, #index=target.index, columns=target.columns) prediction_q95 = pd.DataFrame( data=prediction_quantiles[95], index=ext_index, #index=target.index, columns=target.columns) fig = plt.figure(figsize=(12, 12)) grid = plt.GridSpec(5, 5, top=0.90, bottom=0.11, left=0.07, right=0.92, hspace=0.2, wspace=0.3) for j, name in enumerate(plot_county_names[disease]): # TODO: this should be incapsulated as plot_curve(county) (, days) ax = fig.add_subplot(grid[np.unravel_index(list(range(25))[j], (5, 5))]) county_id = countyByName[name] dates = [pd.Timestamp(day) for day in ext_index] days = [(day - min(dates)).days + 1 for day in dates] # plot our predictions w/ quartiles p_pred = ax.plot(dates, prediction_mean[county_id], "-", color=C1, linewidth=2.0, zorder=4) p_quant = ax.fill_between(dates, prediction_q25[county_id], prediction_q75[county_id], facecolor=C2, alpha=0.5, zorder=1) ax.plot(dates, prediction_q25[county_id], ":", color=C2, linewidth=2.0, zorder=3) ax.plot(dates, prediction_q75[county_id], ":", color=C2, linewidth=2.0, zorder=3) p_real = ax.plot(dates[:-5], target[county_id], "k.") ax.set_title(name, fontsize=18) #days = [i+1 for i in range(len(dates))] #ax.set_xticks(days[::5]) ticks = [ '2020-03-01', '2020-03-12', '2020-03-22', '2020-04-01', '2020-04-11', '2020-04-21', '2020-05-1', '2020-05-11', '2020-05-21' ] labels = [ '0', '10', '20', '30', '40', '50', '60', '70', '80', ] plt.xticks(ticks, labels) ax.tick_params(axis="both", direction='out', size=2, labelsize=14) plt.setp(ax.get_xticklabels(), visible=False) if j > 19: plt.setp(ax.get_xticklabels(), rotation=60) plt.setp(ax.get_xticklabels()[::2], visible=True) ax.autoscale(False) p_quant2 = ax.fill_between(days, prediction_q5[county_id], prediction_q95[county_id], facecolor=C2, alpha=0.25, zorder=0) ax.plot(days, prediction_q5[county_id], ":", color=C2, alpha=0.5, linewidth=2.0, zorder=1) ax.plot(days, prediction_q95[county_id], ":", color=C2, alpha=0.5, linewidth=2.0, zorder=1) # Plot blue line for indicating where predictions start. ax.axvline(dates[-5], ls='-', lw=2, c='cornflowerblue') ax.axvline(dates[-10], ls='--', lw=2, c='cornflowerblue') plt.legend( [p_real[0], p_pred[0], p_quant, p_quant2], ["reported", "predicted", "25\%-75\% quantile", "5\%-95\% quantile"], fontsize=16, ncol=5, loc="upper center", bbox_to_anchor=(0, -0.01, 1, 1), bbox_transform=plt.gcf().transFigure) fig.text(0.5, 0.02, "Time [days since Mar. 01]", ha='center', fontsize=22) fig.text(0.01, 0.46, "Reported/predicted infections", va='center', rotation='vertical', fontsize=22) if save_plot: plt.savefig("../figures/curves_{}_appendix_{}.pdf".format( disease, model_i))
def temporal_contribution(model_i=15, combinations=combinations, save_plot=False): use_ia, use_report_delay, use_demographics, trend_order, periodic_order = combinations[ model_i] plt.style.use('ggplot') with open('../data/counties/counties.pkl', "rb") as f: county_info = pkl.load(f) C1 = "#D55E00" C2 = "#E69F00" C3 = C2 # "#808080" if use_report_delay: fig = plt.figure(figsize=(25, 10)) grid = plt.GridSpec(4, 1, top=0.93, bottom=0.12, left=0.11, right=0.97, hspace=0.28, wspace=0.30) else: fig = plt.figure(figsize=(16, 10)) grid = plt.GridSpec(3, 1, top=0.93, bottom=0.12, left=0.11, right=0.97, hspace=0.28, wspace=0.30) disease = "covid19" prediction_region = "germany" data = load_daily_data(disease, prediction_region, county_info) first_day = pd.Timestamp('2020-04-01') last_day = data.index.max() _, target_train, _, _ = split_data( data, train_start=first_day, test_start=last_day - pd.Timedelta(days=1), post_test=last_day + pd.Timedelta(days=1)) tspan = (target_train.index[0], target_train.index[-1]) model = BaseModel( tspan, county_info, [ "../data/ia_effect_samples/{}_{}.pkl".format(disease, i) for i in range(100) ], include_ia=use_ia, include_report_delay=use_report_delay, include_demographics=True, trend_poly_order=4, periodic_poly_order=4) features = model.evaluate_features(target_train.index, target_train.columns) trend_features = features["temporal_trend"].swaplevel(0, 1).loc["09162"] periodic_features = features["temporal_seasonal"].swaplevel(0, 1).loc["09162"] #t_all = t_all_b if disease == "borreliosis" else t_all_cr trace = load_final_trace() trend_params = pm.trace_to_dataframe(trace, varnames=["W_t_t"]) periodic_params = pm.trace_to_dataframe(trace, varnames=["W_t_s"]) TT = trend_params.values.dot(trend_features.values.T) TP = periodic_params.values.dot(periodic_features.values.T) TTP = TT + TP # add report delay if used #if use_report_delay: # delay_features = features["temporal_report_delay"].swaplevel(0,1).loc["09162"] # delay_params = pm.trace_to_dataframe(trace,varnames=["W_t_d"]) # TD =delay_params.values.dot(delay_features.values.T) # TTP += TD # TD_quantiles = quantiles(TD, (25, 75)) TT_quantiles = quantiles(TT, (25, 75)) TP_quantiles = quantiles(TP, (25, 75)) TTP_quantiles = quantiles(TTP, (2.5, 25, 75, 97.5)) dates = [pd.Timestamp(day) for day in target_train.index.values] days = [(day - min(dates)).days for day in dates] # Temporal trend+periodic effect if use_report_delay: ax_tp = fig.add_subplot(grid[0, 0]) else: ax_tp = fig.add_subplot(grid[0, 0]) ax_tp.fill_between(days, np.exp(TTP_quantiles[25]), np.exp(TTP_quantiles[75]), alpha=0.5, zorder=1, facecolor=C1) ax_tp.plot(days, np.exp(TTP.mean(axis=0)), "-", color=C1, lw=2, zorder=5) ax_tp.plot(days, np.exp(TTP_quantiles[25]), "-", color=C2, lw=2, zorder=3) ax_tp.plot(days, np.exp(TTP_quantiles[75]), "-", color=C2, lw=2, zorder=3) ax_tp.plot(days, np.exp(TTP_quantiles[2.5]), "--", color=C2, lw=2, zorder=3) ax_tp.plot(days, np.exp(TTP_quantiles[97.5]), "--", color=C2, lw=2, zorder=3) #ax_tp.plot(days, np.exp(TTP[:25, :].T), # "--", color=C3, lw=1, alpha=0.5, zorder=2) ax_tp.tick_params(axis="x", rotation=45) # Temporal trend effect ax_t = fig.add_subplot(grid[1, 0], sharex=ax_tp) #ax_t.fill_between(days, np.exp(TT_quantiles[25]), np.exp( # TT_quantiles[75]), alpha=0.5, zorder=1, facecolor=C1) ax_t.plot(days, np.exp(TT.mean(axis=0)), "-", color=C1, lw=2, zorder=5) #ax_t.plot(days, np.exp( # TT_quantiles[25]), "-", color=C2, lw=2, zorder=3) #ax_t.plot(days, np.exp( # TT_quantiles[75]), "-", color=C2, lw=2, zorder=3) #ax_t.plot(days, np.exp(TT[:25, :].T), # "--", color=C3, lw=1, alpha=0.5, zorder=2) ax_t.tick_params(axis="x", rotation=45) ax_t.ticklabel_format(axis="y", style="sci", scilimits=(0, 0)) # Temporal periodic effect ax_p = fig.add_subplot(grid[2, 0], sharex=ax_tp) #ax_p.fill_between(days, np.exp(TP_quantiles[25]), np.exp( # TP_quantiles[75]), alpha=0.5, zorder=1, facecolor=C1) ax_p.plot(days, np.exp(TP.mean(axis=0)), "-", color=C1, lw=2, zorder=5) ax_p.set_ylim([-0.0001, 0.001]) #ax_p.plot(days, np.exp( # TP_quantiles[25]), "-", color=C2, lw=2, zorder=3) #ax_p.plot(days, np.exp( # TP_quantiles[75]), "-", color=C2, lw=2, zorder=3) #ax_p.plot(days, np.exp(TP[:25, :].T), # "--", color=C3, lw=1, alpha=0.5, zorder=2) ax_p.tick_params(axis="x", rotation=45) ax_p.ticklabel_format(axis="y", style="sci", scilimits=(0, 0)) ticks = [ '2020-03-02', '2020-03-12', '2020-03-22', '2020-04-01', '2020-04-11', '2020-04-21', '2020-05-1', '2020-05-11', '2020-05-21' ] labels = [ '02.03.2020', '12.03.2020', '22.03.2020', '01.04.2020', '11.04.2020', '21.04.2020', '01.05.2020', '11.05.2020', '21.05.2020' ] #if use_report_delay: # ax_td = fig.add_subplot(grid[2, 0], sharex=ax_p) # # ax_td.fill_between(days, np.exp(TD_quantiles[25]), np.exp( # TD_quantiles[75]), alpha=0.5, zorder=1, facecolor=C1) # ax_td.plot(days, np.exp(TD.mean(axis=0)), # "-", color=C1, lw=2, zorder=5) # ax_td.plot(days, np.exp( # TD_quantiles[25]), "-", color=C2, lw=2, zorder=3) # ax_td.plot(days, np.exp( # TD_quantiles[75]), "-", color=C2, lw=2, zorder=3) # ax_td.plot(days, np.exp(TD[:25, :].T), # "--", color=C3, lw=1, alpha=0.5, zorder=2) # ax_td.tick_params(axis="x", rotation=45) #ax_tp.set_title("campylob." if disease == # "campylobacter" else disease, fontsize=22) ax_p.set_xlabel("time [days]", fontsize=22) ax_p.set_ylabel("periodic\ncontribution", fontsize=22) ax_t.set_ylabel("trend\ncontribution", fontsize=22) ax_tp.set_ylabel("combined\ncontribution", fontsize=22) #if use_report_delay: # ax_td.set_ylabel("r.delay\ncontribution", fontsize=22) ax_t.set_xlim(days[0], days[-1]) ax_t.tick_params(labelbottom=False, labelleft=True, labelsize=18, length=6) ax_p.tick_params(labelbottom=True, labelleft=True, labelsize=18, length=6) ax_tp.tick_params(labelbottom=False, labelleft=True, labelsize=18, length=6) #ax_p.set_xticks(ticks)#,labels) if save_plot: fig.savefig("../figures/temporal_contribution_{}.pdf".format(model_i))