コード例 #1
0
def plot_mixing_matrix(plotter: Plotter, location: str, iso3: str):
    fig, axis, _, _, _, _ = plotter.get_figure()
    legend = []

    mixing_matrix = get_country_mixing_matrix(location, iso3)
    pyplot.imshow(mixing_matrix, cmap="hot", interpolation="none", extent=[0, 80, 80, 0])
    plotter.save_figure(fig, filename="mixing-matrix", title_text="Mixing matrix")
コード例 #2
0
def plot_mixing_matrix_2(plotter: Plotter, iso3: str):
    fig, axes, _, n_rows, n_cols, _ = plotter.get_figure(n_panels=6)
    fig.tight_layout()
    positions = {
        "all_locations": [0, 0],
        "home": [0, 1],
        "work": [0, 2],
        "other_locations": [1, 1],
        "school": [1, 2],
        "none": [1, 0],
    }

    for location, position in positions.items():
        axis = axes[position[0], position[1]]
        if location != "none":
            mixing_matrix = get_country_mixing_matrix(location, iso3)
            im = axis.imshow(mixing_matrix, cmap="hot", interpolation="none", extent=[0, 80, 80, 0])
            axis.set_title(get_plot_text_dict(location), fontsize=12)
            axis.set_xticks([5, 25, 45, 65])
            axis.set_yticks([5, 25, 45, 65])
            axis.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
            cbar = axis.figure.colorbar(im, ax=axis, cmap="hot")

        else:
            axis.axis("off")

    plotter.save_figure(fig, filename="mixing-matrix", title_text="Mixing matrix")
コード例 #3
0
def plot_time_varying_multi_input(
    plotter: Plotter,
    tv_key: str,
    times: List[float],
    is_logscale: bool,
):
    """
    Plot single simple plot of a function over time
    """
    # Plot requested func names.
    fig, axes, max_dims, n_rows, n_cols, _ = plotter.get_figure()
    if is_logscale:
        axes.set_yscale("log")

    df = pd.DataFrame(tv_key)
    df.index = times

    axes.plot(df.index, df.values)
    change_xaxis_to_date(axes, REF_DATE)
    pyplot.legend(
        df.columns, loc="best", labels=[get_plot_text_dict(location) for location in df.columns]
    )
    if X_MIN is not None and X_MAX is not None:
        axes.set_xlim((X_MIN, X_MAX))
    axes.set_ylim(bottom=0.0)

    plotter.save_figure(
        fig, filename=f"time-variant-{'Google mobility'}", title_text="Google mobility"
    )
コード例 #4
0
def plot_time_varying_input(
    plotter: Plotter,
    tv_key: str,
    tv_func: Callable[[float], float],
    times: List[float],
    is_logscale: bool,
):
    """
    Plot single simple plot of a function over time
    """
    # Plot requested func names.
    fig, axes, max_dims, n_rows, n_cols, _ = plotter.get_figure()
    if is_logscale:
        axes.set_yscale("log")

    if type(tv_func) is not list:
        funcs = [tv_func]
    else:
        funcs = tv_func

    for func in funcs:
        values = list(map(func, times))
        axes.plot(times, values)

    if X_MIN is not None and X_MAX is not None:
        axes.set_xlim((X_MIN, X_MAX))

    plotter.save_figure(fig, filename=f"time-variant-{tv_key}", title_text=tv_key)
コード例 #5
0
def plot_multi_age_distribution(plotter: Plotter, sub_region: [str], iso3: str):
    fig, axis, _, _, _, _ = plotter.get_figure()
    legend = []

    import streamlit as st

    # Set age groups
    agegroup_strata = [int(s) for s in range(0, 100, 5)]
    if iso3 is "PHL":

        multi, axes = pyplot.subplots(2, 2, figsize=(30, 30))

        i = 0
        for x in {0, 1}:
            for y in {0, 1}:

                age_distribution = get_population_by_agegroup(agegroup_strata, iso3, sub_region[i])
                age_distribution = [each / 10e5 for each in age_distribution]
                axes[y, x].bar(agegroup_strata, height=age_distribution, width=4, align="edge")
                if i == 0:
                    axes[y, x].set_title("Philippines").set_fontsize(30)
                else:
                    axes[y, x].set_title(sub_region[i]).set_fontsize(30)
                axes[y, x].set_xlabel("Age").set_fontsize(20)
                axes[y, x].set_ylabel("Millions").set_fontsize(20)
                axes[y, x].set_ylim(0, 12)
                axes[y, x].xaxis.set_tick_params(labelsize=20)
                axes[y, x].yaxis.set_tick_params(labelsize=20)
                i += 1

    plotter.save_figure(multi, filename="age-distribution", title_text="Age distribution")
