Ejemplo n.º 1
0
def test_pretrain_fail():
    prepped_data = preproc_utils.prep_all_data(
        x_data_file="test_data/test_data",
        y_data_file="test_data/obs_temp_flow",
        # pretrain_file="test_data/obs_temp_flow",
        train_start_date="2003-09-15",
        train_end_date="2004-09-16",
        val_start_date="2004-09-17",
        val_end_date="2005-09-18",
        test_start_date="2005-09-19",
        test_end_date="2006-09-20",
        spatial_idx_name="segs_test",
        time_idx_name="times_test",
        x_vars=["seg_rain", "seg_tave_air"],
        y_vars=["temp_c"],
    )

    with pytest.raises(KeyError):
        model = train.train_model(io_data=prepped_data,
                                  pretrain_epochs=2,
                                  finetune_epochs=2,
                                  hidden_units=10,
                                  out_dir='test_data/test_training_out',
                                  model_type="lstm",
                                  seed=2,
                                  dropout=0.12,
                                  loss_func=loss_functions.rmse_masked_one_var)
Ejemplo n.º 2
0
def test_finetune_rgcn():
    prepped_data = preproc_utils.prep_all_data(
        x_data_file="test_data/test_data",
        y_data_file="test_data/obs_temp_flow",
        train_start_date="2003-09-15",
        train_end_date="2004-09-16",
        val_start_date="2004-09-17",
        val_end_date="2005-09-18",
        test_start_date="2005-09-19",
        test_end_date="2006-09-20",
        spatial_idx_name="segs_test",
        time_idx_name="times_test",
        segs=[2007, 2012],
        distfile="../../../drb-dl-model/data/in/distance_matrix.npz",
        x_vars=["seg_rain", "seg_tave_air"],
        y_vars=["temp_c"],
    )

    test_out_dir = 'test_data/test_training_out'
    if os.path.exists(test_out_dir):
        shutil.rmtree(test_out_dir)

    os.mkdir(test_out_dir)

    model = train.train_model(io_data=prepped_data,
                              finetune_epochs=2,
                              pretrain_epochs=0,
                              hidden_units=10,
                              out_dir='test_data/test_training_out',
                              model_type="rgcn",
                              seed=2,
                              dropout=0.12,
                              loss_func=loss_functions.rmse)
Ejemplo n.º 3
0
def test_prep_data():
    prepped_data = preproc_utils.prep_all_data(
        x_data_file="test_data/test_data",
        y_data_file="test_data/obs_temp_flow",
        train_start_date="2003-09-15",
        train_end_date="2004-09-16",
        val_start_date="2004-09-17",
        val_end_date="2005-09-18",
        test_start_date="2005-09-19",
        test_end_date="2006-09-20",
        spatial_idx_name="segs_test",
        time_idx_name="times_test",
        x_vars=["seg_rain", "seg_tave_air"],
        y_vars_finetune=["temp_c", "discharge_cms"],
    )

    assert "x_trn" in prepped_data.keys()
    assert "x_val" in prepped_data.keys()
    assert "x_tst" in prepped_data.keys()
    assert "ids_trn" in prepped_data.keys()
    assert "ids_val" in prepped_data.keys()
    assert "ids_tst" in prepped_data.keys()
    assert "times_trn" in prepped_data.keys()
    assert "times_val" in prepped_data.keys()
    assert "times_tst" in prepped_data.keys()
    assert "y_obs_trn" in prepped_data.keys()
    assert "y_obs_val" in prepped_data.keys()
    assert "y_obs_tst" in prepped_data.keys()
    assert "y_obs_vars" in prepped_data.keys()
    assert "y_mean" in prepped_data.keys()
    assert "y_std" in prepped_data.keys()
    assert "x_vars" in prepped_data.keys()
    assert "x_mean" in prepped_data.keys()
    assert "x_std" in prepped_data.keys()
Ejemplo n.º 4
0
def test_prep_data_w_pretrain_file_no_y_pretrain():
    with pytest.raises(ValueError):
        preproc_utils.prep_all_data(
            x_data_file="test_data/test_data",
            y_data_file="test_data/obs_temp_flow",
            pretrain_file="test_data/obs_temp_flow",
            train_start_date="2003-09-15",
            train_end_date="2004-09-16",
            val_start_date="2004-09-17",
            val_end_date="2005-09-18",
            test_start_date="2005-09-19",
            test_end_date="2006-09-20",
            spatial_idx_name="segs_test",
            time_idx_name="times_test",
            x_vars=["seg_rain", "seg_tave_air"],
            y_vars_finetune=["temp_c", "discharge_cms"],
        )
