Пример #1
0
def run_state(state, states_only=False):
    """
    Run the R_t inference for each county in a state.

    Parameters
    ----------
    state: str
        State to run against.
    states_only: bool
        If True only run the state level.
    """
    state_obj = us.states.lookup(state)
    df = RtInferenceEngine.run_for_fips(state_obj.fips)
    output_path = get_run_artifact_path(state_obj.fips, RunArtifact.RT_INFERENCE_RESULT)
    if df is None or df.empty:
        logging.error("Empty dataframe encountered! No RtInference results available for %s", state)
    else:
        df.to_json(output_path)

    # Run the counties.
    if not states_only:
        all_fips = load_data.get_all_fips_codes_for_a_state(state)

        # Something in here doesn't like multiprocessing...
        rt_inferences = all_fips.map(lambda x: RtInferenceEngine.run_for_fips(x)).tolist()

        for fips, rt_inference in zip(all_fips, rt_inferences):
            county_output_file = get_run_artifact_path(fips, RunArtifact.RT_INFERENCE_RESULT)
            if rt_inference is not None:
                rt_inference.to_json(county_output_file)
Пример #2
0
def run_state(state, states_only=False):
    """
    Run the R_t inference for each county in a state.

    Parameters
    ----------
    state: str
        State to run against.
    states_only: bool
        If True only run the state level.
    """
    state_obj = us.states.lookup(state)
    df = RtInferenceEngine.run_for_fips(state_obj.fips)
    output_path = get_run_artifact_path(state_obj.fips, RunArtifact.RT_INFERENCE_RESULT)
    df.to_json(output_path)

    # Run the counties.
    if not states_only:
        df = load_data.load_county_metadata()
        all_fips = df[df['state'].str.lower() == state_obj.name.lower()].fips.values

        # Something in here doesn't like multiprocessing...
        # p = Pool(2)
        rt_inferences = list(map(RtInferenceEngine.run_for_fips, all_fips))
        # p.close()

        for fips, rt_inference in zip(all_fips, rt_inferences):
            county_output_file = get_run_artifact_path(fips, RunArtifact.RT_INFERENCE_RESULT)
            if rt_inference is not None:
                rt_inference.to_json(county_output_file)
    def __init__(
        self,
        fips,
        n_years=0.5,
        n_samples=250,
        suppression_policy=(0.35, 0.5, 0.75, 1),
        skip_plots=False,
        output_percentiles=(5, 25, 32, 50, 75, 68, 95),
        generate_report=True,
        run_mode=RunMode.DEFAULT,
        min_hospitalization_threshold=5,
        hospitalization_to_confirmed_case_ratio=1 / 4,
    ):

        self.fips = fips
        self.agg_level = AggregationLevel.COUNTY if len(fips) == 5 else AggregationLevel.STATE

        self.t_list = np.linspace(0, int(365 * n_years), int(365 * n_years) + 1)
        self.skip_plots = skip_plots
        self.run_mode = RunMode(run_mode)
        self.hospitalizations_for_state = None
        self.min_hospitalization_threshold = min_hospitalization_threshold
        self.hospitalization_to_confirmed_case_ratio = hospitalization_to_confirmed_case_ratio

        if self.agg_level is AggregationLevel.COUNTY:
            self.county_metadata = load_data.load_county_metadata_by_fips(fips)
            self.state_abbr = us.states.lookup(self.county_metadata["state"]).abbr
            self.state_name = us.states.lookup(self.county_metadata["state"]).name

            self.output_file_report = get_run_artifact_path(self.fips, RunArtifact.ENSEMBLE_REPORT)
            self.output_file_data = get_run_artifact_path(self.fips, RunArtifact.ENSEMBLE_RESULT)

        else:
            self.state_abbr = us.states.lookup(self.fips).abbr
            self.state_name = us.states.lookup(self.fips).name

            self.output_file_report = None
            self.output_file_data = get_run_artifact_path(self.fips, RunArtifact.ENSEMBLE_RESULT)

        os.makedirs(os.path.dirname(self.output_file_data), exist_ok=True)
        if self.output_file_report:
            os.makedirs(os.path.dirname(self.output_file_report), exist_ok=True)

        self.output_percentiles = output_percentiles
        self.n_samples = n_samples
        self.n_years = n_years
        # TODO: Will be soon replaced with loaders for all the inferred params.
        # self.t0 = fit_results.load_t0(fips)
        self.date_generated = datetime.datetime.utcnow().isoformat()
        self.suppression_policy = suppression_policy
        self.summary = copy.deepcopy(self.__dict__)
        self.summary.pop("t_list")
        self.generate_report = generate_report

        self.suppression_policies = None
        self.override_params = dict()
        self.init_run_mode()

        self.all_outputs = {}
Пример #4
0
def run_state(state, states_only=False, with_age_structure=False):
    """
    Run the fitter for each county in a state.

    Parameters
    ----------
    state: str
        State to run against.
    states_only: bool
        If True only run the state level.
    with_age_structure: bool
        If True run model with age structure.
    """
    state_obj = us.states.lookup(state)
    logging.info(f"Running MLE fitter for state {state_obj.name}")

    model_fitter = ModelFitter.run_for_fips(
        fips=state_obj.fips, with_age_structure=with_age_structure)

    df_whitelist = load_data.load_whitelist()
    df_whitelist = df_whitelist[df_whitelist["inference_ok"] == True]

    output_path = get_run_artifact_path(state_obj.fips,
                                        RunArtifact.MLE_FIT_RESULT)
    data = pd.DataFrame(model_fitter.fit_results, index=[state_obj.fips])
    data.to_json(output_path)

    with open(get_run_artifact_path(state_obj.fips, RunArtifact.MLE_FIT_MODEL),
              "wb") as f:
        pickle.dump(model_fitter.mle_model, f)

    # Run the counties.
    if not states_only:
        # TODO: Replace with build_county_list
        df_whitelist = load_data.load_whitelist()
        df_whitelist = df_whitelist[df_whitelist["inference_ok"] == True]

        all_fips = df_whitelist[df_whitelist["state"].str.lower() ==
                                state_obj.name.lower()].fips.values

        if len(all_fips) > 0:
            p = Pool()
            fitters = p.map(ModelFitter.run_for_fips, all_fips)
            p.close()

            county_output_file = get_run_artifact_path(
                all_fips[0], RunArtifact.MLE_FIT_RESULT)
            data = pd.DataFrame([fit.fit_results for fit in fitters if fit])
            data.to_json(county_output_file)

            # Serialize the model results.
            for fips, fitter in zip(all_fips, fitters):
                if fitter:
                    with open(
                            get_run_artifact_path(fips,
                                                  RunArtifact.MLE_FIT_MODEL),
                            "wb") as f:
                        pickle.dump(fitter.mle_model, f)
Пример #5
0
def _persist_results_per_state(state_df):
    county_output_file = get_run_artifact_path(state_df.fips[0],
                                               RunArtifact.MLE_FIT_RESULT)
    data = state_df.drop(['state', 'mle_model'], axis=1)
    data.to_json(county_output_file)

    for fips, county_series in state_df.iterrows():
        with open(get_run_artifact_path(fips, RunArtifact.MLE_FIT_MODEL),
                  'wb') as f:
            pickle.dump(county_series.mle_model, f)
Пример #6
0
    def _load_model_for_fips(self, scenario='inferred'):
        """
        Try to load a model for the locale, else load the state level model
        and update parameters for the county.
        """
        artifact_path = get_run_artifact_path(self.fips, RunArtifact.MLE_FIT_MODEL)
        if os.path.exists(artifact_path):
            with open(artifact_path, 'rb') as f:
                model = pickle.load(f)
            inferred_params = fit_results.load_inference_result(self.fips)

        else:
            _logger.info(f'No MLE model found for {self.state_name}: {self.fips}. Reverting to state level.')
            artifact_path = get_run_artifact_path(self.fips[:2], RunArtifact.MLE_FIT_MODEL)
            if os.path.exists(artifact_path):
                with open(artifact_path, 'rb') as f:
                    model = pickle.load(f)
                inferred_params = fit_results.load_inference_result(self.fips[:2])
            else:
                raise FileNotFoundError(f'Could not locate state result for {self.state_name}')

            # Rescale state values to the county population and replace county
            # specific params.
            default_params = ParameterEnsembleGenerator(
                self.fips, N_samples=1, t_list=model.t_list,
                suppression_policy=model.suppression_policy).get_average_seir_parameters()
            population_ratio = default_params['N'] / model.N
            model.N *= population_ratio
            model.I_initial *= population_ratio
            model.E_initial *= population_ratio
            model.A_initial *= population_ratio
            model.S_initial = model.N - model.I_initial - model.E_initial - model.A_initial

            for key in {'beds_general', 'beds_ICU', 'ventilators'}:
                setattr(model, key, default_params[key])

        # Determine the appropriate future suppression policy based on the
        # scenario of interest.
        if scenario == 'inferred':
            eps_final = inferred_params['eps']
        else:
            eps_final = sp.get_future_suppression_from_r0(inferred_params['R0'], scenario=scenario)

        model.suppression_policy = sp.generate_two_step_policy(
            self.t_list,
            eps=inferred_params['eps'],
            t_break=inferred_params['t_break'],
            t_break_final=(datetime.datetime.today() - datetime.datetime.fromisoformat(inferred_params['t0_date'])).days,
            eps_final=eps_final
        )
        model.run()
        return model
    def generate_whitelist(self):
        """
        Generate a county whitelist based on the cuts above.

        Returns
        -------
        df: whitelist
        """
        logging.info("Generating county level whitelist...")

        # parallel load and compute
        df_candidates = self.county_metadata.fips.parallel_apply(_whitelist_candidates_per_fips)

        # join extra data
        df_candidates = df_candidates.merge(
            self.county_metadata[["fips", "state", "county"]],
            left_on="fips",
            right_on="fips",
            how="inner",
        )
        df_candidates["inference_ok"] = (
            (df_candidates.nonzero_case_datapoints >= self.nonzero_case_datapoints)
            & (df_candidates.nonzero_death_datapoints >= self.nonzero_death_datapoints)
            & (df_candidates.total_cases >= self.total_cases)
            & (df_candidates.total_deaths >= self.total_deaths)
        )

        output_path = get_run_artifact_path(
            fips="06", artifact=RunArtifact.WHITELIST_RESULT  # Dummy fips since not used here...
        )
        df_whitelist = df_candidates[["fips", "state", "county", "inference_ok"]]
        df_whitelist.to_json(output_path)

        return df_whitelist
