def _generate_input_data( fips: str, include_testing_correction: bool, include_deaths: bool, figure_collector: Optional[list], ): """ Allow the RtInferenceEngine to be agnostic to aggregation level by handling the loading first include_testing_correction: bool If True, include a correction for testing increases and decreases. """ times, observed_new_cases, observed_new_deaths = load_data.load_new_case_data_by_fips( fips, t0=InferRtConstants.REF_DATE, include_testing_correction=include_testing_correction) date = [InferRtConstants.REF_DATE + timedelta(days=int(t)) for t in times] df = filter_and_smooth_input_data( df=pd.DataFrame(dict(cases=observed_new_cases, deaths=observed_new_deaths), index=date), include_deaths=include_deaths, figure_collector=figure_collector, display_name=fips, log=rt_log.new(fips=fips), ) return df
def load_observations(fips=None, ref_date=REF_DATE): """ Load observations (new cases, new deaths and hospitalizations) for given fips code. Parameters ---------- fips: str FIPS code. ref_date: Datetime Reference start date. Returns ------- observations: pd.DataFrame Contains observations for given fips codes, with columns: - new_cases: float, observed new cases - new_deaths: float, observed new deaths - hospitalizations: float, observed hospitalizatons and dates of observation as index. """ observations = {} if len(fips) == 5: times, observations['new_cases'], observations['new_deaths'] = \ load_data.load_new_case_data_by_fips(fips, ref_date) hospital_times, hospitalizations, hospitalization_data_type = \ load_data.load_hospitalization_data(fips, t0=ref_date) observations['times'] = times.values elif len(fips) == 2: state_obj = us.states.lookup(fips) observations['times'], observations['new_cases'], observations['new_deaths'] = \ load_data.load_new_case_data_by_state(state_obj.name, ref_date) hospital_times, hospitalizations, hospitalization_data_type = \ load_data.load_hospitalization_data_by_state(state_obj.abbr, t0=ref_date) observations['times'] = np.array(observations['times']) observations['hospitalizations'] = np.full( observations['times'].shape[0], np.nan) if hospitalization_data_type is HospitalizationDataType.CUMULATIVE_HOSPITALIZATIONS: observations['hospitalizations'][ hospital_times - observations['times'].min()] = np.diff(hospitalizations) elif hospitalization_data_type is HospitalizationDataType.CURRENT_HOSPITALIZATIONS: observations['hospitalizations'][ hospital_times - observations['times'].min()] = hospitalizations observation_dates = [ ref_date + timedelta(int(t)) for t in observations['times'] ] observations = pd.DataFrame( observations, index=pd.DatetimeIndex(observation_dates)).dropna(axis=1, how='all') return observations
def _whitelist_candidates_per_fips(fips): times, observed_new_cases, observed_new_deaths = load_data.load_new_case_data_by_fips( fips, t0=datetime(day=1, month=1, year=2020)) record = dict(fips=fips, total_cases=observed_new_cases.sum(), total_deaths=observed_new_deaths.sum(), nonzero_case_datapoints=np.sum(observed_new_cases > 0), nonzero_death_datapoints=np.sum(observed_new_deaths > 0)) return pd.Series(record)
def run_for_fips(cls, fips, n_retries=3, with_age_structure=False): """ Run the model fitter for a state or county fips code. Parameters ---------- fips: str 2-digit state or 5-digit county fips code. n_retries: int The model fitter is stochastic in nature and a seed cannot be set. This is a bandaid until more sophisticated retries can be implemented. with_age_structure: bool If True run model with age structure. Returns ------- : ModelFitter """ # Assert that there are some cases for counties if len(fips) == 5: _, observed_new_cases, _ = load_data.load_new_case_data_by_fips( fips, t0=datetime.today()) if observed_new_cases.sum() < 1: return None try: retries_left = n_retries model_is_empty = True while retries_left > 0 and model_is_empty: model_fitter = cls(fips=fips, with_age_structure=with_age_structure) try: model_fitter.fit() if model_fitter.mle_model and os.environ.get( 'PYSEIR_PLOT_RESULTS') == 'True': model_fitter.plot_fitting_results() except RuntimeError as e: logging.warning('No convergence.. Retrying ' + str(e)) retries_left = retries_left - 1 if model_fitter.mle_model: model_is_empty = False if retries_left <= 0 and model_is_empty: raise RuntimeError( f'Could not converge after {n_retries} for fips {fips}') except Exception: logging.exception(f"Failed to run {fips}") return None return model_fitter
def generate_whitelist(self): """ Generate a county whitelist based on the cuts above. Returns ------- df: whitelist """ logging.info('Generating county level whitelist...') whitelist_generator_inputs = [] for fips in self.county_metadata.fips: times, observed_new_cases, observed_new_deaths = load_data.load_new_case_data_by_fips( fips, t0=datetime(day=1, month=1, year=2020)) metadata = self.county_metadata[self.county_metadata.fips == fips].iloc[0].to_dict() record = dict( fips=fips, state=metadata['state'], county=metadata['county'], total_cases=observed_new_cases.sum(), total_deaths=observed_new_deaths.sum(), nonzero_case_datapoints=np.sum(observed_new_cases > 0), nonzero_death_datapoints=np.sum(observed_new_deaths > 0) ) whitelist_generator_inputs.append(record) df_candidates = pd.DataFrame(whitelist_generator_inputs) df_whitelist = df_candidates[['fips', 'state', 'county']] df_whitelist.loc[:, 'inference_ok'] = ( (df_candidates.nonzero_case_datapoints >= self.nonzero_case_datapoints) & (df_candidates.nonzero_death_datapoints >= self.nonzero_death_datapoints) & (df_candidates.total_cases >= self.total_cases) & (df_candidates.total_deaths >= self.total_deaths) ) output_path = get_run_artifact_path( fips='06', # Dummy fips since not used here... artifact=RunArtifact.WHITELIST_RESULT) df_whitelist.to_json(output_path) return df_whitelist
def run_for_fips(cls, fips, n_retries=3): """ Run the model fitter for a state or county fips code. Parameters ---------- fips: str 2-digit state or 5-digit county fips code. n_retries: int The model fitter is stochastic in nature and a seed cannot be set. This is a bandaid until more sophisticated retries can be implemented. Returns ------- : ModelFitter """ # Assert that there are some cases for counties if len(fips) == 5: _, observed_new_cases, _ = load_data.load_new_case_data_by_fips( fips, t0=datetime.today()) if observed_new_cases.sum() < 1: return None try: for i in range(n_retries): model_fitter = cls(fips) try: model_fitter.fit() if model_fitter.mle_model: model_fitter.plot_fitting_results() break except RuntimeError as e: logging.warning('No convergence.. Retrying ' + str(e)) if model_fitter.mle_model is None: raise RuntimeError( f'Could not converge after {n_retries} for fips {fips}') except Exception: logging.exception(f"Failed to run {fips}") return None return model_fitter
def __init__(self, fips, ref_date=datetime(year=2020, month=1, day=1), min_deaths=2, n_years=1, cases_to_deaths_err_factor=.5, hospital_to_deaths_err_factor=.5, percent_error_on_max_observation=0.5, with_age_structure=False): # Seed the random state. It is unclear whether this propagates to the # Minuit optimizer. np.random.seed(seed=42) self.fips = fips self.ref_date = ref_date self.min_deaths = min_deaths self.t_list = np.linspace(0, int(365 * n_years), int(365 * n_years) + 1) self.cases_to_deaths_err_factor = cases_to_deaths_err_factor self.hospital_to_deaths_err_factor = hospital_to_deaths_err_factor self.percent_error_on_max_observation = percent_error_on_max_observation self.t0_guess = 60 self.with_age_structure = with_age_structure if len(fips) == 2: # State FIPS are 2 digits self.agg_level = AggregationLevel.STATE self.state_obj = us.states.lookup(self.fips) self.state = self.state_obj.name self.times, self.observed_new_cases, self.observed_new_deaths = \ load_data.load_new_case_data_by_state(self.state, self.ref_date) self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \ load_data.load_hospitalization_data_by_state(self.state_obj.abbr, t0=self.ref_date) self.display_name = self.state else: self.agg_level = AggregationLevel.COUNTY geo_metadata = load_data.load_county_metadata().set_index( 'fips').loc[fips].to_dict() state = geo_metadata['state'] self.state_obj = us.states.lookup(state) county = geo_metadata['county'] if county: self.display_name = county + ', ' + state else: self.display_name = state # TODO Swap for new data source. self.times, self.observed_new_cases, self.observed_new_deaths = \ load_data.load_new_case_data_by_fips(self.fips, t0=self.ref_date) self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \ load_data.load_hospitalization_data(self.fips, t0=self.ref_date) self.cases_stdev, self.hosp_stdev, self.deaths_stdev = self.calculate_observation_errors( ) self.set_inference_parameters() self.model_fit_keys = ['R0', 'eps', 't_break', 'log10_I_initial'] self.SEIR_kwargs = self.get_average_seir_parameters() self.fit_results = None self.mle_model = None self.chi2_deaths = None self.chi2_cases = None self.chi2_hosp = None self.dof_deaths = None self.dof_cases = None self.dof_hosp = None
def __init__( self, fips, window_size=InferRtConstants.COUNT_SMOOTHING_WINDOW_SIZE, kernel_std=5, r_list=np.linspace(0, 10, 501), process_sigma=0.05, ref_date=datetime(year=2020, month=1, day=1), confidence_intervals=(0.68, 0.95), min_cases=5, min_deaths=5, include_testing_correction=True, ): np.random.seed(InferRtConstants.RNG_SEED) # Param Generation used for Xcor in align_time_series, has some stochastic FFT elements. self.fips = fips self.r_list = r_list self.window_size = window_size self.kernel_std = kernel_std self.process_sigma = process_sigma self.ref_date = ref_date self.confidence_intervals = confidence_intervals self.min_cases = min_cases self.min_deaths = min_deaths self.include_testing_correction = include_testing_correction # Because rounding is disabled we don't need high min_deaths, min_cases anymore self.min_cases = min(InferRtConstants.MIN_COUNTS_TO_INFER, self.min_cases) if not InferRtConstants.DISABLE_DEATHS: self.min_deaths = min(InferRtConstants.MIN_COUNTS_TO_INFER, self.min_deaths) if len(fips) == 2: # State FIPS are 2 digits self.agg_level = AggregationLevel.STATE self.state_obj = us.states.lookup(self.fips) self.state = self.state_obj.name ( self.times, self.observed_new_cases, self.observed_new_deaths, ) = load_data.load_new_case_data_by_state( self.state, self.ref_date, include_testing_correction=self.include_testing_correction, ) self.times_raw_new_cases, self.raw_new_cases, _ = load_data.load_new_case_data_by_state( self.state, self.ref_date, include_testing_correction=False ) ( self.hospital_times, self.hospitalizations, self.hospitalization_data_type, ) = load_data.load_hospitalization_data_by_state( state=self.state_obj.abbr, t0=self.ref_date ) self.display_name = self.state else: self.agg_level = AggregationLevel.COUNTY self.geo_metadata = ( load_data.load_county_metadata().set_index("fips").loc[fips].to_dict() ) self.state = self.geo_metadata["state"] self.state_obj = us.states.lookup(self.state) self.county = self.geo_metadata["county"] if self.county: self.display_name = self.county + ", " + self.state else: self.display_name = self.state ( self.times, self.observed_new_cases, self.observed_new_deaths, ) = load_data.load_new_case_data_by_fips( self.fips, t0=self.ref_date, include_testing_correction=self.include_testing_correction, ) ( self.times_raw_new_cases, self.raw_new_cases, _, ) = load_data.load_new_case_data_by_fips( self.fips, t0=self.ref_date, include_testing_correction=False, ) ( self.hospital_times, self.hospitalizations, self.hospitalization_data_type, ) = load_data.load_hospitalization_data(self.fips, t0=self.ref_date) self.case_dates = [ref_date + timedelta(days=int(t)) for t in self.times] self.raw_new_case_dates = [ ref_date + timedelta(days=int(t)) for t in self.times_raw_new_cases ] if self.hospitalization_data_type: self.hospital_dates = [ref_date + timedelta(days=int(t)) for t in self.hospital_times] self.default_parameters = ParameterEnsembleGenerator( fips=self.fips, N_samples=500, t_list=np.linspace(0, 365, 366) ).get_average_seir_parameters() # Serial period = Incubation + 0.5 * Infections self.serial_period = ( 1 / self.default_parameters["sigma"] + 0.5 * 1 / self.default_parameters["delta"] ) # If we only receive current hospitalizations, we need to account for # the outflow to reconstruct new admissions. if ( self.hospitalization_data_type is load_data.HospitalizationDataType.CURRENT_HOSPITALIZATIONS ): los_general = self.default_parameters["hospitalization_length_of_stay_general"] los_icu = self.default_parameters["hospitalization_length_of_stay_icu"] hosp_rate_general = self.default_parameters["hospitalization_rate_general"] hosp_rate_icu = self.default_parameters["hospitalization_rate_icu"] icu_rate = hosp_rate_icu / hosp_rate_general flow_out_of_hosp = self.hospitalizations[:-1] * ( (1 - icu_rate) / los_general + icu_rate / los_icu ) # We are attempting to reconstruct the cumulative hospitalizations. self.hospitalizations = np.diff(self.hospitalizations) + flow_out_of_hosp self.hospital_dates = self.hospital_dates[1:] self.hospital_times = self.hospital_times[1:] self.log_likelihood = None self.log = structlog.getLogger(Rt_Inference_Target=self.display_name) self.log.info(event="Running:")
def __init__(self, fips, window_size=7, kernel_std=2, r_list=np.linspace(0, 10, 501), process_sigma=0.15, ref_date=datetime(year=2020, month=1, day=1), confidence_intervals=(0.68, 0.75, 0.90)): self.fips = fips self.r_list = r_list self.window_size = window_size self.kernel_std = kernel_std self.process_sigma = process_sigma self.ref_date = ref_date self.confidence_intervals = confidence_intervals if len(fips) == 2: # State FIPS are 2 digits self.agg_level = AggregationLevel.STATE self.state_obj = us.states.lookup(self.fips) self.state = self.state_obj.name self.geo_metadata = load_data.load_county_metadata_by_state(self.state).loc[self.state].to_dict() self.times, self.observed_new_cases, self.observed_new_deaths = \ load_data.load_new_case_data_by_state(self.state, self.ref_date) self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \ load_data.load_hospitalization_data_by_state(self.state_obj.abbr, t0=self.ref_date) self.display_name = self.state else: self.agg_level = AggregationLevel.COUNTY self.geo_metadata = load_data.load_county_metadata().set_index('fips').loc[fips].to_dict() self.state = self.geo_metadata['state'] self.state_obj = us.states.lookup(self.state) self.county = self.geo_metadata['county'] if self.county: self.display_name = self.county + ', ' + self.state else: self.display_name = self.state # TODO Swap for new data source. self.times, self.observed_new_cases, self.observed_new_deaths = \ load_data.load_new_case_data_by_fips(self.fips, t0=self.ref_date) self.hospital_times, self.hospitalizations, self.hospitalization_data_type = \ load_data.load_hospitalization_data(self.fips, t0=self.ref_date) logging.info(f'Running Rt Inference for {self.display_name}') self.case_dates = [ref_date + timedelta(days=int(t)) for t in self.times] if self.hospitalization_data_type: self.hospital_dates = [ref_date + timedelta(days=int(t)) for t in self.hospital_times] self.default_parameters = ParameterEnsembleGenerator( fips=self.fips, N_samples=500, t_list=np.linspace(0, 365, 366) ).get_average_seir_parameters() # Serial period = Incubation + 0.5 * Infections self.serial_period = 1 / self.default_parameters['sigma'] + 0.5 * 1 / self.default_parameters['delta'] # If we only receive current hospitalizations, we need to account for # the outflow to reconstruct new admissions. if self.hospitalization_data_type is load_data.HospitalizationDataType.CURRENT_HOSPITALIZATIONS: los_general = self.default_parameters['hospitalization_length_of_stay_general'] los_icu = self.default_parameters['hospitalization_length_of_stay_icu'] hosp_rate_general = self.default_parameters['hospitalization_rate_general'] hosp_rate_icu = self.default_parameters['hospitalization_rate_icu'] icu_rate = hosp_rate_icu / hosp_rate_general flow_out_of_hosp = self.hospitalizations[:-1] * ((1 - icu_rate) / los_general + icu_rate / los_icu) # We are attempting to reconstruct the cumulative hospitalizations. self.hospitalizations = np.diff(self.hospitalizations) + flow_out_of_hosp self.hospital_dates = self.hospital_dates[1:] self.hospital_times = self.hospital_times[1:] self.log_likelihood = None
def __init__( self, fips, ref_date=datetime(year=2020, month=1, day=1), min_deaths=2, n_years=1, cases_to_deaths_err_factor=0.5, hospital_to_deaths_err_factor=0.5, percent_error_on_max_observation=0.5, with_age_structure=False, ): # Seed the random state. It is unclear whether this propagates to the # Minuit optimizer. np.random.seed(seed=42) self.fips = fips self.ref_date = ref_date self.days_since_ref_date = (dt.date.today() - ref_date.date() - timedelta(days=7)).days # ndays end of 2nd ramp may extend past days_since_ref_date w/o penalty on chi2 score self.days_allowed_beyond_ref = 0 self.min_deaths = min_deaths self.t_list = np.linspace(0, int(365 * n_years), int(365 * n_years) + 1) self.cases_to_deaths_err_factor = cases_to_deaths_err_factor self.hospital_to_deaths_err_factor = hospital_to_deaths_err_factor self.percent_error_on_max_observation = percent_error_on_max_observation self.t0_guess = 60 self.with_age_structure = with_age_structure if len(fips) == 2: # State FIPS are 2 digits self.agg_level = AggregationLevel.STATE self.state_obj = us.states.lookup(self.fips) self.state = self.state_obj.name ( self.times, self.observed_new_cases, self.observed_new_deaths, ) = load_data.load_new_case_data_by_state(self.state, self.ref_date) ( self.hospital_times, self.hospitalizations, self.hospitalization_data_type, ) = load_data.load_hospitalization_data_by_state( self.state_obj.abbr, t0=self.ref_date) ( self.icu_times, self.icu, self.icu_data_type, ) = load_data.load_hospitalization_data_by_state( self.state_obj.abbr, t0=self.ref_date, category=HospitalizationCategory.ICU) self.display_name = self.state else: self.agg_level = AggregationLevel.COUNTY geo_metadata = load_data.load_county_metadata().set_index( "fips").loc[fips].to_dict() state = geo_metadata["state"] self.state_obj = us.states.lookup(state) county = geo_metadata["county"] if county: self.display_name = county + ", " + state else: self.display_name = state # TODO Swap for new data source. ( self.times, self.observed_new_cases, self.observed_new_deaths, ) = load_data.load_new_case_data_by_fips(self.fips, t0=self.ref_date) ( self.hospital_times, self.hospitalizations, self.hospitalization_data_type, ) = load_data.load_hospitalization_data(self.fips, t0=self.ref_date) ( self.icu_times, self.icu, self.icu_data_type, ) = load_data.load_hospitalization_data( self.fips, t0=self.ref_date, category=HospitalizationCategory.ICU) self.cases_stdev, self.hosp_stdev, self.deaths_stdev = self.calculate_observation_errors( ) self.set_inference_parameters() self.model_fit_keys = [ "R0", "eps", "t_break", "eps2", "t_delta_phases", "log10_I_initial", ] self.SEIR_kwargs = self.get_average_seir_parameters() self.fit_results = None self.mle_model = None self.chi2_deaths = None self.chi2_cases = None self.chi2_hosp = None self.dof_deaths = None self.dof_cases = None self.dof_hosp = None
def fit_county_model(fips): """ Fit the county's current trajectory, using the existing measures. We fit only to mortality data if available, else revert to case data. We assume a poisson process generates mortalities at a rate defined by the underlying dynamical model. TODO @ EC: Add hospitalization data when available. Parameters ---------- fips: str County fips. Returns ------- fit_values: dict Optimal values from the fitter. """ county_metadata = load_data.load_county_metadata().set_index( 'fips').loc[fips].to_dict() times, observed_new_cases, observed_new_deaths = load_data.load_new_case_data_by_fips( fips, t0=ref_date) logging.info( f'Fitting MLE model to {county_metadata["county"]}, {county_metadata["state"]}' ) SEIR_params = get_average_SEIR_parameters(fips) def _fit_seir(R0, t0, eps): model = SEIRModel(R0=R0, suppression_policy=suppression_policies. generate_empirical_distancing_policy( t_list, fips, future_suppression=eps), **SEIR_params) model.run() predicted_cases = model.gamma * np.interp( times, t_list + t0, model.results['total_new_infections']) predicted_deaths = np.interp(times, t_list + t0, model.results['direct_deaths_per_day']) # Assume the error on the case count could be off by a massive factor 50. # Basically we don't want to use it if there appreciable mortality data available. # Longer term there is a better procedure. cases_variance = 1e10 * observed_new_cases.copy( )**2 # Make the stdev N times larger x the number of cases deaths_variance = observed_new_deaths.copy() # Poisson dist error # Zero inflated poisson Avoid floating point errors.. cases_variance[cases_variance == 0] = 1e10 deaths_variance[deaths_variance == 0] = 1e10 # Compute Chi2 chi2_cases = np.sum( (observed_new_cases - predicted_cases)**2 / cases_variance) if observed_new_deaths.sum() > 5: chi2_deaths = np.sum( (observed_new_deaths - predicted_deaths)**2 / deaths_variance) else: chi2_deaths = 0 return chi2_deaths + chi2_cases # Note that error def is not right here. We need a realistic error model... m = iminuit.Minuit(_fit_seir, R0=4, t0=50, eps=.5, error_eps=.2, limit_R0=[1, 8], limit_eps=[0, 2], limit_t0=[-90, 90], error_t0=1, error_R0=1., errordef=1) m.migrad() values = dict(fips=fips, **dict(m.values)) values['t0_date'] = ref_date + timedelta(days=values['t0']) values['Reff_current'] = values['R0'] * values['eps'] values['observed_total_deaths'] = np.sum(observed_new_deaths) values['county'] = county_metadata['county'] values['state'] = county_metadata['state'] values['total_population'] = county_metadata['total_population'] values['population_density'] = county_metadata['population_density'] return values
def plot_inferred_result(fit_results): """ Plot the results of an MLE inference """ fips = fit_results['fips'] county_metadata = load_data.load_county_metadata().set_index( 'fips').loc[fips].to_dict() times, observed_new_cases, observed_new_deaths = load_data.load_new_case_data_by_fips( fips, t0=ref_date) if observed_new_cases.sum() < 5: logging.warning( f"{county_metadata['county']} has fewer than 5 cases. Aborting plot." ) return else: logging.info(f"Plotting MLE Fits for {county_metadata['county']}") R0, t0, eps = fit_results['R0'], fit_results['t0'], fit_results['eps'] model = SEIRModel(R0=R0, suppression_policy=suppression_policies. generate_empirical_distancing_policy( t_list, fips, future_suppression=eps), **get_average_SEIR_parameters(fit_results['fips'])) model.run() data_dates = [ref_date + timedelta(days=t) for t in times] model_dates = [ ref_date + timedelta(days=t + fit_results['t0']) for t in t_list ] plt.figure(figsize=(10, 8)) plt.errorbar(data_dates, observed_new_cases, marker='o', linestyle='', label='Observed Cases Per Day') plt.errorbar(data_dates, observed_new_deaths, yerr=np.sqrt(observed_new_deaths), marker='o', linestyle='', label='Observed Deaths') plt.plot(model_dates, model.results['total_new_infections'], label='Estimated Total New Infections Per Day') plt.plot(model_dates, model.gamma * model.results['total_new_infections'], label='Symptomatic Model Cases Per Day') plt.plot(model_dates, model.results['direct_deaths_per_day'], label='Model Deaths Per Day') plt.yscale('log') plt.ylim(.9e0) plt.xlim(data_dates[0], data_dates[-1] + timedelta(days=90)) plt.xticks(rotation=30) plt.legend(loc=1) plt.grid(which='both', alpha=.3) plt.title(county_metadata['county']) for i, (k, v) in enumerate(fit_results.items()): if k not in ('fips', 't0_date', 'county', 'state'): plt.text(.025, .97 - 0.04 * i, f'{k}={v:1.3f}', transform=plt.gca().transAxes, fontsize=12) else: plt.text(.025, .97 - 0.04 * i, f'{k}={v}', transform=plt.gca().transAxes, fontsize=12) output_file = os.path.join( OUTPUT_DIR, fit_results['state'].title(), 'reports', f'{fit_results["state"]}__{fit_results["county"]}__{fit_results["fips"]}__mle_fit_results.pdf' ) plt.savefig(output_file)