Exemplo n.º 1
0
    def load_observations(fips=None, ref_date=REF_DATE):
        """
        Load observations (new cases, new deaths and hospitalizations) for
        given fips code.

        Parameters
        ----------
        fips: str
            FIPS code.
        ref_date: Datetime
            Reference start date.

        Returns
        -------
        observations: pd.DataFrame
            Contains observations for given fips codes, with columns:
            - new_cases: float, observed new cases
            - new_deaths: float, observed new deaths
            - hospitalizations: float, observed hospitalizatons
            and dates of observation as index.
        """

        observations = {}
        if len(fips) == 5:
            times, observations['new_cases'], observations['new_deaths'] = \
                load_data.load_new_case_data_by_fips(fips, ref_date)
            hospital_times, hospitalizations, hospitalization_data_type = \
                load_data.load_hospitalization_data(fips, t0=ref_date)
            observations['times'] = times.values
        elif len(fips) == 2:
            state_obj = us.states.lookup(fips)
            observations['times'], observations['new_cases'], observations['new_deaths'] = \
                load_data.load_new_case_data_by_state(state_obj.name, ref_date)
            hospital_times, hospitalizations, hospitalization_data_type = \
                load_data.load_hospitalization_data_by_state(state_obj.abbr, t0=ref_date)
            observations['times'] = np.array(observations['times'])

        observations['hospitalizations'] = np.full(
            observations['times'].shape[0], np.nan)
        if hospitalization_data_type is HospitalizationDataType.CUMULATIVE_HOSPITALIZATIONS:
            observations['hospitalizations'][
                hospital_times -
                observations['times'].min()] = np.diff(hospitalizations)
        elif hospitalization_data_type is HospitalizationDataType.CURRENT_HOSPITALIZATIONS:
            observations['hospitalizations'][
                hospital_times -
                observations['times'].min()] = hospitalizations

        observation_dates = [
            ref_date + timedelta(int(t)) for t in observations['times']
        ]
        observations = pd.DataFrame(
            observations,
            index=pd.DatetimeIndex(observation_dates)).dropna(axis=1,
                                                              how='all')

        return observations
Exemplo n.º 2
0
    def __init__(self,
                 fips,
                 ref_date=datetime(year=2020, month=1, day=1),
                 min_deaths=2,
                 n_years=1,
                 cases_to_deaths_err_factor=.5,
                 hospital_to_deaths_err_factor=.5,
                 percent_error_on_max_observation=0.5,
                 with_age_structure=False):

        # Seed the random state. It is unclear whether this propagates to the
        # Minuit optimizer.
        np.random.seed(seed=42)

        self.fips = fips
        self.ref_date = ref_date
        self.min_deaths = min_deaths
        self.t_list = np.linspace(0, int(365 * n_years),
                                  int(365 * n_years) + 1)
        self.cases_to_deaths_err_factor = cases_to_deaths_err_factor
        self.hospital_to_deaths_err_factor = hospital_to_deaths_err_factor
        self.percent_error_on_max_observation = percent_error_on_max_observation
        self.t0_guess = 60
        self.with_age_structure = with_age_structure

        if len(fips) == 2:  # State FIPS are 2 digits
            self.agg_level = AggregationLevel.STATE
            self.state_obj = us.states.lookup(self.fips)
            self.state = self.state_obj.name

            self.times, self.observed_new_cases, self.observed_new_deaths = \
                load_data.load_new_case_data_by_state(self.state, self.ref_date)

            self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \
                load_data.load_hospitalization_data_by_state(self.state_obj.abbr, t0=self.ref_date)
            self.display_name = self.state
        else:
            self.agg_level = AggregationLevel.COUNTY
            geo_metadata = load_data.load_county_metadata().set_index(
                'fips').loc[fips].to_dict()
            state = geo_metadata['state']
            self.state_obj = us.states.lookup(state)
            county = geo_metadata['county']
            if county:
                self.display_name = county + ', ' + state
            else:
                self.display_name = state
            # TODO Swap for new data source.
            self.times, self.observed_new_cases, self.observed_new_deaths = \
                load_data.load_new_case_data_by_fips(self.fips, t0=self.ref_date)
            self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \
                load_data.load_hospitalization_data(self.fips, t0=self.ref_date)

        self.cases_stdev, self.hosp_stdev, self.deaths_stdev = self.calculate_observation_errors(
        )
        self.set_inference_parameters()

        self.model_fit_keys = ['R0', 'eps', 't_break', 'log10_I_initial']

        self.SEIR_kwargs = self.get_average_seir_parameters()
        self.fit_results = None
        self.mle_model = None

        self.chi2_deaths = None
        self.chi2_cases = None
        self.chi2_hosp = None
        self.dof_deaths = None
        self.dof_cases = None
        self.dof_hosp = None
