def __init__(self, fips, N_samples, t_list,
                 I_initial=1, suppression_policy=None):

        # Caching globally to avoid relatively significant performance overhead
        # of loading for each county.
        global beds_data, population_data
        if not beds_data or not population_data:
            beds_data = DHBeds.local().beds()
            population_data = FIPSPopulation.local().population()

        self.fips = fips
        self.agg_level = AggregationLevel.COUNTY if len(self.fips) == 5 else AggregationLevel.STATE
        self.N_samples = N_samples
        self.I_initial = I_initial
        self.suppression_policy = suppression_policy
        self.t_list = t_list

        if self.agg_level is AggregationLevel.COUNTY:
            self.county_metadata = load_data.load_county_metadata().set_index('fips').loc[fips].to_dict()
            self.state_abbr = us.states.lookup(self.county_metadata['state']).abbr
            self.population = population_data.get_county_level('USA', state=self.state_abbr, fips=self.fips)
            # TODO: Some counties do not have hospitals. Likely need to go to HRR level..
            self.beds = beds_data.get_county_level(self.state_abbr, fips=self.fips) or 0
            self.icu_beds = beds_data.get_county_level(self.state_abbr, fips=self.fips, column='icu_beds') or 0
        else:
            self.state_abbr = us.states.lookup(fips).abbr
            self.population = population_data.get_state_level('USA', state=self.state_abbr)
            self.beds = beds_data.get_state_level(self.state_abbr) or 0
            self.icu_beds = beds_data.get_state_level(self.state_abbr, column='icu_beds') or 0
Beispiel #2
0
    def __init__(
        self,
        state,
        reference_date=datetime(day=1, month=3, year=2020),
        plot_compartments=("HICU", "HGen", "HVent"),
        primary_suppression_policy="suppression_policy__0.5",
    ):
        self.state = state
        self.reference_date = reference_date
        self.plot_compartments = plot_compartments
        self.primary_suppression_policy = primary_suppression_policy

        # Load the county metadata and extract names for the state.
        county_metadata = load_data.load_county_metadata()
        self.counties = county_metadata[county_metadata["state"].str.lower() ==
                                        self.state.lower()]["fips"]
        self.ensemble_data_by_county = {
            fips: load_data.load_ensemble_results(fips)
            for fips in self.counties
        }
        self.county_metadata = county_metadata.set_index("fips")
        self.names = [
            self.county_metadata.loc[fips, "county"].replace(" County", "")
            for fips in self.counties
        ]
        self.filename = os.path.join(
            OUTPUT_DIR, self.state, "reports",
            f"summary__{self.state}__state_report.pdf")
        self.surge_filename = os.path.join(
            OUTPUT_DIR, self.state, "reports",
            f"summary__{self.state}__state_surge_report.xlsx")
    def generate_state(self, states_only=False):
        """
        Generate for each county in a state, the output for the webUI.

        Parameters
        ----------
        states_only: bool
            If True only run the state level.
        """
        state_fips = us.states.lookup(self.state).fips
        self.map_fips(state_fips)

        if not states_only:
            df = load_data.load_county_metadata()
            all_fips = df[df['state'].str.lower() == self.state.lower()].fips

            if not self.include_imputed:
                # Filter...
                fips_with_cases = self.jhu_local.timeseries() \
                    .get_subset(AggregationLevel.COUNTY, country='USA') \
                    .get_data(country='USA', state=self.state_abbreviation)
                fips_with_cases = fips_with_cases[
                    fips_with_cases.cases > 0].fips.unique().tolist()
                all_fips = [
                    fips for fips in all_fips if fips in fips_with_cases
                ]

            p = Pool()
            p.map(self.map_fips, all_fips)
            p.close()
Beispiel #4
0
def run_state(state, states_only=False):
    """
    Run the R_t inference for each county in a state.

    Parameters
    ----------
    state: str
        State to run against.
    states_only: bool
        If True only run the state level.
    """
    state_obj = us.states.lookup(state)
    df = RtInferenceEngine.run_for_fips(state_obj.fips)
    output_path = get_run_artifact_path(state_obj.fips, RunArtifact.RT_INFERENCE_RESULT)
    df.to_json(output_path)

    # Run the counties.
    if not states_only:
        df = load_data.load_county_metadata()
        all_fips = df[df['state'].str.lower() == state_obj.name.lower()].fips.values

        # Something in here doesn't like multiprocessing...
        # p = Pool(2)
        rt_inferences = list(map(RtInferenceEngine.run_for_fips, all_fips))
        # p.close()

        for fips, rt_inference in zip(all_fips, rt_inferences):
            county_output_file = get_run_artifact_path(fips, RunArtifact.RT_INFERENCE_RESULT)
            if rt_inference is not None:
                rt_inference.to_json(county_output_file)
    def __init__(self,
                 state,
                 reference_date=datetime(day=1, month=3, year=2020),
                 plot_compartments=('HICU', 'HGen', 'HVent'),
                 primary_suppression_policy='suppression_policy__0.5'):
        self.state = state
        self.reference_date = reference_date
        self.plot_compartments = plot_compartments
        self.primary_suppression_policy = primary_suppression_policy

        # Load the county metadata and extract names for the state.
        county_metadata = load_data.load_county_metadata()
        self.counties = county_metadata[county_metadata['state'].str.lower() ==
                                        self.state.lower()]['fips']
        self.ensemble_data_by_county = {
            fips: load_data.load_ensemble_results(fips)
            for fips in self.counties
        }
        self.county_metadata = county_metadata.set_index('fips')
        self.names = [
            self.county_metadata.loc[fips, 'county'].replace(' County', '')
            for fips in self.counties
        ]
        self.filename = os.path.join(
            OUTPUT_DIR, self.state, 'reports',
            f"summary__{self.state}__state_report.pdf")
        self.surge_filename = os.path.join(
            OUTPUT_DIR, self.state, 'reports',
            f"summary__{self.state}__state_surge_report.xlsx")