Ejemplo n.º 5
0
def test_prep_data_no_test_dates():
    prepped_data = preproc_utils.prep_all_data(
        x_data_file="test_data/test_data",
        y_data_file="test_data/obs_temp_flow",
        pretrain_file="test_data/test_data",
        train_start_date="2003-09-15",
        train_end_date="2004-09-16",
        val_start_date="2004-09-17",
        val_end_date="2005-09-18",
        spatial_idx_name="segs_test",
        time_idx_name="times_test",
        x_vars=["seg_rain", "seg_tave_air"],
        y_vars_finetune=["temp_c", "discharge_cms"],
        y_vars_pretrain=["seg_tave_water", "seg_outflow"],
    )

    assert "x_trn" in prepped_data.keys()
    assert "x_val" in prepped_data.keys()
    assert "x_tst" in prepped_data.keys()
    assert "ids_trn" in prepped_data.keys()
    assert "ids_val" in prepped_data.keys()
    assert "ids_tst" in prepped_data.keys()
    assert "times_trn" in prepped_data.keys()
    assert "times_val" in prepped_data.keys()
    assert "times_tst" in prepped_data.keys()
    assert "y_pre_trn" in prepped_data.keys()
    assert "y_pre_full" in prepped_data.keys()
    assert "y_obs_trn" in prepped_data.keys()
    assert "y_obs_val" in prepped_data.keys()
    assert "y_obs_tst" in prepped_data.keys()

    assert prepped_data["x_trn"] is not None
    assert prepped_data["x_val"] is not None
    assert prepped_data["x_tst"] is None
    assert prepped_data["ids_trn"] is not None
    assert prepped_data["ids_val"] is not None
    assert prepped_data["ids_tst"] is None
    assert prepped_data["times_trn"] is not None
    assert prepped_data["times_val"] is not None
    assert prepped_data["times_tst"] is None
    assert prepped_data["y_pre_trn"] is not None
    assert prepped_data["y_obs_trn"] is not None
    assert prepped_data["y_obs_val"] is not None
    assert prepped_data["y_obs_tst"] is None
Ejemplo n.º 6
0
def test_prep_data_val_test_sites_test_dates():
    data = preproc_utils.prep_all_data(
        x_data_file="test_data/test_data",
        y_data_file="test_data/obs_temp_flow",
        pretrain_file="test_data/test_data",
        train_start_date="2003-09-15",
        train_end_date="2004-09-16",
        val_start_date="2004-09-17",
        val_end_date="2005-09-18",
        test_start_date="2005-09-19",
        test_end_date="2006-09-20",
        val_sites=[2007],
        test_sites=[2012],
        spatial_idx_name="segs_test",
        time_idx_name="times_test",
        x_vars=["seg_rain", "seg_tave_air"],
        y_vars_finetune=["temp_c", "discharge_cms"],
        y_vars_pretrain=["seg_tave_water", "seg_outflow"],
    )

    assert_segs_in_ids(data['ids_trn'])
    assert_segs_in_ids(data['ids_val'])
    assert_segs_in_ids(data['ids_tst'])

    df_trn = df_from_array(data, 'trn')
    df_val = df_from_array(data, 'val')
    df_tst = df_from_array(data, 'tst')

    assert get_num_non_nans(df_trn, 2007) == 0
    assert get_num_non_nans(df_trn, 2012) == 0
    assert get_num_non_nans(df_trn, 2014) > 0
    assert get_num_non_nans(df_trn, 2037) > 0

    assert get_num_non_nans(df_val, 2007) > 0
    assert get_num_non_nans(df_val, 2012) == 0
    assert get_num_non_nans(df_val, 2014) > 0
    assert get_num_non_nans(df_val, 2037) > 0

    assert get_num_non_nans(df_tst, 2007) > 0
    assert get_num_non_nans(df_tst, 2012) > 0
    assert get_num_non_nans(df_tst, 2014) > 0
    assert get_num_non_nans(df_tst, 2037) > 0
Ejemplo n.º 7
0
def test_prep_data_no_scale_y():
    data = preproc_utils.prep_all_data(
        x_data_file="test_data/test_data",
        y_data_file="test_data/obs_temp_flow",
        pretrain_file="test_data/test_data",
        train_start_date="2003-09-15",
        train_end_date="2004-09-16",
        val_start_date="2004-09-17",
        val_end_date="2005-09-18",
        test_start_date="2005-09-19",
        test_end_date="2006-09-20",
        test_sites=[2012, 2037],
        spatial_idx_name="segs_test",
        normalize_y=False,
        time_idx_name="times_test",
        x_vars=["seg_rain", "seg_tave_air"],
        y_vars_finetune=["temp_c", "discharge_cms"],
        y_vars_pretrain=["seg_tave_water", "seg_outflow"],
    )

    # make sure all the std's are 1 and the means are 0
    assert ((data['y_std'] - 1).sum() == 0)
    assert (data['y_mean'].sum() == 0)