def train_for_region(self, data_source, region_type, region_name, train_start_date, train_end_date,
                      search_space, search_parameters, train_loss_function, input_filepath):
     observations = DataFetcherModule.get_observations_for_region(region_type, region_name,
                                                                  data_source=data_source, filepath=input_filepath)
     region_metadata = DataFetcherModule.get_regional_metadata(region_type, region_name, data_source=data_source)
     return self.train(region_metadata, observations, train_start_date, train_end_date,
                       search_space, search_parameters, train_loss_function)
예제 #2
0
 def predict_for_region(self, region_type, region_name, run_day, forecast_start_date,
                        forecast_end_date):
     observations = DataFetcherModule.get_observations_for_region(region_type, region_name)
     region_metadata = DataFetcherModule.get_regional_metadata(region_type, region_name)
     return self.predict(region_type, region_name, region_metadata, observations, run_day,
                         forecast_start_date,
                         forecast_end_date)
예제 #3
0
 def evaluate_for_region(self, region_type, region_name, run_day,
                         test_start_date, test_end_date, loss_functions):
     observations = DataFetcherModule.get_observations_for_region(
         region_type, region_name)
     region_metadata = DataFetcherModule.get_regional_metadata(
         region_type, region_name)
     return self.evaluate(region_metadata, observations, run_day,
                          test_start_date, test_end_date, loss_functions)
예제 #4
0
 def train_for_region(self, region_type, region_name, train_start_date,
                      train_end_date, search_space, search_parameters,
                      train_loss_function):
     observations = DataFetcherModule.get_observations_for_region(
         region_type, region_name)
     region_metadata = DataFetcherModule.get_regional_metadata(
         region_type, region_name)
     return self.train(region_metadata, observations, train_start_date,
                       train_end_date, search_space, search_parameters,
                       train_loss_function)
예제 #5
0
 def evaluate_for_region(self, data_source, region_type, region_name,
                         run_day, test_start_date, test_end_date,
                         loss_functions, input_filepath):
     observations = DataFetcherModule.get_observations_for_region(
         region_type,
         region_name,
         data_source=data_source,
         filepath=input_filepath)
     region_metadata = DataFetcherModule.get_regional_metadata(
         region_type, region_name, data_source=data_source)
     return self.evaluate(region_metadata, observations, run_day,
                          test_start_date, test_end_date, loss_functions)
 def predict_for_region(self, data_source, region_type, region_name,
                        run_day, forecast_start_date, forecast_end_date,
                        input_filepath):
     observations = DataFetcherModule.get_observations_for_region(
         region_type,
         region_name,
         data_source=data_source,
         filepath=input_filepath)
     region_metadata = DataFetcherModule.get_regional_metadata(
         region_type, region_name, data_source=data_source)
     return self.predict(region_type, region_name, region_metadata,
                         observations, run_day, forecast_start_date,
                         forecast_end_date)
예제 #7
0
def get_observations_in_range(region_name,
                              region_type,
                              start_date,
                              end_date,
                              obs_type='confirmed'):
    """
        Return a list of counts of obs_type cases
        from the region in the specified date range.
    """
    observations = DataFetcherModule.get_observations_for_region(
        region_type, region_name)
    observations_df = observations[observations['observation'] == obs_type]

    start_date = datetime.strptime(start_date, '%m/%d/%y')
    end_date = datetime.strptime(end_date, '%m/%d/%y')
    delta = (end_date - start_date).days
    days = []
    for i in range(delta + 1):
        days.append((start_date + timedelta(days=i)).strftime('%-m/%-d/%-y'))

    # Fetch observations in the date range
    observations_df = observations_df[days]

    # Transpose the df to get the
    # observations_df.shape = (num_days, 1)
    observations_df = observations_df.reset_index(drop=True).transpose()

    # Rename the column to capture observation type
    # Note that the hardcoded 0 in the rename comes from the reset_index
    # from the previous step
    observations = observations_df[0].to_list()
    return observations
예제 #8
0
def get_clean_staffing_ratio(staff_ratios_file_path: str):
    """Fixes the input staff ratio matrix and also the column names, index"""
    staff_ratios = DataFetcherModule.get_staffing_ratios(
        staff_ratios_file_path)
    for b in staff_ratios.columns:
        bnew = b.split('\n')[0].strip()
        staff_ratios = staff_ratios.rename(columns={b: bnew})
    staff_ratios = staff_ratios.set_index('Personnel')
    staff_ratios.fillna(0, inplace=True)
    return staff_ratios.copy()
 def predict_for_region(self, data_source: DataSource, region_type: str, region_name: List[str], run_day: str,
                        start_date: str, input_type: InputType, time_intervals: List[ForecastTimeInterval],
                        input_filepath: str):
     """
     method downloads data using data fetcher module and then run predict on that dataset.
     @param region_type: region_type supported by data_fetcher module
     @param region_name: region_name supported by data_fetcher module
     @param run_day: date of initialization
     @param start_date: start_date
     @param input_type: input_type can be npi_list/param_override
     @param time_intervals: list of time_intervals with parameters
     @param data_source: data source
     @param input_filepath: input data file path
     @return: pd.DataFrame: predictions
     """
     observations = DataFetcherModule.get_observations_for_region(region_type, region_name, data_source=data_source,
                                                                  filepath=input_filepath)
     region_metadata = DataFetcherModule.get_regional_metadata(region_type, region_name)
     return self.predict(region_type, region_name, region_metadata, observations, run_day,
                         start_date, input_type, time_intervals)