Beispiel #6
0
def run_state(state, ensemble_kwargs, states_only=False):
    """
    Run the EnsembleRunner for each county in a state.

    Parameters
    ----------
    state: str
        State to run against.
    ensemble_kwargs: dict
        Kwargs passed to the EnsembleRunner object.
    states_only: bool
        If True only run the state level.
    """
    # Run the state level
    runner = EnsembleRunner(fips=us.states.lookup(state).fips, **ensemble_kwargs)
    runner.run_ensemble()

    if not states_only:
        # Run county level
        df = load_data.load_county_metadata()
        all_fips = df[df['state'].str.lower() == state.lower()].fips
        p = Pool()
        f = partial(_run_county, ensemble_kwargs=ensemble_kwargs)
        p.map(f, all_fips)
        p.close()
    def __init__(
        self, total_cases=50, total_deaths=0, nonzero_case_datapoints=5, nonzero_death_datapoints=0
    ):
        self.county_metadata = load_data.load_county_metadata()
        self.df_whitelist = None

        self.total_cases = total_cases
        self.total_deaths = total_deaths
        self.nonzero_case_datapoints = nonzero_case_datapoints
        self.nonzero_death_datapoints = nonzero_death_datapoints
Beispiel #8
0
def generate_empirical_distancing_policy_by_state(t_list,
                                                  state,
                                                  future_suppression,
                                                  reference_start_date=None):
    """
    Produce a suppression policy at state level based on Imperial College
    estimates of social distancing programs combined with County level
    datasets about their implementation.

    Note: This is about 250ms per state, which adds up when running e.g. MLE
    optimization. Bottleneck is computing the suppression policy to date which
    is done by summing counties. This should be done once per state and lru
    cached, not done for each county every call. Also just using numpy instead
    of pandas.

    Parameters
    ----------
    t_list: array-like
        List of times to interpolate over.
    state: str
        State full name to lookup interventions against.
    future_suppression: float
        The suppression level to apply in an ongoing basis after today, and
        going backward as the lockdown / stay-at-home efficacy.
    reference_start_date: pd.Timestamp
        Start date as reference to shift t_list.

    Returns
    -------
    suppression_model: callable
        suppression_model(t) returns the current suppression model at time t
    """
    county_metadata = load_data.load_county_metadata()
    counties_fips = county_metadata[county_metadata.state ==
                                    state].fips.unique()

    if reference_start_date is None:
        reference_start_date = min([infer_t0(fips) for fips in counties_fips])

    # Aggregate the counties to the state level, weighted by population.
    weight = county_metadata.loc[county_metadata.state == state,
                                 "total_population"].values
    weight = weight / weight.sum()
    results = []
    for fips in counties_fips:
        suppression_policy = generate_empirical_distancing_policy(
            fips=fips,
            t_list=t_list,
            future_suppression=future_suppression,
            reference_start_date=reference_start_date,
        )
        results.append(suppression_policy(t_list).clip(max=1, min=0))
    results_for_state = (np.vstack(results).T * weight).sum(axis=1)

    return interp1d(t_list, results_for_state, fill_value="extrapolate")
    def generate_surge_spreadsheet(self):
        """
        Produce a spreadsheet summarizing peaks.

        Parameters
        ----------
        state: str
            State to generate sheet for.

        Returns
        -------

        """
        df = load_data.load_county_metadata()
        all_fips = df[df['state'].str.lower() == self.state.lower()].fips
        all_data = {fips: load_data.load_ensemble_results(fips) for fips in all_fips}
        df = df.set_index('fips')

        records = []
        for fips, ensembles in all_data.items():
            county_name = df.loc[fips]['county']
            t0 = fit_results.load_t0(fips)

            for suppression_policy, ensemble in ensembles.items():

                county_record = dict(
                    county_name=county_name,
                    county_fips=fips,
                    mitigation_policy=policy_to_mitigation(suppression_policy)
                )

                for compartment in ['HGen', 'general_admissions_per_day', 'HICU', 'icu_admissions_per_day', 'total_new_infections',
                                    'direct_deaths_per_day', 'total_deaths', 'D']:
                    compartment_name = compartment_to_name_map[compartment]

                    county_record[compartment_name + ' Peak Value Mean'] = '%.0f' % ensemble[compartment]['peak_value_mean']
                    county_record[compartment_name + ' Peak Value Median'] = '%.0f' % ensemble[compartment]['peak_value_ci50']
                    county_record[compartment_name + ' Peak Value CI25'] = '%.0f' % ensemble[compartment]['peak_value_ci25']
                    county_record[compartment_name + ' Peak Value CI75'] = '%.0f' % ensemble[compartment]['peak_value_ci75']
                    county_record[compartment_name + ' Peak Time Median'] = (t0 + timedelta(days=ensemble[compartment]['peak_time_ci50'])).date().isoformat()

                    # Leaving for now...
                    # if 'surge_start' in ensemble[compartment]:
                    #     if not np.isnan(np.nanmean(ensemble[compartment]['surge_start'])):
                    #         county_record[compartment_name + ' Surge Start Mean'] = (t0 + timedelta(days=np.nanmean(ensemble[compartment]['surge_start']))).date().isoformat()
                    #         county_record[compartment_name + ' Surge End Mean'] = (t0 + timedelta(days=np.nanmean(ensemble[compartment]['surge_end']))).date().isoformat()

                records.append(county_record)

        df = pd.DataFrame(records)
        writer = pd.ExcelWriter(self.surge_filename, engine='xlsxwriter')
        for policy in df['mitigation_policy'].unique()[::-1]:
            df[df['mitigation_policy'] == policy].drop(['mitigation_policy', 'county_fips'], axis=1)
            df[df['mitigation_policy'] == policy].drop(['mitigation_policy', 'county_fips'], axis=1).to_excel(writer, sheet_name=policy)
        writer.save()
Beispiel #10
0
    def __init__(self, fips, N_samples, t_list,
                 I_initial=1, suppression_policy=None):
        self.fips = fips
        self.agg_level = AggregationLevel.COUNTY if len(self.fips) == 5 else AggregationLevel.STATE
        self.N_samples = N_samples
        self.I_initial = I_initial
        self.suppression_policy = suppression_policy
        self.t_list = t_list

        if self.agg_level is AggregationLevel.COUNTY:
            self.county_metadata = load_data.load_county_metadata().set_index('fips').loc[fips].to_dict()
            self.state_abbr = us.states.lookup(self.county_metadata['state']).abbr
            self._latest = combined_datasets.get_us_latest_for_fips(self.fips)
        else:
            self.state_abbr = us.states.lookup(fips).abbr
            self._latest = combined_datasets.get_us_latest_for_state(self.state_abbr)
