Exemple #1
0
    def calculate_results(mapping, diff, diff_lt, normalize=True):
        """ Calculate results """

        lats = []
        lons = []
        text = []
        val_std = []

        for index in diff_lt:
            ix = mapping.index[index % mapping.shape[0]]
            lat = mapping.loc[ix, 'latitude']
            lon = mapping.loc[ix, 'longitude']
            lats.append(lat)
            lons.append(lon)
            text.append(f'({lat}, {lon})={diff[index]}')

            diff_ix = diff[index]
            val_std.append(1.5 / diff_ix if diff_ix > 0 else 2)

        result = pd.DataFrame({
            'lat': lats,
            'lon': lons,
            'text': text,
            'val_std': val_std
        })

        if normalize:
            result['val_std'] = ModelHelpers.normalize(result['val_std'])

        return result
Exemple #2
0
    def train_test_split(trmm_data,
                         prediction_ts,
                         onset_ts,
                         years=range(1998, 2017),
                         years_train=range(1998, 2016),
                         years_dev=None,
                         years_test=range(2016, 2017)):
        """
        Prepare data to be in a digestible format for the model

        :trmm_data: Filtered and optionally aggregated TRMM dataset to use
        :outcomes: Outcomes as generated by the base model

        :return:
        """

        # generate outcomes
        outcomes = ModelHelpers.generate_outcomes(prediction_ts,
                                                  onset_ts,
                                                  years,
                                                  numerical=True)

        # unstack the entire trmm dataset
        # => bring into matrix form with lat/lon on axes
        unstacked = ModelHelpers.unstack_all(trmm_data, years)

        # generate training data
        X_train = ModelHelpers.reshape_years(
            [unstacked[year] for year in years_train], num_channels=1)
        y_train = ModelHelpers.stack_outcomes(outcomes, years_train)

        # generate test data
        X_test = ModelHelpers.reshape_years(
            [unstacked[year] for year in years_test], num_channels=1)
        y_test = ModelHelpers.stack_outcomes(outcomes, years_test)

        if years_dev:
            X_dev = ModelHelpers.reshape_years(
                [unstacked[year] for year in years_dev], num_channels=1)
            y_dev = ModelHelpers.stack_outcomes(outcomes, years_dev)

            return X_train, y_train, X_test, y_test, X_dev, y_dev, unstacked

        return X_train, y_train, X_test, y_test, None, None, unstacked
def filter_fun(df, year):
    # setup a filter function
    return ModelHelpers.filter_until(df, prediction_ts[year])
    {
        'epochs': 50,
        'patience': PATIENCE,
        'lr_plateau': (0.1, 10, 0.0001)
    },
    {
        'epochs': 50,
        'patience': PATIENCE,
        'lr_plateau': (0.1, 10, 0.0001),
        'dropout_conv': 0.3,
        'dropout_recurrent': 0.2
    },
]

# prepare onset dates and prediction timestamps
onset_dates, onset_ts = ModelHelpers.load_onset_dates(version='v2',
                                                      objective=True)
prediction_ts = ModelHelpers.generate_prediction_ts(PREDICT_ON, YEARS)

# load the ERA dataset
print("> Loading Dataset")


def filter_fun(df, year):
    # setup a filter function
    return ModelHelpers.filter_until(df, prediction_ts[year])


data_temp, data_hum = ERA.load_dataset(YEARS,
                                       version='v2',
                                       filter_fun=filter_fun)
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'pool_dims': (2, 2, 2, 0),
    'dense_activation': 'tanh',
    'dropout_recurrent': 0.3
}, {
    'epochs': EPOCHS,
    'patience': PATIENCE,
    'pool_dims': (2, 2, 2, 0),
    'dense_activation': 'tanh',
    'dense_nodes': 512
}]

# --- Loading the dataset ---
print("> Loading Dataset")
onset_dates, onset_ts = ModelHelpers.load_onset_dates()
prediction_ts = ModelHelpers.generate_prediction_ts(PREDICT_ON, YEARS)


def filter_fun(df, year):
    return ModelHelpers.filter_until(df, prediction_ts[year])


