Beispiel #1
0
def spatially_whiten(X: np.ndarray, *args, **kwargs):
    """spatially whiten the nd-array X

    Args:
        X (np.ndarray): the data to be whitened, with channels/space in the *last* axis

    Returns:
        X (np.ndarray): the whitened X
        W (np.ndarray): the whitening matrix used to whiten X
    """
    Cxx = updateCxx(None, X, None)
    W, _ = robust_whitener(Cxx, *args, **kwargs)
    X = X @ W  #np.einsum("...d,dw->...w",X,W)
    return (X, W)
minsubset = ('C3', 'C4')
chsubset = minsubset
keep = [c in chsubset for c in ch_names]
X = X[..., keep]
ch_names = [ch_names[i] for i in range(len(ch_names)) if keep[i]]
plt.figure(100)
plot_erp(X, lab, 'car')

# hp-lp
X, _, _ = butter_sosfilt(X, stopband=((0, 8), (16, -1)), fs=fs)
plt.figure(101)
plot_erp(X, lab, 'hp-lp', plotp=True)

# whiten
Cxx = updateCxx(None, X, None)
W, _ = robust_whitener(Cxx)
X = np.einsum("Tsd,dw->Tsw", X, W)
#Cxxw=updateCxx(None,X,None)

plt.figure(102)
plot_erp(X, lab, 'wht')

# welch
freqs, X = welch(X,
                 fs,
                 axis=-2,
                 nperseg=int(fs * .5),
                 noverlap=.5,
                 return_onesided=True,
                 scaling='spectrum')