コード例 #6
0
def plot_outputs_single(
    plotter: Plotter,
    scenario: Scenario,
    output_config: dict,
    is_logscale=False,
    axis=None,
    single_panel=True,
    xaxis_date=False,
):
    """
    Plot the model derived/generated outputs requested by the user for a single scenario.
    """
    if single_panel:
        fig, axis, _, _, _, _ = plotter.get_figure()

    if is_logscale:
        axis.set_yscale("log")

    output_name = output_config["output_key"]
    target_values = output_config["values"]
    target_times = output_config["times"]
    _plot_outputs_to_axis(axis, scenario, output_name)
    _plot_targets_to_axis(axis, target_values, target_times)

    if xaxis_date:
        change_xaxis_to_date(axis, REF_DATE)

    if X_MIN is not None and X_MAX is not None:
        axis.set_xlim((X_MIN, X_MAX))

    if single_panel:
        plotter.save_figure(fig, filename=output_name, subdir="outputs", title_text=output_name)
コード例 #7
0
def plot_acceptance_ratio(
    plotter: Plotter,
    mcmc_tables: List[pd.DataFrame],
    burn_in: int,
    label_font_size=6,
    dpi_request=300,
):
    """
    Plot the progressive acceptance ratio over iterations.
    """

    fig, axis, _, _, _, _ = plotter.get_figure()
    full_df = db.load.append_tables(mcmc_tables)
    n_chains = max(full_df["chain"])
    for chain in range(n_chains):
        chain_mask = full_df["chain"] == chain
        chain_df = full_df[chain_mask]
        ratios = collate_acceptance_ratios(chain_df["accept"])

        # Plot
        axis.plot(ratios, alpha=0.8, linewidth=0.7)

        # Add vertical line for burn-in point
        if burn_in > 0:
            axis.axvline(x=burn_in, color=COLOR_THEME[1], linestyle="dotted")

    axis.set_title("acceptance ratio", fontsize=label_font_size)
    axis.set_xlabel("iterations", fontsize=label_font_size)
    axis.set_ylim(bottom=0.0)
    pyplot.setp(axis.get_yticklabels(), fontsize=label_font_size)
    pyplot.setp(axis.get_xticklabels(), fontsize=label_font_size)
    plotter.save_figure(fig,
                        filename=f"acceptance_ratio",
                        dpi_request=dpi_request)
コード例 #8
0
ファイル: plots.py プロジェクト: thecapitalistcycle/AuTuMN
def plot_seroprevalence_by_age(
    plotter: Plotter,
    uncertainty_df: pd.DataFrame,
    scenario_id: int,
    time: float,
    ref_date=REF_DATE,
    axis=None,
    name="",
    requested_quantiles=None,
):
    single_panel = axis is None
    if single_panel:
        fig, axis, _, _, _, _ = plotter.get_figure()

    axis, max_value, df, seroprevalence_by_age = plot_age_seroprev_to_axis(
        uncertainty_df, scenario_id, time, axis, requested_quantiles, ref_date,
        name)
    if single_panel:
        plotter.save_figure(fig,
                            filename="sero_by_age",
                            subdir="outputs",
                            title_text="")

    overall_seropos_estimates = df[df["type"] == "proportion_seropositive"][[
        "quantile", "value"
    ]].set_index("quantile")

    return max_value, seroprevalence_by_age, overall_seropos_estimates
コード例 #9
0
def plot_posterior(
    plotter: Plotter,
    mcmc_params: List[pd.DataFrame],
    mcmc_tables: List[pd.DataFrame],
    burn_in: int,
    param_name: str,
    num_bins: int,
    prior,
):
    """
    Plots the posterior distribution of a given parameter in a histogram.
    """
    vals_df = get_posterior(mcmc_params, mcmc_tables, param_name, burn_in)
    fig, axis, _, _, _, _ = plotter.get_figure()
    vals_df.hist(bins=num_bins, ax=axis, density=True)

    if prior:
        x_range = workout_plot_x_range(prior)
        x_values = np.linspace(x_range[0], x_range[1], num=1000)
        y_values = [calculate_prior(prior, x, log=False) for x in x_values]

        # Plot the prior
        axis.plot(x_values, y_values)

    plotter.save_figure(fig,
                        filename=f"{param_name}-posterior",
                        title_text=f"{param_name} posterior")
コード例 #10
0
def plot_agg_compartments_multi_scenario(
    plotter: Plotter,
    scenarios: List[Scenario],
    compartment_names: List[str],
    is_logscale=False,
):
    """
    Plot multiple compartments with values aggregated for a multiple scenarios.
    """
    fig, axis, _, _, _, _ = plotter.get_figure()
    legend = []
    for color_idx, scenario in enumerate(scenarios):
        model = scenario.model
        values = np.zeros(model.outputs.shape[0])
        for compartment_name in compartment_names:
            comp_idx = model.compartment_names.index(compartment_name)
            values += model.outputs[:, comp_idx]

        axis.plot(model.times, values, color=COLOR_THEME[color_idx], alpha=0.7)
        legend.append(scenario.name)

    axis.legend(legend)
    if is_logscale:
        axis.set_yscale("log")

    plotter.save_figure(fig, filename="aggregate-compartments", title_text="aggregate compartments")