Beispiel #11
0
def run_state(state, ensemble_kwargs):
    """
    Run the EnsembleRunner for each county in a state.

    Parameters
    ----------
    state: str
        State to run against.
    ensemble_kwargs: dict
        Kwargs passed to the EnsembleRunner object.
    """
    df = load_data.load_county_metadata()
    all_fips = df[df['state'].str.lower() == state.lower()].fips
    p = Pool()
    f = partial(_run_county, ensemble_kwargs=ensemble_kwargs)
    p.map(f, all_fips)
    p.close()
Beispiel #12
0
def load_t0(fips):
    """
    Load the simulation start time by county.

    Parameters
    ----------
    fips: str
        County FIPS

    Returns
    -------
    : datetime
        t0(C=1) cases.
    """
    county_metadata = load_county_metadata().set_index('fips')
    state = county_metadata.loc[fips]['state']
    fit_results = os.path.join(OUTPUT_DIR, state, 'data', f'summary__{state}_imputed_start_times.pkl')
    return datetime.fromtimestamp(pd.read_pickle(fit_results).set_index('fips').loc[fips]['t0_date'].timestamp())
Beispiel #13
0
def run_state(state):
    """
    Run the fitter for each county in a state.

    Parameters
    ----------
    state: str
        State to run against.
    """
    df = load_data.load_county_metadata()
    all_fips = df[df['state'].str.lower() == state.lower()].fips

    p = Pool()
    fit_results = p.map(fit_county_model, all_fips)

    output_file = os.path.join(OUTPUT_DIR, state.title(), 'data',
                               f'summary_{state}__mle_fit_results.json')
    pd.DataFrame(fit_results).to_json(output_file)

    p.map(plot_inferred_result, fit_results)
    p.close()
Beispiel #14
0
def load_t0(fips):
    """
    Load the simulation start time by county.

    Parameters
    ----------
    fips: str
        County FIPS

    Returns
    -------
    : datetime
        t0(C=1) cases.
    """
    county_metadata = load_county_metadata().set_index("fips")
    state = county_metadata.loc[fips]["state"]
    fit_results = os.path.join(OUTPUT_DIR, "pyseir", state, "data",
                               f"summary__{state}_imputed_start_times.pkl")
    return datetime.fromtimestamp(
        pd.read_pickle(fit_results).set_index("fips").loc[fips]
        ["t0_date"].timestamp())
    def __init__(self,
                 fips,
                 N_samples,
                 t_list,
                 I_initial=1,
                 suppression_policy=None):

        self.fips = fips
        self.N_samples = N_samples
        self.I_initial = I_initial
        self.suppression_policy = suppression_policy
        self.t_list = t_list
        county_metadata = load_data.load_county_metadata()
        hospital_bed_data = load_data.load_hospital_data()

        # TODO: Some counties do not have hospitals. Likely need to go to HRR level..
        hospital_bed_data = hospital_bed_data[[
            'fips', 'num_licensed_beds', 'num_staffed_beds', 'num_icu_beds',
            'bed_utilization', 'potential_increase_in_bed_capac'
        ]].groupby('fips').sum()
        self.county_metadata_merged = county_metadata.merge(
            hospital_bed_data, on='fips',
            how='left').set_index('fips').loc[fips].to_dict()
Beispiel #16
0
    def __init__(self,
                 fips,
                 N_samples,
                 t_list,
                 I_initial=1,
                 suppression_policy=None):

        # Caching globally to avoid relatively significant performance overhead
        # of loading for each county.
        global beds_data, population_data
        if not beds_data or not population_data:
            beds_data = CovidCareMapBeds.local().beds()
            population_data = FIPSPopulation.local().population()

        self.fips = fips
        self.agg_level = AggregationLevel.COUNTY if len(
            self.fips) == 5 else AggregationLevel.STATE
        self.N_samples = N_samples
        self.I_initial = I_initial
        self.suppression_policy = suppression_policy
        self.t_list = t_list

        if self.agg_level is AggregationLevel.COUNTY:
            self.county_metadata = load_data.load_county_metadata().set_index(
                'fips').loc[fips].to_dict()
            self.state_abbr = us.states.lookup(
                self.county_metadata['state']).abbr
            self.population = population_data.get_record_for_fips(
                fips=self.fips)[CommonFields.POPULATION]
            # TODO: Some counties do not have hospitals. Likely need to go to HRR level..
            self._beds_data = beds_data.get_record_for_fips(fips)
        else:
            self.state_abbr = us.states.lookup(fips).abbr
            self.population = population_data.get_record_for_state(
                self.state_abbr)[CommonFields.POPULATION]
            self._beds_data = beds_data.get_record_for_state(self.state_abbr)