data_trmm = TRMM.load_dataset(YEARS,
                              PRE_MONSOON,
                              invalidate=False,
                              filter_fun=filter_fun,
                              aggregation_resolution=AGGREGATION_RESOLUTION,
                              bundled=False)

# --- Building a model with numerical output based on the above tunings ---
Exemple #6
0
def train_model(x):
    config = {
        'aggregation_resolution': 1.0,
        'config_build': {
            'batch_norm':
            True,
            'conv_activation':
            mapping['conv_activation'][int(x[:, 8])],
            'conv_dropout':
            float(x[:, 1]),
            'conv_filters':
            mapping['network_structure'][int(x[0, 7])]['conv_filters'],
            'conv_kernels':
            mapping['network_structure'][int(x[0, 7])]['conv_kernels'],
            'conv_strides':
            mapping['network_structure'][int(x[0, 7])]['conv_strides'],
            'conv_pooling':
            mapping['network_structure'][int(x[0, 7])]['conv_pooling'],
            'conv_kernel_regularizer':
            (mapping['regularizer'][int(x[:, 4])], float(x[:, 3])),
            'conv_recurrent_activation':
            'hard_sigmoid',
            'conv_recurrent_regularizer':
            (mapping['regularizer'][int(x[:, 4])], float(x[:, 3])),
            'conv_recurrent_dropout':
            float(x[:, 2]),
            'dense_dropout':
            float(x[:, 0]),
            'dense_nodes':
            mapping['network_structure'][int(x[0, 7])]['dense_nodes'],
            'dense_activation':
            mapping['dense_activation'][int(x[:, 9])],
            'dense_kernel_regularizer': (mapping['regularizer'][int(x[:, 4])],
                                         float(x[:, 3])),
            'learning_rate':
            mapping['initial_lr'][int(x[:, 6])],
            'loss':
            'mean_squared_error',
            'optimizer':
            mapping['optimizer'][int(x[:, 5])],
            'padding':
            'same'
        },
        'config_fit': {
            'batch_size': 1,
            'epochs': EPOCHS,
            'lr_plateau': (0.5, 10, 0.00001),
            'patience': PATIENCE,
            'tensorboard': True,
            'validation_split': 0.1
        },
        'objective_onsets': True,
        'predict_on': PREDICT_ON,
        'years': YEARS,
        'years_train': YEARS_TRAIN,
        'years_dev': YEARS_DEV,
        'years_test': YEARS_TEST,
    }

    # hash the config that was passed down
    hashed_params = hashlib.md5(str(config).encode()).hexdigest()

    # build a model based on the above tunings
    print("> Training Model")
    print(f">> Parameters: {x}")
    print(f">> Config: {config}")

    # prepare onset dates and prediction timestamps
    onset_dates, onset_ts = ModelHelpers.load_onset_dates(version='v2',
                                                          objective=True)
    prediction_ts = ModelHelpers.generate_prediction_ts(PREDICT_ON, YEARS)

    def filter_fun(df, year):
        # setup a filter function
        return ModelHelpers.filter_until(df, prediction_ts[year])

    # load the TRMM dataset
    print("> Loading Dataset")
    # load data for the pre-monsoon period (MAM)
    trmm_data = TRMM.load_dataset(
        YEARS,
        range(3, 6),
        aggregation_resolution=config['aggregation_resolution'],
        timestamp=True,
        invalidate=False,
        version='v2',
        filter_fun=filter_fun,
        bundled=False)

    era_temp, era_hum = ERA.load_dataset(
        YEARS,
        invalidate=False,
        timestamp=True,
        filter_fun=filter_fun,
        aggregation_resolution=config['aggregation_resolution'])

    # train test split
    print("> Train-Test Split")
    X_train, y_train, X_test, y_test, X_dev, y_dev, unstacked = ModelERAv2.train_test_split(
        [trmm_data, era_temp, era_hum],
        prediction_ts,
        onset_ts,
        years=config['years'],
        years_train=config['years_train'],
        years_test=config['years_test'],
        years_dev=config['years_dev'])

    # if the model was already trained, load it from cache
    print('>> Building a new model')

    # throw away the models in memory
    K.clear_session()

    # train a model based on the given config
    results = ModelHelpers.run_config(
        ModelERAv2,
        config,
        X_train,
        y_train,
        invalidate=True,
        validation_data=(X_dev, y_dev) if config['years_dev'] else None,
        cache_id=hashed_params,
        version=VERSION)

    model = results['model']

    # evaluate the latest model
    eval_latest_train = model.evaluate(X_train, y_train)
    print(model.predict(X_train, batch_size=1), y_train)
    print(f'Train (latest): {eval_latest_train}')

    eval_latest_dev = 'NO_DEV'
    if X_dev is not None:
        eval_latest_dev = model.evaluate(X_dev, y_dev)
        print(model.predict(X_dev, batch_size=1), y_dev)
        print(f'Dev (latest): {eval_latest_dev}')

    eval_latest_test = model.evaluate(X_test, y_test)
    print(model.predict(X_test, batch_size=1), y_test)
    print(f'Test (latest): {eval_latest_test}')

    # evaluate the best model
    best_path = f'00_CACHE/lstm_{VERSION}_{hashed_params}_best.h5'
    if os.path.isfile(best_path):
        best_model = load_model(best_path)

        eval_best_train = best_model.evaluate(X_train, y_train)
        print(best_model.predict(X_train, batch_size=1), y_train)
        print(f'Train (best): {eval_best_train}')

        eval_best_dev = 'NO_DEV'
        if X_dev is not None:
            eval_best_dev = best_model.evaluate(X_dev, y_dev)
            print(best_model.predict(X_dev, batch_size=1), y_dev)
            print(f'Dev (best): {eval_best_dev}')

        eval_best_test = best_model.evaluate(X_test, y_test)
        print(best_model.predict(X_test, batch_size=1), y_test)
        print(f'Test (best): {eval_best_test}')

        with open(f'03_EVALUATION/bayesian/{VERSION}_optimization_out.csv',
                  'a') as csvfile:
            writer = csv.writer(csvfile,
                                delimiter=';',
                                quotechar='"',
                                quoting=csv.QUOTE_MINIMAL)
            writer.writerow([
                str(hashed_params),
                str(config),
                str(eval_latest_train),
                str(eval_best_train),
                str(eval_latest_dev),
                str(eval_best_dev),
                str(eval_latest_test),
                str(eval_best_test)
            ])
    else:
        with open(f'03_EVALUATION/bayesian/{VERSION}_optimization_out.csv',
                  'a') as csvfile:
            writer = csv.writer(csvfile,
                                delimiter=';',
                                quotechar='"',
                                quoting=csv.QUOTE_MINIMAL)
            writer.writerow([
                str(hashed_params),
                str(config),
                str(eval_latest_train), '-',
                str(eval_latest_dev), '-',
                str(eval_latest_test), '-'
            ])

    # return MSE over train and dev set as a parameter to optimizer
    # TODO: should this be weighted in any way?
    # TODO: latest or best? or both?
    # TODO: we could overfit on the dev set by doing this?!
    # but this is the same as optimizing validation error manually...
    return eval_latest_train[1] + eval_latest_dev[1]
        'tensorboard': False,
        'validation_split': 0.1
    },
    'objective_onsets': True,
    'predict_on': PREDICT_ON,
    'years': YEARS,
    'years_train': YEARS_TRAIN,
    'years_dev': YEARS_DEV,
    'years_test': YEARS_TEST,
}]

