def test_normalize_off(): feature_df_list = [ pd.DataFrame({ "time": pd.to_datetime(list(range(160)), unit="s"), "A": range(160), "B": range(160), "y": np.ones(160) }) for _ in range(1) ] meta = StorageMeta() storage = BatchStorageMemory(meta) translate = Translate(features=["A", "B"], look_back=0, look_forward=0, n_seconds=1, normalize=False) batch_generator = Builder(storage, translate, batch_size=16, pseudo_stratify=False) batch_generator.generate_and_save_batches(feature_df_list) for batch in storage._data.values(): # all batches have monotonically increasing numbers (range used to create data) assert np.diff(batch["features"][:, 0, 0]).all() # feature A assert np.diff(batch["features"][:, 0, 1]).all() # feature B
def test_normalize_on(): feature_df_list = reduce(add, [[pd.DataFrame({"time": pd.to_datetime(list(range(50)), unit="s"), "A": range(1, 51), "B": range(101, 151), "y": np.ones(50)}), pd.DataFrame({"time": pd.to_datetime(list(range(50)), unit="s"), "A": range(51, 101), "B": range(151, 201), "y": np.ones(50)})] for _ in range(5)], []) meta = StorageMeta() storage = BatchStorageMemory(meta) translate = Translate(features=["A", "B"], look_back=0, look_forward=0, n_seconds=1, normalize=True, verbose=True) batch_generator = Builder(storage, translate, batch_size=10, pseudo_stratify=False) batch_generator.generate_and_save_batches(feature_df_list) tools.assert_almost_equal(translate.scaler.mean_[0], 50, delta=1) tools.assert_almost_equal(translate.scaler.mean_[1], 150, delta=1) for batch in storage._data.values(): # all batches have monotonically increasing numbers (range used to create data) assert np.diff(batch["features"][:, 0, 0]).all() # feature A assert np.diff(batch["features"][:, 0, 1]).all() # feature B
def test_builder_stratify(): feature_set = sorted(["A", "B"]) feature_df_list = [ pd.DataFrame({ "time": pd.to_datetime(list(range(160)), unit="s"), "A": np.ones(160), "B": np.ones(160), "y": np.ones(160) }) for _ in range(1) ] meta = StorageMeta(validation_split=0.5) storage = BatchStorageMemory(meta) translate = Translate(features=feature_set, look_back=0, look_forward=0, n_seconds=1) batch_generator = Builder(storage, translate, batch_size=16, stratify_nbatch_groupings=3, pseudo_stratify=True) batch_generator.generate_and_save_batches(feature_df_list) assert batch_generator._stratify tools.eq_(len(meta.train.ids), 5) tools.eq_(len(meta.validation.ids), 5)
def test_save_and_load_meta(): feature_df_list = [ pd.DataFrame({ "time": pd.to_datetime(list(range(160)), unit="s"), "A": range(160), "B": range(160), "y": np.ones(160) }) for _ in range(1) ] meta = StorageMeta() storage = BatchStorageMemory(meta) translate = Translate(features=["A", "B"], look_back=0, look_forward=0, n_seconds=1, normalize=False) batch_generator = Builder(storage, translate, batch_size=16, pseudo_stratify=False) batch_generator.generate_and_save_batches(feature_df_list) batch_generator.save_meta() translate = Translate(features=["A", "B"], look_back=99, look_forward=99, n_seconds=99, normalize=True) batch_generator_reload = Builder(storage, translate, batch_size=99, pseudo_stratify=False) batch_generator_reload.load_meta() tools.eq_(batch_generator.batch_size, batch_generator_reload.batch_size) tools.eq_(translate._features, translate._features) tools.eq_(translate._look_forward, translate._look_forward) tools.eq_(translate._look_back, translate._look_back) tools.eq_(translate._n_seconds, translate._n_seconds) tools.eq_(translate._normalize, translate._normalize)
def test_builder_storage_meta_validation(): feature_set = sorted(["A", "B"]) feature_df_list = [ pd.DataFrame({ "time": pd.to_datetime(list(range(35)), unit="s"), "A": np.ones(35), "B": np.ones(35), "y": np.ones(35) }) for _ in range(1) ] meta = StorageMeta(validation_split=0.5) storage = BatchStorageMemory(meta) translate = Translate(features=feature_set, look_back=2, look_forward=1, n_seconds=1) batch_generator = Builder(storage, translate, batch_size=16) batch_generator.generate_and_save_batches(feature_df_list) tools.eq_(len(meta.train.ids), 1) tools.eq_(len(meta.validation.ids), 1)