コード例 #1
0
def run_binary():
    from data_loader import toy_data_binary, \
                            moon_dataset, \
                            data_loader_mnist

    datasets = [(toy_data_binary(), 'Synthetic data'),
                (moon_dataset(), 'Two Moon data'),
                (data_loader_mnist(), 'Binarized MNIST data')]

    for data, name in datasets:
        print(name)
        X_train, X_test, y_train, y_test = toy_data_binary()

        if name == 'Binarized MNIST data':
            y_train = [0 if yi < 5 else 1 for yi in y_train]
            y_test = [0 if yi < 5 else 1 for yi in y_test]
            y_train = np.asarray(y_train)
            y_test = np.asarray(y_test)

        for loss_type in ["perceptron", "logistic"]:
            w, b = sol.binary_train(X_train, y_train, loss="logistic")
            train_preds = sol.binary_predict(X_train, w, b, loss=loss_type)
            preds = sol.binary_predict(X_test, w, b, loss=loss_type)
            print(loss_type + ' train acc: %f, test acc: %f' %
                  (accuracy_score(y_train, train_preds),
                   accuracy_score(y_test, preds)))
        print()
コード例 #2
0
def run_binary():
    from data_loader import toy_data_binary, \
                            data_loader_mnist

    print('Performing binary classification on synthetic data')
    X_train, X_test, y_train, y_test = toy_data_binary()

    w, b = binary_train(X_train, y_train)

    train_preds = binary_predict(X_train, w, b)
    preds = binary_predict(X_test, w, b)
    print(
        'train acc: %f, test acc: %f' %
        (accuracy_score(y_train, train_preds), accuracy_score(y_test, preds)))

    print('Performing binary classification on binarized MNIST')
    X_train, X_test, y_train, y_test = data_loader_mnist()

    binarized_y_train = [0 if yi < 5 else 1 for yi in y_train]
    binarized_y_test = [0 if yi < 5 else 1 for yi in y_test]

    w, b = binary_train(X_train, binarized_y_train)

    train_preds = binary_predict(X_train, w, b)
    preds = binary_predict(X_test, w, b)
    print('train acc: %f, test acc: %f' %
          (accuracy_score(binarized_y_train, train_preds),
           accuracy_score(binarized_y_test, preds)))