Пример #8
0
 def generate_report(self):
     """
     Generate pdf report of backtesting results.
     """
     output_path = get_run_artifact_path(self.fips, "backtest_result")
     pdf = matplotlib.backends.backend_pdf.PdfPages(output_path)
     self.plot_backtest_results(self.backtest_results, pdf)
     self.plot_historical_predictions(self.historical_predictions,
                                      self.observations, pdf)
     pdf.close()
Пример #9
0
    def run_ensemble(self):
        """
        Run an ensemble of models for each suppression policy nad generate the
        output report / results dataset.
        """
        for suppression_policy_name, suppression_policy in self.suppression_policies.items(
        ):

            logging.info(
                f'Running simulation ensemble for {self.state_name} {self.fips} {suppression_policy_name}'
            )

            if suppression_policy_name == 'suppression_policy__inferred':

                artifact_path = get_run_artifact_path(
                    self.fips, RunArtifact.MLE_FIT_MODEL)
                if os.path.exists(artifact_path):
                    with open(artifact_path, 'rb') as f:
                        model_ensemble = [pickle.load(f)]
                else:
                    logging.warning(
                        f'No MLE model found for {self.state_name}: {self.fips}. Skipping.'
                    )
            else:
                parameter_sampler = ParameterEnsembleGenerator(
                    fips=self.fips,
                    N_samples=self.n_samples,
                    t_list=self.t_list,
                    suppression_policy=suppression_policy)
                parameter_ensemble = parameter_sampler.sample_seir_parameters(
                    override_params=self.override_params)
                model_ensemble = list(
                    map(self._run_single_simulation, parameter_ensemble))

            if self.agg_level is AggregationLevel.COUNTY:
                self.all_outputs['county_metadata'] = self.county_metadata
                self.all_outputs['county_metadata']['age_distribution'] = list(
                    self.all_outputs['county_metadata']['age_distribution'])
                self.all_outputs['county_metadata']['age_bins'] = list(
                    self.all_outputs['county_metadata']['age_distribution'])

            self.all_outputs[
                f'{suppression_policy_name}'] = self._generate_output_for_suppression_policy(
                    model_ensemble)

        if self.generate_report and self.output_file_report:
            report = CountyReport(self.fips,
                                  model_ensemble=model_ensemble,
                                  county_outputs=self.all_outputs,
                                  filename=self.output_file_report,
                                  summary=self.summary)
            report.generate_and_save()

        with open(self.output_file_data, 'w') as f:
            json.dump(self.all_outputs, f)
Пример #10
0
def load_whitelist():
    """
    Load the whitelist result.

    Returns
    -------
    whitelist: pd.DataFrame
        DataFrame containing a whitelist of product features for counties.
    """
    # Whitelist path isn't state specific, but the call requires ANY fips
    PLACEHOLDER_FIPS = "06"
    path = get_run_artifact_path(fips=PLACEHOLDER_FIPS, artifact=RunArtifact.WHITELIST_RESULT)
    return pd.read_json(path, dtype={"fips": str})
Пример #11
0
def load_whitelist():
    """
    Load the whitelist result.

    Returns
    -------
    whitelist: pd.DataFrame
        DataFrame containing a whitelist of product features for counties.
    """
    path = get_run_artifact_path(
        fips='06',  # dummy since not used for whitelist.
        artifact=RunArtifact.WHITELIST_RESULT)
    return pd.read_json(path, dtype={'fips': str})
Пример #12
0
def test__pyseir_end_to_end():
    # This covers a lot of edge cases.
    cli._run_all(state='idaho')
    path = get_run_artifact_path('16001', RunArtifact.WEB_UI_RESULT).replace('__INTERVENTION_IDX__', '2')
    assert os.path.exists(path)

    with open(path) as f:
        output = json.load(f)

    output = pd.DataFrame(output)
    rt_col = schema.CAN_MODEL_OUTPUT_SCHEMA.index(schema.RT_INDICATOR)

    assert (output[rt_col].astype(float) > 0).any()
    assert (output.loc[output[rt_col].astype(float).notnull(), rt_col].astype(float) < 6).all()
Пример #13
0
def test_pyseir_end_to_end_idaho():
    # This covers a lot of edge cases.
    # cli._run_all(state='Idaho')
    cli._build_all_for_states(states=["Idaho"], generate_reports=False, fips="16001")
    path = get_run_artifact_path("16001", RunArtifact.WEB_UI_RESULT).replace(
        "__INTERVENTION_IDX__", "2"
    )
    path = pathlib.Path(path)
    assert path.exists()
    output = CANPyseirLocationOutput.load_from_path(path)
    data = output.data
    with_values = data[schema.RT_INDICATOR].dropna()
    assert len(with_values) > 10
    assert (with_values > 0).all()
    assert (with_values < 6).all()
Пример #14
0
def load_Rt_result(fips):
    """
    Load the Rt inference result.

    Parameters
    ----------
    fips: str
        State or County FIPS code.

    Returns
    -------
    results: pd.DataFrame
        DataFrame containing the R_t inferences.
    """
    path = get_run_artifact_path(fips, RunArtifact.RT_INFERENCE_RESULT)
    return pd.read_json(path)
Пример #15
0
def load_inference_result(fips):
    """
    Load fit results by state or county fips code.

    Parameters
    ----------
    fips: str
        State or County FIPS code.

    Returns
    -------
    : dict
        Dictionary of fit result information.
    """
    output_file = get_run_artifact_path(fips, RunArtifact.MLE_FIT_RESULT)
    return pd.read_json(output_file).iloc[0].to_dict()
Пример #16
0
def run_county(fips):
    """
    Run the R_t inference for each county in a state.

    Parameters
    ----------
    fips: str
        County fips to run against
    """
    if not fips:
        return None

    df = RtInferenceEngine.run_for_fips(fips)
    county_output_file = get_run_artifact_path(fips, RunArtifact.RT_INFERENCE_RESULT)
    if df is not None and not df.empty:
        df.to_json(county_output_file)
Пример #17
0
def test__pyseir_end_to_end():
    # This covers a lot of edge cases.
    # cli._run_all(state='idaho')
    cli._build_all_for_states(states=["idaho"], generate_reports=False)
    path = get_run_artifact_path("16001", RunArtifact.WEB_UI_RESULT).replace(
        "__INTERVENTION_IDX__", "2")
    assert os.path.exists(path)

    with open(path) as f:
        output = json.load(f)

    output = pd.DataFrame(output)
    rt_col = schema.CAN_MODEL_OUTPUT_SCHEMA.index(schema.RT_INDICATOR)

    assert (output[rt_col].astype(float) > 0).any()
    assert (output.loc[output[rt_col].astype(float).notnull(),
                       rt_col].astype(float) < 6).all()
Пример #18
0
def load_ensemble_results(fips):
    """
    Retrieve ensemble results for a given state or county fips code.

    Parameters
    ----------
    fips: str
        State or county FIPS to load.

    Returns
    -------
    ensemble_results: dict
    """
    output_filename = get_run_artifact_path(fips, RunArtifact.ENSEMBLE_RESULT)
    with open(output_filename) as f:
        fit_results = json.load(f)
    return fit_results
    def generate_whitelist(self):
        """
        Generate a county whitelist based on the cuts above.

        Returns
        -------
        df: whitelist
        """
        logging.info('Generating county level whitelist...')

        whitelist_generator_inputs = []
        for fips in self.county_metadata.fips:
            times, observed_new_cases, observed_new_deaths = load_data.load_new_case_data_by_fips(
                fips, t0=datetime(day=1, month=1, year=2020))

            metadata = self.county_metadata[self.county_metadata.fips == fips].iloc[0].to_dict()

            record = dict(
                fips=fips,
                state=metadata['state'],
                county=metadata['county'],
                total_cases=observed_new_cases.sum(),
                total_deaths=observed_new_deaths.sum(),
                nonzero_case_datapoints=np.sum(observed_new_cases > 0),
                nonzero_death_datapoints=np.sum(observed_new_deaths > 0)
            )
            whitelist_generator_inputs.append(record)

        df_candidates = pd.DataFrame(whitelist_generator_inputs)

        df_whitelist = df_candidates[['fips', 'state', 'county']]
        df_whitelist.loc[:, 'inference_ok'] = (
                  (df_candidates.nonzero_case_datapoints >= self.nonzero_case_datapoints)
                & (df_candidates.nonzero_death_datapoints >= self.nonzero_death_datapoints)
                & (df_candidates.total_cases >= self.total_cases)
                & (df_candidates.total_deaths >= self.total_deaths)
        )

        output_path = get_run_artifact_path(
            fips='06', # Dummy fips since not used here...
            artifact=RunArtifact.WHITELIST_RESULT)
        df_whitelist.to_json(output_path)

        return df_whitelist
