def __init__(self, fips, N_samples, t_list, I_initial=1, suppression_policy=None): # Caching globally to avoid relatively significant performance overhead # of loading for each county. global beds_data, population_data if not beds_data or not population_data: beds_data = DHBeds.local().beds() population_data = FIPSPopulation.local().population() self.fips = fips self.agg_level = AggregationLevel.COUNTY if len(self.fips) == 5 else AggregationLevel.STATE self.N_samples = N_samples self.I_initial = I_initial self.suppression_policy = suppression_policy self.t_list = t_list if self.agg_level is AggregationLevel.COUNTY: self.county_metadata = load_data.load_county_metadata().set_index('fips').loc[fips].to_dict() self.state_abbr = us.states.lookup(self.county_metadata['state']).abbr self.population = population_data.get_county_level('USA', state=self.state_abbr, fips=self.fips) # TODO: Some counties do not have hospitals. Likely need to go to HRR level.. self.beds = beds_data.get_county_level(self.state_abbr, fips=self.fips) or 0 self.icu_beds = beds_data.get_county_level(self.state_abbr, fips=self.fips, column='icu_beds') or 0 else: self.state_abbr = us.states.lookup(fips).abbr self.population = population_data.get_state_level('USA', state=self.state_abbr) self.beds = beds_data.get_state_level(self.state_abbr) or 0 self.icu_beds = beds_data.get_state_level(self.state_abbr, column='icu_beds') or 0
def __init__( self, state, reference_date=datetime(day=1, month=3, year=2020), plot_compartments=("HICU", "HGen", "HVent"), primary_suppression_policy="suppression_policy__0.5", ): self.state = state self.reference_date = reference_date self.plot_compartments = plot_compartments self.primary_suppression_policy = primary_suppression_policy # Load the county metadata and extract names for the state. county_metadata = load_data.load_county_metadata() self.counties = county_metadata[county_metadata["state"].str.lower() == self.state.lower()]["fips"] self.ensemble_data_by_county = { fips: load_data.load_ensemble_results(fips) for fips in self.counties } self.county_metadata = county_metadata.set_index("fips") self.names = [ self.county_metadata.loc[fips, "county"].replace(" County", "") for fips in self.counties ] self.filename = os.path.join( OUTPUT_DIR, self.state, "reports", f"summary__{self.state}__state_report.pdf") self.surge_filename = os.path.join( OUTPUT_DIR, self.state, "reports", f"summary__{self.state}__state_surge_report.xlsx")
def generate_state(self, states_only=False): """ Generate for each county in a state, the output for the webUI. Parameters ---------- states_only: bool If True only run the state level. """ state_fips = us.states.lookup(self.state).fips self.map_fips(state_fips) if not states_only: df = load_data.load_county_metadata() all_fips = df[df['state'].str.lower() == self.state.lower()].fips if not self.include_imputed: # Filter... fips_with_cases = self.jhu_local.timeseries() \ .get_subset(AggregationLevel.COUNTY, country='USA') \ .get_data(country='USA', state=self.state_abbreviation) fips_with_cases = fips_with_cases[ fips_with_cases.cases > 0].fips.unique().tolist() all_fips = [ fips for fips in all_fips if fips in fips_with_cases ] p = Pool() p.map(self.map_fips, all_fips) p.close()
def run_state(state, states_only=False): """ Run the R_t inference for each county in a state. Parameters ---------- state: str State to run against. states_only: bool If True only run the state level. """ state_obj = us.states.lookup(state) df = RtInferenceEngine.run_for_fips(state_obj.fips) output_path = get_run_artifact_path(state_obj.fips, RunArtifact.RT_INFERENCE_RESULT) df.to_json(output_path) # Run the counties. if not states_only: df = load_data.load_county_metadata() all_fips = df[df['state'].str.lower() == state_obj.name.lower()].fips.values # Something in here doesn't like multiprocessing... # p = Pool(2) rt_inferences = list(map(RtInferenceEngine.run_for_fips, all_fips)) # p.close() for fips, rt_inference in zip(all_fips, rt_inferences): county_output_file = get_run_artifact_path(fips, RunArtifact.RT_INFERENCE_RESULT) if rt_inference is not None: rt_inference.to_json(county_output_file)
def __init__(self, state, reference_date=datetime(day=1, month=3, year=2020), plot_compartments=('HICU', 'HGen', 'HVent'), primary_suppression_policy='suppression_policy__0.5'): self.state = state self.reference_date = reference_date self.plot_compartments = plot_compartments self.primary_suppression_policy = primary_suppression_policy # Load the county metadata and extract names for the state. county_metadata = load_data.load_county_metadata() self.counties = county_metadata[county_metadata['state'].str.lower() == self.state.lower()]['fips'] self.ensemble_data_by_county = { fips: load_data.load_ensemble_results(fips) for fips in self.counties } self.county_metadata = county_metadata.set_index('fips') self.names = [ self.county_metadata.loc[fips, 'county'].replace(' County', '') for fips in self.counties ] self.filename = os.path.join( OUTPUT_DIR, self.state, 'reports', f"summary__{self.state}__state_report.pdf") self.surge_filename = os.path.join( OUTPUT_DIR, self.state, 'reports', f"summary__{self.state}__state_surge_report.xlsx")
def run_state(state, ensemble_kwargs, states_only=False): """ Run the EnsembleRunner for each county in a state. Parameters ---------- state: str State to run against. ensemble_kwargs: dict Kwargs passed to the EnsembleRunner object. states_only: bool If True only run the state level. """ # Run the state level runner = EnsembleRunner(fips=us.states.lookup(state).fips, **ensemble_kwargs) runner.run_ensemble() if not states_only: # Run county level df = load_data.load_county_metadata() all_fips = df[df['state'].str.lower() == state.lower()].fips p = Pool() f = partial(_run_county, ensemble_kwargs=ensemble_kwargs) p.map(f, all_fips) p.close()
def __init__( self, total_cases=50, total_deaths=0, nonzero_case_datapoints=5, nonzero_death_datapoints=0 ): self.county_metadata = load_data.load_county_metadata() self.df_whitelist = None self.total_cases = total_cases self.total_deaths = total_deaths self.nonzero_case_datapoints = nonzero_case_datapoints self.nonzero_death_datapoints = nonzero_death_datapoints
def generate_empirical_distancing_policy_by_state(t_list, state, future_suppression, reference_start_date=None): """ Produce a suppression policy at state level based on Imperial College estimates of social distancing programs combined with County level datasets about their implementation. Note: This is about 250ms per state, which adds up when running e.g. MLE optimization. Bottleneck is computing the suppression policy to date which is done by summing counties. This should be done once per state and lru cached, not done for each county every call. Also just using numpy instead of pandas. Parameters ---------- t_list: array-like List of times to interpolate over. state: str State full name to lookup interventions against. future_suppression: float The suppression level to apply in an ongoing basis after today, and going backward as the lockdown / stay-at-home efficacy. reference_start_date: pd.Timestamp Start date as reference to shift t_list. Returns ------- suppression_model: callable suppression_model(t) returns the current suppression model at time t """ county_metadata = load_data.load_county_metadata() counties_fips = county_metadata[county_metadata.state == state].fips.unique() if reference_start_date is None: reference_start_date = min([infer_t0(fips) for fips in counties_fips]) # Aggregate the counties to the state level, weighted by population. weight = county_metadata.loc[county_metadata.state == state, "total_population"].values weight = weight / weight.sum() results = [] for fips in counties_fips: suppression_policy = generate_empirical_distancing_policy( fips=fips, t_list=t_list, future_suppression=future_suppression, reference_start_date=reference_start_date, ) results.append(suppression_policy(t_list).clip(max=1, min=0)) results_for_state = (np.vstack(results).T * weight).sum(axis=1) return interp1d(t_list, results_for_state, fill_value="extrapolate")
def generate_surge_spreadsheet(self): """ Produce a spreadsheet summarizing peaks. Parameters ---------- state: str State to generate sheet for. Returns ------- """ df = load_data.load_county_metadata() all_fips = df[df['state'].str.lower() == self.state.lower()].fips all_data = {fips: load_data.load_ensemble_results(fips) for fips in all_fips} df = df.set_index('fips') records = [] for fips, ensembles in all_data.items(): county_name = df.loc[fips]['county'] t0 = fit_results.load_t0(fips) for suppression_policy, ensemble in ensembles.items(): county_record = dict( county_name=county_name, county_fips=fips, mitigation_policy=policy_to_mitigation(suppression_policy) ) for compartment in ['HGen', 'general_admissions_per_day', 'HICU', 'icu_admissions_per_day', 'total_new_infections', 'direct_deaths_per_day', 'total_deaths', 'D']: compartment_name = compartment_to_name_map[compartment] county_record[compartment_name + ' Peak Value Mean'] = '%.0f' % ensemble[compartment]['peak_value_mean'] county_record[compartment_name + ' Peak Value Median'] = '%.0f' % ensemble[compartment]['peak_value_ci50'] county_record[compartment_name + ' Peak Value CI25'] = '%.0f' % ensemble[compartment]['peak_value_ci25'] county_record[compartment_name + ' Peak Value CI75'] = '%.0f' % ensemble[compartment]['peak_value_ci75'] county_record[compartment_name + ' Peak Time Median'] = (t0 + timedelta(days=ensemble[compartment]['peak_time_ci50'])).date().isoformat() # Leaving for now... # if 'surge_start' in ensemble[compartment]: # if not np.isnan(np.nanmean(ensemble[compartment]['surge_start'])): # county_record[compartment_name + ' Surge Start Mean'] = (t0 + timedelta(days=np.nanmean(ensemble[compartment]['surge_start']))).date().isoformat() # county_record[compartment_name + ' Surge End Mean'] = (t0 + timedelta(days=np.nanmean(ensemble[compartment]['surge_end']))).date().isoformat() records.append(county_record) df = pd.DataFrame(records) writer = pd.ExcelWriter(self.surge_filename, engine='xlsxwriter') for policy in df['mitigation_policy'].unique()[::-1]: df[df['mitigation_policy'] == policy].drop(['mitigation_policy', 'county_fips'], axis=1) df[df['mitigation_policy'] == policy].drop(['mitigation_policy', 'county_fips'], axis=1).to_excel(writer, sheet_name=policy) writer.save()
def __init__(self, fips, N_samples, t_list, I_initial=1, suppression_policy=None): self.fips = fips self.agg_level = AggregationLevel.COUNTY if len(self.fips) == 5 else AggregationLevel.STATE self.N_samples = N_samples self.I_initial = I_initial self.suppression_policy = suppression_policy self.t_list = t_list if self.agg_level is AggregationLevel.COUNTY: self.county_metadata = load_data.load_county_metadata().set_index('fips').loc[fips].to_dict() self.state_abbr = us.states.lookup(self.county_metadata['state']).abbr self._latest = combined_datasets.get_us_latest_for_fips(self.fips) else: self.state_abbr = us.states.lookup(fips).abbr self._latest = combined_datasets.get_us_latest_for_state(self.state_abbr)
def run_state(state, ensemble_kwargs): """ Run the EnsembleRunner for each county in a state. Parameters ---------- state: str State to run against. ensemble_kwargs: dict Kwargs passed to the EnsembleRunner object. """ df = load_data.load_county_metadata() all_fips = df[df['state'].str.lower() == state.lower()].fips p = Pool() f = partial(_run_county, ensemble_kwargs=ensemble_kwargs) p.map(f, all_fips) p.close()
def load_t0(fips): """ Load the simulation start time by county. Parameters ---------- fips: str County FIPS Returns ------- : datetime t0(C=1) cases. """ county_metadata = load_county_metadata().set_index('fips') state = county_metadata.loc[fips]['state'] fit_results = os.path.join(OUTPUT_DIR, state, 'data', f'summary__{state}_imputed_start_times.pkl') return datetime.fromtimestamp(pd.read_pickle(fit_results).set_index('fips').loc[fips]['t0_date'].timestamp())
def run_state(state): """ Run the fitter for each county in a state. Parameters ---------- state: str State to run against. """ df = load_data.load_county_metadata() all_fips = df[df['state'].str.lower() == state.lower()].fips p = Pool() fit_results = p.map(fit_county_model, all_fips) output_file = os.path.join(OUTPUT_DIR, state.title(), 'data', f'summary_{state}__mle_fit_results.json') pd.DataFrame(fit_results).to_json(output_file) p.map(plot_inferred_result, fit_results) p.close()
def load_t0(fips): """ Load the simulation start time by county. Parameters ---------- fips: str County FIPS Returns ------- : datetime t0(C=1) cases. """ county_metadata = load_county_metadata().set_index("fips") state = county_metadata.loc[fips]["state"] fit_results = os.path.join(OUTPUT_DIR, "pyseir", state, "data", f"summary__{state}_imputed_start_times.pkl") return datetime.fromtimestamp( pd.read_pickle(fit_results).set_index("fips").loc[fips] ["t0_date"].timestamp())
def __init__(self, fips, N_samples, t_list, I_initial=1, suppression_policy=None): self.fips = fips self.N_samples = N_samples self.I_initial = I_initial self.suppression_policy = suppression_policy self.t_list = t_list county_metadata = load_data.load_county_metadata() hospital_bed_data = load_data.load_hospital_data() # TODO: Some counties do not have hospitals. Likely need to go to HRR level.. hospital_bed_data = hospital_bed_data[[ 'fips', 'num_licensed_beds', 'num_staffed_beds', 'num_icu_beds', 'bed_utilization', 'potential_increase_in_bed_capac' ]].groupby('fips').sum() self.county_metadata_merged = county_metadata.merge( hospital_bed_data, on='fips', how='left').set_index('fips').loc[fips].to_dict()
def __init__(self, fips, N_samples, t_list, I_initial=1, suppression_policy=None): # Caching globally to avoid relatively significant performance overhead # of loading for each county. global beds_data, population_data if not beds_data or not population_data: beds_data = CovidCareMapBeds.local().beds() population_data = FIPSPopulation.local().population() self.fips = fips self.agg_level = AggregationLevel.COUNTY if len( self.fips) == 5 else AggregationLevel.STATE self.N_samples = N_samples self.I_initial = I_initial self.suppression_policy = suppression_policy self.t_list = t_list if self.agg_level is AggregationLevel.COUNTY: self.county_metadata = load_data.load_county_metadata().set_index( 'fips').loc[fips].to_dict() self.state_abbr = us.states.lookup( self.county_metadata['state']).abbr self.population = population_data.get_record_for_fips( fips=self.fips)[CommonFields.POPULATION] # TODO: Some counties do not have hospitals. Likely need to go to HRR level.. self._beds_data = beds_data.get_record_for_fips(fips) else: self.state_abbr = us.states.lookup(fips).abbr self.population = population_data.get_record_for_state( self.state_abbr)[CommonFields.POPULATION] self._beds_data = beds_data.get_record_for_state(self.state_abbr)
def plot_inferred_result(fit_results): """ Plot the results of an MLE inference """ fips = fit_results['fips'] county_metadata = load_data.load_county_metadata().set_index( 'fips').loc[fips].to_dict() times, observed_new_cases, observed_new_deaths = load_data.load_new_case_data_by_fips( fips, t0=ref_date) if observed_new_cases.sum() < 5: logging.warning( f"{county_metadata['county']} has fewer than 5 cases. Aborting plot." ) return else: logging.info(f"Plotting MLE Fits for {county_metadata['county']}") R0, t0, eps = fit_results['R0'], fit_results['t0'], fit_results['eps'] model = SEIRModel(R0=R0, suppression_policy=suppression_policies. generate_empirical_distancing_policy( t_list, fips, future_suppression=eps), **get_average_SEIR_parameters(fit_results['fips'])) model.run() data_dates = [ref_date + timedelta(days=t) for t in times] model_dates = [ ref_date + timedelta(days=t + fit_results['t0']) for t in t_list ] plt.figure(figsize=(10, 8)) plt.errorbar(data_dates, observed_new_cases, marker='o', linestyle='', label='Observed Cases Per Day') plt.errorbar(data_dates, observed_new_deaths, yerr=np.sqrt(observed_new_deaths), marker='o', linestyle='', label='Observed Deaths') plt.plot(model_dates, model.results['total_new_infections'], label='Estimated Total New Infections Per Day') plt.plot(model_dates, model.gamma * model.results['total_new_infections'], label='Symptomatic Model Cases Per Day') plt.plot(model_dates, model.results['direct_deaths_per_day'], label='Model Deaths Per Day') plt.yscale('log') plt.ylim(.9e0) plt.xlim(data_dates[0], data_dates[-1] + timedelta(days=90)) plt.xticks(rotation=30) plt.legend(loc=1) plt.grid(which='both', alpha=.3) plt.title(county_metadata['county']) for i, (k, v) in enumerate(fit_results.items()): if k not in ('fips', 't0_date', 'county', 'state'): plt.text(.025, .97 - 0.04 * i, f'{k}={v:1.3f}', transform=plt.gca().transAxes, fontsize=12) else: plt.text(.025, .97 - 0.04 * i, f'{k}={v}', transform=plt.gca().transAxes, fontsize=12) output_file = os.path.join( OUTPUT_DIR, fit_results['state'].title(), 'reports', f'{fit_results["state"]}__{fit_results["county"]}__{fit_results["fips"]}__mle_fit_results.pdf' ) plt.savefig(output_file)
def __init__(self, fips, ref_date=datetime(year=2020, month=1, day=1), min_deaths=2, n_years=1, cases_to_deaths_err_factor=.5, hospital_to_deaths_err_factor=.5, percent_error_on_max_observation=0.5, with_age_structure=False): # Seed the random state. It is unclear whether this propagates to the # Minuit optimizer. np.random.seed(seed=42) self.fips = fips self.ref_date = ref_date self.min_deaths = min_deaths self.t_list = np.linspace(0, int(365 * n_years), int(365 * n_years) + 1) self.cases_to_deaths_err_factor = cases_to_deaths_err_factor self.hospital_to_deaths_err_factor = hospital_to_deaths_err_factor self.percent_error_on_max_observation = percent_error_on_max_observation self.t0_guess = 60 self.with_age_structure = with_age_structure if len(fips) == 2: # State FIPS are 2 digits self.agg_level = AggregationLevel.STATE self.state_obj = us.states.lookup(self.fips) self.state = self.state_obj.name self.times, self.observed_new_cases, self.observed_new_deaths = \ load_data.load_new_case_data_by_state(self.state, self.ref_date) self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \ load_data.load_hospitalization_data_by_state(self.state_obj.abbr, t0=self.ref_date) self.display_name = self.state else: self.agg_level = AggregationLevel.COUNTY geo_metadata = load_data.load_county_metadata().set_index( 'fips').loc[fips].to_dict() state = geo_metadata['state'] self.state_obj = us.states.lookup(state) county = geo_metadata['county'] if county: self.display_name = county + ', ' + state else: self.display_name = state # TODO Swap for new data source. self.times, self.observed_new_cases, self.observed_new_deaths = \ load_data.load_new_case_data_by_fips(self.fips, t0=self.ref_date) self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \ load_data.load_hospitalization_data(self.fips, t0=self.ref_date) self.cases_stdev, self.hosp_stdev, self.deaths_stdev = self.calculate_observation_errors( ) self.set_inference_parameters() self.model_fit_keys = ['R0', 'eps', 't_break', 'log10_I_initial'] self.SEIR_kwargs = self.get_average_seir_parameters() self.fit_results = None self.mle_model = None self.chi2_deaths = None self.chi2_cases = None self.chi2_hosp = None self.dof_deaths = None self.dof_cases = None self.dof_hosp = None
def generate_start_times_for_state(state, generate_report=False): """ Generate imputed start dates for each county. Parameters ---------- state: str State to model counties of. generate_report: bool If True, generate summary plots. """ metadata = load_data.load_county_metadata() state_dir = os.path.join(OUTPUT_DIR, "pyseir", state) os.makedirs(state_dir, exist_ok=True) os.makedirs(os.path.join(state_dir, "reports"), exist_ok=True) os.makedirs(os.path.join(state_dir, "data"), exist_ok=True) logging.info(f"Imputing start times for {state.capitalize()}") state_fips = us.states.lookup(state).fips metadata["state_fips"] = metadata["fips"].apply(lambda x: x[:2]) counties = metadata[metadata["state_fips"] == state_fips].fips if len(counties) == 0: raise ValueError(f"No entries for state {state}.") # Fit exponential model to extract T0. f = partial(_fit_fips, generate_report=generate_report) with Pool(maxtasksperchild=1) as p: fips_to_fit_map = { fips: val for fips, val in zip(counties.values, p.map(f, counties.values)) } # -------------------------------- # ML to Impute start time for counties with no data based on pop. density # ------------------------------- # Merge in county level metadata. county_fits = ( pd.DataFrame.from_dict(fips_to_fit_map, orient="index") .reset_index() .rename({"index": "fips"}, axis=1) ) merged = county_fits.merge(metadata, on="fips") merged["days_from_2020_01_01"] = (merged.t0_date - datetime.fromisoformat("2020-01-01")).dt.days samples_with_data = merged["days_from_2020_01_01"].notnull() samples_with_no_data = merged["days_from_2020_01_01"].isnull() if samples_with_no_data.any(): X = np.nan_to_num( np.log( merged[["population_density", "housing_density", "total_population"]][ samples_with_data ] ) ) X_predict = np.nan_to_num( np.log( merged[["population_density", "housing_density", "total_population"]][ samples_with_no_data ] ) ) # Test a few regressions for estimator in [LinearRegression(), RandomForestRegressor(), BayesianRidge()]: cv_result = cross_validate( estimator, X=X, y=merged["days_from_2020_01_01"][samples_with_data], scoring="r2", cv=2, ) logging.info(f'{estimator.__class__.__name__} CV r2: {cv_result["test_score"].mean()}') # Train best model and impute the missing times. best_model = BayesianRidge() best_model.fit(X=X, y=merged["days_from_2020_01_01"][samples_with_data]) merged.loc[samples_with_no_data, "days_from_2020_01_01"] = best_model.predict(X_predict) merged.loc[samples_with_no_data, "t0_date"] = datetime.fromisoformat( "2020-01-01" ) + np.array([timedelta(days=t) for t in best_model.predict(X_predict)]) # Plot doubling time by population density merged.loc[samples_with_no_data, "imputed_start_time"] = True merged.loc[samples_with_data, "imputed_start_time"] = False merged.loc[samples_with_data, "doubling_rate_days"] = np.log(2) * merged["model_params"][ samples_with_data ].apply(lambda x: x["scale"]) merged.to_pickle(os.path.join(state_dir, "data", f"summary__{state}_imputed_start_times.pkl")) if generate_report: # Plot population density plt.figure(figsize=(14, 4)) for i, x in enumerate(("population_density", "housing_density", "total_population")): plt.subplot(1, 3, i + 1) plt.title(state) sns.jointplot( x=np.log10(merged[x]), y="days_from_2020_01_01", data=merged, kind="reg", height=5 ) plt.xlabel("log10 Population Density") plt.savefig( os.path.join(state_dir, "reports", f"summary__{state}__population_density.pdf"), bbox_inches="tight", ) plt.close() # Plot Doubling Rates by distance # TODO: Impute doubling time. sns.jointplot( np.log10(merged.population_density), merged.doubling_rate_days, kind="reg", height=10 ) plt.xlabel("Log10 Population Density", fontsize=16) plt.ylabel("Doubling Time [Days]", fontsize=16) plt.grid() plt.savefig( os.path.join(state_dir, "reports", f"summary__{state}__doubling_time.pdf"), bbox_inches="tight", ) plt.close()
def __init__( self, fips, window_size=InferRtConstants.COUNT_SMOOTHING_WINDOW_SIZE, kernel_std=5, r_list=np.linspace(0, 10, 501), process_sigma=0.05, ref_date=datetime(year=2020, month=1, day=1), confidence_intervals=(0.68, 0.95), min_cases=5, min_deaths=5, include_testing_correction=True, ): np.random.seed(InferRtConstants.RNG_SEED) # Param Generation used for Xcor in align_time_series, has some stochastic FFT elements. self.fips = fips self.r_list = r_list self.window_size = window_size self.kernel_std = kernel_std self.process_sigma = process_sigma self.ref_date = ref_date self.confidence_intervals = confidence_intervals self.min_cases = min_cases self.min_deaths = min_deaths self.include_testing_correction = include_testing_correction # Because rounding is disabled we don't need high min_deaths, min_cases anymore self.min_cases = min(InferRtConstants.MIN_COUNTS_TO_INFER, self.min_cases) if not InferRtConstants.DISABLE_DEATHS: self.min_deaths = min(InferRtConstants.MIN_COUNTS_TO_INFER, self.min_deaths) if len(fips) == 2: # State FIPS are 2 digits self.agg_level = AggregationLevel.STATE self.state_obj = us.states.lookup(self.fips) self.state = self.state_obj.name ( self.times, self.observed_new_cases, self.observed_new_deaths, ) = load_data.load_new_case_data_by_state( self.state, self.ref_date, include_testing_correction=self.include_testing_correction, ) self.times_raw_new_cases, self.raw_new_cases, _ = load_data.load_new_case_data_by_state( self.state, self.ref_date, include_testing_correction=False ) ( self.hospital_times, self.hospitalizations, self.hospitalization_data_type, ) = load_data.load_hospitalization_data_by_state( state=self.state_obj.abbr, t0=self.ref_date ) self.display_name = self.state else: self.agg_level = AggregationLevel.COUNTY self.geo_metadata = ( load_data.load_county_metadata().set_index("fips").loc[fips].to_dict() ) self.state = self.geo_metadata["state"] self.state_obj = us.states.lookup(self.state) self.county = self.geo_metadata["county"] if self.county: self.display_name = self.county + ", " + self.state else: self.display_name = self.state ( self.times, self.observed_new_cases, self.observed_new_deaths, ) = load_data.load_new_case_data_by_fips( self.fips, t0=self.ref_date, include_testing_correction=self.include_testing_correction, ) ( self.times_raw_new_cases, self.raw_new_cases, _, ) = load_data.load_new_case_data_by_fips( self.fips, t0=self.ref_date, include_testing_correction=False, ) ( self.hospital_times, self.hospitalizations, self.hospitalization_data_type, ) = load_data.load_hospitalization_data(self.fips, t0=self.ref_date) self.case_dates = [ref_date + timedelta(days=int(t)) for t in self.times] self.raw_new_case_dates = [ ref_date + timedelta(days=int(t)) for t in self.times_raw_new_cases ] if self.hospitalization_data_type: self.hospital_dates = [ref_date + timedelta(days=int(t)) for t in self.hospital_times] self.default_parameters = ParameterEnsembleGenerator( fips=self.fips, N_samples=500, t_list=np.linspace(0, 365, 366) ).get_average_seir_parameters() # Serial period = Incubation + 0.5 * Infections self.serial_period = ( 1 / self.default_parameters["sigma"] + 0.5 * 1 / self.default_parameters["delta"] ) # If we only receive current hospitalizations, we need to account for # the outflow to reconstruct new admissions. if ( self.hospitalization_data_type is load_data.HospitalizationDataType.CURRENT_HOSPITALIZATIONS ): los_general = self.default_parameters["hospitalization_length_of_stay_general"] los_icu = self.default_parameters["hospitalization_length_of_stay_icu"] hosp_rate_general = self.default_parameters["hospitalization_rate_general"] hosp_rate_icu = self.default_parameters["hospitalization_rate_icu"] icu_rate = hosp_rate_icu / hosp_rate_general flow_out_of_hosp = self.hospitalizations[:-1] * ( (1 - icu_rate) / los_general + icu_rate / los_icu ) # We are attempting to reconstruct the cumulative hospitalizations. self.hospitalizations = np.diff(self.hospitalizations) + flow_out_of_hosp self.hospital_dates = self.hospital_dates[1:] self.hospital_times = self.hospital_times[1:] self.log_likelihood = None self.log = structlog.getLogger(Rt_Inference_Target=self.display_name) self.log.info(event="Running:")
def __init__(self, fips, window_size=7, kernel_std=2, r_list=np.linspace(0, 10, 501), process_sigma=0.15, ref_date=datetime(year=2020, month=1, day=1), confidence_intervals=(0.68, 0.75, 0.90)): self.fips = fips self.r_list = r_list self.window_size = window_size self.kernel_std = kernel_std self.process_sigma = process_sigma self.ref_date = ref_date self.confidence_intervals = confidence_intervals if len(fips) == 2: # State FIPS are 2 digits self.agg_level = AggregationLevel.STATE self.state_obj = us.states.lookup(self.fips) self.state = self.state_obj.name self.geo_metadata = load_data.load_county_metadata_by_state(self.state).loc[self.state].to_dict() self.times, self.observed_new_cases, self.observed_new_deaths = \ load_data.load_new_case_data_by_state(self.state, self.ref_date) self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \ load_data.load_hospitalization_data_by_state(self.state_obj.abbr, t0=self.ref_date) self.display_name = self.state else: self.agg_level = AggregationLevel.COUNTY self.geo_metadata = load_data.load_county_metadata().set_index('fips').loc[fips].to_dict() self.state = self.geo_metadata['state'] self.state_obj = us.states.lookup(self.state) self.county = self.geo_metadata['county'] if self.county: self.display_name = self.county + ', ' + self.state else: self.display_name = self.state # TODO Swap for new data source. self.times, self.observed_new_cases, self.observed_new_deaths = \ load_data.load_new_case_data_by_fips(self.fips, t0=self.ref_date) self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \ load_data.load_hospitalization_data(self.fips, t0=self.ref_date) logging.info(f'Running Rt Inference for {self.display_name}') self.case_dates = [ref_date + timedelta(days=int(t)) for t in self.times] if self.hospitalization_data_type: self.hospital_dates = [ref_date + timedelta(days=int(t)) for t in self.hospital_times] self.default_parameters = ParameterEnsembleGenerator( fips=self.fips, N_samples=500, t_list=np.linspace(0, 365, 366) ).get_average_seir_parameters() # Serial period = Incubation + 0.5 * Infections self.serial_period = 1 / self.default_parameters['sigma'] + 0.5 * 1 / self.default_parameters['delta'] # If we only receive current hospitalizations, we need to account for # the outflow to reconstruct new admissions. if self.hospitalization_data_type is load_data.HospitalizationDataType.CURRENT_HOSPITALIZATIONS: los_general = self.default_parameters['hospitalization_length_of_stay_general'] los_icu = self.default_parameters['hospitalization_length_of_stay_icu'] hosp_rate_general = self.default_parameters['hospitalization_rate_general'] hosp_rate_icu = self.default_parameters['hospitalization_rate_icu'] icu_rate = hosp_rate_icu / hosp_rate_general flow_out_of_hosp = self.hospitalizations[:-1] * ((1 - icu_rate) / los_general + icu_rate / los_icu) # We are attempting to reconstruct the cumulative hospitalizations. self.hospitalizations = np.diff(self.hospitalizations) + flow_out_of_hosp self.hospital_dates = self.hospital_dates[1:] self.hospital_times = self.hospital_times[1:] self.log_likelihood = None
def __init__( self, fips, ref_date=datetime(year=2020, month=1, day=1), min_deaths=2, n_years=1, cases_to_deaths_err_factor=0.5, hospital_to_deaths_err_factor=0.5, percent_error_on_max_observation=0.5, with_age_structure=False, ): # Seed the random state. It is unclear whether this propagates to the # Minuit optimizer. np.random.seed(seed=42) self.fips = fips self.ref_date = ref_date self.days_since_ref_date = (dt.date.today() - ref_date.date() - timedelta(days=7)).days # ndays end of 2nd ramp may extend past days_since_ref_date w/o penalty on chi2 score self.days_allowed_beyond_ref = 0 self.min_deaths = min_deaths self.t_list = np.linspace(0, int(365 * n_years), int(365 * n_years) + 1) self.cases_to_deaths_err_factor = cases_to_deaths_err_factor self.hospital_to_deaths_err_factor = hospital_to_deaths_err_factor self.percent_error_on_max_observation = percent_error_on_max_observation self.t0_guess = 60 self.with_age_structure = with_age_structure if len(fips) == 2: # State FIPS are 2 digits self.agg_level = AggregationLevel.STATE self.state_obj = us.states.lookup(self.fips) self.state = self.state_obj.name ( self.times, self.observed_new_cases, self.observed_new_deaths, ) = load_data.load_new_case_data_by_state(self.state, self.ref_date) ( self.hospital_times, self.hospitalizations, self.hospitalization_data_type, ) = load_data.load_hospitalization_data_by_state( self.state_obj.abbr, t0=self.ref_date) ( self.icu_times, self.icu, self.icu_data_type, ) = load_data.load_hospitalization_data_by_state( self.state_obj.abbr, t0=self.ref_date, category=HospitalizationCategory.ICU) self.display_name = self.state else: self.agg_level = AggregationLevel.COUNTY geo_metadata = load_data.load_county_metadata().set_index( "fips").loc[fips].to_dict() state = geo_metadata["state"] self.state_obj = us.states.lookup(state) county = geo_metadata["county"] if county: self.display_name = county + ", " + state else: self.display_name = state # TODO Swap for new data source. ( self.times, self.observed_new_cases, self.observed_new_deaths, ) = load_data.load_new_case_data_by_fips(self.fips, t0=self.ref_date) ( self.hospital_times, self.hospitalizations, self.hospitalization_data_type, ) = load_data.load_hospitalization_data(self.fips, t0=self.ref_date) ( self.icu_times, self.icu, self.icu_data_type, ) = load_data.load_hospitalization_data( self.fips, t0=self.ref_date, category=HospitalizationCategory.ICU) self.cases_stdev, self.hosp_stdev, self.deaths_stdev = self.calculate_observation_errors( ) self.set_inference_parameters() self.model_fit_keys = [ "R0", "eps", "t_break", "eps2", "t_delta_phases", "log10_I_initial", ] self.SEIR_kwargs = self.get_average_seir_parameters() self.fit_results = None self.mle_model = None self.chi2_deaths = None self.chi2_cases = None self.chi2_hosp = None self.dof_deaths = None self.dof_cases = None self.dof_hosp = None
def generate_start_times_for_state(state, generate_report=False): """ Generate imputed start dates for each county. Parameters ---------- state: str State to model counties of. generate_report: bool If True, generate summary plots. """ metadata = load_data.load_county_metadata() state_dir = os.path.join(OUTPUT_DIR, 'pyseir', state) os.makedirs(state_dir, exist_ok=True) os.makedirs(os.path.join(state_dir, 'reports'), exist_ok=True) os.makedirs(os.path.join(state_dir, 'data'), exist_ok=True) logging.info(f'Imputing start times for {state.capitalize()}') state_fips = us.states.lookup(state).fips metadata['state_fips'] = metadata['fips'].apply(lambda x: x[:2]) counties = metadata[metadata['state_fips'] == state_fips].fips if len(counties) == 0: raise ValueError(f'No entries for state {state}.') # Fit exponential model to extract T0. f = partial(_fit_fips, generate_report=generate_report) p = Pool() fips_to_fit_map = {fips: val for fips, val in zip(counties.values, p.map(f, counties.values))} p.close() # -------------------------------- # ML to Impute start time for counties with no data based on pop. density # ------------------------------- # Merge in county level metadata. county_fits = pd.DataFrame.from_dict(fips_to_fit_map, orient='index').reset_index().rename({'index': 'fips'}, axis=1) merged = county_fits.merge(metadata, on='fips') merged['days_from_2020_01_01'] = (merged.t0_date - datetime.fromisoformat('2020-01-01')).dt.days samples_with_data = merged['days_from_2020_01_01'].notnull() samples_with_no_data = merged['days_from_2020_01_01'].isnull() if samples_with_no_data.any(): X = np.nan_to_num(np.log(merged[['population_density', 'housing_density', 'total_population']][samples_with_data])) X_predict = np.nan_to_num(np.log(merged[['population_density', 'housing_density', 'total_population']][samples_with_no_data])) # Test a few regressions for estimator in [LinearRegression(), RandomForestRegressor(), BayesianRidge()]: cv_result = cross_validate(estimator, X=X, y=merged['days_from_2020_01_01'][samples_with_data], scoring='r2', cv=2) logging.info(f'{estimator.__class__.__name__} CV r2: {cv_result["test_score"].mean()}') # Train best model and impute the missing times. best_model = BayesianRidge() best_model.fit(X=X, y=merged['days_from_2020_01_01'][samples_with_data]) merged.loc[samples_with_no_data, 'days_from_2020_01_01'] = best_model.predict(X_predict) merged.loc[samples_with_no_data, 't0_date'] = datetime.fromisoformat('2020-01-01') \ + np.array([timedelta(days=t) for t in best_model.predict(X_predict)]) # Plot doubling time by population density merged.loc[samples_with_no_data, 'imputed_start_time'] = True merged.loc[samples_with_data, 'imputed_start_time'] = False merged.loc[samples_with_data, 'doubling_rate_days'] = np.log(2) * merged['model_params'][samples_with_data].apply(lambda x: x['scale']) merged.to_pickle(os.path.join(state_dir, 'data', f'summary__{state}_imputed_start_times.pkl')) if generate_report: # Plot population density plt.figure(figsize=(14, 4)) for i, x in enumerate(('population_density', 'housing_density', 'total_population')): plt.subplot(1, 3, i + 1) plt.title(state) sns.jointplot(x=np.log10(merged[x]), y='days_from_2020_01_01', data=merged, kind='reg', height=5) plt.xlabel('log10 Population Density') plt.savefig(os.path.join(state_dir, 'reports', f'summary__{state}__population_density.pdf'), bbox_inches='tight') plt.close() # Plot Doubling Rates by distance # TODO: Impute doubling time. sns.jointplot(np.log10(merged.population_density), merged.doubling_rate_days, kind='reg', height=10) plt.xlabel('Log10 Population Density', fontsize=16) plt.ylabel('Doubling Time [Days]', fontsize=16) plt.grid() plt.savefig(os.path.join(state_dir, 'reports', f'summary__{state}__doubling_time.pdf'), bbox_inches='tight') plt.close()
def generate_surge_spreadsheet(self): """ Produce a spreadsheet summarizing peaks. Parameters ---------- state: str State to generate sheet for. Returns ------- """ df = load_data.load_county_metadata() all_fips = load_data.get_all_fips_codes_for_a_state(self.state) all_data = { fips: load_data.load_ensemble_results(fips) for fips in all_fips } df = df.set_index("fips") records = [] for fips, ensembles in all_data.items(): county_name = df.loc[fips]["county"] t0 = fit_results.load_t0(fips) for suppression_policy, ensemble in ensembles.items(): county_record = dict( county_name=county_name, county_fips=fips, mitigation_policy=policy_to_mitigation(suppression_policy), ) for compartment in [ "HGen", "general_admissions_per_day", "HICU", "icu_admissions_per_day", "total_new_infections", "direct_deaths_per_day", "total_deaths", "D", ]: compartment_name = compartment_to_name_map[compartment] county_record[compartment_name + " Peak Value Mean"] = ( "%.0f" % ensemble[compartment]["peak_value_mean"]) county_record[compartment_name + " Peak Value Median"] = ( "%.0f" % ensemble[compartment]["peak_value_ci50"]) county_record[compartment_name + " Peak Value CI25"] = ( "%.0f" % ensemble[compartment]["peak_value_ci25"]) county_record[compartment_name + " Peak Value CI75"] = ( "%.0f" % ensemble[compartment]["peak_value_ci75"]) county_record[compartment_name + " Peak Time Median"] = (( t0 + timedelta(days=ensemble[compartment]["peak_time_ci50"]) ).date().isoformat()) # Leaving for now... # if 'surge_start' in ensemble[compartment]: # if not np.isnan(np.nanmean(ensemble[compartment]['surge_start'])): # county_record[compartment_name + ' Surge Start Mean'] = (t0 + timedelta(days=np.nanmean(ensemble[compartment]['surge_start']))).date().isoformat() # county_record[compartment_name + ' Surge End Mean'] = (t0 + timedelta(days=np.nanmean(ensemble[compartment]['surge_end']))).date().isoformat() records.append(county_record) df = pd.DataFrame(records) df.write_json(self.surge_filename)
def fit_county_model(fips): """ Fit the county's current trajectory, using the existing measures. We fit only to mortality data if available, else revert to case data. We assume a poisson process generates mortalities at a rate defined by the underlying dynamical model. TODO @ EC: Add hospitalization data when available. Parameters ---------- fips: str County fips. Returns ------- fit_values: dict Optimal values from the fitter. """ county_metadata = load_data.load_county_metadata().set_index( 'fips').loc[fips].to_dict() times, observed_new_cases, observed_new_deaths = load_data.load_new_case_data_by_fips( fips, t0=ref_date) logging.info( f'Fitting MLE model to {county_metadata["county"]}, {county_metadata["state"]}' ) SEIR_params = get_average_SEIR_parameters(fips) def _fit_seir(R0, t0, eps): model = SEIRModel(R0=R0, suppression_policy=suppression_policies. generate_empirical_distancing_policy( t_list, fips, future_suppression=eps), **SEIR_params) model.run() predicted_cases = model.gamma * np.interp( times, t_list + t0, model.results['total_new_infections']) predicted_deaths = np.interp(times, t_list + t0, model.results['direct_deaths_per_day']) # Assume the error on the case count could be off by a massive factor 50. # Basically we don't want to use it if there appreciable mortality data available. # Longer term there is a better procedure. cases_variance = 1e10 * observed_new_cases.copy( )**2 # Make the stdev N times larger x the number of cases deaths_variance = observed_new_deaths.copy() # Poisson dist error # Zero inflated poisson Avoid floating point errors.. cases_variance[cases_variance == 0] = 1e10 deaths_variance[deaths_variance == 0] = 1e10 # Compute Chi2 chi2_cases = np.sum( (observed_new_cases - predicted_cases)**2 / cases_variance) if observed_new_deaths.sum() > 5: chi2_deaths = np.sum( (observed_new_deaths - predicted_deaths)**2 / deaths_variance) else: chi2_deaths = 0 return chi2_deaths + chi2_cases # Note that error def is not right here. We need a realistic error model... m = iminuit.Minuit(_fit_seir, R0=4, t0=50, eps=.5, error_eps=.2, limit_R0=[1, 8], limit_eps=[0, 2], limit_t0=[-90, 90], error_t0=1, error_R0=1., errordef=1) m.migrad() values = dict(fips=fips, **dict(m.values)) values['t0_date'] = ref_date + timedelta(days=values['t0']) values['Reff_current'] = values['R0'] * values['eps'] values['observed_total_deaths'] = np.sum(observed_new_deaths) values['county'] = county_metadata['county'] values['state'] = county_metadata['state'] values['total_population'] = county_metadata['total_population'] values['population_density'] = county_metadata['population_density'] return values