Exemple #1
0
def _generate_input_data(
    fips: str,
    include_testing_correction: bool,
    include_deaths: bool,
    figure_collector: Optional[list],
):
    """
    Allow the RtInferenceEngine to be agnostic to aggregation level by handling the loading first

    include_testing_correction: bool
        If True, include a correction for testing increases and decreases.
    """
    times, observed_new_cases, observed_new_deaths = load_data.load_new_case_data_by_fips(
        fips,
        t0=InferRtConstants.REF_DATE,
        include_testing_correction=include_testing_correction)

    date = [InferRtConstants.REF_DATE + timedelta(days=int(t)) for t in times]

    df = filter_and_smooth_input_data(
        df=pd.DataFrame(dict(cases=observed_new_cases,
                             deaths=observed_new_deaths),
                        index=date),
        include_deaths=include_deaths,
        figure_collector=figure_collector,
        display_name=fips,
        log=rt_log.new(fips=fips),
    )
    return df
    def load_observations(fips=None, ref_date=REF_DATE):
        """
        Load observations (new cases, new deaths and hospitalizations) for
        given fips code.

        Parameters
        ----------
        fips: str
            FIPS code.
        ref_date: Datetime
            Reference start date.

        Returns
        -------
        observations: pd.DataFrame
            Contains observations for given fips codes, with columns:
            - new_cases: float, observed new cases
            - new_deaths: float, observed new deaths
            - hospitalizations: float, observed hospitalizatons
            and dates of observation as index.
        """

        observations = {}
        if len(fips) == 5:
            times, observations['new_cases'], observations['new_deaths'] = \
                load_data.load_new_case_data_by_fips(fips, ref_date)
            hospital_times, hospitalizations, hospitalization_data_type = \
                load_data.load_hospitalization_data(fips, t0=ref_date)
            observations['times'] = times.values
        elif len(fips) == 2:
            state_obj = us.states.lookup(fips)
            observations['times'], observations['new_cases'], observations['new_deaths'] = \
                load_data.load_new_case_data_by_state(state_obj.name, ref_date)
            hospital_times, hospitalizations, hospitalization_data_type = \
                load_data.load_hospitalization_data_by_state(state_obj.abbr, t0=ref_date)
            observations['times'] = np.array(observations['times'])

        observations['hospitalizations'] = np.full(
            observations['times'].shape[0], np.nan)
        if hospitalization_data_type is HospitalizationDataType.CUMULATIVE_HOSPITALIZATIONS:
            observations['hospitalizations'][
                hospital_times -
                observations['times'].min()] = np.diff(hospitalizations)
        elif hospitalization_data_type is HospitalizationDataType.CURRENT_HOSPITALIZATIONS:
            observations['hospitalizations'][
                hospital_times -
                observations['times'].min()] = hospitalizations

        observation_dates = [
            ref_date + timedelta(int(t)) for t in observations['times']
        ]
        observations = pd.DataFrame(
            observations,
            index=pd.DatetimeIndex(observation_dates)).dropna(axis=1,
                                                              how='all')

        return observations
Exemple #3
0
def _whitelist_candidates_per_fips(fips):
    times, observed_new_cases, observed_new_deaths = load_data.load_new_case_data_by_fips(
        fips, t0=datetime(day=1, month=1, year=2020))

    record = dict(fips=fips,
                  total_cases=observed_new_cases.sum(),
                  total_deaths=observed_new_deaths.sum(),
                  nonzero_case_datapoints=np.sum(observed_new_cases > 0),
                  nonzero_death_datapoints=np.sum(observed_new_deaths > 0))
    return pd.Series(record)
