def run_ensemble(self):
        """
        Run an ensemble of models for each suppression policy nad generate the
        output report / results dataset.
        """
        for suppression_policy_name, suppression_policy in self.suppression_policies.items(
        ):

            logging.info(
                f'Running simulation ensemble for {self.state_name} {self.fips} {suppression_policy_name}'
            )

            if suppression_policy_name == 'suppression_policy__inferred':

                artifact_path = get_run_artifact_path(
                    self.fips, RunArtifact.MLE_FIT_MODEL)
                if os.path.exists(artifact_path):
                    with open(artifact_path, 'rb') as f:
                        model_ensemble = [pickle.load(f)]
                else:
                    logging.warning(
                        f'No MLE model found for {self.state_name}: {self.fips}. Skipping.'
                    )
            else:
                parameter_sampler = ParameterEnsembleGenerator(
                    fips=self.fips,
                    N_samples=self.n_samples,
                    t_list=self.t_list,
                    suppression_policy=suppression_policy)
                parameter_ensemble = parameter_sampler.sample_seir_parameters(
                    override_params=self.override_params)
                model_ensemble = list(
                    map(self._run_single_simulation, parameter_ensemble))

            if self.agg_level is AggregationLevel.COUNTY:
                self.all_outputs['county_metadata'] = self.county_metadata
                self.all_outputs['county_metadata']['age_distribution'] = list(
                    self.all_outputs['county_metadata']['age_distribution'])
                self.all_outputs['county_metadata']['age_bins'] = list(
                    self.all_outputs['county_metadata']['age_distribution'])

            self.all_outputs[
                f'{suppression_policy_name}'] = self._generate_output_for_suppression_policy(
                    model_ensemble)

        if self.generate_report and self.output_file_report:
            report = CountyReport(self.fips,
                                  model_ensemble=model_ensemble,
                                  county_outputs=self.all_outputs,
                                  filename=self.output_file_report,
                                  summary=self.summary)
            report.generate_and_save()

        with open(self.output_file_data, 'w') as f:
            json.dump(self.all_outputs, f)
    def _load_model_for_region(self, scenario="inferred"):
        """
        Try to load a model for the region, else load the state level model and update parameters
        for the region.
        """
        model = self.regional_input.load_mle_fit_model()
        if model:
            inferred_params = self.regional_input.load_inference_result()
        else:
            _log.info(f"No MLE model found. Reverting to state level.",
                      region=self.regional_input.region)
            model = self.regional_input.load_state_mle_fit_model()
            if model:
                inferred_params = self.regional_input.load_state_inference_result(
                )
            else:
                raise FileNotFoundError(
                    f"Could not locate state result for {self.state_name}")

            # Rescale state values to the county population and replace county
            # specific params.
            # TODO: get_average_seir_parameters should return the analytic solution when available
            # right now it runs an average over the ensemble (with N_samples not consistently set
            # across the code base).
            default_params = ParameterEnsembleGenerator(
                N_samples=500,
                combined_datasets_latest=self.regional_input.latest,
                t_list=model.t_list,
                suppression_policy=model.suppression_policy,
            ).get_average_seir_parameters()
            population_ratio = default_params["N"] / model.N
            model.N *= population_ratio
            model.I_initial *= population_ratio
            model.E_initial *= population_ratio
            model.A_initial *= population_ratio
            model.S_initial = model.N - model.I_initial - model.E_initial - model.A_initial

            for key in {"beds_general", "beds_ICU", "ventilators"}:
                setattr(model, key, default_params[key])

        # Determine the appropriate future suppression policy based on the
        # scenario of interest.

        eps_final = sp.estimate_future_suppression_from_fits(inferred_params,
                                                             scenario=scenario)

        model.suppression_policy = sp.get_epsilon_interpolator(
            eps=inferred_params["eps"],
            t_break=inferred_params["t_break"],
            eps2=inferred_params["eps2"],
            t_delta_phases=inferred_params["t_delta_phases"],
            t_break_final=(datetime.datetime.today() -
                           datetime.datetime.fromisoformat(
                               inferred_params["t0_date"])).days,
            eps_final=eps_final,
        )
        model.run()
        return model