Пример #20
0
def load_inference_result(fips):
    """
    Load fit results by state or county fips code.

    Parameters
    ----------
    fips: str
        State or County FIPS code.

    Returns
    -------
    : dict
        Dictionary of fit result information.
    """
    output_file = get_run_artifact_path(fips, RunArtifact.MLE_FIT_RESULT)
    df = pd.read_json(output_file, dtype={"fips": "str"})
    if len(fips) == 2:
        return df.iloc[0].to_dict()
    else:
        return df.set_index("fips").loc[fips].to_dict()
Пример #21
0
def run_rt_for_fips(
    fips: str,
    include_deaths: bool = False,
    include_testing_correction: bool = True,
    figure_collector: Optional[list] = None,
):
    """Entry Point for Infer Rt"""

    # Generate the Data Packet to Pass to RtInferenceEngine
    input_df = _generate_input_data(
        fips=fips,
        include_testing_correction=include_testing_correction,
        include_deaths=include_deaths,
        figure_collector=figure_collector,
    )
    if input_df.dropna().empty:
        rt_log.warning(
            event="Infer Rt Skipped. No Data Passed Filter Requirements:",
            fips=fips)
        return

    # Save a reference to instantiated engine (eventually I want to pull out the figure
    # generation and saving so that I don't have to pass a display_name and fips into the class
    engine = RtInferenceEngine(
        data=input_df,
        display_name=_get_display_name(fips),
        fips=fips,
        include_deaths=include_deaths,
    )

    # Generate the output DataFrame (consider renaming the function infer_all to be clearer)
    output_df = engine.infer_all()

    # Save the output to json for downstream repacking and incorporation.
    if output_df is not None and not output_df.empty:
        output_path = get_run_artifact_path(fips,
                                            RunArtifact.RT_INFERENCE_RESULT)
        output_df.to_json(output_path)
    return
Пример #22
0
def plot_fitting_results(result) -> None:
    """
    Entry point from model_fitter. Generate and save all PySEIR related Figures.
    """
    output_file = get_run_artifact_path(result.fips, RunArtifact.MLE_FIT_REPORT)

    # Save the mle fitter
    mle_fig = result.mle_model.plot_results()
    mle_fig.savefig(output_file.replace("mle_fit_results", "mle_fit_model"), bbox_inches="tight")

    # Generate Figure
    fig = _plot_model_fitter_results(result)

    # Save the figure in Log-Linear
    fig.gca().set_yscale("log")
    fig.savefig(output_file, bbox_inches="tight")

    # Save the figure in Linear mode
    fig.gca().set_yscale("linear")
    fig.savefig(
        output_file.replace("mle_fit_results", "mle_fit_results_linear"), bbox_inches="tight"
    )
    return
Пример #23
0
    def generate_whitelist(self):
        """
        Generate a county whitelist based on the cuts above.

        Returns
        -------
        df: whitelist
        """
        logging.info('Generating county level whitelist...')

        # parallel load and compute
        df_candidates = self.county_metadata.fips.parallel_apply(
            _whitelist_candidates_per_fips)

        # join extra data
        df_candidates = df_candidates.merge(
            self.county_metadata[['fips', 'state', 'county']],
            left_on='fips',
            right_on='fips',
            how='inner')

        df_whitelist = df_candidates[['fips', 'state', 'county']]
        df_whitelist.loc[:, 'inference_ok'] = (
            (df_candidates.nonzero_case_datapoints >=
             self.nonzero_case_datapoints)
            & (df_candidates.nonzero_death_datapoints >=
               self.nonzero_death_datapoints)
            & (df_candidates.total_cases >= self.total_cases)
            & (df_candidates.total_deaths >= self.total_deaths))

        output_path = get_run_artifact_path(
            fips='06',  # Dummy fips since not used here...
            artifact=RunArtifact.WHITELIST_RESULT)
        df_whitelist.to_json(output_path)

        return df_whitelist
Пример #24
0
    def map_fips(self, fips):
        """
        For a given county fips code, generate the CAN UI output format.

        Parameters
        ----------
        fips: str
            County FIPS code to map.
        """
        logging.info(f"Mapping output to WebUI for {self.state}, {fips}")
        pyseir_outputs = load_data.load_ensemble_results(fips)

        if len(fips) == 5 and fips not in self.df_whitelist.fips.values:
            logging.info(f"Excluding {fips} due to white list...")
            return
        try:
            fit_results = load_inference_result(fips)
            t0_simulation = datetime.fromisoformat(fit_results["t0_date"])
        except (KeyError, ValueError):
            logging.error(f"Fit result not found for {fips}. Skipping...")
            return

        # ---------------------------------------------------------------------
        # Rescale hosps based on the population ratio... Could swap this to
        # infection ratio later?
        # ---------------------------------------------------------------------
        hosp_times, current_hosp, _ = load_data.load_hospitalization_data_by_state(
            state=self.state_abbreviation,
            t0=t0_simulation,
            convert_cumulative_to_current=True,
            category="hospitalized",
        )

        _, current_icu, _ = load_data.load_hospitalization_data_by_state(
            state=self.state_abbreviation,
            t0=t0_simulation,
            convert_cumulative_to_current=True,
            category="icu",
        )

        if len(fips) == 5:
            population = self.population_data.get_record_for_fips(fips)[CommonFields.POPULATION]
        else:
            population = self.population_data.get_record_for_state(self.state_abbreviation)[
                CommonFields.POPULATION
            ]

        # logging.info(f'Mapping output to WebUI for {self.state}, {fips}')
        # pyseir_outputs = load_data.load_ensemble_results(fips)
        # if pyseir_outputs is None:
        #     logging.warning(f'Fit result not found for {fips}: Skipping county')
        #     return None

        policies = [key for key in pyseir_outputs.keys() if key.startswith("suppression_policy")]
        if current_hosp is not None:
            t_latest_hosp_data, current_hosp = hosp_times[-1], current_hosp[-1]
            t_latest_hosp_data_date = t0_simulation + timedelta(days=int(t_latest_hosp_data))

            state_hosp_gen = load_data.get_compartment_value_on_date(
                fips=fips[:2], compartment="HGen", date=t_latest_hosp_data_date
            )
            state_hosp_icu = load_data.get_compartment_value_on_date(
                fips=fips[:2], compartment="HICU", date=t_latest_hosp_data_date
            )

            if len(fips) == 5:
                # Rescale the county level hospitalizations by the expected
                # ratio of county / state hospitalizations from simulations.
                # We use ICU data if available too.
                county_hosp = load_data.get_compartment_value_on_date(
                    fips=fips,
                    compartment="HGen",
                    date=t_latest_hosp_data_date,
                    ensemble_results=pyseir_outputs,
                )
                county_icu = load_data.get_compartment_value_on_date(
                    fips=fips,
                    compartment="HICU",
                    date=t_latest_hosp_data_date,
                    ensemble_results=pyseir_outputs,
                )
                current_hosp *= (county_hosp + county_icu) / (state_hosp_gen + state_hosp_icu)

            hosp_rescaling_factor = current_hosp / (state_hosp_gen + state_hosp_icu)

            # Some states have covidtracking issues. We shouldn't ground ICU cases
            # to zero since so far these have all been bad reporting.
            if current_icu is not None and current_icu[-1] > 0:
                icu_rescaling_factor = current_icu[-1] / state_hosp_icu
            else:
                icu_rescaling_factor = current_hosp / (state_hosp_gen + state_hosp_icu)
        else:
            hosp_rescaling_factor = 1.0
            icu_rescaling_factor = 1.0

        # Iterate through each suppression policy.
        # Model output is interpolated to the dates desired for the API.
        for i_policy, suppression_policy in enumerate(
            [key for key in pyseir_outputs.keys() if key.startswith("suppression_policy")]
        ):

            output_for_policy = pyseir_outputs[suppression_policy]
            output_model = pd.DataFrame()

            t_list = output_for_policy["t_list"]
            t_list_downsampled = range(0, int(max(t_list)), self.output_interval_days)

            output_model[schema.DAY_NUM] = t_list_downsampled
            output_model[schema.DATE] = [
                (t0_simulation + timedelta(days=t)).date().strftime("%m/%d/%y")
                for t in t_list_downsampled
            ]
            output_model[schema.TOTAL] = population
            output_model[schema.TOTAL_SUSCEPTIBLE] = np.interp(
                t_list_downsampled, t_list, output_for_policy["S"]["ci_50"]
            )
            output_model[schema.EXPOSED] = np.interp(
                t_list_downsampled, t_list, output_for_policy["E"]["ci_50"]
            )
            output_model[schema.INFECTED] = np.interp(
                t_list_downsampled,
                t_list,
                np.add(output_for_policy["I"]["ci_50"], output_for_policy["A"]["ci_50"]),
            )  # Infected + Asympt.
            output_model[schema.INFECTED_A] = output_model[schema.INFECTED]
            output_model[schema.INFECTED_B] = hosp_rescaling_factor * np.interp(
                t_list_downsampled, t_list, output_for_policy["HGen"]["ci_50"]
            )  # Hosp General
            output_model[schema.INFECTED_C] = icu_rescaling_factor * np.interp(
                t_list_downsampled, t_list, output_for_policy["HICU"]["ci_50"]
            )  # Hosp ICU
            # General + ICU beds. don't include vent here because they are also counted in ICU
            output_model[schema.ALL_HOSPITALIZED] = np.add(
                output_model[schema.INFECTED_B], output_model[schema.INFECTED_C]
            )
            output_model[schema.ALL_INFECTED] = output_model[schema.INFECTED]
            output_model[schema.DEAD] = np.interp(
                t_list_downsampled, t_list, output_for_policy["total_deaths"]["ci_50"]
            )
            final_beds = np.mean(output_for_policy["HGen"]["capacity"])
            output_model[schema.BEDS] = final_beds
            output_model[schema.CUMULATIVE_INFECTED] = np.interp(
                t_list_downsampled,
                t_list,
                np.cumsum(output_for_policy["total_new_infections"]["ci_50"]),
            )

            if fit_results:
                output_model[schema.Rt] = np.interp(
                    t_list_downsampled,
                    t_list,
                    fit_results["eps"] * fit_results["R0"] * np.ones(len(t_list)),
                )
                output_model[schema.Rt_ci90] = np.interp(
                    t_list_downsampled,
                    t_list,
                    2 * fit_results["eps_error"] * fit_results["R0"] * np.ones(len(t_list)),
                )
            else:
                output_model[schema.Rt] = 0
                output_model[schema.Rt_ci90] = 0

            output_model[schema.CURRENT_VENTILATED] = icu_rescaling_factor * np.interp(
                t_list_downsampled, t_list, output_for_policy["HVent"]["ci_50"]
            )
            output_model[schema.POPULATION] = population
            # Average capacity.
            output_model[schema.ICU_BED_CAPACITY] = np.mean(output_for_policy["HICU"]["capacity"])
            output_model[schema.VENTILATOR_CAPACITY] = np.mean(
                output_for_policy["HVent"]["capacity"]
            )

            # Truncate date range of output.
            output_dates = pd.to_datetime(output_model["date"])
            output_model = output_model[
                (output_dates >= datetime(month=3, day=3, year=2020))
                & (output_dates < datetime.today() + timedelta(days=90))
            ]
            output_model = output_model.fillna(0)

            # Fill in results for the Rt indicator.
            try:
                rt_results = load_Rt_result(fips)
                rt_results.index = rt_results["Rt_MAP_composite"].index.strftime("%m/%d/%y")
                merged = output_model.merge(
                    rt_results[["Rt_MAP_composite", "Rt_ci95_composite"]],
                    right_index=True,
                    left_on="date",
                    how="left",
                )
                output_model[schema.RT_INDICATOR] = merged["Rt_MAP_composite"]

                # With 90% probability the value is between rt_indicator - ci90 to rt_indicator + ci90
                output_model[schema.RT_INDICATOR_CI90] = (
                    merged["Rt_ci95_composite"] - merged["Rt_MAP_composite"]
                )
            except (ValueError, KeyError) as e:
                output_model[schema.RT_INDICATOR] = "NaN"
                output_model[schema.RT_INDICATOR_CI90] = "NaN"

            output_model[[schema.RT_INDICATOR, schema.RT_INDICATOR_CI90]] = output_model[
                [schema.RT_INDICATOR, schema.RT_INDICATOR_CI90]
            ].fillna("NaN")

            # Truncate floats and cast as strings to match data model.
            int_columns = [
                col
                for col in output_model.columns
                if col
                not in (
                    schema.DATE,
                    schema.Rt,
                    schema.Rt_ci90,
                    schema.RT_INDICATOR,
                    schema.RT_INDICATOR_CI90,
                )
            ]
            output_model.loc[:, int_columns] = (
                output_model[int_columns].fillna(0).astype(int).astype(str)
            )
            output_model.loc[
                :, [schema.Rt, schema.Rt_ci90, schema.RT_INDICATOR, schema.RT_INDICATOR_CI90]
            ] = (
                output_model[
                    [schema.Rt, schema.Rt_ci90, schema.RT_INDICATOR, schema.RT_INDICATOR_CI90]
                ]
                .fillna(0)
                .round(decimals=4)
                .astype(str)
            )

            # Convert the records format to just list(list(values))
            output_model = [
                [val for val in timestep.values()]
                for timestep in output_model.to_dict(orient="records")
            ]

            output_path = get_run_artifact_path(
                fips, RunArtifact.WEB_UI_RESULT, output_dir=self.output_dir
            )
            policy_enum = Intervention.from_webui_data_adaptor(suppression_policy)
            output_path = output_path.replace("__INTERVENTION_IDX__", str(policy_enum.value))
            with open(output_path, "w") as f:
                json.dump(output_model, f)