Exemple #4
0
    def run_for_fips(cls, fips, n_retries=3, with_age_structure=False):
        """
        Run the model fitter for a state or county fips code.

        Parameters
        ----------
        fips: str
            2-digit state or 5-digit county fips code.
        n_retries: int
            The model fitter is stochastic in nature and a seed cannot be set.
            This is a bandaid until more sophisticated retries can be
            implemented.
        with_age_structure: bool
            If True run model with age structure.

        Returns
        -------
        : ModelFitter
        """
        # Assert that there are some cases for counties
        if len(fips) == 5:
            _, observed_new_cases, _ = load_data.load_new_case_data_by_fips(
                fips, t0=datetime.today())
            if observed_new_cases.sum() < 1:
                return None

        try:
            retries_left = n_retries
            model_is_empty = True
            while retries_left > 0 and model_is_empty:
                model_fitter = cls(fips=fips,
                                   with_age_structure=with_age_structure)
                try:
                    model_fitter.fit()
                    if model_fitter.mle_model and os.environ.get(
                            'PYSEIR_PLOT_RESULTS') == 'True':
                        model_fitter.plot_fitting_results()
                except RuntimeError as e:
                    logging.warning('No convergence.. Retrying ' + str(e))
                retries_left = retries_left - 1
                if model_fitter.mle_model:
                    model_is_empty = False
            if retries_left <= 0 and model_is_empty:
                raise RuntimeError(
                    f'Could not converge after {n_retries} for fips {fips}')
        except Exception:
            logging.exception(f"Failed to run {fips}")
            return None
        return model_fitter
    def generate_whitelist(self):
        """
        Generate a county whitelist based on the cuts above.

        Returns
        -------
        df: whitelist
        """
        logging.info('Generating county level whitelist...')

        whitelist_generator_inputs = []
        for fips in self.county_metadata.fips:
            times, observed_new_cases, observed_new_deaths = load_data.load_new_case_data_by_fips(
                fips, t0=datetime(day=1, month=1, year=2020))

            metadata = self.county_metadata[self.county_metadata.fips == fips].iloc[0].to_dict()

            record = dict(
                fips=fips,
                state=metadata['state'],
                county=metadata['county'],
                total_cases=observed_new_cases.sum(),
                total_deaths=observed_new_deaths.sum(),
                nonzero_case_datapoints=np.sum(observed_new_cases > 0),
                nonzero_death_datapoints=np.sum(observed_new_deaths > 0)
            )
            whitelist_generator_inputs.append(record)

        df_candidates = pd.DataFrame(whitelist_generator_inputs)

        df_whitelist = df_candidates[['fips', 'state', 'county']]
        df_whitelist.loc[:, 'inference_ok'] = (
                  (df_candidates.nonzero_case_datapoints >= self.nonzero_case_datapoints)
                & (df_candidates.nonzero_death_datapoints >= self.nonzero_death_datapoints)
                & (df_candidates.total_cases >= self.total_cases)
                & (df_candidates.total_deaths >= self.total_deaths)
        )

        output_path = get_run_artifact_path(
            fips='06', # Dummy fips since not used here...
            artifact=RunArtifact.WHITELIST_RESULT)
        df_whitelist.to_json(output_path)

        return df_whitelist
Exemple #6
0
    def run_for_fips(cls, fips, n_retries=3):
        """
        Run the model fitter for a state or county fips code.

        Parameters
        ----------
        fips: str
            2-digit state or 5-digit county fips code.
        n_retries: int
            The model fitter is stochastic in nature and a seed cannot be set.
            This is a bandaid until more sophisticated retries can be
            implemented.

        Returns
        -------
        : ModelFitter
        """
        # Assert that there are some cases for counties
        if len(fips) == 5:
            _, observed_new_cases, _ = load_data.load_new_case_data_by_fips(
                fips, t0=datetime.today())
            if observed_new_cases.sum() < 1:
                return None

        try:
            for i in range(n_retries):
                model_fitter = cls(fips)
                try:
                    model_fitter.fit()
                    if model_fitter.mle_model:
                        model_fitter.plot_fitting_results()
                        break
                except RuntimeError as e:
                    logging.warning('No convergence.. Retrying ' + str(e))
            if model_fitter.mle_model is None:
                raise RuntimeError(
                    f'Could not converge after {n_retries} for fips {fips}')
        except Exception:
            logging.exception(f"Failed to run {fips}")
            return None
        return model_fitter
