def fit_all_case_data(num_procs=4): pool = multi.Pool(num_procs) print(f"Pooling with {num_procs} processors") case_counts = parse_tsv() scenario_data = load_population_data() age_distributions = load_distribution() params = [] for region in case_counts: if region in scenario_data: params.append([ region, case_counts[region], scenario_data.get(region, None), age_distributions[scenario_data[region]['ages']], False ]) results = pool.map(fit_population, params) results_dict = {} for k, params in results: if params is None: results_dict[k] = None elif np.isfinite(params['logInitial']): results_dict[k] = params else: results_dict[k] = None return results_dict
def generate(output_json, num_procs=1, recalculate=False): import re scenarios = [] case_counts = parse_tsv() scenario_data = load_population_data() for fname in (FIT_PARAMETERS, 'fit_parameters_1stwave.json'): first_wave = '1stwave' in fname print("reading file", fname) fit_fname = os.path.join(BASE_PATH, fname) if (recalculate or (not os.path.isfile(fit_fname))) and (not first_wave): results = fit_all_case_data(num_procs) with open(fit_fname, 'w') as fh: json.dump(results, fh) else: with open(fit_fname, 'r') as fh: results = json.load(fh) for region in scenario_data: if region not in results or results[ region] is None or region.startswith('FRA-'): continue if first_wave: # skip if a small region or not fit to case data (no 'containment_start') results[region]['logInitial'] = np.log( results[region]['initialCases']) results[region]['tMax'] = '2020-08-31' results[region]['seroprevalence'] = 0.0 if (re.match('[A-Z][A-Z][A-Z]-', region) and results[region]['initialCases'] < 100) or ( 'containment_start' not in results[region]): continue elif np.isnan(results[region]['logInitial']) or np.isinf( results[region]['logInitial']): continue scenario = AllParams(**scenario_data[region], tMin=results[region]['tMin'], tMax=results[region]['tMax'], cases_key=region if region in case_counts else 'None') if first_wave: scenario.mitigation.mitigation_intervals = [ MitigationInterval( name="Intervention 1", tMin=datetime.strptime( results[region]['containment_start'], '%Y-%m-%d').date(), id=uuid4(), tMax=scenario.simulation.simulation_time_range.end + timedelta(1), color=mitigation_colors.get("Intervention 1", "#cccccc"), mitigationValue=round(100 * results[region]['efficacy'])) ] elif region in case_counts: set_mitigation(scenario, results[region].get('mitigations', [])) else: scenario.mitigation.mitigation_intervals = [] if len(scenario.mitigation.mitigation_intervals): scenario.mitigation.mitigation_intervals[ -1].time_range.end = datetime.strptime( results[region]['tMax'], '%Y-%m-%d').date() + timedelta(1) scenario.population.seroprevalence = round( 100 * results[region]['seroprevalence'], 2) scenario.population.initial_number_of_cases = int( round(np.exp(results[region]['logInitial']))) if first_wave: scenario_name = f"[1st wave] {region}" else: scenario_name = region scenarios.append(ScenarioData(scenario, scenario_name)) with open(output_json, "w+") as fd: output = ScenarioArray(scenarios) output.marshalJSON(fd)
with open(output_json, "w+") as fd: output = ScenarioArray(scenarios) output.marshalJSON(fd) if __name__ == '__main__': generate('test.json', recalculate=False) from scripts.test_fitting_procedure import generate_data, check_fit from scripts.model import trace_ages, get_IFR from matplotlib import pyplot as plt case_counts = parse_tsv() scenario_data = load_population_data() age_distributions = load_distribution() # region = 'JPN-Kagawa' region = 'United States of America' # region = 'Germany' region = 'Switzerland' region = 'USA-Texas' age_dis = age_distributions[scenario_data[region]['ages']] region, p, fit_params = fit_population( (region, case_counts[region], scenario_data[region], age_dis, True)) model_data = generate_data(fit_params) model_cases = model_data['cases'][7:] - model_data['cases'][:-7] model_deaths = model_data['deaths'][7:] - model_data['deaths'][:-7] model_time = fit_params.time[7:]