Beispiel #17
0
def plot_inferred_result(fit_results):
    """
    Plot the results of an MLE inference
    """
    fips = fit_results['fips']
    county_metadata = load_data.load_county_metadata().set_index(
        'fips').loc[fips].to_dict()
    times, observed_new_cases, observed_new_deaths = load_data.load_new_case_data_by_fips(
        fips, t0=ref_date)
    if observed_new_cases.sum() < 5:
        logging.warning(
            f"{county_metadata['county']} has fewer than 5 cases. Aborting plot."
        )
        return
    else:
        logging.info(f"Plotting MLE Fits for {county_metadata['county']}")

    R0, t0, eps = fit_results['R0'], fit_results['t0'], fit_results['eps']

    model = SEIRModel(R0=R0,
                      suppression_policy=suppression_policies.
                      generate_empirical_distancing_policy(
                          t_list, fips, future_suppression=eps),
                      **get_average_SEIR_parameters(fit_results['fips']))
    model.run()

    data_dates = [ref_date + timedelta(days=t) for t in times]
    model_dates = [
        ref_date + timedelta(days=t + fit_results['t0']) for t in t_list
    ]
    plt.figure(figsize=(10, 8))
    plt.errorbar(data_dates,
                 observed_new_cases,
                 marker='o',
                 linestyle='',
                 label='Observed Cases Per Day')
    plt.errorbar(data_dates,
                 observed_new_deaths,
                 yerr=np.sqrt(observed_new_deaths),
                 marker='o',
                 linestyle='',
                 label='Observed Deaths')
    plt.plot(model_dates,
             model.results['total_new_infections'],
             label='Estimated Total New Infections Per Day')
    plt.plot(model_dates,
             model.gamma * model.results['total_new_infections'],
             label='Symptomatic Model Cases Per Day')
    plt.plot(model_dates,
             model.results['direct_deaths_per_day'],
             label='Model Deaths Per Day')
    plt.yscale('log')
    plt.ylim(.9e0)
    plt.xlim(data_dates[0], data_dates[-1] + timedelta(days=90))

    plt.xticks(rotation=30)
    plt.legend(loc=1)
    plt.grid(which='both', alpha=.3)
    plt.title(county_metadata['county'])
    for i, (k, v) in enumerate(fit_results.items()):
        if k not in ('fips', 't0_date', 'county', 'state'):
            plt.text(.025,
                     .97 - 0.04 * i,
                     f'{k}={v:1.3f}',
                     transform=plt.gca().transAxes,
                     fontsize=12)
        else:
            plt.text(.025,
                     .97 - 0.04 * i,
                     f'{k}={v}',
                     transform=plt.gca().transAxes,
                     fontsize=12)

    output_file = os.path.join(
        OUTPUT_DIR, fit_results['state'].title(), 'reports',
        f'{fit_results["state"]}__{fit_results["county"]}__{fit_results["fips"]}__mle_fit_results.pdf'
    )
    plt.savefig(output_file)
Beispiel #18
0
    def __init__(self,
                 fips,
                 ref_date=datetime(year=2020, month=1, day=1),
                 min_deaths=2,
                 n_years=1,
                 cases_to_deaths_err_factor=.5,
                 hospital_to_deaths_err_factor=.5,
                 percent_error_on_max_observation=0.5,
                 with_age_structure=False):

        # Seed the random state. It is unclear whether this propagates to the
        # Minuit optimizer.
        np.random.seed(seed=42)

        self.fips = fips
        self.ref_date = ref_date
        self.min_deaths = min_deaths
        self.t_list = np.linspace(0, int(365 * n_years),
                                  int(365 * n_years) + 1)
        self.cases_to_deaths_err_factor = cases_to_deaths_err_factor
        self.hospital_to_deaths_err_factor = hospital_to_deaths_err_factor
        self.percent_error_on_max_observation = percent_error_on_max_observation
        self.t0_guess = 60
        self.with_age_structure = with_age_structure

        if len(fips) == 2:  # State FIPS are 2 digits
            self.agg_level = AggregationLevel.STATE
            self.state_obj = us.states.lookup(self.fips)
            self.state = self.state_obj.name

            self.times, self.observed_new_cases, self.observed_new_deaths = \
                load_data.load_new_case_data_by_state(self.state, self.ref_date)

            self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \
                load_data.load_hospitalization_data_by_state(self.state_obj.abbr, t0=self.ref_date)
            self.display_name = self.state
        else:
            self.agg_level = AggregationLevel.COUNTY
            geo_metadata = load_data.load_county_metadata().set_index(
                'fips').loc[fips].to_dict()
            state = geo_metadata['state']
            self.state_obj = us.states.lookup(state)
            county = geo_metadata['county']
            if county:
                self.display_name = county + ', ' + state
            else:
                self.display_name = state
            # TODO Swap for new data source.
            self.times, self.observed_new_cases, self.observed_new_deaths = \
                load_data.load_new_case_data_by_fips(self.fips, t0=self.ref_date)
            self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \
                load_data.load_hospitalization_data(self.fips, t0=self.ref_date)

        self.cases_stdev, self.hosp_stdev, self.deaths_stdev = self.calculate_observation_errors(
        )
        self.set_inference_parameters()

        self.model_fit_keys = ['R0', 'eps', 't_break', 'log10_I_initial']

        self.SEIR_kwargs = self.get_average_seir_parameters()
        self.fit_results = None
        self.mle_model = None

        self.chi2_deaths = None
        self.chi2_cases = None
        self.chi2_hosp = None
        self.dof_deaths = None
        self.dof_cases = None
        self.dof_hosp = None