Exemple #7
0
    def __init__(self,
                 fips,
                 ref_date=datetime(year=2020, month=1, day=1),
                 min_deaths=2,
                 n_years=1,
                 cases_to_deaths_err_factor=.5,
                 hospital_to_deaths_err_factor=.5,
                 percent_error_on_max_observation=0.5,
                 with_age_structure=False):

        # Seed the random state. It is unclear whether this propagates to the
        # Minuit optimizer.
        np.random.seed(seed=42)

        self.fips = fips
        self.ref_date = ref_date
        self.min_deaths = min_deaths
        self.t_list = np.linspace(0, int(365 * n_years),
                                  int(365 * n_years) + 1)
        self.cases_to_deaths_err_factor = cases_to_deaths_err_factor
        self.hospital_to_deaths_err_factor = hospital_to_deaths_err_factor
        self.percent_error_on_max_observation = percent_error_on_max_observation
        self.t0_guess = 60
        self.with_age_structure = with_age_structure

        if len(fips) == 2:  # State FIPS are 2 digits
            self.agg_level = AggregationLevel.STATE
            self.state_obj = us.states.lookup(self.fips)
            self.state = self.state_obj.name

            self.times, self.observed_new_cases, self.observed_new_deaths = \
                load_data.load_new_case_data_by_state(self.state, self.ref_date)

            self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \
                load_data.load_hospitalization_data_by_state(self.state_obj.abbr, t0=self.ref_date)
            self.display_name = self.state
        else:
            self.agg_level = AggregationLevel.COUNTY
            geo_metadata = load_data.load_county_metadata().set_index(
                'fips').loc[fips].to_dict()
            state = geo_metadata['state']
            self.state_obj = us.states.lookup(state)
            county = geo_metadata['county']
            if county:
                self.display_name = county + ', ' + state
            else:
                self.display_name = state
            # TODO Swap for new data source.
            self.times, self.observed_new_cases, self.observed_new_deaths = \
                load_data.load_new_case_data_by_fips(self.fips, t0=self.ref_date)
            self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \
                load_data.load_hospitalization_data(self.fips, t0=self.ref_date)

        self.cases_stdev, self.hosp_stdev, self.deaths_stdev = self.calculate_observation_errors(
        )
        self.set_inference_parameters()

        self.model_fit_keys = ['R0', 'eps', 't_break', 'log10_I_initial']

        self.SEIR_kwargs = self.get_average_seir_parameters()
        self.fit_results = None
        self.mle_model = None

        self.chi2_deaths = None
        self.chi2_cases = None
        self.chi2_hosp = None
        self.dof_deaths = None
        self.dof_cases = None
        self.dof_hosp = None