plt.figure(103)
Beispiel #3
0
def load_ninapro_db2(datadir,
                     stopband=((0, 15), (45, 65), (95, 125), (250, -1)),
                     envelopeband=(10, -1),
                     trlen_ms=None,
                     ofs=60,
                     nvirt=20,
                     rectify=True,
                     whiten=True,
                     log=True,
                     plot=False,
                     filterbank=None,
                     zscore_y=True,
                     verb=1):
    d = loadmat(datadir, variable_names=('emg', 'glove', 'stimulus'))
    X = d['emg']  # (nSamp,d)
    Y = d['glove']  # (nSamp,e)
    lab = d['stimulus'].ravel()  # (nSamp,1) - use to slice out trials+labels
    fs = 2000
    if ofs is None:
        ofs = fs

    # get trial start/end info
    trl_start = np.flatnonzero(np.diff(lab) > 0)
    lab = lab[trl_start + 1]
    print('trl_start={}'.format(trl_start))
    print('label={}'.format(lab))
    print("diff(trl_start)={}".format(np.diff(trl_start)))
    if trlen_ms is None:
        trlen_ms = np.max(np.diff(trl_start)) * 1000 / fs
        print('trlen_ms={}'.format(trlen_ms))

    if not stopband is None:
        if verb > 0:
            print("preFilter: {}Hz".format(stopband))
        X, _, _ = butter_sosfilt(X, stopband, fs)
        if plot:
            plt.figure(101)
            plt.plot(X)
            plt.title("hp+notch+lp")
        # preprocess -> spatial whiten
        # TODO[] : make this  fit->transform method

    if whiten:
        if verb > 0: print("spatial whitener")
        Cxx = updateCxx(None, X, None)
        W, _ = robust_whitener(Cxx)
        X = np.einsum("sd,dw->sw", X, W)
        if plot:
            plt.figure(102)
            plt.plot(X)
            plt.title("+whiten")

    if not filterbank is None:
        if verb > 0: print("Filterbank: {}".format(filterbank))
        # apply filter bank to frequency ranges into virtual channels
        Xs = []
        # TODO: make a nicer shape, e.g. (tr,samp,band,ch)
        for bi, band in enumerate(filterbank):
            Xf, _, _ = butter_sosfilt(X, band, fs)
            Xs.append(Xf)
        # stack the bands as virtual channels
        X = np.concatenate(Xs, -1)

    X = np.abs(X)  # rectify

    if log:
        if verb > 0: print("log amplitude")
        X = np.log(np.maximum(X, 1e-6))
    if plot:
        plt.figure(103)
        plt.plot(X)
        plt.title("+abs")
    if envelopeband is not None:
        if verb > 0: print("Envelop band={}".format(envelopeband))
        X, _, _ = butter_sosfilt(X, envelopeband,
                                 fs)  # low-pass = envelope extraction
        if plot:
            plt.figure(104)
            plt.plot(X)
            plt.title("env")

    # preprocess -> downsample
    resamprate = int(fs / ofs)
    if resamprate > 1:
        if verb > 0:
            print("resample: {}->{}hz rsrate={}".format(
                fs, fs / resamprate, resamprate))
        X = X[..., ::resamprate, :]  # decimate X (trl, samp, d)
        Y = Y[..., ::resamprate, :]  # decimate Y (trl, samp, e)
        trl_start = trl_start / resamprate
        fs = fs / resamprate

    # pre-process : z-trans Y
    if zscore_y:
        if verb > 0: print("Z-trans Y")
        mu = np.mean(Y, axis=-2, keepdims=True)
        std = np.std(Y, axis=-2, keepdims=True)
        std[std < 1e-6] = 1  # guard divide by 0
        Y = (Y - mu) / std

    # generate artificial other stimulus streams, for testing
    # TODO: randomize in better way
    Y = Y[:, np.newaxis, :]  # (nSamp, nY, e)
    Y_test = block_randomize(Y,
                             nvirt,
                             axis=-3,
                             block_size=Y.shape[0] // 100 // 2)
    Y = np.concatenate((Y, Y_test), -2)  # (nSamp, nY, e)

    # slice X,Y into trials
    oX = X  # (nSamp,d)
    oY = Y  # (nSamp,nY,e)
    trlen_samp = int(trlen_ms * fs / 1000)
    X = np.zeros((trl_start.size, trlen_samp, X.shape[-1]))
    Y = np.zeros((trl_start.size, trlen_samp) + Y.shape[-2:])
    print("Slicing {} trials of {}ms".format(len(trl_start), trlen_ms))
    for ti, tii in enumerate(trl_start):
        tii = int(tii)
        trl_len = min(oX.shape[0], tii + trlen_samp) - tii
        X[ti, :trl_len, ...] = oX[tii:tii + trl_len, ...]
        Y[ti, :trl_len, ...] = oY[tii:tii + trl_len, ...]

    # make meta-info
    coords = [None] * X.ndim
    coords[0] = {'name': 'trial', 'coords': lab}
    coords[1] = {
        'name': 'time',
        'fs': fs,
        'units': 'ms',
        'coords': np.arange(X.shape[1]) * 1000 / fs
    }
    coords[2] = {'name': 'channel', 'coords': None}

    return (X, Y, coords)
Beispiel #4
0
def extract_envelope(X,
                     fs,
                     stopband=None,
                     whiten=True,
                     filterbank=None,
                     log=True,
                     env_stopband=(10, -1),
                     verb=False,
                     plot=False):
    """extract the envelope from the input data

    Args:
        X ([type]): [description]
        fs ([type]): [description]
        stopband ([type], optional): pre-filter stop band. Defaults to None.
        whiten (bool, optional): flag if we spatially whiten before envelope extraction. Defaults to True.
        filterbank ([type], optional): set of filters to apply to extract the envelope for each filter output. Defaults to None.
        log (bool, optional): flag if we return raw power or log-power. Defaults to True.
        env_stopband (tuple, optional): post-filter on the extracted envelopes. Defaults to (10,-1).
        verb (bool, optional): verbosity level. Defaults to False.
        plot (bool, optional): flag if we plot the result of each preprocessing step. Defaults to False.

    Returns:
        X: the extracted envelopes
    """
    from multipleCCA import robust_whitener
    from updateSummaryStatistics import updateCxx
    from utils import butter_sosfilt

    if plot:
        import matplotlib.pyplot as plt
        plt.figure(100)
        plt.clf()
        plt.plot(X[:int(fs * 10), :].copy())
        plt.title("raw")

    if not stopband is None:
        if verb > 0: print("preFilter: {}Hz".format(stopband))
        X, _, _ = butter_sosfilt(X, stopband, fs)
        if plot:
            plt.figure(101)
            plt.clf()
            plt.plot(X[:int(fs * 10), :].copy())
            plt.title("hp+notch+lp")
        # preprocess -> spatial whiten
        # TODO[] : make this  fit->transform method

    if whiten:
        if verb > 0: print("spatial whitener")
        Cxx = updateCxx(None, X, None)
        W, _ = robust_whitener(Cxx)
        X = np.einsum("sd,dw->sw", X, W)
        if plot:
            plt.figure(102)
            plt.clf()
            plt.plot(X[:int(fs * 10), :].copy())
            plt.title("+whiten")

    if not filterbank is None:
        if verb > 0: print("Filterbank: {}".format(filterbank))
        if plot:
            plt.figure(103)
            plt.clf()
        # apply filter bank to frequency ranges into virtual channels
        Xs = []
        # TODO: make a nicer shape, e.g. (tr,samp,band,ch)
        # TODO[]: check doesn't modify in place
        for bi, band in enumerate(filterbank):
            Xf, _, _ = butter_sosfilt(X.copy(), band, fs)
            Xs.append(Xf)
            if plot:
                plt.subplot(len(filterbank), 1, bi + 1)
                plt.plot(Xf[:int(fs * 10), :])
                plt.title("+filterbank {}".format(band))
        # stack the bands as virtual channels
        X = np.concatenate(Xs, -1)

    X = np.abs(X)  # rectify
    if plot:
        plt.figure(104)
        plt.plot(X[:int(fs * 10), :])
        plt.title("+abs")

    if log:
        if verb > 0: print("log amplitude")
        X = np.log(np.maximum(X, 1e-6))
        if plot:
            plt.figure(105)
            plt.clf()
            plt.plot(X[:int(fs * 10), :])
            plt.title("+log")

    if env_stopband is not None:
        if verb > 0: print("Envelop band={}".format(env_stopband))
        X, _, _ = butter_sosfilt(X, env_stopband,
                                 fs)  # low-pass = envelope extraction
        if plot:
            plt.figure(104)
            plt.clf()
            plt.plot(X[:int(fs * 10), :])
            plt.title("+env")
    return X
Beispiel #5
0
def load_mark_EMG(datadir, sessdir=None, sessfn=None, ofs=60, stopband=((0,10),(45,55),(95,105),(145,-1)), filterbank=None, verb=0, log=True, whiten=True, plot=False):

    fs=1000
    ch_names=None
    
    # load the data file
    Xfn = os.path.expanduser(datadir)
    if sessdir:
        Xfn = os.path.join(Xfn, sessdir)
    if sessfn:
        Xfn = os.path.join(Xfn, sessfn)
    sessdir = os.path.dirname(Xfn)

    print("Loading {}".format(Xfn))
    data = loadmat(Xfn)

    def squeeze(v):
        while v.size == 1 and v.ndim > 0:
            v = v[0]
        return v

    X = np.array([squeeze(d['buf']) for d in squeeze(data['data'])]) # ( nTrl, nch, nSamp)
    X = np.moveaxis(X,(0,1,2),(0,2,1)) # ( nTr, nSamp, nCh)
    X = np.ascontiguousarray(X) # ensure memory efficient
    lab = np.array([squeeze(e['value']) for e in data['devents']],dtype=int) # (nTrl,)

    import matplotlib.pyplot as plt
    if plot: plt.figure(100);plt.plot(X[0,:,:]);plt.title("raw")

    # preprocess -> spectral filter, in continuous time!
    if stopband is not None:
        if verb > 0:
            print("preFilter: {}Hz".format(stopband))
        X, _, _ = butter_sosfilt(X,stopband,fs)
        if plot:plt.figure(101);plt.plot(X[0,:,:]);plt.title("hp+notch+lp")
        # preprocess -> spatial whiten
        # TODO[] : make this  fit->transform method
        if whiten:
            print("spatial whitener")
            Cxx = updateCxx(None,X,None)
            W,_ = robust_whitener(Cxx)
            X = np.einsum("tsd,dw->tsw",X,W)
            if plot:plt.figure(102);plt.plot(X[0,:,:]);plt.title("+whiten")
            
    if not filterbank is None:
        if verb > 0:  print("Filterbank: {}".format(filterbank))
        # apply filter bank to frequency ranges into virtual channels
        Xs=[]
        # TODO: make a nicer shape, e.g. (tr,samp,band,ch)
        for bi,band in enumerate(filterbank):
            Xf, _, _ = butter_sosfilt(X,band,fs)
            Xs.append(Xf)
        # stack the bands as virtual channels
        X = np.concatenate(Xs,-1)
        
    X = np.abs(X) # rectify
    if log:
        print("log amplitude")
        X = np.log(np.maximum(X,1e-6))
    if plot:plt.figure(103);plt.plot(X[0,:,:]);plt.title("+abs")
    X, _, _ = butter_sosfilt(X,(40,-1),fs) # low-pass = envelope extraction
    if plot:plt.figure(104);plt.plot(X[0,:,:]);plt.title("env")
        
    # preprocess -> downsample @60hz
    resamprate=int(fs/ofs)
    if resamprate > 1:
        if verb > 0:
            print("resample: {}->{}hz rsrate={}".format(fs,ofs,resamprate))
        X = X[:, ::resamprate, :] # decimate X (trl, samp, d)
        fs = fs/resamprate

    # get Y
    Y_true, lab2class = lab2ind(lab) # (nTrl, e)
    Y_true = Y_true[:, np.newaxis, :] # ( nTrl,1,e)
    # TODO[] : exhaustive list of other targets...
    Yall = np.eye(Y_true.shape[-1],dtype=bool) # (nvirt,e)
    Yall = np.tile(Yall,(Y_true.shape[0],1,1)) # (nTrl,nvirt,e)
    Y = np.append(Y_true,Yall,axis=-2) # (nTrl,nvirt+1,e)
    # upsample to ofs
    Y = np.tile(Y[:,np.newaxis,:,:],(1,X.shape[1],1,1)) #  (nTrl, nSamp, nY, e)
    Y = Y.astype(np.float32)
    
    # make coords array for the meta-info about the dimensions of X
    coords = [None]*X.ndim
    coords[0] = {'name':'trial'}
    coords[1] = {'name':'time','unit':'ms', \
                 'coords':np.arange(X.shape[1])*1000/fs, \
                 'fs':fs}
    coords[2] = {'name':'channel','coords':ch_names}
    # return data + metadata
    return (X, Y, coords)