Example #1
0
def plot_parallel_coordinates(
    plotter: Plotter,
    mcmc_tables: List[pd.DataFrame],
    mcmc_params: List[pd.DataFrame],
):
    parameters = [
        param for param in mcmc_params[0].loc[:, "name"].unique().tolist()
        if "dispersion_param" not in param
    ]

    target_n_lines = 500.0
    n_samples = int(target_n_lines / len(mcmc_tables))
    combined_mcmc_df = merge_and_pivot_mcmc_parameters_loglike(
        mcmc_tables, mcmc_params, parameters, n_samples_per_chain=n_samples)
    w = len(parameters) * 200
    h = 800
    labels = {}
    for param in parameters:
        labels[param] = get_plot_text_dict(param)
    figure = px.parallel_coordinates(
        combined_mcmc_df,
        color="fitness",
        dimensions=parameters,
        labels=labels,
        color_continuous_scale=px.colors.diverging.Tealrose,
        height=h,
        width=w,
    )
    figure.show()
Example #2
0
def plot_mixing_matrix_2(plotter: Plotter, iso3: str):
    fig, axes, _, n_rows, n_cols, _ = plotter.get_figure(n_panels=6)
    fig.tight_layout()
    positions = {
        "all_locations": [0, 0],
        "home": [0, 1],
        "work": [0, 2],
        "other_locations": [1, 1],
        "school": [1, 2],
        "none": [1, 0],
    }

    for location, position in positions.items():
        axis = axes[position[0], position[1]]
        if location != "none":
            mixing_matrix = get_country_mixing_matrix(location, iso3)
            im = axis.imshow(mixing_matrix, cmap="hot", interpolation="none", extent=[0, 80, 80, 0])
            axis.set_title(get_plot_text_dict(location), fontsize=12)
            axis.set_xticks([5, 25, 45, 65])
            axis.set_yticks([5, 25, 45, 65])
            axis.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
            cbar = axis.figure.colorbar(im, ax=axis, cmap="hot")

        else:
            axis.axis("off")

    plotter.save_figure(fig, filename="mixing-matrix", title_text="Mixing matrix")
Example #3
0
def plot_time_varying_multi_input(
    plotter: Plotter,
    tv_key: str,
    times: List[float],
    is_logscale: bool,
):
    """
    Plot single simple plot of a function over time
    """
    # Plot requested func names.
    fig, axes, max_dims, n_rows, n_cols, _ = plotter.get_figure()
    if is_logscale:
        axes.set_yscale("log")

    df = pd.DataFrame(tv_key)
    df.index = times

    axes.plot(df.index, df.values)
    change_xaxis_to_date(axes, REF_DATE)
    pyplot.legend(
        df.columns, loc="best", labels=[get_plot_text_dict(location) for location in df.columns]
    )
    if X_MIN is not None and X_MAX is not None:
        axes.set_xlim((X_MIN, X_MAX))
    axes.set_ylim(bottom=0.0)

    plotter.save_figure(
        fig, filename=f"time-variant-{'Google mobility'}", title_text="Google mobility"
    )
Example #4
0
def plot_param_matrix(
    plotter: StreamlitPlotter,
    mcmc_params: List[pd.DataFrame],
    parameters: List,
    label_param_string=False,
    show_ticks=False,
    file_name="",
):

    burn_in, label_font_size, label_chars, bins, style, dpi_request = BURN_INS, 8, 2, 20, "Shade", 300
    plots.calibration.plots.plot_param_vs_param(
        plotter,
        mcmc_params,
        parameters,
        burn_in,
        style,
        bins,
        label_font_size,
        label_chars,
        dpi_request,
        label_param_string=label_param_string,
        show_ticks=show_ticks,
        file_name=file_name,
    )
    param_names = [get_plot_text_dict(param) for param in parameters]
    params_df = pd.DataFrame({"names": param_names})
    params_df["numbers"] = range(1, len(params_df) + 1)
    create_downloadable_csv(params_df, "parameter_indices")
    key_string = ""
    for i_param, param_name in enumerate(param_names):
        key_string += str(i_param + 1) + ", " + param_name + "; "
    st.write(key_string)
    st.dataframe(params_df)
