def curves(start, county, n_weeks=3, model_i=35, save_plot=False):

    with open('../data/counties/counties.pkl', "rb") as f:
        counties = pkl.load(f)
    start = int(start)
    n_weeks = int(n_weeks)
    model_i = int(model_i)
    # with open('../data/comparison.pkl', "rb") as f:
    #     best_model = pkl.load(f)

    # update to day and new limits!
    xlim = (5.5, 15.5)
    ylim = (47, 56)  # <- 10 weeks

    #countyByName = OrderedDict(
    #    [('Düsseldorf', '05111'), ('Leipzig', '14713'), ('Nürnberg', '09564'), ('München', '09162')])
    countyByName = make_county_dict()
    # Hier dann das reinspeisen
    plot_county_names = {"covid19": [county]}
    start_day = pd.Timestamp('2020-01-28') + pd.Timedelta(days=start)
    year = str(start_day)[:4]
    month = str(start_day)[5:7]
    day = str(start_day)[8:10]
    # if os.path.exists("../figures/{}_{}_{}/curve_trend_{}.png".format(year, month, day,countyByName[county])):
    #    return

    day_folder_path = "../figures/{}_{}_{}".format(year, month, day)
    Path(day_folder_path).mkdir(parents=True, exist_ok=True)

    # check for metadata file:
    if not os.path.isfile("../figures/{}_{}_{}/metadata.csv".format(
            year, month, day)):
        ids = []
        for key in counties:
            ids.append(int(key))
        df = pd.DataFrame(data=ids, columns=["countyID"])
        df["probText"] = ""
        df.to_csv("../figures/{}_{}_{}/metadata.csv".format(year, month, day))

# colors for curves
#red
    C1 = "#D55E00"
    C2 = "#E69F00"
    #C3 = "#0073CF"
    #green
    C4 = "#188500"
    C5 = "#29c706"
    #C6 = "#0073CF"

    # quantiles we want to plot
    qs = [0.25, 0.50, 0.75]

    fig = plt.figure(figsize=(12, 6))
    grid = plt.GridSpec(
        1,
        1,
        top=0.9,
        bottom=0.2,
        left=0.07,
        right=0.97,
        hspace=0.25,
        wspace=0.15,
    )

    # for i, disease in enumerate(diseases):
    i = 0
    disease = "covid19"
    prediction_region = "germany"
    data = load_daily_data_n_weeks(start, n_weeks, disease, prediction_region,
                                   counties)

    start_day = pd.Timestamp('2020-01-28') + pd.Timedelta(days=start)
    i_start_day = 0
    day_0 = start_day + pd.Timedelta(days=n_weeks * 7 + 5)
    day_m5 = day_0 - pd.Timedelta(days=5)
    day_p5 = day_0 + pd.Timedelta(days=5)

    _, target, _, _ = split_data(data,
                                 train_start=start_day,
                                 test_start=day_0,
                                 post_test=day_p5)

    county_ids = target.columns
    county_id = countyByName[county]

    ### SELECTION CRITERION ###
    #if np.count_non_zero(target[county_id]) < 7: #???
    #    stdd = 10
    #    gaussian = lambda x: np.exp( (-(x)**2) / (2* stdd**2) )

    # Load our prediction samples
    res = load_pred_model_window(model_i, start, n_weeks)
    res_trend = load_pred_model_window(model_i, start, n_weeks, trend=True)
    n_days = (day_p5 - start_day).days
    prediction_samples = np.reshape(res['y'], (res['y'].shape[0], -1, 412))
    prediction_samples_trend = np.reshape(res_trend['μ'],
                                          (res_trend['μ'].shape[0], -1, 412))
    prediction_samples = prediction_samples[:, i_start_day:i_start_day +
                                            n_days, :]
    prediction_samples_trend = prediction_samples_trend[:, i_start_day:
                                                        i_start_day +
                                                        n_days, :]
    ext_index = pd.DatetimeIndex([d for d in target.index] + \
            [d for d in pd.date_range(target.index[-1]+timedelta(1),day_p5-timedelta(1))])

    # TODO: figure out where quantiles comes from and if its pymc3, how to replace it
    prediction_quantiles = quantiles(prediction_samples, (5, 25, 75, 95))

    prediction_mean = pd.DataFrame(data=np.mean(prediction_samples, axis=0),
                                   index=ext_index,
                                   columns=target.columns)
    prediction_q25 = pd.DataFrame(data=prediction_quantiles[25],
                                  index=ext_index,
                                  columns=target.columns)
    prediction_q75 = pd.DataFrame(data=prediction_quantiles[75],
                                  index=ext_index,
                                  columns=target.columns)
    prediction_q5 = pd.DataFrame(data=prediction_quantiles[5],
                                 index=ext_index,
                                 columns=target.columns)
    prediction_q95 = pd.DataFrame(data=prediction_quantiles[95],
                                  index=ext_index,
                                  columns=target.columns)

    prediction_mean_trend = pd.DataFrame(data=np.mean(prediction_samples_trend,
                                                      axis=0),
                                         index=ext_index,
                                         columns=target.columns)

    # Unnecessary for-loop
    for j, name in enumerate(plot_county_names[disease]):
        ax = fig.add_subplot(grid[j, i])

        county_id = countyByName[name]
        dates = [pd.Timestamp(day) for day in ext_index]
        days = [(day - min(dates)).days for day in dates]

        # plot our predictions w/ quartiles
        p_pred = ax.plot_date(dates,
                              prediction_mean[county_id],
                              "-",
                              color=C1,
                              linewidth=2.0,
                              zorder=4)
        # plot our predictions w/ quartiles

        p_quant = ax.fill_between(dates,
                                  prediction_q25[county_id],
                                  prediction_q75[county_id],
                                  facecolor=C2,
                                  alpha=0.5,
                                  zorder=1)
        ax.plot_date(dates,
                     prediction_q25[county_id],
                     ":",
                     color=C2,
                     linewidth=2.0,
                     zorder=3)
        ax.plot_date(dates,
                     prediction_q75[county_id],
                     ":",
                     color=C2,
                     linewidth=2.0,
                     zorder=3)

        # plot ground truth
        p_real = ax.plot_date(dates[:-5], target[county_id], "k.")
        print(dates[-5] - pd.Timedelta(12, unit='h'))
        # plot 30week marker
        ax.axvline(dates[-5] - pd.Timedelta(12, unit='h'),
                   ls='-',
                   lw=2,
                   c='dodgerblue')
        ax.axvline(dates[-10] - pd.Timedelta(12, unit='h'),
                   ls='--',
                   lw=2,
                   c='lightskyblue')

        ax.set_ylabel("Fallzahlen/Tag nach Meldedatum", fontsize=16)
        ax.tick_params(axis="both",
                       direction='out',
                       size=6,
                       labelsize=16,
                       length=6)
        ticks = [
            start_day + pd.Timedelta(days=i)
            for i in [0, 5, 10, 15, 20, 25, 30, 35, 40]
        ]
        labels = [
            "{}.{}.{}".format(str(d)[8:10],
                              str(d)[5:7],
                              str(d)[:4]) for d in ticks
        ]

        plt.xticks(ticks, labels)
        #new_ticks = plt.get_xtickslabels()
        plt.setp(ax.get_xticklabels()[-4], color="red")
        plt.setp(ax.get_xticklabels(), rotation=45)

        ax.autoscale(True)
        p_quant2 = ax.fill_between(dates,
                                   prediction_q5[county_id],
                                   prediction_q95[county_id],
                                   facecolor=C2,
                                   alpha=0.25,
                                   zorder=0)
        ax.plot_date(dates,
                     prediction_q5[county_id],
                     ":",
                     color=C2,
                     alpha=0.5,
                     linewidth=2.0,
                     zorder=1)
        ax.plot_date(dates,
                     prediction_q95[county_id],
                     ":",
                     color=C2,
                     alpha=0.5,
                     linewidth=2.0,
                     zorder=1)

        # Plot the trend.
        '''
        p_pred_trend = ax.plot_date(
                        dates,
                        prediction_mean_trend[county_id],
                        "-",
                        color="green",
                        linewidth=2.0,
                        zorder=4)
        '''
        # Compute probability of increase/decreas
        i_county = county_ids.get_loc(county_id)
        trace = load_trace_window(disease, model_i, start, n_weeks)
        trend_params = pm.trace_to_dataframe(trace, varnames=["W_t_t"]).values
        trend_w2 = np.reshape(trend_params,
                              newshape=(1000, 412, 2))[:, i_county, 1]
        prob2 = np.mean(trend_w2 > 0)

        # Set axis limits.
        ylimmax = max(3 * (target[county_id]).max(), 10)
        ax.set_ylim([-(1 / 30) * ylimmax, ylimmax])
        ax.set_xlim([start_day, day_p5 - pd.Timedelta(days=1)])

        if (i == 0) & (j == 0):
            ax.legend([p_real[0], p_pred[0], p_quant, p_quant2], [
                "Daten RKI", "Modell", "25\%-75\%-Quantil", "5\%-95\%-Quantil"
            ],
                      fontsize=16,
                      loc="upper left")

        # Not perfectly positioned.
        print("uheufbhwio")
        print(ax.get_xticks()[-5])
        print(ax.get_ylim()[1])
        pos1 = tuple(
            ax.transData.transform((ax.get_xticks()[-3], ax.get_ylim()[1])))
        pos1 = (ax.get_xticks()[-5], ax.get_ylim()[1])
        print(pos1)
        fontsize_bluebox = 18
        fig.text(ax.get_xticks()[-5] + 0.65,
                 ax.get_ylim()[1],
                 "Nowcast",
                 ha="left",
                 va="top",
                 fontsize=fontsize_bluebox,
                 bbox=dict(facecolor='lightskyblue', boxstyle='rarrow'),
                 transform=ax.transData)
        # fig.text(pos1[0]/1200, pos1[1]/600,"Nowcast",fontsize=fontsize_bluebox,bbox=dict(facecolor='cornflowerblue'))
        fig.text(ax.get_xticks()[-4] + 0.65,
                 ax.get_ylim()[1],
                 "Forecast",
                 ha="left",
                 va="top",
                 fontsize=fontsize_bluebox,
                 bbox=dict(facecolor='dodgerblue', boxstyle='rarrow'),
                 transform=ax.transData)
        ''' 
        fig.text(0,
                1 + 0.025,
                r"$\textbf{"  + plot_county_names["covid19"][j]+ r"}$",
                fontsize=22,
                transform=ax.transAxes)
        '''

        #plt.yticks(ax.get_yticks()[:-1], ax.get_yticklabels()[:-1])
        # Store text in csv.
        #fontsize_probtext = 14
        if prob2 >= 0.5:
            #fig.text(0.865, 0.685, "Die Fallzahlen \n werden mit einer \n Wahrscheinlichkeit \n von {:2.1f}\% steigen.".format(prob2*100), fontsize=fontsize_probtext,bbox=dict(facecolor='white'))
            probText = "Die Fallzahlen werden mit einer Wahrscheinlichkeit von {:2.1f}\% steigen.".format(
                prob2 * 100)
        else:
            probText = "Die Fallzahlen werden mit einer Wahrscheinlichkeit von {:2.1f}\% fallen.".format(
                100 - prob2 * 100)
            #fig.text(0.865, 0.685, "Die Fallzahlen \n werden mit einer \n Wahrscheinlichkeit \n von {:2.1f}\% fallen.".format(100-prob2*100), fontsize=fontsize_probtext ,bbox=dict(facecolor='white'))

        print(county_id)
        df = pd.read_csv("../figures/{}_{}_{}/metadata.csv".format(
            year, month, day),
                         index_col=0)
        county_ix = df["countyID"][df["countyID"] == int(county_id)].index[0]
        if prob2 >= 0.5:
            probVal = prob2 * 100
        else:
            probVal = -(100 - prob2 * 100)
        df.iloc[county_ix, 1] = probVal  #= probText
        df.to_csv("../figures/{}_{}_{}/metadata.csv".format(year, month, day))
        print(probVal)

        plt.tight_layout()
    if save_plot:
        year = str(start_day)[:4]
        month = str(start_day)[5:7]
        day = str(start_day)[8:10]
        day_folder_path = "../figures/{}_{}_{}".format(year, month, day)
        Path(day_folder_path).mkdir(parents=True, exist_ok=True)

        plt.savefig("../figures/{}_{}_{}/curve_{}.png".format(
            year, month, day, countyByName[county]),
                    dpi=200)

    plt.close()
    return fig