def generate_start_times_for_state(state, generate_report=False):
    """
    Generate imputed start dates for each county.

    Parameters
    ----------
    state: str
        State to model counties of.
    generate_report: bool
        If True, generate summary plots.
    """
    metadata = load_data.load_county_metadata()
    state_dir = os.path.join(OUTPUT_DIR, "pyseir", state)
    os.makedirs(state_dir, exist_ok=True)
    os.makedirs(os.path.join(state_dir, "reports"), exist_ok=True)
    os.makedirs(os.path.join(state_dir, "data"), exist_ok=True)

    logging.info(f"Imputing start times for {state.capitalize()}")

    state_fips = us.states.lookup(state).fips
    metadata["state_fips"] = metadata["fips"].apply(lambda x: x[:2])
    counties = metadata[metadata["state_fips"] == state_fips].fips
    if len(counties) == 0:
        raise ValueError(f"No entries for state {state}.")

    # Fit exponential model to extract T0.
    f = partial(_fit_fips, generate_report=generate_report)
    with Pool(maxtasksperchild=1) as p:
        fips_to_fit_map = {
            fips: val for fips, val in zip(counties.values, p.map(f, counties.values))
        }

    # --------------------------------
    # ML to Impute start time for counties with no data based on pop. density
    # -------------------------------
    # Merge in county level metadata.
    county_fits = (
        pd.DataFrame.from_dict(fips_to_fit_map, orient="index")
        .reset_index()
        .rename({"index": "fips"}, axis=1)
    )
    merged = county_fits.merge(metadata, on="fips")
    merged["days_from_2020_01_01"] = (merged.t0_date - datetime.fromisoformat("2020-01-01")).dt.days

    samples_with_data = merged["days_from_2020_01_01"].notnull()
    samples_with_no_data = merged["days_from_2020_01_01"].isnull()
    if samples_with_no_data.any():
        X = np.nan_to_num(
            np.log(
                merged[["population_density", "housing_density", "total_population"]][
                    samples_with_data
                ]
            )
        )
        X_predict = np.nan_to_num(
            np.log(
                merged[["population_density", "housing_density", "total_population"]][
                    samples_with_no_data
                ]
            )
        )

        # Test a few regressions
        for estimator in [LinearRegression(), RandomForestRegressor(), BayesianRidge()]:
            cv_result = cross_validate(
                estimator,
                X=X,
                y=merged["days_from_2020_01_01"][samples_with_data],
                scoring="r2",
                cv=2,
            )
            logging.info(f'{estimator.__class__.__name__} CV r2: {cv_result["test_score"].mean()}')

        # Train best model and impute the missing times.
        best_model = BayesianRidge()
        best_model.fit(X=X, y=merged["days_from_2020_01_01"][samples_with_data])

        merged.loc[samples_with_no_data, "days_from_2020_01_01"] = best_model.predict(X_predict)
        merged.loc[samples_with_no_data, "t0_date"] = datetime.fromisoformat(
            "2020-01-01"
        ) + np.array([timedelta(days=t) for t in best_model.predict(X_predict)])

    # Plot doubling time by population density
    merged.loc[samples_with_no_data, "imputed_start_time"] = True
    merged.loc[samples_with_data, "imputed_start_time"] = False
    merged.loc[samples_with_data, "doubling_rate_days"] = np.log(2) * merged["model_params"][
        samples_with_data
    ].apply(lambda x: x["scale"])
    merged.to_pickle(os.path.join(state_dir, "data", f"summary__{state}_imputed_start_times.pkl"))

    if generate_report:
        # Plot population density
        plt.figure(figsize=(14, 4))
        for i, x in enumerate(("population_density", "housing_density", "total_population")):
            plt.subplot(1, 3, i + 1)
            plt.title(state)
            sns.jointplot(
                x=np.log10(merged[x]), y="days_from_2020_01_01", data=merged, kind="reg", height=5
            )
            plt.xlabel("log10 Population Density")
        plt.savefig(
            os.path.join(state_dir, "reports", f"summary__{state}__population_density.pdf"),
            bbox_inches="tight",
        )
        plt.close()

        # Plot Doubling Rates by distance
        # TODO: Impute doubling time.
        sns.jointplot(
            np.log10(merged.population_density), merged.doubling_rate_days, kind="reg", height=10
        )
        plt.xlabel("Log10 Population Density", fontsize=16)
        plt.ylabel("Doubling Time [Days]", fontsize=16)
        plt.grid()
        plt.savefig(
            os.path.join(state_dir, "reports", f"summary__{state}__doubling_time.pdf"),
            bbox_inches="tight",
        )
        plt.close()
Beispiel #20
0
    def __init__(
        self,
        fips,
        window_size=InferRtConstants.COUNT_SMOOTHING_WINDOW_SIZE,
        kernel_std=5,
        r_list=np.linspace(0, 10, 501),
        process_sigma=0.05,
        ref_date=datetime(year=2020, month=1, day=1),
        confidence_intervals=(0.68, 0.95),
        min_cases=5,
        min_deaths=5,
        include_testing_correction=True,
    ):
        np.random.seed(InferRtConstants.RNG_SEED)
        # Param Generation used for Xcor in align_time_series, has some stochastic FFT elements.
        self.fips = fips
        self.r_list = r_list
        self.window_size = window_size
        self.kernel_std = kernel_std
        self.process_sigma = process_sigma
        self.ref_date = ref_date
        self.confidence_intervals = confidence_intervals
        self.min_cases = min_cases
        self.min_deaths = min_deaths
        self.include_testing_correction = include_testing_correction

        # Because rounding is disabled we don't need high min_deaths, min_cases anymore
        self.min_cases = min(InferRtConstants.MIN_COUNTS_TO_INFER, self.min_cases)
        if not InferRtConstants.DISABLE_DEATHS:
            self.min_deaths = min(InferRtConstants.MIN_COUNTS_TO_INFER, self.min_deaths)

        if len(fips) == 2:  # State FIPS are 2 digits
            self.agg_level = AggregationLevel.STATE
            self.state_obj = us.states.lookup(self.fips)
            self.state = self.state_obj.name

            (
                self.times,
                self.observed_new_cases,
                self.observed_new_deaths,
            ) = load_data.load_new_case_data_by_state(
                self.state,
                self.ref_date,
                include_testing_correction=self.include_testing_correction,
            )
            self.times_raw_new_cases, self.raw_new_cases, _ = load_data.load_new_case_data_by_state(
                self.state, self.ref_date, include_testing_correction=False
            )

            (
                self.hospital_times,
                self.hospitalizations,
                self.hospitalization_data_type,
            ) = load_data.load_hospitalization_data_by_state(
                state=self.state_obj.abbr, t0=self.ref_date
            )
            self.display_name = self.state
        else:
            self.agg_level = AggregationLevel.COUNTY
            self.geo_metadata = (
                load_data.load_county_metadata().set_index("fips").loc[fips].to_dict()
            )
            self.state = self.geo_metadata["state"]
            self.state_obj = us.states.lookup(self.state)
            self.county = self.geo_metadata["county"]
            if self.county:
                self.display_name = self.county + ", " + self.state
            else:
                self.display_name = self.state

            (
                self.times,
                self.observed_new_cases,
                self.observed_new_deaths,
            ) = load_data.load_new_case_data_by_fips(
                self.fips,
                t0=self.ref_date,
                include_testing_correction=self.include_testing_correction,
            )
            (
                self.times_raw_new_cases,
                self.raw_new_cases,
                _,
            ) = load_data.load_new_case_data_by_fips(
                self.fips, t0=self.ref_date, include_testing_correction=False,
            )
            (
                self.hospital_times,
                self.hospitalizations,
                self.hospitalization_data_type,
            ) = load_data.load_hospitalization_data(self.fips, t0=self.ref_date)

        self.case_dates = [ref_date + timedelta(days=int(t)) for t in self.times]
        self.raw_new_case_dates = [
            ref_date + timedelta(days=int(t)) for t in self.times_raw_new_cases
        ]

        if self.hospitalization_data_type:
            self.hospital_dates = [ref_date + timedelta(days=int(t)) for t in self.hospital_times]

        self.default_parameters = ParameterEnsembleGenerator(
            fips=self.fips, N_samples=500, t_list=np.linspace(0, 365, 366)
        ).get_average_seir_parameters()

        # Serial period = Incubation + 0.5 * Infections
        self.serial_period = (
            1 / self.default_parameters["sigma"] + 0.5 * 1 / self.default_parameters["delta"]
        )

        # If we only receive current hospitalizations, we need to account for
        # the outflow to reconstruct new admissions.
        if (
            self.hospitalization_data_type
            is load_data.HospitalizationDataType.CURRENT_HOSPITALIZATIONS
        ):
            los_general = self.default_parameters["hospitalization_length_of_stay_general"]
            los_icu = self.default_parameters["hospitalization_length_of_stay_icu"]
            hosp_rate_general = self.default_parameters["hospitalization_rate_general"]
            hosp_rate_icu = self.default_parameters["hospitalization_rate_icu"]
            icu_rate = hosp_rate_icu / hosp_rate_general
            flow_out_of_hosp = self.hospitalizations[:-1] * (
                (1 - icu_rate) / los_general + icu_rate / los_icu
            )
            # We are attempting to reconstruct the cumulative hospitalizations.
            self.hospitalizations = np.diff(self.hospitalizations) + flow_out_of_hosp
            self.hospital_dates = self.hospital_dates[1:]
            self.hospital_times = self.hospital_times[1:]

        self.log_likelihood = None

        self.log = structlog.getLogger(Rt_Inference_Target=self.display_name)
        self.log.info(event="Running:")
