Ejemplo n.º 1
0
def plot_parameter_traces(param_values_by_chain, max_n_iter=2500):
    param_names = list(param_values_by_chain["belgium"].keys())
    n_rows = len(param_names) + 1
    n_cols = len(OPTI_REGIONS) + 1

    w, h = 6, 2
    title_w, title_h = 4, 2
    fig = plt.figure(
        constrained_layout=True,
        figsize=(title_w + w * len(OPTI_REGIONS),
                 title_h + h * len(param_names)),
    )  # (w, h)
    widths = [title_w] + [w] * len(OPTI_REGIONS)
    heights = [title_h] + [h] * len(param_names)
    spec = fig.add_gridspec(ncols=n_cols,
                            nrows=n_rows,
                            width_ratios=widths,
                            height_ratios=heights)

    # load priors to set y range
    prior_list = get_prior_distributions_for_opti()
    for i_country, country in enumerate(OPTI_REGIONS):
        ax = fig.add_subplot(spec[0, i_country + 1])
        ax.text(
            0.5,
            0.2,
            COUNTRY_TITLES[country],
            fontsize=23,
            horizontalalignment="center",
            verticalalignment="center",
        )
        ax.axis("off")
        for i_param, param_name in enumerate(param_names):
            if param_name in param_values_by_chain[country]:

                ax = fig.add_subplot(spec[i_param + 1, i_country + 1])

                if i_param < len(param_names) - 1:
                    ax.axes.get_xaxis().set_visible(False)

                param_values = param_values_by_chain[country][param_name]
                n_iterations = min(max_n_iter, len(param_values))
                param_values = param_values[:n_iterations]

                ax.plot(range(n_iterations),
                        param_values,
                        "-",
                        color="royalblue")

                ax.grid(False, axis="x")

                y_range = param_info[param_name]["range"]
                _prior = [
                    p for p in prior_list if p["param_name"] == param_name
                ]
                if len(_prior) > 0:
                    prior = _prior[0]
                    if prior["distribution"] == "uniform":
                        y_range = prior["distri_params"]

                range_w = y_range[1] - y_range[0]
                buffer = 0.1 * range_w
                y_range = [y_range[0] - buffer, y_range[1] + buffer]

                ax.set_ylim(y_range)

            if i_country == 0:
                ax = fig.add_subplot(spec[i_param + 1, 0])
                ax.text(
                    0.5,
                    0.2,
                    param_info[param_name]["name"],
                    fontsize=20,
                    horizontalalignment="center",
                    verticalalignment="center",
                )
                ax.axis("off")

    plt.tight_layout()
    plt.savefig("figures/param_traces.pdf")
Ejemplo n.º 2
0
from apps.covid_19.calibration import base
from autumn.constants import Region
from apps.covid_19.mixing_optimisation.utils import (
    get_prior_distributions_for_opti,
    get_target_outputs_for_opti,
    get_weekly_summed_targets,
    add_dispersion_param_prior_for_gaussian,
)
from autumn.tool_kit.utils import print_target_to_plots_from_calibration

country = Region.ITALY

PAR_PRIORS = get_prior_distributions_for_opti()

for i, par in enumerate(PAR_PRIORS):
    if par["param_name"] == "contact_rate":
        PAR_PRIORS[i]["distri_params"] = [0.025, 0.06]

TARGET_OUTPUTS = get_target_outputs_for_opti(country,
                                             source='who',
                                             data_start_time=61,
                                             data_end_time=182)

# Use weekly counts
for target in TARGET_OUTPUTS:
    target["years"], target["values"] = get_weekly_summed_targets(
        target["years"], target["values"])

MULTIPLIERS = {}

PAR_PRIORS = add_dispersion_param_prior_for_gaussian(PAR_PRIORS,
Ejemplo n.º 3
0
def make_posterior_ranges_figure(param_values):
    n_panels = len(param_values["belgium"])
    country_list = OPTI_REGIONS[::-1]
    n_col = 4
    n_row = int(n_panels // n_col)
    if n_col * n_row < n_panels:
        n_row += 1

    fig, axs = plt.subplots(n_row, n_col, figsize=(11, 12))
    plt.subplots_adjust(left=None,
                        bottom=None,
                        right=None,
                        top=None,
                        wspace=None,
                        hspace=0.4)

    # load priors to set x range
    prior_list = get_prior_distributions_for_opti()

    i_col = -1
    i_row = 0
    for param_name in list(param_values["belgium"].keys()):
        i_col += 1
        if i_col >= n_col:
            i_col = 0
            i_row += 1
        h = 0
        for country in country_list:
            h += 1
            # find mean and CI
            if param_name in param_values[country]:
                values = param_values[country][param_name]
                point_estimate = mean(values)
                low_95, low_50, med, up_50, up_95 = quantile(values,
                                                             q=(0.025, 0.25,
                                                                0.5, 0.75,
                                                                0.975))

                axs[i_row, i_col].plot([low_95, up_95], [h, h],
                                       linewidth=1,
                                       color="black")
                axs[i_row, i_col].plot([low_50, up_50], [h, h],
                                       linewidth=3,
                                       color="steelblue")
                axs[i_row, i_col].plot([point_estimate], [h],
                                       marker="o",
                                       color="crimson",
                                       markersize=5)

        axs[i_row, i_col].plot([0], [0])
        axs[i_row, i_col].set_ylim((0.5, 6.5))

        axs[i_row, i_col].set_title(param_info[param_name]["name"],
                                    fontsize=10.5)

        x_range = param_info[param_name]["range"]
        _prior = [p for p in prior_list if p["param_name"] == param_name]
        if len(_prior) > 0:
            prior = _prior[0]
            if prior["distribution"] == "uniform":
                x_range = prior["distri_params"]
        range_w = x_range[1] - x_range[0]
        buffer = 0.1 * range_w
        x_range = [x_range[0] - buffer, x_range[1] + buffer]

        axs[i_row, i_col].set_xlim(x_range)

        # Format x-ticks if requested
        if "xticks" in param_info[param_name]:
            axs[i_row, i_col].set_xticks(param_info[param_name]["xticks"])
            axs[i_row,
                i_col].set_xticklabels(param_info[param_name]["xlabels"])

        if "multiplier" in param_info[param_name]:
            axs[i_row, i_col].set_xlabel(param_info[param_name]["multiplier"],
                                         labelpad=-7.5)

        # Set y-ticks and yticks-labels
        if i_col == 0:
            axs[i_row, i_col].set_yticks([1, 2, 3, 4, 5, 6])
            axs[i_row, i_col].set_yticklabels([
                c.title().replace("United-Kingdom", "UK") for c in country_list
            ])
        else:
            axs[i_row, i_col].set_yticks([])

        axs[i_row, i_col].grid(False, axis="y")

    # Leave blank axis for remaining panels
    for i_col_blank in range(i_col + 1, n_col):
        axs[i_row, i_col_blank].axis("off")

    plt.tight_layout()
    plt.savefig("figures/param_posteriors.pdf")