Example #5
0
def plot_multiple_posteriors(
    plotter: Plotter,
    mcmc_params: List[pd.DataFrame],
    mcmc_tables: List[pd.DataFrame],
    burn_in: int,
    num_bins: int,
    title_font_size: int,
    label_font_size: int,
    capitalise_first_letter: bool,
    dpi_request: int,
    priors: list,
    parameters: list,
    file_name="all_posteriors",
):
    """
    Plots the posterior distribution of a given parameter in a histogram.
    """

    # Except not the dispersion parameters - only the epidemiological ones
    fig, axes, _, n_rows, n_cols, indices = plotter.get_figure(len(parameters))

    for i in range(n_rows * n_cols):
        axis = axes[indices[i][0], indices[i][1]]

        if i < len(parameters):
            param_name = parameters[i]
            vals_df = get_posterior(mcmc_params, mcmc_tables, param_name,
                                    burn_in)

            for i, prior in enumerate(priors):
                if prior["param_name"] == param_name:
                    prior = priors[i]
                    break
            x_range = workout_plot_x_range(prior)
            x_values = np.linspace(x_range[0], x_range[1], num=1000)
            y_values = [calculate_prior(prior, x, log=False) for x in x_values]

            # Plot histograms
            vals_df.hist(bins=num_bins, ax=axis, density=True)

            # Plot the prior
            axis.plot(x_values, y_values)

            axis.set_title(
                get_plot_text_dict(
                    param_name,
                    capitalise_first_letter=capitalise_first_letter),
                fontsize=title_font_size,
            )
            pyplot.setp(axis.get_yticklabels(), fontsize=label_font_size)
            pyplot.setp(axis.get_xticklabels(), fontsize=label_font_size)

        else:
            axis.axis("off")

    fig.tight_layout()
    plotter.save_figure(fig, filename=file_name, dpi_request=dpi_request)
Example #6
0
def plot_multiple_param_traces(
    plotter: Plotter,
    mcmc_params: List[pd.DataFrame],
    burn_in: int,
    title_font_size: int,
    label_font_size: int,
    capitalise_first_letter: bool,
    dpi_request: int,
    optional_param_request=None,
    file_name="all_traces",
    x_ticks_on=True,
):

    # Except not the dispersion parameters - only the epidemiological ones
    parameters = [
        param for param in mcmc_params[0].loc[:, "name"].unique().tolist()
        if "dispersion_param" not in param
    ]
    params_to_plot = optional_param_request if optional_param_request else parameters

    fig, axes, _, n_rows, n_cols, indices = plotter.get_figure(
        len(params_to_plot), share_xaxis=True, share_yaxis=False)

    for i in range(n_rows * n_cols):
        axis = axes[indices[i][0], indices[i][1]]
        if i < len(params_to_plot):
            param_name = params_to_plot[i]
            for i_chain in range(mcmc_params[0]["chain"].iloc[-1]):
                param_mask = (mcmc_params[0]["chain"] == i_chain) & (
                    mcmc_params[0]["name"] == param_name)
                param_values = mcmc_params[0][param_mask].values
                axis.plot(param_values[:, 3], alpha=0.8, linewidth=0.7)
            axis.set_title(
                get_plot_text_dict(
                    param_name,
                    capitalise_first_letter=capitalise_first_letter),
                fontsize=title_font_size,
            )

            if not x_ticks_on:
                axis.set_xticks([])
            elif indices[i][0] == n_rows - 1:
                x_label = "Iterations" if capitalise_first_letter else "iterations"
                axis.set_xlabel(x_label, fontsize=label_font_size)
            pyplot.setp(axis.get_yticklabels(), fontsize=label_font_size)
            pyplot.setp(axis.get_xticklabels(), fontsize=label_font_size)

            if burn_in > 0:
                axis.axvline(x=burn_in,
                             color=COLOR_THEME[1],
                             linestyle="dotted")

        else:
            axis.axis("off")

    fig.tight_layout()
    plotter.save_figure(fig, filename=file_name, dpi_request=dpi_request)
Example #7
0
def plot_all_params_vs_loglike(
    plotter: Plotter,
    mcmc_tables: List[pd.DataFrame],
    mcmc_params: List[pd.DataFrame],
    burn_in: int,
    title_font_size: int,
    label_font_size: int,
    capitalise_first_letter: bool,
    dpi_request: int,
):

    # Except not the dispersion parameters - only the epidemiological ones
    parameters = [
        param for param in mcmc_params[0].loc[:, "name"].unique().tolist()
        if "dispersion_param" not in param
    ]

    fig, axes, _, n_rows, n_cols, indices = plotter.get_figure(
        len(parameters), share_xaxis=False, share_yaxis=True)

    for i in range(n_rows * n_cols):
        axis = axes[indices[i][0], indices[i][1]]

        if i < len(parameters):
            param_name = parameters[i]
            plot_param_vs_loglike(mcmc_tables, mcmc_params, param_name,
                                  burn_in, axis)
            axis.set_title(
                get_plot_text_dict(
                    param_name,
                    capitalise_first_letter=capitalise_first_letter),
                fontsize=title_font_size,
            )
            if indices[i][0] == n_rows - 1:
                x_label = "Iterations" if capitalise_first_letter else "iterations"
                axis.set_xlabel(x_label, fontsize=label_font_size)
            pyplot.setp(axis.get_yticklabels(), fontsize=label_font_size)
            pyplot.setp(axis.get_xticklabels(), fontsize=label_font_size)

        else:
            axis.axis("off")

    fig.tight_layout()
    plotter.save_figure(fig,
                        filename=f"all_posteriors",
                        dpi_request=dpi_request)
