def __assert_wcqt_slicer(dataset, t_len, *slicer_args): slicer = streams.wcqt_slices(dataset.to_df().iloc[0], t_len, *slicer_args, random_seed=RANDOM_SEED) for i in range(10): data = next(slicer)['x_in'] assert len(data.shape) == 4 assert data.shape[1] > 1 and data.shape[1] < 10 assert data.shape[2] == t_len
def __assert_wcqt_slicer_predict(dataset, t_len, *slicer_args): slicer = streams.wcqt_slices(dataset.to_df().iloc[0], t_len, *slicer_args, random_seed=RANDOM_SEED) # The first one should work data = next(slicer)['x_in'] assert len(data.shape) == 4 assert data.shape[1] > 1 and data.shape[1] < 10 assert data.shape[2] == t_len # The second one should raise stopiteration with pytest.raises(StopIteration): data = next(slicer)['x_in']
def test_wcqt_slicer_with_data_less_tlen(generated_data): df, t_len = generated_data slicer = streams.wcqt_slices(df.iloc[0], t_len) batch = next(slicer) assert batch['x_in'].shape[2] == t_len