Пример #25
0
    def plot_fitting_results(self):
        """
        Plotting model fitting results.
        """
        data_dates = [self.ref_date + timedelta(days=t) for t in self.times]
        if self.hospital_times is not None:
            hosp_dates = [
                self.ref_date + timedelta(days=float(t))
                for t in self.hospital_times
            ]
        model_dates = [
            self.ref_date + timedelta(days=t + self.fit_results['t0'])
            for t in self.t_list
        ]

        # Don't display the zero-inflated error bars
        cases_err = np.array(self.cases_stdev)
        cases_err[self.observed_new_cases == 0] = 0
        death_err = deepcopy(self.deaths_stdev)
        death_err[self.observed_new_deaths == 0] = 0
        if self.hosp_stdev is not None:
            hosp_stdev = deepcopy(self.hosp_stdev)
            hosp_stdev[hosp_stdev > 1e5] = 0

        plt.figure(figsize=(18, 12))
        plt.errorbar(data_dates,
                     self.observed_new_cases,
                     yerr=cases_err,
                     marker='o',
                     linestyle='',
                     label='Observed Cases Per Day',
                     color='steelblue',
                     capsize=3,
                     alpha=.4,
                     markersize=10)
        plt.errorbar(data_dates,
                     self.observed_new_deaths,
                     yerr=death_err,
                     marker='d',
                     linestyle='',
                     label='Observed Deaths Per Day',
                     color='firebrick',
                     capsize=3,
                     alpha=.4,
                     markersize=10)

        plt.plot(model_dates,
                 self.mle_model.results['total_new_infections'],
                 label='Estimated Total New Infections Per Day',
                 linestyle='--',
                 lw=4,
                 color='steelblue')
        plt.plot(model_dates,
                 self.fit_results['test_fraction'] *
                 self.mle_model.results['total_new_infections'],
                 label='Estimated Tested New Infections Per Day',
                 color='steelblue',
                 lw=4)

        plt.plot(model_dates,
                 self.mle_model.results['total_deaths_per_day'],
                 label='Model Deaths Per Day',
                 color='firebrick',
                 lw=4)

        if self.hospitalization_data_type is HospitalizationDataType.CUMULATIVE_HOSPITALIZATIONS:
            new_hosp_observed = self.hospitalizations[
                1:] - self.hospitalizations[:-1]
            plt.errorbar(hosp_dates[1:],
                         new_hosp_observed,
                         yerr=hosp_stdev,
                         marker='s',
                         linestyle='',
                         label='Observed New Hospitalizations Per Day',
                         color='darkseagreen',
                         capsize=3,
                         alpha=1)
            predicted_hosp = (self.mle_model.results['HGen_cumulative'] +
                              self.mle_model.results['HICU_cumulative'])
            predicted_hosp = predicted_hosp[1:] - predicted_hosp[:-1]
            plt.plot(model_dates[1:],
                     self.fit_results['hosp_fraction'] * predicted_hosp,
                     label='Estimated Total New Hospitalizations Per Day',
                     linestyle='-.',
                     lw=4,
                     color='darkseagreen',
                     markersize=10)
        elif self.hospitalization_data_type is HospitalizationDataType.CURRENT_HOSPITALIZATIONS:
            plt.errorbar(hosp_dates,
                         self.hospitalizations,
                         yerr=hosp_stdev,
                         marker='s',
                         linestyle='',
                         label='Observed Total Current Hospitalizations',
                         color='darkseagreen',
                         capsize=3,
                         alpha=.5,
                         markersize=10)
            predicted_hosp = (self.mle_model.results['HGen'] +
                              self.mle_model.results['HICU'])
            plt.plot(model_dates,
                     self.fit_results['hosp_fraction'] * predicted_hosp,
                     label='Estimated Total Current Hospitalizations',
                     linestyle='-.',
                     lw=4,
                     color='darkseagreen')

        plt.plot(model_dates,
                 self.fit_results['hosp_fraction'] *
                 self.mle_model.results['HICU'],
                 label='Estimated ICU Occupancy',
                 linestyle=':',
                 lw=6,
                 color='black')
        plt.plot(model_dates,
                 self.fit_results['hosp_fraction'] *
                 self.mle_model.results['HGen'],
                 label='Estimated General Occupancy',
                 linestyle=':',
                 lw=4,
                 color='black',
                 alpha=0.4)

        plt.yscale('log')
        y_lim = plt.ylim(.8e0)

        start_intervention_date = self.ref_date + timedelta(
            days=self.fit_results['t_break'] + self.fit_results['t0'])
        stop_intervention_date = start_intervention_date + timedelta(days=14)

        plt.fill_betweenx([y_lim[0], y_lim[1]],
                          [start_intervention_date, start_intervention_date],
                          [stop_intervention_date, stop_intervention_date],
                          alpha=0.2,
                          label='Estimated Intervention')

        running_total = timedelta(days=0)
        for i_label, k in enumerate(('symptoms_to_hospital_days',
                                     'hospitalization_length_of_stay_general',
                                     'hospitalization_length_of_stay_icu')):

            end_time = timedelta(days=self.SEIR_kwargs[k])
            x = start_intervention_date + running_total
            y = 1.5**(i_label + 1)
            plt.errorbar(x=[x],
                         y=[y],
                         xerr=[[timedelta(days=0)], [end_time]],
                         marker='',
                         capsize=8,
                         color='k',
                         elinewidth=3,
                         capthick=3)
            plt.text(x + (end_time + timedelta(days=2)),
                     y,
                     k.replace('_', ' ').title(),
                     fontsize=14)
            running_total += end_time

        if self.SEIR_kwargs['beds_ICU'] > 0:
            plt.hlines(self.SEIR_kwargs['beds_ICU'],
                       *plt.xlim(),
                       color='k',
                       linestyles='-',
                       linewidths=6,
                       alpha=0.2)
            plt.text(data_dates[0] + timedelta(days=5),
                     self.SEIR_kwargs['beds_ICU'] * 1.1,
                     'Available ICU Capacity',
                     color='k',
                     alpha=0.5,
                     fontsize=15)

        plt.ylim(*y_lim)
        plt.xlim(min(model_dates[0], data_dates[0]),
                 data_dates[-1] + timedelta(days=150))
        plt.xticks(rotation=30, fontsize=14)
        plt.yticks(fontsize=14)
        plt.legend(loc=4, fontsize=14)
        plt.grid(which='both', alpha=.5)
        plt.title(self.display_name, fontsize=20)

        for i, (k, v) in enumerate(self.fit_results.items()):

            fontweight = 'bold' if k in ('R0', 'Reff') else 'normal'

            if np.isscalar(v) and not isinstance(v, str):
                plt.text(1.05,
                         .7 - 0.032 * i,
                         f'{k}={v:1.3f}',
                         transform=plt.gca().transAxes,
                         fontsize=15,
                         alpha=.6,
                         fontweight=fontweight)
            else:
                plt.text(1.05,
                         .7 - 0.032 * i,
                         f'{k}={v}',
                         transform=plt.gca().transAxes,
                         fontsize=15,
                         alpha=.6,
                         fontweight=fontweight)

        output_file = get_run_artifact_path(self.fips,
                                            RunArtifact.MLE_FIT_REPORT)
        plt.savefig(output_file, bbox_inches='tight')
        plt.close()

        self.mle_model.plot_results()
        plt.savefig(output_file.replace('mle_fit_results', 'mle_fit_model'),
                    bbox_inches='tight')
        plt.close()