Exemplo n.º 3
0
    def __init__(
        self,
        fips,
        window_size=InferRtConstants.COUNT_SMOOTHING_WINDOW_SIZE,
        kernel_std=5,
        r_list=np.linspace(0, 10, 501),
        process_sigma=0.05,
        ref_date=datetime(year=2020, month=1, day=1),
        confidence_intervals=(0.68, 0.95),
        min_cases=5,
        min_deaths=5,
        include_testing_correction=True,
    ):
        np.random.seed(InferRtConstants.RNG_SEED)
        # Param Generation used for Xcor in align_time_series, has some stochastic FFT elements.
        self.fips = fips
        self.r_list = r_list
        self.window_size = window_size
        self.kernel_std = kernel_std
        self.process_sigma = process_sigma
        self.ref_date = ref_date
        self.confidence_intervals = confidence_intervals
        self.min_cases = min_cases
        self.min_deaths = min_deaths
        self.include_testing_correction = include_testing_correction

        # Because rounding is disabled we don't need high min_deaths, min_cases anymore
        self.min_cases = min(InferRtConstants.MIN_COUNTS_TO_INFER, self.min_cases)
        if not InferRtConstants.DISABLE_DEATHS:
            self.min_deaths = min(InferRtConstants.MIN_COUNTS_TO_INFER, self.min_deaths)

        if len(fips) == 2:  # State FIPS are 2 digits
            self.agg_level = AggregationLevel.STATE
            self.state_obj = us.states.lookup(self.fips)
            self.state = self.state_obj.name

            (
                self.times,
                self.observed_new_cases,
                self.observed_new_deaths,
            ) = load_data.load_new_case_data_by_state(
                self.state,
                self.ref_date,
                include_testing_correction=self.include_testing_correction,
            )
            self.times_raw_new_cases, self.raw_new_cases, _ = load_data.load_new_case_data_by_state(
                self.state, self.ref_date, include_testing_correction=False
            )

            (
                self.hospital_times,
                self.hospitalizations,
                self.hospitalization_data_type,
            ) = load_data.load_hospitalization_data_by_state(
                state=self.state_obj.abbr, t0=self.ref_date
            )
            self.display_name = self.state
        else:
            self.agg_level = AggregationLevel.COUNTY
            self.geo_metadata = (
                load_data.load_county_metadata().set_index("fips").loc[fips].to_dict()
            )
            self.state = self.geo_metadata["state"]
            self.state_obj = us.states.lookup(self.state)
            self.county = self.geo_metadata["county"]
            if self.county:
                self.display_name = self.county + ", " + self.state
            else:
                self.display_name = self.state

            (
                self.times,
                self.observed_new_cases,
                self.observed_new_deaths,
            ) = load_data.load_new_case_data_by_fips(
                self.fips,
                t0=self.ref_date,
                include_testing_correction=self.include_testing_correction,
            )
            (
                self.times_raw_new_cases,
                self.raw_new_cases,
                _,
            ) = load_data.load_new_case_data_by_fips(
                self.fips, t0=self.ref_date, include_testing_correction=False,
            )
            (
                self.hospital_times,
                self.hospitalizations,
                self.hospitalization_data_type,
            ) = load_data.load_hospitalization_data(self.fips, t0=self.ref_date)

        self.case_dates = [ref_date + timedelta(days=int(t)) for t in self.times]
        self.raw_new_case_dates = [
            ref_date + timedelta(days=int(t)) for t in self.times_raw_new_cases
        ]

        if self.hospitalization_data_type:
            self.hospital_dates = [ref_date + timedelta(days=int(t)) for t in self.hospital_times]

        self.default_parameters = ParameterEnsembleGenerator(
            fips=self.fips, N_samples=500, t_list=np.linspace(0, 365, 366)
        ).get_average_seir_parameters()

        # Serial period = Incubation + 0.5 * Infections
        self.serial_period = (
            1 / self.default_parameters["sigma"] + 0.5 * 1 / self.default_parameters["delta"]
        )

        # If we only receive current hospitalizations, we need to account for
        # the outflow to reconstruct new admissions.
        if (
            self.hospitalization_data_type
            is load_data.HospitalizationDataType.CURRENT_HOSPITALIZATIONS
        ):
            los_general = self.default_parameters["hospitalization_length_of_stay_general"]
            los_icu = self.default_parameters["hospitalization_length_of_stay_icu"]
            hosp_rate_general = self.default_parameters["hospitalization_rate_general"]
            hosp_rate_icu = self.default_parameters["hospitalization_rate_icu"]
            icu_rate = hosp_rate_icu / hosp_rate_general
            flow_out_of_hosp = self.hospitalizations[:-1] * (
                (1 - icu_rate) / los_general + icu_rate / los_icu
            )
            # We are attempting to reconstruct the cumulative hospitalizations.
            self.hospitalizations = np.diff(self.hospitalizations) + flow_out_of_hosp
            self.hospital_dates = self.hospital_dates[1:]
            self.hospital_times = self.hospital_times[1:]

        self.log_likelihood = None

        self.log = structlog.getLogger(Rt_Inference_Target=self.display_name)
        self.log.info(event="Running:")