# get the tuning for the current index
TUNING = TUNINGS[INDEX]

# prepare onset dates and prediction timestamps
onset_dates, onset_ts = ModelHelpers.load_onset_dates(
    version='v2', objective=True if TUNING['objective_onsets'] else False)
prediction_ts = ModelHelpers.generate_prediction_ts(TUNING['predict_on'],
                                                    TUNING['years'])

# load the ERA dataset
print("> Loading Dataset")


def filter_fun(df, year):
    # setup a filter function
    return ModelHelpers.filter_until(df, prediction_ts[year])


data_trmm = TRMM.load_dataset(
    TUNING['years'],
    PRE_MONSOON,
    },
    {
        'epochs': 50,
        'patience': PATIENCE,
        'conv_kernel_regularizer': ('L2', 0.1),
        'conv_recurrent_regularizer': ('L2', 0.1),
        'dense_kernel_regularizer': ('L2', 0.2),
        'lr_plateau': (0.1, 10, 0.0001),
        'learning_rate': 0.1,
        'dropout_conv': 0.3,
        'dropout_recurrent': 0.2
    },
]

# prepare onset dates and prediction timestamps
onset_dates, onset_ts = ModelHelpers.load_onset_dates()
prediction_ts = ModelHelpers.generate_prediction_ts(PREDICT_ON, YEARS)