Пример #26
0
    def infer_all(self, plot=True, shift_deaths=0):
        """
        Infer R_t from all available data sources.

        Parameters
        ----------
        plot: bool
            If True, generate a plot of the inference.
        shift_deaths: int
            Shift the death time series by this amount with respect to cases
            (when plotting only, does not shift the returned result).

        Returns
        -------
        inference_results: pd.DataFrame
            Columns containing MAP estimates and confidence intervals.
        """
        df_all = None
        available_timeseries = self.get_available_timeseries()

        for timeseries_type in available_timeseries:
            # Add Raw Data Output to Output Dataframe
            dates_raw, times_raw, timeseries_raw = self.get_timeseries(timeseries_type)
            df_raw = pd.DataFrame()
            df_raw["date"] = dates_raw
            df_raw = df_raw.set_index("date")
            df_raw[timeseries_type.value] = timeseries_raw

            df = pd.DataFrame()
            dates, times, posteriors, start_idx = self.get_posteriors(timeseries_type)
            # Note that it is possible for the dates to be missing days
            # This can cause problems when:
            #   1) computing posteriors that assume continuous data (above),
            #   2) when merging data with variable keys
            if posteriors is None:
                continue

            df[f"Rt_MAP__{timeseries_type.value}"] = posteriors.idxmax()
            for ci in self.confidence_intervals:
                ci_low, ci_high = self.highest_density_interval(posteriors, ci=ci)

                low_val = 1 - ci
                high_val = ci
                df[f"Rt_ci{int(math.floor(100 * low_val))}__{timeseries_type.value}"] = ci_low
                df[f"Rt_ci{int(math.floor(100 * high_val))}__{timeseries_type.value}"] = ci_high

            df["date"] = dates
            df = df.set_index("date")

            if df_all is None:
                df_all = df
            else:
                # To avoid any surprises merging the data, keep only the keys from the case data
                # which will be the first added to df_all. So merge with how ="left" rather than "outer"
                df_all = df_all.merge(df_raw, left_index=True, right_index=True, how="left")
                df_all = df_all.merge(df, left_index=True, right_index=True, how="left")

            # ------------------------------------------------
            # Compute the indicator lag using the curvature
            # alignment method.
            # ------------------------------------------------
            if (
                timeseries_type in (TimeseriesType.NEW_DEATHS, TimeseriesType.NEW_HOSPITALIZATIONS)
                and f"Rt_MAP__{TimeseriesType.NEW_CASES.value}" in df_all.columns
            ):

                # Go back up to 30 days or the max time series length we have if shorter.
                last_idx = max(-21, -len(df))
                series_a = df_all[f"Rt_MAP__{TimeseriesType.NEW_CASES.value}"].iloc[-last_idx:]
                series_b = df_all[f"Rt_MAP__{timeseries_type.value}"].iloc[-last_idx:]

                shift_in_days = self.align_time_series(series_a=series_a, series_b=series_b,)

                df_all[f"lag_days__{timeseries_type.value}"] = shift_in_days
                logging.debug(
                    "Using timeshift of: %s for timeseries type: %s ",
                    shift_in_days,
                    timeseries_type,
                )
                # Shift all the columns.
                for col in df_all.columns:
                    if timeseries_type.value in col:
                        df_all[col] = df_all[col].shift(shift_in_days)
                        # Extend death and hopitalization rt signals beyond
                        # shift to avoid sudden jumps in composite metric.
                        #
                        # N.B interpolate() behaves differently depending on the location
                        # of the missing values: For any nans appearing in between valid
                        # elements of the series, an interpolated value is filled in.
                        # For values at the end of the series, the last *valid* value is used.
                        logging.debug("Filling in %s missing values", shift_in_days)
                        df_all[col] = df_all[col].interpolate(
                            limit_direction="forward", method="linear"
                        )

        if df_all is None:
            logging.warning("Inference not possible for fips: %s", self.fips)
            return None

        if (
            not InferRtConstants.DISABLE_DEATHS
            and "Rt_MAP__new_deaths" in df_all
            and "Rt_MAP__new_cases" in df_all
        ):
            df_all["Rt_MAP_composite"] = np.nanmean(
                df_all[["Rt_MAP__new_cases", "Rt_MAP__new_deaths"]], axis=1
            )
            # Just use the Stdev of cases. A correlated quadrature summed error
            # would be better, but is also more confusing and difficult to fix
            # discontinuities between death and case errors since deaths are
            # only available for a subset. Systematic errors are much larger in
            # any case.
            df_all["Rt_ci95_composite"] = df_all["Rt_ci95__new_cases"]

        elif "Rt_MAP__new_cases" in df_all:
            df_all["Rt_MAP_composite"] = df_all["Rt_MAP__new_cases"]
            df_all["Rt_ci95_composite"] = df_all["Rt_ci95__new_cases"]

        # Correct for tail suppression where Rt is incorrectly forced towards 1 as
        # case smoothing lags when approaching end of time series
        suppression = 1.0 * np.ones(len(df_all))
        if (
            InferRtConstants.CORRECT_TAIL_SUPRESSION > 0.0
            and InferRtConstants.CORRECT_TAIL_SUPRESSION <= 1.0
        ):
            tail_sup = self.evaluate_head_tail_suppression()
            # Calculate rt suppression by smoothing delay at tail of sequence
            # and pad with 1s at front (which won't change anything) to apply
            suppression = np.concatenate(
                [1.0 * np.ones(len(df_all) - len(tail_sup)), tail_sup.values]
            )
            # Adjust rt by undoing the supppression. If Rt does not lag likelihood
            # at all then wouldn't need np.power term and just a flag to enable. But
            # when lagging linear use of suppression correction seems to over-adjust
            # so use some power CORRECT_TAIL_SUPPRESSION between 0. and 1.
            df_all["Rt_MAP_composite"] = (df_all["Rt_MAP_composite"] - 1.0) / np.power(
                suppression, InferRtConstants.CORRECT_TAIL_SUPRESSION
            ) + 1.0

        # Optionally Smooth just Rt_MAP_composite.
        # Note this doesn't lag in time and preserves integral of Rteff over time
        for i in range(0, InferRtConstants.SMOOTH_RT_MAP_COMPOSITE):
            kernel_width = round(InferRtConstants.RT_SMOOTHING_WINDOW_SIZE / 4)
            smoothed = (
                df_all["Rt_MAP_composite"]
                .rolling(
                    InferRtConstants.RT_SMOOTHING_WINDOW_SIZE,
                    win_type="gaussian",
                    min_periods=kernel_width,
                    center=True,
                )
                .mean(std=kernel_width)
            )

            # Adjust down confidence interval due to count smoothing over kernel_width values but not
            # below threshold. Adjust confidence interval for tail suppression adjustment also.
            df_all["Rt_MAP_composite"] = smoothed
            df_all["Rt_ci95_composite"] = (
                (df_all["Rt_ci95_composite"] - df_all["Rt_MAP_composite"])
                / math.sqrt(
                    2.0 * kernel_width  # averaging over many points reduces confidence interval
                )
                / np.power(suppression, InferRtConstants.CORRECT_TAIL_SUPRESSION / 2)
            ).apply(lambda v: max(v, InferRtConstants.MIN_CONF_WIDTH)) + df_all["Rt_MAP_composite"]

        if plot:
            plt.figure(figsize=(10, 6))

            # plt.hlines([1.0], *plt.xlim(), alpha=1, color="g")
            # plt.hlines([1.1], *plt.xlim(), alpha=1, color="gold")
            # plt.hlines([1.3], *plt.xlim(), alpha=1, color="r")

            if "Rt_ci5__new_deaths" in df_all:
                if not InferRtConstants.DISABLE_DEATHS:
                    plt.fill_between(
                        df_all.index,
                        df_all["Rt_ci5__new_deaths"],
                        df_all["Rt_ci95__new_deaths"],
                        alpha=0.2,
                        color="firebrick",
                    )
                # Show for reference even if not used
                plt.scatter(
                    df_all.index,
                    df_all["Rt_MAP__new_deaths"].shift(periods=shift_deaths),
                    alpha=1,
                    s=25,
                    color="firebrick",
                    label="New Deaths",
                )

            if "Rt_ci5__new_cases" in df_all:
                if not InferRtConstants.DISABLE_DEATHS:
                    plt.fill_between(
                        df_all.index,
                        df_all["Rt_ci5__new_cases"],
                        df_all["Rt_ci95__new_cases"],
                        alpha=0.2,
                        color="steelblue",
                    )
                plt.scatter(
                    df_all.index,
                    df_all["Rt_MAP__new_cases"],
                    alpha=1,
                    s=25,
                    color="steelblue",
                    label="New Cases",
                    marker="s",
                )

            if "Rt_ci5__new_hospitalizations" in df_all:
                if not InferRtConstants.DISABLE_DEATHS:
                    plt.fill_between(
                        df_all.index,
                        df_all["Rt_ci5__new_hospitalizations"],
                        df_all["Rt_ci95__new_hospitalizations"],
                        alpha=0.4,
                        color="darkseagreen",
                    )
                # Show for reference even if not used
                plt.scatter(
                    df_all.index,
                    df_all["Rt_MAP__new_hospitalizations"],
                    alpha=1,
                    s=25,
                    color="darkseagreen",
                    label="New Hospitalizations",
                    marker="d",
                )

            if "Rt_MAP_composite" in df_all:
                plt.scatter(
                    df_all.index,
                    df_all["Rt_MAP_composite"],
                    alpha=1,
                    s=25,
                    color="black",
                    label="Inferred $R_{t}$ Web",
                    marker="d",
                )

            if "Rt_ci95_composite" in df_all:
                plt.fill_between(
                    df_all.index,
                    df_all["Rt_ci95_composite"],
                    2 * df_all["Rt_MAP_composite"] - df_all["Rt_ci95_composite"],
                    alpha=0.2,
                    color="gray",
                )

            plt.hlines([0.9], *plt.xlim(), alpha=1, color="g")
            plt.hlines([1.1], *plt.xlim(), alpha=1, color="gold")
            plt.hlines([1.4], *plt.xlim(), alpha=1, color="r")

            plt.xticks(rotation=30)
            plt.grid(True)
            plt.xlim(df_all.index.min() - timedelta(days=2), df_all.index.max() + timedelta(days=2))
            plt.ylim(0.0, 3.0)
            plt.ylabel("$R_t$", fontsize=16)
            plt.legend()
            plt.title(self.display_name, fontsize=16)

            output_path = get_run_artifact_path(self.fips, RunArtifact.RT_INFERENCE_REPORT)
            plt.savefig(output_path, bbox_inches="tight")
            plt.close()
        if df_all.empty:
            logging.warning("Inference not possible for fips: %s", self.fips)
        return df_all