示例#3
0
    def _load_model_for_fips(self, scenario='inferred'):
        """
        Try to load a model for the locale, else load the state level model
        and update parameters for the county.
        """
        artifact_path = get_run_artifact_path(self.fips, RunArtifact.MLE_FIT_MODEL)
        if os.path.exists(artifact_path):
            with open(artifact_path, 'rb') as f:
                model = pickle.load(f)
            inferred_params = fit_results.load_inference_result(self.fips)

        else:
            _logger.info(f'No MLE model found for {self.state_name}: {self.fips}. Reverting to state level.')
            artifact_path = get_run_artifact_path(self.fips[:2], RunArtifact.MLE_FIT_MODEL)
            if os.path.exists(artifact_path):
                with open(artifact_path, 'rb') as f:
                    model = pickle.load(f)
                inferred_params = fit_results.load_inference_result(self.fips[:2])
            else:
                raise FileNotFoundError(f'Could not locate state result for {self.state_name}')

            # Rescale state values to the county population and replace county
            # specific params.
            default_params = ParameterEnsembleGenerator(
                self.fips, N_samples=1, t_list=model.t_list,
                suppression_policy=model.suppression_policy).get_average_seir_parameters()
            population_ratio = default_params['N'] / model.N
            model.N *= population_ratio
            model.I_initial *= population_ratio
            model.E_initial *= population_ratio
            model.A_initial *= population_ratio
            model.S_initial = model.N - model.I_initial - model.E_initial - model.A_initial

            for key in {'beds_general', 'beds_ICU', 'ventilators'}:
                setattr(model, key, default_params[key])

        # Determine the appropriate future suppression policy based on the
        # scenario of interest.
        if scenario == 'inferred':
            eps_final = inferred_params['eps']
        else:
            eps_final = sp.get_future_suppression_from_r0(inferred_params['R0'], scenario=scenario)

        model.suppression_policy = sp.generate_two_step_policy(
            self.t_list,
            eps=inferred_params['eps'],
            t_break=inferred_params['t_break'],
            t_break_final=(datetime.datetime.today() - datetime.datetime.fromisoformat(inferred_params['t0_date'])).days,
            eps_final=eps_final
        )
        model.run()
        return model
示例#4
0
    def get_average_seir_parameters(self):
        """
        Generate the additional fitter candidates from the ensemble generator. This
        has the suppression policy and R0 keys removed.

        Returns
        -------
        SEIR_kwargs: dict
            The average ensemble params.
        """
        SEIR_kwargs = ParameterEnsembleGenerator(
            fips=self.fips,
            N_samples=5000,
            t_list=self.t_list,
            suppression_policy=None).get_average_seir_parameters()

        SEIR_kwargs = {
            k: v
            for k, v in SEIR_kwargs.items() if k not in self.fit_params
        }
        del SEIR_kwargs['suppression_policy']
        del SEIR_kwargs['I_initial']
        return SEIR_kwargs
    def get_average_seir_parameters(self):
        """
        Generate the additional fitter candidates from the ensemble generator. This
        has the suppression policy and R0 keys removed.

        Returns
        -------
        SEIR_kwargs: dict
            The average ensemble params.
        """
        SEIR_kwargs = ParameterEnsembleGenerator(
            N_samples=5000,
            t_list=self.t_list,
            combined_datasets_latest=self.regional_input.latest,
            suppression_policy=None,
        ).get_average_seir_parameters()

        SEIR_kwargs = {
            k: v
            for k, v in SEIR_kwargs.items() if k not in self.fit_params
        }
        del SEIR_kwargs["suppression_policy"]
        del SEIR_kwargs["I_initial"]
        return SEIR_kwargs