def filter_fun(df, year):
    # setup a filter function
    return ModelHelpers.filter_until(df, prediction_ts[year])


# load the TRMM dataset
print("> Loading Dataset")
data_trmm = TRMM.load_dataset(YEARS,
                              PRE_MONSOON,
                              invalidate=False,
                              filter_fun=filter_fun,
                              aggregation_resolution=AGGREGATION_RESOLUTION,
    def train_test_split(datasets,
                         prediction_ts,
                         onset_ts,
                         true_offset,
                         years=range(1979, 2018),
                         years_train=range(1979, 2010),
                         years_dev=range(2010, 2013),
                         years_test=range(2013, 2018)):
        """
        Prepare data to be in a digestible format for the model

        :datasets: List of datasets to use as features
        :outcomes: Outcomes as generated by the base model

        :return:
        """

        # generate outcomes
        outcomes_train = ModelHelpers.generate_outcomes(
            prediction_ts,
            onset_ts,
            years_train,
            numerical=True,
            sequence=True,
            true_offset=true_offset)
        outcomes_rest = ModelHelpers.generate_outcomes(
            prediction_ts,
            onset_ts, [i for j in (years_dev, years_test) for i in j],
            numerical=True,
            sequence=True)
        print(outcomes_train, outcomes_rest)

        # datasets = ModelHelpers.augment_data(datasets, prediction_ts, years)

        # unstack the entire dataset
        # => bring into matrix form with lat/lon on axes
        # unstacked = ModelHelpers.unstack_all(datasets, years)
        # unstacked = ModelHelpers.prepare_datasets(years, datasets, prediction_ts)

        # print(unstacked[1995][0][0])
        # print(f'> unstacked: {unstacked[1995].shape!s}')

        # generate training data
        X_train = ModelHelpers.prepare_datasets(years_train,
                                                datasets,
                                                prediction_ts,
                                                true_offset=true_offset)
        # X_train = ModelHelpers.reshape_years([unstacked[year] for year in years_train], num_channels=len(datasets))
        y_train = ModelHelpers.stack_outcomes(outcomes_train,
                                              years_train,
                                              augmented=True)
        print(X_train[0][0][0])
        print('> X_train', X_train.shape, 'y_train', y_train.shape)

        X_train = ModelHelpers.normalize_channels(X_train)

        # generate test data
        X_test = ModelHelpers.prepare_datasets(years_test, datasets,
                                               prediction_ts)
        # X_test = ModelHelpers.reshape_years([unstacked[year] for year in years_test], num_channels=len(datasets))
        y_test = ModelHelpers.stack_outcomes(outcomes_rest,
                                             years_test,
                                             augmented=True)
        print(X_test[0][0][0])
        print('> X_test', X_test.shape, 'y_test', y_test.shape)

        X_test = ModelHelpers.normalize_channels(X_test)

        if years_dev:
            # X_dev = ModelHelpers.reshape_years([unstacked[year] for year in years_dev], num_channels=len(datasets))
            X_dev = ModelHelpers.prepare_datasets(years_dev, datasets,
                                                  prediction_ts)
            y_dev = ModelHelpers.stack_outcomes(outcomes_rest,
                                                years_dev,
                                                augmented=True)
            print(X_dev.shape)
            print('> X_dev', X_dev.shape, 'y_dev', y_dev.shape)

            X_dev = ModelHelpers.normalize_channels(X_dev)

            return X_train, y_train, X_test, y_test, X_dev, y_dev

        return X_train, y_train, X_test, y_test, None, None
Exemple #10
0
    def train_test_split(trmm_data,
                         prediction_ts,
                         onset_ts,
                         numerical=False,
                         years=range(1998, 2017),
                         years_train=range(1998, 2016),
                         years_dev=None,
                         years_test=range(2016, 2017)):
        """
        Prepare data to be in a digestible format for the model

        :trmm_data: Filtered and optionally aggregated TRMM dataset to use
        :outcomes: Outcomes as generated by the base model

        :return:
        """

        def unstack_year(df):
            """ Unstack a single year and return an unstacked sequence of grids """

            return np.array(
                [df.iloc[:, i].unstack().values for i in range(df.shape[1])])

        def unstack_all(df):
            """ Unstack all years and return the resulting dict """

            result = {}

            for year in years:
                result[year] = unstack_year(df[year])

            return result

        def reshape_years(arr):
            return np.array(
                list(
                    map(lambda year: year.reshape((year.shape[0], year.shape[1], year.shape[2], 1)),
                        arr)))

        def stack_outcomes(outcomes, years):
            if numerical:
                return [outcomes[year] for year in years]

            return np.concatenate([outcomes[year] for year in years])

        # generate outcomes
        outcomes = ModelHelpers.generate_outcomes(
            prediction_ts, onset_ts, years, numerical=numerical)

        # unstack the entire trmm dataset
        # => bring into matrix form with lat/lon on axes
        unstacked = unstack_all(trmm_data)

        # generate training data
        X_train = reshape_years([unstacked[year] for year in years_train])
        y_train = stack_outcomes(outcomes, years_train)

        # generate test data
        X_test = reshape_years([unstacked[year] for year in years_test])
        y_test = stack_outcomes(outcomes, years_test)

        if years_dev:
            X_dev = reshape_years([unstacked[year] for year in years_dev])
            y_dev = stack_outcomes(outcomes, years_dev)

            return X_train, y_train, X_test, y_test, X_dev, y_dev, unstacked

        return X_train, y_train, X_test, y_test, unstacked
    def train_test_split(
        datasets,
        prediction_ts,
        # prediction_ts_test,
        onset_ts,
        years=range(1979, 2018),
        years_train=range(1979, 2010),
        years_dev=range(2010, 2013),
        years_test=range(2013, 2018)):
        """
        Prepare data to be in a digestible format for the model

        :datasets: List of datasets to use as features (e.g., t and r dataframes)
        :prediction_ts: Sequences of prediction timestamps to use
        :onset_ts: Onset dates to use for outcome calculations
        :years: The overall years of all sets
        :years_train: The years to use for the training set
        :years_dev: The years to optionally use for the validation set
        :years_test: The years to use for the test set

        :return:
        """

        # generate outcomes
        outcomes = ModelHelpers.generate_outcomes(prediction_ts,
                                                  onset_ts,
                                                  years,
                                                  numerical=True,
                                                  sequence=True)
        # outcomes_test = ModelHelpers.generate_outcomes(prediction_ts_test, onset_ts, years_test, numerical=True, sequence=True)
        # print(outcomes_test)

        # generate training data
        X_train = ModelHelpers.prepare_datasets(years_train, datasets,
                                                prediction_ts)
        y_train = ModelHelpers.stack_outcomes(outcomes,
                                              years_train,
                                              augmented=True)
        print(X_train[0][0][0])
        print('> X_train', X_train.shape, 'y_train', y_train.shape)

        # standardize the training set, extracting mean and std
        X_train, X_mean, X_std = ModelHelpers.normalize_channels(X_train,
                                                                 seperate=True)

        # generate test data
        X_test = ModelHelpers.prepare_datasets(years_test, datasets,
                                               prediction_ts)
        y_test = ModelHelpers.stack_outcomes(outcomes,
                                             years_test,
                                             augmented=True)
        print(X_test)
        print('> X_test', X_test.shape, 'y_test', y_test.shape)

        # standardize the test set using mean and std from the training set
        X_test = ModelHelpers.normalize_channels(X_test,
                                                 mean=X_mean,
                                                 std=X_std)

        if years_dev:
            X_dev = ModelHelpers.prepare_datasets(years_dev, datasets,
                                                  prediction_ts)
            y_dev = ModelHelpers.stack_outcomes(outcomes,
                                                years_dev,
                                                augmented=True)
            print(X_dev.shape)
            print('> X_dev', X_dev.shape, 'y_dev', y_dev.shape)

            # standardize the dev set using mean and std from the training set
            X_dev = ModelHelpers.normalize_channels(X_dev,
                                                    mean=X_mean,
                                                    std=X_std)

            return X_train, y_train, X_test, y_test, X_dev, y_dev

        return X_train, y_train, X_test, y_test, None, None
Exemple #12
0
    def train_test_split(datasets,
                         prediction_ts,
                         onset_ts,
                         years=range(1979, 2018),
                         years_train=range(1979, 2012),
                         years_dev=range(2012, 2015),
                         years_test=range(2015, 2018)):
        """
        Prepare data to be in a digestible format for the model

        :datasets: List of datasets to use as features
        :outcomes: Outcomes as generated by the base model

        :return:
        """

        # generate outcomes
        outcomes = ModelHelpers.generate_outcomes(prediction_ts,
                                                  onset_ts,
                                                  years,
                                                  numerical=True)

        # unstack the entire dataset
        # => bring into matrix form with lat/lon on axes
        unstacked = ModelHelpers.unstack_all(datasets, years)

        print(unstacked[2017][0][0])
        print(f'> unstacked: {unstacked[2017].shape!s}')

        # generate training data
        X_train = ModelHelpers.reshape_years(
            [unstacked[year] for year in years_train],
            num_channels=len(datasets))
        X_train = ModelHelpers.normalize_channels(X_train, standardize=False)
        y_train = ModelHelpers.stack_outcomes(outcomes, years_train)
        print(X_train[0][0][0])
        print(f'> X_train: {X_train.shape!s}')

        # generate test data
        X_test = ModelHelpers.reshape_years(
            [unstacked[year] for year in years_test],
            num_channels=len(datasets))
        X_test = ModelHelpers.normalize_channels(X_test)
        y_test = ModelHelpers.stack_outcomes(outcomes, years_test)
        print(X_test[0][0][0])
        print(f'> X_test: {X_test.shape!s}')

        if years_dev:
            X_dev = ModelHelpers.reshape_years(
                [unstacked[year] for year in years_dev],
                num_channels=len(datasets))
            X_dev = ModelHelpers.normalize_channels(X_dev)
            y_dev = ModelHelpers.stack_outcomes(outcomes, years_dev)
            print(X_dev.shape)
            print(f'> X_dev: {X_dev.shape!s}')

            return X_train, y_train, X_test, y_test, X_dev, y_dev, unstacked

        return X_train, y_train, X_test, y_test, None, None, unstacked
def filter_fun(df, year):
    return ModelHelpers.filter_until(df, onset_ts[year])
    },
    {
        'epochs': 50,
        'patience': PATIENCE,
        'conv_kernel_regularizer': ('L2', 0.1),
        'conv_recurrent_regularizer': ('L2', 0.1),
        'dense_kernel_regularizer': ('L2', 0.2),
        'lr_plateau': (0.1, 10, 0.0001),
        'learning_rate': 0.1,
        'dropout_conv': 0.3,
        'dropout_recurrent': 0.2
    },
]

