def xdawn_embedding(data, use_xdawn):
    """Perform embedding of EEG data in 2D Euclidean space
    with Laplacian Eigenmaps.

    Parameters
    ----------
    data : dict
        A dictionary containing training and testing data

    Returns
    -------
    array
        Embedded

    """

    if use_xdawn:
        nfilter = 3
        xdwn = XdawnCovariances(estimator='scm', nfilter=nfilter)
        covs = xdwn.fit(data['train_x'],
                        data['train_y']).transform(data['test_x'])

        lapl = Embedding(metric='riemann', n_components=3)
        embd = lapl.fit_transform(covs)
    else:
        tangent_space = Pipeline([
            ('cov_transform', Covariances(estimator='lwf')),
            ('tangent_space', TangentSpace(metric='riemann'))
        ])
        t_space = tangent_space.fit(data['train_x'],
                                    data['train_y']).transform(data['test_x'])
        reducer = umap.UMAP(n_neighbors=30, min_dist=1, spread=2)
        embd = reducer.fit_transform(t_space)

    return embd
def get_sourcetarget_split_p300(source, target, ncovs_train):

    X_source = source['signals']
    y_source = source['labels'].flatten()
    covs_source = XdawnCovariances(classes=[2]).fit_transform(
        X_source, y_source)

    source = {}
    source['covs'] = covs_source
    source['labels'] = y_source

    X_target = target['signals']
    y_target = target['labels'].flatten()

    if ncovs_train is None:
        ncovs_train = np.sum(y_target == 2)

    sel = np.arange(len(y_target))
    np.random.shuffle(sel)
    X_target = X_target[sel]
    y_target = y_target[sel]

    idx_erps = np.where(y_target == 2)[0][:ncovs_train]
    idx_rest = np.where(
        y_target == 1)[0][:ncovs_train *
                          5]  # because there's one ERP in every 6 flashes

    idx_train = np.concatenate([idx_erps, idx_rest])
    idx_test = np.array(
        [i for i in range(len(y_target)) if i not in idx_train])

    erp = XdawnCovariances(classes=[2])
    erp.fit(X_target[idx_train], y_target[idx_train])

    target_train = {}
    covs_target_train = erp.transform(X_target[idx_train])
    y_target_train = y_target[idx_train]
    target_train['covs'] = covs_target_train
    target_train['labels'] = y_target_train

    target_test = {}
    covs_target_test = erp.transform(X_target[idx_test])
    y_target_test = y_target[idx_test]
    target_test['covs'] = covs_target_test
    target_test['labels'] = y_target_test

    return source, target_train, target_test
def xdawn_embedding(data):
    """Perform embedding of EEG data in 2D Euclidean space
    with Laplacian Eigenmaps.

    Parameters
    ----------
    data : dict
        A dictionary containing training and testing data

    Returns
    -------
    array
        Embedded

    """

    nfilter = 3
    xdwn = XdawnCovariances(estimator='scm', nfilter=nfilter)
    covs = xdwn.fit(data['train_x'], data['train_y']).transform(data['test_x'])

    lapl = Embedding(metric='riemann', n_components=3)
    embd = lapl.fit_transform(covs)

    return embd
print('Multiclass classification with XDAWN + MDM')

clf = make_pipeline(XdawnCovariances(n_components), MDM())

for train_idx, test_idx in cv.split(epochs_data):
    y_train, y_test = labels[train_idx], labels[test_idx]

    clf.fit(epochs_data[train_idx], y_train)
    pr[test_idx] = clf.predict(epochs_data[test_idx])

print(classification_report(labels, pr))

###############################################################################
# plot the spatial patterns
xd = XdawnCovariances(n_components)
xd.fit(epochs_data, labels)

evoked.data = xd.Xd_.patterns_.T
evoked.times = np.arange(evoked.data.shape[0])
evoked.plot_topomap(
    times=[0, n_components, 2 * n_components, 3 * n_components],
    ch_type='grad',
    colorbar=False,
    size=1.5)

###############################################################################
# plot the confusion matrix
names = ['audio left', 'audio right', 'vis left', 'vis right']
plot_confusion_matrix(labels, pr, names)
plt.show()
epochs_data = epochs.get_data()

print("Multiclass classification with XDAWN + MDM")

clf = make_pipeline(XdawnCovariances(n_components), MDM())