Example #8
0
def plot_multi_cdr_curves(plotter: Plotter, times, detected_proportions,
                          start_date, end_date, rotation, regions):
    """
    Plot multiple sets of CDR curves onto a multi-panel figure
    """
    fig, axes, _, n_rows, n_cols, indices = plotter.get_figure(
        n_panels=len(regions))

    # Loop over models and plot
    for i_region in range(n_rows * n_cols):
        axis = axes[indices[i_region][0], indices[i_region][1]]
        if i_region < len(regions):
            axis = plot_cdr_to_axis(axis, times,
                                    detected_proportions[i_region])
            tidy_cdr_axis(axis, rotation, start_date, end_date)
            axis.set_title(get_plot_text_dict(regions[i_region]))
        else:
            axis.axis("off")

    fig.tight_layout()
    plotter.save_figure(fig, filename=f"multi_cdr_curves")
Example #9
0
def plot_multi_fit(
    plotter: Plotter,
    output_names: list,
    outputs: dict,
    targets,
    is_logscale=False,
    title_font_size=8,
    label_font_size=8,
    dpi_request=300,
    capitalise_first_letter=False,
):

    fig, axes, _, n_rows, n_cols, indices = plotter.get_figure(
        len(output_names), share_xaxis=True)

    for i_output in range(n_rows * n_cols):
        if i_output < len(output_names):
            output = output_names[i_output]
            axis = plot_calibration(
                axes[indices[i_output][0], indices[i_output][1]],
                output,
                outputs[output],
                targets,
                is_logscale,
            )
            change_xaxis_to_date(axis, REF_DATE, rotation=0)
            axis.set_title(
                get_plot_text_dict(
                    output, capitalise_first_letter=capitalise_first_letter),
                fontsize=title_font_size,
            )
            filename = f"calibration-fit-{output}"
        else:
            axes[indices[i_output][0], indices[i_output][1]].axis("off")

    fig.tight_layout()
    plotter.save_figure(fig, filename=filename, dpi_request=dpi_request)
Example #10
0
def plot_vic_seroprevalences(
    plotter: Plotter,
    uncertainty_df: pd.DataFrame,
    scenario_id: int,
    time: float,
    ref_date=REF_DATE,
    name="",
    requested_quantiles=None,
    credible_range=50,
):

    fig, axes, _, _, _, _ = plotter.get_figure(n_panels=2, share_yaxis="all")
    cluster_axis, age_axis = axes
    mask = (uncertainty_df["scenario"]
            == scenario_id) & (uncertainty_df["time"] == time)
    df = uncertainty_df[mask]
    quantile_vals = df["quantile"].unique().tolist()
    seroprevalence_by_cluster = {}
    sero_outputs = [
        output for output in df["type"].unique().tolist()
        if "proportion_seropositiveXcluster_" in output
    ]

    max_value = -10.0
    if len(sero_outputs) == 0:
        cluster_axis.text(
            0.0, 0.5,
            "Cluster-specific seroprevalence outputs are not available for this run"
        )
    else:
        for output in sero_outputs:
            output_mask = df["type"] == output
            cluster = output.split("proportion_seropositiveXcluster_")[1]
            seroprevalence_by_cluster[cluster] = {}
            for q in quantile_vals:
                q_mask = df["quantile"] == q
                seroprevalence_by_cluster[cluster][q] = [
                    100.0 * v
                    for v in df[output_mask][q_mask]["value"].tolist()
                ]
        q_keys = requested_quantiles if requested_quantiles else sorted(
            quantile_vals)
        num_quantiles = len(q_keys)
        half_length = num_quantiles // 2

        cluster_names = [
            get_plot_text_dict(i.split("proportion_seropositiveXcluster_")[1])
            for i in sero_outputs
        ]

        lower_q_key = (100.0 - credible_range) / 100.0 / 2.0
        upper_q_key = 1.0 - lower_q_key

        x_positions = range(len(seroprevalence_by_cluster))

        for i, cluster in enumerate(list(seroprevalence_by_cluster.keys())):
            cluster_axis.plot(
                [x_positions[i], x_positions[i]],
                [
                    seroprevalence_by_cluster[cluster][lower_q_key],
                    seroprevalence_by_cluster[cluster][upper_q_key],
                ],
                "-",
                color="black",
                lw=1.0,
            )
            cluster_axis.set_xticklabels(cluster_names,
                                         fontsize=10,
                                         rotation=90)
            max_value = max(max_value,
                            seroprevalence_by_cluster[cluster][upper_q_key][0])

            if num_quantiles % 2:
                q_key = q_keys[half_length]
                label = None if i > 0 else "model"
                cluster_axis.plot(
                    x_positions[i],
                    seroprevalence_by_cluster[cluster][q_key],
                    "o",
                    color="black",
                    markersize=4,
                    label=label,
                )

        cluster_axis.xaxis.set_ticks(x_positions)

        cluster_axis.set_ylim(bottom=0.0)
        cluster_axis.set_ylabel("% previously infected", fontsize=13)
        _date = ref_date + datetime.timedelta(days=time)

    axis, max_value, df, seroprevalence_by_age = plot_age_seroprev_to_axis(
        uncertainty_df,
        scenario_id,
        time,
        age_axis,
        requested_quantiles,
        ref_date,
        name,
        add_date_as_title=False,
        add_ylabel=False,
    )

    plotter.save_figure(fig,
                        filename="sero_by_cluster",
                        subdir="outputs",
                        title_text="")

    overall_seropos_estimates = df[df["type"] == "proportion_seropositive"][[
        "quantile", "value"
    ]].set_index("quantile")

    return max_value, seroprevalence_by_cluster, overall_seropos_estimates