# prepare onset dates and prediction timestamps
onset_dates, onset_ts = ModelHelpers.load_onset_dates()
prediction_ts = ModelHelpers.generate_prediction_ts(PREDICT_ON, YEARS)


def filter_fun(df, year):
    # setup a filter function
    return ModelHelpers.filter_until(df, prediction_ts[year])


# load the TRMM dataset
print("> Loading Dataset")
data_trmm = TRMM.load_dataset(
    YEARS,
    PRE_MONSOON,
    invalidate=False,
    filter_fun=filter_fun,
Exemple #15
0
        'years_train': [
            1979, 1980, 1981, 1982, 1983, 1984, 1986, 1987, 1988, 1989, 1990,
            1991, 1992, 1993, 1994, 1996, 1997, 1998, 1999, 2000, 2001, 2002,
            2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2016
        ],
        'years_dev':
        None,
        'years_test': [1985, 1995, 2003, 2004, 2005, 2014, 2015, 2017]
    }
]

# get the tuning for the current index
TUNING = TUNINGS[INDEX]

# load onset dates
onset_dates, onset_ts = ModelHelpers.load_onset_dates(
    version='v2', objective=True if TUNING['objective_onsets'] else False)

# prepare prediction timestamps
# generate a sequence of timestamps for train and validation (and, optionally, test)
prediction_ts = ModelHelpers.generate_prediction_ts(
    TUNING['predict_on'],
    TUNING['years'],
    onset_dates=onset_dates,
    sequence_length=TUNING['prediction_sequence'],
    sequence_offset=TUNING['prediction_offset'],
    example_length=TUNING['prediction_example_length'])

