示例#1
0
def infer_t0(fips, method='first_case', default=pd.Timestamp('2020-02-01')):
    """
    Infer t0 for a given fips under given methods:
       - first_case: t0 is set as time of first observed case.
       - impute: t0 is imputed.
    Returns default value if neither method works.

    Parameters
    ----------
    fips : str
        County fips
    method : str
        The method to determine t0.
    default : pd.Timestamp
        Default t0 if neither method works.

    Returns
    -------
    t0 : pd.Timestamp
        Inferred t0 for given fips.
    """

    if method == 'impute':
        t0 = fit_results.load_t0(fips)
    elif method == 'first_case':
        case_data = load_county_case_data()
        if fips in case_data.fips:
            t0 = case_data[case_data.fips == fips].date.min()
        else:
            t0 = default
    elif method == 'reference_date':
        t0 = default
    else:
        raise ValueError(f'Invalid method {method} for t0 inference')
    return t0
示例#2
0
    def __init__(self,
                 fips,
                 n_years=2,
                 n_samples=250,
                 suppression_policy=(0.35, 0.5, 0.75, 1),
                 skip_plots=False,
                 output_percentiles=(5, 25, 32, 50, 75, 68, 95),
                 generate_report=True):

        self.fips = fips
        self.t_list = np.linspace(0, 365 * n_years, 365 * n_years)
        self.skip_plots = skip_plots

        self.county_metadata = load_data.load_county_metadata_by_fips(fips)
        self.output_percentiles = output_percentiles
        self.n_samples = n_samples
        self.n_years = n_years
        self.t0 = fit_results.load_t0(fips)
        self.date_generated = datetime.datetime.utcnow().isoformat()
        self.suppression_policy = suppression_policy

        self.summary = copy.deepcopy(self.__dict__)
        self.summary.pop('t_list')
        self.generate_report = generate_report

        self.all_outputs = {}
        self.output_file_report = os.path.join(
            OUTPUT_DIR, self.county_metadata['state'], 'reports',
            f"{self.county_metadata['state']}__{self.county_metadata['county']}__{self.fips}__ensemble_projections.pdf"
        )
        self.output_file_data = os.path.join(
            OUTPUT_DIR, self.county_metadata['state'], 'data',
            f"{self.county_metadata['state']}__{self.county_metadata['county']}__{self.fips}__ensemble_projections.json"
        )
    def generate_surge_spreadsheet(self):
        """
        Produce a spreadsheet summarizing peaks.

        Parameters
        ----------
        state: str
            State to generate sheet for.

        Returns
        -------

        """
        df = load_data.load_county_metadata()
        all_fips = df[df['state'].str.lower() == self.state.lower()].fips
        all_data = {fips: load_data.load_ensemble_results(fips) for fips in all_fips}
        df = df.set_index('fips')

        records = []
        for fips, ensembles in all_data.items():
            county_name = df.loc[fips]['county']
            t0 = fit_results.load_t0(fips)

            for suppression_policy, ensemble in ensembles.items():

                county_record = dict(
                    county_name=county_name,
                    county_fips=fips,
                    mitigation_policy=policy_to_mitigation(suppression_policy)
                )

                for compartment in ['HGen', 'general_admissions_per_day', 'HICU', 'icu_admissions_per_day', 'total_new_infections',
                                    'direct_deaths_per_day', 'total_deaths', 'D']:
                    compartment_name = compartment_to_name_map[compartment]

                    county_record[compartment_name + ' Peak Value Mean'] = '%.0f' % ensemble[compartment]['peak_value_mean']
                    county_record[compartment_name + ' Peak Value Median'] = '%.0f' % ensemble[compartment]['peak_value_ci50']
                    county_record[compartment_name + ' Peak Value CI25'] = '%.0f' % ensemble[compartment]['peak_value_ci25']
                    county_record[compartment_name + ' Peak Value CI75'] = '%.0f' % ensemble[compartment]['peak_value_ci75']
                    county_record[compartment_name + ' Peak Time Median'] = (t0 + timedelta(days=ensemble[compartment]['peak_time_ci50'])).date().isoformat()

                    # Leaving for now...
                    # if 'surge_start' in ensemble[compartment]:
                    #     if not np.isnan(np.nanmean(ensemble[compartment]['surge_start'])):
                    #         county_record[compartment_name + ' Surge Start Mean'] = (t0 + timedelta(days=np.nanmean(ensemble[compartment]['surge_start']))).date().isoformat()
                    #         county_record[compartment_name + ' Surge End Mean'] = (t0 + timedelta(days=np.nanmean(ensemble[compartment]['surge_end']))).date().isoformat()

                records.append(county_record)

        df = pd.DataFrame(records)
        writer = pd.ExcelWriter(self.surge_filename, engine='xlsxwriter')
        for policy in df['mitigation_policy'].unique()[::-1]:
            df[df['mitigation_policy'] == policy].drop(['mitigation_policy', 'county_fips'], axis=1)
            df[df['mitigation_policy'] == policy].drop(['mitigation_policy', 'county_fips'], axis=1).to_excel(writer, sheet_name=policy)
        writer.save()