Beispiel #21
0
    def __init__(self,
                 fips,
                 window_size=7,
                 kernel_std=2,
                 r_list=np.linspace(0, 10, 501),
                 process_sigma=0.15,
                 ref_date=datetime(year=2020, month=1, day=1),
                 confidence_intervals=(0.68, 0.75, 0.90)):

        self.fips = fips
        self.r_list = r_list
        self.window_size = window_size
        self.kernel_std = kernel_std
        self.process_sigma = process_sigma
        self.ref_date = ref_date
        self.confidence_intervals = confidence_intervals

        if len(fips) == 2:  # State FIPS are 2 digits
            self.agg_level = AggregationLevel.STATE
            self.state_obj = us.states.lookup(self.fips)
            self.state = self.state_obj.name
            self.geo_metadata = load_data.load_county_metadata_by_state(self.state).loc[self.state].to_dict()

            self.times, self.observed_new_cases, self.observed_new_deaths = \
                load_data.load_new_case_data_by_state(self.state, self.ref_date)

            self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \
                load_data.load_hospitalization_data_by_state(self.state_obj.abbr, t0=self.ref_date)
            self.display_name = self.state
        else:
            self.agg_level = AggregationLevel.COUNTY
            self.geo_metadata = load_data.load_county_metadata().set_index('fips').loc[fips].to_dict()
            self.state = self.geo_metadata['state']
            self.state_obj = us.states.lookup(self.state)
            self.county = self.geo_metadata['county']
            if self.county:
                self.display_name = self.county + ', ' + self.state
            else:
                self.display_name = self.state

            # TODO Swap for new data source.
            self.times, self.observed_new_cases, self.observed_new_deaths = \
                load_data.load_new_case_data_by_fips(self.fips, t0=self.ref_date)
            self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \
                load_data.load_hospitalization_data(self.fips, t0=self.ref_date)

        logging.info(f'Running Rt Inference for {self.display_name}')

        self.case_dates = [ref_date + timedelta(days=int(t)) for t in self.times]
        if self.hospitalization_data_type:
            self.hospital_dates = [ref_date + timedelta(days=int(t)) for t in self.hospital_times]

        self.default_parameters = ParameterEnsembleGenerator(
            fips=self.fips,
            N_samples=500,
            t_list=np.linspace(0, 365, 366)
        ).get_average_seir_parameters()

        # Serial period = Incubation + 0.5 * Infections
        self.serial_period = 1 / self.default_parameters['sigma'] + 0.5 * 1 /   self.default_parameters['delta']

        # If we only receive current hospitalizations, we need to account for
        # the outflow to reconstruct new admissions.
        if self.hospitalization_data_type is load_data.HospitalizationDataType.CURRENT_HOSPITALIZATIONS:
            los_general = self.default_parameters['hospitalization_length_of_stay_general']
            los_icu = self.default_parameters['hospitalization_length_of_stay_icu']
            hosp_rate_general = self.default_parameters['hospitalization_rate_general']
            hosp_rate_icu = self.default_parameters['hospitalization_rate_icu']
            icu_rate = hosp_rate_icu / hosp_rate_general
            flow_out_of_hosp = self.hospitalizations[:-1] * ((1 - icu_rate) / los_general + icu_rate / los_icu)
            # We are attempting to reconstruct the cumulative hospitalizations.
            self.hospitalizations = np.diff(self.hospitalizations) + flow_out_of_hosp
            self.hospital_dates = self.hospital_dates[1:]
            self.hospital_times = self.hospital_times[1:]

        self.log_likelihood = None