Exemple #8
0
    def __init__(
        self,
        fips,
        window_size=InferRtConstants.COUNT_SMOOTHING_WINDOW_SIZE,
        kernel_std=5,
        r_list=np.linspace(0, 10, 501),
        process_sigma=0.05,
        ref_date=datetime(year=2020, month=1, day=1),
        confidence_intervals=(0.68, 0.95),
        min_cases=5,
        min_deaths=5,
        include_testing_correction=True,
    ):
        np.random.seed(InferRtConstants.RNG_SEED)
        # Param Generation used for Xcor in align_time_series, has some stochastic FFT elements.
        self.fips = fips
        self.r_list = r_list
        self.window_size = window_size
        self.kernel_std = kernel_std
        self.process_sigma = process_sigma
        self.ref_date = ref_date
        self.confidence_intervals = confidence_intervals
        self.min_cases = min_cases
        self.min_deaths = min_deaths
        self.include_testing_correction = include_testing_correction

        # Because rounding is disabled we don't need high min_deaths, min_cases anymore
        self.min_cases = min(InferRtConstants.MIN_COUNTS_TO_INFER, self.min_cases)
        if not InferRtConstants.DISABLE_DEATHS:
            self.min_deaths = min(InferRtConstants.MIN_COUNTS_TO_INFER, self.min_deaths)

        if len(fips) == 2:  # State FIPS are 2 digits
            self.agg_level = AggregationLevel.STATE
            self.state_obj = us.states.lookup(self.fips)
            self.state = self.state_obj.name

            (
                self.times,
                self.observed_new_cases,
                self.observed_new_deaths,
            ) = load_data.load_new_case_data_by_state(
                self.state,
                self.ref_date,
                include_testing_correction=self.include_testing_correction,
            )
            self.times_raw_new_cases, self.raw_new_cases, _ = load_data.load_new_case_data_by_state(
                self.state, self.ref_date, include_testing_correction=False
            )

            (
                self.hospital_times,
                self.hospitalizations,
                self.hospitalization_data_type,
            ) = load_data.load_hospitalization_data_by_state(
                state=self.state_obj.abbr, t0=self.ref_date
            )
            self.display_name = self.state
        else:
            self.agg_level = AggregationLevel.COUNTY
            self.geo_metadata = (
                load_data.load_county_metadata().set_index("fips").loc[fips].to_dict()
            )
            self.state = self.geo_metadata["state"]
            self.state_obj = us.states.lookup(self.state)
            self.county = self.geo_metadata["county"]
            if self.county:
                self.display_name = self.county + ", " + self.state
            else:
                self.display_name = self.state

            (
                self.times,
                self.observed_new_cases,
                self.observed_new_deaths,
            ) = load_data.load_new_case_data_by_fips(
                self.fips,
                t0=self.ref_date,
                include_testing_correction=self.include_testing_correction,
            )
            (
                self.times_raw_new_cases,
                self.raw_new_cases,
                _,
            ) = load_data.load_new_case_data_by_fips(
                self.fips, t0=self.ref_date, include_testing_correction=False,
            )
            (
                self.hospital_times,
                self.hospitalizations,
                self.hospitalization_data_type,
            ) = load_data.load_hospitalization_data(self.fips, t0=self.ref_date)

        self.case_dates = [ref_date + timedelta(days=int(t)) for t in self.times]
        self.raw_new_case_dates = [
            ref_date + timedelta(days=int(t)) for t in self.times_raw_new_cases
        ]

        if self.hospitalization_data_type:
            self.hospital_dates = [ref_date + timedelta(days=int(t)) for t in self.hospital_times]

        self.default_parameters = ParameterEnsembleGenerator(
            fips=self.fips, N_samples=500, t_list=np.linspace(0, 365, 366)
        ).get_average_seir_parameters()

        # Serial period = Incubation + 0.5 * Infections
        self.serial_period = (
            1 / self.default_parameters["sigma"] + 0.5 * 1 / self.default_parameters["delta"]
        )

        # If we only receive current hospitalizations, we need to account for
        # the outflow to reconstruct new admissions.
        if (
            self.hospitalization_data_type
            is load_data.HospitalizationDataType.CURRENT_HOSPITALIZATIONS
        ):
            los_general = self.default_parameters["hospitalization_length_of_stay_general"]
            los_icu = self.default_parameters["hospitalization_length_of_stay_icu"]
            hosp_rate_general = self.default_parameters["hospitalization_rate_general"]
            hosp_rate_icu = self.default_parameters["hospitalization_rate_icu"]
            icu_rate = hosp_rate_icu / hosp_rate_general
            flow_out_of_hosp = self.hospitalizations[:-1] * (
                (1 - icu_rate) / los_general + icu_rate / los_icu
            )
            # We are attempting to reconstruct the cumulative hospitalizations.
            self.hospitalizations = np.diff(self.hospitalizations) + flow_out_of_hosp
            self.hospital_dates = self.hospital_dates[1:]
            self.hospital_times = self.hospital_times[1:]

        self.log_likelihood = None

        self.log = structlog.getLogger(Rt_Inference_Target=self.display_name)
        self.log.info(event="Running:")
    def __init__(self,
                 fips,
                 window_size=7,
                 kernel_std=2,
                 r_list=np.linspace(0, 10, 501),
                 process_sigma=0.15,
                 ref_date=datetime(year=2020, month=1, day=1),
                 confidence_intervals=(0.68, 0.75, 0.90)):

        self.fips = fips
        self.r_list = r_list
        self.window_size = window_size
        self.kernel_std = kernel_std
        self.process_sigma = process_sigma
        self.ref_date = ref_date
        self.confidence_intervals = confidence_intervals

        if len(fips) == 2:  # State FIPS are 2 digits
            self.agg_level = AggregationLevel.STATE
            self.state_obj = us.states.lookup(self.fips)
            self.state = self.state_obj.name
            self.geo_metadata = load_data.load_county_metadata_by_state(self.state).loc[self.state].to_dict()

            self.times, self.observed_new_cases, self.observed_new_deaths = \
                load_data.load_new_case_data_by_state(self.state, self.ref_date)

            self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \
                load_data.load_hospitalization_data_by_state(self.state_obj.abbr, t0=self.ref_date)
            self.display_name = self.state
        else:
            self.agg_level = AggregationLevel.COUNTY
            self.geo_metadata = load_data.load_county_metadata().set_index('fips').loc[fips].to_dict()
            self.state = self.geo_metadata['state']
            self.state_obj = us.states.lookup(self.state)
            self.county = self.geo_metadata['county']
            if self.county:
                self.display_name = self.county + ', ' + self.state
            else:
                self.display_name = self.state

            # TODO Swap for new data source.
            self.times, self.observed_new_cases, self.observed_new_deaths = \
                load_data.load_new_case_data_by_fips(self.fips, t0=self.ref_date)
            self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \
                load_data.load_hospitalization_data(self.fips, t0=self.ref_date)

        logging.info(f'Running Rt Inference for {self.display_name}')

        self.case_dates = [ref_date + timedelta(days=int(t)) for t in self.times]
        if self.hospitalization_data_type:
            self.hospital_dates = [ref_date + timedelta(days=int(t)) for t in self.hospital_times]

        self.default_parameters = ParameterEnsembleGenerator(
            fips=self.fips,
            N_samples=500,
            t_list=np.linspace(0, 365, 366)
        ).get_average_seir_parameters()

        # Serial period = Incubation + 0.5 * Infections
        self.serial_period = 1 / self.default_parameters['sigma'] + 0.5 * 1 /   self.default_parameters['delta']

        # If we only receive current hospitalizations, we need to account for
        # the outflow to reconstruct new admissions.
        if self.hospitalization_data_type is load_data.HospitalizationDataType.CURRENT_HOSPITALIZATIONS:
            los_general = self.default_parameters['hospitalization_length_of_stay_general']
            los_icu = self.default_parameters['hospitalization_length_of_stay_icu']
            hosp_rate_general = self.default_parameters['hospitalization_rate_general']
            hosp_rate_icu = self.default_parameters['hospitalization_rate_icu']
            icu_rate = hosp_rate_icu / hosp_rate_general
            flow_out_of_hosp = self.hospitalizations[:-1] * ((1 - icu_rate) / los_general + icu_rate / los_icu)
            # We are attempting to reconstruct the cumulative hospitalizations.
            self.hospitalizations = np.diff(self.hospitalizations) + flow_out_of_hosp
            self.hospital_dates = self.hospital_dates[1:]
            self.hospital_times = self.hospital_times[1:]

        self.log_likelihood = None