for train_idx, test_idx in cv:
    y_train, y_test = labels[train_idx], labels[test_idx]

    clf.fit(epochs_data[train_idx], y_train)
    pr[test_idx] = clf.predict(epochs_data[test_idx])

print(classification_report(labels, pr))

###############################################################################
# plot the spatial patterns
xd = XdawnCovariances(n_components)
xd.fit(epochs_data, labels)

evoked.data = xd.Xd.patterns_.T
evoked.times = np.arange(evoked.data.shape[0])
evoked.plot_topomap(
    times=[0, n_components, 2 * n_components, 3 * n_components], ch_type="grad", colorbar=False, size=1.5
)

###############################################################################
# plot the confusion matrix
names = ["audio left", "audio right", "vis left", "vis right"]
plot_confusion_matrix(labels, pr, names)
plt.show()
# Read epochs
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False,
                    picks=picks, baseline=None, preload=True, verbose=False)

X = epochs.get_data()
y = epochs.events[:, -1]

###############################################################################
# Embedding the Xdawn covariance matrices with Laplacian Eigenmaps

nfilter = 4
xdwn = XdawnCovariances(estimator='scm', nfilter=nfilter)
split = train_test_split(X, y, train_size=0.25, random_state=42)
Xtrain, Xtest, ytrain, ytest = split
covs = xdwn.fit(Xtrain, ytrain).transform(Xtest)

lapl = Embedding(metric='riemann', n_components=2)
embd = lapl.fit_transform(covs)

###############################################################################
# Plot the three first components of the embedded points

fig, ax = plt.subplots(figsize=(7, 8), facecolor='white')

for cond, label in event_id.items():
    idx = (ytest == label)
    ax.scatter(embd[idx, 0], embd[idx, 1], s=36, label=cond)

ax.set_xlabel(r'$\varphi_1$', fontsize=16)
ax.set_ylabel(r'$\varphi_2$', fontsize=16)
        X, y, meta = paradigm.get_data(datasets[condition], subjects=[subject])
        y = LabelEncoder().fit_transform(y)

        data[condition]['X'] = X
        data[condition]['y'] = y

        # estimate xDawn covs
        ncomps = 4
        erp = XdawnCovariances(classes=[1],
                               estimator='lwf',
                               nfilter=ncomps,
                               xdawn_estimator='lwf')
        #erp = ERPCovariances(classes=[1], estimator='lwf', svd=ncomps)
        split = train_test_split(X, y, train_size=0.50, random_state=42)
        Xtrain, Xtest, ytrain, ytest = split
        covs = erp.fit(Xtrain, ytrain).transform(Xtest)

        Mtarget = mean_riemann(covs[ytest == 1])
        Mnontarget = mean_riemann(covs[ytest == 0])
        stats[condition]['distance'] = distance_riemann(Mtarget, Mnontarget)
        stats[condition]['dispersion_target'] = np.sum(
            [distance_riemann(covi, Mtarget)**2
             for covi in covs[ytest == 1]]) / len(covs[ytest == 1])
        stats[condition]['dispersion_nontarget'] = np.sum([
            distance_riemann(covi, Mnontarget)**2 for covi in covs[ytest == 0]
        ]) / len(covs[ytest == 0])

    print('subject', subject)
    print(stats)

    results[subject] = stats
                    picks=picks,
                    baseline=None,
                    preload=True,
                    verbose=False)

X = epochs.get_data()
y = epochs.events[:, -1]

###############################################################################
# Embedding the Xdawn covariance matrices with Laplacian Eigenmaps

nfilter = 4
xdwn = XdawnCovariances(estimator='scm', nfilter=nfilter)
split = train_test_split(X, y, train_size=0.25, random_state=42)
Xtrain, Xtest, ytrain, ytest = split
covs = xdwn.fit(Xtrain, ytrain).transform(Xtest)

lapl = Embedding(metric='riemann', n_components=2)
embd = lapl.fit_transform(covs)

###############################################################################
# Plot the three first components of the embedded points

fig, ax = plt.subplots(figsize=(7, 8), facecolor='white')

for cond, label in event_id.items():
    idx = (ytest == label)
    ax.scatter(embd[idx, 0], embd[idx, 1], s=36, label=cond)

ax.set_xlabel(r'$\varphi_1$', fontsize=16)
ax.set_ylabel(r'$\varphi_2$', fontsize=16)