Beispiel #22
0
    def __init__(
        self,
        fips,
        ref_date=datetime(year=2020, month=1, day=1),
        min_deaths=2,
        n_years=1,
        cases_to_deaths_err_factor=0.5,
        hospital_to_deaths_err_factor=0.5,
        percent_error_on_max_observation=0.5,
        with_age_structure=False,
    ):

        # Seed the random state. It is unclear whether this propagates to the
        # Minuit optimizer.
        np.random.seed(seed=42)

        self.fips = fips
        self.ref_date = ref_date
        self.days_since_ref_date = (dt.date.today() - ref_date.date() -
                                    timedelta(days=7)).days
        # ndays end of 2nd ramp may extend past days_since_ref_date w/o  penalty on chi2 score
        self.days_allowed_beyond_ref = 0
        self.min_deaths = min_deaths
        self.t_list = np.linspace(0, int(365 * n_years),
                                  int(365 * n_years) + 1)
        self.cases_to_deaths_err_factor = cases_to_deaths_err_factor
        self.hospital_to_deaths_err_factor = hospital_to_deaths_err_factor
        self.percent_error_on_max_observation = percent_error_on_max_observation
        self.t0_guess = 60
        self.with_age_structure = with_age_structure

        if len(fips) == 2:  # State FIPS are 2 digits
            self.agg_level = AggregationLevel.STATE
            self.state_obj = us.states.lookup(self.fips)
            self.state = self.state_obj.name

            (
                self.times,
                self.observed_new_cases,
                self.observed_new_deaths,
            ) = load_data.load_new_case_data_by_state(self.state,
                                                      self.ref_date)

            (
                self.hospital_times,
                self.hospitalizations,
                self.hospitalization_data_type,
            ) = load_data.load_hospitalization_data_by_state(
                self.state_obj.abbr, t0=self.ref_date)

            (
                self.icu_times,
                self.icu,
                self.icu_data_type,
            ) = load_data.load_hospitalization_data_by_state(
                self.state_obj.abbr,
                t0=self.ref_date,
                category=HospitalizationCategory.ICU)

            self.display_name = self.state
        else:
            self.agg_level = AggregationLevel.COUNTY
            geo_metadata = load_data.load_county_metadata().set_index(
                "fips").loc[fips].to_dict()
            state = geo_metadata["state"]
            self.state_obj = us.states.lookup(state)
            county = geo_metadata["county"]
            if county:
                self.display_name = county + ", " + state
            else:
                self.display_name = state
            # TODO Swap for new data source.
            (
                self.times,
                self.observed_new_cases,
                self.observed_new_deaths,
            ) = load_data.load_new_case_data_by_fips(self.fips,
                                                     t0=self.ref_date)
            (
                self.hospital_times,
                self.hospitalizations,
                self.hospitalization_data_type,
            ) = load_data.load_hospitalization_data(self.fips,
                                                    t0=self.ref_date)
            (
                self.icu_times,
                self.icu,
                self.icu_data_type,
            ) = load_data.load_hospitalization_data(
                self.fips,
                t0=self.ref_date,
                category=HospitalizationCategory.ICU)

        self.cases_stdev, self.hosp_stdev, self.deaths_stdev = self.calculate_observation_errors(
        )
        self.set_inference_parameters()

        self.model_fit_keys = [
            "R0",
            "eps",
            "t_break",
            "eps2",
            "t_delta_phases",
            "log10_I_initial",
        ]

        self.SEIR_kwargs = self.get_average_seir_parameters()
        self.fit_results = None
        self.mle_model = None

        self.chi2_deaths = None
        self.chi2_cases = None
        self.chi2_hosp = None
        self.dof_deaths = None
        self.dof_cases = None
        self.dof_hosp = None
Beispiel #23
0
def generate_start_times_for_state(state, generate_report=False):
    """
    Generate imputed start dates for each county.

    Parameters
    ----------
    state: str
        State to model counties of.
    generate_report: bool
        If True, generate summary plots.
    """
    metadata = load_data.load_county_metadata()
    state_dir = os.path.join(OUTPUT_DIR, 'pyseir', state)
    os.makedirs(state_dir, exist_ok=True)
    os.makedirs(os.path.join(state_dir, 'reports'), exist_ok=True)
    os.makedirs(os.path.join(state_dir, 'data'), exist_ok=True)

    logging.info(f'Imputing start times for {state.capitalize()}')

    state_fips = us.states.lookup(state).fips
    metadata['state_fips'] = metadata['fips'].apply(lambda x: x[:2])
    counties = metadata[metadata['state_fips'] == state_fips].fips
    if len(counties) == 0:
        raise ValueError(f'No entries for state {state}.')

    # Fit exponential model to extract T0.
    f = partial(_fit_fips, generate_report=generate_report)
    p = Pool()
    fips_to_fit_map = {fips: val for fips, val in zip(counties.values, p.map(f, counties.values))}
    p.close()

    # --------------------------------
    # ML to Impute start time for counties with no data based on pop. density
    # -------------------------------
    # Merge in county level metadata.
    county_fits = pd.DataFrame.from_dict(fips_to_fit_map, orient='index').reset_index().rename({'index': 'fips'}, axis=1)
    merged = county_fits.merge(metadata, on='fips')
    merged['days_from_2020_01_01'] = (merged.t0_date - datetime.fromisoformat('2020-01-01')).dt.days

    samples_with_data = merged['days_from_2020_01_01'].notnull()
    samples_with_no_data = merged['days_from_2020_01_01'].isnull()
    if samples_with_no_data.any():
        X = np.nan_to_num(np.log(merged[['population_density', 'housing_density', 'total_population']][samples_with_data]))
        X_predict = np.nan_to_num(np.log(merged[['population_density', 'housing_density', 'total_population']][samples_with_no_data]))

        # Test a few regressions
        for estimator in [LinearRegression(), RandomForestRegressor(), BayesianRidge()]:
            cv_result = cross_validate(estimator, X=X, y=merged['days_from_2020_01_01'][samples_with_data], scoring='r2', cv=2)
            logging.info(f'{estimator.__class__.__name__} CV r2: {cv_result["test_score"].mean()}')

        # Train best model and impute the missing times.
        best_model = BayesianRidge()
        best_model.fit(X=X, y=merged['days_from_2020_01_01'][samples_with_data])

        merged.loc[samples_with_no_data, 'days_from_2020_01_01'] = best_model.predict(X_predict)
        merged.loc[samples_with_no_data, 't0_date'] = datetime.fromisoformat('2020-01-01') \
                                                      + np.array([timedelta(days=t) for t in best_model.predict(X_predict)])

    # Plot doubling time by population density
    merged.loc[samples_with_no_data, 'imputed_start_time'] = True
    merged.loc[samples_with_data, 'imputed_start_time'] = False
    merged.loc[samples_with_data, 'doubling_rate_days'] = np.log(2) * merged['model_params'][samples_with_data].apply(lambda x: x['scale'])
    merged.to_pickle(os.path.join(state_dir, 'data', f'summary__{state}_imputed_start_times.pkl'))

    if generate_report:
        # Plot population density
        plt.figure(figsize=(14, 4))
        for i, x in enumerate(('population_density', 'housing_density', 'total_population')):
            plt.subplot(1, 3, i + 1)
            plt.title(state)
            sns.jointplot(x=np.log10(merged[x]), y='days_from_2020_01_01', data=merged, kind='reg', height=5)
            plt.xlabel('log10 Population Density')
        plt.savefig(os.path.join(state_dir, 'reports', f'summary__{state}__population_density.pdf'), bbox_inches='tight')
        plt.close()

        # Plot Doubling Rates by distance
        # TODO: Impute doubling time.
        sns.jointplot(np.log10(merged.population_density), merged.doubling_rate_days, kind='reg', height=10)
        plt.xlabel('Log10 Population Density', fontsize=16)
        plt.ylabel('Doubling Time [Days]', fontsize=16)
        plt.grid()
        plt.savefig(os.path.join(state_dir, 'reports', f'summary__{state}__doubling_time.pdf'), bbox_inches='tight')
        plt.close()