Пример #27
0
    def apply_gaussian_smoothing(self, timeseries_type, plot=True, smoothed_max_threshold=5):
        """
        Apply a rolling Gaussian window to smooth the data. This signature and
        returns match get_time_series, but will return a subset of the input
        time-series starting at the first non-zero value.

        Parameters
        ----------
        timeseries_type: TimeseriesType
            Which type of time-series to use.
        plot: bool
            If True, plot smoothed and original data.
        smoothed_max_threshold: int
            This parameter allows you to filter out entire series
            (e.g. NEW_DEATHS) when they do not contain high enough
            numeric values. This has been added to account for low-level
            constant smoothed values having a disproportionate effect on
            our final R(t) calculation, when all of their values are below
            this parameter.

        Returns
        -------
        dates: array-like
            Input data over a subset of indices available after windowing.
        times: array-like
            Output integers since the reference date.
        smoothed: array-like
            Gaussian smoothed data.


        """
        timeseries_type = TimeseriesType(timeseries_type)
        dates, times, timeseries = self.get_timeseries(timeseries_type)
        self.log = self.log.bind(timeseries_type=timeseries_type.value)

        # Don't even try if the timeseries is too short (Florida hospitalizations failing with length=6)
        if len(timeseries) < InferRtConstants.MIN_TIMESERIES_LENGTH:
            return [], [], []

        # Hospitalizations have a strange effect in the first few data points across many states.
        # Let's just drop those..
        if timeseries_type in (
            TimeseriesType.CURRENT_HOSPITALIZATIONS,
            TimeseriesType.NEW_HOSPITALIZATIONS,
        ):
            dates, times, timeseries = dates[2:], times[:2], timeseries[2:]

        # Remove Outliers Before Smoothing. Replaces a value if the current is more than 10 std
        # from the 14 day trailing mean and std
        timeseries = replace_outliers(pd.Series(timeseries), log=self.log)

        # Smoothing no longer involves rounding
        smoothed = timeseries.rolling(
            self.window_size, win_type="gaussian", min_periods=self.kernel_std, center=True
        ).mean(std=self.kernel_std)

        # Retain logic for detecting what would be nonzero values if rounded
        nonzeros = [idx for idx, val in enumerate(smoothed.round()) if val != 0]

        if smoothed.empty:
            idx_start = 0
        elif max(smoothed) < smoothed_max_threshold:
            # skip the entire array.
            idx_start = len(smoothed)
        else:
            idx_start = nonzeros[0]

        smoothed = smoothed.iloc[idx_start:]
        original = timeseries.loc[smoothed.index]

        # Only plot counts and smoothed timeseries for cases
        if plot and timeseries_type == TimeseriesType.NEW_CASES and len(smoothed) > 0:
            plt.figure(figsize=(10, 6))
            plt.scatter(
                dates[-len(original) :],
                original,
                alpha=0.3,
                label=timeseries_type.value.replace("_", " ").title() + "Shifted",
            )
            plt.plot(dates[-len(original) :], smoothed)
            plt.grid(True, which="both")
            plt.xticks(rotation=30)
            plt.xlim(min(dates[-len(original) :]), max(dates) + timedelta(days=2))
            # plt.legend()
            output_path = get_run_artifact_path(self.fips, RunArtifact.RT_SMOOTHING_REPORT)
            plt.savefig(output_path, bbox_inches="tight")
            plt.close()

        return dates, times, smoothed
    def write_region(self, regional_input: RegionalInput) -> None:
        """Generates the CAN UI output format for a given region.

        Args:
            regional_input: the region and its data
        """
        # Get the latest observed values to use in calculating shims
        observed_latest_dict = regional_input.latest

        state = observed_latest_dict[CommonFields.STATE]
        log.info("Mapping output to WebUI.",
                 state=state,
                 fips=regional_input.fips)
        shim_log = structlog.getLogger(fips=regional_input.fips)
        pyseir_outputs = regional_input.ensemble_results()

        try:
            fit_results = regional_input.inference_result()
            t0_simulation = datetime.fromisoformat(fit_results["t0_date"])
        except (KeyError, ValueError):
            log.error("Fit result not found for fips. Skipping...",
                      fips=regional_input.fips)
            return
        population = regional_input.population

        # We will shim all suppression policies by the same amount (since historical tracking error
        # for all policies is the same).
        baseline_policy = "suppression_policy__inferred"  # This could be any valid policy

        # We need the index in the model's temporal frame.
        idx_offset = int(fit_results["t_today"] - fit_results["t0"])

        observed_death_latest = observed_latest_dict[CommonFields.DEATHS]
        observed_total_hosps_latest = observed_latest_dict[
            CommonFields.CURRENT_HOSPITALIZED]
        observed_icu_latest = observed_latest_dict[CommonFields.CURRENT_ICU]

        # For Deaths
        model_death_latest = pyseir_outputs[baseline_policy]["total_deaths"][
            "ci_50"][idx_offset]
        model_acute_latest = pyseir_outputs[baseline_policy]["HGen"]["ci_50"][
            idx_offset]
        model_icu_latest = pyseir_outputs[baseline_policy]["HICU"]["ci_50"][
            idx_offset]
        model_total_hosps_latest = model_acute_latest + model_icu_latest

        death_shim = shim.calculate_strict_shim(
            model=model_death_latest,
            observed=observed_death_latest,
            log=shim_log.bind(type=CommonFields.DEATHS),
        )

        total_hosp_shim = shim.calculate_strict_shim(
            model=model_total_hosps_latest,
            observed=observed_total_hosps_latest,
            log=shim_log.bind(type=CommonFields.CURRENT_HOSPITALIZED),
        )

        # For ICU This one is a little more interesting since we often don't have ICU. In this case
        # we use information from the same aggregation level (intralevel) to keep the ratios
        # between general hospitalization and icu hospitalization
        icu_shim = shim.calculate_intralevel_icu_shim(
            model_acute=model_acute_latest,
            model_icu=model_icu_latest,
            observed_icu=observed_icu_latest,
            observed_total_hosps=observed_total_hosps_latest,
            log=shim_log.bind(type=CommonFields.CURRENT_ICU),
        )

        # Iterate through each suppression policy.
        # Model output is interpolated to the dates desired for the API.
        suppression_policies = [
            key for key in pyseir_outputs.keys()
            if key.startswith("suppression_policy")
        ]
        for suppression_policy in suppression_policies:
            output_for_policy = pyseir_outputs[suppression_policy]
            output_model = pd.DataFrame()
            t_list = output_for_policy["t_list"]
            t_list_downsampled = range(0, int(max(t_list)),
                                       self.output_interval_days)

            output_model[schema.DAY_NUM] = t_list_downsampled
            output_model[schema.DATE] = [
                (t0_simulation + timedelta(days=t)).date().strftime("%Y-%m-%d")
                for t in t_list_downsampled
            ]
            output_model[schema.TOTAL] = population
            output_model[schema.TOTAL_SUSCEPTIBLE] = np.interp(
                t_list_downsampled, t_list, output_for_policy["S"]["ci_50"])
            output_model[schema.EXPOSED] = np.interp(
                t_list_downsampled, t_list, output_for_policy["E"]["ci_50"])
            output_model[schema.INFECTED] = np.interp(
                t_list_downsampled,
                t_list,
                np.add(output_for_policy["I"]["ci_50"],
                       output_for_policy["A"]["ci_50"]),
            )  # Infected + Asympt.
            output_model[schema.INFECTED_A] = output_model[schema.INFECTED]

            interpolated_model_acute_values = np.interp(
                t_list_downsampled, t_list, output_for_policy["HGen"]["ci_50"])
            output_model[schema.INFECTED_B] = interpolated_model_acute_values

            raw_model_icu_values = output_for_policy["HICU"]["ci_50"]
            interpolated_model_icu_values = np.interp(t_list_downsampled,
                                                      t_list,
                                                      raw_model_icu_values)

            # 21 August 2020: The line assigning schema.INFECTED_C in the output_model is
            # commented out while the Linear Regression estimator is patched through this pipeline
            # to be consumed downstream by the ICU utilization calculations. It is left here as a
            # marker for the future if the ICU utilization calculations is dis-entangled from the
            # PySEIR model outputs.

            output_model[schema.INFECTED_C] = (
                icu_shim + interpolated_model_icu_values).clip(min=0)

            # General + ICU beds. don't include vent here because they are also counted in ICU
            output_model[schema.ALL_HOSPITALIZED] = (
                interpolated_model_acute_values +
                interpolated_model_icu_values + total_hosp_shim).clip(min=0)

            output_model[schema.ALL_INFECTED] = output_model[schema.INFECTED]

            # Shim Deaths to Match Observed
            raw_model_deaths_values = output_for_policy["total_deaths"][
                "ci_50"]
            interp_model_deaths_values = np.interp(t_list_downsampled, t_list,
                                                   raw_model_deaths_values)
            output_model[schema.DEAD] = (interp_model_deaths_values +
                                         death_shim).clip(min=0)

            # Continue mapping
            final_beds = np.mean(output_for_policy["HGen"]["capacity"])
            output_model[schema.BEDS] = final_beds
            output_model[schema.CUMULATIVE_INFECTED] = np.interp(
                t_list_downsampled,
                t_list,
                np.cumsum(output_for_policy["total_new_infections"]["ci_50"]),
            )

            if fit_results:
                output_model[schema.Rt] = np.interp(
                    t_list_downsampled,
                    t_list,
                    fit_results["eps2"] * fit_results["R0"] *
                    np.ones(len(t_list)),
                )
                output_model[schema.Rt_ci90] = np.interp(
                    t_list_downsampled,
                    t_list,
                    2 * fit_results["eps2_error"] * fit_results["R0"] *
                    np.ones(len(t_list)),
                )
            else:
                output_model[schema.Rt] = 0
                output_model[schema.Rt_ci90] = 0

            output_model[schema.CURRENT_VENTILATED] = (
                icu_shim +
                np.interp(t_list_downsampled, t_list,
                          output_for_policy["HVent"]["ci_50"])).clip(min=0)
            output_model[schema.POPULATION] = population
            # Average capacity.
            output_model[schema.ICU_BED_CAPACITY] = np.mean(
                output_for_policy["HICU"]["capacity"])
            output_model[schema.VENTILATOR_CAPACITY] = np.mean(
                output_for_policy["HVent"]["capacity"])

            # Truncate date range of output.
            output_dates = pd.to_datetime(output_model["date"])
            output_model = output_model[
                (output_dates >= datetime(month=3, day=3, year=2020))
                & (output_dates < datetime.today() + timedelta(days=90))]

            # Fill in results for the Rt indicator.
            rt_results = regional_input.inferred_infection_rate()

            if rt_results is None or rt_results.empty:
                log.warning(
                    "No Rt Results found, clearing Rt in output.",
                    fips=regional_input.fips,
                    suppression_policy=suppression_policy,
                )
                output_model[schema.RT_INDICATOR] = "NaN"
                output_model[schema.RT_INDICATOR_CI90] = "NaN"
            else:
                rt_results.index = rt_results["date"].dt.strftime("%Y-%m-%d")
                merged = output_model.merge(
                    rt_results[["Rt_MAP_composite", "Rt_ci95_composite"]],
                    right_index=True,
                    left_on="date",
                    how="left",
                )
                output_model[schema.RT_INDICATOR] = merged["Rt_MAP_composite"]

                # With 90% probability the value is between rt_indicator - ci90
                # to rt_indicator + ci90
                output_model[schema.RT_INDICATOR_CI90] = (
                    merged["Rt_ci95_composite"] - merged["Rt_MAP_composite"])

            output_model[[schema.RT_INDICATOR,
                          schema.RT_INDICATOR_CI90]] = output_model[[
                              schema.RT_INDICATOR, schema.RT_INDICATOR_CI90
                          ]].fillna("NaN")

            int_columns = [
                col for col in output_model.columns if col not in (
                    schema.DATE,
                    schema.Rt,
                    schema.Rt_ci90,
                    schema.RT_INDICATOR,
                    schema.RT_INDICATOR_CI90,
                    schema.FIPS,
                )
            ]
            # Casing floats to ints and then replacing filled in zeros with NaN values instead of
            # propagating zeros.
            na_int_columns = output_model.loc[:, int_columns].isna()
            output_model.loc[:,
                             int_columns] = output_model[int_columns].fillna(
                                 0).astype(int)
            output_model[na_int_columns] = np.nan
            output_model.loc[:, [
                schema.Rt, schema.Rt_ci90, schema.RT_INDICATOR, schema.
                RT_INDICATOR_CI90
            ]] = output_model[[
                schema.Rt, schema.Rt_ci90, schema.RT_INDICATOR,
                schema.RT_INDICATOR_CI90
            ]]

            output_model[schema.FIPS] = regional_input.fips
            intervention = Intervention.from_webui_data_adaptor(
                suppression_policy)
            output_model[schema.INTERVENTION] = intervention.value
            output_path = get_run_artifact_path(regional_input.region,
                                                RunArtifact.WEB_UI_RESULT,
                                                output_dir=self.output_dir)
            output_path = output_path.replace("__INTERVENTION_IDX__",
                                              str(intervention.value))
            output_model.to_json(output_path, orient=OUTPUT_JSON_ORIENT)
