Example #1
0
def check_other_classifiers(train_X, train_y, test_X, test_y):
    from pyriemann.classification import MDM, TSclassifier
    from sklearn.linear_model import LogisticRegression
    from pyriemann.estimation import Covariances
    from sklearn.pipeline import Pipeline
    from mne.decoding import CSP
    import seaborn as sns
    import pandas as pd

    train_y = [np.where(i == 1)[0][0] for i in train_y]
    test_y = [np.where(i == 1)[0][0] for i in test_y]

    cov_data_train = Covariances().transform(train_X)
    cov_data_test = Covariances().transform(test_X)
    cv = KFold(n_splits=10, random_state=42)
    clf = TSclassifier()
    scores = cross_val_score(clf, cov_data_train, train_y, cv=cv, n_jobs=1)
    print("Tangent space Classification accuracy: ", np.mean(scores))

    clf = TSclassifier()
    clf.fit(cov_data_train, train_y)
    print(clf.score(cov_data_test, test_y))

    mdm = MDM(metric=dict(mean='riemann', distance='riemann'))
    scores = cross_val_score(mdm, cov_data_train, train_y, cv=cv, n_jobs=1)
    print("MDM Classification accuracy: ", np.mean(scores))
    mdm = MDM()
    mdm.fit(cov_data_train, train_y)

    fig, axes = plt.subplots(1, 2)
    ch_names = [ch for ch in range(8)]

    df = pd.DataFrame(data=mdm.covmeans_[0], index=ch_names, columns=ch_names)
    g = sns.heatmap(df,
                    ax=axes[0],
                    square=True,
                    cbar=False,
                    xticklabels=2,
                    yticklabels=2)
    g.set_title('Mean covariance - feet')

    df = pd.DataFrame(data=mdm.covmeans_[1], index=ch_names, columns=ch_names)
    g = sns.heatmap(df,
                    ax=axes[1],
                    square=True,
                    cbar=False,
                    xticklabels=2,
                    yticklabels=2)
    plt.xticks(rotation='vertical')
    plt.yticks(rotation='horizontal')
    g.set_title('Mean covariance - hands')

    # dirty fix
    plt.sca(axes[0])
    plt.xticks(rotation='vertical')
    plt.yticks(rotation='horizontal')
    plt.savefig("meancovmat.png")
    plt.show()
    def fit(self, X, y):
        # validate
        X, y = check_X_y(X, y, allow_nd=True)
        X = check_array(X, allow_nd=True)

        # set internal vars
        self.classes_ = unique_labels(y)
        self.X_ = X
        self.y_ = y

        ##################################################
        # split X into train and test sets, so that
        # grid search can be performed on train set only
        seed = 7
        np.random.seed(seed)
        #X_TRAIN, X_TEST, y_TRAIN, y_TEST = train_test_split(X, y, test_size=0.25, random_state=seed)

        for epoch_trim in self.epoch_bounds:
            for bandpass in self.bandpass_filters:

                X_train, X_test, y_train, y_test = train_test_split(
                    X, y, test_size=0.25, random_state=seed)

                # X_train = np.copy(X_TRAIN)
                # X_test = np.copy(X_TEST)
                # y_train = np.copy(y_TRAIN)
                # y_test = np.copy(y_TEST)

                # separate out inputs that are tuples
                bandpass_start, bandpass_end = bandpass
                epoch_trim_start, epoch_trim_end = epoch_trim

                # bandpass filter coefficients
                b, a = butter(
                    5,
                    np.array([bandpass_start, bandpass_end]) /
                    (self.sfreq * 0.5), 'bandpass')

                # filter and crop TRAINING SET
                X_train = self.preprocess_X(X_train, b, a, epoch_trim_start,
                                            epoch_trim_end)
                # validate
                X_train, y_train = check_X_y(X_train, y_train, allow_nd=True)
                X_train = check_array(X_train, allow_nd=True)

                # filter and crop TEST SET
                X_test = self.preprocess_X(X_test, b, a, epoch_trim_start,
                                           epoch_trim_end)
                # validate
                X_test, y_test = check_X_y(X_test, y_test, allow_nd=True)
                X_test = check_array(X_test, allow_nd=True)

                ###########################################################################
                # self-tune CSP to find optimal number of filters to use at these settings
                #[best_num_filters, best_num_filters_score] = self.self_tune(X_train, y_train)
                best_num_filters = 5

                # as an option, we could tune optimal CSP filter num against complete train set
                #X_tune = self.preprocess_X(X, b, a, epoch_trim_start, epoch_trim_end)
                #[best_num_filters, best_num_filters_score] = self.self_tune(X_tune, y)

                # now use this insight to really fit with optimal CSP spatial filters
                """
				reg : float | str | None (default None)
			        if not None, allow regularization for covariance estimation
			        if float, shrinkage covariance is used (0 <= shrinkage <= 1).
			        if str, optimal shrinkage using Ledoit-Wolf Shrinkage ('ledoit_wolf')
			        or Oracle Approximating Shrinkage ('oas').
				"""
                transformer = CSP(n_components=best_num_filters,
                                  reg='ledoit_wolf')
                transformer.fit(X_train, y_train)

                # use these CSP spatial filters to transform train and test
                spatial_filters_train = transformer.transform(X_train)
                spatial_filters_test = transformer.transform(X_test)

                # put this back in as failsafe if NaN or inf starts cropping up
                # spatial_filters_train = np.nan_to_num(spatial_filters_train)
                # check_X_y(spatial_filters_train, y_train)
                # spatial_filters_test = np.nan_to_num(spatial_filters_test)
                # check_X_y(spatial_filters_test, y_test)

                # train LDA
                classifier = LinearDiscriminantAnalysis()
                classifier.fit(spatial_filters_train, y_train)
                score = classifier.score(spatial_filters_test, y_test)

                #print "current score",score
                print "bandpass:"******"epoch window:", epoch_trim_start, epoch_trim_end
                #print best_num_filters,"filters chosen"

                # put in ranked order Top 10 list
                idx = bisect(self.ranked_scores, score)
                self.ranked_scores.insert(idx, score)
                self.ranked_scores_opts.insert(
                    idx,
                    dict(bandpass=bandpass,
                         epoch_trim=epoch_trim,
                         filters=best_num_filters))
                self.ranked_classifiers.insert(idx, classifier)
                self.ranked_transformers.insert(idx, transformer)

                if len(self.ranked_scores) > self.num_votes:
                    self.ranked_scores.pop(0)
                if len(self.ranked_scores_opts) > self.num_votes:
                    self.ranked_scores_opts.pop(0)
                if len(self.ranked_classifiers) > self.num_votes:
                    self.ranked_classifiers.pop(0)
                if len(self.ranked_transformers) > self.num_votes:
                    self.ranked_transformers.pop(0)
                """
				Covariance computation
				"""
                # compute covariance matrices
                cov_data_train = covariances(X=X_train)
                cov_data_test = covariances(X=X_test)

                clf_mdm = MDM(metric=dict(mean='riemann', distance='riemann'))
                clf_mdm.fit(cov_data_train, y_train)
                score_mdm = clf_mdm.score(cov_data_test, y_test)
                # print "MDM prediction score:",score_mdm
                # put in ranked order Top 10 list
                idx = bisect(self.ranked_scores_mdm, score_mdm)
                self.ranked_scores_mdm.insert(idx, score_mdm)
                self.ranked_scores_opts_mdm.insert(
                    idx,
                    dict(bandpass=bandpass,
                         epoch_trim=epoch_trim,
                         filters=best_num_filters))
                self.ranked_classifiers_mdm.insert(idx, clf_mdm)

                if len(self.ranked_scores_mdm) > self.num_votes:
                    self.ranked_scores_mdm.pop(0)
                if len(self.ranked_scores_opts_mdm) > self.num_votes:
                    self.ranked_scores_opts_mdm.pop(0)
                if len(self.ranked_classifiers_mdm) > self.num_votes:
                    self.ranked_classifiers_mdm.pop(0)

                clf_ts = TSclassifier()
                clf_ts.fit(cov_data_train, y_train)
                score_ts = clf_ts.score(cov_data_test, y_test)
                # put in ranked order Top 10 list
                idx = bisect(self.ranked_scores_ts, score_ts)
                self.ranked_scores_ts.insert(idx, score_ts)
                self.ranked_scores_opts_ts.insert(
                    idx,
                    dict(bandpass=bandpass,
                         epoch_trim=epoch_trim,
                         filters=best_num_filters))
                self.ranked_classifiers_ts.insert(idx, clf_ts)

                if len(self.ranked_scores_ts) > self.num_votes:
                    self.ranked_scores_ts.pop(0)
                if len(self.ranked_scores_opts_ts) > self.num_votes:
                    self.ranked_scores_opts_ts.pop(0)
                if len(self.ranked_classifiers_ts) > self.num_votes:
                    self.ranked_classifiers_ts.pop(0)

                print "CSP+LDA score:", score, "Tangent space w/LR score:", score_ts

                print "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
                print "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^"
                print "    T O P  ", self.num_votes, "  C L A S S I F I E R S"
                print
                #j=1
                for i in xrange(len(self.ranked_scores)):
                    print i, ",", round(self.ranked_scores[i], 4), ",",
                    print self.ranked_scores_opts[i]
                print "-------------------------------------"
                for i in xrange(len(self.ranked_scores_ts)):
                    print i, ",", round(self.ranked_scores_ts[i], 4), ",",
                    print self.ranked_scores_opts_ts[i]
                print "-------------------------------------"
                for i in xrange(len(self.ranked_scores_mdm)):
                    print i, ",", round(self.ranked_scores_mdm[i], 4), ",",
                    print self.ranked_scores_opts_mdm[i]

        # finish up, set the flag to indicate "fitted" state
        self.fit_ = True

        # Return the classifier
        return self