def test_time_series_input_transform():
    x = utils.generate_data(shape=(32, ))
    input_node = input_adapter.TimeseriesInputAdapter(2)
    x = input_node.transform(x)
    for row in x.as_numpy_iterator():
        assert row.ndim == 2

    (x, _), _1 = utils.dataframe_dataframe()
    input_node = input_adapter.TimeseriesInputAdapter(lookback=2)
    x = input_node.fit_transform(x)
    assert input_node.column_names[0] == 'sex'
    for row in x.as_numpy_iterator():
        assert row.ndim == 2
def test_time_series_input_restore_look_back():
    adapter = input_adapter.TimeseriesInputAdapter(2)

    adapter = input_adapter.TimeseriesInputAdapter.from_config(
        adapter.get_config())

    assert adapter.lookback == 2
def test_time_series_input_col_type_without_name():
    train_x = pd.read_csv(utils.TRAIN_CSV_PATH).to_numpy().astype(np.unicode)
    with pytest.raises(ValueError) as info:
        adapter = input_adapter.TimeseriesInputAdapter(
            lookback=2, column_types=utils.COLUMN_TYPES)
        adapter.transform(train_x)
    assert str(info.value) == "Column names must be specified."
Exemple #4
0
def test_time_series_input_less_col_name():
    (x, _), _1 = utils.dataframe_numpy()
    with pytest.raises(ValueError) as info:
        adapter = input_adapter.TimeseriesInputAdapter(
            lookback=2, column_names=utils.LESS_COLUMN_NAMES_FROM_CSV)
        adapter.transform(x)
    assert 'Expect column_names to have length' in str(info.value)
def test_time_series_input_name_type_mismatch():
    column_types = copy.copy(utils.COLUMN_TYPES)
    column_types["age_"] = column_types.pop("age")
    with pytest.raises(ValueError) as info:
        adapter = input_adapter.TimeseriesInputAdapter(
            lookback=2, column_types=column_types)
        adapter.transform(pd.read_csv(utils.TRAIN_CSV_PATH))
    assert "Column_names and column_types are mismatched." in str(info.value)
Exemple #6
0
def test_time_series_input_col_type_without_name():
    num_data = 500
    train_x = utils.generate_structured_data(num_data)
    with pytest.raises(ValueError) as info:
        adapter = input_adapter.TimeseriesInputAdapter(
            lookback=2, column_types=utils.COLUMN_TYPES_FROM_NUMPY)
        adapter.transform(train_x)
    assert str(info.value) == 'Column names must be specified.'
Exemple #7
0
def test_time_series_input_name_type_mismatch():
    (x, _), _1 = utils.dataframe_dataframe()
    column_types = copy.copy(utils.COLUMN_TYPES_FROM_CSV)
    column_types['age_'] = column_types.pop('age')
    with pytest.raises(ValueError) as info:
        adapter = input_adapter.TimeseriesInputAdapter(
            lookback=2, column_types=column_types)
        adapter.transform(x)
    assert 'Column_names and column_types are mismatched.' in str(info.value)
def test_time_series_input_transform():
    x = utils.generate_data(shape=(32, ))
    adapter = input_adapter.TimeseriesInputAdapter(2)
    x = adapter.transform(x)
    for row in x.as_numpy_iterator():
        assert row.ndim == 3
def test_time_series_input_transform_df_to_dataset():
    adapter = input_adapter.TimeseriesInputAdapter(2)

    x = adapter.fit_transform(pd.DataFrame(utils.generate_data(shape=(32, ))))

    assert isinstance(x, tf.data.Dataset)
def test_time_series_input_less_col_name():
    with pytest.raises(ValueError) as info:
        adapter = input_adapter.TimeseriesInputAdapter(
            lookback=2, column_names=utils.COLUMN_NAMES[:-2])
        adapter.transform(pd.read_csv(utils.TRAIN_CSV_PATH))
    assert "Expect column_names to have length" in str(info.value)
def test_time_series_input_with_illegal_dim():
    x = utils.generate_data(shape=(32, 32))
    adapter = input_adapter.TimeseriesInputAdapter(2)
    with pytest.raises(ValueError) as info:
        x = adapter.transform(x)
    assert "Expect the data in TimeseriesInput to have 2" in str(info.value)
def test_time_series_input_type_error():
    x = "unknown"
    adapter = input_adapter.TimeseriesInputAdapter(2)
    with pytest.raises(TypeError) as info:
        x = adapter.transform(x)
    assert "Expect the data in TimeseriesInput to be numpy" in str(info.value)