コード例 #11
0
def plot_multiple_posteriors(
    plotter: Plotter,
    mcmc_params: List[pd.DataFrame],
    mcmc_tables: List[pd.DataFrame],
    burn_in: int,
    num_bins: int,
    title_font_size: int,
    label_font_size: int,
    capitalise_first_letter: bool,
    dpi_request: int,
    priors: list,
    parameters: list,
    file_name="all_posteriors",
):
    """
    Plots the posterior distribution of a given parameter in a histogram.
    """

    # Except not the dispersion parameters - only the epidemiological ones
    fig, axes, _, n_rows, n_cols, indices = plotter.get_figure(len(parameters))

    for i in range(n_rows * n_cols):
        axis = axes[indices[i][0], indices[i][1]]

        if i < len(parameters):
            param_name = parameters[i]
            vals_df = get_posterior(mcmc_params, mcmc_tables, param_name,
                                    burn_in)

            for i, prior in enumerate(priors):
                if prior["param_name"] == param_name:
                    prior = priors[i]
                    break
            x_range = workout_plot_x_range(prior)
            x_values = np.linspace(x_range[0], x_range[1], num=1000)
            y_values = [calculate_prior(prior, x, log=False) for x in x_values]

            # Plot histograms
            vals_df.hist(bins=num_bins, ax=axis, density=True)

            # Plot the prior
            axis.plot(x_values, y_values)

            axis.set_title(
                get_plot_text_dict(
                    param_name,
                    capitalise_first_letter=capitalise_first_letter),
                fontsize=title_font_size,
            )
            pyplot.setp(axis.get_yticklabels(), fontsize=label_font_size)
            pyplot.setp(axis.get_xticklabels(), fontsize=label_font_size)

        else:
            axis.axis("off")

    fig.tight_layout()
    plotter.save_figure(fig, filename=file_name, dpi_request=dpi_request)
コード例 #12
0
def plot_multiple_param_traces(
    plotter: Plotter,
    mcmc_params: List[pd.DataFrame],
    burn_in: int,
    title_font_size: int,
    label_font_size: int,
    capitalise_first_letter: bool,
    dpi_request: int,
    optional_param_request=None,
    file_name="all_traces",
    x_ticks_on=True,
):

    # Except not the dispersion parameters - only the epidemiological ones
    parameters = [
        param for param in mcmc_params[0].loc[:, "name"].unique().tolist()
        if "dispersion_param" not in param
    ]
    params_to_plot = optional_param_request if optional_param_request else parameters

    fig, axes, _, n_rows, n_cols, indices = plotter.get_figure(
        len(params_to_plot), share_xaxis=True, share_yaxis=False)

    for i in range(n_rows * n_cols):
        axis = axes[indices[i][0], indices[i][1]]
        if i < len(params_to_plot):
            param_name = params_to_plot[i]
            for i_chain in range(mcmc_params[0]["chain"].iloc[-1]):
                param_mask = (mcmc_params[0]["chain"] == i_chain) & (
                    mcmc_params[0]["name"] == param_name)
                param_values = mcmc_params[0][param_mask].values
                axis.plot(param_values[:, 3], alpha=0.8, linewidth=0.7)
            axis.set_title(
                get_plot_text_dict(
                    param_name,
                    capitalise_first_letter=capitalise_first_letter),
                fontsize=title_font_size,
            )

            if not x_ticks_on:
                axis.set_xticks([])
            elif indices[i][0] == n_rows - 1:
                x_label = "Iterations" if capitalise_first_letter else "iterations"
                axis.set_xlabel(x_label, fontsize=label_font_size)
            pyplot.setp(axis.get_yticklabels(), fontsize=label_font_size)
            pyplot.setp(axis.get_xticklabels(), fontsize=label_font_size)

            if burn_in > 0:
                axis.axvline(x=burn_in,
                             color=COLOR_THEME[1],
                             linestyle="dotted")

        else:
            axis.axis("off")

    fig.tight_layout()
    plotter.save_figure(fig, filename=file_name, dpi_request=dpi_request)
コード例 #13
0
def plot_loglikelihood_boxplots(plotter: Plotter,
                                mcmc_tables: List[pd.DataFrame]):
    fig, axis, _, _, _, _ = plotter.get_figure()
    if len(mcmc_tables) > 1:
        df = pd.concat(mcmc_tables)
    else:
        df = mcmc_tables[0]

    df["-log(-loglikelihood)"] = [-log(-v) for v in df["loglikelihood"]]
    df.boxplot(column=["-log(-loglikelihood)"], by="chain", ax=axis)
    plotter.save_figure(fig, filename="loglikelihood-boxplots", title_text="")