Пример #29
0
    def infer_all(self, plot=False, shift_deaths=0):
        """
        Infer R_t from all available data sources.

        Parameters
        ----------
        plot: bool
            If True, generate a plot of the inference.
        shift_deaths: int
            Shift the death time series by this amount with respect to cases
            (when plotting only, does not shift the returned result).

        Returns
        -------
        inference_results: pd.DataFrame
            Columns containing MAP estimates and confidence intervals.
        """
        df_all = None
        available_timeseries = []

        if len(self.get_timeseries(TimeseriesType.NEW_CASES)) > 0:
            available_timeseries.append(TimeseriesType.NEW_CASES)

        if len(self.get_timeseries(TimeseriesType.NEW_DEATHS)) > 0:
            available_timeseries.append(TimeseriesType.NEW_DEATHS)

        if self.hospitalization_data_type is load_data.HospitalizationDataType.CURRENT_HOSPITALIZATIONS:
            # We have converted this timeseries to new hospitalizations.
            available_timeseries.append(TimeseriesType.NEW_HOSPITALIZATIONS)
        elif self.hospitalization_data_type is load_data.HospitalizationDataType.CUMULATIVE_HOSPITALIZATIONS:
            available_timeseries.append(TimeseriesType.NEW_HOSPITALIZATIONS)

        for timeseries_type in available_timeseries:

            df = pd.DataFrame()
            dates, times, posteriors = self.get_posteriors(timeseries_type)
            if posteriors is not None:
                df[f'Rt_MAP__{timeseries_type.value}'] = posteriors.idxmax()
                for ci in self.confidence_intervals:
                    ci_low, ci_high = self.highest_density_interval(posteriors, ci=ci)
                    df[f'Rt_ci{int(100 * (1 - ci / 2))}__{timeseries_type.value}'] = ci_low
                    df[f'Rt_ci{int(100 * ci / 2)}__{timeseries_type.value}'] = ci_high

                df['date'] = dates
                df = df.set_index('date')

                if df_all is None:
                    df_all = df
                else:
                    df_all = df_all.merge(df, left_index=True, right_index=True, how='outer')

        if plot:
            plt.figure(figsize=(10, 6))

            if 'Rt_ci5__new_deaths' in df_all:
                plt.fill_between(df_all.index,  df_all['Rt_ci5__new_deaths'],  df_all['Rt_ci95__new_deaths'],
                                 alpha=.2, color='firebrick')
                plt.scatter(df_all.index, df_all['Rt_MAP__new_deaths'].shift(periods=shift_deaths),
                            alpha=1, s=25, color='firebrick', label='New Deaths')

            if 'Rt_ci5__new_cases' in df_all:
                plt.fill_between(df_all.index, df_all['Rt_ci5__new_cases'], df_all['Rt_ci95__new_cases'],
                                 alpha=.2, color='steelblue')
                plt.scatter(df_all.index, df_all['Rt_MAP__new_cases'],
                            alpha=1, s=25, color='steelblue', label='New Cases', marker='s')

            if self.hospitalization_data_type:
                plt.fill_between(df_all.index, df_all['Rt_ci5__new_hospitalizations'], df_all['Rt_ci95__new_hospitalizations'],
                                 alpha=.4, color='darkseagreen')
                plt.scatter(df_all.index, df_all['Rt_MAP__new_hospitalizations'],
                            alpha=1, s=25, color='darkseagreen', label='New Hospitalizations', marker='d')

            plt.hlines([1.0], *plt.xlim(), alpha=1, color='g')
            plt.hlines([1.1], *plt.xlim(), alpha=1, color='gold')
            plt.hlines([1.3], *plt.xlim(), alpha=1, color='r')

            plt.xticks(rotation=30)
            plt.grid(True)
            plt.xlim(df_all.index.min() - timedelta(days=2), df_all.index.max() + timedelta(days=2))
            plt.ylim(0, 5)
            plt.ylabel('$R_t$', fontsize=16)
            plt.legend()
            plt.title(self.display_name, fontsize=16)

            output_path = get_run_artifact_path(self.fips, RunArtifact.RT_INFERENCE_REPORT)
            plt.savefig(output_path, bbox_inches='tight')
            plt.close()

        return df_all