Exemple #10
0
    def __init__(
        self,
        fips,
        ref_date=datetime(year=2020, month=1, day=1),
        min_deaths=2,
        n_years=1,
        cases_to_deaths_err_factor=0.5,
        hospital_to_deaths_err_factor=0.5,
        percent_error_on_max_observation=0.5,
        with_age_structure=False,
    ):

        # Seed the random state. It is unclear whether this propagates to the
        # Minuit optimizer.
        np.random.seed(seed=42)

        self.fips = fips
        self.ref_date = ref_date
        self.days_since_ref_date = (dt.date.today() - ref_date.date() -
                                    timedelta(days=7)).days
        # ndays end of 2nd ramp may extend past days_since_ref_date w/o  penalty on chi2 score
        self.days_allowed_beyond_ref = 0
        self.min_deaths = min_deaths
        self.t_list = np.linspace(0, int(365 * n_years),
                                  int(365 * n_years) + 1)
        self.cases_to_deaths_err_factor = cases_to_deaths_err_factor
        self.hospital_to_deaths_err_factor = hospital_to_deaths_err_factor
        self.percent_error_on_max_observation = percent_error_on_max_observation
        self.t0_guess = 60
        self.with_age_structure = with_age_structure

        if len(fips) == 2:  # State FIPS are 2 digits
            self.agg_level = AggregationLevel.STATE
            self.state_obj = us.states.lookup(self.fips)
            self.state = self.state_obj.name

            (
                self.times,
                self.observed_new_cases,
                self.observed_new_deaths,
            ) = load_data.load_new_case_data_by_state(self.state,
                                                      self.ref_date)

            (
                self.hospital_times,
                self.hospitalizations,
                self.hospitalization_data_type,
            ) = load_data.load_hospitalization_data_by_state(
                self.state_obj.abbr, t0=self.ref_date)

            (
                self.icu_times,
                self.icu,
                self.icu_data_type,
            ) = load_data.load_hospitalization_data_by_state(
                self.state_obj.abbr,
                t0=self.ref_date,
                category=HospitalizationCategory.ICU)

            self.display_name = self.state
        else:
            self.agg_level = AggregationLevel.COUNTY
            geo_metadata = load_data.load_county_metadata().set_index(
                "fips").loc[fips].to_dict()
            state = geo_metadata["state"]
            self.state_obj = us.states.lookup(state)
            county = geo_metadata["county"]
            if county:
                self.display_name = county + ", " + state
            else:
                self.display_name = state
            # TODO Swap for new data source.
            (
                self.times,
                self.observed_new_cases,
                self.observed_new_deaths,
            ) = load_data.load_new_case_data_by_fips(self.fips,
                                                     t0=self.ref_date)
            (
                self.hospital_times,
                self.hospitalizations,
                self.hospitalization_data_type,
            ) = load_data.load_hospitalization_data(self.fips,
                                                    t0=self.ref_date)
            (
                self.icu_times,
                self.icu,
                self.icu_data_type,
            ) = load_data.load_hospitalization_data(
                self.fips,
                t0=self.ref_date,
                category=HospitalizationCategory.ICU)

        self.cases_stdev, self.hosp_stdev, self.deaths_stdev = self.calculate_observation_errors(
        )
        self.set_inference_parameters()

        self.model_fit_keys = [
            "R0",
            "eps",
            "t_break",
            "eps2",
            "t_delta_phases",
            "log10_I_initial",
        ]

        self.SEIR_kwargs = self.get_average_seir_parameters()
        self.fit_results = None
        self.mle_model = None

        self.chi2_deaths = None
        self.chi2_cases = None
        self.chi2_hosp = None
        self.dof_deaths = None
        self.dof_cases = None
        self.dof_hosp = None