Exemplo n.º 4
0
    def map_fips(self, fips):
        """
        For a given county fips code, generate the CAN UI output format.

        Parameters
        ----------
        fips: str
            County FIPS code to map.
        """
        logging.info(f"Mapping output to WebUI for {self.state}, {fips}")
        pyseir_outputs = load_data.load_ensemble_results(fips)

        if len(fips) == 5 and fips not in self.df_whitelist.fips.values:
            logging.info(f"Excluding {fips} due to white list...")
            return
        try:
            fit_results = load_inference_result(fips)
            t0_simulation = datetime.fromisoformat(fit_results["t0_date"])
        except (KeyError, ValueError):
            logging.error(f"Fit result not found for {fips}. Skipping...")
            return

        # ---------------------------------------------------------------------
        # Rescale hosps based on the population ratio... Could swap this to
        # infection ratio later?
        # ---------------------------------------------------------------------
        hosp_times, current_hosp, _ = load_data.load_hospitalization_data_by_state(
            state=self.state_abbreviation,
            t0=t0_simulation,
            convert_cumulative_to_current=True,
            category="hospitalized",
        )

        _, current_icu, _ = load_data.load_hospitalization_data_by_state(
            state=self.state_abbreviation,
            t0=t0_simulation,
            convert_cumulative_to_current=True,
            category="icu",
        )

        if len(fips) == 5:
            population = self.population_data.get_record_for_fips(fips)[CommonFields.POPULATION]
        else:
            population = self.population_data.get_record_for_state(self.state_abbreviation)[
                CommonFields.POPULATION
            ]

        # logging.info(f'Mapping output to WebUI for {self.state}, {fips}')
        # pyseir_outputs = load_data.load_ensemble_results(fips)
        # if pyseir_outputs is None:
        #     logging.warning(f'Fit result not found for {fips}: Skipping county')
        #     return None

        policies = [key for key in pyseir_outputs.keys() if key.startswith("suppression_policy")]
        if current_hosp is not None:
            t_latest_hosp_data, current_hosp = hosp_times[-1], current_hosp[-1]
            t_latest_hosp_data_date = t0_simulation + timedelta(days=int(t_latest_hosp_data))

            state_hosp_gen = load_data.get_compartment_value_on_date(
                fips=fips[:2], compartment="HGen", date=t_latest_hosp_data_date
            )
            state_hosp_icu = load_data.get_compartment_value_on_date(
                fips=fips[:2], compartment="HICU", date=t_latest_hosp_data_date
            )

            if len(fips) == 5:
                # Rescale the county level hospitalizations by the expected
                # ratio of county / state hospitalizations from simulations.
                # We use ICU data if available too.
                county_hosp = load_data.get_compartment_value_on_date(
                    fips=fips,
                    compartment="HGen",
                    date=t_latest_hosp_data_date,
                    ensemble_results=pyseir_outputs,
                )
                county_icu = load_data.get_compartment_value_on_date(
                    fips=fips,
                    compartment="HICU",
                    date=t_latest_hosp_data_date,
                    ensemble_results=pyseir_outputs,
                )
                current_hosp *= (county_hosp + county_icu) / (state_hosp_gen + state_hosp_icu)

            hosp_rescaling_factor = current_hosp / (state_hosp_gen + state_hosp_icu)

            # Some states have covidtracking issues. We shouldn't ground ICU cases
            # to zero since so far these have all been bad reporting.
            if current_icu is not None and current_icu[-1] > 0:
                icu_rescaling_factor = current_icu[-1] / state_hosp_icu
            else:
                icu_rescaling_factor = current_hosp / (state_hosp_gen + state_hosp_icu)
        else:
            hosp_rescaling_factor = 1.0
            icu_rescaling_factor = 1.0

        # Iterate through each suppression policy.
        # Model output is interpolated to the dates desired for the API.
        for i_policy, suppression_policy in enumerate(
            [key for key in pyseir_outputs.keys() if key.startswith("suppression_policy")]
        ):

            output_for_policy = pyseir_outputs[suppression_policy]
            output_model = pd.DataFrame()

            t_list = output_for_policy["t_list"]
            t_list_downsampled = range(0, int(max(t_list)), self.output_interval_days)

            output_model[schema.DAY_NUM] = t_list_downsampled
            output_model[schema.DATE] = [
                (t0_simulation + timedelta(days=t)).date().strftime("%m/%d/%y")
                for t in t_list_downsampled
            ]
            output_model[schema.TOTAL] = population
            output_model[schema.TOTAL_SUSCEPTIBLE] = np.interp(
                t_list_downsampled, t_list, output_for_policy["S"]["ci_50"]
            )
            output_model[schema.EXPOSED] = np.interp(
                t_list_downsampled, t_list, output_for_policy["E"]["ci_50"]
            )
            output_model[schema.INFECTED] = np.interp(
                t_list_downsampled,
                t_list,
                np.add(output_for_policy["I"]["ci_50"], output_for_policy["A"]["ci_50"]),
            )  # Infected + Asympt.
            output_model[schema.INFECTED_A] = output_model[schema.INFECTED]
            output_model[schema.INFECTED_B] = hosp_rescaling_factor * np.interp(
                t_list_downsampled, t_list, output_for_policy["HGen"]["ci_50"]
            )  # Hosp General
            output_model[schema.INFECTED_C] = icu_rescaling_factor * np.interp(
                t_list_downsampled, t_list, output_for_policy["HICU"]["ci_50"]
            )  # Hosp ICU
            # General + ICU beds. don't include vent here because they are also counted in ICU
            output_model[schema.ALL_HOSPITALIZED] = np.add(
                output_model[schema.INFECTED_B], output_model[schema.INFECTED_C]
            )
            output_model[schema.ALL_INFECTED] = output_model[schema.INFECTED]
            output_model[schema.DEAD] = np.interp(
                t_list_downsampled, t_list, output_for_policy["total_deaths"]["ci_50"]
            )
            final_beds = np.mean(output_for_policy["HGen"]["capacity"])
            output_model[schema.BEDS] = final_beds
            output_model[schema.CUMULATIVE_INFECTED] = np.interp(
                t_list_downsampled,
                t_list,
                np.cumsum(output_for_policy["total_new_infections"]["ci_50"]),
            )

            if fit_results:
                output_model[schema.Rt] = np.interp(
                    t_list_downsampled,
                    t_list,
                    fit_results["eps"] * fit_results["R0"] * np.ones(len(t_list)),
                )
                output_model[schema.Rt_ci90] = np.interp(
                    t_list_downsampled,
                    t_list,
                    2 * fit_results["eps_error"] * fit_results["R0"] * np.ones(len(t_list)),
                )
            else:
                output_model[schema.Rt] = 0
                output_model[schema.Rt_ci90] = 0

            output_model[schema.CURRENT_VENTILATED] = icu_rescaling_factor * np.interp(
                t_list_downsampled, t_list, output_for_policy["HVent"]["ci_50"]
            )
            output_model[schema.POPULATION] = population
            # Average capacity.
            output_model[schema.ICU_BED_CAPACITY] = np.mean(output_for_policy["HICU"]["capacity"])
            output_model[schema.VENTILATOR_CAPACITY] = np.mean(
                output_for_policy["HVent"]["capacity"]
            )

            # Truncate date range of output.
            output_dates = pd.to_datetime(output_model["date"])
            output_model = output_model[
                (output_dates >= datetime(month=3, day=3, year=2020))
                & (output_dates < datetime.today() + timedelta(days=90))
            ]
            output_model = output_model.fillna(0)

            # Fill in results for the Rt indicator.
            try:
                rt_results = load_Rt_result(fips)
                rt_results.index = rt_results["Rt_MAP_composite"].index.strftime("%m/%d/%y")
                merged = output_model.merge(
                    rt_results[["Rt_MAP_composite", "Rt_ci95_composite"]],
                    right_index=True,
                    left_on="date",
                    how="left",
                )
                output_model[schema.RT_INDICATOR] = merged["Rt_MAP_composite"]

                # With 90% probability the value is between rt_indicator - ci90 to rt_indicator + ci90
                output_model[schema.RT_INDICATOR_CI90] = (
                    merged["Rt_ci95_composite"] - merged["Rt_MAP_composite"]
                )
            except (ValueError, KeyError) as e:
                output_model[schema.RT_INDICATOR] = "NaN"
                output_model[schema.RT_INDICATOR_CI90] = "NaN"

            output_model[[schema.RT_INDICATOR, schema.RT_INDICATOR_CI90]] = output_model[
                [schema.RT_INDICATOR, schema.RT_INDICATOR_CI90]
            ].fillna("NaN")

            # Truncate floats and cast as strings to match data model.
            int_columns = [
                col
                for col in output_model.columns
                if col
                not in (
                    schema.DATE,
                    schema.Rt,
                    schema.Rt_ci90,
                    schema.RT_INDICATOR,
                    schema.RT_INDICATOR_CI90,
                )
            ]
            output_model.loc[:, int_columns] = (
                output_model[int_columns].fillna(0).astype(int).astype(str)
            )
            output_model.loc[
                :, [schema.Rt, schema.Rt_ci90, schema.RT_INDICATOR, schema.RT_INDICATOR_CI90]
            ] = (
                output_model[
                    [schema.Rt, schema.Rt_ci90, schema.RT_INDICATOR, schema.RT_INDICATOR_CI90]
                ]
                .fillna(0)
                .round(decimals=4)
                .astype(str)
            )

            # Convert the records format to just list(list(values))
            output_model = [
                [val for val in timestep.values()]
                for timestep in output_model.to_dict(orient="records")
            ]

            output_path = get_run_artifact_path(
                fips, RunArtifact.WEB_UI_RESULT, output_dir=self.output_dir
            )
            policy_enum = Intervention.from_webui_data_adaptor(suppression_policy)
            output_path = output_path.replace("__INTERVENTION_IDX__", str(policy_enum.value))
            with open(output_path, "w") as f:
                json.dump(output_model, f)