Пример #30
0
    def infer_all(self, plot=False, shift_deaths=0):
        """
        Infer R_t from all available data sources.

        Parameters
        ----------
        plot: bool
            If True, generate a plot of the inference.
        shift_deaths: int
            Shift the death time series by this amount with respect to cases
            (when plotting only, does not shift the returned result).

        Returns
        -------
        inference_results: pd.DataFrame
            Columns containing MAP estimates and confidence intervals.
        """
        df_all = None
        available_timeseries = []
        IDX_OF_COUNTS = 2
        cases = self.get_timeseries(TimeseriesType.NEW_CASES.value)[IDX_OF_COUNTS]
        deaths = self.get_timeseries(TimeseriesType.NEW_DEATHS.value)[IDX_OF_COUNTS]
        if self.hospitalization_data_type:
            hosps = self.get_timeseries(TimeseriesType.NEW_HOSPITALIZATIONS.value)[IDX_OF_COUNTS]

        if np.sum(cases) > self.min_cases:
            available_timeseries.append(TimeseriesType.NEW_CASES)

        if np.sum(deaths) > self.min_deaths:
            available_timeseries.append(TimeseriesType.NEW_DEATHS)

        if self.hospitalization_data_type is load_data.HospitalizationDataType.CURRENT_HOSPITALIZATIONS and len(hosps > 3):
            # We have converted this timeseries to new hospitalizations.
            available_timeseries.append(TimeseriesType.NEW_HOSPITALIZATIONS)
        elif self.hospitalization_data_type is load_data.HospitalizationDataType.CUMULATIVE_HOSPITALIZATIONS and len(hosps > 3):
            available_timeseries.append(TimeseriesType.NEW_HOSPITALIZATIONS)

        for timeseries_type in available_timeseries:

            df = pd.DataFrame()
            dates, times, posteriors = self.get_posteriors(timeseries_type)
            if posteriors is not None:
                df[f'Rt_MAP__{timeseries_type.value}'] = posteriors.idxmax()
                for ci in self.confidence_intervals:
                    ci_low, ci_high = self.highest_density_interval(posteriors, ci=ci)

                    low_val = 1 - ci
                    high_val = ci
                    df[f'Rt_ci{int(math.floor(100 * low_val))}__{timeseries_type.value}'] = ci_low
                    df[f'Rt_ci{int(math.floor(100 * high_val))}__{timeseries_type.value}'] = ci_high

                df['date'] = dates
                df = df.set_index('date')

                if df_all is None:
                    df_all = df
                else:
                    df_all = df_all.merge(df, left_index=True, right_index=True, how='outer')

                # Compute the indicator lag using the curvature alignment method.
                if timeseries_type in (TimeseriesType.NEW_DEATHS, TimeseriesType.NEW_HOSPITALIZATIONS) \
                        and f'Rt_MAP__{TimeseriesType.NEW_CASES.value}' in df_all.columns:
                    # Go back upto 30 days or the max time series length we have if shorter.
                    last_idx = max(-21, -len(df))
                    shift_in_days = self.align_time_series(
                        series_a=df_all[f'Rt_MAP__{TimeseriesType.NEW_CASES.value}'].iloc[-last_idx:],
                        series_b=df_all[f'Rt_MAP__{timeseries_type.value}'].iloc[-last_idx:]
                    )
                    df_all[f'lag_days__{timeseries_type.value}'] = shift_in_days

                    # Shift all the columns.
                    for col in df_all.columns:
                        if timeseries_type.value in col:
                            df_all[col] = df_all[col].shift(shift_in_days)

        if df_all is not None and 'Rt_MAP__new_deaths' in df_all and 'Rt_MAP__new_cases' in df_all:
            df_all['Rt_MAP_composite'] = np.nanmean(df_all[['Rt_MAP__new_cases', 'Rt_MAP__new_deaths']], axis=1)
            # Just use the Stdev of cases. A correlated quadrature summed error
            # would be better, but is also more confusing and difficult to fix
            # discontinuities between death and case errors since deaths are
            # only available for a subset. Systematic errors are much larger in
            # any case.
            df_all['Rt_ci95_composite'] = df_all['Rt_ci95__new_cases']

        elif df_all is not None and 'Rt_MAP__new_cases' in df_all:
            df_all['Rt_MAP_composite'] = df_all['Rt_MAP__new_cases']
            df_all['Rt_ci95_composite'] = df_all['Rt_ci95__new_cases']

        if plot:
            plt.figure(figsize=(10, 6))

            if 'Rt_ci5__new_deaths' in df_all:
                plt.fill_between(df_all.index,  df_all['Rt_ci5__new_deaths'],  df_all['Rt_ci95__new_deaths'],
                                 alpha=.2, color='firebrick')
                plt.scatter(df_all.index, df_all['Rt_MAP__new_deaths'].shift(periods=shift_deaths),
                            alpha=1, s=25, color='firebrick', label='New Deaths')

            if 'Rt_ci5__new_cases' in df_all:
                plt.fill_between(df_all.index, df_all['Rt_ci5__new_cases'], df_all['Rt_ci95__new_cases'],
                                 alpha=.2, color='steelblue')
                plt.scatter(df_all.index, df_all['Rt_MAP__new_cases'],
                            alpha=1, s=25, color='steelblue', label='New Cases', marker='s')

            if 'Rt_ci5__new_hospitalizations' in df_all:
                plt.fill_between(df_all.index, df_all['Rt_ci5__new_hospitalizations'], df_all['Rt_ci95__new_hospitalizations'],
                                 alpha=.4, color='darkseagreen')
                plt.scatter(df_all.index, df_all['Rt_MAP__new_hospitalizations'],
                            alpha=1, s=25, color='darkseagreen', label='New Hospitalizations', marker='d')

            plt.hlines([1.0], *plt.xlim(), alpha=1, color='g')
            plt.hlines([1.1], *plt.xlim(), alpha=1, color='gold')
            plt.hlines([1.3], *plt.xlim(), alpha=1, color='r')

            plt.xticks(rotation=30)
            plt.grid(True)
            plt.xlim(df_all.index.min() - timedelta(days=2), df_all.index.max() + timedelta(days=2))
            plt.ylim(0, 5)
            plt.ylabel('$R_t$', fontsize=16)
            plt.legend()
            plt.title(self.display_name, fontsize=16)

            output_path = get_run_artifact_path(self.fips, RunArtifact.RT_INFERENCE_REPORT)
            plt.savefig(output_path, bbox_inches='tight')
            #plt.close()

        return df_all