Beispiel #1
0
def bci2000chain(datfile,
                 chain='SpectralSignalProcessing',
                 parms=(),
                 dims='auto',
                 start=None,
                 duration=None,
                 verbose=False,
                 keep=False,
                 binpath=None,
                 **kwargs):
    """
	
	This is a provisional Python port of the Matlab-based tools/matlab/bci2000chain.m
	
	Excuse the relative paths - this example works if you're currently working in the root
	directory of the BCI2000 distro:
	
	    bci2000chain( datfile='data/samplefiles/eeg3_2.dat',
	                  chain='TransmissionFilter|SpatialFilter|ARFilter',
	                  binpath='tools/cmdline',
	                  parms=[ 'tools/matlab/ExampleParameters.prm' ], SpatialFilterType=3 )
	
	Or for more portable operation, you can do this kind of thing:
	
	    bci2000root( '/PATH/TO/BCI2000' )
	    
	    bci2000chain(datfile=bci2000path( 'data/samplefiles/eeg3_2.dat' ),
	                 chain='TransmissionFilter|SpatialFilter|ARFilter',
	                 parms=[ bci2000path( 'tools/matlab/ExampleParameters.prm' ) ], SpatialFilterType=3 )
	
	
	For clues, see http://www.bci2000.org/wiki/index.php/User_Reference:Matlab_Tools
	and matlab documentation in bci2000chain.m
	
	Most arguments are like the flags in the Matlab equivalent. mutatis to a large extent mutandis.
	Note that for now there is no global way of managing the system path. Either add the
	tools/cmdline directory to your $PATH variable before starting Python, or supply its location
	every time while calling, in the <binpath> argument.

	"""###

    if verbose: verbosityFlag = '-v'
    else: verbosityFlag = '-q'

    if isinstance(chain, str):
        if chain.lower() == 'SpectralSignalProcessing'.lower():
            chain = 'TransmissionFilter|SpatialFilter|SpectralEstimator|LinearClassifier|LPFilter|ExpressionFilter|Normalizer'
        elif chain.lower() == 'ARSignalProcessing'.lower():
            chain = 'TransmissionFilter|SpatialFilter|ARFilter|LinearClassifier|LPFilter|ExpressionFilter|Normalizer'
        elif chain.lower() == 'P3SignalProcessing'.lower():
            chain = 'TransmissionFilter|SpatialFilter|P3TemporalFilter|LinearClassifier'
        chain = chain.split('|')
    chain = [c.strip() for c in chain if len(c.strip())]

    if len(chain) == 0: print('WARNING: chain is empty')

    if start != None and len(str(start).strip()):
        start = ' --start=' + str(start).replace(' ', '')
    else:
        start = ''
    if duration != None and len(str(duration).strip()):
        duration = ' --duration=' + str(duration).replace(' ', '')
    else:
        duration = ''

    if dims == None or str(dims).lower() == 'auto': dims = 0
    if dims not in (0, 2, 3): raise ValueError("dims must be 2, 3 or 'auto'")

    out = CONTAINER()
    err = ''

    cmd = ''
    binaries = []
    tmpdir = tempfile.mkdtemp(prefix='bci2000chain_')

    tmpdatfile = os.path.join(tmpdir, 'in.dat')
    prmfile_in = os.path.join(tmpdir, 'in.prm')
    prmfile_out = os.path.join(tmpdir, 'out.prm')
    matfile = os.path.join(tmpdir, 'out.mat')
    bcifile = os.path.join(tmpdir, 'out.bci')
    shfile = os.path.join(tmpdir, 'go.bat')
    logfile = os.path.join(tmpdir, 'log.txt')

    if not isinstance(datfile, str):
        raise ValueError(
            'datfile must be a filename'
        )  # TODO: if datfile contains the appropriate info, use some create_bcidat equivalent and do datfile = tmpdatfile

    if not os.path.isfile(datfile):
        raise IOError('file not found: %s' % datfile)

    mappings = {
        '$DATFILE': datfile,
        '$PRMFILE_IN': prmfile_in,
        '$PRMFILE_OUT': prmfile_out,
        '$MATFILE': matfile,
        '$BCIFILE': bcifile,
        '$SHFILE': shfile,
        '$LOGFILE': logfile,
    }

    if binpath == None and BCI2000_ROOT_DIR != None: binpath = 'tools/cmdline'
    binpath = bci2000path(binpath)

    def exe(name):
        if binpath: return '"' + os.path.join(binpath, name) + '"'
        else: return name

    if parms == None: parms = []
    if isinstance(parms, tuple): parms = list(parms)
    if not isinstance(parms, list): parms = [parms]
    else: parms = list(parms)
    if len(kwargs): parms.append(kwargs)

    if parms == None or len(parms) == 0:
        cmd += exe('bci_dat2stream') + start + duration + ' < "$DATFILE"'
        binaries.append('bci_dat2stream')
    else:
        if verbose: print('# writing custom parameter file %s' % prmfile_in)
        parms = make_bciprm(verbosityFlag, datfile, parms, '>', prmfile_in)

        if dat2stream_has_p_flag:
            cmd += exe(
                'bci_dat2stream'
            ) + ' "-p$PRMFILE_IN"' + start + duration + ' < "$DATFILE"'  # new-style bci_dat2stream with -p option
            binaries.append('bci_dat2stream')
        else:
            if len(start) or len(duration):
                raise ValueError(
                    'old versionsof bci_dat2stream have no --start or --duration option'
                )
            cmd += '('  # old-style bci_dat2stream with no -p option
            cmd += exe('bci_prm2stream') + ' < "$PRMFILE_IN"'
            cmd += '&& ' + exe(
                'bci_dat2stream') + ' --transmit-sd < "$DATFILE"'
            cmd += ')'
            binaries.append('bci_dat2stream')
            binaries.append('bci_prm2stream')

    for c in chain:
        cmd += ' | ' + exe(c)
        binaries.append(c)

    if stream2mat_saves_parms:
        cmd += ' | ' + exe(
            'bci_stream2mat'
        ) + ' > "$MATFILE"'  # new-style bci_stream2mat with Parms output
        binaries.append('bci_stream2mat')
    else:
        cmd += ' > "$BCIFILE"'
        cmd += ' && ' + exe('bci_stream2mat') + ' < "$BCIFILE" > "$MATFILE"'
        cmd += ' && ' + exe(
            'bci_stream2prm'
        ) + ' < "$BCIFILE" > "$PRMFILE_OUT"'  # old-style bci_stream2mat without Parms output
        binaries.append('bci_stream2mat')
        binaries.append('bci_stream2prm')

    for k, v in list(mappings.items()):
        cmd = cmd.replace(k, v)

    win = sys.platform.lower().startswith('win')
    if win: shebang = '@'
    else: shebang = '#!/bin/sh\n\n'
    open(shfile, 'wt').write(shebang + cmd + '\n')
    if not win: os.chmod(shfile, 484)  # rwxr--r--

    def tidytext(x):
        return x.strip().replace('\r\n', '\n').replace('\r', '\n')

    def getstatusoutput(cmd):
        pipe = os.popen(
            cmd + ' 2>&1', 'r'
        )  # TODO: does this work on Windows? could always make use of logfile here if not
        output = pipe.read()
        status = pipe.close()
        if status == None: status = 0
        return status, tidytext(output)

    def getoutput(cmd):
        return getstatusoutput(cmd)[1]

    if verbose: print('# querying version information')
    binaries = CONTAINER([
        (bin, getoutput(exe(bin) + ' --version').replace('\n', '  '))
        for bin in binaries
    ])

    if verbose: print(cmd)
    t0 = time.time()
    failed, output = getstatusoutput(shfile)
    chaintime = time.time() - t0

    failsig = 'Configuration Error: '
    if failsig in output: failed = 1
    printable_output = output
    printable_lines = output.split('\n')
    maxlines = 10
    if len(printable_lines) > maxlines:
        printable_output = '\n'.join(
            printable_lines[:maxlines] +
            ['[%d more lines omitted]' % (len(printable_lines) - maxlines)])
    if failed:
        if verbose:
            err = 'system call failed:\n' + printable_output  # cmd has already been printed, so don't clutter things further
        else:
            err = 'system call failed:\n\n%s\n\n%s' % (cmd, printable_output)

    if err == '':
        if verbose: print('# loading %s' % matfile)
        try:
            mat = ReadMatFile(matfile)
        except:
            err = "The chain must have failed: could not load %s\nShell output was as follows:\n%s" % (
                matfile, printable_output)
        else:
            if 'Data' not in mat:
                err = "The chain must have failed: no 'Data' variable found in %s\nShell output was as follows:\n%s" % (
                    matfile, printable_output)
            if 'Index' not in mat:
                err = "The chain must have failed: no 'Index' variable found in %s\nShell output was as follows:\n%s" % (
                    matfile, printable_output)

    if err == '':
        out.FileName = datfile
        if stream2mat_saves_parms:
            if verbose: print('# decoding parameters loaded from the mat-file')
            parms = make_bciprm(verbosityFlag, mat['Parms'][0])
        else:
            if verbose: print('# reading output parameter file' + prmfile_out)
            parms = ParmList(
                prmfile_out
            )  # if you get an error that prmfile_out does not exist, recompile your bci_dat2stream and bci_stream2mat binaries from up-to-date sources, and ensure that dat2stream_has_p_flag and stream2mat_saves_parms, at the top of this file, are both set to 1

        out.DateStr = read_bcidate(parms, 'ISO')
        out.DateNum = read_bcidate(parms)
        out.FilterChain = chain
        out.ToolVersions = binaries
        out.ShellInput = cmd
        out.ShellOutput = output
        out.ChainTime = chaintime
        out.ChainSpeedFactor = None
        out.Megabytes = None
        out.Parms = parms
        out.Parms.sort()

        mat['Index'] = mat['Index'][0, 0]
        sigind = mat[
            'Index'].Signal - 1  # indices vary across channels fastest, then elements
        nChannels, nElements = sigind.shape
        nBlocks = mat['Data'].shape[1]
        out.Blocks = nBlocks
        out.BlocksPerSecond = float(parms.SamplingRate.ScaledValue) / float(
            parms.SampleBlockSize.ScaledValue)
        out.SecondsPerBlock = float(parms.SampleBlockSize.ScaledValue) / float(
            parms.SamplingRate.ScaledValue)
        out.ChainSpeedFactor = float(out.Blocks * out.SecondsPerBlock) / float(
            out.ChainTime)

        def unnumpify(x):
            while isinstance(x, numpy.ndarray) and x.size == 1:
                x = x[0]
            if isinstance(x, (numpy.ndarray, tuple, list)):
                x = [unnumpify(xi) for xi in x]
            return x

        out.Channels = nChannels
        out.ChannelLabels = unnumpify(mat.get('ChannelLabels', []))
        try:
            out.ChannelSet = SigTools.ChannelSet(out.ChannelLabels)
        except:
            out.ChannelSet = None

        out.Elements = nElements
        out.ElementLabels = unnumpify(mat.get('ElementLabels', []))
        out.ElementValues = numpy.ravel(mat.get('ElementValues', []))
        out.ElementUnit = unnumpify(mat.get('ElementUnit', None))
        out.ElementRate = out.BlocksPerSecond * out.Elements

        out.Time = out.SecondsPerBlock * numpy.arange(0, nBlocks)
        out.FullTime = out.Time
        out.FullElementValues = out.ElementValues

        # Why does sigind have to be transposed before vectorizing to achieve the same result as sigind(:) WITHOUT a transpose in Matlab?
        # This will probably forever be one of the many deep mysteries of numpy dimensionality handling
        out.Signal = mat['Data'][
            sigind.T.ravel(), :]  # nChannels*nElements - by - nBlocks
        out.Signal = out.Signal + 0.0  # make a contiguous copy

        def seconds(
            s
        ):  # -1 means "no information", 0 means "not units of time",  >0 means the scaling factor
            if getattr(s, 'ElementUnit', None) in ('', None,
                                                   ()) or s.ElementUnit == []:
                return -1
            s = s.ElementUnit
            if s.endswith('seconds'): s = s[:-6]
            elif s.endswith('second'): s = s[:-5]
            elif s.endswith('sec'): s = s[:-2]
            if s.endswith('s'):
                return {
                    'ps': 1e-12,
                    'ns': 1e-9,
                    'us': 1e-6,
                    'mus': 1e-6,
                    'ms': 1e-3,
                    's': 1e+0,
                    'ks': 1e+3,
                    'Ms': 1e+6,
                    'Gs': 1e+9,
                    'Ts': 1e+12,
                }.get(s, 0)
            else:
                return 0

        if dims == 0:  # dimensionality has not been specified explicitly: so guess, based on ElementUnit and/or filter name
            # 3-dimensional output makes more sense than continuous 2-D whenever "elements" can't just be concatenated into an unbroken time-stream
            if len(chain): lastfilter = chain[-1].lower()
            else: lastfilter = ''
            if lastfilter == 'p3temporalfilter':
                dims = 3
            else:
                factor = seconds(out)
                if factor > 0:  # units of time.  TODO: could detect whether the out.ElementValues*factor are (close enough to) contiguous from block to block; then p3temporalfilter wouldn't have to be a special case above
                    dims = 2
                elif factor == 0:  # not units of time: use 3D by default
                    dims = 3
                elif lastfilter in [
                        'p3temporalfilter', 'arfilter', 'fftfilter',
                        'coherencefilter', 'coherencefftfilter'
                ]:  # no ElementUnit info? guess based on filter name
                    dims = 3
                else:
                    dims = 2

        if dims == 3:
            out.Signal = numpy.reshape(
                out.Signal, (nChannels, nElements, nBlocks),
                order='F')  # nChannels - by - nElements - by - nBlocks
            out.Signal = numpy.transpose(
                out.Signal,
                (2, 0, 1))  # nBlocks - by - nChannels - by - nElements
            out.Signal = out.Signal + 0.0  # make a contiguous copy
        elif dims == 2:
            out.FullTime = numpy.repeat(out.Time, nElements)
            factor = seconds(out)
            if len(out.ElementValues):
                out.FullElementValues = numpy.tile(out.ElementValues, nBlocks)
                if factor > 0:
                    out.FullTime = out.FullTime + out.FullElementValues * factor

            out.Signal = numpy.reshape(out.Signal,
                                       (nChannels, nElements * nBlocks),
                                       order='F')  # nChannels - by - nSamples
            out.Signal = numpy.transpose(out.Signal,
                                         (1, 0))  # nSamples - by - nChannels
            out.Signal = out.Signal + 0.0  # make a contiguous copy
        else:
            raise RuntimeError('internal error')

        out.States = CONTAINER()
        try:
            fieldnames = mat['Index']._fieldnames
        except AttributeError:
            items = list(mat['Index'].items())
        else:
            items = [(k, getattr(mat['Index'], k)) for k in fieldnames]
        states = [(k, int(v) - 1) for k, v in items if k != 'Signal']
        for k, v in states:
            setattr(out.States, k, mat['Data'][v, :])
        # TODO: how do the command-line tools handle event states? this seems to be set up to deliver one value per block whatever kind of state we're dealing with

        out.Megabytes = megs(out)
    else:
        out = err
        keep = True
        print(err)
        print()

    if os.path.isdir(tmpdir):
        files = sorted([
            os.path.join(tmpdir, file) for file in os.listdir(tmpdir)
            if file not in ['.', '..']
        ])
        if keep:
            print(
                'The following commands should be executed to clean up the temporary files:'
            )
        elif verbose:
            print('# removing temp files and directory ' + tmpdir)

        for file in files:
            if keep: print("os.remove(r'%s')" % file)
            else:
                try:
                    os.remove(file)
                except Exception as err:
                    sys.stderr.write("failed to remove %s:\n    %s\n" %
                                     (file, str(err)))

        if keep: print("os.rmdir(r'%s')" % tmpdir)
        else:
            try:
                os.rmdir(tmpdir)
            except Exception as err:
                sys.stderr.write("failed to remove %s:\n    %s" %
                                 (tmpdir, str(err)))
        if keep: print("")

    return out
