Example #1
0
def load_mTRF_audio(datadir,
                    regressor='envelope',
                    ntrl=15,
                    stopband=((0, .5), (15, -1)),
                    ofs=60,
                    nvirt_out=30,
                    verb=1):
    d = loadmat(datadir)
    X = d['EEG']  # (nSamp,d)
    Y = d[regressor]  # (nSamp,e)
    Y = Y[:, np.newaxis, :]  # (nSamp, nY, e)
    fs = d['Fs'][0][0]
    if ofs is None:
        ofs = fs

    # preprocess -> spectral filter, in continuous time!
    if stopband is not None:
        if verb > 0:
            print("preFilter: {}Hz".format(stopband))
        X, _, _ = butter_sosfilt(X, stopband, fs, axis=-2)

    # generate artificial other stimulus streams, for testing
    Y_test = block_randomize(Y,
                             nvirt_out,
                             axis=-3,
                             block_size=Y.shape[0] // ntrl // 2)
    Y = np.concatenate((Y, Y_test), -2)  # (nSamp, nY, e)

    # slice X,Y into 'trials'
    if ntrl > 1:
        winsz = X.shape[0] // ntrl
        X = window_axis(X, axis=0, winsz=winsz, step=winsz)  # (ntrl,nSamp,d)
        Y = window_axis(Y, axis=0, winsz=winsz,
                        step=winsz)  # (nTrl,nSamp,nY,e)
    else:
        X = [np.newaxis, ...]
        Y = [np.newaxis, ...]

    # preprocess -> downsample
    resamprate = int(fs / ofs)
    if resamprate > 1:
        if verb > 0:
            print("resample: {}->{}hz rsrate={}".format(
                fs, fs / resamprate, resamprate))
        X = X[:, ::resamprate, :]  # decimate X (trl, samp, d)
        Y = Y[:, ::resamprate, :]  # decimate Y (trl, samp, y)
        fs = fs / resamprate

    # make meta-info
    coords = [None] * X.ndim
    coords[0] = {'name': 'trial'}
    coords[1] = {
        'name': 'time',
        'fs': fs,
        'units': 'ms',
        'coords': np.arange(X.shape[1]) * 1000 / fs
    }
    coords[2] = {'name': 'channel', 'coords': None}

    return (X, Y, coords)
def testcase():
    import sys

    # plos_one
    #datadir = '/home/jadref/removable/SD Card/data/bci/own_experiments/noisetagging_v3'
    #sessdir = 's01'
    #sessfn = 'traindata.mat'

    # lowlands
    datadir = '/home/jadref/removable/SD Card/data/bci/own_experiments/lowlands'
    sessdir = ''
    sessfn = 'LL_ENG_39_20170820_tr_train_1.mat' # should perform 100%

    # command-line, for testing
    if len(sys.argv) > 1:
        datadir = sys.argv[1]
    if len(sys.argv) > 2:
        sessdir = sys.argv[2]
    if len(sys.argv) > 3:
        fn = sys.argv[3]

    from load_brainstream import load_brainstream
    X, Y, coords = load_brainstream(datadir, sessdir, sessfn, ofs=60, stopband=((0,5.5),(24,-1)))
    fs = coords[1]['fs']
    ch_names = coords[2]['coords']
    
    from model_fitting import MultiCCA
    from decodingCurveSupervised import decodingCurveSupervised
    cca = MultiCCA(tau=18)
    cca.fit(X, Y)
    print('cca = {}'.format(cca))
    Fy = cca.predict(X, Y, dedup0=True)
    print("Fy={}".format(Fy.shape))
    (_) = decodingCurveSupervised(Fy)

    # test w.r.t. matlab
    from scipy.io import loadmat
    from utils import window_axis
    import numpy as np
    Xe = window_axis(X, axis=-2, winsz=18)
    Ye = window_axis(Y, axis=-2, winsz=1)
    
    matdata=loadmat('/home/jadref/'+sessfn)
    Xm = np.moveaxis(matdata['X'], (0, 1, 2), (2, 1, 0)) # (t,s,d)
    Ym = np.moveaxis(matdata['Y'], (0, 1, 2), (2, 1, 0)) # (t,s,y)
    stem = np.moveaxis(matdata['stimTimes_samp'], (0, 1), (1, 0)) # (t,e)

    cca = MultiCCA(tau=18)
    cca.fit(Xm, Ym)
    cca.score(Xm, Ym)
    (res) = decodingCurveSupervised(cca.predict(Xm,Ym))
    
    print("X-Xm={}".format(np.max(np.abs(X-Xm).ravel())))
    print("Y-Ym={}".format(np.max(np.abs(Y[:, :, 0]-Ym[:, :, 0]).ravel())))

    Xem = np.moveaxis(matdata['Xe'], (0, 1, 2, 3), (3, 2,1,0))# (t,e,tau,d)
    Yem = np.moveaxis(matdata['Ye'], (0, 1, 2), (2, 1, 0)) # (t,e, y)

    # off by 1 in the sliced version, w.r.t. Matlab
    print("Xe-Xem={}".format(np.max(np.abs(Xe[:,1:Xem.shape[1]+1,:,:]-Xem).ravel())))
    
    print("Ye-Yem={}".format(np.max(np.abs(Ye[:,1:Yem.shape[1]+1, 0, 0]-Yem[:, :, 0]).ravel())))
    
    
    cca = MultiCCA(tau=Xem.shape[-2])
    cca.fit(Xem, Yem, stem)
    print('cca = {}'.format(cca))
    Fy = cca.predict(Xem, Yem)
    print("Fy={}".format(Fy.shape))
    (res)=decodingCurveSupervised(Fy);

    # run in stages
    from updateSummaryStatistics import updateSummaryStatistics
    from stim2event import stim2event
    Yeem=np.moveaxis(matdata['Yee'], (0, 1, 2, 3), (3, 2, 1, 0))
    Yeme=stim2event(Yem, ['re', 'fe'], -2) # py convert to brain events
    print("Yeem-Yeme={}".format(np.max(np.abs(Yeem-Yeme).ravel())))
    
    Cxx,Cxy,Cyy=updateSummaryStatistics(Xem, Yeem[:,:,0:1,:], stem)

    Cxxm=matdata['Cxx']
    Cxym=np.moveaxis(matdata['Cxy'],(0,1,2),(2,1,0))
    Cyym=np.moveaxis(matdata['Cyy'],(0,1,2,3),(3,2,1,0))

    print('Cxx-Cxxm={}'.format(np.max(np.abs(Cxx-Cxxm).ravel())))
    print('Cxy-Cxym={}'.format(np.max(np.abs(Cxy-Cxym).ravel())))
    print('Cyy-Cyym={}'.format(np.max(np.abs(Cyy-Cyym).ravel())))
    
    import matplotlib.pyplot as plt
    plt.clf();plt.plot(X[1,:,0]);plt.plot(Xm[0,:,0])
    plt.clf();plt.plot(Y_true.ravel());plt.plot(Ym_true.ravel());
def load_twofinger(datadir, sessdir=None, sessfn=None, ofs=60, stopband=((0,1),(25,-1)), subtriallen=10, nvirt=20, verb=0, ch_idx=slice(32)):
    
    # load the data file
    Xfn = datadir
    if sessdir:
        Xfn = os.path.join(Xfn, sessdir)
    if sessfn:
        Xfn = os.path.join(Xfn, sessfn)
    sessdir = os.path.dirname(Xfn)

    data = loadmat(Xfn)

    def squeeze(v):
        while v.size == 1 and v.ndim > 0:
            v = v[0]
        return v

    fs = 512
    ch_names = [c[0] for c in squeeze(data['chann']).ravel()]
    X = squeeze(data['X'])  # (ch,samp)
    X = np.moveaxis(X,(0,1),(1,0)) # (samp,ch)
    X = X.astype(np.float32)
    X = np.ascontiguousarray(X)
    if ch_idx is not None:
        X = X[:, ch_idx]
        ch_names = ch_names[ch_idx]
    if verb>0: print("X={}".format(X.shape),flush=True)
    
    lab = squeeze(data['Y']).astype(int).ravel() # (samp,)

    # preprocess -> spectral filter, in continuous time!
    if stopband is not None:
        if verb > 0:
            print("preFilter: {}Hz".format(stopband))
        X, _, _ = butter_sosfilt(X,stopband,fs)
    
    # make the targets, for the events we care about
    Y, lab2class = lab2ind(lab,marker2stim.values()) # (nTrl, e) # feature dim per class
    if verb>0: print("Y={}".format(Y.shape))
        
    # preprocess -> downsample
    resamprate = int(fs/ofs)
    if resamprate > 1:
        if verb > 0:
            print("resample: {}->{}hz rsrate={}".format(fs, fs/resamprate, resamprate))
        X = X[..., ::resamprate, :] # decimate X (trl, samp, d)
        # re-sample Y, being sure to keep any events in the re-sample window
        Y = window_axis(Y,winsz=resamprate,step=resamprate,axis=-2) # (trl, samp, win, e)
        Y = np.max(Y,axis=-2) # (trl,samp,e)  N.B. use max so don't loose single sample events
        fs = fs/resamprate
        if verb > 0:
            print("X={}".format(X.shape))
            print("Y={}".format(Y.shape))

    # make virtual targets
    Y = Y[:,np.newaxis,:] # (nsamp,1,e)
    Y_virt = block_randomize(Y, nvirt, axis=-3) # (nsamp,nvirt,e)
    Y = np.concatenate((Y, Y_virt), axis=-2) # (nsamp,1+nvirt,e)
    if verb>0: print("Y={}".format(Y.shape))

    # cut into sub-trials
    nsubtrials = X.shape[0]/fs/subtriallen
    if nsubtrials > 1:
        winsz = int(X.shape[0]//nsubtrials)
        if verb>0: print('subtrial winsz={}'.format(winsz))
        # slice into sub-trials
        X = window_axis(X,axis=0,winsz=winsz,step=winsz) # (trl,win,d)
        Y = window_axis(Y,axis=0,winsz=winsz,step=winsz) # (trl,win,nY)
        if verb>0: 
            print("X={}".format(X.shape))
            print("Y={}".format(Y.shape))
    
    # make coords array for the meta-info about the dimensions of X
    coords = [None]*X.ndim
    coords[0] = {'name':'trial'}
    coords[1] = {'name':'time','unit':'ms', \
                 'coords':np.arange(X.shape[1])/fs, \
                 'fs':fs}
    coords[2] = {'name':'channel','coords':ch_names}
    # return data + metadata
    return (X, Y, coords)
Example #4
0
def load_brainsonfire(datadir,
                      sessdir=None,
                      sessfn=None,
                      ofs=60,
                      stopband=((0, 1), (25, -1)),
                      subtriallen=10,
                      nvirt=20,
                      chIdx=slice(64),
                      verb=2):

    # load the data file
    Xfn = datadir
    if sessdir:
        Xfn = os.path.join(Xfn, sessdir)
    if sessfn:
        Xfn = os.path.join(Xfn, sessfn)
    sessdir = os.path.dirname(Xfn)

    if verb > 1: print("Loading header")
    hdr = read_buffer_offline_header(Xfn)
    if verb > 1: print("Loading data")
    X = read_buffer_offline_data(Xfn, hdr)  # (nsamp,nch)
    if verb > 1: print("Loading events")
    evts = read_buffer_offline_events(Xfn)

    fs = hdr.fs
    ch_names = hdr.labels

    if chIdx is not None:
        X = X[..., chIdx]
        ch_names = ch_names[chIdx] if ch_names is not None else None

    # pre-resample to save memory
    rsrate = int(fs // 120)
    if rsrate > 1:
        if verb > 0:
            print("Pre-re-sample by {}: {}->{}Hz".format(
                rsrate, fs, fs / rsrate))
        X = X[::rsrate, :]
        for e in evts:
            e.sample = e.sample / rsrate
        fs = fs / rsrate

    if verb > 0: print("X={} @{}Hz".format(X.shape, fs), flush=True)

    # extract the trigger info
    trigevts = [e for e in evts if e.type.lower() == trigger_event]
    trig_samp = np.array([e.sample for e in trigevts], dtype=int)
    trig_val = [e.value for e in trigevts]
    trig_ind, lab2class = lab2ind(
        trig_val)  # convert to indicator (ntrig,ncls)
    # up-sample to stim rate
    Y = np.zeros((X.shape[0], trig_ind.shape[-1]), dtype=bool)
    Y[trig_samp, :] = trig_ind
    if verb > 0:
        print("Y={}".format(Y.shape))

    # BODGE: trim to useful data range
    if .1 < (trig_samp[0] - fs) / X.shape[0] or (trig_samp[-1] +
                                                 fs) / X.shape[0] < .9:
        if verb > 0:
            print('Trimming range: {}-{}s'.format(trig_samp[0] / fs,
                                                  trig_samp[-1] / fs))
        # limit to the useful data range
        rng = slice(int(trig_samp[0] - fs), int(trig_samp[-1] + fs))
        X = X[rng, :]
        Y = Y[rng, ...]
        if verb > 0: print("X={}".format(X.shape))
        if verb > 0: print("Y={}".format(Y.shape))

    # preprocess -> spectral filter, in continuous time!
    if stopband is not None:
        if verb > 0:
            print("preFilter: {}Hz".format(stopband))
        X, _, _ = butter_sosfilt(X, stopband, fs)

    # preprocess -> downsample
    resamprate = int(fs / ofs)
    if resamprate > 1:
        if verb > 0:
            print("resample by {}: {}->{}Hz".format(resamprate, fs,
                                                    fs / resamprate))
        X = X[..., ::resamprate, :]  # decimate X (trl, samp, d)
        # re-sample Y, being sure to keep any events in the re-sample window
        Y = window_axis(Y, winsz=resamprate, step=resamprate,
                        axis=-2)  # (trl, samp, win, e)
        Y = np.max(
            Y, axis=-2
        )  # (trl,samp,e)  N.B. use max so don't loose single sample events
        fs = fs / resamprate

    # make virtual targets
    Y = Y[:, np.newaxis, :]  # (nsamp,1,e)
    Y_virt = block_randomize(Y, nvirt, axis=-3)  # (nsamp,nvirt,e)
    Y = np.concatenate((Y, Y_virt), axis=-2)  # (nsamp,1+nvirt,e)
    if verb > 0: print("Y={}".format(Y.shape))

    # cut into sub-trials
    nsubtrials = X.shape[0] / fs / subtriallen
    if nsubtrials > 1:
        winsz = int(X.shape[0] // nsubtrials)
        if verb > 0: print('subtrial winsz={}'.format(winsz))
        # slice into sub-trials
        X = window_axis(X, axis=0, winsz=winsz, step=winsz)  # (trl,win,d)
        Y = window_axis(Y, axis=0, winsz=winsz, step=winsz)  # (trl,win,nY)
        if verb > 0: print("X={}".format(X.shape))
        if verb > 0: print("Y={}".format(Y.shape))

    # make coords array for the meta-info about the dimensions of X
    coords = [None] * X.ndim
    coords[0] = {'name': 'trial'}
    coords[1] = {'name':'time','unit':'ms', \
                 'coords':np.arange(X.shape[1])/fs, \
                 'fs':fs}
    coords[2] = {'name': 'channel', 'coords': ch_names}
    # return data + metadata
    return (X, Y, coords)
def load_cocktail(datadir,
                  sessdir=None,
                  sessfn=None,
                  ofs=60,
                  stopband=((0, 5), (25, -1)),
                  verb=0,
                  trlen_ms=None,
                  subtriallen=10):

    # load the data file
    Xfn = os.path.expanduser(datadir)
    if sessdir:
        Xfn = os.path.join(Xfn, sessdir)
    if sessfn:
        Xfn = os.path.join(Xfn, sessfn)
    sessdir = Xfn if os.path.isdir(Xfn) else os.path.dirname(Xfn)
    stimdir = os.path.join(sessdir, '..', '..', 'Stimuli', 'Envelopes')

    runfns = glob(os.path.join(sessdir, '*.mat'))

    # extract subId (to get attended stream)
    subid = int(sessdir.split("Subject")[1].split("_")[0])
    attended_book = subids_attended_book[subid]
    # extract run id
    runid = [int(f.split("Run")[1].split(".mat")[0]) for f in runfns]
    # sort into numeric order
    sorted_id = argsort(runid)
    runfns = [(runid[i], runfns[i]) for i in sorted_id]

    # load the raw EEG data
    data = [None] * len(runid)
    stim = [None] * len(runid)
    print("Run:", end='')
    for i, (ri, rf) in enumerate(runfns):
        print("{} ".format(ri), end='', flush=True)
        data[i] = loadmat(rf)
        # load
        stim[i] = [
            loadmat(
                os.path.join(stimdir, book, "{}_{}_env.mat".format(book, ri)))
            for book in books
        ]
    # make a label list for the trials
    lab = [books.index(attended_book)] * len(data)
    fs = squeeze(data[0]['fs'])

    if not all(squeeze(d['fs']) == fs for d in data):
        raise ValueError("Different sample rates in different runs")
    #if not all(d['fsEnv'] == fs for d in stim):
    #    raise valueError("Different samples rates in between EEG and Envelope")

    # make the X and Y arrays
    X0 = data[0]['eegData']
    Y0 = stim[0][0]['envelope']
    nSamp = min(X0.shape[0], Y0.shape[0])
    d = X0.shape[1]
    e = Y0.shape[1]
    X = np.zeros((len(runid), nSamp, d), dtype='float32')
    Y = np.zeros((len(runid), nSamp, 1 + len(stim[0]), e),
                 dtype='float32')  #  (nTr,nSamp,nY,e)
    for ti, (d, s) in enumerate(zip(data, stim)):
        X[ti, :, :] = d['eegData'][:nSamp, :]
        Y[ti, :,
          0, :] = s[lab[ti]]['envelope'][:nSamp, :]  # objID==0 is attended
        for si, ss in enumerate(s):  # all possible stimuli
            Y[ti, :, si + 1, :] = ss['envelope'][:nSamp, :]

    print("X={}".format(X.shape), flush=True)

    # preprocess -> spectral filter, in continuous time!
    if stopband is not None:
        if verb > 0:
            print("preFilter: {}Hz".format(stopband))
        X, _, _ = butter_sosfilt(X, stopband, fs)

    # preprocess -> downsample
    resamprate = int(fs / ofs)
    if resamprate > 1:
        if verb > 0:
            print("resample: {}->{}hz rsrate={}".format(fs, ofs, resamprate))
        X = X[:, ::resamprate, :]  # decimate X (trl, samp, d)
        Y = Y[:, ::resamprate, ...]  # decimate Y (trl, samp, y, e)
        fs = fs / resamprate

    nsubtrials = X.shape[1] / fs / subtriallen
    if nsubtrials > 1:
        winsz = int(X.shape[1] // nsubtrials)
        print('{} subtrials -> winsz={}'.format(nsubtrials, winsz))
        # slice into sub-trials
        X = window_axis(X, axis=1, winsz=winsz, step=winsz)  # (trl,win,samp,d)
        Y = window_axis(Y, axis=1, winsz=winsz,
                        step=winsz)  # (trl,win,samp,nY)
        # concatenate windows into trial dim
        X = X.reshape((X.shape[0] * X.shape[1], ) + X.shape[2:])
        Y = Y.reshape((Y.shape[0] * Y.shape[1], ) + Y.shape[2:])
        print("X={}".format(X.shape))

    # make coords array for the meta-info about the dimensions of X
    coords = [None] * X.ndim
    coords[0] = {'name': 'trial'}
    coords[1] = {'name':'time','unit':'ms', \
                 'coords':np.arange(X.shape[1])/fs, \
                 'fs':fs}
    coords[2] = {'name': 'channel', 'coords': ch_names}
    # return data + metadata
    return (X, Y, coords)
Example #6
0
def load_p300_prn(datadir, sessdir=None, sessfn=None, ofs=60, offset_ms=(-1000,1000), ifs=None, fr=None, stopband=((0,1), (25,-1)), order=6, subtriallen=10, verb=0, nvirt=20, chidx=slice(64)):

    # load the data file
    Xfn = datadir
    if sessdir:
        Xfn = os.path.join(Xfn, sessdir)
    if sessfn:
        Xfn = os.path.join(Xfn, sessfn)
    sessdir = os.path.dirname(Xfn)
    try:
        data = loadmat(Xfn)
    except NotImplementedError:
        # TODO[] : make this work correctly -- HDF5 field access is different to loadmat's
        #import h5py
        #data = h5py.File(Xfn, 'r')
        pass

    X = data['X']
    X = X.astype("float32") # [ ch x samp x trl ]:float - raw eeg
    X = np.moveaxis(X, (0, 1, 2), (2, 1, 0)) # (nTrl, nSamp, d)
    X = np.ascontiguousarray(X) # ensure memory efficient layout

    ch_names = np.stack(data['di'][0]['vals'][0][0]).ravel()
    # Extract the sample rate.  Argh!, why so deeply nested?
    if ifs is None:
        fs = data['di'][1]['info'][0]['fs'][0,0][0,0]
    else:
        fs = ifs

    extrainfo = data['di'][2]['extra'][0]
    try:
        Ye = np.stack(extrainfo['flipgrid'][0], -1) # (nY,nEp,nTrl)
    except:
        Ye = None
    Ye0= np.stack(extrainfo['flash'][0], -1) # true-target  (1,nEp,nTrl)
    tgtLetter = extrainfo['target']  # target letter, not needed

    samptimes = data['di'][1]['vals'][0].ravel()   # (nSamp,)
    flashi_ms = np.stack(extrainfo['flashi_ms'][0], -1) #(1,nEp,nTrl)

    # convert flashi_ms to flashi_samp and upsampled Ye to sample rate
    Ye0= np.moveaxis(Ye0, (0, 1, 2), (2, 1, 0)) # (nTrl, nEp, 1)
    stimTimes_ms = np.moveaxis(flashi_ms, (0, 1, 2), (2, 1, 0)) # (nTrl, nEp, 1)
    if Ye is not None:
        Ye = np.moveaxis(Ye,  (0, 1, 2), (2, 1, 0)) # (nTrl, nEp, nY)
    else:
        # make a pseudo-set of alternative targets
        Ye = block_randomize(Ye0[...,np.newaxis],nvirt,-3) #(nTrl,nEp,nvirt,1)
        Ye = Ye[...,0] # (nTrl,nEp,nvirt)
        print("{} virt targets".format(Ye.shape[-1]))
    
    # upsample to sample rate
    stimTimes_samp = np.zeros(stimTimes_ms.shape, dtype=int) # index from trial start for each flash
    Y = np.zeros(X.shape[:-1]+(Ye.shape[-1]+Ye0.shape[-1],), dtype='float32') # (nTrl, nEP, nY+1)
    for ti in range(Y.shape[0]):
        lastflash = None
        flashi_trli = stimTimes_ms[ti, :, 0]
        for fi, flash_time_ms in enumerate(flashi_trli):
            # find nearest sample time
            si = np.argmin(np.abs(samptimes - flash_time_ms))
            stimTimes_samp[ti, fi, 0] = si
            
            if lastflash: # hold until new values
                Y[ti, lastflash+1:si, :] = Y[ti, lastflash, :]
                
            Y[ti, si, 0]  = Ye0[ti, fi, 0] # true info always 1st row
            Y[ti, si, 1:] = Ye[ti, fi, :]  # rest possiblities
            lastflash = si
    # for comparsion...
    #print("{}".format(np.array(np.mean(stimTimes_samp, axis=0),dtype=int).ravel()))
    
    # preprocess -> ch-seln
    if chidx is not None:
        X=X[...,chidx]
        ch_names = ch_names[chidx]

    # Trim to useful data range
    stimRng = ( np.min(stimTimes_samp[:,0,0]+offset_ms[0]*fs/1000),
                np.max(stimTimes_samp[:,-1,0]+offset_ms[1]*fs/1000) )
    print("stimRng={}".format(stimRng))
    if  0 < stimRng[0] or stimRng[1] < X.shape[-2]:
        if verb>-1 : print('Trimming range: {}-{}ms'.format(stimRng[0]/fs,stimRng[-1]/fs))
        # limit to the useful data range
        rng = slice(int(max(0,stimRng[0])), int(min(X.shape[-2],stimRng[1])))
        X = X[..., rng, :]
        Y = Y[..., rng, :]
        if verb > 0: print("X={}".format(X.shape))
        if verb > 0: print("Y={}".format(Y.shape))

    # preprocess -> spectral filter
    if stopband is not None:
        if verb > 0:
            print("preFilter: {}Hz".format(stopband))
        X, _, _ = butter_sosfilt(X,stopband,fs,order=order)
    
    # preprocess -> downsample 
    resamprate = int(round(fs/ofs))
    if resamprate > 1:
        if verb > 0:
            print("resample: {}->{}hz rsrate={}".format(fs, ofs, resamprate))
        X = X[:, ::resamprate, :] # decimate X (trl, samp, d)
        Y = Y[:, ::resamprate, :] # decimate Y (trl, samp, y)
        fs = fs/resamprate

    nsubtrials = X.shape[1]/fs/subtriallen if subtriallen is not None else 0
    if nsubtrials > 1:
        winsz = int(X.shape[1]//nsubtrials)
        print('{} subtrials -> winsz={}'.format(nsubtrials,winsz))
        # slice into sub-trials
        X = window_axis(X,axis=1,winsz=winsz,step=winsz) # (trl,win,samp,d)
        Y = window_axis(Y,axis=1,winsz=winsz,step=winsz) # (trl,win,samp,nY)
        # concatenate windows into trial dim
        X = X.reshape((X.shape[0]*X.shape[1],)+X.shape[2:])
        Y = Y.reshape((Y.shape[0]*Y.shape[1],)+Y.shape[2:])
        if verb>0 : print("X={}".format(X.shape))

    # make coords array for the meta-info about the dimensions of X
    coords = [None]*X.ndim
    coords[0] = {'name':'trial','coords':np.arange(X.shape[0])}
    coords[1] = {'name':'time','unit':'ms', \
                 'coords':np.arange(X.shape[1]) * 1000/fs, \
                 'fs':fs}
    coords[2] = {'name':'channel','coords':ch_names}
    
    return (X, Y, coords)