Beispiel #1
0
def plot_model_weights(model: BaseSequence2Sequence, ch_names=None):
    from updateSummaryStatistics import plot_erp
    from scoreStimulus import factored2full
    plot_erp(factored2full(model.W_, model.R_),
             ch_names=ch_names,
             evtlabs=model.evtlabs,
             offset=model.offset)
def testcase():
    import sys

    if os.path.isdir('D:\external_data'):
        sessfn = 'D:\\external_data\twente\twofinger\S00.mat'
    else:
        sessfn = '/home/jadref/data/bci/external_data/twente/twofinger/S00.mat'
    # command-line, for testing
    if len(sys.argv) > 1:
        sessfn = sys.argv[1]
 
    from load_twofinger import load_twofinger
    oX, oY, coords = load_twofinger(sessfn, ofs=60, nsubtrials=40)
    times = coords[1]['coords']
    fs = coords[1]['fs']
    ch_names = coords[2]['coords']
    X=oX.copy()
    Y=oY.copy()

    print("X({}){}".format([c['name'] for c in coords],X.shape))
    print("Y={}".format(Y.shape))
    print("fs={}".format(fs))

    tau=fs*.7
    evtlabs = None
    times=np.arange(int(tau))/fs
    rank=1

    # visualize the dataset
    from stim2event import stim2event
    from updateSummaryStatistics import updateSummaryStatistics, plot_erp, plot_summary_statistics, idOutliers
    import matplotlib.pyplot as plt
    
    Cxx, Cxy, Cyy = updateSummaryStatistics(X, Y[...,0:1,:], tau=tau)

    plt.figure(1);
    print("summary stats")
    plot_summary_statistics(Cxx, Cxy, Cyy, evtlabs, times, ch_names)

    plt.figure(2);
    print("ERP")
    plot_erp(Cxy, ch_names=ch_names, evtlabs=evtlabs, times=times, plottype='plot', axis=-1)

    from model_fitting import MultiCCA
    from decodingCurveSupervised import decodingCurveSupervised
    cca = MultiCCA(tau=tau, evtlabs=evtlabs, rank=rank)
    scores = cca.cv_fit(X, Y)
    Fy = scores['estimator']
    print("Fy={}".format(Fy.shape))
    (_)=decodingCurveSupervised(Fy)

    # plot the solution
    from scoreStimulus import factored2full
    print("Plot Model")
    plt.figure(3)
    plot_erp(factored2full(cca.W_, cca.R_), ch_names=ch_names, evtlabs=evtlabs, times=times)
    #   plot Fy
    plt.figure(4)
    for ti in range(min(Fy.shape[0],25)):
        plt.subplot(5,5,ti+1)
        plt.imshow(np.cumsum(Fy[ti,:,:],axis=-2),aspect='auto')
    plt.show()