コード例 #14
0
def plot_param_vs_param_by_chain(
    plotter: Plotter,
    mcmc_params: List[pd.DataFrame],
    parameters: list,
    label_font_size: int,
    label_chars: int,
    dpi_request: int,
):
    """
    Plot the parameter traces for each MCMC chain with separate colouring.
    """

    fig, axes, _, _, _, _ = plotter.get_figure(n_panels=len(parameters)**2)

    for chain in range(len(mcmc_params)):
        for x_idx, x_param_name in enumerate(parameters):
            x_param_mask = mcmc_params[chain]["name"] == x_param_name
            for y_idx, y_param_name in enumerate(parameters):
                y_param_mask = mcmc_params[chain]["name"] == y_param_name

                # Get axis and turn off ticks
                axis = axes[x_idx, y_idx]
                axis.xaxis.set_ticks([])
                axis.yaxis.set_ticks([])

                # Plot
                if x_idx > y_idx:
                    axis.scatter(
                        mcmc_params[chain][x_param_mask]["value"].to_list(),
                        mcmc_params[chain][y_param_mask]["value"].to_list(),
                        alpha=0.5,
                        s=0.1,
                    )
                elif x_idx == y_idx:
                    axis.hist(
                        mcmc_params[chain][x_param_mask]["value"].to_list())
                else:
                    axis.axis("off")

                # Set labels
                if y_idx == 0:
                    axis.set_ylabel(x_param_name[:label_chars],
                                    rotation=0,
                                    fontsize=label_font_size)
                if x_idx == len(parameters) - 1:
                    axis.set_xlabel(y_param_name[:label_chars],
                                    fontsize=label_font_size)

    # Save
    plotter.save_figure(fig,
                        filename="parameter_correlation_matrix",
                        dpi_request=dpi_request)
コード例 #15
0
def plot_age_distribution(plotter: Plotter, sub_region: str, iso3: str):
    fig, axis, _, _, _, _ = plotter.get_figure()
    legend = []

    # Set age groups
    agegroup_strata = [int(s) for s in range(0, 100, 5)]

    age_distribution = get_population_by_agegroup(agegroup_strata, iso3, sub_region)
    age_distribution = [each / 10e5 for each in age_distribution]
    axis.set_xlabel("Age", fontsize=10)
    axis.set_ylabel("Millions", fontsize=10)
    pyplot.bar(agegroup_strata, height=age_distribution, width=4, align="edge")
    plotter.save_figure(fig, filename="age-distribution", title_text="Age distribution")
コード例 #16
0
def plot_burn_in(plotter: Plotter, num_iters: int, burn_in: int):
    """
    Plot the trade off been num iters and burn-in for MCMC runs.
    """
    fig, axis, _, _, _, _ = plotter.get_figure()

    def floor(n):
        val = num_iters - n
        return val if val > 0 else 0

    values = [floor(i) for i in range(num_iters)]

    fig, axis, _, _, _, _ = plotter.get_figure()

    axis.plot(values, color=COLOR_THEME[0])
    axis.set_ylabel("Number iters after burn-in")
    axis.set_xlabel("Burn-in size")
    axis.set_ylim(bottom=-5, top=num_iters)
    axis.set_xlim(left=0, right=num_iters)
    axis.axvline(x=burn_in, color=COLOR_THEME[1], linestyle="dotted")
    axis.axhline(y=num_iters - burn_in,
                 color=COLOR_THEME[1],
                 linestyle="dotted")
    plotter.save_figure(fig, filename="burn-in", title_text="burn-in")
コード例 #17
0
def plot_calibration_fit(
    plotter: Plotter,
    output_name: str,
    outputs: list,
    targets,
    is_logscale=False,
):
    fig, axis, _, _, _, _ = plotter.get_figure()
    plot_calibration(axis, output_name, outputs, targets, is_logscale)
    if is_logscale:
        filename = f"calibration-fit-{output_name}-logscale"
        title_text = f"Calibration fit for {output_name} (logscale)"
    else:
        filename = f"calibration-fit-{output_name}"
        title_text = f"Calibration fit for {output_name}"
    plotter.save_figure(fig, filename=filename, title_text=title_text)
コード例 #18
0
def plot_all_params_vs_loglike(
    plotter: Plotter,
    mcmc_tables: List[pd.DataFrame],
    mcmc_params: List[pd.DataFrame],
    burn_in: int,
    title_font_size: int,
    label_font_size: int,
    capitalise_first_letter: bool,
    dpi_request: int,
):

    # Except not the dispersion parameters - only the epidemiological ones
    parameters = [
        param for param in mcmc_params[0].loc[:, "name"].unique().tolist()
        if "dispersion_param" not in param
    ]

    fig, axes, _, n_rows, n_cols, indices = plotter.get_figure(
        len(parameters), share_xaxis=False, share_yaxis=True)

    for i in range(n_rows * n_cols):
        axis = axes[indices[i][0], indices[i][1]]

        if i < len(parameters):
            param_name = parameters[i]
            plot_param_vs_loglike(mcmc_tables, mcmc_params, param_name,
                                  burn_in, axis)
            axis.set_title(
                get_plot_text_dict(
                    param_name,
                    capitalise_first_letter=capitalise_first_letter),
                fontsize=title_font_size,
            )
            if indices[i][0] == n_rows - 1:
                x_label = "Iterations" if capitalise_first_letter else "iterations"
                axis.set_xlabel(x_label, fontsize=label_font_size)
            pyplot.setp(axis.get_yticklabels(), fontsize=label_font_size)
            pyplot.setp(axis.get_xticklabels(), fontsize=label_font_size)

        else:
            axis.axis("off")

    fig.tight_layout()
    plotter.save_figure(fig,
                        filename=f"all_posteriors",
                        dpi_request=dpi_request)