예제 #10
0
def plot(model_params,
         forecast_df,
         forecast_start_date,
         forecast_end_date,
         plot_name='default.png'):
    """
        Plot actual_confirmed cases vs forecasts.
        
        Assert that forecast_end_date is prior to the current date
        to ensure availability of actual_counts.
    """
    # Check for forecast_end_date being prior to current date
    end_date = datetime.strptime(forecast_end_date, '%m/%d/%y')
    assert end_date < datetime.now()

    # Fetch actual counts from the DataFetcher module
    region_name = model_params['region']
    region_type = model_params['region_type']
    actual_observations = DataFetcherModule.get_observations_for_region(
        region_name, region_type)

    # Get relevant time-series of actual counts from actual_observations
    actual_observations = get_observations_in_range(region_name,
                                                    region_type,
                                                    forecast_start_date,
                                                    forecast_end_date,
                                                    obs_type='confirmed')

    forecast_df['actual_confirmed'] = actual_observations

    fig, ax = plt.subplots(figsize=(15, 5))
    fig.suptitle(model_params['region'])
    ax.plot(forecast_df['index'],
            forecast_df['actual_confirmed'],
            color='blue',
            label="actual_confirmed")
    ax.plot(forecast_df['index'],
            forecast_df['confirmed_mean'],
            color='orange',
            label="predicted_confirmed")
    ax.set_ylim(ymin=0)
    ax.legend()

    plt.savefig(plot_name)
예제 #11
0
def plot_m3(train2_model_params,
            train1_start_date,
            forecast_start_date,
            forecast_length,
            rolling_average=False,
            uncertainty=False,
            forecast_config='forecast_config.json',
            plot_config='plot_config.json',
            plot_name='default.png'):

    ## TODO: Log scale
    with open(plot_config) as fplot, \
        open(forecast_config) as fcast:
        default_plot_config = json.load(fplot)
        default_forecast_config = json.load(fcast)

    plot_config = deepcopy(default_plot_config)
    plot_config['uncertainty'] = uncertainty
    plot_config['rolling_average'] = rolling_average

    actual_start_date = (datetime.strptime(train1_start_date, "%m/%d/%y") -
                         timedelta(days=14)).strftime("%-m/%-d/%y")
    forecast_run_day = (datetime.strptime(forecast_start_date, "%m/%d/%y") -
                        timedelta(days=1)).strftime("%-m/%-d/%y")
    forecast_end_date = (
        datetime.strptime(forecast_start_date, "%m/%d/%y") +
        timedelta(days=forecast_length)).strftime("%-m/%-d/%y")

    # Get predictions
    pd_df_forecast = forecast(train2_model_params, forecast_run_day,
                              forecast_start_date, forecast_end_date,
                              default_forecast_config)

    pd_df_forecast['index'] = pd.to_datetime(pd_df_forecast['index'])
    pd_df = pd_df_forecast.sort_values(by=['index'])

    # Get observed data
    actual = DataFetcherModule.get_observations_for_region(
        train2_model_params['region_type'], train2_model_params['region'])
    actual = actual.set_index('observation')
    actual = actual.transpose()
    actual = actual.reset_index()
    start = actual.index[actual['index'] == actual_start_date].tolist()[0]
    end = actual.index[actual['index'] == forecast_run_day].tolist()[0]
    actual = actual[start:end + 1]
    actual['index'] = pd.to_datetime(actual['index'])

    plot_markers = plot_config['markers']
    plot_colors = plot_config['colors']
    plot_labels = plot_config['labels']
    plot_variables = plot_config['variables']

    fig, ax = plt.subplots(figsize=(16, 12))

    for variable in plot_variables:

        # Plot observed values
        ax.plot(actual['index'],
                actual[variable],
                plot_markers['observed'],
                color=plot_colors[variable],
                label=plot_labels[variable] + ': Observed')

        # Plot mean predictions
        if variable + '_mean' in pd_df:
            ax.plot(pd_df['index'],
                    pd_df[variable + '_mean'],
                    plot_markers['predicted']['mean'],
                    color=plot_colors[variable],
                    label=plot_labels[variable] + ': Predicted')

        # Plot uncertainty in predictions
        if plot_config['uncertainty'] == True:

            if variable + '_min' in pd_df:
                ax.plot(pd_df['index'],
                        pd_df[variable + '_min'],
                        plot_markers['predicted']['min'],
                        color=plot_colors[variable],
                        label=plot_labels[variable] + ': Predicted (Min)')

            if variable + '_max' in pd_df:
                ax.plot(pd_df['index'],
                        pd_df[variable + '_max'],
                        plot_markers['predicted']['max'],
                        color=plot_colors[variable],
                        label=plot_labels[variable] + ': Predicted (Max)')

        # Plot rolling average
        if plot_config['rolling_average'] == True and variable + '_ra' in pd_df:
            ax.plot(pd_df['index'],
                    pd_df[variable + '_ra'],
                    plot_markers['rolling_average'],
                    color=plot_colors[variable],
                    label=plot_labels[variable] + ': Predicted (RA)')

    ax.xaxis.set_major_locator(mdates.DayLocator(interval=5))
    ax.xaxis.set_minor_locator(mdates.DayLocator(interval=1))
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))

    plt.ylabel('No of People')
    plt.xlabel('Time')
    plt.legend()
    plt.grid()

    plt.savefig(plot_name)