Beispiel #2
0
def forestplot(trace_obj,
               varnames=None,
               transform=identity_transform,
               alpha=0.05,
               quartiles=True,
               rhat=True,
               main=None,
               xtitle=None,
               xlim=None,
               ylabels=None,
               chain_spacing=0.05,
               vline=0,
               gs=None,
               plot_transformed=False,
               **plot_kwargs):
    """
    Forest plot (model summary plot).

    Generates a "forest plot" of 100*(1-alpha)% credible intervals for either
    the set of variables in a given model, or a specified set of nodes.

    Parameters
    ----------

    trace_obj: NpTrace or MultiTrace object
        Trace(s) from an MCMC sample.
    varnames: list
        List of variables to plot (defaults to None, which results in all
        variables plotted).
    transform : callable
        Function to transform data (defaults to identity)
    alpha (optional): float
        Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05).
    quartiles (optional): bool
        Flag for plotting the interquartile range, in addition to the
        (1-alpha)*100% intervals (defaults to True).
    rhat (optional): bool
        Flag for plotting Gelman-Rubin statistics. Requires 2 or more chains
        (defaults to True).
    main (optional): string
        Title for main plot. Passing False results in titles being suppressed;
        passing None (default) results in default titles.
    xtitle (optional): string
        Label for x-axis. Defaults to no label
    xlim (optional): list or tuple
        Range for x-axis. Defaults to matplotlib's best guess.
    ylabels (optional): list or array
        User-defined labels for each variable. If not provided, the node
        __name__ attributes are used.
    chain_spacing (optional): float
        Plot spacing between chains (defaults to 0.05).
    vline (optional): numeric
        Location of vertical reference line (defaults to 0).
    gs : GridSpec
        Matplotlib GridSpec object. Defaults to None.
    plot_transformed : bool
        Flag for plotting automatically transformed variables in addition to
        original variables (defaults to False).
    plot_kwargs : dict
        Optional arguments for plot elements. Currently accepts 'fontsize',
        'linewidth', 'color', 'marker', and 'markersize'.

    Returns
    -------

    gs : matplotlib GridSpec

    """
    # Quantiles to be calculated
    if quartiles:
        qlist = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)]
    else:
        qlist = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)]

    # Range for x-axis
    plotrange = None

    # Subplots
    interval_plot = None

    nchains = trace_obj.nchains

    if varnames is None:
        varnames = get_default_varnames(trace_obj.varnames, plot_transformed)

    plot_rhat = (rhat and nchains > 1)
    # Empty list for y-axis labels
    if gs is None:
        # Initialize plot
        if plot_rhat:
            gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1])
        else:
            gs = gridspec.GridSpec(1, 1)

        # Subplot for confidence intervals
        interval_plot = plt.subplot(gs[0])

    trace_quantiles = quantiles(trace_obj,
                                qlist,
                                transform=transform,
                                squeeze=False)
    hpd_intervals = hpd(trace_obj, alpha, transform=transform, squeeze=False)

    labels = []
    for j, chain in enumerate(trace_obj.chains):
        # Counter for current variable
        var = 0
        for varname in varnames:
            var_quantiles = trace_quantiles[chain][varname]

            quants = [var_quantiles[v] for v in qlist]
            var_hpd = hpd_intervals[chain][varname].T

            # Substitute HPD interval for quantile
            quants[0] = var_hpd[0].T
            quants[-1] = var_hpd[1].T

            # Ensure x-axis contains range of current interval
            if plotrange:
                plotrange = [
                    min(plotrange[0], np.min(quants)),
                    max(plotrange[1], np.max(quants))
                ]
            else:
                plotrange = [np.min(quants), np.max(quants)]

            # Number of elements in current variable
            value = trace_obj.get_values(varname, chains=[chain])[0]
            k = np.size(value)

            # Append variable name(s) to list
            if j == 0:
                if k > 1:
                    names = _var_str(varname, np.shape(value))
                    labels += names
                else:
                    labels.append(varname)

            # Add spacing for each chain, if more than one
            offset = [0] + [(chain_spacing * ((i + 2) / 2)) * (-1)**i
                            for i in range(nchains - 1)]

            # Y coordinate with offset
            y = -var + offset[j]

            # Deal with multivariate nodes
            if k > 1:
                for q in np.transpose(quants).squeeze():
                    # Multiple y values
                    interval_plot = _plot_tree(interval_plot, y, q, quartiles,
                                               **plot_kwargs)
                    y -= 1
            else:
                interval_plot = _plot_tree(interval_plot, y, quants, quartiles,
                                           **plot_kwargs)

            # Increment index
            var += k

    labels = ylabels if ylabels is not None else labels

    # Update margins
    left_margin = np.max([len(x) for x in labels]) * 0.015
    gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05)

    # Define range of y-axis
    interval_plot.set_ylim(-var + 0.5, 0.5)

    datarange = plotrange[1] - plotrange[0]
    interval_plot.set_xlim(plotrange[0] - 0.05 * datarange,
                           plotrange[1] + 0.05 * datarange)

    # Add variable labels
    interval_plot.set_yticks([-l for l in range(len(labels))])
    interval_plot.set_yticklabels(labels,
                                  fontsize=plot_kwargs.get('fontsize', None))

    # Add title
    plot_title = ""
    if main is None:
        plot_title = "{:.0f}% Credible Intervals".format((1 - alpha) * 100)
    elif main:
        plot_title = main
    if plot_title:
        interval_plot.set_title(plot_title,
                                fontsize=plot_kwargs.get('fontsize', None))

    # Add x-axis label
    if xtitle is not None:
        interval_plot.set_xlabel(xtitle)

    # Constrain to specified range
    if xlim is not None:
        interval_plot.set_xlim(*xlim)

    # Remove ticklines on y-axes
    for ticks in interval_plot.yaxis.get_major_ticks():
        ticks.tick1On = False
        ticks.tick2On = False

    for loc, spine in interval_plot.spines.items():
        if loc in ['left', 'right']:
            spine.set_color('none')  # don't draw spine

    # Reference line
    interval_plot.axvline(vline, color='k', linestyle=':')

    # Genenerate Gelman-Rubin plot
    if plot_rhat:
        _make_rhat_plot(trace_obj, plt.subplot(gs[1]), "R-hat", labels,
                        varnames, plot_transformed)

    return gs
Beispiel #3
0
def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.05, quartiles=True,
               rhat=True, main=None, xtitle=None, xlim=None, ylabels=None,
               chain_spacing=0.05, vline=0, gs=None, plot_transformed=False):
    """
    Forest plot (model summary plot).

    Generates a "forest plot" of 100*(1-alpha)% credible intervals for either
    the set of variables in a given model, or a specified set of nodes.

    Parameters
    ----------

    trace_obj: NpTrace or MultiTrace object
        Trace(s) from an MCMC sample.
    varnames: list
        List of variables to plot (defaults to None, which results in all
        variables plotted).
    transform : callable
        Function to transform data (defaults to identity)
    alpha (optional): float
        Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05).
    quartiles (optional): bool
        Flag for plotting the interquartile range, in addition to the
        (1-alpha)*100% intervals (defaults to True).
    rhat (optional): bool
        Flag for plotting Gelman-Rubin statistics. Requires 2 or more chains
        (defaults to True).
    main (optional): string
        Title for main plot. Passing False results in titles being suppressed;
        passing None (default) results in default titles.
    xtitle (optional): string
        Label for x-axis. Defaults to no label
    xlim (optional): list or tuple
        Range for x-axis. Defaults to matplotlib's best guess.
    ylabels (optional): list or array
        User-defined labels for each variable. If not provided, the node
        __name__ attributes are used.
    chain_spacing (optional): float
        Plot spacing between chains (defaults to 0.05).
    vline (optional): numeric
        Location of vertical reference line (defaults to 0).
    gs : GridSpec
        Matplotlib GridSpec object. Defaults to None.
    plot_transformed : bool
        Flag for plotting automatically transformed variables in addition to
        original variables (defaults to False).

    Returns
    -------

    gs : matplotlib GridSpec

    """
    # Quantiles to be calculated
    if quartiles:
        qlist = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)]
    else:
        qlist = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)]

    # Range for x-axis
    plotrange = None

    # Subplots
    interval_plot = None

    nchains = trace_obj.nchains

    if varnames is None:
        varnames = get_default_varnames(trace_obj, plot_transformed)

    plot_rhat = (rhat and nchains > 1)
    # Empty list for y-axis labels
    if gs is None:
        # Initialize plot
        if plot_rhat:
            gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1])
        else:
            gs = gridspec.GridSpec(1, 1)

        # Subplot for confidence intervals
        interval_plot = plt.subplot(gs[0])

    trace_quantiles = quantiles(trace_obj, qlist, transform=transform, squeeze=False)
    hpd_intervals = hpd(trace_obj, alpha, transform=transform, squeeze=False)

    labels = []
    for j, chain in enumerate(trace_obj.chains):
        # Counter for current variable
        var = 0
        for varname in varnames:
            var_quantiles = trace_quantiles[chain][varname]

            quants = [var_quantiles[v] for v in qlist]
            var_hpd = hpd_intervals[chain][varname].T

            # Substitute HPD interval for quantile
            quants[0] = var_hpd[0].T
            quants[-1] = var_hpd[1].T

            # Ensure x-axis contains range of current interval
            if plotrange:
                plotrange = [min(plotrange[0], np.min(quants)),
                             max(plotrange[1], np.max(quants))]
            else:
                plotrange = [np.min(quants), np.max(quants)]

            # Number of elements in current variable
            value = trace_obj.get_values(varname, chains=[chain])[0]
            k = np.size(value)

            # Append variable name(s) to list
            if j == 0:
                if k > 1:
                    names = _var_str(varname, np.shape(value))
                    labels += names
                else:
                    labels.append(varname)

            # Add spacing for each chain, if more than one
            offset = [0] + [(chain_spacing * ((i + 2) / 2)) * (-1) ** i for i in range(nchains - 1)]

            # Y coordinate with offset
            y = -var + offset[j]

            # Deal with multivariate nodes
            if k > 1:
                for q in np.transpose(quants).squeeze():
                    # Multiple y values
                    interval_plot = _plot_tree(interval_plot, y, q, quartiles)
                    y -= 1
            else:
                interval_plot = _plot_tree(interval_plot, y, quants, quartiles)

            # Increment index
            var += k

    labels = ylabels if ylabels is not None else labels

    # Update margins
    left_margin = np.max([len(x) for x in labels]) * 0.015
    gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05)

    # Define range of y-axis
    interval_plot.set_ylim(-var + 0.5, 0.5)

    datarange = plotrange[1] - plotrange[0]
    interval_plot.set_xlim(plotrange[0] - 0.05 * datarange, plotrange[1] + 0.05 * datarange)

    # Add variable labels
    interval_plot.set_yticks([-l for l in range(len(labels))])
    interval_plot.set_yticklabels(labels)

    # Add title
    plot_title = ""
    if main is None:
        plot_title = "{:.0f}% Credible Intervals".format((1 - alpha) * 100)
    elif main:
        plot_title = main
    if plot_title:
        interval_plot.set_title(plot_title)

    # Add x-axis label
    if xtitle is not None:
        interval_plot.set_xlabel(xtitle)

    # Constrain to specified range
    if xlim is not None:
        interval_plot.set_xlim(*xlim)

    # Remove ticklines on y-axes
    for ticks in interval_plot.yaxis.get_major_ticks():
        ticks.tick1On = False
        ticks.tick2On = False

    for loc, spine in interval_plot.spines.items():
        if loc in ['left', 'right']:
            spine.set_color('none')  # don't draw spine

    # Reference line
    interval_plot.axvline(vline, color='k', linestyle='--')

    # Genenerate Gelman-Rubin plot
    if plot_rhat:
        _make_rhat_plot(trace_obj, plt.subplot(gs[1]), "R-hat", labels, varnames, plot_transformed)

    return gs