Exemple #11
0
def fit_county_model(fips):
    """
    Fit the county's current trajectory, using the existing measures. We fit
    only to mortality data if available, else revert to case data.

    We assume a poisson process generates mortalities at a rate defined by the
    underlying dynamical model.

    TODO @ EC: Add hospitalization data when available.

    Parameters
    ----------
    fips: str
        County fips.

    Returns
    -------
    fit_values: dict
        Optimal values from the fitter.
    """
    county_metadata = load_data.load_county_metadata().set_index(
        'fips').loc[fips].to_dict()
    times, observed_new_cases, observed_new_deaths = load_data.load_new_case_data_by_fips(
        fips, t0=ref_date)

    logging.info(
        f'Fitting MLE model to {county_metadata["county"]}, {county_metadata["state"]}'
    )
    SEIR_params = get_average_SEIR_parameters(fips)

    def _fit_seir(R0, t0, eps):
        model = SEIRModel(R0=R0,
                          suppression_policy=suppression_policies.
                          generate_empirical_distancing_policy(
                              t_list, fips, future_suppression=eps),
                          **SEIR_params)
        model.run()

        predicted_cases = model.gamma * np.interp(
            times, t_list + t0, model.results['total_new_infections'])
        predicted_deaths = np.interp(times, t_list + t0,
                                     model.results['direct_deaths_per_day'])

        # Assume the error on the case count could be off by a massive factor 50.
        # Basically we don't want to use it if there appreciable mortality data available.
        # Longer term there is a better procedure.
        cases_variance = 1e10 * observed_new_cases.copy(
        )**2  # Make the stdev N times larger x the number of cases
        deaths_variance = observed_new_deaths.copy()  # Poisson dist error

        # Zero inflated poisson Avoid floating point errors..
        cases_variance[cases_variance == 0] = 1e10
        deaths_variance[deaths_variance == 0] = 1e10

        # Compute Chi2
        chi2_cases = np.sum(
            (observed_new_cases - predicted_cases)**2 / cases_variance)
        if observed_new_deaths.sum() > 5:
            chi2_deaths = np.sum(
                (observed_new_deaths - predicted_deaths)**2 / deaths_variance)
        else:
            chi2_deaths = 0
        return chi2_deaths + chi2_cases

    # Note that error def is not right here. We need a realistic error model...
    m = iminuit.Minuit(_fit_seir,
                       R0=4,
                       t0=50,
                       eps=.5,
                       error_eps=.2,
                       limit_R0=[1, 8],
                       limit_eps=[0, 2],
                       limit_t0=[-90, 90],
                       error_t0=1,
                       error_R0=1.,
                       errordef=1)
    m.migrad()
    values = dict(fips=fips, **dict(m.values))
    values['t0_date'] = ref_date + timedelta(days=values['t0'])
    values['Reff_current'] = values['R0'] * values['eps']
    values['observed_total_deaths'] = np.sum(observed_new_deaths)
    values['county'] = county_metadata['county']
    values['state'] = county_metadata['state']
    values['total_population'] = county_metadata['total_population']
    values['population_density'] = county_metadata['population_density']
    return values