# prediction_ts_test = ModelHelpers.generate_prediction_ts(TUNING['predict_on'], TUNING['years_test'], onset_dates=onset_dates, sequence_length=TUNING['prediction_sequence'], sequence_offset=TUNING['prediction_offset'], example_length=TUNING['prediction_example_length'])


# setup a filter function
    def train_test_split(datasets,
                         prediction_ts,
                         onset_ts,
                         years=range(1979, 2018),
                         years_train=range(1979, 2012),
                         years_dev=range(2012, 2015),
                         years_test=range(2015, 2018)):
        """
        Prepare data to be in a digestible format for the model

        :datasets: List of datasets to use as features
        :outcomes: Outcomes as generated by the base model

        :return:
        """
        def unstack_year(df):
            """ Unstack a single year and return an unstacked sequence of grids """

            return np.array(
                [df.iloc[:, i].unstack().values for i in range(df.shape[1])])

        def unstack_all(dataframes):
            """ Unstack all years and return the resulting dict """

            result = {}

            for year in years:
                result[year] = np.stack(
                    [unstack_year(df[year]) for df in dataframes], axis=-1)

            return result

        def reshape_years(arr):
            return np.array(
                list(
                    map(
                        lambda year: year.reshape(
                            (year.shape[0], year.shape[1], year.shape[2], 2)),
                        arr)))

        def stack_outcomes(outcomes, years):
            return [outcomes[year] for year in years]

        def normalize_channels(arr):
            # normalize each channel seperately
            # axes 1 and 2 should be fixed (lat and lon)
            # axes 0 and 2 should be variable (for each channel seperately over all images)
            # see: https://stackoverflow.com/questions/42460217/how-to-normalize-a-4d-numpy-array
            arr_min = arr.min(axis=(1, 2), keepdims=True)
            arr_max = arr.max(axis=(1, 2), keepdims=True)

            return (arr - arr_min) / (arr_max - arr_min)

        # generate outcomes
        outcomes = ModelHelpers.generate_outcomes(prediction_ts,
                                                  onset_ts,
                                                  years,
                                                  numerical=True)

        # unstack the entire dataset
        # => bring into matrix form with lat/lon on axes
        unstacked = unstack_all(datasets)

        print(unstacked[2017][0][0])
        print(f'> unstacked: {unstacked[2017].shape!s}')

        # generate training data
        X_train = reshape_years([unstacked[year] for year in years_train])
        X_train = normalize_channels(X_train)
        y_train = stack_outcomes(outcomes, years_train)
        print(X_train[0][0][0])
        print(f'> X_train: {X_train.shape!s}')

        # generate test data
        X_test = reshape_years([unstacked[year] for year in years_test])
        X_test = normalize_channels(X_test)
        y_test = stack_outcomes(outcomes, years_test)
        print(X_test[0][0][0])
        print(f'> X_test: {X_test.shape!s}')

        if years_dev:
            X_dev = reshape_years([unstacked[year] for year in years_dev])
            X_dev = normalize_channels(X_dev)
            y_dev = stack_outcomes(outcomes, years_dev)
            print(X_dev.shape)
            print(f'> X_dev: {X_dev.shape!s}')

            return X_train, y_train, X_test, y_test, X_dev, y_dev, unstacked

        return X_train, y_train, X_test, y_test, unstacked