示例#6
0
def get_average_SEIR_parameters(fips):
    """
    Generate the additional fitter candidates from the ensemble generator. This
    has the suppresssion policy and R0 keys removed.

    Returns
    -------
    params: dict
        The average ensemble params.
    """
    SEIR_kwargs = ParameterEnsembleGenerator(
        fips, N_samples=10000, t_list=t_list,
        suppression_policy=None).get_average_seir_parameters()
    SEIR_kwargs.pop('R0')
    SEIR_kwargs.pop('suppression_policy')
    return SEIR_kwargs
示例#7
0
    def run_ensemble(self):
        """
        Run an ensemble of models for each suppression policy nad generate the
        output report / results dataset.
        """

        for suppression_policy in self.suppression_policy:
            logging.info(f'Generating For Policy {suppression_policy}')

            parameter_ensemble = ParameterEnsembleGenerator(
                fips=self.fips,
                N_samples=self.n_samples,
                t_list=self.t_list,
                suppression_policy=generate_empirical_distancing_policy(
                    t_list=self.t_list,
                    fips=self.fips,
                    future_suppression=suppression_policy)
            ).sample_seir_parameters()

            model_ensemble = list(
                map(self._run_single_simulation, parameter_ensemble))

            logging.info(
                f'Generating Report for suppression policy {suppression_policy}'
            )
            self.all_outputs[f'suppression_policy__{suppression_policy}'] = \
                self._generate_output_for_suppression_policy(model_ensemble, suppression_policy)

        if self.generate_report:
            report = CountyReport(self.fips,
                                  model_ensemble=model_ensemble,
                                  county_outputs=self.all_outputs,
                                  filename=self.output_file_report,
                                  summary=self.summary)
            report.generate_and_save()

        with open(self.output_file_data, 'w') as f:
            json.dump(self.all_outputs, f)