Exemplo n.º 5
0
    def __init__(self,
                 fips,
                 window_size=7,
                 kernel_std=2,
                 r_list=np.linspace(0, 10, 501),
                 process_sigma=0.15,
                 ref_date=datetime(year=2020, month=1, day=1),
                 confidence_intervals=(0.68, 0.75, 0.90)):

        self.fips = fips
        self.r_list = r_list
        self.window_size = window_size
        self.kernel_std = kernel_std
        self.process_sigma = process_sigma
        self.ref_date = ref_date
        self.confidence_intervals = confidence_intervals

        if len(fips) == 2:  # State FIPS are 2 digits
            self.agg_level = AggregationLevel.STATE
            self.state_obj = us.states.lookup(self.fips)
            self.state = self.state_obj.name
            self.geo_metadata = load_data.load_county_metadata_by_state(self.state).loc[self.state].to_dict()

            self.times, self.observed_new_cases, self.observed_new_deaths = \
                load_data.load_new_case_data_by_state(self.state, self.ref_date)

            self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \
                load_data.load_hospitalization_data_by_state(self.state_obj.abbr, t0=self.ref_date)
            self.display_name = self.state
        else:
            self.agg_level = AggregationLevel.COUNTY
            self.geo_metadata = load_data.load_county_metadata().set_index('fips').loc[fips].to_dict()
            self.state = self.geo_metadata['state']
            self.state_obj = us.states.lookup(self.state)
            self.county = self.geo_metadata['county']
            if self.county:
                self.display_name = self.county + ', ' + self.state
            else:
                self.display_name = self.state

            # TODO Swap for new data source.
            self.times, self.observed_new_cases, self.observed_new_deaths = \
                load_data.load_new_case_data_by_fips(self.fips, t0=self.ref_date)
            self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \
                load_data.load_hospitalization_data(self.fips, t0=self.ref_date)

        logging.info(f'Running Rt Inference for {self.display_name}')

        self.case_dates = [ref_date + timedelta(days=int(t)) for t in self.times]
        if self.hospitalization_data_type:
            self.hospital_dates = [ref_date + timedelta(days=int(t)) for t in self.hospital_times]

        self.default_parameters = ParameterEnsembleGenerator(
            fips=self.fips,
            N_samples=500,
            t_list=np.linspace(0, 365, 366)
        ).get_average_seir_parameters()

        # Serial period = Incubation + 0.5 * Infections
        self.serial_period = 1 / self.default_parameters['sigma'] + 0.5 * 1 /   self.default_parameters['delta']

        # If we only receive current hospitalizations, we need to account for
        # the outflow to reconstruct new admissions.
        if self.hospitalization_data_type is load_data.HospitalizationDataType.CURRENT_HOSPITALIZATIONS:
            los_general = self.default_parameters['hospitalization_length_of_stay_general']
            los_icu = self.default_parameters['hospitalization_length_of_stay_icu']
            hosp_rate_general = self.default_parameters['hospitalization_rate_general']
            hosp_rate_icu = self.default_parameters['hospitalization_rate_icu']
            icu_rate = hosp_rate_icu / hosp_rate_general
            flow_out_of_hosp = self.hospitalizations[:-1] * ((1 - icu_rate) / los_general + icu_rate / los_icu)
            # We are attempting to reconstruct the cumulative hospitalizations.
            self.hospitalizations = np.diff(self.hospitalizations) + flow_out_of_hosp
            self.hospital_dates = self.hospital_dates[1:]
            self.hospital_times = self.hospital_times[1:]

        self.log_likelihood = None