Beispiel #4
0
def forestplot(trace,
               models=None,
               varnames=None,
               transform=identity_transform,
               alpha=0.05,
               quartiles=True,
               rhat=True,
               main=None,
               xtitle=None,
               xlim=None,
               ylabels=None,
               colors='C0',
               chain_spacing=0.1,
               vline=0,
               gs=None,
               plot_transformed=False,
               plot_kwargs=None):
    """
    Forest plot (model summary plot).

    Generates a "forest plot" of 100*(1-alpha)% credible intervals from a trace
    or list of traces.

    Parameters
    ----------

    trace : trace or list of traces
        Trace(s) from an MCMC sample.
    models : list (optional)
        List with names for the models in the list of traces. Useful when
        plotting more that one trace.
    varnames: list
        List of variables to plot (defaults to None, which results in all
        variables plotted).
    transform : callable
        Function to transform data (defaults to identity)
    alpha : float, optional
        Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05).
    quartiles : bool, optional
        Flag for plotting the interquartile range, in addition to the
        (1-alpha)*100% intervals (defaults to True).
    rhat : bool, optional
        Flag for plotting Gelman-Rubin statistics. Requires 2 or more chains
        (defaults to True).
    main : string, optional
        Title for main plot. Passing False results in titles being suppressed;
        passing None (default) results in default titles.
    xtitle : string, optional
        Label for x-axis. Defaults to no label
    xlim : list or tuple, optional
        Range for x-axis. Defaults to matplotlib's best guess.
    ylabels : list or array, optional
        User-defined labels for each variable. If not provided, the node
        __name__ attributes are used.
    colors : list or string, optional
        list with valid matplotlib colors, one color per model. Alternative a
        string can be passed. If the string is `cycle `, it will automatically
        chose a color per model from the matyplolib's cycle. If a single color
        is passed, eg 'k', 'C2', 'red' this color will be used for all models.
        Defauls to 'C0' (blueish in most matplotlib styles)
    chain_spacing : float, optional
        Plot spacing between chains (defaults to 0.1).
    vline : numeric, optional
        Location of vertical reference line (defaults to 0).
    gs : GridSpec
        Matplotlib GridSpec object. Defaults to None.
    plot_transformed : bool
        Flag for plotting automatically transformed variables in addition to
        original variables (defaults to False).
    plot_kwargs : dict
        Optional arguments for plot elements. Currently accepts 'fontsize',
        'linewidth', 'marker', and 'markersize'.

    Returns
    -------

    gs : matplotlib GridSpec

    """
    if plot_kwargs is None:
        plot_kwargs = {}

    if not isinstance(trace, (list, tuple)):
        trace = [trace]

    if models is None:
        if len(trace) > 1:
            models = ['m_{}'.format(i) for i in range(len(trace))]
        else:
            models = ['']
    elif len(models) != len(trace):
        raise ValueError("The number of names for the models does not match "
                         "the number of models")

    if colors == 'cycle':
        colors = ['C{}'.format(i % 10) for i in range(len(models))]
    elif isinstance(colors, str):
        colors = [colors for i in range(len(models))]

    # Quantiles to be calculated
    if quartiles:
        qlist = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)]
    else:
        qlist = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)]

    nchains = [tr.nchains for tr in trace]

    if varnames is None:
        varnames = []
        for idx, tr in enumerate(trace):
            varnames_tmp = get_default_varnames(tr.varnames, plot_transformed)
            for v in varnames_tmp:
                if v not in varnames:
                    varnames.append(v)

    plot_rhat = [rhat and nch > 1 for nch in nchains]
    # Empty list for y-axis labels
    if gs is None:
        # Initialize plot
        if np.any(plot_rhat):
            gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1])
            gr_plot = plt.subplot(gs[1])
            gr_plot.set_xticks((1.0, 1.5, 2.0), ("1", "1.5", "2+"))
            gr_plot.set_xlim(0.9, 2.1)
            gr_plot.set_yticks([])
            gr_plot.set_title('R-hat')
        else:
            gs = gridspec.GridSpec(1, 1)

        # Subplot for confidence intervals
        interval_plot = plt.subplot(gs[0])

    trace_quantiles = []
    hpd_intervals = []
    for tr in trace:
        trace_quantiles.append(
            quantiles(tr, qlist, transform=transform, squeeze=False))
        hpd_intervals.append(hpd(tr, alpha, transform=transform,
                                 squeeze=False))

    labels = []
    var = 0
    all_quants = []
    bands = [0.05, 0] * len(varnames)
    var_old = 0.5
    for v_idx, v in enumerate(varnames):
        for h, tr in enumerate(trace):
            if v not in tr.varnames:
                labels.append(models[h] + ' ' + v)
                var += 1
            else:
                for j, chain in enumerate(tr.chains):
                    var_quantiles = trace_quantiles[h][chain][v]

                    quants = [var_quantiles[vq] for vq in qlist]
                    var_hpd = hpd_intervals[h][chain][v].T

                    # Substitute HPD interval for quantile
                    quants[0] = var_hpd[0].T
                    quants[-1] = var_hpd[1].T

                    # Ensure x-axis contains range of current interval
                    all_quants.extend(np.ravel(quants))

                    # Number of elements in current variable
                    value = tr.get_values(v, chains=[chain])[0]
                    k = np.size(value)

                    # Append variable name(s) to list
                    if j == 0:
                        if k > 1:
                            names = _var_str(v, np.shape(value))
                            names[0] = models[h] + ' ' + names[0]
                            labels += names
                        else:
                            labels.append(models[h] + ' ' + v)

                    # Add spacing for each chain, if more than one
                    offset = [0] + [(chain_spacing * ((i + 2) / 2)) * (-1)**i
                                    for i in range(nchains[h] - 1)]

                    # Y coordinate with offset
                    y = -var + offset[j]

                    # Deal with multivariate nodes

                    if k > 1:
                        qs = np.moveaxis(np.array(quants), 0, -1).squeeze()
                        for q in qs.reshape(-1, len(quants)):
                            # Multiple y values
                            interval_plot = _plot_tree(interval_plot, y, q,
                                                       quartiles, colors[h],
                                                       plot_kwargs)
                            y -= 1
                    else:
                        interval_plot = _plot_tree(interval_plot, y, quants,
                                                   quartiles, colors[h],
                                                   plot_kwargs)

                # Genenerate Gelman-Rubin plot
                if plot_rhat[h] and v in tr.varnames:
                    R = gelman_rubin(tr, [v])
                    if k > 1:
                        Rval = dict2pd(R, 'rhat').values
                        gr_plot.plot([min(r, 2) for r in Rval],
                                     [-(j + var) for j in range(k)],
                                     'o',
                                     color=colors[h],
                                     markersize=4)
                    else:
                        gr_plot.plot(min(R[v], 2),
                                     -var,
                                     'o',
                                     color=colors[h],
                                     markersize=4)
                var += k

        if len(trace) > 1:
            interval_plot.axhspan(var_old,
                                  y - chain_spacing - 0.5,
                                  facecolor='k',
                                  alpha=bands[v_idx])
            gr_plot.axhspan(var_old,
                            y - chain_spacing - 0.5,
                            facecolor='k',
                            alpha=bands[v_idx])
            var_old = y - chain_spacing - 0.5

    if ylabels is not None:
        labels = ylabels

    # Update margins
    left_margin = np.max([len(x) for x in labels]) * 0.015
    gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05)

    # Define range of y-axis for forestplot and R-hat
    interval_plot.set_ylim(-var + 0.5, 0.5)
    if np.any(plot_rhat):
        gr_plot.set_ylim(-var + 0.5, 0.5)

    plotrange = [np.min(all_quants), np.max(all_quants)]
    datarange = plotrange[1] - plotrange[0]
    interval_plot.set_xlim(plotrange[0] - 0.05 * datarange,
                           plotrange[1] + 0.05 * datarange)

    # Add variable labels
    interval_plot.set_yticks([-l for l in range(len(labels))])
    interval_plot.set_yticklabels(labels,
                                  fontsize=plot_kwargs.get('fontsize', None))

    # Add title
    if main is None:
        plot_title = "{:.0f}% Credible Intervals".format((1 - alpha) * 100)
    elif main:
        plot_title = main
    else:
        plot_title = ""

    interval_plot.set_title(plot_title,
                            fontsize=plot_kwargs.get('fontsize', None))

    # Add x-axis label
    if xtitle is not None:
        interval_plot.set_xlabel(xtitle)

    # Constrain to specified range
    if xlim is not None:
        interval_plot.set_xlim(*xlim)

    # Remove ticklines on y-axes
    for ticks in interval_plot.yaxis.get_major_ticks():
        ticks.tick1On = False
        ticks.tick2On = False

    for loc, spine in interval_plot.spines.items():
        if loc in ['left', 'right']:
            spine.set_color('none')  # don't draw spine

    # Reference line
    interval_plot.axvline(vline, color='k', linestyle=':')

    return gs
    features = model.evaluate_features(target_train.index,
                                       target_train.columns)

    trend_features = features["temporal_trend"].swaplevel(0, 1).loc["09162"]
    periodic_features = features["temporal_seasonal"].swaplevel(0,
                                                                1).loc["09162"]
    #t_all = t_all_b if disease == "borreliosis" else t_all_cr

    trace = load_trace(disease, use_age, use_eastwest)
    trend_params = pm.trace_to_dataframe(trace, varnames=["W_t_t"])
    periodic_params = pm.trace_to_dataframe(trace, varnames=["W_t_s"])

    TT = trend_params.values.dot(trend_features.values.T)
    TP = periodic_params.values.dot(periodic_features.values.T)
    TTP = TT + TP
    TT_quantiles = quantiles(TT, (25, 75))
    TP_quantiles = quantiles(TP, (25, 75))
    TTP_quantiles = quantiles(TTP, (25, 75))

    dates = [n.wednesday() for n in target_train.index.values]

    # Temporal periodic effect
    ax_p = fig.add_subplot(grid[0, i])

    ax_p.fill_between(dates,
                      np.exp(TP_quantiles[25]),
                      np.exp(TP_quantiles[75]),
                      alpha=0.5,
                      zorder=1,
                      facecolor=C1)
    ax_p.plot_date(dates,
def curves_no_periodic(model_i=15, prediction_day=30, save_plot=False):

    with open('../data/counties/counties.pkl', "rb") as f:
        counties = pkl.load(f)

    # with open('../data/comparison.pkl', "rb") as f:
    #     best_model = pkl.load(f)

    # update to day and new limits!
    xlim = (5.5, 15.5)
    ylim = (47, 56)  # <- 10 weeks

    countyByName = OrderedDict([('Düsseldorf', '05111'), ('Leipzig', '14713'),
                                ('Nürnberg', '09564'), ('München', '09162')])
    plot_county_names = {"covid19": ["Düsseldorf", "Leipzig"]}

    # colors for curves
    C1 = "#D55E00"
    C2 = "#E69F00"
    C3 = "#0073CF"

    # quantiles we want to plot
    qs = [0.25, 0.50, 0.75]

    fig = plt.figure(figsize=(12, 14))
    grid = plt.GridSpec(3,
                        1,
                        top=0.9,
                        bottom=0.1,
                        left=0.07,
                        right=0.97,
                        hspace=0.25,
                        wspace=0.15,
                        height_ratios=[1, 1, 1.75])

    # for i, disease in enumerate(diseases):
    i = 0
    disease = "covid19"
    prediction_region = "germany"

    data = load_daily_data(disease, prediction_region, counties)

    start_day = pd.Timestamp('2020-04-10')
    i_start_day = (start_day - data.index.min()).days
    day_0 = pd.Timestamp('2020-05-21')
    day_m5 = day_0 - pd.Timedelta(days=5)
    day_p5 = day_0 + pd.Timedelta(days=5)

    _, target, _, _ = split_data(data,
                                 train_start=start_day,
                                 test_start=day_0,
                                 post_test=day_p5)

    county_ids = target.columns

    # Load our prediction samples
    res = load_final_nowcast_pred()
    #res_test = load_pred_by_i(disease, model_i)
    # print(res_train['y'].shape)
    #print(res_test['y'].shape)
    n_days = (day_p5 - start_day).days
    #print(res['y'].shape)
    prediction_samples = np.reshape(res['y'], (res['y'].shape[0], -1, 412))
    #print(prediction_samples.shape)
    #print(target.index)
    prediction_samples = prediction_samples[:, i_start_day:i_start_day +
                                            n_days, :]
    ext_index = pd.DatetimeIndex([d for d in target.index] + \
            [d for d in pd.date_range(target.index[-1]+timedelta(1),day_p5-timedelta(1))])

    # TODO: figure out where quantiles comes from and if its pymc3, how to replace it
    prediction_quantiles = quantiles(prediction_samples, (5, 25, 75, 95))

    prediction_mean = pd.DataFrame(data=np.mean(prediction_samples, axis=0),
                                   index=ext_index,
                                   columns=target.columns)
    prediction_q25 = pd.DataFrame(data=prediction_quantiles[25],
                                  index=ext_index,
                                  columns=target.columns)
    prediction_q75 = pd.DataFrame(data=prediction_quantiles[75],
                                  index=ext_index,
                                  columns=target.columns)
    prediction_q5 = pd.DataFrame(data=prediction_quantiles[5],
                                 index=ext_index,
                                 columns=target.columns)
    prediction_q95 = pd.DataFrame(data=prediction_quantiles[95],
                                  index=ext_index,
                                  columns=target.columns)

    map_ax = fig.add_subplot(grid[2, i])
    map_ax.set_position(grid[2, i].get_position(fig).translated(0, -0.05))
    map_ax.set_xlabel("{}.{}.{}".format(prediction_mean.index[-5].day,
                                        prediction_mean.index[-5].month,
                                        prediction_mean.index[-5].year),
                      fontsize=22)

    # plot the chloropleth map
    plot_counties(map_ax,
                  counties,
                  prediction_mean.iloc[-10].to_dict(),
                  edgecolors=dict(
                      zip(map(countyByName.get, plot_county_names[disease]),
                          ["red"] * len(plot_county_names[disease]))),
                  xlim=xlim,
                  ylim=ylim,
                  contourcolor="black",
                  background=False,
                  xticks=False,
                  yticks=False,
                  grid=False,
                  frame=True,
                  ylabel=False,
                  xlabel=False,
                  lw=2)

    map_ax.set_rasterized(True)

    for j, name in enumerate(plot_county_names[disease]):
        ax = fig.add_subplot(grid[j, i])

        county_id = countyByName[name]
        #     dates = [n.wednesday() for n in target.index.values]
        dates = [pd.Timestamp(day) for day in ext_index]
        days = [(day - min(dates)).days for day in dates]

        # plot our predictions w/ quartiles
        p_pred = ax.plot_date(dates,
                              prediction_mean[county_id],
                              "-",
                              color=C1,
                              linewidth=2.0,
                              zorder=4)
        p_quant = ax.fill_between(dates,
                                  prediction_q25[county_id],
                                  prediction_q75[county_id],
                                  facecolor=C2,
                                  alpha=0.5,
                                  zorder=1)
        ax.plot_date(dates,
                     prediction_q25[county_id],
                     ":",
                     color=C2,
                     linewidth=2.0,
                     zorder=3)
        ax.plot_date(dates,
                     prediction_q75[county_id],
                     ":",
                     color=C2,
                     linewidth=2.0,
                     zorder=3)

        # plot ground truth
        p_real = ax.plot_date(dates[:-5], target[county_id], "k.")

        # plot 30week marker
        ax.axvline(dates[-5], ls='-', lw=2, c='cornflowerblue')
        ax.axvline(dates[-10], ls='--', lw=2, c='cornflowerblue')

        #ax.set_title(["campylobacteriosis" if disease == "campylobacter" else disease]
        #            [0] + "\n" + name if j == 0 else name, fontsize=22)
        if j == 1:
            ax.set_xlabel("Time", fontsize=20)
        ax.tick_params(axis="both",
                       direction='out',
                       size=6,
                       labelsize=16,
                       length=6)
        ticks = [
            '2020-03-02', '2020-03-12', '2020-03-22', '2020-04-01',
            '2020-04-11', '2020-04-21', '2020-05-1', '2020-05-11', '2020-05-21'
        ]
        labels = [
            '02.03.2020', '12.03.2020', '22.03.2020', '01.04.2020',
            '11.04.2020', '21.04.2020', '01.05.2020', '11.05.2020',
            '21.05.2020'
        ]
        plt.xticks(ticks, labels)
        #plt.xlabel(ticks)
        plt.setp(ax.get_xticklabels(), visible=j > 0, rotation=45)

        cent = np.array(counties[county_id]["shape"].centroid.coords[0])
        txt = map_ax.annotate(name,
                              cent,
                              cent + 0.5,
                              color="white",
                              arrowprops=dict(facecolor='white',
                                              shrink=0.001,
                                              headwidth=3),
                              fontsize=26,
                              fontweight='bold',
                              fontname="Arial")
        txt.set_path_effects(
            [PathEffects.withStroke(linewidth=2, foreground='black')])

        ax.set_xlim([start_day, day_p5 - pd.Timedelta(1)])
        ax.autoscale(False)
        p_quant2 = ax.fill_between(dates,
                                   prediction_q5[county_id],
                                   prediction_q95[county_id],
                                   facecolor=C2,
                                   alpha=0.25,
                                   zorder=0)
        ax.plot_date(dates,
                     prediction_q5[county_id],
                     ":",
                     color=C2,
                     alpha=0.5,
                     linewidth=2.0,
                     zorder=1)
        ax.plot_date(dates,
                     prediction_q95[county_id],
                     ":",
                     color=C2,
                     alpha=0.5,
                     linewidth=2.0,
                     zorder=1)

        if (i == 0) & (j == 0):
            ax.legend([p_real[0], p_pred[0], p_quant, p_quant2], [
                "reported", "predicted", "25\%-75\% quantile",
                "5\%-95\% quantile"
            ],
                      fontsize=12,
                      loc="upper right")
        fig.text(0,
                 1 + 0.025,
                 r"$\textbf{" + str(i + 2) + "ABC"[j] + " " +
                 plot_county_names["covid19"][j] + r"}$",
                 fontsize=22,
                 transform=ax.transAxes)
    fig.text(0,
             0.95,
             r"$\textbf{" + str(i + 2) + r"C}$",
             fontsize=22,
             transform=map_ax.transAxes)

    fig.text(0.01,
             0.66,
             "Reported/predicted infections",
             va='center',
             rotation='vertical',
             fontsize=22)

    if save_plot:
        plt.savefig("../figures/curves_{}.pdf".format(model_i))
Beispiel #7
0
# if disease == "borreliosis":
#       data = data[data.index >= parse_yearweek("2013-KW1")]
_, target, _, _ = split_data(
    data, train_start=pd.Timestamp(
        2020, 1, 28), test_start=pd.Timestamp(
        2020, 3, 30), post_test=pd.Timestamp(
        2020, 3, 31)) # plots for the training period!
# _, _, _, target = split_data(data)
county_ids = target.columns

res = load_pred(disease, use_age, use_eastwest)
n_days = 62 # for now; get from timestamps up top!

prediction_samples = np.reshape(res['y'], (res['y'].shape[0], n_days, -1)) 
# TODO: figure out where quantiles comes from and if its pymc3, how to replace it
prediction_quantiles = quantiles(prediction_samples, (5, 25, 75, 95)) 

prediction_mean = pd.DataFrame(
    data=np.mean(
        prediction_samples,
        axis=0),
    index=target.index,
    columns=target.columns)
prediction_q25 = pd.DataFrame(
    data=prediction_quantiles[25],
    index=target.index,
    columns=target.columns)
prediction_q75 = pd.DataFrame(
    data=prediction_quantiles[75],
    index=target.index,
    columns=target.columns)
def curves(start, n_weeks=3, model_i=35,save_plot=False):

    with open('../data/counties/counties.pkl', "rb") as f:
        counties = pkl.load(f)

    

    start = int(start)
    n_weeks = int(n_weeks)
    model_i = int(model_i)
    # with open('../data/comparison.pkl', "rb") as f:
    #     best_model = pkl.load(f)

    # update to day and new limits!
    xlim = (5.5, 15.5)
    ylim = (47, 56) # <- 10 weeks

    #countyByName = OrderedDict(
    #    [('Düsseldorf', '05111'), ('Leipzig', '14713'), ('Nürnberg', '09564'), ('München', '09162')])
    # Hier dann das reinspeisen
    # Fake for loop to work
    #plot_county_names = {"covid19": ["Leipzig"]}

    # colors for curves
    C1 = "#D55E00"
    C2 = "#E69F00"
    C3 = "#0073CF"

    # quantiles we want to plot
    qs = [0.25, 0.50, 0.75]

    fig = plt.figure(figsize=(6, 8))
    grid = plt.GridSpec(
        1,
        1,
        #top=0.99,
        #bottom=0.01,
        #left=0.1,
        #right=0.9,
        #hspace=0.02,
        #wspace=0.15,
        #height_ratios=[
        #    1,
        #    1,
        #    1.75]
        )

    # for i, disease in enumerate(diseases):
    i = 0
    disease = "covid19"
    prediction_region = "germany"
    data = load_daily_data_n_weeks(start, n_weeks, disease, prediction_region, counties)

    start_day = pd.Timestamp('2020-01-28') + pd.Timedelta(days=start)
    i_start_day = 0
    day_0 = start_day + pd.Timedelta(days=n_weeks*7+5)
    day_m5 = day_0 - pd.Timedelta(days=5)
    day_p5 = day_0 + pd.Timedelta(days=5)


    _, target, _, _ = split_data(
        data,
        train_start=start_day,
        test_start=day_0,
        post_test=day_p5)

    county_ids = target.columns

    # Load our prediction samples
    res = load_pred_model_window(model_i, start, n_weeks)
    n_days = (day_p5 - start_day).days
    prediction_samples = np.reshape(res['y'], (res['y'].shape[0], -1, 412))
    prediction_samples = prediction_samples[:,i_start_day:i_start_day+n_days,:]
    ext_index = pd.DatetimeIndex([d for d in target.index] + \
            [d for d in pd.date_range(target.index[-1]+timedelta(1),day_p5-timedelta(1))])

    # TODO: figure out where quantiles comes from and if its pymc3, how to replace it
    prediction_quantiles = quantiles(prediction_samples, (5, 25, 75, 95)) 

    prediction_mean = pd.DataFrame(
        data=np.mean(
            prediction_samples,
            axis=0),
        index=ext_index,
        columns=target.columns)
    prediction_q25 = pd.DataFrame(
        data=prediction_quantiles[25],
        index=ext_index,
        columns=target.columns)
    prediction_q75 = pd.DataFrame(
        data=prediction_quantiles[75],
        index=ext_index,
        columns=target.columns)
    prediction_q5 = pd.DataFrame(
        data=prediction_quantiles[5],
        index=ext_index,
        columns=target.columns)
    prediction_q95 = pd.DataFrame(
        data=prediction_quantiles[95],
        index=ext_index,
        columns=target.columns)

    
    map_ax = fig.add_subplot(grid[0, i])
    #map_ax.set_position(grid[0, i].get_position(fig).translated(0, -0.05))
  
     
    # relativize prediction mean.
    map_vals = prediction_mean.iloc[-10]
    map_rki = data.iloc[-1].values.astype('float64')
    map_keys = []
    # ik= 0
    for (ik, (key, _)) in enumerate(counties.items()):
        n_people = counties[key]['demographics'][('total',2018)]
        map_vals[ik] = (map_vals[ik] / n_people) * 100000
        map_rki[ik] = (map_rki[ik] / n_people) * 100000
        # ik = ik+1
        map_keys.append(key)

    map_df = pd.DataFrame(index=None)
    map_df["countyID"] = map_keys
    map_df["newInf100k"] = list(map_vals)
    map_df["newInf100k_RKI"] = list(map_rki)
    
    # plot the chloropleth map
    plot_counties(map_ax,
                counties,
                map_vals.to_dict(),
                #prediction_mean.iloc[-10].to_dict(),
                #edgecolors=dict(zip(map(countyByName.get,
                #                        plot_county_names[disease]),
                #                    ["red"] * len(plot_county_names[disease]))),
                edgecolors=None,
                xlim=xlim,
                ylim=ylim,
                contourcolor="black",
                background=False,
                xticks=False,
                yticks=False,
                grid=False,
                frame=True,
                ylabel=False,
                xlabel=False,
                lw=2)
    #plt.colorbar()
    map_ax.set_rasterized(True)
    
    '''
    for j, name in enumerate(plot_county_names[disease]):
        ax = fig.add_subplot(grid[j, i])

        county_id = countyByName[name]
    #     dates = [n.wednesday() for n in target.index.values]
        dates = [pd.Timestamp(day) for day in ext_index]
        days = [ (day - min(dates)).days for day in dates]
    '''

    fig.text(0.71,0.17,"Neuinfektionen \n pro 100.000 \n Einwohner", fontsize=14, color=[0.3,0.3,0.3])


    if save_plot:
        year = str(start_day)[:4]
        month = str(start_day)[5:7]
        day = str(start_day)[8:10]
        day_folder_path = "../figures/{}_{}_{}".format(year, month, day)
        Path(day_folder_path).mkdir(parents=True, exist_ok=True)

        plt.savefig("../figures/{}_{}_{}/map.png".format(year, month, day), dpi=300)
        map_df.to_csv("../figures/{}_{}_{}/map.csv".format(year, month, day))
    plt.close()
    return fig
Beispiel #9
0
def plotdata_csv(start, n_weeks, csv_path, counties, output_dir):
    countyByName = make_county_dict()
    data = load_data_n_weeks(start, n_weeks, csv_path)
    start_day = pd.Timestamp("2020-01-28") + pd.Timedelta(days=start)
    day_0 = start_day + pd.Timedelta(days=n_weeks * 7 + 5)
    day_m5 = day_0 - pd.Timedelta(days=5)
    day_p5 = day_0 + pd.Timedelta(days=5)
    _, target, _, _ = split_data(data,
                                 train_start=start_day,
                                 test_start=day_0,
                                 post_test=day_p5)

    # Load our prediction samples
    res = load_predictions(start, n_weeks)
    res_trend = load_trend_predictions(start, n_weeks)

    prediction_samples = np.reshape(res["y"], (res["y"].shape[0], -1, 412))
    prediction_samples_mu = np.reshape(res["μ"], (res["μ"].shape[0], -1, 412))
    prediction_samples_trend = np.reshape(res_trend["y"],
                                          (res_trend["y"].shape[0], -1, 412))
    prediction_samples_trend_mu = np.reshape(
        res_trend["μ"], (res_trend["μ"].shape[0], -1, 412))
    predictions_7day_inc = sample_x_days_incidence_by_county(
        prediction_samples_trend, 7)
    predictions_7day_inc_mu = sample_x_days_incidence_by_county(
        prediction_samples_trend_mu, 7)
    ext_index = pd.DatetimeIndex([d for d in target.index] + [
        d for d in pd.date_range(target.index[-1] + timedelta(1), day_p5 -
                                 timedelta(1))
    ])
    # TODO: figure out if we want to replace quantiles function (newer pymc3 versions don't support it)
    prediction_quantiles = quantiles(prediction_samples, (5, 25, 75, 95))
    prediction_quantiles_trend = quantiles(prediction_samples_trend,
                                           (5, 25, 75, 95))
    prediction_quantiles_7day_inc = quantiles(predictions_7day_inc,
                                              (5, 25, 75, 95))

    prediction_mean = pd.DataFrame(
        data=np.mean(prediction_samples_mu, axis=0),
        index=ext_index,
        columns=target.columns,
    )
    prediction_q25 = pd.DataFrame(data=prediction_quantiles[25],
                                  index=ext_index,
                                  columns=target.columns)
    prediction_q75 = pd.DataFrame(data=prediction_quantiles[75],
                                  index=ext_index,
                                  columns=target.columns)
    prediction_q5 = pd.DataFrame(data=prediction_quantiles[5],
                                 index=ext_index,
                                 columns=target.columns)
    prediction_q95 = pd.DataFrame(data=prediction_quantiles[95],
                                  index=ext_index,
                                  columns=target.columns)

    prediction_mean_trend = pd.DataFrame(
        data=np.mean(prediction_samples_trend_mu, axis=0),
        index=ext_index,
        columns=target.columns,
    )
    prediction_q25_trend = pd.DataFrame(data=prediction_quantiles_trend[25],
                                        index=ext_index,
                                        columns=target.columns)
    prediction_q75_trend = pd.DataFrame(data=prediction_quantiles_trend[75],
                                        index=ext_index,
                                        columns=target.columns)
    prediction_q5_trend = pd.DataFrame(data=prediction_quantiles_trend[5],
                                       index=ext_index,
                                       columns=target.columns)
    prediction_q95_trend = pd.DataFrame(data=prediction_quantiles_trend[95],
                                        index=ext_index,
                                        columns=target.columns)

    prediction_mean_7day = pd.DataFrame(
        data=np.pad(
            np.mean(predictions_7day_inc_mu, axis=0),
            ((6, 0), (0, 0)),
            "constant",
            constant_values=np.nan,
        ),
        index=ext_index,
        columns=target.columns,
    )
    prediction_q25_7day = pd.DataFrame(
        data=np.pad(
            prediction_quantiles_7day_inc[25].astype(float),
            ((6, 0), (0, 0)),
            "constant",
            constant_values=np.nan,
        ),
        index=ext_index,
        columns=target.columns,
    )
    prediction_q75_7day = pd.DataFrame(
        data=np.pad(
            prediction_quantiles_7day_inc[75].astype(float),
            ((6, 0), (0, 0)),
            "constant",
            constant_values=np.nan,
        ),
        index=ext_index,
        columns=target.columns,
    )
    prediction_q5_7day = pd.DataFrame(
        data=np.pad(
            prediction_quantiles_7day_inc[5].astype(float),
            ((6, 0), (0, 0)),
            "constant",
            constant_values=np.nan,
        ),
        index=ext_index,
        columns=target.columns,
    )
    prediction_q95_7day = pd.DataFrame(
        data=np.pad(
            prediction_quantiles_7day_inc[95].astype(float),
            ((6, 0), (0, 0)),
            "constant",
            constant_values=np.nan,
        ),
        index=ext_index,
        columns=target.columns,
    )

    rki_7day = target.rolling(7).sum()

    ref_date = target.iloc[-1].name
    nowcast_vals = prediction_mean.loc[prediction_mean.index == ref_date]
    nowcast7day_vals = prediction_mean_7day.loc[prediction_mean.index ==
                                                ref_date]
    rki_vals = target.iloc[-1]
    rki_7day_vals = rki_7day.iloc[-1]

    map_nowcast = []
    map_nowcast100k = []
    map_nowcast_7day = []
    map_nowcast_7day100k = []
    map_rki = []
    map_rki100k = []
    map_rki_7day = []
    map_rki_7day100k = []
    map_keys = []

    for (county, county_id) in countyByName.items():
        rki_data = np.append(target.loc[:, county_id].values,
                             np.repeat(np.nan, 5))
        rki_data7day = np.append(rki_7day.loc[:, county_id].values,
                                 np.repeat(np.nan, 5))
        n_people = counties[county_id]["demographics"][("total", 2018)]

        map_nowcast.append(nowcast_vals[county_id].item())
        map_nowcast100k.append(nowcast_vals[county_id].item() / n_people *
                               100000)
        map_nowcast_7day.append(nowcast7day_vals[county_id].item())
        map_nowcast_7day100k.append(nowcast7day_vals[county_id].item() /
                                    n_people * 100000)
        map_rki.append(rki_vals[county_id].item())
        map_rki100k.append(rki_vals[county_id].item() / n_people * 100000)
        map_rki_7day.append(rki_7day_vals[county_id].item())
        map_rki_7day100k.append(rki_7day_vals[county_id].item() / n_people *
                                100000)
        map_keys.append(county_id)

        county_data = pd.DataFrame(
            {
                "Raw Prediction Mean":
                prediction_mean.loc[:, county_id].values,
                "Raw Prediction Mean 100k":
                np.multiply(
                    np.divide(prediction_mean.loc[:, county_id].values,
                              n_people),
                    100000,
                ),
                "Raw Prediction Q25":
                prediction_q25.loc[:, county_id].values,
                "Raw Prediction Q25 100k":
                np.multiply(
                    np.divide(prediction_q25.loc[:, county_id].values,
                              n_people),
                    100000,
                ),
                "Raw Prediction Q75":
                prediction_q75.loc[:, county_id].values,
                "Raw Prediction Q75 100k":
                np.multiply(
                    np.divide(prediction_q75.loc[:, county_id].values,
                              n_people),
                    100000,
                ),
                "Raw Prediction Q5":
                prediction_q5.loc[:, county_id].values,
                "Raw Prediction Q5 100k":
                np.multiply(
                    np.divide(prediction_q5.loc[:, county_id].values,
                              n_people),
                    100000,
                ),
                "Raw Prediction Q95":
                prediction_q95.loc[:, county_id].values,
                "Raw Prediction Q95 100k":
                np.multiply(
                    np.divide(prediction_q95.loc[:, county_id].values,
                              n_people),
                    100000,
                ),
                "Trend Prediction Mean":
                prediction_mean_trend.loc[:, county_id].values,
                "Trend Prediction Mean 100k":
                np.multiply(
                    np.divide(prediction_mean_trend.loc[:, county_id].values,
                              n_people),
                    100000,
                ),
                "Trend Prediction Q25":
                prediction_q25_trend.loc[:, county_id].values,
                "Trend Prediction Q25 100k":
                np.multiply(
                    np.divide(prediction_q25_trend.loc[:, county_id].values,
                              n_people),
                    100000,
                ),
                "Trend Prediction Q75":
                prediction_q75_trend.loc[:, county_id].values,
                "Trend Prediction Q75 100k":
                np.multiply(
                    np.divide(prediction_q75_trend.loc[:, county_id].values,
                              n_people),
                    100000,
                ),
                "Trend Prediction Q5":
                prediction_q5_trend.loc[:, county_id].values,
                "Trend Prediction Q5 100k":
                np.multiply(
                    np.divide(prediction_q5_trend.loc[:, county_id].values,
                              n_people),
                    100000,
                ),
                "Trend Prediction Q95":
                prediction_q95_trend.loc[:, county_id].values,
                "Trend Prediction Q95 100k":
                np.multiply(
                    np.divide(prediction_q95_trend.loc[:, county_id].values,
                              n_people),
                    100000,
                ),
                "Trend 7Day Prediction Mean":
                prediction_mean_7day.loc[:, county_id].values,
                "Trend 7Day Prediction Mean 100k":
                np.multiply(
                    np.divide(prediction_mean_7day.loc[:, county_id].values,
                              n_people),
                    100000,
                ),
                "Trend 7Day Prediction Q25":
                prediction_q25_7day.loc[:, county_id].values,
                "Trend 7Day Prediction Q25 100k":
                np.multiply(
                    np.divide(prediction_q25_7day.loc[:, county_id].values,
                              n_people),
                    100000,
                ),
                "Trend 7Day Prediction Q75":
                prediction_q75_7day.loc[:, county_id].values,
                "Trend 7Day Prediction Q75 100k":
                np.multiply(
                    np.divide(prediction_q75_7day.loc[:, county_id].values,
                              n_people),
                    100000,
                ),
                "Trend 7Day Prediction Q5":
                prediction_q5_7day.loc[:, county_id].values,
                "Trend 7Day Prediction Q5 100k":
                np.multiply(
                    np.divide(prediction_q5_7day.loc[:, county_id].values,
                              n_people),
                    100000,
                ),
                "Trend 7Day Prediction Q95":
                prediction_q95_7day.loc[:, county_id].values,
                "Trend 7Day Prediction Q95 100k":
                np.multiply(
                    np.divide(prediction_q95_7day.loc[:, county_id].values,
                              n_people),
                    100000,
                ),
                "RKI Meldedaten":
                rki_data,
                "RKI 7Day Incidence":
                rki_data7day,
                "is_nowcast": (day_m5 <= ext_index) & (ext_index < day_0),
                "is_high":
                np.less(prediction_q95_trend.loc[:, county_id].values,
                        rki_data),
                "is_prediction": (day_0 <= ext_index),
            },
            index=ext_index,
        )
        fpath = os.path.join(output_dir, "{}.csv".format(countyByName[county]))
        county_data.to_csv(fpath)

    map_df = pd.DataFrame(index=None)
    map_df["countyID"] = map_keys
    map_df["newInf100k"] = map_nowcast100k
    map_df["7DayInf100k"] = map_nowcast_7day100k
    map_df["newInf100k_RKI"] = map_rki100k
    map_df["7DayInf100k_RKI"] = map_rki_7day100k
    map_df["newInfRaw"] = map_nowcast
    map_df["7DayInfRaw"] = map_nowcast_7day
    map_df["newInfRaw_RKI"] = map_rki
    map_df["7DayInfRaw_RKI"] = map_rki_7day
    map_df.to_csv(os.path.join(output_dir, "map.csv"))
Beispiel #10
0
def forestplot(trace, models=None, varnames=None, transform=identity_transform,
               alpha=0.05, quartiles=True, rhat=True, main=None, xtitle=None,
               xlim=None, ylabels=None, colors='C0', chain_spacing=0.1, vline=0,
               gs=None, plot_transformed=False, plot_kwargs=None):
    """
    Forest plot (model summary plot).

    Generates a "forest plot" of 100*(1-alpha)% credible intervals from a trace
    or list of traces.

    Parameters
    ----------

    trace : trace or list of traces
        Trace(s) from an MCMC sample.
    models : list (optional)
        List with names for the models in the list of traces. Useful when
        plotting more that one trace.
    varnames: list
        List of variables to plot (defaults to None, which results in all
        variables plotted).
    transform : callable
        Function to transform data (defaults to identity)
    alpha : float, optional
        Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05).
    quartiles : bool, optional
        Flag for plotting the interquartile range, in addition to the
        (1-alpha)*100% intervals (defaults to True).
    rhat : bool, optional
        Flag for plotting Gelman-Rubin statistics. Requires 2 or more chains
        (defaults to True).
    main : string, optional
        Title for main plot. Passing False results in titles being suppressed;
        passing None (default) results in default titles.
    xtitle : string, optional
        Label for x-axis. Defaults to no label
    xlim : list or tuple, optional
        Range for x-axis. Defaults to matplotlib's best guess.
    ylabels : list or array, optional
        User-defined labels for each variable. If not provided, the node
        __name__ attributes are used.
    colors : list or string, optional
        list with valid matplotlib colors, one color per model. Alternative a
        string can be passed. If the string is `cycle `, it will automatically
        chose a color per model from the matyplolib's cycle. If a single color
        is passed, eg 'k', 'C2', 'red' this color will be used for all models.
        Defauls to 'C0' (blueish in most matplotlib styles)
    chain_spacing : float, optional
        Plot spacing between chains (defaults to 0.1).
    vline : numeric, optional
        Location of vertical reference line (defaults to 0).
    gs : GridSpec
        Matplotlib GridSpec object. Defaults to None.
    plot_transformed : bool
        Flag for plotting automatically transformed variables in addition to
        original variables (defaults to False).
    plot_kwargs : dict
        Optional arguments for plot elements. Currently accepts 'fontsize',
        'linewidth', 'marker', and 'markersize'.

    Returns
    -------

    gs : matplotlib GridSpec

    """
    if plot_kwargs is None:
        plot_kwargs = {}

    if not isinstance(trace, (list, tuple)):
        trace = [trace]

    if models is None:
        if len(trace) > 1:
            models = ['m_{}'.format(i) for i in range(len(trace))]
        else:
            models = ['']
    elif len(models) != len(trace):
        raise ValueError("The number of names for the models does not match "
                         "the number of models")

    if colors == 'cycle':
        colors = ['C{}'.format(i % 10) for i in range(len(models))]
    elif isinstance(colors, str):
        colors = [colors for i in range(len(models))]

    # Quantiles to be calculated
    if quartiles:
        qlist = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)]
    else:
        qlist = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)]

    nchains = [tr.nchains for tr in trace]

    if varnames is None:
        varnames = []
        for idx, tr in enumerate(trace):
            varnames_tmp = get_default_varnames(tr.varnames, plot_transformed)
            for v in varnames_tmp:
                if v not in varnames:
                    varnames.append(v)

    plot_rhat = [rhat and nch > 1 for nch in nchains]
    # Empty list for y-axis labels
    if gs is None:
        # Initialize plot
        if np.any(plot_rhat):
            gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1])
            gr_plot = plt.subplot(gs[1])
            gr_plot.set_xticks((1.0, 1.5, 2.0), ("1", "1.5", "2+"))
            gr_plot.set_xlim(0.9, 2.1)
            gr_plot.set_yticks([])
            gr_plot.set_title('R-hat')
        else:
            gs = gridspec.GridSpec(1, 1)

    # Subplot for confidence intervals
    interval_plot = plt.subplot(gs[0])

    trace_quantiles = []
    hpd_intervals = []
    for tr in trace:
        trace_quantiles.append(quantiles(tr, qlist, transform=transform,
                                         squeeze=False))
        hpd_intervals.append(hpd(tr, alpha, transform=transform,
                                 squeeze=False))

    labels = []
    var = 0
    all_quants = []
    bands = [0.05, 0] * len(varnames)
    var_old = 0.5
    for v_idx, v in enumerate(varnames):
        for h, tr in enumerate(trace):
            if v not in tr.varnames:
                labels.append(models[h] + ' ' + v)
                var += 1
            else:
                for j, chain in enumerate(tr.chains):
                    var_quantiles = trace_quantiles[h][chain][v]

                    quants = [var_quantiles[vq] for vq in qlist]
                    var_hpd = hpd_intervals[h][chain][v].T

                    # Substitute HPD interval for quantile
                    quants[0] = var_hpd[0].T
                    quants[-1] = var_hpd[1].T

                    # Ensure x-axis contains range of current interval
                    all_quants.extend(np.ravel(quants))

                    # Number of elements in current variable
                    value = tr.get_values(v, chains=[chain])[0]
                    k = np.size(value)

                    # Append variable name(s) to list
                    if j == 0:
                        if k > 1:
                            names = _var_str(v, np.shape(value))
                            names[0] = models[h] + ' ' + names[0]
                            labels += names
                        else:
                            labels.append(models[h] + ' ' + v)

                    # Add spacing for each chain, if more than one
                    offset = [0] + [(chain_spacing * ((i + 2) / 2)) *
                                    (-1) ** i for i in range(nchains[h] - 1)]

                    # Y coordinate with offset
                    y = - var + offset[j]

                    # Deal with multivariate nodes

                    if k > 1:
                        qs = np.moveaxis(np.array(quants), 0, -1).squeeze()
                        for q in qs.reshape(-1, len(quants)):
                            # Multiple y values
                            interval_plot = _plot_tree(interval_plot, y, q,
                                                       quartiles, colors[h],
                                                       plot_kwargs)
                            y -= 1
                    else:
                        interval_plot = _plot_tree(interval_plot, y, quants,
                                                   quartiles, colors[h],
                                                   plot_kwargs)

                # Genenerate Gelman-Rubin plot
                if plot_rhat[h] and v in tr.varnames:
                    R = gelman_rubin(tr, [v])
                    if k > 1:
                        Rval = dict2pd(R, 'rhat').values
                        gr_plot.plot([min(r, 2) for r in Rval],
                                     [-(j + var) for j in range(k)], 'o',
                                     color=colors[h], markersize=4)
                    else:
                        gr_plot.plot(min(R[v], 2), -var, 'o', color=colors[h],
                                     markersize=4)
                var += k

        if len(trace) > 1:
            interval_plot.axhspan(var_old, y - chain_spacing - 0.5,
                                  facecolor='k', alpha=bands[v_idx])
            if np.any(plot_rhat):
                gr_plot.axhspan(var_old, y - chain_spacing - 0.5,
                                facecolor='k', alpha=bands[v_idx])
            var_old = y - chain_spacing - 0.5

    if ylabels is not None:
        labels = ylabels

    # Update margins
    left_margin = np.max([len(x) for x in labels]) * 0.015
    gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05)

    # Define range of y-axis for forestplot and R-hat
    interval_plot.set_ylim(- var + 0.5, 0.5)
    if np.any(plot_rhat):
        gr_plot.set_ylim(- var + 0.5, 0.5)

    plotrange = [np.min(all_quants), np.max(all_quants)]
    datarange = plotrange[1] - plotrange[0]
    interval_plot.set_xlim(plotrange[0] - 0.05 * datarange,
                           plotrange[1] + 0.05 * datarange)

    # Add variable labels
    interval_plot.set_yticks([- l for l in range(len(labels))])
    interval_plot.set_yticklabels(labels,
                                  fontsize=plot_kwargs.get('fontsize', None))

    # Add title
    if main is None:
        plot_title = "{:.0f}% Credible Intervals".format((1 - alpha) * 100)
    elif main:
        plot_title = main
    else:
        plot_title = ""

    interval_plot.set_title(plot_title,
                            fontsize=plot_kwargs.get('fontsize', None))

    # Add x-axis label
    if xtitle is not None:
        interval_plot.set_xlabel(xtitle)

    # Constrain to specified range
    if xlim is not None:
        interval_plot.set_xlim(*xlim)

    # Remove ticklines on y-axes
    for ticks in interval_plot.yaxis.get_major_ticks():
        ticks.tick1On = False
        ticks.tick2On = False

    for loc, spine in interval_plot.spines.items():
        if loc in ['left', 'right']:
            spine.set_color('none')  # don't draw spine

    # Reference line
    interval_plot.axvline(vline, color='k', linestyle=':')

    return gs