示例#8
0
    def __init__(
        self,
        fips,
        window_size=InferRtConstants.COUNT_SMOOTHING_WINDOW_SIZE,
        kernel_std=5,
        r_list=np.linspace(0, 10, 501),
        process_sigma=0.05,
        ref_date=datetime(year=2020, month=1, day=1),
        confidence_intervals=(0.68, 0.95),
        min_cases=5,
        min_deaths=5,
        include_testing_correction=True,
    ):
        np.random.seed(InferRtConstants.RNG_SEED)
        # Param Generation used for Xcor in align_time_series, has some stochastic FFT elements.
        self.fips = fips
        self.r_list = r_list
        self.window_size = window_size
        self.kernel_std = kernel_std
        self.process_sigma = process_sigma
        self.ref_date = ref_date
        self.confidence_intervals = confidence_intervals
        self.min_cases = min_cases
        self.min_deaths = min_deaths
        self.include_testing_correction = include_testing_correction

        # Because rounding is disabled we don't need high min_deaths, min_cases anymore
        self.min_cases = min(InferRtConstants.MIN_COUNTS_TO_INFER, self.min_cases)
        if not InferRtConstants.DISABLE_DEATHS:
            self.min_deaths = min(InferRtConstants.MIN_COUNTS_TO_INFER, self.min_deaths)

        if len(fips) == 2:  # State FIPS are 2 digits
            self.agg_level = AggregationLevel.STATE
            self.state_obj = us.states.lookup(self.fips)
            self.state = self.state_obj.name

            (
                self.times,
                self.observed_new_cases,
                self.observed_new_deaths,
            ) = load_data.load_new_case_data_by_state(
                self.state,
                self.ref_date,
                include_testing_correction=self.include_testing_correction,
            )
            self.times_raw_new_cases, self.raw_new_cases, _ = load_data.load_new_case_data_by_state(
                self.state, self.ref_date, include_testing_correction=False
            )

            (
                self.hospital_times,
                self.hospitalizations,
                self.hospitalization_data_type,
            ) = load_data.load_hospitalization_data_by_state(
                state=self.state_obj.abbr, t0=self.ref_date
            )
            self.display_name = self.state
        else:
            self.agg_level = AggregationLevel.COUNTY
            self.geo_metadata = (
                load_data.load_county_metadata().set_index("fips").loc[fips].to_dict()
            )
            self.state = self.geo_metadata["state"]
            self.state_obj = us.states.lookup(self.state)
            self.county = self.geo_metadata["county"]
            if self.county:
                self.display_name = self.county + ", " + self.state
            else:
                self.display_name = self.state

            (
                self.times,
                self.observed_new_cases,
                self.observed_new_deaths,
            ) = load_data.load_new_case_data_by_fips(
                self.fips,
                t0=self.ref_date,
                include_testing_correction=self.include_testing_correction,
            )
            (
                self.times_raw_new_cases,
                self.raw_new_cases,
                _,
            ) = load_data.load_new_case_data_by_fips(
                self.fips, t0=self.ref_date, include_testing_correction=False,
            )
            (
                self.hospital_times,
                self.hospitalizations,
                self.hospitalization_data_type,
            ) = load_data.load_hospitalization_data(self.fips, t0=self.ref_date)

        self.case_dates = [ref_date + timedelta(days=int(t)) for t in self.times]
        self.raw_new_case_dates = [
            ref_date + timedelta(days=int(t)) for t in self.times_raw_new_cases
        ]

        if self.hospitalization_data_type:
            self.hospital_dates = [ref_date + timedelta(days=int(t)) for t in self.hospital_times]

        self.default_parameters = ParameterEnsembleGenerator(
            fips=self.fips, N_samples=500, t_list=np.linspace(0, 365, 366)
        ).get_average_seir_parameters()

        # Serial period = Incubation + 0.5 * Infections
        self.serial_period = (
            1 / self.default_parameters["sigma"] + 0.5 * 1 / self.default_parameters["delta"]
        )

        # If we only receive current hospitalizations, we need to account for
        # the outflow to reconstruct new admissions.
        if (
            self.hospitalization_data_type
            is load_data.HospitalizationDataType.CURRENT_HOSPITALIZATIONS
        ):
            los_general = self.default_parameters["hospitalization_length_of_stay_general"]
            los_icu = self.default_parameters["hospitalization_length_of_stay_icu"]
            hosp_rate_general = self.default_parameters["hospitalization_rate_general"]
            hosp_rate_icu = self.default_parameters["hospitalization_rate_icu"]
            icu_rate = hosp_rate_icu / hosp_rate_general
            flow_out_of_hosp = self.hospitalizations[:-1] * (
                (1 - icu_rate) / los_general + icu_rate / los_icu
            )
            # We are attempting to reconstruct the cumulative hospitalizations.
            self.hospitalizations = np.diff(self.hospitalizations) + flow_out_of_hosp
            self.hospital_dates = self.hospital_dates[1:]
            self.hospital_times = self.hospital_times[1:]

        self.log_likelihood = None

        self.log = structlog.getLogger(Rt_Inference_Target=self.display_name)
        self.log.info(event="Running:")
