Exemplo n.º 1
0
    def _load_model_for_fips(self, scenario='inferred'):
        """
        Try to load a model for the locale, else load the state level model
        and update parameters for the county.
        """
        artifact_path = get_run_artifact_path(self.fips, RunArtifact.MLE_FIT_MODEL)
        if os.path.exists(artifact_path):
            with open(artifact_path, 'rb') as f:
                model = pickle.load(f)
            inferred_params = fit_results.load_inference_result(self.fips)

        else:
            _logger.info(f'No MLE model found for {self.state_name}: {self.fips}. Reverting to state level.')
            artifact_path = get_run_artifact_path(self.fips[:2], RunArtifact.MLE_FIT_MODEL)
            if os.path.exists(artifact_path):
                with open(artifact_path, 'rb') as f:
                    model = pickle.load(f)
                inferred_params = fit_results.load_inference_result(self.fips[:2])
            else:
                raise FileNotFoundError(f'Could not locate state result for {self.state_name}')

            # Rescale state values to the county population and replace county
            # specific params.
            default_params = ParameterEnsembleGenerator(
                self.fips, N_samples=1, t_list=model.t_list,
                suppression_policy=model.suppression_policy).get_average_seir_parameters()
            population_ratio = default_params['N'] / model.N
            model.N *= population_ratio
            model.I_initial *= population_ratio
            model.E_initial *= population_ratio
            model.A_initial *= population_ratio
            model.S_initial = model.N - model.I_initial - model.E_initial - model.A_initial

            for key in {'beds_general', 'beds_ICU', 'ventilators'}:
                setattr(model, key, default_params[key])

        # Determine the appropriate future suppression policy based on the
        # scenario of interest.
        if scenario == 'inferred':
            eps_final = inferred_params['eps']
        else:
            eps_final = sp.get_future_suppression_from_r0(inferred_params['R0'], scenario=scenario)

        model.suppression_policy = sp.generate_two_step_policy(
            self.t_list,
            eps=inferred_params['eps'],
            t_break=inferred_params['t_break'],
            t_break_final=(datetime.datetime.today() - datetime.datetime.fromisoformat(inferred_params['t0_date'])).days,
            eps_final=eps_final
        )
        model.run()
        return model
Exemplo n.º 2
0
def get_compartment_value_on_date(fips,
                                  compartment,
                                  date,
                                  ensemble_results=None):
    """
    Return the value of compartment at a specified date.

    Parameters
    ----------
    fips: str
        State or County fips.
    compartment: str
        Name of the compartment to retrieve.
    date: datetime
        Date to retrieve values for.
    ensemble_results: NoneType or dict
        Pass in the pre-loaded simulation data to save time, else load it.
        Pass in the pre-loaded simulation data to save time, else load it.

    Returns
    -------
    value: float
        Value of compartment on a given date.
    """
    if ensemble_results is None:
        ensemble_results = load_ensemble_results(fips)
    # Circular import avoidance
    from pyseir.inference.fit_results import load_inference_result

    simulation_start_date = datetime.fromisoformat(
        load_inference_result(fips)["t0_date"])
    date_idx = int((date - simulation_start_date).days)
    return ensemble_results["suppression_policy__inferred"][compartment][
        "ci_50"][date_idx]
Exemplo n.º 3
0
    def set_inference_parameters(self):
        """
        Setup inference parameters based on data availability and manual
        overrides.  As data becomes more sparse, we further constrain the fit,
        which improves stability substantially.
        """
        self.fit_params = self.DEFAULT_FIT_PARAMS
        # Update any state specific params.
        for k, v in self.PARAM_SETS.items():
            if self.state_obj.abbr in k:
                self.fit_params.update(v)

        self.fit_params["fix_hosp_fraction"] = self.hospitalizations is None
        if self.fit_params["fix_hosp_fraction"]:
            self.fit_params["hosp_fraction"] = 1

        if len(self.fips) == 5:
            OBSERVED_NEW_CASES_GUESS_THRESHOLD = 2
            idx_enough_cases = np.argwhere(
                np.cumsum(self.observed_new_cases) >=
                OBSERVED_NEW_CASES_GUESS_THRESHOLD)[0][0]
            initial_cases_guess = np.cumsum(
                self.observed_new_cases)[idx_enough_cases]
            t0_guess = list(self.times)[idx_enough_cases]

            state_fit_result = load_inference_result(fips=self.state_obj.fips)
            self.fit_params["t0"] = t0_guess

            total_cases = np.sum(self.observed_new_cases)
            self.fit_params["log10_I_initial"] = np.log10(
                initial_cases_guess / self.fit_params["test_fraction"])
            self.fit_params["limit_t0"] = state_fit_result[
                "t0"] - 20, state_fit_result["t0"] + 30
            self.fit_params["t_break"] = state_fit_result["t_break"] - (
                t0_guess - state_fit_result["t0"])
            self.fit_params["R0"] = state_fit_result["R0"]
            self.fit_params["test_fraction"] = state_fit_result[
                "test_fraction"]
            self.fit_params["eps"] = state_fit_result["eps"]
            if total_cases < 100:
                self.fit_params["t_break"] = 10
                self.fit_params["fix_test_fraction"] = True
                self.fit_params["fix_R0"] = True
                self.fit_params["limit_t0"] = (
                    state_fit_result["t0"] - 5,
                    state_fit_result["t0"] + 30,
                )
            if total_cases < 50:
                self.fit_params["fix_eps"] = True
                self.fit_params["fix_t_break"] = True
