Exemplo n.º 1
0
def fit_all_case_data(num_procs=4):
    pool = multi.Pool(num_procs)
    print(f"Pooling with {num_procs} processors")
    case_counts = parse_tsv()
    scenario_data = load_population_data()
    age_distributions = load_distribution()
    params = []

    for region in case_counts:
        if region in scenario_data:
            params.append([
                region, case_counts[region],
                scenario_data.get(region, None),
                age_distributions[scenario_data[region]['ages']], False
            ])
    results = pool.map(fit_population, params)

    results_dict = {}
    for k, params in results:
        if params is None:
            results_dict[k] = None
        elif np.isfinite(params['logInitial']):
            results_dict[k] = params
        else:
            results_dict[k] = None

    return results_dict
Exemplo n.º 2
0
def generate(output_json, num_procs=1, recalculate=False):
    import re
    scenarios = []
    case_counts = parse_tsv()
    scenario_data = load_population_data()

    for fname in (FIT_PARAMETERS, 'fit_parameters_1stwave.json'):
        first_wave = '1stwave' in fname
        print("reading file", fname)
        fit_fname = os.path.join(BASE_PATH, fname)
        if (recalculate or
            (not os.path.isfile(fit_fname))) and (not first_wave):
            results = fit_all_case_data(num_procs)
            with open(fit_fname, 'w') as fh:
                json.dump(results, fh)
        else:
            with open(fit_fname, 'r') as fh:
                results = json.load(fh)

        for region in scenario_data:
            if region not in results or results[
                    region] is None or region.startswith('FRA-'):
                continue
            if first_wave:  # skip if a small region or not fit to case data (no 'containment_start')
                results[region]['logInitial'] = np.log(
                    results[region]['initialCases'])
                results[region]['tMax'] = '2020-08-31'
                results[region]['seroprevalence'] = 0.0
                if (re.match('[A-Z][A-Z][A-Z]-', region)
                        and results[region]['initialCases'] < 100) or (
                            'containment_start' not in results[region]):
                    continue
            elif np.isnan(results[region]['logInitial']) or np.isinf(
                    results[region]['logInitial']):
                continue

            scenario = AllParams(**scenario_data[region],
                                 tMin=results[region]['tMin'],
                                 tMax=results[region]['tMax'],
                                 cases_key=region
                                 if region in case_counts else 'None')
            if first_wave:
                scenario.mitigation.mitigation_intervals = [
                    MitigationInterval(
                        name="Intervention 1",
                        tMin=datetime.strptime(
                            results[region]['containment_start'],
                            '%Y-%m-%d').date(),
                        id=uuid4(),
                        tMax=scenario.simulation.simulation_time_range.end +
                        timedelta(1),
                        color=mitigation_colors.get("Intervention 1",
                                                    "#cccccc"),
                        mitigationValue=round(100 *
                                              results[region]['efficacy']))
                ]
            elif region in case_counts:
                set_mitigation(scenario,
                               results[region].get('mitigations', []))
            else:
                scenario.mitigation.mitigation_intervals = []
            if len(scenario.mitigation.mitigation_intervals):
                scenario.mitigation.mitigation_intervals[
                    -1].time_range.end = datetime.strptime(
                        results[region]['tMax'],
                        '%Y-%m-%d').date() + timedelta(1)
            scenario.population.seroprevalence = round(
                100 * results[region]['seroprevalence'], 2)
            scenario.population.initial_number_of_cases = int(
                round(np.exp(results[region]['logInitial'])))

            if first_wave:
                scenario_name = f"[1st wave] {region}"
            else:
                scenario_name = region

            scenarios.append(ScenarioData(scenario, scenario_name))

    with open(output_json, "w+") as fd:
        output = ScenarioArray(scenarios)
        output.marshalJSON(fd)
Exemplo n.º 3
0
    with open(output_json, "w+") as fd:
        output = ScenarioArray(scenarios)
        output.marshalJSON(fd)


if __name__ == '__main__':

    generate('test.json', recalculate=False)

    from scripts.test_fitting_procedure import generate_data, check_fit
    from scripts.model import trace_ages, get_IFR
    from matplotlib import pyplot as plt

    case_counts = parse_tsv()
    scenario_data = load_population_data()
    age_distributions = load_distribution()
    # region = 'JPN-Kagawa'
    region = 'United States of America'
    # region = 'Germany'
    region = 'Switzerland'
    region = 'USA-Texas'
    age_dis = age_distributions[scenario_data[region]['ages']]
    region, p, fit_params = fit_population(
        (region, case_counts[region], scenario_data[region], age_dis, True))

    model_data = generate_data(fit_params)
    model_cases = model_data['cases'][7:] - model_data['cases'][:-7]
    model_deaths = model_data['deaths'][7:] - model_data['deaths'][:-7]
    model_time = fit_params.time[7:]