コード例 #19
0
def plot_mcmc_parameter_trace(plotter: Plotter,
                              mcmc_params: List[pd.DataFrame], burn_in: int,
                              param_name: str):
    """
    Plot the prameter traces for each MCMC run.
    """
    fig, axis, _, _, _, _ = plotter.get_figure()
    for idx, table_df in enumerate(mcmc_params):
        param_mask = table_df["name"] == param_name
        param_df = table_df[param_mask]
        axis.plot(param_df["run"], param_df["value"], alpha=0.8, linewidth=0.7)
        if burn_in > 0:
            axis.axvline(x=burn_in, color=COLOR_THEME[1], linestyle="dotted")

    axis.set_ylabel(param_name)
    axis.set_xlabel("MCMC iterations")
    plotter.save_figure(fig,
                        filename=f"{param_name}-traces",
                        title_text=f"{param_name}-traces")
コード例 #20
0
def plot_single_param_loglike(
    plotter: Plotter,
    mcmc_tables: List[pd.DataFrame],
    mcmc_params: List[pd.DataFrame],
    burn_in: int,
    param_name: str,
):
    """
    Plots the loglikelihood against parameter values.
    """
    fig, axis, _, _, _, _ = plotter.get_figure()
    plot_param_vs_loglike(mcmc_tables, mcmc_params, param_name, burn_in, axis)
    axis.set_xlabel(param_name)
    axis.set_ylabel("-log(-loglikelihood)")
    plotter.save_figure(
        fig,
        filename=f"likelihood-against-{param_name}",
        title_text=f"likelihood against {param_name}",
    )
コード例 #21
0
def plot_multi_cdr_curves(plotter: Plotter, times, detected_proportions,
                          start_date, end_date, rotation, regions):
    """
    Plot multiple sets of CDR curves onto a multi-panel figure
    """
    fig, axes, _, n_rows, n_cols, indices = plotter.get_figure(
        n_panels=len(regions))

    # Loop over models and plot
    for i_region in range(n_rows * n_cols):
        axis = axes[indices[i_region][0], indices[i_region][1]]
        if i_region < len(regions):
            axis = plot_cdr_to_axis(axis, times,
                                    detected_proportions[i_region])
            tidy_cdr_axis(axis, rotation, start_date, end_date)
            axis.set_title(get_plot_text_dict(regions[i_region]))
        else:
            axis.axis("off")

    fig.tight_layout()
    plotter.save_figure(fig, filename=f"multi_cdr_curves")
コード例 #22
0
def plot_loglikelihood_trace(plotter: Plotter,
                             mcmc_tables: List[pd.DataFrame],
                             burn_in=0):
    """
    Plot the loglikelihood traces for each MCMC run.
    """
    fig, axis, _, _, _, _ = plotter.get_figure()

    if len(mcmc_tables
           ) == 1:  # there may be multiple chains within a single dataframe
        table_df = mcmc_tables[0]
        accept_mask = table_df["accept"] == 1
        chain_idx = list(table_df["chain"].unique())
        for chain_id in chain_idx:
            chain_mask = table_df["chain"] == chain_id
            masked_df = table_df[accept_mask][chain_mask]
            axis.plot(masked_df["run"],
                      masked_df["loglikelihood"],
                      alpha=0.8,
                      linewidth=0.7)
    else:  # there is one chain per dataframe
        for idx, table_df in enumerate(mcmc_tables):
            accept_mask = table_df["accept"] == 1
            table_df[accept_mask].loglikelihood.plot.line(ax=axis,
                                                          alpha=0.8,
                                                          linewidth=0.7)

    axis.set_ylabel("Loglikelihood")
    axis.set_xlabel("MCMC iterations")

    if burn_in:
        axis.axvline(x=burn_in, color=COLOR_THEME[1], linestyle="dotted")
        y_min = min(table_df.loglikelihood[burn_in:])
        y_max = max(table_df.loglikelihood[burn_in:])
        axis.set_ylim(
            (y_min - 0.2 * (y_max - y_min), y_max + 0.2 * (y_max - y_min)))

    plotter.save_figure(fig,
                        filename="loglikelihood-traces",
                        title_text="loglikelihood-traces")
コード例 #23
0
def plot_cdr_curves(
    plotter: Plotter,
    times,
    detected_proportion,
    end_date,
    rotation,
    start_date=1.0,
    alpha=1.0,
    line_width=0.7,
):
    """
    Plot a single set of CDR curves to a one-panel figure
    """
    fig, axis, _, _, _, _ = plotter.get_figure()
    axis = plot_cdr_to_axis(axis,
                            times,
                            detected_proportion,
                            alpha=alpha,
                            line_width=line_width)
    axis.set_ylabel("proportion symptomatic cases detected")
    tidy_cdr_axis(axis, rotation, start_date, end_date)
    plotter.save_figure(fig, filename=f"cdr_curves")
