def pca_component_plot(component_ix: int, rotation_matrix: np.array,
                       stimXr: np.array, respXr: np.array,
                       stim_epochs: mne.Epochs, response_epochs: mne.Epochs):
    '''Make a pretty plot.
    
    Pars:
    - component_ix: ID of the component to plot
    - rotation_matrix
    - stimXr: Rotated stimulus-locked data
    - respXr: Rotated response-locked data
    - stim_epochs, response_epochs: The original epoch objects
    '''
    fig, axes = plt.subplots(1,
                             3,
                             figsize=(18, 3),
                             gridspec_kw={'width_ratios': [1, 4, 4]})
    topomap(rotation_matrix[:, component_ix],
            response_epochs.info,
            axes=axes[0])
    axes[0].set_title('Component %i' % (component_ix + 1))
    axes = axes[1:]
    Xs = [stimXr, respXr]
    timeses = [stim_epochs.times, response_epochs.times]
    xlabs = ['Time from stimulus', 'Time to response']
    for ax, X, times, xlab in zip(axes, Xs, timeses, xlabs):
        plt.sca(ax)
        eegf.plot_mean_sem(X[:, component_ix] * million, times)
        plt.hlines(0, *plt.xlim(), linestyle='dashed')
        plt.vlines(0, *plt.ylim(), linestyle='dashed')
        plt.xlabel(xlab)
        ax.set_yticklabels([])
    plt.tight_layout()
    return fig
def do_threeway_comparison(epochs,
                           ch,
                           by_subject=True,
                           agg_func=mean_by_subject,
                           crop=None,
                           baseline=None,
                           title=None,
                           ax=None,
                           legend=True,
                           neg_up=True):
    # labels = ['Easy choice', 'Difficult choice', 'Guess']
    if ax is not None:
        plt.sca(ax)
    es = get_condition_epochs(epochs, crop=crop, baseline=baseline)
    labs = ['Easy', 'Difficult', 'Guess']
    for e, l in zip(es, labs):
        if by_subject:
            X = agg_func(e)
        else:
            X = e.get_data()
        if type(ch) == list and len(ch) == 2:
            X = million * (X[:, ch[0]] - X[:, ch[1]])
        else:
            X = X[:, ch] * million
        eegf.plot_mean_sem(X, e.times, label=l)
        if neg_up:
            eegf.flipy()
    if legend:
        plt.legend()
def do_twoway_comparison(epochs,
                         ch,
                         variable,
                         labels,
                         crop=None,
                         baseline=None,
                         title=None,
                         by_subject=True,
                         neg_up=True):
    E = epochs.copy()
    if crop is not None:
        E = E.crop(*crop)
    if baseline is not None:
        E = E.apply_baseline(baseline)
    e0 = E['%s == 0' % variable]
    e1 = E['%s == 1' % variable]
    es = [e0, e1]
    for e, l in zip(es, labels):
        if by_subject:
            X = mean_by_subject(e)[:, ch] * million
        else:
            X = e.get_data()[:, ch] * million
        eegf.plot_mean_sem(X, e.times, label=l)
        if neg_up:
            eegf.flipy()
    plt.legend()
def do_rt_comparison(epochs, ch, drop_direction, bins=6, se=False):
    '''drop direction = 1 for trial-locked epochs, -1 for response-locked'''
    E = epochs.copy()
    qs = pd.qcut(E.metadata['rt'], bins)
    t0 = E.time_as_index(0)[0]
    for i, q in enumerate(np.sort(qs.unique())):
        ix = qs == q
        e = E[ix]
        lbl = '%.2f < RT < %.2f' % (q.left, q.right)
        X = e.get_data()[:, ch] * million
        if drop_direction == 1:
            ti = e.time_as_index(q.right)[0]
            if ti > 0:
                X[:, ti:] = np.nan
        else:
            ti = e.time_as_index(-q.right)[0]
            if ti > 0:
                X[:, :ti] = np.nan
        if se:
            eegf.plot_mean_sem(X, e.times, label=lbl)
        else:
            plt.plot(e.times, X.mean(0), label=lbl)
    plt.vlines(0, linestyle='dashed', *plt.ylim())