Example #11
0
def plot_timeseries_with_uncertainty(plotter: Plotter,
                                     uncertainty_df: pd.DataFrame,
                                     output_name: str,
                                     scenario_idxs: List[int],
                                     targets: dict,
                                     is_logscale=False,
                                     x_low=0.0,
                                     x_up=1e6,
                                     axis=None,
                                     n_xticks=None,
                                     ref_date=REF_DATE,
                                     add_targets=True,
                                     overlay_uncertainty=True,
                                     title_font_size=12,
                                     label_font_size=10,
                                     dpi_request=300,
                                     capitalise_first_letter=False,
                                     legend=False,
                                     requested_x_ticks=None,
                                     show_title=True,
                                     ylab=None,
                                     x_axis_to_date=True,
                                     start_quantile=0,
                                     sc_colors=None,
                                     custom_title=None,
                                     vlines={},
                                     hlines={}):
    """
    Plots the uncertainty timeseries for one or more scenarios.
    Also plots any calibration targets that are provided.
    """

    single_panel = axis is None
    if single_panel:
        fig, axis, _, _, _, _ = plotter.get_figure()

    n_scenarios_to_plot = len(scenario_idxs)
    if sc_colors is None:
        n_scenarios_to_plot = min([len(scenario_idxs), len(COLORS)])
        colors = _apply_transparency(COLORS[:n_scenarios_to_plot],
                                     ALPHAS[:n_scenarios_to_plot])

    # Plot each scenario on a single axis
    data_to_return = {}
    for i, scenario_idx in enumerate(scenario_idxs[:n_scenarios_to_plot]):
        if sc_colors is None:
            if scenario_idx < len(colors):
                scenario_colors = colors[scenario_idx]
            else:
                scenario_colors = colors[-1]
        else:
            scenario_colors = sc_colors[i]

        times, quantiles = _plot_uncertainty(
            axis,
            uncertainty_df,
            output_name,
            scenario_idx,
            x_up,
            x_low,
            scenario_colors,
            overlay_uncertainty=overlay_uncertainty,
            start_quantile=start_quantile,
            zorder=i + 1,
        )

        data_to_return[scenario_idx] = pd.DataFrame.from_dict(quantiles)
        data_to_return[scenario_idx].insert(0, "days from 31/12/2019", times)

    # Add plot targets
    if add_targets:
        values, times = _get_target_values(targets, output_name)
        trunc_values = [
            v for (v, t) in zip(values, times) if x_low <= t <= x_up
        ]
        trunc_times = [
            t for (v, t) in zip(values, times) if x_low <= t <= x_up
        ]
        _plot_targets_to_axis(axis,
                              trunc_values,
                              trunc_times,
                              on_uncertainty_plot=True)

    # Sort out x-axis
    if x_axis_to_date:
        change_xaxis_to_date(axis, ref_date, rotation=0)
    axis.tick_params(axis="x", labelsize=label_font_size)
    axis.tick_params(axis="y", labelsize=label_font_size)

    # Add lines with marking text to plots
    add_vertical_lines_to_plot(axis, vlines)
    add_horizontal_lines_to_plot(axis, hlines)

    if output_name == "proportion_seropositive":
        axis.yaxis.set_major_formatter(mtick.PercentFormatter(1, symbol=""))
    if show_title:
        title = custom_title if custom_title else get_plot_text_dict(
            output_name)
        axis.set_title(title, fontsize=title_font_size)

    if requested_x_ticks is not None:
        pyplot.xticks(requested_x_ticks)
    elif n_xticks is not None:
        pyplot.locator_params(axis="x", nbins=n_xticks)

    if is_logscale:
        axis.set_yscale("log")
    elif not (output_name.startswith("rel_diff")
              or output_name.startswith("abs_diff")):
        axis.set_ylim(ymin=0)

    if ylab is not None:
        axis.set_ylabel(ylab, fontsize=label_font_size)

    if legend:
        pyplot.legend(labels=scenario_idxs)

    if single_panel:
        idx_str = "-".join(map(str, scenario_idxs))
        filename = f"uncertainty-{output_name}-{idx_str}"
        plotter.save_figure(fig, filename=filename, dpi_request=dpi_request)

    return data_to_return