Exemplo n.º 4
0
    def map_fips(self, fips):
        """
        For a given county fips code, generate the CAN UI output format.

        Parameters
        ----------
        fips: str
            County FIPS code to map.
        """
        logging.info(f"Mapping output to WebUI for {self.state}, {fips}")
        pyseir_outputs = load_data.load_ensemble_results(fips)

        if len(fips) == 5 and fips not in self.df_whitelist.fips.values:
            logging.info(f"Excluding {fips} due to white list...")
            return
        try:
            fit_results = load_inference_result(fips)
            t0_simulation = datetime.fromisoformat(fit_results["t0_date"])
        except (KeyError, ValueError):
            logging.error(f"Fit result not found for {fips}. Skipping...")
            return

        # ---------------------------------------------------------------------
        # Rescale hosps based on the population ratio... Could swap this to
        # infection ratio later?
        # ---------------------------------------------------------------------
        hosp_times, current_hosp, _ = load_data.load_hospitalization_data_by_state(
            state=self.state_abbreviation,
            t0=t0_simulation,
            convert_cumulative_to_current=True,
            category="hospitalized",
        )

        _, current_icu, _ = load_data.load_hospitalization_data_by_state(
            state=self.state_abbreviation,
            t0=t0_simulation,
            convert_cumulative_to_current=True,
            category="icu",
        )

        if len(fips) == 5:
            population = self.population_data.get_record_for_fips(fips)[CommonFields.POPULATION]
        else:
            population = self.population_data.get_record_for_state(self.state_abbreviation)[
                CommonFields.POPULATION
            ]

        # logging.info(f'Mapping output to WebUI for {self.state}, {fips}')
        # pyseir_outputs = load_data.load_ensemble_results(fips)
        # if pyseir_outputs is None:
        #     logging.warning(f'Fit result not found for {fips}: Skipping county')
        #     return None

        policies = [key for key in pyseir_outputs.keys() if key.startswith("suppression_policy")]
        if current_hosp is not None:
            t_latest_hosp_data, current_hosp = hosp_times[-1], current_hosp[-1]
            t_latest_hosp_data_date = t0_simulation + timedelta(days=int(t_latest_hosp_data))

            state_hosp_gen = load_data.get_compartment_value_on_date(
                fips=fips[:2], compartment="HGen", date=t_latest_hosp_data_date
            )
            state_hosp_icu = load_data.get_compartment_value_on_date(
                fips=fips[:2], compartment="HICU", date=t_latest_hosp_data_date
            )

            if len(fips) == 5:
                # Rescale the county level hospitalizations by the expected
                # ratio of county / state hospitalizations from simulations.
                # We use ICU data if available too.
                county_hosp = load_data.get_compartment_value_on_date(
                    fips=fips,
                    compartment="HGen",
                    date=t_latest_hosp_data_date,
                    ensemble_results=pyseir_outputs,
                )
                county_icu = load_data.get_compartment_value_on_date(
                    fips=fips,
                    compartment="HICU",
                    date=t_latest_hosp_data_date,
                    ensemble_results=pyseir_outputs,
                )
                current_hosp *= (county_hosp + county_icu) / (state_hosp_gen + state_hosp_icu)

            hosp_rescaling_factor = current_hosp / (state_hosp_gen + state_hosp_icu)

            # Some states have covidtracking issues. We shouldn't ground ICU cases
            # to zero since so far these have all been bad reporting.
            if current_icu is not None and current_icu[-1] > 0:
                icu_rescaling_factor = current_icu[-1] / state_hosp_icu
            else:
                icu_rescaling_factor = current_hosp / (state_hosp_gen + state_hosp_icu)
        else:
            hosp_rescaling_factor = 1.0
            icu_rescaling_factor = 1.0

        # Iterate through each suppression policy.
        # Model output is interpolated to the dates desired for the API.
        for i_policy, suppression_policy in enumerate(
            [key for key in pyseir_outputs.keys() if key.startswith("suppression_policy")]
        ):

            output_for_policy = pyseir_outputs[suppression_policy]
            output_model = pd.DataFrame()

            t_list = output_for_policy["t_list"]
            t_list_downsampled = range(0, int(max(t_list)), self.output_interval_days)

            output_model[schema.DAY_NUM] = t_list_downsampled
            output_model[schema.DATE] = [
                (t0_simulation + timedelta(days=t)).date().strftime("%m/%d/%y")
                for t in t_list_downsampled
            ]
            output_model[schema.TOTAL] = population
            output_model[schema.TOTAL_SUSCEPTIBLE] = np.interp(
                t_list_downsampled, t_list, output_for_policy["S"]["ci_50"]
            )
            output_model[schema.EXPOSED] = np.interp(
                t_list_downsampled, t_list, output_for_policy["E"]["ci_50"]
            )
            output_model[schema.INFECTED] = np.interp(
                t_list_downsampled,
                t_list,
                np.add(output_for_policy["I"]["ci_50"], output_for_policy["A"]["ci_50"]),
            )  # Infected + Asympt.
            output_model[schema.INFECTED_A] = output_model[schema.INFECTED]
            output_model[schema.INFECTED_B] = hosp_rescaling_factor * np.interp(
                t_list_downsampled, t_list, output_for_policy["HGen"]["ci_50"]
            )  # Hosp General
            output_model[schema.INFECTED_C] = icu_rescaling_factor * np.interp(
                t_list_downsampled, t_list, output_for_policy["HICU"]["ci_50"]
            )  # Hosp ICU
            # General + ICU beds. don't include vent here because they are also counted in ICU
            output_model[schema.ALL_HOSPITALIZED] = np.add(
                output_model[schema.INFECTED_B], output_model[schema.INFECTED_C]
            )
            output_model[schema.ALL_INFECTED] = output_model[schema.INFECTED]
            output_model[schema.DEAD] = np.interp(
                t_list_downsampled, t_list, output_for_policy["total_deaths"]["ci_50"]
            )
            final_beds = np.mean(output_for_policy["HGen"]["capacity"])
            output_model[schema.BEDS] = final_beds
            output_model[schema.CUMULATIVE_INFECTED] = np.interp(
                t_list_downsampled,
                t_list,
                np.cumsum(output_for_policy["total_new_infections"]["ci_50"]),
            )

            if fit_results:
                output_model[schema.Rt] = np.interp(
                    t_list_downsampled,
                    t_list,
                    fit_results["eps"] * fit_results["R0"] * np.ones(len(t_list)),
                )
                output_model[schema.Rt_ci90] = np.interp(
                    t_list_downsampled,
                    t_list,
                    2 * fit_results["eps_error"] * fit_results["R0"] * np.ones(len(t_list)),
                )
            else:
                output_model[schema.Rt] = 0
                output_model[schema.Rt_ci90] = 0

            output_model[schema.CURRENT_VENTILATED] = icu_rescaling_factor * np.interp(
                t_list_downsampled, t_list, output_for_policy["HVent"]["ci_50"]
            )
            output_model[schema.POPULATION] = population
            # Average capacity.
            output_model[schema.ICU_BED_CAPACITY] = np.mean(output_for_policy["HICU"]["capacity"])
            output_model[schema.VENTILATOR_CAPACITY] = np.mean(
                output_for_policy["HVent"]["capacity"]
            )

            # Truncate date range of output.
            output_dates = pd.to_datetime(output_model["date"])
            output_model = output_model[
                (output_dates >= datetime(month=3, day=3, year=2020))
                & (output_dates < datetime.today() + timedelta(days=90))
            ]
            output_model = output_model.fillna(0)

            # Fill in results for the Rt indicator.
            try:
                rt_results = load_Rt_result(fips)
                rt_results.index = rt_results["Rt_MAP_composite"].index.strftime("%m/%d/%y")
                merged = output_model.merge(
                    rt_results[["Rt_MAP_composite", "Rt_ci95_composite"]],
                    right_index=True,
                    left_on="date",
                    how="left",
                )
                output_model[schema.RT_INDICATOR] = merged["Rt_MAP_composite"]

                # With 90% probability the value is between rt_indicator - ci90 to rt_indicator + ci90
                output_model[schema.RT_INDICATOR_CI90] = (
                    merged["Rt_ci95_composite"] - merged["Rt_MAP_composite"]
                )
            except (ValueError, KeyError) as e:
                output_model[schema.RT_INDICATOR] = "NaN"
                output_model[schema.RT_INDICATOR_CI90] = "NaN"

            output_model[[schema.RT_INDICATOR, schema.RT_INDICATOR_CI90]] = output_model[
                [schema.RT_INDICATOR, schema.RT_INDICATOR_CI90]
            ].fillna("NaN")

            # Truncate floats and cast as strings to match data model.
            int_columns = [
                col
                for col in output_model.columns
                if col
                not in (
                    schema.DATE,
                    schema.Rt,
                    schema.Rt_ci90,
                    schema.RT_INDICATOR,
                    schema.RT_INDICATOR_CI90,
                )
            ]
            output_model.loc[:, int_columns] = (
                output_model[int_columns].fillna(0).astype(int).astype(str)
            )
            output_model.loc[
                :, [schema.Rt, schema.Rt_ci90, schema.RT_INDICATOR, schema.RT_INDICATOR_CI90]
            ] = (
                output_model[
                    [schema.Rt, schema.Rt_ci90, schema.RT_INDICATOR, schema.RT_INDICATOR_CI90]
                ]
                .fillna(0)
                .round(decimals=4)
                .astype(str)
            )

            # Convert the records format to just list(list(values))
            output_model = [
                [val for val in timestep.values()]
                for timestep in output_model.to_dict(orient="records")
            ]

            output_path = get_run_artifact_path(
                fips, RunArtifact.WEB_UI_RESULT, output_dir=self.output_dir
            )
            policy_enum = Intervention.from_webui_data_adaptor(suppression_policy)
            output_path = output_path.replace("__INTERVENTION_IDX__", str(policy_enum.value))
            with open(output_path, "w") as f:
                json.dump(output_model, f)
