def bci2000chain(datfile, chain='SpectralSignalProcessing', parms=(), dims='auto', start=None, duration=None, verbose=False, keep=False, binpath=None, **kwargs): """ This is a provisional Python port of the Matlab-based tools/matlab/bci2000chain.m Excuse the relative paths - this example works if you're currently working in the root directory of the BCI2000 distro: bci2000chain( datfile='data/samplefiles/eeg3_2.dat', chain='TransmissionFilter|SpatialFilter|ARFilter', binpath='tools/cmdline', parms=[ 'tools/matlab/ExampleParameters.prm' ], SpatialFilterType=3 ) Or for more portable operation, you can do this kind of thing: bci2000root( '/PATH/TO/BCI2000' ) bci2000chain(datfile=bci2000path( 'data/samplefiles/eeg3_2.dat' ), chain='TransmissionFilter|SpatialFilter|ARFilter', parms=[ bci2000path( 'tools/matlab/ExampleParameters.prm' ) ], SpatialFilterType=3 ) For clues, see http://www.bci2000.org/wiki/index.php/User_Reference:Matlab_Tools and matlab documentation in bci2000chain.m Most arguments are like the flags in the Matlab equivalent. mutatis to a large extent mutandis. Note that for now there is no global way of managing the system path. Either add the tools/cmdline directory to your $PATH variable before starting Python, or supply its location every time while calling, in the <binpath> argument. """### if verbose: verbosityFlag = '-v' else: verbosityFlag = '-q' if isinstance(chain, str): if chain.lower() == 'SpectralSignalProcessing'.lower(): chain = 'TransmissionFilter|SpatialFilter|SpectralEstimator|LinearClassifier|LPFilter|ExpressionFilter|Normalizer' elif chain.lower() == 'ARSignalProcessing'.lower(): chain = 'TransmissionFilter|SpatialFilter|ARFilter|LinearClassifier|LPFilter|ExpressionFilter|Normalizer' elif chain.lower() == 'P3SignalProcessing'.lower(): chain = 'TransmissionFilter|SpatialFilter|P3TemporalFilter|LinearClassifier' chain = chain.split('|') chain = [c.strip() for c in chain if len(c.strip())] if len(chain) == 0: print('WARNING: chain is empty') if start != None and len(str(start).strip()): start = ' --start=' + str(start).replace(' ', '') else: start = '' if duration != None and len(str(duration).strip()): duration = ' --duration=' + str(duration).replace(' ', '') else: duration = '' if dims == None or str(dims).lower() == 'auto': dims = 0 if dims not in (0, 2, 3): raise ValueError("dims must be 2, 3 or 'auto'") out = CONTAINER() err = '' cmd = '' binaries = [] tmpdir = tempfile.mkdtemp(prefix='bci2000chain_') tmpdatfile = os.path.join(tmpdir, 'in.dat') prmfile_in = os.path.join(tmpdir, 'in.prm') prmfile_out = os.path.join(tmpdir, 'out.prm') matfile = os.path.join(tmpdir, 'out.mat') bcifile = os.path.join(tmpdir, 'out.bci') shfile = os.path.join(tmpdir, 'go.bat') logfile = os.path.join(tmpdir, 'log.txt') if not isinstance(datfile, str): raise ValueError( 'datfile must be a filename' ) # TODO: if datfile contains the appropriate info, use some create_bcidat equivalent and do datfile = tmpdatfile if not os.path.isfile(datfile): raise IOError('file not found: %s' % datfile) mappings = { '$DATFILE': datfile, '$PRMFILE_IN': prmfile_in, '$PRMFILE_OUT': prmfile_out, '$MATFILE': matfile, '$BCIFILE': bcifile, '$SHFILE': shfile, '$LOGFILE': logfile, } if binpath == None and BCI2000_ROOT_DIR != None: binpath = 'tools/cmdline' binpath = bci2000path(binpath) def exe(name): if binpath: return '"' + os.path.join(binpath, name) + '"' else: return name if parms == None: parms = [] if isinstance(parms, tuple): parms = list(parms) if not isinstance(parms, list): parms = [parms] else: parms = list(parms) if len(kwargs): parms.append(kwargs) if parms == None or len(parms) == 0: cmd += exe('bci_dat2stream') + start + duration + ' < "$DATFILE"' binaries.append('bci_dat2stream') else: if verbose: print('# writing custom parameter file %s' % prmfile_in) parms = make_bciprm(verbosityFlag, datfile, parms, '>', prmfile_in) if dat2stream_has_p_flag: cmd += exe( 'bci_dat2stream' ) + ' "-p$PRMFILE_IN"' + start + duration + ' < "$DATFILE"' # new-style bci_dat2stream with -p option binaries.append('bci_dat2stream') else: if len(start) or len(duration): raise ValueError( 'old versionsof bci_dat2stream have no --start or --duration option' ) cmd += '(' # old-style bci_dat2stream with no -p option cmd += exe('bci_prm2stream') + ' < "$PRMFILE_IN"' cmd += '&& ' + exe( 'bci_dat2stream') + ' --transmit-sd < "$DATFILE"' cmd += ')' binaries.append('bci_dat2stream') binaries.append('bci_prm2stream') for c in chain: cmd += ' | ' + exe(c) binaries.append(c) if stream2mat_saves_parms: cmd += ' | ' + exe( 'bci_stream2mat' ) + ' > "$MATFILE"' # new-style bci_stream2mat with Parms output binaries.append('bci_stream2mat') else: cmd += ' > "$BCIFILE"' cmd += ' && ' + exe('bci_stream2mat') + ' < "$BCIFILE" > "$MATFILE"' cmd += ' && ' + exe( 'bci_stream2prm' ) + ' < "$BCIFILE" > "$PRMFILE_OUT"' # old-style bci_stream2mat without Parms output binaries.append('bci_stream2mat') binaries.append('bci_stream2prm') for k, v in list(mappings.items()): cmd = cmd.replace(k, v) win = sys.platform.lower().startswith('win') if win: shebang = '@' else: shebang = '#!/bin/sh\n\n' open(shfile, 'wt').write(shebang + cmd + '\n') if not win: os.chmod(shfile, 484) # rwxr--r-- def tidytext(x): return x.strip().replace('\r\n', '\n').replace('\r', '\n') def getstatusoutput(cmd): pipe = os.popen( cmd + ' 2>&1', 'r' ) # TODO: does this work on Windows? could always make use of logfile here if not output = pipe.read() status = pipe.close() if status == None: status = 0 return status, tidytext(output) def getoutput(cmd): return getstatusoutput(cmd)[1] if verbose: print('# querying version information') binaries = CONTAINER([ (bin, getoutput(exe(bin) + ' --version').replace('\n', ' ')) for bin in binaries ]) if verbose: print(cmd) t0 = time.time() failed, output = getstatusoutput(shfile) chaintime = time.time() - t0 failsig = 'Configuration Error: ' if failsig in output: failed = 1 printable_output = output printable_lines = output.split('\n') maxlines = 10 if len(printable_lines) > maxlines: printable_output = '\n'.join( printable_lines[:maxlines] + ['[%d more lines omitted]' % (len(printable_lines) - maxlines)]) if failed: if verbose: err = 'system call failed:\n' + printable_output # cmd has already been printed, so don't clutter things further else: err = 'system call failed:\n\n%s\n\n%s' % (cmd, printable_output) if err == '': if verbose: print('# loading %s' % matfile) try: mat = ReadMatFile(matfile) except: err = "The chain must have failed: could not load %s\nShell output was as follows:\n%s" % ( matfile, printable_output) else: if 'Data' not in mat: err = "The chain must have failed: no 'Data' variable found in %s\nShell output was as follows:\n%s" % ( matfile, printable_output) if 'Index' not in mat: err = "The chain must have failed: no 'Index' variable found in %s\nShell output was as follows:\n%s" % ( matfile, printable_output) if err == '': out.FileName = datfile if stream2mat_saves_parms: if verbose: print('# decoding parameters loaded from the mat-file') parms = make_bciprm(verbosityFlag, mat['Parms'][0]) else: if verbose: print('# reading output parameter file' + prmfile_out) parms = ParmList( prmfile_out ) # if you get an error that prmfile_out does not exist, recompile your bci_dat2stream and bci_stream2mat binaries from up-to-date sources, and ensure that dat2stream_has_p_flag and stream2mat_saves_parms, at the top of this file, are both set to 1 out.DateStr = read_bcidate(parms, 'ISO') out.DateNum = read_bcidate(parms) out.FilterChain = chain out.ToolVersions = binaries out.ShellInput = cmd out.ShellOutput = output out.ChainTime = chaintime out.ChainSpeedFactor = None out.Megabytes = None out.Parms = parms out.Parms.sort() mat['Index'] = mat['Index'][0, 0] sigind = mat[ 'Index'].Signal - 1 # indices vary across channels fastest, then elements nChannels, nElements = sigind.shape nBlocks = mat['Data'].shape[1] out.Blocks = nBlocks out.BlocksPerSecond = float(parms.SamplingRate.ScaledValue) / float( parms.SampleBlockSize.ScaledValue) out.SecondsPerBlock = float(parms.SampleBlockSize.ScaledValue) / float( parms.SamplingRate.ScaledValue) out.ChainSpeedFactor = float(out.Blocks * out.SecondsPerBlock) / float( out.ChainTime) def unnumpify(x): while isinstance(x, numpy.ndarray) and x.size == 1: x = x[0] if isinstance(x, (numpy.ndarray, tuple, list)): x = [unnumpify(xi) for xi in x] return x out.Channels = nChannels out.ChannelLabels = unnumpify(mat.get('ChannelLabels', [])) try: out.ChannelSet = SigTools.ChannelSet(out.ChannelLabels) except: out.ChannelSet = None out.Elements = nElements out.ElementLabels = unnumpify(mat.get('ElementLabels', [])) out.ElementValues = numpy.ravel(mat.get('ElementValues', [])) out.ElementUnit = unnumpify(mat.get('ElementUnit', None)) out.ElementRate = out.BlocksPerSecond * out.Elements out.Time = out.SecondsPerBlock * numpy.arange(0, nBlocks) out.FullTime = out.Time out.FullElementValues = out.ElementValues # Why does sigind have to be transposed before vectorizing to achieve the same result as sigind(:) WITHOUT a transpose in Matlab? # This will probably forever be one of the many deep mysteries of numpy dimensionality handling out.Signal = mat['Data'][ sigind.T.ravel(), :] # nChannels*nElements - by - nBlocks out.Signal = out.Signal + 0.0 # make a contiguous copy def seconds( s ): # -1 means "no information", 0 means "not units of time", >0 means the scaling factor if getattr(s, 'ElementUnit', None) in ('', None, ()) or s.ElementUnit == []: return -1 s = s.ElementUnit if s.endswith('seconds'): s = s[:-6] elif s.endswith('second'): s = s[:-5] elif s.endswith('sec'): s = s[:-2] if s.endswith('s'): return { 'ps': 1e-12, 'ns': 1e-9, 'us': 1e-6, 'mus': 1e-6, 'ms': 1e-3, 's': 1e+0, 'ks': 1e+3, 'Ms': 1e+6, 'Gs': 1e+9, 'Ts': 1e+12, }.get(s, 0) else: return 0 if dims == 0: # dimensionality has not been specified explicitly: so guess, based on ElementUnit and/or filter name # 3-dimensional output makes more sense than continuous 2-D whenever "elements" can't just be concatenated into an unbroken time-stream if len(chain): lastfilter = chain[-1].lower() else: lastfilter = '' if lastfilter == 'p3temporalfilter': dims = 3 else: factor = seconds(out) if factor > 0: # units of time. TODO: could detect whether the out.ElementValues*factor are (close enough to) contiguous from block to block; then p3temporalfilter wouldn't have to be a special case above dims = 2 elif factor == 0: # not units of time: use 3D by default dims = 3 elif lastfilter in [ 'p3temporalfilter', 'arfilter', 'fftfilter', 'coherencefilter', 'coherencefftfilter' ]: # no ElementUnit info? guess based on filter name dims = 3 else: dims = 2 if dims == 3: out.Signal = numpy.reshape( out.Signal, (nChannels, nElements, nBlocks), order='F') # nChannels - by - nElements - by - nBlocks out.Signal = numpy.transpose( out.Signal, (2, 0, 1)) # nBlocks - by - nChannels - by - nElements out.Signal = out.Signal + 0.0 # make a contiguous copy elif dims == 2: out.FullTime = numpy.repeat(out.Time, nElements) factor = seconds(out) if len(out.ElementValues): out.FullElementValues = numpy.tile(out.ElementValues, nBlocks) if factor > 0: out.FullTime = out.FullTime + out.FullElementValues * factor out.Signal = numpy.reshape(out.Signal, (nChannels, nElements * nBlocks), order='F') # nChannels - by - nSamples out.Signal = numpy.transpose(out.Signal, (1, 0)) # nSamples - by - nChannels out.Signal = out.Signal + 0.0 # make a contiguous copy else: raise RuntimeError('internal error') out.States = CONTAINER() try: fieldnames = mat['Index']._fieldnames except AttributeError: items = list(mat['Index'].items()) else: items = [(k, getattr(mat['Index'], k)) for k in fieldnames] states = [(k, int(v) - 1) for k, v in items if k != 'Signal'] for k, v in states: setattr(out.States, k, mat['Data'][v, :]) # TODO: how do the command-line tools handle event states? this seems to be set up to deliver one value per block whatever kind of state we're dealing with out.Megabytes = megs(out) else: out = err keep = True print(err) print() if os.path.isdir(tmpdir): files = sorted([ os.path.join(tmpdir, file) for file in os.listdir(tmpdir) if file not in ['.', '..'] ]) if keep: print( 'The following commands should be executed to clean up the temporary files:' ) elif verbose: print('# removing temp files and directory ' + tmpdir) for file in files: if keep: print("os.remove(r'%s')" % file) else: try: os.remove(file) except Exception as err: sys.stderr.write("failed to remove %s:\n %s\n" % (file, str(err))) if keep: print("os.rmdir(r'%s')" % tmpdir) else: try: os.rmdir(tmpdir) except Exception as err: sys.stderr.write("failed to remove %s:\n %s" % (tmpdir, str(err))) if keep: print("") return out
def ClassifyERPs( featurefiles, C=(10.0, 1.0, 0.1, 0.01), gamma=(1.0, 0.8, 0.6, 0.4, 0.2, 0.0), keepchan=(), rmchan=(), rmchan_usualsuspects=('AUDL', 'AUDR', 'LAUD', 'RAUD', 'SYNC', 'VSYNC', 'VMRK', 'OLDREF'), rebias=True, save=False, select=False, description='ERPs to attended vs unattended events', maxcount=None, classes=None, folds=None, time_window=None, keeptrials=None, ): file_inventory = [] d = DataFiles.load(featurefiles, catdim=0, maxcount=maxcount, return_details=file_inventory) if isinstance(folds, basestring) and folds.lower() in [ 'lofo', 'loro', 'leave on run out', 'leave one file out' ]: n, folds = 0, [] for each in file_inventory: neach = each[1]['x'] folds.append(range(n, n + neach)) n += neach if 'x' not in d: raise ValueError( "found no trial data - no 'x' variable - in the specified files") if 'y' not in d: raise ValueError( "found no trial labels - no 'y' variable - in the specified files") x = d['x'] y = numpy.array(d['y'].flat) if keeptrials != None: x = x[numpy.asarray(keeptrials), :, :] y = y[numpy.asarray(keeptrials)] if time_window != None: fs = d['fs'] t = SigTools.samples2msec(numpy.arange(x.shape[2]), fs) x[:, :, t < min(time_window)] = 0 x[:, :, t > max(time_window)] = 0 if classes != None: for cl in classes: if cl not in y: raise ValueError("class %s is not in the dataset" % str(cl)) mask = numpy.array([yi in classes for yi in y]) y = y[mask] x = x[mask] discarded = sum(mask == False) if discarded: print "discarding %d trials that are outside the requested classes %s" % ( discarded, str(classes)) n = len(y) uy = numpy.unique(y) if uy.size != 2: raise ValueError("expected 2 classes in dataset, found %d : %s" % (uy.size, str(uy))) for uyi in uy: nyi = sum([yi == uyi for yi in y]) nyi_min = 2 if nyi < nyi_min: raise ValueError( 'only %d exemplars of class %s - need at least %d' % (nyi, str(uyi), nyi_min)) y = numpy.sign(y - uy.mean()) cov, trchvar = SigTools.spcov( x=x, y=y, balance=False, return_trchvar=True) # NB: symwhitenkern would not be able to balance starttime = time.time() chlower = [ch.lower() for ch in d['channels']] if keepchan in [None, (), '', []]: if isinstance(rmchan, basestring): rmchan = rmchan.split() if isinstance(rmchan_usualsuspects, basestring): rmchan_usualsuspects = rmchan_usualsuspects.split() allrmchan = [ ch.lower() for ch in list(rmchan) + list(rmchan_usualsuspects) ] unwanted = numpy.array([ch in allrmchan for ch in chlower]) notfound = [ch for ch in rmchan if ch.lower() not in chlower] else: if isinstance(keepchan, basestring): keepchan = keepchan.split() lowerkeepchan = [ch.lower() for ch in keepchan] unwanted = numpy.array([ch not in lowerkeepchan for ch in chlower]) notfound = [ch for ch in keepchan if ch.lower() not in chlower] wanted = numpy.logical_not(unwanted) print ' ' if len(notfound): print "WARNING: could not find channel%s %s\n" % ({ 1: '' }.get(len(notfound), 's'), ', '.join(notfound)) removed = [ch for removing, ch in zip(unwanted, d['channels']) if removing] if len(removed): print "removed %d channel%s (%s)" % (len(removed), { 1: '' }.get(len(removed), 's'), ', '.join(removed)) print "classification will be based on %d channel%s" % (sum(wanted), { 1: '' }.get(sum(wanted), 's')) print "%d negatives + %d positives = %d exemplars" % (sum(y < 0), sum(y > 0), n) print ' ' x[:, unwanted, :] = 0 cov[:, unwanted] = 0 cov[unwanted, :] = 0 nu = numpy.asarray(cov).diagonal()[wanted].mean() for i in range(len(cov)): if cov[i, i] == 0: cov[i, i] = nu if not isinstance(C, (tuple, list, numpy.ndarray, type(None))): C = [C] if not isinstance(gamma, (tuple, list, numpy.ndarray, type(None))): gamma = [gamma] c = SigTools.klr2class(lossfunc=SigTools.balanced_loss, relcost='balance') c.varyhyper({}) if c != None: c.hyper.C = list(C) if gamma == None: c.hyper.kernel.func = SigTools.linkern else: c.varyhyper({ 'kernel.func': SigTools.symwhitenkern, 'kernel.cov': [cov], 'kernel.gamma': list(gamma) }) c.cvtrain(x=x, y=y, folds=folds) if rebias: c.rebias() c.calibrate() chosen = c.cv.chosen.hyper if gamma == None: Ps = None Gp = c.featureweight(x=x) else: Ps = SigTools.svd( SigTools.shrinkcov(cov, copy=True, gamma=chosen.kernel.gamma)).isqrtm xp = SigTools.spfilt(x, Ps.H, copy=True) Gp = c.featureweight(x=xp) u = SigTools.stfac(Gp, Ps) u.channels = d['channels'] u.channels_used = wanted u.fs = d['fs'] u.trchvar = trchvar try: u.channels = SigTools.ChannelSet(u.channels) except: print 'WARNING: failed to convert channels to ChannelSet' elapsed = time.time() - starttime minutes = int(elapsed / 60.0) seconds = int(round(elapsed - minutes * 60.0)) print '%d min %d sec' % (minutes, seconds) datestamp = time.strftime('%Y-%m-%d %H:%M:%S') csummary = '%s (%s) trained on %d (CV %s = %.3f) at %s' % ( c.__class__.__name__, SigTools.experiment()._shortdesc(chosen), sum(c.input.istrain), c.loss.func.__name__, c.loss.train, datestamp, ) description = 'binary classification of %s: %s' % (description, csummary) u.description = description if save or select: if not isinstance(save, basestring): save = featurefiles if isinstance(save, (tuple, list)): save = save[-1] if save.lower().endswith('.gz'): save = save[:-3] if save.lower().endswith('.pk'): save = save[:-3] save = save + '_weights.prm' print "\nsaving %s\n" % save Parameters.Param(u.G.A, Name='ERPClassifierWeights', Section='PythonSig', Subsection='Epoch', Comment=csummary).write_to(save) Parameters.Param(c.model.bias, Name='ERPClassifierBias', Section='PythonSig', Subsection='Epoch', Comment=csummary).append_to(save) Parameters.Param(description, Name='SignalProcessingDescription', Section='PythonSig').append_to(save) if select: if not isinstance(select, basestring): select = 'ChosenWeights.prm' if not os.path.isabs(select): select = os.path.join(os.path.split(save)[0], select) print "saving %s\n" % select import shutil shutil.copyfile(save, select) print description return u, c