from autumn.tool_kit import schema_builder as sb
from typing import List, Dict


def test_schema_builder():
    cerberus_schema = sb.build_schema(**INPUT_SCHEMA)
    assert cerberus_schema == EXPECTED_SCHEMA


INPUT_SCHEMA = {
    "region": sb.Nullable(str),
    "translations": sb.DictGeneric(str, str),
    "outputs_to_plot": sb.List(sb.Dict(name=str)),
    "pop_distribution_strata": sb.List(str),
    "prevalence_combos": sb.List(sb.List(str)),
    "input_function": sb.Dict(start_time=float, func_names=sb.List(str)),
    "parameter_category_values": sb.Dict(time=float, param_names=sb.List(str)),
}

EXPECTED_SCHEMA = {
    "region": {
        "type": "string",
        "nullable": True
    },
    "translations": {
        "type": "dict",
        "valuesrules": {
            "type": "string"
        },
        "keysrules": {
            "type": "string"
Beispiel #2
0
from autumn.tool_kit.scenarios import Scenario
from autumn.tool_kit import schema_builder as sb
from autumn.tool_kit.uncertainty import export_mcmc_quantiles
from autumn.db.database import Database

from .plotter import Plotter, COLOR_THEME

logger = logging.getLogger(__name__)

# Schema used to validate output plotting configuration data.
validate_plot_config = sb.build_validator(
    # A list of translation mappings used to format plot titles.
    translations=sb.DictGeneric(str, str),
    # List of derived / generated outputs to plot
    outputs_to_plot=sb.List(
        sb.Dict(name=str,
                target_times=sb.List(float),
                target_values=sb.List(sb.List(float)))),
)


def plot_mcmc_parameter_trace(plotter: Plotter,
                              mcmc_tables: List[pd.DataFrame],
                              param_name: str):
    """
    Plot the prameter traces for each MCMC run.
    """
    _overwrite_non_accepted_mcmc_runs(mcmc_tables, column_name=param_name)
    fig, axis, _, _, _ = plotter.get_figure()
    for idx, table_df in enumerate(mcmc_tables):
        table_df[param_name].plot.line(ax=axis, alpha=0.8, linewidth=0.7)
Beispiel #3
0
from autumn.tool_kit import schema_builder as sb

validate_params = sb.build_validator(
    stratify_by=sb.List(str),
    # Country info
    iso3=str,
    region=sb.Nullable(str),
    # Running time.
    start_time=float,
    end_time=float,
    time_step=float,
    # Compartment construction
    compartment_periods=sb.DictGeneric(str, float),
    compartment_periods_calculated=dict,
    # Infectiousness adjustments (not sure where used)
    hospital_props=sb.List(float),
    hospital_props_multiplier=float,
    # mortality parameters
    use_raw_mortality_estimates=bool,
    infection_fatality_props=sb.List(float),
    ifr_double_exp_model_params=dict,
    # Age stratified params
    agegroup_breaks=sb.List(float),
    age_based_susceptibility=sb.DictGeneric(str, float),
    # Clinical status stratified params
    clinical_strata=sb.List(str),
    non_sympt_infect_multiplier=float,
    late_infect_multiplier=sb.Dict(sympt_isolate=float,
                                   hospital_non_icu=float,
                                   icu=float),
    icu_mortality_prop=float,
Beispiel #4
0
from autumn.demography.social_mixing import load_specific_prem_sheet
from autumn.tool_kit.scenarios import Scenario
from autumn.tool_kit import schema_builder as sb

from .plotter import Plotter, COLOR_THEME

logger = logging.getLogger(__file__)


# Schema used to validate output plotting configuration data.
validate_plot_config = sb.build_validator(
    # A list of translation mappings used to format plot titles.
    translations=sb.DictGeneric(str, str),
    # List of derived / generated outputs to plot
    outputs_to_plot=sb.List(
        sb.Dict(name=str, target_times=sb.List(float), target_values=sb.List(sb.List(float)))
    ),
    # Plot population distribution across particular strata
    pop_distribution_strata=sb.List(str),
    # Plot prevalence combinations
    prevalence_combos=sb.List(sb.List(str)),
    # Visualise input functions over model time range.
    input_function=sb.Dict(start_time=float, func_names=sb.List(str)),
    # Visualise parameter values across categories for a particular time.
    parameter_category_values=sb.Dict(time=float, param_names=sb.List(str)),
)


def plot_mcmc_parameter_trace(plotter: Plotter, mcmc_tables: List[pd.DataFrame], param_name: str):
    """
    Plot the prameter traces for each MCMC run.
Beispiel #5
0
from copy import deepcopy

from summer.model.strat_model import StratifiedModel
from autumn.tool_kit import schema_builder as sb

from .requested_outputs import RequestedOutput

# FIXME - This data representation can be improved... somehow.
validate_post_process_config = sb.build_validator(
    # Outputs to be generated
    # Eg. ["prevXinfectiousXamongXage_10Xstrain_sensitive", "distribution_of_strataXstrain"]
    requested_outputs=sb.List(str),
    # Constants to multiply the generated outputs by.
    # Eg. {"prevXinfectiousXamongXage_10Xstrain_sensitive": 1.0e5}
    multipliers=sb.DictGeneric(str, float),
    # List of compartment, stratification pairs used to generate some more requested outputs
    # Eg. [["infectious", "location"], ["latent", "location"]]
    collated_combos=sb.List(sb.List(str)),
)


def post_process(model: StratifiedModel,
                 post_process_config: dict,
                 add_defaults=True):
    """
    Derive generated outputs from a model after the model has run.
    Returns a dict of generated outputs.
    """
    validate_post_process_config(post_process_config)

    # Read config.