Example #12
0
def plot_param_vs_param(
    plotter: Plotter,
    mcmc_params: List[pd.DataFrame],
    parameters: list,
    burn_in: int,
    style: str,
    bins: int,
    label_font_size: int,
    label_chars: int,
    dpi_request: int,
    label_param_string=True,
    show_ticks=False,
    file_name="parameter_correlation_matrix",
):
    """
    Plot the parameter correlation matrices for each parameter combination.
    """

    # Prelims
    fig, axes, _, _, _, _ = plotter.get_figure(n_panels=len(parameters)**2)
    row_data, col_data = {}, {}

    # Get x and y data separately and collate up over the chains
    for row_idx, row_param_name in enumerate(parameters):
        row_data[row_param_name] = []
        for chain in range(len(mcmc_params)):
            x_param_mask = (mcmc_params[chain]["name"] == row_param_name) & (
                mcmc_params[chain]["run"] > burn_in)
            row_data[row_param_name] += mcmc_params[chain][x_param_mask][
                "value"].to_list()
    for col_idx, col_param_name in enumerate(parameters):
        col_data[col_param_name] = []
        for chain in range(len(mcmc_params)):
            y_param_mask = (mcmc_params[chain]["name"] == col_param_name) & (
                mcmc_params[chain]["run"] > burn_in)
            col_data[col_param_name] += mcmc_params[chain][y_param_mask][
                "value"].to_list()

    # Loop over parameter combinations
    for row_idx, row_param_name in enumerate(parameters):
        for col_idx, col_param_name in enumerate(parameters):

            axis = axes[row_idx, col_idx]
            if not show_ticks:
                axis.xaxis.set_ticks([])
                axis.yaxis.set_ticks([])
            else:
                axis.tick_params(labelsize=4)

            # Plot
            if row_idx > col_idx:
                if style == "Scatter":
                    axis.scatter(
                        col_data[col_param_name],
                        row_data[row_param_name],
                        alpha=0.5,
                        s=0.1,
                        color="k",
                    )
                elif style == "KDE":
                    sns.kdeplot(
                        col_data[col_param_name],
                        row_data[row_param_name],
                        ax=axis,
                        shade=True,
                        levels=5,
                        lw=1.0,
                    )
                else:
                    axis.hist2d(col_data[col_param_name],
                                row_data[row_param_name],
                                bins=bins)
            elif row_idx == col_idx:
                axis.hist(
                    row_data[row_param_name],
                    color=[0.2, 0.2, 0.6] if style == "Shade" else "k",
                    bins=bins,
                )
                axis.yaxis.set_ticks([])
            else:
                axis.axis("off")

            # Axis labels (these have to be reversed for some reason)
            x_param_label = col_param_name if label_param_string else str(
                col_idx + 1)
            y_param_label = row_param_name if label_param_string else str(
                row_idx + 1)
            if row_idx == len(parameters) - 1:
                axis.set_xlabel(get_plot_text_dict(x_param_label),
                                fontsize=label_font_size,
                                labelpad=3)
            if col_idx == 0:
                axis.set_ylabel(get_plot_text_dict(y_param_label),
                                fontsize=label_font_size)

    # Save
    plotter.save_figure(fig, filename=file_name, dpi_request=dpi_request)