示例#9
0
    def __init__(self,
                 fips,
                 window_size=7,
                 kernel_std=2,
                 r_list=np.linspace(0, 10, 501),
                 process_sigma=0.15,
                 ref_date=datetime(year=2020, month=1, day=1),
                 confidence_intervals=(0.68, 0.75, 0.90)):

        self.fips = fips
        self.r_list = r_list
        self.window_size = window_size
        self.kernel_std = kernel_std
        self.process_sigma = process_sigma
        self.ref_date = ref_date
        self.confidence_intervals = confidence_intervals

        if len(fips) == 2:  # State FIPS are 2 digits
            self.agg_level = AggregationLevel.STATE
            self.state_obj = us.states.lookup(self.fips)
            self.state = self.state_obj.name
            self.geo_metadata = load_data.load_county_metadata_by_state(self.state).loc[self.state].to_dict()

            self.times, self.observed_new_cases, self.observed_new_deaths = \
                load_data.load_new_case_data_by_state(self.state, self.ref_date)

            self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \
                load_data.load_hospitalization_data_by_state(self.state_obj.abbr, t0=self.ref_date)
            self.display_name = self.state
        else:
            self.agg_level = AggregationLevel.COUNTY
            self.geo_metadata = load_data.load_county_metadata().set_index('fips').loc[fips].to_dict()
            self.state = self.geo_metadata['state']
            self.state_obj = us.states.lookup(self.state)
            self.county = self.geo_metadata['county']
            if self.county:
                self.display_name = self.county + ', ' + self.state
            else:
                self.display_name = self.state

            # TODO Swap for new data source.
            self.times, self.observed_new_cases, self.observed_new_deaths = \
                load_data.load_new_case_data_by_fips(self.fips, t0=self.ref_date)
            self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \
                load_data.load_hospitalization_data(self.fips, t0=self.ref_date)

        logging.info(f'Running Rt Inference for {self.display_name}')

        self.case_dates = [ref_date + timedelta(days=int(t)) for t in self.times]
        if self.hospitalization_data_type:
            self.hospital_dates = [ref_date + timedelta(days=int(t)) for t in self.hospital_times]

        self.default_parameters = ParameterEnsembleGenerator(
            fips=self.fips,
            N_samples=500,
            t_list=np.linspace(0, 365, 366)
        ).get_average_seir_parameters()

        # Serial period = Incubation + 0.5 * Infections
        self.serial_period = 1 / self.default_parameters['sigma'] + 0.5 * 1 /   self.default_parameters['delta']

        # If we only receive current hospitalizations, we need to account for
        # the outflow to reconstruct new admissions.
        if self.hospitalization_data_type is load_data.HospitalizationDataType.CURRENT_HOSPITALIZATIONS:
            los_general = self.default_parameters['hospitalization_length_of_stay_general']
            los_icu = self.default_parameters['hospitalization_length_of_stay_icu']
            hosp_rate_general = self.default_parameters['hospitalization_rate_general']
            hosp_rate_icu = self.default_parameters['hospitalization_rate_icu']
            icu_rate = hosp_rate_icu / hosp_rate_general
            flow_out_of_hosp = self.hospitalizations[:-1] * ((1 - icu_rate) / los_general + icu_rate / los_icu)
            # We are attempting to reconstruct the cumulative hospitalizations.
            self.hospitalizations = np.diff(self.hospitalizations) + flow_out_of_hosp
            self.hospital_dates = self.hospital_dates[1:]
            self.hospital_times = self.hospital_times[1:]

        self.log_likelihood = None
    def _load_model_for_fips(self, scenario="inferred"):
        """
        Try to load a model for the locale, else load the state level model
        and update parameters for the county.
        """
        artifact_path = get_run_artifact_path(self.fips, RunArtifact.MLE_FIT_MODEL)
        if os.path.exists(artifact_path):
            with open(artifact_path, "rb") as f:
                model = pickle.load(f)
            inferred_params = fit_results.load_inference_result(self.fips)

        else:
            _logger.info(
                f"No MLE model found for {self.state_name}: {self.fips}. Reverting to state level."
            )
            artifact_path = get_run_artifact_path(self.fips[:2], RunArtifact.MLE_FIT_MODEL)
            if os.path.exists(artifact_path):
                with open(artifact_path, "rb") as f:
                    model = pickle.load(f)
                inferred_params = fit_results.load_inference_result(self.fips[:2])
            else:
                raise FileNotFoundError(f"Could not locate state result for {self.state_name}")

            # Rescale state values to the county population and replace county
            # specific params.
            # TODO: get_average_seir_parameters should return the analytic solution when available
            # right now it runs an average over the ensemble (with N_samples not consistently set
            # across the code base).
            default_params = ParameterEnsembleGenerator(
                self.fips,
                N_samples=500,
                t_list=model.t_list,
                suppression_policy=model.suppression_policy,
            ).get_average_seir_parameters()
            population_ratio = default_params["N"] / model.N
            model.N *= population_ratio
            model.I_initial *= population_ratio
            model.E_initial *= population_ratio
            model.A_initial *= population_ratio
            model.S_initial = model.N - model.I_initial - model.E_initial - model.A_initial

            for key in {"beds_general", "beds_ICU", "ventilators"}:
                setattr(model, key, default_params[key])

        # Determine the appropriate future suppression policy based on the
        # scenario of interest.

        eps_final = sp.estimate_future_suppression_from_fits(inferred_params, scenario=scenario)

        model.suppression_policy = sp.get_epsilon_interpolator(
            eps=inferred_params["eps"],
            t_break=inferred_params["t_break"],
            eps2=inferred_params["eps2"],
            t_delta_phases=inferred_params["t_delta_phases"],
            t_break_final=(
                datetime.datetime.today()
                - datetime.datetime.fromisoformat(inferred_params["t0_date"])
            ).days,
            eps_final=eps_final,
        )
        model.run()
        return model
