def spatially_whiten(X: np.ndarray, *args, **kwargs): """spatially whiten the nd-array X Args: X (np.ndarray): the data to be whitened, with channels/space in the *last* axis Returns: X (np.ndarray): the whitened X W (np.ndarray): the whitening matrix used to whiten X """ Cxx = updateCxx(None, X, None) W, _ = robust_whitener(Cxx, *args, **kwargs) X = X @ W #np.einsum("...d,dw->...w",X,W) return (X, W)
minsubset = ('C3', 'C4') chsubset = minsubset keep = [c in chsubset for c in ch_names] X = X[..., keep] ch_names = [ch_names[i] for i in range(len(ch_names)) if keep[i]] plt.figure(100) plot_erp(X, lab, 'car') # hp-lp X, _, _ = butter_sosfilt(X, stopband=((0, 8), (16, -1)), fs=fs) plt.figure(101) plot_erp(X, lab, 'hp-lp', plotp=True) # whiten Cxx = updateCxx(None, X, None) W, _ = robust_whitener(Cxx) X = np.einsum("Tsd,dw->Tsw", X, W) #Cxxw=updateCxx(None,X,None) plt.figure(102) plot_erp(X, lab, 'wht') # welch freqs, X = welch(X, fs, axis=-2, nperseg=int(fs * .5), noverlap=.5, return_onesided=True, scaling='spectrum') plt.figure(103)
def load_ninapro_db2(datadir, stopband=((0, 15), (45, 65), (95, 125), (250, -1)), envelopeband=(10, -1), trlen_ms=None, ofs=60, nvirt=20, rectify=True, whiten=True, log=True, plot=False, filterbank=None, zscore_y=True, verb=1): d = loadmat(datadir, variable_names=('emg', 'glove', 'stimulus')) X = d['emg'] # (nSamp,d) Y = d['glove'] # (nSamp,e) lab = d['stimulus'].ravel() # (nSamp,1) - use to slice out trials+labels fs = 2000 if ofs is None: ofs = fs # get trial start/end info trl_start = np.flatnonzero(np.diff(lab) > 0) lab = lab[trl_start + 1] print('trl_start={}'.format(trl_start)) print('label={}'.format(lab)) print("diff(trl_start)={}".format(np.diff(trl_start))) if trlen_ms is None: trlen_ms = np.max(np.diff(trl_start)) * 1000 / fs print('trlen_ms={}'.format(trlen_ms)) if not stopband is None: if verb > 0: print("preFilter: {}Hz".format(stopband)) X, _, _ = butter_sosfilt(X, stopband, fs) if plot: plt.figure(101) plt.plot(X) plt.title("hp+notch+lp") # preprocess -> spatial whiten # TODO[] : make this fit->transform method if whiten: if verb > 0: print("spatial whitener") Cxx = updateCxx(None, X, None) W, _ = robust_whitener(Cxx) X = np.einsum("sd,dw->sw", X, W) if plot: plt.figure(102) plt.plot(X) plt.title("+whiten") if not filterbank is None: if verb > 0: print("Filterbank: {}".format(filterbank)) # apply filter bank to frequency ranges into virtual channels Xs = [] # TODO: make a nicer shape, e.g. (tr,samp,band,ch) for bi, band in enumerate(filterbank): Xf, _, _ = butter_sosfilt(X, band, fs) Xs.append(Xf) # stack the bands as virtual channels X = np.concatenate(Xs, -1) X = np.abs(X) # rectify if log: if verb > 0: print("log amplitude") X = np.log(np.maximum(X, 1e-6)) if plot: plt.figure(103) plt.plot(X) plt.title("+abs") if envelopeband is not None: if verb > 0: print("Envelop band={}".format(envelopeband)) X, _, _ = butter_sosfilt(X, envelopeband, fs) # low-pass = envelope extraction if plot: plt.figure(104) plt.plot(X) plt.title("env") # preprocess -> downsample resamprate = int(fs / ofs) if resamprate > 1: if verb > 0: print("resample: {}->{}hz rsrate={}".format( fs, fs / resamprate, resamprate)) X = X[..., ::resamprate, :] # decimate X (trl, samp, d) Y = Y[..., ::resamprate, :] # decimate Y (trl, samp, e) trl_start = trl_start / resamprate fs = fs / resamprate # pre-process : z-trans Y if zscore_y: if verb > 0: print("Z-trans Y") mu = np.mean(Y, axis=-2, keepdims=True) std = np.std(Y, axis=-2, keepdims=True) std[std < 1e-6] = 1 # guard divide by 0 Y = (Y - mu) / std # generate artificial other stimulus streams, for testing # TODO: randomize in better way Y = Y[:, np.newaxis, :] # (nSamp, nY, e) Y_test = block_randomize(Y, nvirt, axis=-3, block_size=Y.shape[0] // 100 // 2) Y = np.concatenate((Y, Y_test), -2) # (nSamp, nY, e) # slice X,Y into trials oX = X # (nSamp,d) oY = Y # (nSamp,nY,e) trlen_samp = int(trlen_ms * fs / 1000) X = np.zeros((trl_start.size, trlen_samp, X.shape[-1])) Y = np.zeros((trl_start.size, trlen_samp) + Y.shape[-2:]) print("Slicing {} trials of {}ms".format(len(trl_start), trlen_ms)) for ti, tii in enumerate(trl_start): tii = int(tii) trl_len = min(oX.shape[0], tii + trlen_samp) - tii X[ti, :trl_len, ...] = oX[tii:tii + trl_len, ...] Y[ti, :trl_len, ...] = oY[tii:tii + trl_len, ...] # make meta-info coords = [None] * X.ndim coords[0] = {'name': 'trial', 'coords': lab} coords[1] = { 'name': 'time', 'fs': fs, 'units': 'ms', 'coords': np.arange(X.shape[1]) * 1000 / fs } coords[2] = {'name': 'channel', 'coords': None} return (X, Y, coords)
def extract_envelope(X, fs, stopband=None, whiten=True, filterbank=None, log=True, env_stopband=(10, -1), verb=False, plot=False): """extract the envelope from the input data Args: X ([type]): [description] fs ([type]): [description] stopband ([type], optional): pre-filter stop band. Defaults to None. whiten (bool, optional): flag if we spatially whiten before envelope extraction. Defaults to True. filterbank ([type], optional): set of filters to apply to extract the envelope for each filter output. Defaults to None. log (bool, optional): flag if we return raw power or log-power. Defaults to True. env_stopband (tuple, optional): post-filter on the extracted envelopes. Defaults to (10,-1). verb (bool, optional): verbosity level. Defaults to False. plot (bool, optional): flag if we plot the result of each preprocessing step. Defaults to False. Returns: X: the extracted envelopes """ from multipleCCA import robust_whitener from updateSummaryStatistics import updateCxx from utils import butter_sosfilt if plot: import matplotlib.pyplot as plt plt.figure(100) plt.clf() plt.plot(X[:int(fs * 10), :].copy()) plt.title("raw") if not stopband is None: if verb > 0: print("preFilter: {}Hz".format(stopband)) X, _, _ = butter_sosfilt(X, stopband, fs) if plot: plt.figure(101) plt.clf() plt.plot(X[:int(fs * 10), :].copy()) plt.title("hp+notch+lp") # preprocess -> spatial whiten # TODO[] : make this fit->transform method if whiten: if verb > 0: print("spatial whitener") Cxx = updateCxx(None, X, None) W, _ = robust_whitener(Cxx) X = np.einsum("sd,dw->sw", X, W) if plot: plt.figure(102) plt.clf() plt.plot(X[:int(fs * 10), :].copy()) plt.title("+whiten") if not filterbank is None: if verb > 0: print("Filterbank: {}".format(filterbank)) if plot: plt.figure(103) plt.clf() # apply filter bank to frequency ranges into virtual channels Xs = [] # TODO: make a nicer shape, e.g. (tr,samp,band,ch) # TODO[]: check doesn't modify in place for bi, band in enumerate(filterbank): Xf, _, _ = butter_sosfilt(X.copy(), band, fs) Xs.append(Xf) if plot: plt.subplot(len(filterbank), 1, bi + 1) plt.plot(Xf[:int(fs * 10), :]) plt.title("+filterbank {}".format(band)) # stack the bands as virtual channels X = np.concatenate(Xs, -1) X = np.abs(X) # rectify if plot: plt.figure(104) plt.plot(X[:int(fs * 10), :]) plt.title("+abs") if log: if verb > 0: print("log amplitude") X = np.log(np.maximum(X, 1e-6)) if plot: plt.figure(105) plt.clf() plt.plot(X[:int(fs * 10), :]) plt.title("+log") if env_stopband is not None: if verb > 0: print("Envelop band={}".format(env_stopband)) X, _, _ = butter_sosfilt(X, env_stopband, fs) # low-pass = envelope extraction if plot: plt.figure(104) plt.clf() plt.plot(X[:int(fs * 10), :]) plt.title("+env") return X
def load_mark_EMG(datadir, sessdir=None, sessfn=None, ofs=60, stopband=((0,10),(45,55),(95,105),(145,-1)), filterbank=None, verb=0, log=True, whiten=True, plot=False): fs=1000 ch_names=None # load the data file Xfn = os.path.expanduser(datadir) if sessdir: Xfn = os.path.join(Xfn, sessdir) if sessfn: Xfn = os.path.join(Xfn, sessfn) sessdir = os.path.dirname(Xfn) print("Loading {}".format(Xfn)) data = loadmat(Xfn) def squeeze(v): while v.size == 1 and v.ndim > 0: v = v[0] return v X = np.array([squeeze(d['buf']) for d in squeeze(data['data'])]) # ( nTrl, nch, nSamp) X = np.moveaxis(X,(0,1,2),(0,2,1)) # ( nTr, nSamp, nCh) X = np.ascontiguousarray(X) # ensure memory efficient lab = np.array([squeeze(e['value']) for e in data['devents']],dtype=int) # (nTrl,) import matplotlib.pyplot as plt if plot: plt.figure(100);plt.plot(X[0,:,:]);plt.title("raw") # preprocess -> spectral filter, in continuous time! if stopband is not None: if verb > 0: print("preFilter: {}Hz".format(stopband)) X, _, _ = butter_sosfilt(X,stopband,fs) if plot:plt.figure(101);plt.plot(X[0,:,:]);plt.title("hp+notch+lp") # preprocess -> spatial whiten # TODO[] : make this fit->transform method if whiten: print("spatial whitener") Cxx = updateCxx(None,X,None) W,_ = robust_whitener(Cxx) X = np.einsum("tsd,dw->tsw",X,W) if plot:plt.figure(102);plt.plot(X[0,:,:]);plt.title("+whiten") if not filterbank is None: if verb > 0: print("Filterbank: {}".format(filterbank)) # apply filter bank to frequency ranges into virtual channels Xs=[] # TODO: make a nicer shape, e.g. (tr,samp,band,ch) for bi,band in enumerate(filterbank): Xf, _, _ = butter_sosfilt(X,band,fs) Xs.append(Xf) # stack the bands as virtual channels X = np.concatenate(Xs,-1) X = np.abs(X) # rectify if log: print("log amplitude") X = np.log(np.maximum(X,1e-6)) if plot:plt.figure(103);plt.plot(X[0,:,:]);plt.title("+abs") X, _, _ = butter_sosfilt(X,(40,-1),fs) # low-pass = envelope extraction if plot:plt.figure(104);plt.plot(X[0,:,:]);plt.title("env") # preprocess -> downsample @60hz resamprate=int(fs/ofs) if resamprate > 1: if verb > 0: print("resample: {}->{}hz rsrate={}".format(fs,ofs,resamprate)) X = X[:, ::resamprate, :] # decimate X (trl, samp, d) fs = fs/resamprate # get Y Y_true, lab2class = lab2ind(lab) # (nTrl, e) Y_true = Y_true[:, np.newaxis, :] # ( nTrl,1,e) # TODO[] : exhaustive list of other targets... Yall = np.eye(Y_true.shape[-1],dtype=bool) # (nvirt,e) Yall = np.tile(Yall,(Y_true.shape[0],1,1)) # (nTrl,nvirt,e) Y = np.append(Y_true,Yall,axis=-2) # (nTrl,nvirt+1,e) # upsample to ofs Y = np.tile(Y[:,np.newaxis,:,:],(1,X.shape[1],1,1)) # (nTrl, nSamp, nY, e) Y = Y.astype(np.float32) # make coords array for the meta-info about the dimensions of X coords = [None]*X.ndim coords[0] = {'name':'trial'} coords[1] = {'name':'time','unit':'ms', \ 'coords':np.arange(X.shape[1])*1000/fs, \ 'fs':fs} coords[2] = {'name':'channel','coords':ch_names} # return data + metadata return (X, Y, coords)