def do_pca_for_subject(trial_epochs_csd, response_epochs_csd, s, info):
    from matplotlib.gridspec import GridSpec
    print('# Participant %i' % s)
    r_ep = response_epochs_csd['participant == %i' % s]
    X = r_ep.copy().crop(-2., 0).get_data()[:, :32]
    trialX = trial_epochs_csd['participant == %i' % s]
    respX = r_ep.get_data()[:, :32]
    fig = plt.figure(figsize=(20, 16))
    gs = GridSpec(nrows=4, ncols=12, figure=fig)
    fig.suptitle('Participant %i' % s)
    ## Do PCA
    covariance_csd = np.array(
        [np.cov(X[i] - X[i].mean()) for i in range(X.shape[0])])
    cov = covariance_csd.mean(0)
    # eig_vals, eig_vecs = np.linalg.eig(cov)
    eig_vals, eig_vecs = np.linalg.eigh(cov)
    eig_vals = eig_vals[::-1]  ## Reverse order
    eig_vecs = eig_vecs[:, ::-1]
    ## Variance explained
    ve = eig_vals / eig_vals.sum()
    ax1 = fig.add_subplot(gs[0, 0:2])
    plt.sca(ax1)
    plt.plot(list(range(1, len(ve) + 1)), ve * 100, '-o')
    plt.hlines(ve.mean() * 100, linestyle='dashed', *plt.xlim())
    plt.ylabel('% variance explained')
    plt.xlabel('Component')
    plt.xticks(list(range(1, 32, 2)))
    plt.xlim(0, 12)
    ## Correct signs
    t0, t1 = response_epochs_csd.time_as_index([-2, 0])
    respX_pca = np.stack(
        [respX[i].T.dot(eig_vecs) for i in range(respX.shape[0])],
        axis=0).swapaxes(2, 1)
    X = respX_pca.mean(0)
    comp_signs = np.sign(X[:, t1] - X[:, t0])
    eig_vecs *= comp_signs
    ## Plot topography
    n = 9
    for i in range(n):
        ax = fig.add_subplot(gs[0, 2 + i])
        topomap(eig_vecs[:, i], info, axes=ax)
        plt.title('C%i' % (i + 1))
    # Get timecourse
    respX_pca = np.stack(
        [respX[i].T.dot(eig_vecs)
         for i in range(respX.shape[0])], axis=0).swapaxes(2, 1) * 1000000
    trialX = trial_epochs_csd.get_data()[:, :32]
    trialX_pca = np.stack(
        [trialX[i].T.dot(eig_vecs)
         for i in range(trialX.shape[0])], axis=0).swapaxes(2, 1) * 1000000
    ## PCA timecourse 1
    ax3 = fig.add_subplot(gs[1:2, :6])
    plt.sca(ax3)
    for i in range(9):
        eegf.plot_mean_sem(trialX_pca[:, i],
                           trial_epochs_csd.times,
                           label='C %i' % (i + 1))
    plt.vlines(0, linestyle='--', *plt.ylim())
    plt.legend(loc='upper left', prop={'size': 10})
    plt.xlim(-.5, 2)
    plt.xlabel('Time from onset (s)')
    ## PCA timecourse 2
    ax4 = fig.add_subplot(gs[1:2, 6:])
    plt.sca(ax4)
    for i in range(9):
        eegf.plot_mean_sem(respX_pca[:, i],
                           response_epochs_csd.times,
                           label='C %i' % (i + 1))
    plt.vlines(0, linestyle='--', *plt.ylim())
    plt.legend(loc='upper left', prop={'size': 10})
    plt.xlim(-2, .5)
    plt.xlabel('Time to action (s)')
    ## Rotate
    varimax_vectors = varimax(eig_vals[:n] * eig_vecs[:, :n],
                              method='varimax').T
    ## Correct signs
    t0, t1 = response_epochs_csd.time_as_index([-2, 0])
    respX_vmax = np.stack(
        [respX[i].T.dot(varimax_vectors) for i in range(respX.shape[0])],
        axis=0).swapaxes(2, 1)
    X = respX_vmax.mean(0)
    comp_signs = np.sign(X[:, t1] - X[:, t0])
    varimax_vectors *= comp_signs
    ## Topography
    for i in range(n):
        ax = fig.add_subplot(gs[2, 2 + i])
        topomap(varimax_vectors[:, i], info, axes=ax, show=False)
        plt.title('VM%i' % (i + 1))
    ## Varimax timecourses
    respX_vmax = np.stack(
        [respX[i].T.dot(varimax_vectors) for i in range(respX.shape[0])],
        axis=0).swapaxes(2, 1)
    trialX_vmax = np.stack(
        [trialX[i].T.dot(varimax_vectors) for i in range(trialX.shape[0])],
        axis=0).swapaxes(2, 1)
    ## Varimax timecourse 1
    ax3 = fig.add_subplot(gs[3, :6])
    plt.sca(ax3)
    for i in range(9):
        eegf.plot_mean_sem(trialX_vmax[:, i],
                           trial_epochs_csd.times,
                           label='C %i' % (i + 1))
    plt.vlines(0, linestyle='--', *plt.ylim())
    plt.legend(loc='upper left', prop={'size': 10})
    plt.xlim(-.5, 2)
    plt.xlabel('Time from onset (s)')
    ## Varimax timecourse 2
    ax4 = fig.add_subplot(gs[3, 6:])
    plt.sca(ax4)
    for i in range(9):
        eegf.plot_mean_sem(respX_vmax[:, i],
                           response_epochs_csd.times,
                           label='C %i' % (i + 1))
    plt.vlines(0, linestyle='--', *plt.ylim())
    plt.legend(loc='upper left', prop={'size': 10})
    plt.xlim(-2, .5)
    plt.xlabel('Time to action (s)')
    ## Finish
    # plt.tight_layout()
    return fig