def curves_appendix(model_i=15, save_plot=False):

    with open('../data/counties/counties.pkl', "rb") as f:
        counties = pkl.load(f)

    countyByName = OrderedDict([('Düsseldorf', '05111'),
                                ('Recklinghausen', '05562'),
                                ("Hannover", "03241"), ("Hamburg", "02000"),
                                ("Berlin-Mitte", "11001"),
                                ("Osnabrück", "03404"),
                                ("Frankfurt (Main)", "06412"),
                                ("Görlitz", "14626"), ("Stuttgart", "08111"),
                                ("Potsdam", "12054"), ("Köln", "05315"),
                                ("Aachen", "05334"), ("Rostock", "13003"),
                                ("Flensburg", "01001"),
                                ("Frankfurt (Oder)", "12053"),
                                ("Lübeck", "01003"), ("Münster", "05515"),
                                ("Berlin Neukölln", "11008"),
                                ('Göttingen', "03159"), ("Cottbus", "12052"),
                                ("Erlangen", "09562"), ("Regensburg", "09362"),
                                ("Bayreuth", "09472"), ("Bautzen", "14625"),
                                ('Nürnberg', '09564'), ('München', '09162'),
                                ("Würzburg", "09679"), ("Deggendorf", "09271"),
                                ("Ansbach", "09571"), ("Rottal-Inn", "09277"),
                                ("Passau", "09275"), ("Schwabach", "09565"),
                                ("Memmingen", "09764"),
                                ("Erlangen-Höchstadt", "09572"),
                                ("Nürnberger Land", "09574"),
                                ('Roth', "09576"), ('Starnberg', "09188"),
                                ('Berchtesgadener Land', "09172"),
                                ('Schweinfurt', "09678"),
                                ("Augsburg", "09772"),
                                ('Neustadt a.d.Waldnaab', "09374"),
                                ("Fürstenfeldbruck", "09179"),
                                ('Rosenheim', "09187"), ("Straubing", "09263"),
                                ("Erding", "09177"),
                                ("Tirschenreuth", "09377"),
                                ('Miltenberg', "09676"),
                                ('Neumarkt i.d.OPf.', "09373"),
                                ('Heinsberg', "05370"),
                                ('Landsberg am Lech', "09181"),
                                ('Rottal-Inn', "09277"), ("Tübingen", "08416"),
                                ("Augsburg", "09772"), ("Bielefeld", "05711")])

    plot_county_names = {
        "covid19": [
            "Düsseldorf", "Heinsberg", "Hannover", "München", "Hamburg",
            "Berlin-Mitte", "Osnabrück", "Frankfurt (Main)", "Görlitz",
            "Stuttgart", "Landsberg am Lech", "Köln", "Rottal-Inn", "Rostock",
            "Flensburg", "Frankfurt (Oder)", "Lübeck", "Münster",
            "Berlin Neukölln", "Göttingen", "Bielefeld", "Tübingen",
            "Augsburg", "Bayreuth", "Nürnberg"
        ]
    }

    # colors for curves
    C1 = "#D55E00"
    C2 = "#E69F00"
    C3 = "#0073CF"

    i = 0
    disease = "covid19"
    prediction_region = "germany"

    data = load_daily_data(disease, prediction_region, counties)

    start_day = pd.Timestamp('2020-03-01')
    i_start_day = (start_day - data.index.min()).days
    day_0 = pd.Timestamp('2020-05-21')
    day_m5 = day_0 - pd.Timedelta(days=5)
    day_p5 = day_0 + pd.Timedelta(days=5)

    _, target, _, _ = split_data(data,
                                 train_start=start_day,
                                 test_start=day_0,
                                 post_test=day_p5)

    county_ids = target.columns

    # Load our prediction samples

    res = load_final_pred()
    n_days = (day_p5 - start_day).days

    prediction_samples = np.reshape(res['y'], (res['y'].shape[0], -1, 412))

    prediction_samples = prediction_samples[:, i_start_day:i_start_day +
                                            n_days, :]
    prediction_quantiles = quantiles(prediction_samples, (5, 25, 75, 95))
    ext_index = pd.DatetimeIndex([d for d in target.index] + \
            [d for d in pd.date_range(target.index[-1]+timedelta(1),day_p5-timedelta(1))])
    #print(ext_index)
    # print(prediction_samples.shape)

    prediction_mean = pd.DataFrame(
        data=np.mean(prediction_samples, axis=0),
        index=ext_index,
        #index=target.index,
        columns=target.columns)
    prediction_q25 = pd.DataFrame(
        data=prediction_quantiles[25],
        index=ext_index,
        #index=target.index,
        columns=target.columns)
    prediction_q75 = pd.DataFrame(
        data=prediction_quantiles[75],
        index=ext_index,
        #index=target.index,
        columns=target.columns)
    prediction_q5 = pd.DataFrame(
        data=prediction_quantiles[5],
        index=ext_index,
        #index=target.index,
        columns=target.columns)
    prediction_q95 = pd.DataFrame(
        data=prediction_quantiles[95],
        index=ext_index,
        #index=target.index,
        columns=target.columns)

    fig = plt.figure(figsize=(12, 12))
    grid = plt.GridSpec(5,
                        5,
                        top=0.90,
                        bottom=0.11,
                        left=0.07,
                        right=0.92,
                        hspace=0.2,
                        wspace=0.3)

    for j, name in enumerate(plot_county_names[disease]):
        # TODO: this should be incapsulated as plot_curve(county) (, days)

        ax = fig.add_subplot(grid[np.unravel_index(list(range(25))[j],
                                                   (5, 5))])

        county_id = countyByName[name]

        dates = [pd.Timestamp(day) for day in ext_index]
        days = [(day - min(dates)).days + 1 for day in dates]

        # plot our predictions w/ quartiles
        p_pred = ax.plot(dates,
                         prediction_mean[county_id],
                         "-",
                         color=C1,
                         linewidth=2.0,
                         zorder=4)
        p_quant = ax.fill_between(dates,
                                  prediction_q25[county_id],
                                  prediction_q75[county_id],
                                  facecolor=C2,
                                  alpha=0.5,
                                  zorder=1)
        ax.plot(dates,
                prediction_q25[county_id],
                ":",
                color=C2,
                linewidth=2.0,
                zorder=3)
        ax.plot(dates,
                prediction_q75[county_id],
                ":",
                color=C2,
                linewidth=2.0,
                zorder=3)

        p_real = ax.plot(dates[:-5], target[county_id], "k.")

        ax.set_title(name, fontsize=18)
        #days = [i+1 for i in range(len(dates))]
        #ax.set_xticks(days[::5])
        ticks = [
            '2020-03-01', '2020-03-12', '2020-03-22', '2020-04-01',
            '2020-04-11', '2020-04-21', '2020-05-1', '2020-05-11', '2020-05-21'
        ]
        labels = [
            '0',
            '10',
            '20',
            '30',
            '40',
            '50',
            '60',
            '70',
            '80',
        ]
        plt.xticks(ticks, labels)
        ax.tick_params(axis="both", direction='out', size=2, labelsize=14)
        plt.setp(ax.get_xticklabels(), visible=False)
        if j > 19:
            plt.setp(ax.get_xticklabels(), rotation=60)
            plt.setp(ax.get_xticklabels()[::2], visible=True)

        ax.autoscale(False)
        p_quant2 = ax.fill_between(days,
                                   prediction_q5[county_id],
                                   prediction_q95[county_id],
                                   facecolor=C2,
                                   alpha=0.25,
                                   zorder=0)
        ax.plot(days,
                prediction_q5[county_id],
                ":",
                color=C2,
                alpha=0.5,
                linewidth=2.0,
                zorder=1)
        ax.plot(days,
                prediction_q95[county_id],
                ":",
                color=C2,
                alpha=0.5,
                linewidth=2.0,
                zorder=1)

        # Plot blue line for indicating where predictions start.
        ax.axvline(dates[-5], ls='-', lw=2, c='cornflowerblue')
        ax.axvline(dates[-10], ls='--', lw=2, c='cornflowerblue')

    plt.legend(
        [p_real[0], p_pred[0], p_quant, p_quant2],
        ["reported", "predicted", "25\%-75\% quantile", "5\%-95\% quantile"],
        fontsize=16,
        ncol=5,
        loc="upper center",
        bbox_to_anchor=(0, -0.01, 1, 1),
        bbox_transform=plt.gcf().transFigure)
    fig.text(0.5, 0.02, "Time [days since Mar. 01]", ha='center', fontsize=22)
    fig.text(0.01,
             0.46,
             "Reported/predicted infections",
             va='center',
             rotation='vertical',
             fontsize=22)

    if save_plot:
        plt.savefig("../figures/curves_{}_appendix_{}.pdf".format(
            disease, model_i))