コード例 #24
0
def plot_outputs_multi(
    plotter: Plotter,
    scenarios: List[Scenario],
    output_config: dict,
    is_logscale=False,
    x_low=0.0,
    x_up=1e6,
):
    """
    Plot the model derived/generated outputs requested by the user for multiple single scenarios, on one plot.
    """
    fig, axis, _, _, _, _ = plotter.get_figure()
    output_name = output_config["output_key"]

    import streamlit as st

    legend = []

    for idx, scenario in enumerate(reversed(scenarios)):
        color_idx = len(scenarios) - idx - 1
        _plot_outputs_to_axis(axis, scenario, output_name, color_idx=color_idx, alpha=0.7)
        legend.append(scenario.name)

    axis.legend(legend)

    values = output_config["values"]
    times = output_config["times"]

    _plot_targets_to_axis(axis, values, times)
    if is_logscale:
        axis.set_yscale("log")

    X_MIN = x_low
    X_MAX = x_up

    if X_MIN is not None and X_MAX is not None:
        axis.set_xlim((X_MIN, X_MAX))
    plotter.save_figure(fig, filename=output_name, title_text=output_name)
コード例 #25
0
def plot_multi_compartments_single_scenario(
    plotter: Plotter, scenario: Scenario, compartments: List[str], is_logscale=False
):
    """
    Plot the selected output compartments for a single scenario.
    """
    model = scenario.model
    times = model.times

    fig, axis, _, _, _, _ = plotter.get_figure()
    legend = []
    for color_idx, compartment_name in enumerate(reversed(compartments)):
        comp_idx = model.compartment_names.index(compartment_name)
        values = model.outputs[:, comp_idx]
        axis.plot(times, values, color=COLOR_THEME[color_idx], alpha=0.7)
        legend.append(compartment_name)

    if len(legend) < 10:
        axis.legend(legend)
    if is_logscale:
        axis.set_yscale("log")

    plotter.save_figure(fig, filename="compartments", title_text="compartments")
コード例 #26
0
def plot_single_compartment_multi_scenario(
    plotter: Plotter,
    scenarios: List[Scenario],
    compartment_name: str,
    is_logscale=False,
):
    """
    Plot the selected output compartment for a multiple scenarios.
    """
    fig, axis, _, _, _, _ = plotter.get_figure()
    legend = []
    for color_idx, scenario in enumerate(scenarios):
        model = scenario.model
        comp_idx = model.compartment_names.index(compartment_name)
        values = model.outputs[:, comp_idx]
        axis.plot(model.times, values, color=COLOR_THEME[color_idx], alpha=0.7)
        legend.append(scenario.name)

    axis.legend(legend)
    if is_logscale:
        axis.set_yscale("log")

    plotter.save_figure(fig, filename=compartment_name, title_text=compartment_name)
コード例 #27
0
def plot_multi_fit(
    plotter: Plotter,
    output_names: list,
    outputs: dict,
    targets,
    is_logscale=False,
    title_font_size=8,
    label_font_size=8,
    dpi_request=300,
    capitalise_first_letter=False,
):

    fig, axes, _, n_rows, n_cols, indices = plotter.get_figure(
        len(output_names), share_xaxis=True)

    for i_output in range(n_rows * n_cols):
        if i_output < len(output_names):
            output = output_names[i_output]
            axis = plot_calibration(
                axes[indices[i_output][0], indices[i_output][1]],
                output,
                outputs[output],
                targets,
                is_logscale,
            )
            change_xaxis_to_date(axis, REF_DATE, rotation=0)
            axis.set_title(
                get_plot_text_dict(
                    output, capitalise_first_letter=capitalise_first_letter),
                fontsize=title_font_size,
            )
            filename = f"calibration-fit-{output}"
        else:
            axes[indices[i_output][0], indices[i_output][1]].axis("off")

    fig.tight_layout()
    plotter.save_figure(fig, filename=filename, dpi_request=dpi_request)
