示例#1
0
def test_n_train_examples(n=500):
    if python_version.is_compatible():
        cnn = CNN(epochs=3, log_interval=1000, loader='train', seed=0)
        idx = np.random.choice(X_train, n,
                               replace=False)  # Grab n random examples.
        cnn.fit(train_idx=X_train[idx],
                train_labels=y_train[idx],
                loader='train')
        cnn.loader = 'test'
        pred = cnn.predict(X_test[:n])
        print(accuracy_score(y_test[:n], pred))
        assert (accuracy_score(y_test[:n], pred) > 0.1)

        # Check that dataset defaults to test set when an invalid name is given.
        cnn.loader = 'INVALID'
        pred = cnn.predict(X_test[:n])
        assert (len(pred) == MNIST_TEST_SIZE)

        # Check that pred_proba runs on all examples when None is passed in
        cnn.loader = 'test'
        proba = cnn.predict_proba(idx=None, loader='test')
        assert proba is not None
        assert (len(pred) == MNIST_TEST_SIZE)

    assert True
示例#2
0
def test_n_train_examples():
    if python_version.is_compatible():
        cnn = CNN(epochs=3, log_interval=1000, loader='train', seed=0,
                  dataset='sklearn-digits', )
        cnn.fit(train_idx=X_train_idx, train_labels=y_train,
                loader='train', )
        cnn.loader = 'test'
        pred = cnn.predict(X_test_idx)
        print(accuracy_score(y_test, pred))
        assert (accuracy_score(y_test, pred) > 0.1)

        # Check that exception is raised when invalid name is given.
        cnn.loader = 'INVALID'
        with pytest.raises(ValueError) as e:
            pred = cnn.predict(X_test_idx)

        # Check that pred_proba runs on all examples when None is passed in
        cnn.loader = 'test'
        proba = cnn.predict_proba(idx=None, loader='test')
        assert proba is not None
        assert (len(pred) == SKLEARN_DIGITS_TEST_SIZE)

    assert True