class wrapper_KNN(machine_learning_method):
    """wrapper for pyriemann KNN"""
    def __init__(self, method_name, method_args):
        super(wrapper_KNN, self).__init__(method_name, method_args)
        self.init_method()

    def init_method(self, n_jobs=1):
        self.classifier = KNearestNeighbor(
            n_neighbors=self.method_args['n_neighbors'],
            metric=self.method_args['metric'],
            n_jobs=n_jobs)

    def set_parallel(self, is_parallel=False, n_jobs=8):
        logging.warning(
            'The call to this set_parallel method is reseting the class, and must be fitted again'
        )
        self.parallel = is_parallel
        self.n_jobs = n_jobs

        if self.parallel:
            self.init_method(n_jobs)

    def fit(self, X, y):
        return self.classifier.fit(X, y)

    def predict(self, X):
        return self.classifier.predict(X)
def test_KNN():
    """Test KNearestNeighbor"""
    covset = generate_cov(30, 3)
    labels = np.array([0, 1, 2]).repeat(10)

    knn = KNearestNeighbor(1, metric='riemann')
    knn.fit(covset, labels)
    preds = knn.predict(covset)
    assert_array_equal(labels, preds)
Ejemplo n.º 3
0
def test_KNN():
    """Test KNearestNeighbor"""
    covset = generate_cov(30, 3)
    labels = np.array([0, 1, 2]).repeat(10)

    knn = KNearestNeighbor(1, metric='riemann')
    knn.fit(covset, labels)
    preds = knn.predict(covset)
    assert_array_equal(labels, preds)
Ejemplo n.º 4
0
def test_1NN(get_covmats, get_labels):
    """Test KNearestNeighbor with K=1"""
    n_trials, n_channels, n_classes = 9, 3, 3
    covmats = get_covmats(n_trials, n_channels)
    labels = get_labels(n_trials, n_classes)

    knn = KNearestNeighbor(1, metric="riemann")
    knn.fit(covmats, labels)
    preds = knn.predict(covmats)
    assert_array_equal(labels, preds)
Ejemplo n.º 5
0
    logging.info('Doing training')
    clf_knn_k_fold = []  # Container of classifers trained on each fold
    clf_mdm_k_fold = []  # Container of classifers trained on each fold
    accuracy_list_training_knn = []
    accuracy_list_training_mdm = []
    i = 1
    for train_index, test_index in kf.split(X_train):

        logging.info(f'Doing fold {i}')
        clf_knn = KNearestNeighbor(n_neighbors, metric, n_jobs)
        clf_mdm = MDM(metric, n_jobs)
        X_train_fold, X_test_fold = X_train[train_index], X_train[test_index]
        y_train_fold, y_test_fold = y_train[train_index], y_train[test_index]

        clf_knn.fit(X_train_fold, y_train_fold)
        y_predicted = clf_knn.predict(X_test_fold)
        accuracy = (y_test_fold == y_predicted).sum() / len(y_test_fold)
        clf_knn_k_fold.append(clf_knn)
        accuracy_list_training_knn.append(accuracy)

        clf_mdm.fit(X_train_fold, y_train_fold)
        y_predicted = clf_mdm.predict(X_test_fold)
        accuracy = (y_test_fold == y_predicted).sum() / len(y_test_fold)
        clf_mdm_k_fold.append(clf_mdm)
        accuracy_list_training_mdm.append(accuracy)

        i += 1

    # Testing on test dataset
    logging.info('Doing testing')
    accuracy_list_testing_knn = []