示例#4
0
def infer_t0(fips, method="first_case", default=pd.Timestamp("2020-02-01")):
    """
    Infer t0 for a given fips under given methods:
       - first_case: t0 is set as time of first observed case.
       - impute: t0 is imputed.
    Returns default value if neither method works.

    Parameters
    ----------
    fips : str
        County fips
    method : str
        The method to determine t0.
    default : pd.Timestamp
        Default t0 if neither method works.

    Returns
    -------
    t0 : pd.Timestamp
        Inferred t0 for given fips.
    """

    if method == "impute":
        t0 = fit_results.load_t0(fips)
    elif method == "first_case":
        fips_timeseries = combined_datasets.get_timeseries_for_fips(
            fips, columns=[CommonFields.CASES], min_range_with_some_value=True)
        if not fips_timeseries.empty:
            t0 = fips_timeseries[CommonFields.DATE].min()
        else:
            t0 = default
    elif method == "reference_date":
        t0 = default
    else:
        raise ValueError(f"Invalid method {method} for t0 inference")
    return t0
    def plot_compartment(self, compartment):
        """
        Plot state level data on a compartment.

        Parameters
        ----------
        compartment: str
            Compartment of the model to plot.
        primary_suppression_policy: str
            Best estimate of the true suppression policy. Gets a little extra
            love in the plots, such as confidence intervals.
        """
        fig = plt.figure(figsize=(30, 20))
        plt.suptitle(
            f'{self.state}: Median Peak Estimates for {compartment_to_name_map[compartment]} Surges',
            fontsize=20)

        color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color'] + list(
            'bgrcmyk')
        for i_plt, suppression_policy in enumerate(
                list(self.ensemble_data_by_county.values())[0].keys()):
            # ---------------------------------------------------------
            # Plot Peak Times and values These need to be shifter by t0
            # ---------------------------------------------------------
            plt.subplot(1, 2, 1)
            peak_times = [
                fit_results.load_t0(fips) +
                timedelta(days=self.ensemble_data_by_county[fips]
                          [suppression_policy][compartment]['peak_time_ci50'])
                for fips in self.counties
            ]

            sorted_times = sorted(deepcopy(peak_times))
            median_statewide_peak = sorted_times[len(sorted_times) // 2]

            plt.scatter(peak_times,
                        self.names,
                        label=f'{suppression_policy}',
                        c=color_cycle[i_plt])
            plt.vlines(median_statewide_peak,
                       0,
                       len(self.names),
                       alpha=1,
                       linestyle='-.',
                       colors=color_cycle[i_plt],
                       label=f'State Median: {suppression_policy}')

            plt.subplot(1, 2, 2)
            peak_values = [
                self.ensemble_data_by_county[fips][suppression_policy]
                [compartment]['peak_value_ci50'] for fips in self.counties
            ]
            plt.scatter(peak_values,
                        self.names,
                        label=suppression_policy,
                        c=color_cycle[i_plt])

            if suppression_policy == self.primary_suppression_policy:
                plt.subplot(121)
                for i, fips in enumerate(self.counties):
                    value = 'peak_time'
                    d = self.ensemble_data_by_county[fips][suppression_policy][
                        compartment]
                    t0 = fit_results.load_t0(fips)

                    plt.fill_betweenx(
                        [i - .2, i + .2],
                        [t0 + timedelta(days=d[f'{value}_ci5'])] * 2,
                        [t0 + timedelta(days=d[f'{value}_ci95'])] * 2,
                        alpha=.3,
                        color=color_cycle[i_plt])

                    plt.fill_betweenx(
                        [i - .2, i + .2],
                        [t0 + timedelta(days=d[f'{value}_ci32'])] * 2,
                        [t0 + timedelta(days=d[f'{value}_ci68'])] * 2,
                        alpha=.3,
                        color=color_cycle[i_plt])
                    plt.grid(alpha=.4)
                    plt.xlabel(value)

                ticks = []
                for month in range(1, 13):
                    ticks.append(datetime(month=month, day=1, year=2020))
                    ticks.append(datetime(month=month, day=15, year=2020))
                for month in range(1, 13):
                    ticks.append(datetime(month=month, day=1, year=2021))
                    ticks.append(datetime(month=month, day=15, year=2021))
                plt.xticks(ticks, rotation=30)

                # --------------------------
                # Plot Peak Values
                # --------------------------
                plt.subplot(1, 2, 2)
                plot_capacity = 'capacity' in self.ensemble_data_by_county[
                    fips][suppression_policy][compartment]
                if plot_capacity:
                    capacities = np.median(np.vstack([
                        self.ensemble_data_by_county[fips][suppression_policy]
                        [compartment]['capacity'] for fips in self.counties
                    ]),
                                           axis=1)
                    plt.scatter(capacities,
                                self.names,
                                marker='<',
                                s=100,
                                c='r',
                                label='Estimated Capacity')

                for i, fips in enumerate(self.counties):
                    value = 'peak_value'
                    d = self.ensemble_data_by_county[fips][suppression_policy][
                        compartment]
                    plt.fill_betweenx([i - .2, i + .2],
                                      [d[f'{value}_ci5']] * 2,
                                      [d[f'{value}_ci95']] * 2,
                                      alpha=.3,
                                      color=color_cycle[i_plt])
                    plt.fill_betweenx([i - .2, i + .2],
                                      [d[f'{value}_ci32']] * 2,
                                      [d[f'{value}_ci68']] * 2,
                                      alpha=.3,
                                      color=color_cycle[i_plt])
                    plt.grid(which='both', alpha=.4)
                    plt.xlabel('Required Surge Capacity at Peak', fontsize=14)
                    plt.xscale('log')

                if plot_capacity:
                    up_lim = plt.xlim()[1]
                    for i, (capacity, peak_value) in enumerate(
                            zip(capacities, peak_values)):
                        if np.isnan(capacity) or capacity == 0:
                            try:
                                plt.text(up_lim * 1.3,
                                         i - .5,
                                         f'UNKNOWN CAPACITY: %s NEEDED' %
                                         int(peak_value),
                                         color='r')
                            except ValueError:
                                logging.warning('Error estimating peak. NaN')
                        else:
                            plt.text(
                                up_lim * 1.3,
                                i - .5,
                                f'Surge {peak_value / capacity * 100:.0f}%: {peak_value - capacity:.0f} Needed',
                                color='r')

                    plt.text(.01,
                             .01,
                             f'Surge Capacity Listed for {suppression_policy}',
                             transform=plt.gca().transAxes,
                             color='r',
                             fontsize=16)

        plt.subplot(121)
        caption = textwrap.fill(textwrap.dedent("""
            Surge Peak Timing: Timing of the surge peak under different
            suppression policies. A suppression policy of 0.7 implies contact is
            reduced by 30% (i.e. 30% efficacy of social distancing). Overall
            trends show that higher suppression leads to much longer time until
            surge peak. Several rural counties have imputed start times which
            may be artificially biased to peak sooner. Suppression values below
            ~0.25 (not shown) drive R0 < 1 and decay over time though it is
            unlikely this is achievable.
            
            Error bars represent (68%, 95%) CL based on en ensemble of
            parameters sampled in the appendix for a "best-guess" suppression
            model. Dashed lines indicate the state-wide median. Notably, the
            impact of distancing measures is significantly larger than variance
            associated the epidemiological model suggesting that policy may be
            used to spread these peaks relative to each other to reduce
            coincident surge."""),
                                width=120)
        plt.text(0,
                 1.01,
                 caption,
                 ha='left',
                 va='bottom',
                 transform=plt.gca().transAxes)
        plt.legend()
        plt.subplot(122)

        caption = textwrap.fill(textwrap.dedent(f"""
            Surge Peak Levels: Value of the surge peak under different
            suppression policies. A suppression policy of 0.7 implies contact is
            reduced by 30% (i.e. 30% efficacy of social distancing). Overall
            trends show that higher suppression leads to much lower peak levels
            as the "curve is flattened".

            Error bars represent (68%, 95%) CL based on en ensemble of
            parameters sampled in the appendix for a "best-guess" suppression
            model: {suppression_policy}. Capacity is estimated based on
            aggregating hospital estimates from "Definitive"" to the county
            level. For beds these estimates are (N_total - utilized + estimated
            increase) with each term based on Definitive projections which
            account for utilization (Checking this!). For ventilators, we
            estimate nationally 1.1 ventilators per ICU bed which includes
            national emergency stockpile and a ~30% efficacy of an estimated
            100k old ventilators."""),
                                width=120)
        plt.text(0,
                 1.01,
                 caption,
                 ha='left',
                 va='bottom',
                 transform=plt.gca().transAxes)

        plt.legend()

        return fig
示例#6
0
    def generate_surge_spreadsheet(self):
        """
        Produce a spreadsheet summarizing peaks.

        Parameters
        ----------
        state: str
            State to generate sheet for.

        Returns
        -------

        """
        df = load_data.load_county_metadata()
        all_fips = load_data.get_all_fips_codes_for_a_state(self.state)
        all_data = {
            fips: load_data.load_ensemble_results(fips)
            for fips in all_fips
        }
        df = df.set_index("fips")

        records = []
        for fips, ensembles in all_data.items():
            county_name = df.loc[fips]["county"]
            t0 = fit_results.load_t0(fips)

            for suppression_policy, ensemble in ensembles.items():

                county_record = dict(
                    county_name=county_name,
                    county_fips=fips,
                    mitigation_policy=policy_to_mitigation(suppression_policy),
                )

                for compartment in [
                        "HGen",
                        "general_admissions_per_day",
                        "HICU",
                        "icu_admissions_per_day",
                        "total_new_infections",
                        "direct_deaths_per_day",
                        "total_deaths",
                        "D",
                ]:
                    compartment_name = compartment_to_name_map[compartment]

                    county_record[compartment_name + " Peak Value Mean"] = (
                        "%.0f" % ensemble[compartment]["peak_value_mean"])
                    county_record[compartment_name + " Peak Value Median"] = (
                        "%.0f" % ensemble[compartment]["peak_value_ci50"])
                    county_record[compartment_name + " Peak Value CI25"] = (
                        "%.0f" % ensemble[compartment]["peak_value_ci25"])
                    county_record[compartment_name + " Peak Value CI75"] = (
                        "%.0f" % ensemble[compartment]["peak_value_ci75"])
                    county_record[compartment_name + " Peak Time Median"] = ((
                        t0 +
                        timedelta(days=ensemble[compartment]["peak_time_ci50"])
                    ).date().isoformat())

                    # Leaving for now...
                    # if 'surge_start' in ensemble[compartment]:
                    #     if not np.isnan(np.nanmean(ensemble[compartment]['surge_start'])):
                    #         county_record[compartment_name + ' Surge Start Mean'] = (t0 + timedelta(days=np.nanmean(ensemble[compartment]['surge_start']))).date().isoformat()
                    #         county_record[compartment_name + ' Surge End Mean'] = (t0 + timedelta(days=np.nanmean(ensemble[compartment]['surge_end']))).date().isoformat()

                records.append(county_record)

        df = pd.DataFrame(records)
        df.write_json(self.surge_filename)
def generate_empirical_distancing_policy(t_list, fips, future_suppression):
    """
    Produce a suppression policy based on Imperial College estimates of social
    distancing programs combined with County level datasets about their
    implementation.

    Parameters
    ----------
    t_list: array-like
        List of times to interpolate over.
    fips: str
        County fips to lookup interventions against.
    future_suppression: float
        The suppression level to apply in an ongoing basis after today, and
        going backward as the lockdown / stay-at-home efficacy.

    Returns
    -------
    suppression_model: callable
        suppression_model(t) returns the current suppression model at time t.
    """

    t0 = fit_results.load_t0(fips)
    rho = []

    # Check for fips that don't match.
    public_implementations = load_public_implementations_data().set_index(
        'fips')

    # Not all counties present in this dataset.
    if fips not in public_implementations.index:
        # Then assume 1.0 until today and then future_suppression going forward.
        for t_step in t_list:
            t_actual = t0 + timedelta(days=t_step)
            if t_actual <= datetime.now():
                rho.append(1.0)
            else:
                rho.append(future_suppression)
    else:
        policies = public_implementations.loc[fips].to_dict()
        for t_step in t_list:
            t_actual = t0 + timedelta(days=t_step)
            rho_this_t = 1

            # If this is a future date, assume lockdown continues.
            if t_actual > datetime.utcnow():
                rho.append(future_suppression)
                continue

            # If the policy was enacted on this timestep then activate it in
            # addition to others. These measures are additive unless lockdown is
            # instituted.
            for independent_measure in [
                    'public_schools', 'entertainment_gym',
                    'restaurant_dine-in', 'federal_guidelines'
            ]:

                if not pd.isnull(policies[independent_measure]) and t_actual > \
                        policies[independent_measure]:
                    rho_this_t -= distancing_measure_suppression[
                        independent_measure]

            # Only take the max of these, since 500 doesn't matter if 50 is enacted.
            if not pd.isnull(policies['50_gatherings']
                             ) and t_actual > policies['50_gatherings']:
                rho_this_t -= distancing_measure_suppression['50_gatherings']
            elif not pd.isnull(policies['500_gatherings']
                               ) and t_actual > policies['500_gatherings']:
                rho_this_t -= distancing_measure_suppression['500_gatherings']

            # If lockdown, then we don't care about any others, just set to
            # future suppression.
            if pd.isnull(policies['stay_at_home']
                         ) and t_actual > policies['stay_at_home']:
                rho_this_t = future_suppression
            rho.append(rho_this_t)

    return interp1d(t_list, rho, fill_value='extrapolate')