def test_time_series_input_transform(): x = utils.generate_data(shape=(32, )) input_node = input_adapter.TimeseriesInputAdapter(2) x = input_node.transform(x) for row in x.as_numpy_iterator(): assert row.ndim == 2 (x, _), _1 = utils.dataframe_dataframe() input_node = input_adapter.TimeseriesInputAdapter(lookback=2) x = input_node.fit_transform(x) assert input_node.column_names[0] == 'sex' for row in x.as_numpy_iterator(): assert row.ndim == 2
def test_time_series_input_restore_look_back(): adapter = input_adapter.TimeseriesInputAdapter(2) adapter = input_adapter.TimeseriesInputAdapter.from_config( adapter.get_config()) assert adapter.lookback == 2
def test_time_series_input_col_type_without_name(): train_x = pd.read_csv(utils.TRAIN_CSV_PATH).to_numpy().astype(np.unicode) with pytest.raises(ValueError) as info: adapter = input_adapter.TimeseriesInputAdapter( lookback=2, column_types=utils.COLUMN_TYPES) adapter.transform(train_x) assert str(info.value) == "Column names must be specified."
def test_time_series_input_less_col_name(): (x, _), _1 = utils.dataframe_numpy() with pytest.raises(ValueError) as info: adapter = input_adapter.TimeseriesInputAdapter( lookback=2, column_names=utils.LESS_COLUMN_NAMES_FROM_CSV) adapter.transform(x) assert 'Expect column_names to have length' in str(info.value)
def test_time_series_input_name_type_mismatch(): column_types = copy.copy(utils.COLUMN_TYPES) column_types["age_"] = column_types.pop("age") with pytest.raises(ValueError) as info: adapter = input_adapter.TimeseriesInputAdapter( lookback=2, column_types=column_types) adapter.transform(pd.read_csv(utils.TRAIN_CSV_PATH)) assert "Column_names and column_types are mismatched." in str(info.value)
def test_time_series_input_col_type_without_name(): num_data = 500 train_x = utils.generate_structured_data(num_data) with pytest.raises(ValueError) as info: adapter = input_adapter.TimeseriesInputAdapter( lookback=2, column_types=utils.COLUMN_TYPES_FROM_NUMPY) adapter.transform(train_x) assert str(info.value) == 'Column names must be specified.'
def test_time_series_input_name_type_mismatch(): (x, _), _1 = utils.dataframe_dataframe() column_types = copy.copy(utils.COLUMN_TYPES_FROM_CSV) column_types['age_'] = column_types.pop('age') with pytest.raises(ValueError) as info: adapter = input_adapter.TimeseriesInputAdapter( lookback=2, column_types=column_types) adapter.transform(x) assert 'Column_names and column_types are mismatched.' in str(info.value)
def test_time_series_input_transform(): x = utils.generate_data(shape=(32, )) adapter = input_adapter.TimeseriesInputAdapter(2) x = adapter.transform(x) for row in x.as_numpy_iterator(): assert row.ndim == 3
def test_time_series_input_transform_df_to_dataset(): adapter = input_adapter.TimeseriesInputAdapter(2) x = adapter.fit_transform(pd.DataFrame(utils.generate_data(shape=(32, )))) assert isinstance(x, tf.data.Dataset)
def test_time_series_input_less_col_name(): with pytest.raises(ValueError) as info: adapter = input_adapter.TimeseriesInputAdapter( lookback=2, column_names=utils.COLUMN_NAMES[:-2]) adapter.transform(pd.read_csv(utils.TRAIN_CSV_PATH)) assert "Expect column_names to have length" in str(info.value)
def test_time_series_input_with_illegal_dim(): x = utils.generate_data(shape=(32, 32)) adapter = input_adapter.TimeseriesInputAdapter(2) with pytest.raises(ValueError) as info: x = adapter.transform(x) assert "Expect the data in TimeseriesInput to have 2" in str(info.value)
def test_time_series_input_type_error(): x = "unknown" adapter = input_adapter.TimeseriesInputAdapter(2) with pytest.raises(TypeError) as info: x = adapter.transform(x) assert "Expect the data in TimeseriesInput to be numpy" in str(info.value)