Exemplo n.º 5
0
    def set_inference_parameters(self):
        """
        Setup inference parameters based on data availability and manual
        overrides.  As data becomes more sparse, we further constrain the fit,
        which improves stability substantially.
        """
        self.fit_params = self.DEFAULT_FIT_PARAMS
        # Update State specific SEIR initial guesses
        overwrite_params_df = pd.read_csv(
            "./pyseir_data/pyseir_fitter_initial_conditions_2020_06_10.csv",
            dtype={"fips": object})

        INITIAL_PARAM_SETS = [
            "R0",
            "t0",
            "eps",
            "t_break",
            "eps2",
            "t_delta_phases",
            "test_fraction",
            "hosp_fraction",
            "log10_I_initial",
        ]
        if self.fips in overwrite_params_df["fips"].values:
            this_fips_df = overwrite_params_df.loc[overwrite_params_df["fips"]
                                                   == self.fips]
            for param in INITIAL_PARAM_SETS:
                self.fit_params[param] = this_fips_df[param]

        self.fit_params["fix_hosp_fraction"] = self.hospitalizations is None
        if self.fit_params["fix_hosp_fraction"]:
            self.fit_params["hosp_fraction"] = 1

        if len(self.fips) == 5:
            OBSERVED_NEW_CASES_GUESS_THRESHOLD = 2
            idx_enough_cases = np.argwhere(
                np.cumsum(self.observed_new_cases) >=
                OBSERVED_NEW_CASES_GUESS_THRESHOLD)[0][0]
            initial_cases_guess = np.cumsum(
                self.observed_new_cases)[idx_enough_cases]
            t0_guess = list(self.times)[idx_enough_cases]

            state_fit_result = load_inference_result(fips=self.state_obj.fips)
            self.fit_params["t0"] = t0_guess

            total_cases = np.sum(self.observed_new_cases)
            self.fit_params["log10_I_initial"] = np.log10(
                initial_cases_guess / self.fit_params["test_fraction"])
            self.fit_params[
                "limit_t0"] = t0_guess - 5, state_fit_result["t0"] + 30
            self.fit_params["t_break"] = state_fit_result["t_break"] - (
                t0_guess - state_fit_result["t0"])
            self.fit_params["R0"] = state_fit_result["R0"]
            self.fit_params["test_fraction"] = state_fit_result[
                "test_fraction"]
            self.fit_params["eps"] = state_fit_result["eps"]
            if total_cases < 100:
                self.fit_params["t_break"] = 10
                self.fit_params["fix_test_fraction"] = True
                self.fit_params["fix_R0"] = True
                self.fit_params["limit_t0"] = (
                    state_fit_result["t0"] - 5,
                    state_fit_result["t0"] + 30,
                )
            if total_cases < 50:
                self.fit_params["fix_eps"] = True
                self.fit_params["fix_t_break"] = True
    def _load_model_for_fips(self, scenario="inferred"):
        """
        Try to load a model for the locale, else load the state level model
        and update parameters for the county.
        """
        artifact_path = get_run_artifact_path(self.fips, RunArtifact.MLE_FIT_MODEL)
        if os.path.exists(artifact_path):
            with open(artifact_path, "rb") as f:
                model = pickle.load(f)
            inferred_params = fit_results.load_inference_result(self.fips)

        else:
            _logger.info(
                f"No MLE model found for {self.state_name}: {self.fips}. Reverting to state level."
            )
            artifact_path = get_run_artifact_path(self.fips[:2], RunArtifact.MLE_FIT_MODEL)
            if os.path.exists(artifact_path):
                with open(artifact_path, "rb") as f:
                    model = pickle.load(f)
                inferred_params = fit_results.load_inference_result(self.fips[:2])
            else:
                raise FileNotFoundError(f"Could not locate state result for {self.state_name}")

            # Rescale state values to the county population and replace county
            # specific params.
            # TODO: get_average_seir_parameters should return the analytic solution when available
            # right now it runs an average over the ensemble (with N_samples not consistently set
            # across the code base).
            default_params = ParameterEnsembleGenerator(
                self.fips,
                N_samples=500,
                t_list=model.t_list,
                suppression_policy=model.suppression_policy,
            ).get_average_seir_parameters()
            population_ratio = default_params["N"] / model.N
            model.N *= population_ratio
            model.I_initial *= population_ratio
            model.E_initial *= population_ratio
            model.A_initial *= population_ratio
            model.S_initial = model.N - model.I_initial - model.E_initial - model.A_initial

            for key in {"beds_general", "beds_ICU", "ventilators"}:
                setattr(model, key, default_params[key])

        # Determine the appropriate future suppression policy based on the
        # scenario of interest.

        eps_final = sp.estimate_future_suppression_from_fits(inferred_params, scenario=scenario)

        model.suppression_policy = sp.get_epsilon_interpolator(
            eps=inferred_params["eps"],
            t_break=inferred_params["t_break"],
            eps2=inferred_params["eps2"],
            t_delta_phases=inferred_params["t_delta_phases"],
            t_break_final=(
                datetime.datetime.today()
                - datetime.datetime.fromisoformat(inferred_params["t0_date"])
            ).days,
            eps_final=eps_final,
        )
        model.run()
        return model
    def map_fips(self, fips):
        """
        For a given county fips code, generate the CAN UI output format.

        Parameters
        ----------
        fips: str
            County FIPS code to map.
        """
        if len(fips) == 5:
            population = self.population_data.get_county_level(
                'USA', state=self.state_abbreviation, fips=fips)
        else:
            population = self.population_data.get_state_level(
                'USA', state=self.state_abbreviation)

        logging.info(f'Mapping output to WebUI for {self.state}, {fips}')
        pyseir_outputs = load_data.load_ensemble_results(fips)

        policies = [
            key for key in pyseir_outputs.keys()
            if key.startswith('suppression_policy')
        ]

        all_hospitalized_today = None
        try:
            fit_results = load_inference_result(fips)
        except ValueError:
            fit_results = None
            logging.error(
                f'Fit result not found for {fips}: Skipping inference elements'
            )

        for i_policy, suppression_policy in enumerate(policies):
            if suppression_policy == 'suppression_policy__full_containment':  # No longer shipping this.
                continue
            output_for_policy = pyseir_outputs[suppression_policy]
            output_model = pd.DataFrame()

            if suppression_policy == 'suppression_policy__inferred' and fit_results:
                if len(fips) == 5 and fips not in self.df_whitelist.fips:
                    continue

                t0 = datetime.fromisoformat(fit_results['t0_date'])

                # Hospitalizations need to be rescaled by the inferred factor to match observations for display.
                now_idx = int(
                    (datetime.today() -
                     datetime.fromisoformat(fit_results['t0_date'])).days)
                total_hosps = output_for_policy['HGen']['ci_50'][
                    now_idx] + output_for_policy['HICU']['ci_50'][now_idx]
                hosp_fraction = all_hospitalized_today / total_hosps

            else:
                t0 = datetime.today()
                hosp_fraction = 1.

            t_list = output_for_policy['t_list']
            t_list_downsampled = range(0, int(max(t_list)),
                                       self.output_interval_days)
            # Col 0
            output_model['days'] = t_list_downsampled
            # Col 1
            output_model['date'] = [
                (t0 + timedelta(days=t)).date().strftime('%m/%d/%y')
                for t in t_list_downsampled
            ]
            # Col 2
            output_model['t'] = population
            # Col 3
            output_model['b'] = np.interp(t_list_downsampled, t_list,
                                          output_for_policy['S']['ci_50'])
            # Col 4
            output_model['c'] = np.interp(t_list_downsampled, t_list,
                                          output_for_policy['E']['ci_50'])
            # Col 5
            output_model['d'] = np.interp(
                t_list_downsampled, t_list,
                np.add(output_for_policy['I']['ci_50'],
                       output_for_policy['A']['ci_50']))  # Infected + Asympt.
            # Col 6
            output_model['e'] = output_model['d']
            # Col 7
            output_model['f'] = np.interp(
                t_list_downsampled, t_list,
                output_for_policy['HGen']['ci_50'])  # Hosp General
            # Col 8
            output_model['g'] = np.interp(
                t_list_downsampled, t_list,
                output_for_policy['HICU']['ci_50'])  # Hosp ICU
            # Col 9
            output_model['all_hospitalized'] = hosp_fraction * np.add(
                output_model['f'], output_model['g'])
            # Col 10
            output_model['all_infected'] = output_model['d']
            # Col 11
            output_model['dead'] = np.interp(
                t_list_downsampled, t_list,
                output_for_policy['total_deaths']['ci_50'])
            # Col 12
            final_beds = np.mean(
                output_for_policy['HGen']['capacity']) + np.mean(
                    output_for_policy['HICU']['capacity'])
            output_model['beds'] = final_beds
            output_model['cumulative_infected'] = np.interp(
                t_list_downsampled, t_list,
                np.cumsum(output_for_policy['total_new_infections']['ci_50']))

            if fit_results:
                output_model['R_t'] = np.interp(
                    t_list_downsampled, t_list, fit_results['eps'] *
                    fit_results['R0'] * np.ones(len(t_list)))
                output_model['R_t_stdev'] = np.interp(
                    t_list_downsampled, t_list, fit_results['eps_error'] *
                    fit_results['R0'] * np.ones(len(t_list)))
            else:
                output_model['R_t'] = 0
                output_model['R_t_stdev'] = 0

            # Record the current number of hospitalizations in order to rescale the inference results.
            all_hospitalized_today = output_model['all_hospitalized'][0]

            # Don't backfill inferences
            if suppression_policy != 'suppression_policy__inferred':
                backfill = self.backfill_output_model_fips(
                    fips, t0, final_beds, output_model)
                output_model = pd.concat([
                    backfill, output_model
                ])[output_model.columns].reset_index(drop=True)

            # Truncate date range of output.
            output_dates = pd.to_datetime(output_model['date'])
            output_model = output_model[
                (output_dates > datetime(month=3, day=3, year=2020))
                & (output_dates < datetime.today() + timedelta(days=90))]
            output_model = output_model.fillna(0)

            for col in ['l']:
                output_model[col] = 0
            output_model['population'] = population
            for col in ['m', 'n']:
                output_model[col] = 0

            # Truncate floats and cast as strings to match data model.
            int_columns = [
                col for col in output_model.columns
                if col not in ('date', 'R_t', 'R_t_stdev')
            ]
            output_model.loc[:,
                             int_columns] = output_model[int_columns].fillna(
                                 0).astype(int).astype(str)
            output_model.loc[:, ['R_t', 'R_t_stdev']] = output_model[[
                'R_t', 'R_t_stdev'
            ]].fillna(0).round(decimals=4).astype(str)

            # Convert the records format to just list(list(values))
            output_model = [[
                val for val in timestep.values()
            ] for timestep in output_model.to_dict(orient='records')]

            output_path = get_run_artifact_path(fips,
                                                RunArtifact.WEB_UI_RESULT,
                                                output_dir=self.output_dir)
            policy_enum = Intervention.from_webui_data_adaptor(
                suppression_policy)
            output_path = output_path.replace('__INTERVENTION_IDX__',
                                              str(policy_enum.value))
            with open(output_path, 'w') as f:
                json.dump(output_model, f)
