예제 #1
0
def test_structured_data_unknown_str_in_col_type(tmp_dir):
    with pytest.raises(ValueError) as info:
        structured_data.StructuredDataClassifier(
            column_types=utils.FALSE_COLUMN_TYPES_FROM_CSV,
            directory=tmp_dir,
            max_trials=1,
            seed=utils.SEED)
    assert 'Column_types should be either "categorical"' in str(info.value)
예제 #2
0
def test_structured_data_col_name_type_mismatch(tmp_dir):
    with pytest.raises(ValueError) as info:
        structured_data.StructuredDataClassifier(
            column_names=utils.COLUMN_NAMES_FROM_NUMPY,
            column_types=utils.COLUMN_TYPES_FROM_CSV,
            directory=tmp_dir,
            max_trials=1,
            seed=utils.SEED)
    assert 'Column_names and column_types are mismatched.' in str(info.value)
예제 #3
0
def test_structured_data_classifier_from_csv(init, fit, tmp_dir):
    clf = structured_data.StructuredDataClassifier(
        directory=tmp_dir,
        max_trials=1,
        seed=utils.SEED)

    clf.fit(x=utils.TRAIN_FILE_PATH, y='survived', epochs=2,
            validation_data=(utils.TEST_FILE_PATH, 'survived'))

    assert init.called
    _, kwargs = fit.call_args_list[0]
    assert isinstance(kwargs['x'], pandas.DataFrame)
    assert isinstance(kwargs['y'], np.ndarray)
예제 #4
0
def test_structured_classifier(init, fit, tmp_dir):
    num_data = 500
    train_x = utils.generate_structured_data(num_data)
    train_y = utils.generate_one_hot_labels(num_instances=num_data, num_classes=3)

    clf = structured_data.StructuredDataClassifier(
        column_names=utils.COLUMN_NAMES_FROM_NUMPY,
        directory=tmp_dir,
        max_trials=1,
        seed=utils.SEED)
    clf.fit(train_x, train_y, epochs=2, validation_data=(train_x, train_y))

    assert init.called
    assert fit.called