Beispiel #2
0
def ClassifyERPs(
    featurefiles,
    C=(10.0, 1.0, 0.1, 0.01),
    gamma=(1.0, 0.8, 0.6, 0.4, 0.2, 0.0),
    keepchan=(),
    rmchan=(),
    rmchan_usualsuspects=('AUDL', 'AUDR', 'LAUD', 'RAUD', 'SYNC', 'VSYNC',
                          'VMRK', 'OLDREF'),
    rebias=True,
    save=False,
    select=False,
    description='ERPs to attended vs unattended events',
    maxcount=None,
    classes=None,
    folds=None,
    time_window=None,
    keeptrials=None,
):

    file_inventory = []
    d = DataFiles.load(featurefiles,
                       catdim=0,
                       maxcount=maxcount,
                       return_details=file_inventory)
    if isinstance(folds, basestring) and folds.lower() in [
            'lofo', 'loro', 'leave on run out', 'leave one file out'
    ]:
        n, folds = 0, []
        for each in file_inventory:
            neach = each[1]['x']
            folds.append(range(n, n + neach))
            n += neach

    if 'x' not in d:
        raise ValueError(
            "found no trial data - no 'x' variable - in the specified files")
    if 'y' not in d:
        raise ValueError(
            "found no trial labels - no 'y' variable - in the specified files")

    x = d['x']
    y = numpy.array(d['y'].flat)
    if keeptrials != None:
        x = x[numpy.asarray(keeptrials), :, :]
        y = y[numpy.asarray(keeptrials)]

    if time_window != None:
        fs = d['fs']
        t = SigTools.samples2msec(numpy.arange(x.shape[2]), fs)
        x[:, :, t < min(time_window)] = 0
        x[:, :, t > max(time_window)] = 0

    if classes != None:
        for cl in classes:
            if cl not in y:
                raise ValueError("class %s is not in the dataset" % str(cl))
        mask = numpy.array([yi in classes for yi in y])
        y = y[mask]
        x = x[mask]
        discarded = sum(mask == False)
        if discarded:
            print "discarding %d trials that are outside the requested classes %s" % (
                discarded, str(classes))

    n = len(y)
    uy = numpy.unique(y)
    if uy.size != 2:
        raise ValueError("expected 2 classes in dataset, found %d : %s" %
                         (uy.size, str(uy)))
    for uyi in uy:
        nyi = sum([yi == uyi for yi in y])
        nyi_min = 2
        if nyi < nyi_min:
            raise ValueError(
                'only %d exemplars of class %s - need at least %d' %
                (nyi, str(uyi), nyi_min))

    y = numpy.sign(y - uy.mean())

    cov, trchvar = SigTools.spcov(
        x=x, y=y, balance=False,
        return_trchvar=True)  # NB: symwhitenkern would not be able to balance

    starttime = time.time()

    chlower = [ch.lower() for ch in d['channels']]
    if keepchan in [None, (), '', []]:
        if isinstance(rmchan, basestring): rmchan = rmchan.split()
        if isinstance(rmchan_usualsuspects, basestring):
            rmchan_usualsuspects = rmchan_usualsuspects.split()
        allrmchan = [
            ch.lower() for ch in list(rmchan) + list(rmchan_usualsuspects)
        ]
        unwanted = numpy.array([ch in allrmchan for ch in chlower])
        notfound = [ch for ch in rmchan if ch.lower() not in chlower]
    else:
        if isinstance(keepchan, basestring): keepchan = keepchan.split()
        lowerkeepchan = [ch.lower() for ch in keepchan]
        unwanted = numpy.array([ch not in lowerkeepchan for ch in chlower])
        notfound = [ch for ch in keepchan if ch.lower() not in chlower]

    wanted = numpy.logical_not(unwanted)
    print ' '
    if len(notfound):
        print "WARNING: could not find channel%s %s\n" % ({
            1: ''
        }.get(len(notfound), 's'), ', '.join(notfound))
    removed = [ch for removing, ch in zip(unwanted, d['channels']) if removing]
    if len(removed):
        print "removed %d channel%s (%s)" % (len(removed), {
            1: ''
        }.get(len(removed), 's'), ', '.join(removed))
    print "classification will be based on %d channel%s" % (sum(wanted), {
        1: ''
    }.get(sum(wanted), 's'))
    print "%d negatives + %d positives = %d exemplars" % (sum(y < 0),
                                                          sum(y > 0), n)
    print ' '

    x[:, unwanted, :] = 0
    cov[:, unwanted] = 0
    cov[unwanted, :] = 0
    nu = numpy.asarray(cov).diagonal()[wanted].mean()
    for i in range(len(cov)):
        if cov[i, i] == 0: cov[i, i] = nu

    if not isinstance(C, (tuple, list, numpy.ndarray, type(None))): C = [C]
    if not isinstance(gamma, (tuple, list, numpy.ndarray, type(None))):
        gamma = [gamma]

    c = SigTools.klr2class(lossfunc=SigTools.balanced_loss, relcost='balance')
    c.varyhyper({})
    if c != None: c.hyper.C = list(C)
    if gamma == None: c.hyper.kernel.func = SigTools.linkern
    else:
        c.varyhyper({
            'kernel.func': SigTools.symwhitenkern,
            'kernel.cov': [cov],
            'kernel.gamma': list(gamma)
        })
    c.cvtrain(x=x, y=y, folds=folds)
    if rebias: c.rebias()
    c.calibrate()

    chosen = c.cv.chosen.hyper
    if gamma == None:
        Ps = None
        Gp = c.featureweight(x=x)
    else:
        Ps = SigTools.svd(
            SigTools.shrinkcov(cov, copy=True,
                               gamma=chosen.kernel.gamma)).isqrtm
        xp = SigTools.spfilt(x, Ps.H, copy=True)
        Gp = c.featureweight(x=xp)

    u = SigTools.stfac(Gp, Ps)
    u.channels = d['channels']
    u.channels_used = wanted
    u.fs = d['fs']
    u.trchvar = trchvar
    try:
        u.channels = SigTools.ChannelSet(u.channels)
    except:
        print 'WARNING: failed to convert channels to ChannelSet'

    elapsed = time.time() - starttime
    minutes = int(elapsed / 60.0)
    seconds = int(round(elapsed - minutes * 60.0))
    print '%d min %d sec' % (minutes, seconds)
    datestamp = time.strftime('%Y-%m-%d %H:%M:%S')
    csummary = '%s (%s) trained on %d (CV %s = %.3f) at %s' % (
        c.__class__.__name__,
        SigTools.experiment()._shortdesc(chosen),
        sum(c.input.istrain),
        c.loss.func.__name__,
        c.loss.train,
        datestamp,
    )
    description = 'binary classification of %s: %s' % (description, csummary)
    u.description = description

    if save or select:
        if not isinstance(save, basestring):
            save = featurefiles
            if isinstance(save, (tuple, list)): save = save[-1]
            if save.lower().endswith('.gz'): save = save[:-3]
            if save.lower().endswith('.pk'): save = save[:-3]
            save = save + '_weights.prm'
        print "\nsaving %s\n" % save
        Parameters.Param(u.G.A,
                         Name='ERPClassifierWeights',
                         Section='PythonSig',
                         Subsection='Epoch',
                         Comment=csummary).write_to(save)
        Parameters.Param(c.model.bias,
                         Name='ERPClassifierBias',
                         Section='PythonSig',
                         Subsection='Epoch',
                         Comment=csummary).append_to(save)
        Parameters.Param(description,
                         Name='SignalProcessingDescription',
                         Section='PythonSig').append_to(save)
        if select:
            if not isinstance(select, basestring): select = 'ChosenWeights.prm'
            if not os.path.isabs(select):
                select = os.path.join(os.path.split(save)[0], select)
            print "saving %s\n" % select
            import shutil
            shutil.copyfile(save, select)

    print description
    return u, c