コード例 #28
0
ファイル: plots.py プロジェクト: thecapitalistcycle/AuTuMN
def plot_vic_seroprevalences(
    plotter: Plotter,
    uncertainty_df: pd.DataFrame,
    scenario_id: int,
    time: float,
    ref_date=REF_DATE,
    name="",
    requested_quantiles=None,
    credible_range=50,
):

    fig, axes, _, _, _, _ = plotter.get_figure(n_panels=2, share_yaxis="all")
    cluster_axis, age_axis = axes
    mask = (uncertainty_df["scenario"]
            == scenario_id) & (uncertainty_df["time"] == time)
    df = uncertainty_df[mask]
    quantile_vals = df["quantile"].unique().tolist()
    seroprevalence_by_cluster = {}
    sero_outputs = [
        output for output in df["type"].unique().tolist()
        if "proportion_seropositiveXcluster_" in output
    ]

    max_value = -10.0
    if len(sero_outputs) == 0:
        cluster_axis.text(
            0.0, 0.5,
            "Cluster-specific seroprevalence outputs are not available for this run"
        )
    else:
        for output in sero_outputs:
            output_mask = df["type"] == output
            cluster = output.split("proportion_seropositiveXcluster_")[1]
            seroprevalence_by_cluster[cluster] = {}
            for q in quantile_vals:
                q_mask = df["quantile"] == q
                seroprevalence_by_cluster[cluster][q] = [
                    100.0 * v
                    for v in df[output_mask][q_mask]["value"].tolist()
                ]
        q_keys = requested_quantiles if requested_quantiles else sorted(
            quantile_vals)
        num_quantiles = len(q_keys)
        half_length = num_quantiles // 2

        cluster_names = [
            get_plot_text_dict(i.split("proportion_seropositiveXcluster_")[1])
            for i in sero_outputs
        ]

        lower_q_key = (100.0 - credible_range) / 100.0 / 2.0
        upper_q_key = 1.0 - lower_q_key

        x_positions = range(len(seroprevalence_by_cluster))

        for i, cluster in enumerate(list(seroprevalence_by_cluster.keys())):
            cluster_axis.plot(
                [x_positions[i], x_positions[i]],
                [
                    seroprevalence_by_cluster[cluster][lower_q_key],
                    seroprevalence_by_cluster[cluster][upper_q_key],
                ],
                "-",
                color="black",
                lw=1.0,
            )
            cluster_axis.set_xticklabels(cluster_names,
                                         fontsize=10,
                                         rotation=90)
            max_value = max(max_value,
                            seroprevalence_by_cluster[cluster][upper_q_key][0])

            if num_quantiles % 2:
                q_key = q_keys[half_length]
                label = None if i > 0 else "model"
                cluster_axis.plot(
                    x_positions[i],
                    seroprevalence_by_cluster[cluster][q_key],
                    "o",
                    color="black",
                    markersize=4,
                    label=label,
                )

        cluster_axis.xaxis.set_ticks(x_positions)

        cluster_axis.set_ylim(bottom=0.0)
        cluster_axis.set_ylabel("% previously infected", fontsize=13)
        _date = ref_date + datetime.timedelta(days=time)

    axis, max_value, df, seroprevalence_by_age = plot_age_seroprev_to_axis(
        uncertainty_df,
        scenario_id,
        time,
        age_axis,
        requested_quantiles,
        ref_date,
        name,
        add_date_as_title=False,
        add_ylabel=False,
    )

    plotter.save_figure(fig,
                        filename="sero_by_cluster",
                        subdir="outputs",
                        title_text="")

    overall_seropos_estimates = df[df["type"] == "proportion_seropositive"][[
        "quantile", "value"
    ]].set_index("quantile")

    return max_value, seroprevalence_by_cluster, overall_seropos_estimates
コード例 #29
0
ファイル: plots.py プロジェクト: thecapitalistcycle/AuTuMN
def plot_timeseries_with_uncertainty(plotter: Plotter,
                                     uncertainty_df: pd.DataFrame,
                                     output_name: str,
                                     scenario_idxs: List[int],
                                     targets: dict,
                                     is_logscale=False,
                                     x_low=0.0,
                                     x_up=1e6,
                                     axis=None,
                                     n_xticks=None,
                                     ref_date=REF_DATE,
                                     add_targets=True,
                                     overlay_uncertainty=True,
                                     title_font_size=12,
                                     label_font_size=10,
                                     dpi_request=300,
                                     capitalise_first_letter=False,
                                     legend=False,
                                     requested_x_ticks=None,
                                     show_title=True,
                                     ylab=None,
                                     x_axis_to_date=True,
                                     start_quantile=0,
                                     sc_colors=None,
                                     custom_title=None,
                                     vlines={},
                                     hlines={}):
    """
    Plots the uncertainty timeseries for one or more scenarios.
    Also plots any calibration targets that are provided.
    """

    single_panel = axis is None
    if single_panel:
        fig, axis, _, _, _, _ = plotter.get_figure()

    n_scenarios_to_plot = len(scenario_idxs)
    if sc_colors is None:
        n_scenarios_to_plot = min([len(scenario_idxs), len(COLORS)])
        colors = _apply_transparency(COLORS[:n_scenarios_to_plot],
                                     ALPHAS[:n_scenarios_to_plot])

    # Plot each scenario on a single axis
    data_to_return = {}
    for i, scenario_idx in enumerate(scenario_idxs[:n_scenarios_to_plot]):
        if sc_colors is None:
            if scenario_idx < len(colors):
                scenario_colors = colors[scenario_idx]
            else:
                scenario_colors = colors[-1]
        else:
            scenario_colors = sc_colors[i]

        times, quantiles = _plot_uncertainty(
            axis,
            uncertainty_df,
            output_name,
            scenario_idx,
            x_up,
            x_low,
            scenario_colors,
            overlay_uncertainty=overlay_uncertainty,
            start_quantile=start_quantile,
            zorder=i + 1,
        )

        data_to_return[scenario_idx] = pd.DataFrame.from_dict(quantiles)
        data_to_return[scenario_idx].insert(0, "days from 31/12/2019", times)

    # Add plot targets
    if add_targets:
        values, times = _get_target_values(targets, output_name)
        trunc_values = [
            v for (v, t) in zip(values, times) if x_low <= t <= x_up
        ]
        trunc_times = [
            t for (v, t) in zip(values, times) if x_low <= t <= x_up
        ]
        _plot_targets_to_axis(axis,
                              trunc_values,
                              trunc_times,
                              on_uncertainty_plot=True)

    # Sort out x-axis
    if x_axis_to_date:
        change_xaxis_to_date(axis, ref_date, rotation=0)
    axis.tick_params(axis="x", labelsize=label_font_size)
    axis.tick_params(axis="y", labelsize=label_font_size)

    # Add lines with marking text to plots
    add_vertical_lines_to_plot(axis, vlines)
    add_horizontal_lines_to_plot(axis, hlines)

    if output_name == "proportion_seropositive":
        axis.yaxis.set_major_formatter(mtick.PercentFormatter(1, symbol=""))
    if show_title:
        title = custom_title if custom_title else get_plot_text_dict(
            output_name)
        axis.set_title(title, fontsize=title_font_size)

    if requested_x_ticks is not None:
        pyplot.xticks(requested_x_ticks)
    elif n_xticks is not None:
        pyplot.locator_params(axis="x", nbins=n_xticks)

    if is_logscale:
        axis.set_yscale("log")
    elif not (output_name.startswith("rel_diff")
              or output_name.startswith("abs_diff")):
        axis.set_ylim(ymin=0)

    if ylab is not None:
        axis.set_ylabel(ylab, fontsize=label_font_size)

    if legend:
        pyplot.legend(labels=scenario_idxs)

    if single_panel:
        idx_str = "-".join(map(str, scenario_idxs))
        filename = f"uncertainty-{output_name}-{idx_str}"
        plotter.save_figure(fig, filename=filename, dpi_request=dpi_request)

    return data_to_return