Exemplo n.º 8
0
    def map_fips(self, fips: str) -> None:
        """
        For a given fips code, for either a county or state, generate the CAN UI output format.

        Parameters
        ----------
        fips: str
            FIPS code to map.
        """
        log.info("Mapping output to WebUI.", state=self.state, fips=fips)
        shim_log = structlog.getLogger(fips=fips)
        pyseir_outputs = load_data.load_ensemble_results(fips)

        try:
            fit_results = load_inference_result(fips)
            t0_simulation = datetime.fromisoformat(fit_results["t0_date"])
        except (KeyError, ValueError):
            log.error("Fit result not found for fips. Skipping...", fips=fips)
            return
        population = self._get_population(fips)

        # We will shim all suppression policies by the same amount (since historical tracking error
        # for all policies is the same).
        baseline_policy = "suppression_policy__inferred"  # This could be any valid policy

        # We need the index in the model's temporal frame.
        idx_offset = int(fit_results["t_today"] - fit_results["t0"])

        # Get the latest observed values to use in calculating shims
        observed_latest_dict = combined_datasets.get_us_latest_for_fips(fips)

        observed_death_latest = observed_latest_dict[CommonFields.DEATHS]
        observed_total_hosps_latest = observed_latest_dict[
            CommonFields.CURRENT_HOSPITALIZED]
        observed_icu_latest = observed_latest_dict[CommonFields.CURRENT_ICU]

        # For Deaths
        model_death_latest = pyseir_outputs[baseline_policy]["total_deaths"][
            "ci_50"][idx_offset]
        model_acute_latest = pyseir_outputs[baseline_policy]["HGen"]["ci_50"][
            idx_offset]
        model_icu_latest = pyseir_outputs[baseline_policy]["HICU"]["ci_50"][
            idx_offset]
        model_total_hosps_latest = model_acute_latest + model_icu_latest

        death_shim = shim.calculate_strict_shim(
            model=model_death_latest,
            observed=observed_death_latest,
            log=shim_log.bind(type=CommonFields.DEATHS),
        )

        total_hosp_shim = shim.calculate_strict_shim(
            model=model_total_hosps_latest,
            observed=observed_total_hosps_latest,
            log=shim_log.bind(type=CommonFields.CURRENT_HOSPITALIZED),
        )

        # For ICU This one is a little more interesting since we often don't have ICU. In this case
        # we use information from the same aggregation level (intralevel) to keep the ratios
        # between general hospitalization and icu hospitalization
        icu_shim = shim.calculate_intralevel_icu_shim(
            model_acute=model_acute_latest,
            model_icu=model_icu_latest,
            observed_icu=observed_icu_latest,
            observed_total_hosps=observed_total_hosps_latest,
            log=shim_log.bind(type=CommonFields.CURRENT_ICU),
        )

        # Iterate through each suppression policy.
        # Model output is interpolated to the dates desired for the API.
        suppression_policies = [
            key for key in pyseir_outputs.keys()
            if key.startswith("suppression_policy")
        ]
        for suppression_policy in suppression_policies:
            output_for_policy = pyseir_outputs[suppression_policy]
            output_model = pd.DataFrame()
            t_list = output_for_policy["t_list"]
            t_list_downsampled = range(0, int(max(t_list)),
                                       self.output_interval_days)

            output_model[schema.DAY_NUM] = t_list_downsampled
            output_model[schema.DATE] = [
                (t0_simulation + timedelta(days=t)).date().strftime("%Y-%m-%d")
                for t in t_list_downsampled
            ]
            output_model[schema.TOTAL] = population
            output_model[schema.TOTAL_SUSCEPTIBLE] = np.interp(
                t_list_downsampled, t_list, output_for_policy["S"]["ci_50"])
            output_model[schema.EXPOSED] = np.interp(
                t_list_downsampled, t_list, output_for_policy["E"]["ci_50"])
            output_model[schema.INFECTED] = np.interp(
                t_list_downsampled,
                t_list,
                np.add(output_for_policy["I"]["ci_50"],
                       output_for_policy["A"]["ci_50"]),
            )  # Infected + Asympt.
            output_model[schema.INFECTED_A] = output_model[schema.INFECTED]

            interpolated_model_acute_values = np.interp(
                t_list_downsampled, t_list, output_for_policy["HGen"]["ci_50"])
            output_model[schema.INFECTED_B] = interpolated_model_acute_values

            raw_model_icu_values = output_for_policy["HICU"]["ci_50"]
            interpolated_model_icu_values = np.interp(t_list_downsampled,
                                                      t_list,
                                                      raw_model_icu_values)
            output_model[schema.INFECTED_C] = (
                icu_shim + interpolated_model_icu_values).clip(min=0)

            # General + ICU beds. don't include vent here because they are also counted in ICU
            output_model[schema.ALL_HOSPITALIZED] = (
                interpolated_model_acute_values +
                interpolated_model_icu_values + total_hosp_shim).clip(min=0)

            output_model[schema.ALL_INFECTED] = output_model[schema.INFECTED]

            # Shim Deaths to Match Observed
            raw_model_deaths_values = output_for_policy["total_deaths"][
                "ci_50"]
            interp_model_deaths_values = np.interp(t_list_downsampled, t_list,
                                                   raw_model_deaths_values)
            output_model[schema.DEAD] = (interp_model_deaths_values +
                                         death_shim).clip(min=0)

            # Continue mapping
            final_beds = np.mean(output_for_policy["HGen"]["capacity"])
            output_model[schema.BEDS] = final_beds
            output_model[schema.CUMULATIVE_INFECTED] = np.interp(
                t_list_downsampled,
                t_list,
                np.cumsum(output_for_policy["total_new_infections"]["ci_50"]),
            )

            if fit_results:
                output_model[schema.Rt] = np.interp(
                    t_list_downsampled,
                    t_list,
                    fit_results["eps2"] * fit_results["R0"] *
                    np.ones(len(t_list)),
                )
                output_model[schema.Rt_ci90] = np.interp(
                    t_list_downsampled,
                    t_list,
                    2 * fit_results["eps2_error"] * fit_results["R0"] *
                    np.ones(len(t_list)),
                )
            else:
                output_model[schema.Rt] = 0
                output_model[schema.Rt_ci90] = 0

            output_model[schema.CURRENT_VENTILATED] = (
                icu_shim +
                np.interp(t_list_downsampled, t_list,
                          output_for_policy["HVent"]["ci_50"])).clip(min=0)
            output_model[schema.POPULATION] = population
            # Average capacity.
            output_model[schema.ICU_BED_CAPACITY] = np.mean(
                output_for_policy["HICU"]["capacity"])
            output_model[schema.VENTILATOR_CAPACITY] = np.mean(
                output_for_policy["HVent"]["capacity"])

            # Truncate date range of output.
            output_dates = pd.to_datetime(output_model["date"])
            output_model = output_model[
                (output_dates >= datetime(month=3, day=3, year=2020))
                & (output_dates < datetime.today() + timedelta(days=90))]
            output_model = output_model.fillna(0)

            # Fill in results for the Rt indicator.
            rt_results = load_Rt_result(fips)
            if rt_results is not None:
                rt_results.index = rt_results[
                    "Rt_MAP_composite"].index.strftime("%Y-%m-%d")
                merged = output_model.merge(
                    rt_results[["Rt_MAP_composite", "Rt_ci95_composite"]],
                    right_index=True,
                    left_on="date",
                    how="left",
                )
                output_model[schema.RT_INDICATOR] = merged["Rt_MAP_composite"]

                # With 90% probability the value is between rt_indicator - ci90
                # to rt_indicator + ci90
                output_model[schema.RT_INDICATOR_CI90] = (
                    merged["Rt_ci95_composite"] - merged["Rt_MAP_composite"])
            else:
                log.warning(
                    "No Rt Results found, clearing Rt in output.",
                    fips=fips,
                    suppression_policy=suppression_policy,
                )
                output_model[schema.RT_INDICATOR] = "NaN"
                output_model[schema.RT_INDICATOR_CI90] = "NaN"

            output_model[[schema.RT_INDICATOR,
                          schema.RT_INDICATOR_CI90]] = output_model[[
                              schema.RT_INDICATOR, schema.RT_INDICATOR_CI90
                          ]].fillna("NaN")

            int_columns = [
                col for col in output_model.columns if col not in (
                    schema.DATE,
                    schema.Rt,
                    schema.Rt_ci90,
                    schema.RT_INDICATOR,
                    schema.RT_INDICATOR_CI90,
                    schema.FIPS,
                )
            ]
            output_model.loc[:,
                             int_columns] = output_model[int_columns].fillna(
                                 0).astype(int)
            output_model.loc[:, [
                schema.Rt, schema.Rt_ci90, schema.RT_INDICATOR, schema.
                RT_INDICATOR_CI90
            ]] = output_model[[
                schema.Rt, schema.Rt_ci90, schema.RT_INDICATOR,
                schema.RT_INDICATOR_CI90
            ]].fillna(0)

            output_model[schema.FIPS] = fips
            intervention = Intervention.from_webui_data_adaptor(
                suppression_policy)
            output_model[schema.INTERVENTION] = intervention.value
            output_path = get_run_artifact_path(fips,
                                                RunArtifact.WEB_UI_RESULT,
                                                output_dir=self.output_dir)
            output_path = output_path.replace("__INTERVENTION_IDX__",
                                              str(intervention.value))
            output_model.to_json(output_path, orient=OUTPUT_JSON_ORIENT)
    def map_fips(self, fips: str) -> None:
        """
        For a given fips code, for either a county or state, generate the CAN UI output format.

        Parameters
        ----------
        fips: str
            FIPS code to map.
        """
        log.info("Mapping output to WebUI.", state=self.state, fips=fips)
        pyseir_outputs = load_data.load_ensemble_results(fips)

        try:
            fit_results = load_inference_result(fips)
            t0_simulation = datetime.fromisoformat(fit_results["t0_date"])
        except (KeyError, ValueError):
            log.error("Fit result not found for fips. Skipping...", fips=fips)
            return
        population = self._get_population(fips)

        # Get multiplicative conversion factors to scale model output to fit dataset current values
        hosp_rescaling_factor, icu_rescaling_factor = self._get_model_to_dataset_conversion_factors(
            t0_simulation=t0_simulation,
            fips=fips,
            pyseir_outputs=pyseir_outputs,
        )

        # Iterate through each suppression policy.
        # Model output is interpolated to the dates desired for the API.
        suppression_policies = [
            key for key in pyseir_outputs.keys()
            if key.startswith("suppression_policy")
        ]
        for suppression_policy in suppression_policies:
            output_for_policy = pyseir_outputs[suppression_policy]
            output_model = pd.DataFrame()

            t_list = output_for_policy["t_list"]
            t_list_downsampled = range(0, int(max(t_list)),
                                       self.output_interval_days)

            output_model[schema.DAY_NUM] = t_list_downsampled
            output_model[schema.DATE] = [
                (t0_simulation + timedelta(days=t)).date().strftime("%m/%d/%y")
                for t in t_list_downsampled
            ]
            output_model[schema.TOTAL] = population
            output_model[schema.TOTAL_SUSCEPTIBLE] = np.interp(
                t_list_downsampled, t_list, output_for_policy["S"]["ci_50"])
            output_model[schema.EXPOSED] = np.interp(
                t_list_downsampled, t_list, output_for_policy["E"]["ci_50"])
            output_model[schema.INFECTED] = np.interp(
                t_list_downsampled,
                t_list,
                np.add(output_for_policy["I"]["ci_50"],
                       output_for_policy["A"]["ci_50"]),
            )  # Infected + Asympt.
            output_model[schema.INFECTED_A] = output_model[schema.INFECTED]
            output_model[
                schema.INFECTED_B] = hosp_rescaling_factor * np.interp(
                    t_list_downsampled, t_list,
                    output_for_policy["HGen"]["ci_50"])  # Hosp General

            raw_model_icu_values = output_for_policy["HICU"]["ci_50"]
            interpolated_model_icu_values = np.interp(t_list_downsampled,
                                                      t_list,
                                                      raw_model_icu_values)
            final_derived_model_value = icu_rescaling_factor * interpolated_model_icu_values
            output_model[schema.INFECTED_C] = final_derived_model_value
            # General + ICU beds. don't include vent here because they are also counted in ICU
            output_model[schema.ALL_HOSPITALIZED] = np.add(
                output_model[schema.INFECTED_B],
                output_model[schema.INFECTED_C])
            output_model[schema.ALL_INFECTED] = output_model[schema.INFECTED]
            output_model[schema.DEAD] = np.interp(
                t_list_downsampled, t_list,
                output_for_policy["total_deaths"]["ci_50"])
            final_beds = np.mean(output_for_policy["HGen"]["capacity"])
            output_model[schema.BEDS] = final_beds
            output_model[schema.CUMULATIVE_INFECTED] = np.interp(
                t_list_downsampled,
                t_list,
                np.cumsum(output_for_policy["total_new_infections"]["ci_50"]),
            )

            if fit_results:
                output_model[schema.Rt] = np.interp(
                    t_list_downsampled,
                    t_list,
                    fit_results["eps"] * fit_results["R0"] *
                    np.ones(len(t_list)),
                )
                output_model[schema.Rt_ci90] = np.interp(
                    t_list_downsampled,
                    t_list,
                    2 * fit_results["eps_error"] * fit_results["R0"] *
                    np.ones(len(t_list)),
                )
            else:
                output_model[schema.Rt] = 0
                output_model[schema.Rt_ci90] = 0

            output_model[
                schema.CURRENT_VENTILATED] = icu_rescaling_factor * np.interp(
                    t_list_downsampled, t_list,
                    output_for_policy["HVent"]["ci_50"])
            output_model[schema.POPULATION] = population
            # Average capacity.
            output_model[schema.ICU_BED_CAPACITY] = np.mean(
                output_for_policy["HICU"]["capacity"])
            output_model[schema.VENTILATOR_CAPACITY] = np.mean(
                output_for_policy["HVent"]["capacity"])

            # Truncate date range of output.
            output_dates = pd.to_datetime(output_model["date"])
            output_model = output_model[
                (output_dates >= datetime(month=3, day=3, year=2020))
                & (output_dates < datetime.today() + timedelta(days=90))]
            output_model = output_model.fillna(0)

            # Fill in results for the Rt indicator.
            try:
                rt_results = load_Rt_result(fips)
                rt_results.index = rt_results[
                    "Rt_MAP_composite"].index.strftime("%m/%d/%y")
                merged = output_model.merge(
                    rt_results[["Rt_MAP_composite", "Rt_ci95_composite"]],
                    right_index=True,
                    left_on="date",
                    how="left",
                )
                output_model[schema.RT_INDICATOR] = merged["Rt_MAP_composite"]

                # With 90% probability the value is between rt_indicator - ci90
                # to rt_indicator + ci90
                output_model[schema.RT_INDICATOR_CI90] = (
                    merged["Rt_ci95_composite"] - merged["Rt_MAP_composite"])
            except (ValueError, KeyError) as e:
                log.warning("Clearing Rt in output for fips.",
                            fips=fips,
                            exc_info=e)
                output_model[schema.RT_INDICATOR] = "NaN"
                output_model[schema.RT_INDICATOR_CI90] = "NaN"

            output_model[[schema.RT_INDICATOR,
                          schema.RT_INDICATOR_CI90]] = output_model[[
                              schema.RT_INDICATOR, schema.RT_INDICATOR_CI90
                          ]].fillna("NaN")

            # Truncate floats and cast as strings to match data model.
            int_columns = [
                col for col in output_model.columns if col not in (
                    schema.DATE,
                    schema.Rt,
                    schema.Rt_ci90,
                    schema.RT_INDICATOR,
                    schema.RT_INDICATOR_CI90,
                )
            ]
            output_model.loc[:, int_columns] = (
                output_model[int_columns].fillna(0).astype(int).astype(str))
            output_model.loc[:, [
                schema.Rt, schema.Rt_ci90, schema.RT_INDICATOR, schema.
                RT_INDICATOR_CI90
            ]] = (output_model[[
                schema.Rt, schema.Rt_ci90, schema.RT_INDICATOR,
                schema.RT_INDICATOR_CI90
            ]].fillna(0).round(decimals=4).astype(str))

            # Convert the records format to just list(list(values))
            output_model = [[
                val for val in timestep.values()
            ] for timestep in output_model.to_dict(orient="records")]

            output_path = get_run_artifact_path(fips,
                                                RunArtifact.WEB_UI_RESULT,
                                                output_dir=self.output_dir)
            policy_enum = Intervention.from_webui_data_adaptor(
                suppression_policy)
            output_path = output_path.replace("__INTERVENTION_IDX__",
                                              str(policy_enum.value))
            with open(output_path, "w") as f:
                json.dump(output_model, f)