Exemplo n.º 6
0
    def __init__(
        self,
        fips,
        ref_date=datetime(year=2020, month=1, day=1),
        min_deaths=2,
        n_years=1,
        cases_to_deaths_err_factor=0.5,
        hospital_to_deaths_err_factor=0.5,
        percent_error_on_max_observation=0.5,
        with_age_structure=False,
    ):

        # Seed the random state. It is unclear whether this propagates to the
        # Minuit optimizer.
        np.random.seed(seed=42)

        self.fips = fips
        self.ref_date = ref_date
        self.days_since_ref_date = (dt.date.today() - ref_date.date() -
                                    timedelta(days=7)).days
        # ndays end of 2nd ramp may extend past days_since_ref_date w/o  penalty on chi2 score
        self.days_allowed_beyond_ref = 0
        self.min_deaths = min_deaths
        self.t_list = np.linspace(0, int(365 * n_years),
                                  int(365 * n_years) + 1)
        self.cases_to_deaths_err_factor = cases_to_deaths_err_factor
        self.hospital_to_deaths_err_factor = hospital_to_deaths_err_factor
        self.percent_error_on_max_observation = percent_error_on_max_observation
        self.t0_guess = 60
        self.with_age_structure = with_age_structure

        if len(fips) == 2:  # State FIPS are 2 digits
            self.agg_level = AggregationLevel.STATE
            self.state_obj = us.states.lookup(self.fips)
            self.state = self.state_obj.name

            (
                self.times,
                self.observed_new_cases,
                self.observed_new_deaths,
            ) = load_data.load_new_case_data_by_state(self.state,
                                                      self.ref_date)

            (
                self.hospital_times,
                self.hospitalizations,
                self.hospitalization_data_type,
            ) = load_data.load_hospitalization_data_by_state(
                self.state_obj.abbr, t0=self.ref_date)

            (
                self.icu_times,
                self.icu,
                self.icu_data_type,
            ) = load_data.load_hospitalization_data_by_state(
                self.state_obj.abbr,
                t0=self.ref_date,
                category=HospitalizationCategory.ICU)

            self.display_name = self.state
        else:
            self.agg_level = AggregationLevel.COUNTY
            geo_metadata = load_data.load_county_metadata().set_index(
                "fips").loc[fips].to_dict()
            state = geo_metadata["state"]
            self.state_obj = us.states.lookup(state)
            county = geo_metadata["county"]
            if county:
                self.display_name = county + ", " + state
            else:
                self.display_name = state
            # TODO Swap for new data source.
            (
                self.times,
                self.observed_new_cases,
                self.observed_new_deaths,
            ) = load_data.load_new_case_data_by_fips(self.fips,
                                                     t0=self.ref_date)
            (
                self.hospital_times,
                self.hospitalizations,
                self.hospitalization_data_type,
            ) = load_data.load_hospitalization_data(self.fips,
                                                    t0=self.ref_date)
            (
                self.icu_times,
                self.icu,
                self.icu_data_type,
            ) = load_data.load_hospitalization_data(
                self.fips,
                t0=self.ref_date,
                category=HospitalizationCategory.ICU)

        self.cases_stdev, self.hosp_stdev, self.deaths_stdev = self.calculate_observation_errors(
        )
        self.set_inference_parameters()

        self.model_fit_keys = [
            "R0",
            "eps",
            "t_break",
            "eps2",
            "t_delta_phases",
            "log10_I_initial",
        ]

        self.SEIR_kwargs = self.get_average_seir_parameters()
        self.fit_results = None
        self.mle_model = None

        self.chi2_deaths = None
        self.chi2_cases = None
        self.chi2_hosp = None
        self.dof_deaths = None
        self.dof_cases = None
        self.dof_hosp = None