コード例 #30
0
def plot_param_vs_param(
    plotter: Plotter,
    mcmc_params: List[pd.DataFrame],
    parameters: list,
    burn_in: int,
    style: str,
    bins: int,
    label_font_size: int,
    label_chars: int,
    dpi_request: int,
    label_param_string=True,
    show_ticks=False,
    file_name="parameter_correlation_matrix",
):
    """
    Plot the parameter correlation matrices for each parameter combination.
    """

    # Prelims
    fig, axes, _, _, _, _ = plotter.get_figure(n_panels=len(parameters)**2)
    row_data, col_data = {}, {}

    # Get x and y data separately and collate up over the chains
    for row_idx, row_param_name in enumerate(parameters):
        row_data[row_param_name] = []
        for chain in range(len(mcmc_params)):
            x_param_mask = (mcmc_params[chain]["name"] == row_param_name) & (
                mcmc_params[chain]["run"] > burn_in)
            row_data[row_param_name] += mcmc_params[chain][x_param_mask][
                "value"].to_list()
    for col_idx, col_param_name in enumerate(parameters):
        col_data[col_param_name] = []
        for chain in range(len(mcmc_params)):
            y_param_mask = (mcmc_params[chain]["name"] == col_param_name) & (
                mcmc_params[chain]["run"] > burn_in)
            col_data[col_param_name] += mcmc_params[chain][y_param_mask][
                "value"].to_list()

    # Loop over parameter combinations
    for row_idx, row_param_name in enumerate(parameters):
        for col_idx, col_param_name in enumerate(parameters):

            axis = axes[row_idx, col_idx]
            if not show_ticks:
                axis.xaxis.set_ticks([])
                axis.yaxis.set_ticks([])
            else:
                axis.tick_params(labelsize=4)

            # Plot
            if row_idx > col_idx:
                if style == "Scatter":
                    axis.scatter(
                        col_data[col_param_name],
                        row_data[row_param_name],
                        alpha=0.5,
                        s=0.1,
                        color="k",
                    )
                elif style == "KDE":
                    sns.kdeplot(
                        col_data[col_param_name],
                        row_data[row_param_name],
                        ax=axis,
                        shade=True,
                        levels=5,
                        lw=1.0,
                    )
                else:
                    axis.hist2d(col_data[col_param_name],
                                row_data[row_param_name],
                                bins=bins)
            elif row_idx == col_idx:
                axis.hist(
                    row_data[row_param_name],
                    color=[0.2, 0.2, 0.6] if style == "Shade" else "k",
                    bins=bins,
                )
                axis.yaxis.set_ticks([])
            else:
                axis.axis("off")

            # Axis labels (these have to be reversed for some reason)
            x_param_label = col_param_name if label_param_string else str(
                col_idx + 1)
            y_param_label = row_param_name if label_param_string else str(
                row_idx + 1)
            if row_idx == len(parameters) - 1:
                axis.set_xlabel(get_plot_text_dict(x_param_label),
                                fontsize=label_font_size,
                                labelpad=3)
            if col_idx == 0:
                axis.set_ylabel(get_plot_text_dict(y_param_label),
                                fontsize=label_font_size)

    # Save
    plotter.save_figure(fig, filename=file_name, dpi_request=dpi_request)