Exemple #12
0
def plot_inferred_result(fit_results):
    """
    Plot the results of an MLE inference
    """
    fips = fit_results['fips']
    county_metadata = load_data.load_county_metadata().set_index(
        'fips').loc[fips].to_dict()
    times, observed_new_cases, observed_new_deaths = load_data.load_new_case_data_by_fips(
        fips, t0=ref_date)
    if observed_new_cases.sum() < 5:
        logging.warning(
            f"{county_metadata['county']} has fewer than 5 cases. Aborting plot."
        )
        return
    else:
        logging.info(f"Plotting MLE Fits for {county_metadata['county']}")

    R0, t0, eps = fit_results['R0'], fit_results['t0'], fit_results['eps']

    model = SEIRModel(R0=R0,
                      suppression_policy=suppression_policies.
                      generate_empirical_distancing_policy(
                          t_list, fips, future_suppression=eps),
                      **get_average_SEIR_parameters(fit_results['fips']))
    model.run()

    data_dates = [ref_date + timedelta(days=t) for t in times]
    model_dates = [
        ref_date + timedelta(days=t + fit_results['t0']) for t in t_list
    ]
    plt.figure(figsize=(10, 8))
    plt.errorbar(data_dates,
                 observed_new_cases,
                 marker='o',
                 linestyle='',
                 label='Observed Cases Per Day')
    plt.errorbar(data_dates,
                 observed_new_deaths,
                 yerr=np.sqrt(observed_new_deaths),
                 marker='o',
                 linestyle='',
                 label='Observed Deaths')
    plt.plot(model_dates,
             model.results['total_new_infections'],
             label='Estimated Total New Infections Per Day')
    plt.plot(model_dates,
             model.gamma * model.results['total_new_infections'],
             label='Symptomatic Model Cases Per Day')
    plt.plot(model_dates,
             model.results['direct_deaths_per_day'],
             label='Model Deaths Per Day')
    plt.yscale('log')
    plt.ylim(.9e0)
    plt.xlim(data_dates[0], data_dates[-1] + timedelta(days=90))

    plt.xticks(rotation=30)
    plt.legend(loc=1)
    plt.grid(which='both', alpha=.3)
    plt.title(county_metadata['county'])
    for i, (k, v) in enumerate(fit_results.items()):
        if k not in ('fips', 't0_date', 'county', 'state'):
            plt.text(.025,
                     .97 - 0.04 * i,
                     f'{k}={v:1.3f}',
                     transform=plt.gca().transAxes,
                     fontsize=12)
        else:
            plt.text(.025,
                     .97 - 0.04 * i,
                     f'{k}={v}',
                     transform=plt.gca().transAxes,
                     fontsize=12)

    output_file = os.path.join(
        OUTPUT_DIR, fit_results['state'].title(), 'reports',
        f'{fit_results["state"]}__{fit_results["county"]}__{fit_results["fips"]}__mle_fit_results.pdf'
    )
    plt.savefig(output_file)