示例#11
0
def load_hospitalization_data_by_state(state,
                                       t0,
                                       convert_cumulative_to_current=False,
                                       category="hospitalized"):
    """
    Obtain hospitalization data. We clip because there are sometimes negatives
    either due to data reporting or corrections in case count. These are always
    tiny so we just make downstream easier to work with by clipping.

    Parameters
    ----------
    state: str
        State to lookup.
    t0: datetime
        Datetime to offset by.
    convert_cumulative_to_current: bool
        If True, and only cumulative hospitalizations are available, convert the
        current hospitalizations to the current value.
    category: str
        'icu' for just ICU or 'hospitalized' for all ICU + Acute.

    Returns
    -------
    times: array(float) or NoneType
        List of float days since t0 for the hospitalization data.
    observed_hospitalizations: array(int) or NoneType
        Array of new cases observed each day.
    type: HospitalizationDataType
        Specifies cumulative or current hospitalizations.
    """
    abbr = us.states.lookup(state).abbr
    hospitalization_data = (
        combined_datasets.build_us_timeseries_with_all_fields().get_subset(
            AggregationLevel.STATE, country="USA",
            state=abbr).get_data(country="USA", state=abbr))

    categories = ["icu", "hospitalized"]
    if category not in categories:
        raise ValueError(
            f"Hospitalization category {category} is not in {categories}")

    if len(hospitalization_data) == 0:
        return None, None, None

    if (hospitalization_data[f"current_{category}"] > 0).any():
        hospitalization_data = hospitalization_data[
            hospitalization_data[f"current_{category}"].notnull()]
        times_new = (hospitalization_data["date"].dt.date -
                     t0.date()).dt.days.values
        return (
            times_new,
            hospitalization_data[f"current_{category}"].values.clip(min=0),
            HospitalizationDataType.CURRENT_HOSPITALIZATIONS,
        )
    elif (hospitalization_data[f"cumulative_{category}"] > 0).any():
        hospitalization_data = hospitalization_data[
            hospitalization_data[f"cumulative_{category}"].notnull()]
        times_new = (hospitalization_data["date"].dt.date -
                     t0.date()).dt.days.values
        cumulative = hospitalization_data[
            f"cumulative_{category}"].values.clip(min=0)
        # Some minor glitches for a few states..
        for i, val in enumerate(cumulative[1:]):
            if cumulative[i] > cumulative[i + 1]:
                cumulative[i] = cumulative[i + 1]

        if convert_cumulative_to_current:
            # Must be here to avoid circular import. This is required to convert
            # cumulative hosps to current hosps. We also just use a dummy fips and t_list.
            from pyseir.parameters.parameter_ensemble_generator import ParameterEnsembleGenerator

            params = ParameterEnsembleGenerator(
                fips="06", t_list=[],
                N_samples=1).get_average_seir_parameters()
            if category == "hospitalized":
                average_length_of_stay = (
                    params["hospitalization_rate_general"] *
                    params["hospitalization_length_of_stay_general"] +
                    params["hospitalization_rate_icu"] *
                    (1 - params["fraction_icu_requiring_ventilator"]) *
                    params["hospitalization_length_of_stay_icu"] +
                    params["hospitalization_rate_icu"] *
                    params["fraction_icu_requiring_ventilator"] *
                    params["hospitalization_length_of_stay_icu_and_ventilator"]
                ) / (params["hospitalization_rate_general"] +
                     params["hospitalization_rate_icu"])
            else:
                average_length_of_stay = (
                    (1 - params["fraction_icu_requiring_ventilator"]) *
                    params["hospitalization_length_of_stay_icu"] +
                    params["fraction_icu_requiring_ventilator"] *
                    params["hospitalization_length_of_stay_icu_and_ventilator"]
                )

            # Now compute a cumulative sum, but at each day, subtract the discharges from the previous count.
            new_hospitalizations = np.append([0], np.diff(cumulative))
            current = [0]
            for i, new_hosps in enumerate(new_hospitalizations[1:]):
                current.append(current[i] + new_hosps -
                               current[i] / average_length_of_stay)
            return times_new, current, HospitalizationDataType.CURRENT_HOSPITALIZATIONS
        else:
            return times_new, cumulative, HospitalizationDataType.CUMULATIVE_HOSPITALIZATIONS
    else:
        return None, None, None
    def init_run_mode(self):
        """
        Based on the run mode, generate suppression policies and ensemble
        parameters.  This enables different model combinations and project
        phases.
        """
        self.suppression_policies = dict()

        if self.run_mode is RunMode.CAN_BEFORE_HOSPITALIZATION:
            self.n_samples = 1

            for scenario in [
                    'no_intervention', 'flatten_the_curve', 'full_containment',
                    'social_distancing'
            ]:
                R0 = 3.6
                self.override_params['R0'] = R0
                policy = generate_covidactnow_scenarios(
                    t_list=self.t_list,
                    R0=R0,
                    t0=datetime.datetime.today(),
                    scenario=scenario)
                self.suppression_policies[
                    f'suppression_policy__{scenario}'] = policy
                self.override_params = ParameterEnsembleGenerator(
                    self.fips,
                    N_samples=500,
                    t_list=self.t_list,
                    suppression_policy=policy).get_average_seir_parameters()

            self.override_params['mortality_rate_no_general_beds'] = 0.0
            self.override_params['mortality_rate_from_hospital'] = 0.0
            self.override_params['mortality_rate_from_ICU'] = 0.40
            self.override_params['mortality_rate_no_ICU_beds'] = 1.0

            self.override_params['hospitalization_length_of_stay_general'] = 6
            self.override_params['hospitalization_length_of_stay_icu'] = 13
            self.override_params[
                'hospitalization_length_of_stay_icu_and_ventilator'] = 14

            self.override_params['hospitalization_rate_general'] = 0.0727
            self.override_params[
                'hospitalization_rate_icu'] = 0.13 * self.override_params[
                    'hospitalization_rate_general']
            self.override_params['beds_ICU'] = 0
            self.override_params['symptoms_to_hospital_days'] = 6

            if len(self.covid_data) > 0 and self.covid_data.cases.max() > 0:
                self.t0 = self.covid_data.date.max()
                self.t0, hospitalizations_total = self.get_initial_hospitalizations(
                )

                self.override_params[
                    'HGen_initial'] = hospitalizations_total * (
                        1 - self.override_params['hospitalization_rate_icu'] /
                        self.override_params['hospitalization_rate_general'])
                self.override_params[
                    'HICU_initial'] = hospitalizations_total * self.override_params[
                        'hospitalization_rate_icu'] / self.override_params[
                            'hospitalization_rate_general']
                self.override_params[
                    'HICUVent_initial'] = self.override_params[
                        'HICU_initial'] * self.override_params[
                            'fraction_icu_requiring_ventilator']
                self.override_params[
                    'I_initial'] = hospitalizations_total / self.override_params[
                        'hospitalization_rate_general']

                # The following two params disable the asymptomatic compartment.
                self.override_params['A_initial'] = 0
                self.override_params[
                    'gamma'] = 1  # 100% of Exposed go to the infected bucket.

                # 0.6 is a ~ steady state for the exposed bucket initialization at Reff ~ 1.2
                self.override_params['E_initial'] = 0.6 * (
                    self.override_params['I_initial'] +
                    self.override_params['A_initial'])
                self.override_params['D_initial'] = self.covid_data.deaths.max(
                )

            else:
                self.t0 = datetime.datetime.today()
                self.override_params['I_initial'] = 1
                self.override_params['A_initial'] = 0
                self.override_params[
                    'gamma'] = 1  # 100% of Exposed go to the infected bucket.

        elif self.run_mode is RunMode.CAN_BEFORE_HOSPITALIZATION_NEW_PARAMS:
            self.n_samples = 1

            for scenario in [
                    'no_intervention', 'flatten_the_curve', 'inferred',
                    'social_distancing'
            ]:
                R0 = 3.6
                self.override_params['R0'] = R0
                if scenario != 'inferred':
                    policy = generate_covidactnow_scenarios(
                        t_list=self.t_list,
                        R0=R0,
                        t0=datetime.datetime.today(),
                        scenario=scenario)
                else:
                    policy = None
                self.suppression_policies[
                    f'suppression_policy__{scenario}'] = policy
                self.override_params = ParameterEnsembleGenerator(
                    self.fips,
                    N_samples=500,
                    t_list=self.t_list,
                    suppression_policy=policy).get_average_seir_parameters()

            if len(self.covid_data) > 0 and self.covid_data.cases.max() > 0:
                self.t0 = self.covid_data.date.max()
                self.t0, hospitalizations_total = self.get_initial_hospitalizations(
                )

                self.override_params[
                    'HGen_initial'] = hospitalizations_total * (
                        1 - self.override_params['hospitalization_rate_icu'])
                self.override_params[
                    'HICU_initial'] = hospitalizations_total * self.override_params[
                        'hospitalization_rate_icu']
                self.override_params[
                    'HICUVent_initial'] = self.override_params[
                        'HICU_initial'] * self.override_params[
                            'fraction_icu_requiring_ventilator']
                self.override_params[
                    'I_initial'] = hospitalizations_total / self.override_params[
                        'hospitalization_rate_general']

                # The following two params disable the asymptomatic compartment.
                self.override_params['A_initial'] = 0
                self.override_params[
                    'gamma'] = 1  # 100% of Exposed go to the infected bucket.

                # 1.2 is a ~ steady state for the exposed bucket initialization.
                self.override_params['E_initial'] = 0.6 * (
                    self.override_params['I_initial'] +
                    self.override_params['A_initial'])
                self.override_params['D_initial'] = self.covid_data.deaths.max(
                )

        elif self.run_mode is RunMode.DEFAULT:
            for suppression_policy in self.suppression_policy:
                self.suppression_policies[
                    f'suppression_policy__{suppression_policy}'] = generate_empirical_distancing_policy(
                        t_list=self.t_list,
                        fips=self.fips,
                        future_suppression=suppression_policy)
            self.override_params = dict()
        else:
            raise ValueError('Invalid run mode.')