Beispiel #24
0
    def generate_surge_spreadsheet(self):
        """
        Produce a spreadsheet summarizing peaks.

        Parameters
        ----------
        state: str
            State to generate sheet for.

        Returns
        -------

        """
        df = load_data.load_county_metadata()
        all_fips = load_data.get_all_fips_codes_for_a_state(self.state)
        all_data = {
            fips: load_data.load_ensemble_results(fips)
            for fips in all_fips
        }
        df = df.set_index("fips")

        records = []
        for fips, ensembles in all_data.items():
            county_name = df.loc[fips]["county"]
            t0 = fit_results.load_t0(fips)

            for suppression_policy, ensemble in ensembles.items():

                county_record = dict(
                    county_name=county_name,
                    county_fips=fips,
                    mitigation_policy=policy_to_mitigation(suppression_policy),
                )

                for compartment in [
                        "HGen",
                        "general_admissions_per_day",
                        "HICU",
                        "icu_admissions_per_day",
                        "total_new_infections",
                        "direct_deaths_per_day",
                        "total_deaths",
                        "D",
                ]:
                    compartment_name = compartment_to_name_map[compartment]

                    county_record[compartment_name + " Peak Value Mean"] = (
                        "%.0f" % ensemble[compartment]["peak_value_mean"])
                    county_record[compartment_name + " Peak Value Median"] = (
                        "%.0f" % ensemble[compartment]["peak_value_ci50"])
                    county_record[compartment_name + " Peak Value CI25"] = (
                        "%.0f" % ensemble[compartment]["peak_value_ci25"])
                    county_record[compartment_name + " Peak Value CI75"] = (
                        "%.0f" % ensemble[compartment]["peak_value_ci75"])
                    county_record[compartment_name + " Peak Time Median"] = ((
                        t0 +
                        timedelta(days=ensemble[compartment]["peak_time_ci50"])
                    ).date().isoformat())

                    # Leaving for now...
                    # if 'surge_start' in ensemble[compartment]:
                    #     if not np.isnan(np.nanmean(ensemble[compartment]['surge_start'])):
                    #         county_record[compartment_name + ' Surge Start Mean'] = (t0 + timedelta(days=np.nanmean(ensemble[compartment]['surge_start']))).date().isoformat()
                    #         county_record[compartment_name + ' Surge End Mean'] = (t0 + timedelta(days=np.nanmean(ensemble[compartment]['surge_end']))).date().isoformat()

                records.append(county_record)

        df = pd.DataFrame(records)
        df.write_json(self.surge_filename)
Beispiel #25
0
def fit_county_model(fips):
    """
    Fit the county's current trajectory, using the existing measures. We fit
    only to mortality data if available, else revert to case data.

    We assume a poisson process generates mortalities at a rate defined by the
    underlying dynamical model.

    TODO @ EC: Add hospitalization data when available.

    Parameters
    ----------
    fips: str
        County fips.

    Returns
    -------
    fit_values: dict
        Optimal values from the fitter.
    """
    county_metadata = load_data.load_county_metadata().set_index(
        'fips').loc[fips].to_dict()
    times, observed_new_cases, observed_new_deaths = load_data.load_new_case_data_by_fips(
        fips, t0=ref_date)

    logging.info(
        f'Fitting MLE model to {county_metadata["county"]}, {county_metadata["state"]}'
    )
    SEIR_params = get_average_SEIR_parameters(fips)

    def _fit_seir(R0, t0, eps):
        model = SEIRModel(R0=R0,
                          suppression_policy=suppression_policies.
                          generate_empirical_distancing_policy(
                              t_list, fips, future_suppression=eps),
                          **SEIR_params)
        model.run()

        predicted_cases = model.gamma * np.interp(
            times, t_list + t0, model.results['total_new_infections'])
        predicted_deaths = np.interp(times, t_list + t0,
                                     model.results['direct_deaths_per_day'])

        # Assume the error on the case count could be off by a massive factor 50.
        # Basically we don't want to use it if there appreciable mortality data available.
        # Longer term there is a better procedure.
        cases_variance = 1e10 * observed_new_cases.copy(
        )**2  # Make the stdev N times larger x the number of cases
        deaths_variance = observed_new_deaths.copy()  # Poisson dist error

        # Zero inflated poisson Avoid floating point errors..
        cases_variance[cases_variance == 0] = 1e10
        deaths_variance[deaths_variance == 0] = 1e10

        # Compute Chi2
        chi2_cases = np.sum(
            (observed_new_cases - predicted_cases)**2 / cases_variance)
        if observed_new_deaths.sum() > 5:
            chi2_deaths = np.sum(
                (observed_new_deaths - predicted_deaths)**2 / deaths_variance)
        else:
            chi2_deaths = 0
        return chi2_deaths + chi2_cases

    # Note that error def is not right here. We need a realistic error model...
    m = iminuit.Minuit(_fit_seir,
                       R0=4,
                       t0=50,
                       eps=.5,
                       error_eps=.2,
                       limit_R0=[1, 8],
                       limit_eps=[0, 2],
                       limit_t0=[-90, 90],
                       error_t0=1,
                       error_R0=1.,
                       errordef=1)
    m.migrad()
    values = dict(fips=fips, **dict(m.values))
    values['t0_date'] = ref_date + timedelta(days=values['t0'])
    values['Reff_current'] = values['R0'] * values['eps']
    values['observed_total_deaths'] = np.sum(observed_new_deaths)
    values['county'] = county_metadata['county']
    values['state'] = county_metadata['state']
    values['total_population'] = county_metadata['total_population']
    values['population_density'] = county_metadata['population_density']
    return values