def temporal_contribution(model_i=15,
                          combinations=combinations,
                          save_plot=False):

    use_ia, use_report_delay, use_demographics, trend_order, periodic_order = combinations[
        model_i]

    plt.style.use('ggplot')

    with open('../data/counties/counties.pkl', "rb") as f:
        county_info = pkl.load(f)

    C1 = "#D55E00"
    C2 = "#E69F00"
    C3 = C2  # "#808080"

    if use_report_delay:
        fig = plt.figure(figsize=(25, 10))
        grid = plt.GridSpec(4,
                            1,
                            top=0.93,
                            bottom=0.12,
                            left=0.11,
                            right=0.97,
                            hspace=0.28,
                            wspace=0.30)
    else:
        fig = plt.figure(figsize=(16, 10))
        grid = plt.GridSpec(3,
                            1,
                            top=0.93,
                            bottom=0.12,
                            left=0.11,
                            right=0.97,
                            hspace=0.28,
                            wspace=0.30)

    disease = "covid19"
    prediction_region = "germany"

    data = load_daily_data(disease, prediction_region, county_info)
    first_day = pd.Timestamp('2020-04-01')
    last_day = data.index.max()

    _, target_train, _, _ = split_data(
        data,
        train_start=first_day,
        test_start=last_day - pd.Timedelta(days=1),
        post_test=last_day + pd.Timedelta(days=1))

    tspan = (target_train.index[0], target_train.index[-1])

    model = BaseModel(
        tspan,
        county_info, [
            "../data/ia_effect_samples/{}_{}.pkl".format(disease, i)
            for i in range(100)
        ],
        include_ia=use_ia,
        include_report_delay=use_report_delay,
        include_demographics=True,
        trend_poly_order=4,
        periodic_poly_order=4)

    features = model.evaluate_features(target_train.index,
                                       target_train.columns)

    trend_features = features["temporal_trend"].swaplevel(0, 1).loc["09162"]
    periodic_features = features["temporal_seasonal"].swaplevel(0,
                                                                1).loc["09162"]
    #t_all = t_all_b if disease == "borreliosis" else t_all_cr

    trace = load_final_trace()
    trend_params = pm.trace_to_dataframe(trace, varnames=["W_t_t"])
    periodic_params = pm.trace_to_dataframe(trace, varnames=["W_t_s"])

    TT = trend_params.values.dot(trend_features.values.T)
    TP = periodic_params.values.dot(periodic_features.values.T)
    TTP = TT + TP

    # add report delay if used
    #if use_report_delay:
    #    delay_features = features["temporal_report_delay"].swaplevel(0,1).loc["09162"]
    #    delay_params = pm.trace_to_dataframe(trace,varnames=["W_t_d"])
    #    TD =delay_params.values.dot(delay_features.values.T)

    #   TTP += TD
    #   TD_quantiles = quantiles(TD, (25, 75))

    TT_quantiles = quantiles(TT, (25, 75))
    TP_quantiles = quantiles(TP, (25, 75))
    TTP_quantiles = quantiles(TTP, (2.5, 25, 75, 97.5))

    dates = [pd.Timestamp(day) for day in target_train.index.values]
    days = [(day - min(dates)).days for day in dates]

    # Temporal trend+periodic effect
    if use_report_delay:
        ax_tp = fig.add_subplot(grid[0, 0])
    else:
        ax_tp = fig.add_subplot(grid[0, 0])

    ax_tp.fill_between(days,
                       np.exp(TTP_quantiles[25]),
                       np.exp(TTP_quantiles[75]),
                       alpha=0.5,
                       zorder=1,
                       facecolor=C1)
    ax_tp.plot(days, np.exp(TTP.mean(axis=0)), "-", color=C1, lw=2, zorder=5)
    ax_tp.plot(days, np.exp(TTP_quantiles[25]), "-", color=C2, lw=2, zorder=3)
    ax_tp.plot(days, np.exp(TTP_quantiles[75]), "-", color=C2, lw=2, zorder=3)
    ax_tp.plot(days,
               np.exp(TTP_quantiles[2.5]),
               "--",
               color=C2,
               lw=2,
               zorder=3)
    ax_tp.plot(days,
               np.exp(TTP_quantiles[97.5]),
               "--",
               color=C2,
               lw=2,
               zorder=3)
    #ax_tp.plot(days, np.exp(TTP[:25, :].T),
    #                "--", color=C3, lw=1, alpha=0.5, zorder=2)

    ax_tp.tick_params(axis="x", rotation=45)

    # Temporal trend effect
    ax_t = fig.add_subplot(grid[1, 0], sharex=ax_tp)

    #ax_t.fill_between(days, np.exp(TT_quantiles[25]), np.exp(
    #    TT_quantiles[75]), alpha=0.5, zorder=1, facecolor=C1)
    ax_t.plot(days, np.exp(TT.mean(axis=0)), "-", color=C1, lw=2, zorder=5)
    #ax_t.plot(days, np.exp(
    #    TT_quantiles[25]), "-", color=C2, lw=2, zorder=3)
    #ax_t.plot(days, np.exp(
    #    TT_quantiles[75]), "-", color=C2, lw=2, zorder=3)
    #ax_t.plot(days, np.exp(TT[:25, :].T),
    #            "--", color=C3, lw=1, alpha=0.5, zorder=2)

    ax_t.tick_params(axis="x", rotation=45)
    ax_t.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))

    # Temporal periodic effect
    ax_p = fig.add_subplot(grid[2, 0], sharex=ax_tp)

    #ax_p.fill_between(days, np.exp(TP_quantiles[25]), np.exp(
    #   TP_quantiles[75]), alpha=0.5, zorder=1, facecolor=C1)
    ax_p.plot(days, np.exp(TP.mean(axis=0)), "-", color=C1, lw=2, zorder=5)
    ax_p.set_ylim([-0.0001, 0.001])
    #ax_p.plot(days, np.exp(
    #  TP_quantiles[25]), "-", color=C2, lw=2, zorder=3)
    #ax_p.plot(days, np.exp(
    #  TP_quantiles[75]), "-", color=C2, lw=2, zorder=3)
    #ax_p.plot(days, np.exp(TP[:25, :].T),
    #          "--", color=C3, lw=1, alpha=0.5, zorder=2)

    ax_p.tick_params(axis="x", rotation=45)
    ax_p.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))

    ticks = [
        '2020-03-02', '2020-03-12', '2020-03-22', '2020-04-01', '2020-04-11',
        '2020-04-21', '2020-05-1', '2020-05-11', '2020-05-21'
    ]
    labels = [
        '02.03.2020', '12.03.2020', '22.03.2020', '01.04.2020', '11.04.2020',
        '21.04.2020', '01.05.2020', '11.05.2020', '21.05.2020'
    ]

    #if use_report_delay:
    #    ax_td = fig.add_subplot(grid[2, 0], sharex=ax_p)
    #
    #        ax_td.fill_between(days, np.exp(TD_quantiles[25]), np.exp(
    #            TD_quantiles[75]), alpha=0.5, zorder=1, facecolor=C1)
    #        ax_td.plot(days, np.exp(TD.mean(axis=0)),
    #                    "-", color=C1, lw=2, zorder=5)
    #        ax_td.plot(days, np.exp(
    #            TD_quantiles[25]), "-", color=C2, lw=2, zorder=3)
    #        ax_td.plot(days, np.exp(
    #            TD_quantiles[75]), "-", color=C2, lw=2, zorder=3)
    #        ax_td.plot(days, np.exp(TD[:25, :].T),
    #                    "--", color=C3, lw=1, alpha=0.5, zorder=2)

    #        ax_td.tick_params(axis="x", rotation=45)

    #ax_tp.set_title("campylob." if disease ==
    #            "campylobacter" else disease, fontsize=22)
    ax_p.set_xlabel("time [days]", fontsize=22)

    ax_p.set_ylabel("periodic\ncontribution", fontsize=22)
    ax_t.set_ylabel("trend\ncontribution", fontsize=22)
    ax_tp.set_ylabel("combined\ncontribution", fontsize=22)

    #if use_report_delay:
    #    ax_td.set_ylabel("r.delay\ncontribution", fontsize=22)

    ax_t.set_xlim(days[0], days[-1])
    ax_t.tick_params(labelbottom=False, labelleft=True, labelsize=18, length=6)
    ax_p.tick_params(labelbottom=True, labelleft=True, labelsize=18, length=6)
    ax_tp.tick_params(labelbottom=False,
                      labelleft=True,
                      labelsize=18,
                      length=6)

    #ax_p.set_xticks(ticks)#,labels)

    if save_plot:
        fig.savefig("../figures/temporal_contribution_{}.pdf".format(model_i))