def debug_test_dataset(X, Y, coords=None, tau_ms=300, fs=None, offset_ms=0, evtlabs=('re', 'fe'), rank=1, model='cca', cv=True, preproc_args=None, **kwargs):
    fs = coords[1]['fs'] if coords is not None else fs
    tau = int(fs*tau_ms/1000)
    offset=int(offset_ms*fs/1000)    
    times = np.arange(offset,tau+offset)/fs
    
    if coords is not None:
        print("X({}){}".format([c['name'] for c in coords], X.shape))
    else:
        print("X={}".format(X.shape))
    print("Y={}".format(Y.shape))
    print("fs={}".format(fs))

    if preproc_args is not None:
        X, Y, coords = preprocess(X, Y, coords, **preproc_args)

    ch_names = coords[2]['coords'] if coords is not None else None
    ch_pos = None
    if coords is not None and 'pos2d' in coords[2]:
        ch_pos = coords[2]['pos2d']
    elif not ch_names is None and len(ch_names) > 0:
        from readCapInf import getPosInfo
        cnames, xy, xyz, iseeg =getPosInfo(ch_names)
        ch_pos=xy
    if ch_pos is not None:
        print('ch_pos={}'.format(ch_pos.shape))

    # visualize the dataset
    from stim2event import stim2event
    from updateSummaryStatistics import updateSummaryStatistics, plot_erp, plot_summary_statistics, idOutliers
    import matplotlib.pyplot as plt

    print("Plot X+Y")
    trli=min(3,X.shape[0]-1)
    plt.figure(10); plt.clf()
    plt.subplot(211);
    plt.imshow(X[trli,:,:].T,aspect='auto');plt.colorbar();plt.title('X');plt.xlabel('time (samp)');plt.legend();
    plt.subplot(212);
    if Y.ndim == 3:
        plt.imshow(Y[trli, :, :].T, aspect='auto')
        plt.xlabel('time (samp)')
        plt.ylabel('target')
    else:
        plt.plot(Y[trli, :, 0, :])
    plt.title('Y')
    plt.show()

    print("Plot summary stats")
    if Y.ndim == 4: # already transformed
        Yevt = Y
    else: # convert to event
        Yevt = stim2event(Y, axis=-2, evtypes=evtlabs)
    Cxx, Cxy, Cyy = updateSummaryStatistics(X, Yevt[..., 0:1, :], tau=tau)
    plt.figure(11); plt.clf()
    plot_summary_statistics(Cxx, Cxy, Cyy, evtlabs, times, ch_names)
    plt.show()

    print('Plot global spectral properties')
    from scipy.signal import welch
    freqs, FX = welch(X, axis=-2, fs=fs, nperseg=fs//2, return_onesided=True, detrend=False) # FX = (nFreq, nch)
    print('FX={}'.format(FX.shape))
    plt.figure(18);plt.clf()
    muFX = np.median(FX,axis=0,keepdims=True)
    ylim = (0,2*np.median(np.max(muFX,axis=-2),axis=-1))
    plot_erp(np.median(FX,axis=0,keepdims=True), ch_names=ch_names, evtlabs=None, times=freqs, ylim=ylim)
    plt.suptitle("Grand average spectrum")
    plt.show()

    print("Plot ERP")
    plt.figure(12);plt.clf()
    plot_erp(Cxy, ch_names=ch_names, evtlabs=evtlabs, times=times)
    plt.suptitle("ERP")
    plt.show()
    
    # fit the model
    score, res, Fy, clsfr = analyse_dataset(X,Y,coords,model,evtlabs=evtlabs,cv=cv,tau_ms=tau_ms,fs=fs,offset_ms=offset_ms,rank=rank,**kwargs)
    
    plt.figure(14)
    plot_decoding_curve(*res)
    plt.suptitle("Decoding Curve")

    print("Plot Model")
    plt.figure(15);plt.clf()
    #filter2pattern(clsfr.sigma_,factored2full(clsfr.W_,clsfr.R_))
    if hasattr(clsfr,'A_'):
        plt.suptitle("fwd-model")
        plot_erp(factored2full(clsfr.A_, clsfr.R_), ch_names=ch_names, evtlabs=evtlabs, times=times)
    else:
        plt.suptitle("bwd-model")
        plot_erp(factored2full(clsfr.W_, clsfr.R_), ch_names=ch_names, evtlabs=evtlabs, times=times)
    plt.show()

    if not clsfr.R_ is None:
        print("Plot Factored Model")
        plt.figure(18);plt.clf();
        if hasattr(clsfr,'A_'):
            plt.suptitle("fwd-model")
            plot_factoredmodel(clsfr.A_, clsfr.R_, ch_names=ch_names, ch_pos=ch_pos, evtlabs=evtlabs, times=times)
        else:
            plt.suptitle("bwd-model")
            plot_factoredmodel(clsfr.W_, clsfr.R_, ch_names=ch_names, ch_pos=ch_pos, evtlabs=evtlabs, times=times)
        plt.show()
    
    print("plot Fe")
    plt.figure(16);plt.clf()
    Fe = clsfr.transform(X)
    plot_Fe(Fe)
    plt.suptitle("Fe")
    plt.show()

    print("plot Fy")
    plt.figure(17);plt.clf()
    plot_Fy(Fy,cumsum=True)
    plt.suptitle("Fy")
    plt.show()

    from normalizeOutputScores import normalizeOutputScores, plot_normalizedScores
    print("normalized Fy")
    plt.figure(20);plt.clf()
    # normalize every sample
    ssFy, scale_sFy, decisIdx, nEp, nY = normalizeOutputScores(Fy, minDecisLen=-1)
    plot_Fy(ssFy,cumsum=False)
    plt.suptitle("normalized_Fy")
    plt.show()

    plt.figure(21)
    plot_normalizedScores(Fy[4,:,:],ssFy[4,:,:],scale_sFy[4,:],decisIdx)

    return clsfr
Beispiel #4
0
def testcase(dataset='toy', loader_args=dict()):
    from model_fitting import MultiCCA, FwdLinearRegression, BwdLinearRegression, LinearSklearn
    from decodingCurveSupervised import decodingCurveSupervised
    from datasets import get_dataset

    loadfn, filenames, dataroot = get_dataset(dataset)
    if dataset == 'toy':
        loader_args = dict(tau=10, isi=5, noise2signal=3, nTrl=20, nSamp=50)
    X, Y, coords = loadfn(filenames[0], **loader_args)
    fs = coords[1]['fs']
    ch_names = coords[2]['coords']

    Y = Y[..., 0]  # (nsamp, nY)

    # raw
    tau = int(.3 * fs)
    evtlabs = ('re', 'fe', 'rest'
               )  #  ('re', 'ntre')#  ('re', 'ntre', 'rest')#  ('1', '0')#
    cca = MultiCCA(tau=tau, evtlabs=evtlabs)
    print('cca = {}'.format(cca))
    cca.fit(X, Y)
    Fy = cca.predict(X, Y, dedup0=True)
    print("score={}".format(cca.score(X, Y)))
    (_) = decodingCurveSupervised(Fy)

    # cca - cv-fit
    print("CV fitted")
    cca = MultiCCA(tau=tau, rank=1, reg=None, evtlabs=evtlabs)
    cv_res = cca.cv_fit(X, Y)
    Fy = cv_res['estimator']
    (_) = decodingCurveSupervised(Fy,
                                  priorsigma=(cca.sigma0_, cca.priorweight))

    from model_fitting import MultiCCA, FwdLinearRegression, BwdLinearRegression
    from updateSummaryStatistics import updateSummaryStatistics, plot_erp, plot_summary_statistics
    from scoreStimulus import factored2full
    import matplotlib.pyplot as plt
    from decodingCurveSupervised import decodingCurveSupervised
    tau = int(.3 * fs)
    rank = 1
    evtlabs = ('re', 'fe', 'rest'
               )  #('re', 'ntre') # ('re', 'fe')  # ('1', '0') #
    # cca
    cca = MultiCCA(tau=tau, rank=rank, evtlabs=evtlabs, reg=None)
    print("{}".format(cca))
    cca.fit(X, Y)
    Fy = cca.predict(X, Y, dedup0=True)
    (_) = decodingCurveSupervised(Fy)

    plot_erp(factored2full(cca.W_, cca.R_), ch_names=ch_names, evtlabs=evtlabs)
    plt.savefig('W_cca.png')

    # fwd-model
    print("Forward Model")
    fwd = FwdLinearRegression(tau=tau,
                              evtlabs=evtlabs,
                              reg=None,
                              badEpThresh=4)
    print("{}".format(fwd))
    fwd.fit(X, Y)
    Fy = fwd.predict(X, Y, dedup0=True)
    (_) = decodingCurveSupervised(Fy)

    # bwd-model
    print("Backward Model")
    bwd = BwdLinearRegression(tau=tau,
                              evtlabs=evtlabs,
                              reg=None,
                              badEpThresh=4)
    print("{}".format(bwd))
    bwd.fit(X, Y)
    Fy = bwd.predict(X, Y, dedup0=True)
    (_) = decodingCurveSupervised(Fy)

    Py = bwd.predict_proba(X, Y, dedup0=True)
    visualize_Fy_Py(Fy, Py)

    # sklearn wrapper
    from sklearn.linear_model import Ridge, LogisticRegression
    from sklearn.svm import LinearSVR, LinearSVC
    print("sklearn-ridge")
    ridge = LinearSklearn(tau=tau, evtlabs=evtlabs, clsfr=Ridge(alpha=0))
    print("{}".format(ridge))
    ridge.fit(X, Y)
    Fy = ridge.predict(X, Y, dedup0=True)
    (_) = decodingCurveSupervised(Fy)

    print("sklear-lr")
    lr = LinearSklearn(tau=tau,
                       evtlabs=evtlabs,
                       clsfr=LogisticRegression(C=1,
                                                multi_class='multinomial',
                                                solver='sag'),
                       labelizeY=True)
    print("{}".format(lr))
    lr.fit(X, Y)
    Fy = lr.predict(X, Y, dedup0=True)
    (_) = decodingCurveSupervised(Fy)

    print("sklear-svc")
    svc = LinearSklearn(tau=tau,
                        evtlabs=evtlabs,
                        clsfr=LinearSVC(C=1, multi_class='ovr'),
                        labelizeY=True)
    print("{}".format(svc))
    svc.fit(X, Y)
    Fy = svc.predict(X, Y, dedup0=True)
    (_) = decodingCurveSupervised(Fy)

    plot_erp(factored2full(svc.W_, svc.R_), ch_names=ch_names, evtlabs=evtlabs)
    plt.savefig('W_svc.png')

    # hyper-parameter optimization with cross-validation
    from sklearn.model_selection import GridSearchCV
    tuned_parameters = {
        'rank': [1, 2, 3, 5],
        'tau': [int(dur * fs) for dur in [.2, .3, .5, .7]],
        'evtlabs': [['re', 'fe'], ['re', 'ntre'], ['0', '1']]
    }
    cv_cca = GridSearchCV(MultiCCA(), tuned_parameters)
    cv_cca.fit(X, Y)
    print("CVOPT:\n\n{} = {}\n".format(cv_cca.best_estimator_,
                                       cv_cca.best_score_))
    means = cv_cca.cv_results_['mean_test_score']
    stds = cv_cca.cv_results_['std_test_score']
    for mean, std, params in zip(means, stds, cv_cca.cv_results_['params']):
        print("{:5.3f} (+/-{:5.3f}) for {}".format(mean, std * 2, params))
    print()

    # use the best setting to fit a normal model, and get it's cv estimated predictions
    cca = MultiCCA()
    cca.set_params(**cv_cca.best_params_)  # Note: **dict -> k,v argument array
    cv_res = cca.cv_fit(X, Y)
    Fy = cv_res['estimator']
    (_) = decodingCurveSupervised(Fy)

    # Slice data
    from utils import sliceData, sliceY
    stimTimes = st[st < X.shape[1] - tau + 1]  # limit valid stimTimes
    Xe = sliceData(X, stimTimes, tau=tau)  # d x tau x ep x trl
    Ye = sliceY(Y, stimTimes, featdim=False)  # y x ep  x trl
    print('cca = {}'.format(cca))
    Fy = cca.predict(X, Y)
    print("Fy={}".format(Fy.shape))
    Py = cca.predict_proba(X, Y)
    print("Py={}".format(Py.shape))
    score = cca.score(X, Y)
    print("score={}".format(score))
    decodingCurveSupervised(Fy)