def Initialize(self, indim, outdim):
        self.warp = 1000.0  # let the samples flowing into the ring buffer set the pace
        self.eegfs = self.samplingrate()
        self.hwfs = int(self.params['HardwareSamplingRate'])
        self.chunk = SigTools.msec2samples(
            float(self.params['HardwareChunkMsec']), self.hwfs)
        ringsize = SigTools.msec2samples(
            float(self.params['NIABufferSizeMsec']), self.hwfs)
        self.ring = SigTools.ring(ringsize, indim[0])
        self.ring.allow_overflow = True
        self.nominal['HardwareSamplesPerPacket'] = SigTools.msec2samples(
            self.nominal['SecondsPerPacket'] * 1000.0, self.hwfs)

        cutoff = float(self.params['DSFilterFreqFactor']) * self.eegfs / 2.0
        order = int(self.params['DSFilterOrder'])
        if order > 0 and cutoff > 0.0:
            self.filter = SigTools.causalfilter(
                freq_hz=cutoff,
                samplingfreq_hz=self.hwfs,
                order=order,
                type='lowpass')  #, method=SigTools.firdesign)
        else:
            self.filter = None
        self.dsind = numpy.linspace(0.0,
                                    self.nominal['HardwareSamplesPerPacket'],
                                    self.nominal['SamplesPerPacket'] + 1,
                                    endpoint=True)
        self.dsind = numpy.round(self.dsind).astype(numpy.int).tolist()

        self._threads['listen'].post('start')
        self._threads['listen'].read('ready', wait=True, remove=True)
        self._check_threads()
Example #2
0
    def initialize(cls, app, indim, outdim):
        if int(app.params['ERPDatabaseEnable'])==1:
            if int(app.params['ShowSignalTime']):
                app.addstatemonitor('LastERPVal')
                app.addstatemonitor('ERPCollected')

            #===================================================================
            # Prepare the buffers for saving the data
            # -leaky_trap contains the data to be saved (trap size defined by pre_stim_samples + post_stim_samples + some breathing room
            # -trig_trap contains only the trigger channel
            #===================================================================
            app.x_vec=np.arange(app.erpwin[0],app.erpwin[1],1000.0/app.eegfs,dtype=float)#Needed when saving trials
            app.post_stim_samples = SigTools.msec2samples(app.erpwin[1], app.eegfs)
            app.pre_stim_samples = SigTools.msec2samples(np.abs(app.erpwin[0]), app.eegfs)
            app.leaky_trap=SigTools.Buffering.trap(app.pre_stim_samples + app.post_stim_samples + 5*app.spb, len(app.params['ERPChan']), leaky=True)
            app.trig_trap = SigTools.Buffering.trap(app.post_stim_samples, 1, trigger_channel=0, trigger_threshold=app.trigthresh[0])

            #===================================================================
            # Prepare the models from the database.
            #===================================================================
            app.subject = Subject.objects.get_or_create(name=app.params['SubjectName'])[0]
            #===================================================================
            # app.period = app.subject.get_or_create_recent_period(delay=0)
            # app.subject.periods.update()
            # app.period = app.subject.periods.order_by('-datum_id').all()[0]
            #===================================================================

            #===================================================================
            # Use a thread for database interactions because sometimes they will be slow.
            # (especially when calculating a trial's features)
            #===================================================================
            app.erp_thread = ERPThread(Queue.Queue(), app)
            app.erp_thread.setDaemon(True) #Dunno, always there in the thread examples.
            app.erp_thread.start() #Starts the thread.

            #===================================================================
            # Setup the ERP feedback elements.
            # -Screen will range from -2*fbthresh to +2*fbthresh
            # -Calculated ERP value will be scaled so 65536(int16) fills the screen.
            #===================================================================
            if int(app.params['ERPFeedbackDisplay'])==2:
                fbthresh = app.params['ERPFeedbackThreshold'].val
                app.erp_scale = (2.0**16) / (4.0*np.abs(fbthresh))
                if fbthresh < 0:
                    fbmax = fbthresh * app.erp_scale
                    fbmin = 2.0 * fbthresh * app.erp_scale
                else:
                    fbmax = 2.0 * fbthresh * app.erp_scale
                    fbmin = fbthresh * app.erp_scale
                m=app.scrh/float(2**16)#Conversion factor from signal amplitude to pixels.
                b_offset=app.scrh/2.0 #Input 0.0 should be at this pixel value.
                app.addbar(color=(1,0,0), pos=(0.9*app.scrw,b_offset), thickness=0.1*app.scrw, fac=m)
                n_bars = len(app.bars)
                #app.stimuli['bartext_1'].position=(50,50)
                app.stimuli['bartext_' + str(n_bars)].color=[0,0,0]
                erp_target_box = Block(position=(0.8*app.scrw,m*fbmin+b_offset), size=(0.2*app.scrw,m*(fbmax-fbmin)), color=(1,0,0,0.5), anchor='lowerleft')
                app.stimulus('erp_target_box', z=1, stim=erp_target_box)
Example #3
0
    def stplot(self, img, fs=None, drawnow=True, **kwargs):
        #kwargs['colorbar'] = kwargs.get('colorbar', True)
        if img.shape[0] != self.size:
            raise ValueError(
                "number of rows in image should match number of channels (=%d)"
                % self.size)
        import SigTools
        pylab = SigTools.Plotting.load_pylab()
        t = numpy.arange(img.shape[1], dtype=numpy.float64)
        if fs == None:
            x = None
            xlabel = 'time-sample index'
        else:
            x = t = SigTools.samples2msec(t, fs)
            xlabel = 'time (msec)'
        x = None
        pylab.subplot(121)
        h = SigTools.Plotting.imagesc(img,
                                      x=x,
                                      aspect='auto',
                                      drawnow=False,
                                      picker=5,
                                      **kwargs)
        ax = pylab.gca()
        if x == None and fs != None:
            xl = ax.get_xlim()
            xt = numpy.array(
                [x for x in ax.get_xticks() if min(xl) <= x <= max(xl)])
            xtl = ['%g' % x for x in SigTools.samples2msec(xt, fs)]
            ax.set(xticks=xt, xticklabels=xtl)
        ax.set(
            xlabel=xlabel,
            yticks=range(img.shape[0]),
            yticklabels=['%d: %s' % x for x in enumerate(self.get_labels())])
        ax.grid(True)
        h.channels = self
        h.img = numpy.asarray(img)
        h.timebase = t
        h.spaceax = pylab.subplot(222)
        h.timeax = pylab.subplot(224)

        def onpick(evt):
            col = int(round(evt.mouseevent.xdata))
            row = int(round(evt.mouseevent.ydata))
            h = evt.artist
            pylab.axes(h.spaceax)
            h.channels.plot(h.img[:, col], clim=h.get_clim(), drawnow=False)
            pylab.axes(h.timeax)
            SigTools.plot(h.timebase, h.img[row], drawnow=False)
            pylab.gca().set(ylim=h.get_clim(),
                            xlim=(h.timebase[0], h.timebase[-1]))
            pylab.draw()

        pylab.gcf().canvas.mpl_connect('pick_event', onpick)
        if drawnow: pylab.draw()
        return h
		def onpick(evt):
			col = int(round(evt.mouseevent.xdata))
			row = int(round(evt.mouseevent.ydata))
			h = evt.artist
			pylab.axes(h.spaceax)
			h.channels.plot(h.img[:,col], clim=h.get_clim(), drawnow=False)
			pylab.axes(h.timeax)
			SigTools.plot(h.timebase, h.img[row], drawnow=False)
			pylab.gca().set(ylim=h.get_clim(), xlim=(h.timebase[0],h.timebase[-1]))
			pylab.draw()
def PlotSCD( files='*.pk' ):
	import pylab
	if isinstance( files, basestring ):
		if os.path.isdir( files ): files = os.path.join( files, '*.pk' )
		files = glob.glob( files )
	d = DataFiles.load( files )
	r = SigTools.correlate( d['x'], d['y'], axis=0 )
	SigTools.imagesc( r*numpy.abs(r), y=d['channels'], x=SigTools.samples2msec( range( r.shape[1] ), d['fs'] ), aspect='auto', balance=0.0, colorbar=True )
	pylab.title( ', '.join( [ '%d: %d' % ( yi, ( d['y'] == yi ).sum() ) for yi in numpy.unique( d['y'] ) ] ) )
	pylab.draw()
Example #6
0
 def onpick(evt):
     col = int(round(evt.mouseevent.xdata))
     row = int(round(evt.mouseevent.ydata))
     h = evt.artist
     pylab.axes(h.spaceax)
     h.channels.plot(h.img[:, col], clim=h.get_clim(), drawnow=False)
     pylab.axes(h.timeax)
     SigTools.plot(h.timebase, h.img[row], drawnow=False)
     pylab.gca().set(ylim=h.get_clim(),
                     xlim=(h.timebase[0], h.timebase[-1]))
     pylab.draw()
def PlotTrials( files='*.pk', channel='Cz' ):
	import pylab
	if isinstance( files, basestring ):
		if os.path.isdir( files ): files = os.path.join( files, '*.pk' )
		files = glob.glob( files )
	d = DataFiles.load( files )
	chind = d[ 'channels' ].index( channel )
	v = [ ( d[ 'y' ][ i ], i, vi ) for i, vi in enumerate( d[ 'x' ][ :, chind, : ] ) ]
	v = numpy.array( [ vi for yi, i, vi in sorted( v ) ] )
	SigTools.imagesc( v, x=SigTools.samples2msec( range( v.shape[1] ), d['fs'] ), aspect='auto', balance=0.0, colorbar=True )
	pylab.title( ', '.join( [ '%d: %d' % ( yi, ( d['y'] == yi ).sum() ) for yi in numpy.unique( d['y'] ) ] ) + ( ' (channel %s)' % channel ) )
	pylab.draw()
Example #8
0
def PlotSCD(files='*.pk'):
    import pylab
    if isinstance(files, basestring):
        if os.path.isdir(files): files = os.path.join(files, '*.pk')
        files = glob.glob(files)
    d = DataFiles.load(files)
    r = SigTools.correlate(d['x'], d['y'], axis=0)
    SigTools.imagesc(r * numpy.abs(r),
                     y=d['channels'],
                     x=SigTools.samples2msec(range(r.shape[1]), d['fs']),
                     aspect='auto',
                     balance=0.0,
                     colorbar=True)
    pylab.title(', '.join([
        '%d: %d' % (yi, (d['y'] == yi).sum()) for yi in numpy.unique(d['y'])
    ]))
    pylab.draw()
			def squares(points):
				def around(x, d=0.5): return ((x[0]-d,x[1]-d),(x[0]-d,x[1]+d),(x[0]+d,x[1]+d),(x[0]+d,x[1]-d))
				d = 0.5 * min([d for d in SigTools.sqdist(points).flat if d != 0]) ** 0.5
				v = [around(tuple(x), d) for x in points]
				points,codes = (),()
				for vi in v:
					points = points + vi + (vi[0],)
					codes = codes + (matplotlib.path.Path.MOVETO,) + (matplotlib.path.Path.LINETO,) * len(vi)
				return points, codes
Example #10
0
def PlotTrials(files='*.pk', channel='Cz'):
    import pylab
    if isinstance(files, basestring):
        if os.path.isdir(files): files = os.path.join(files, '*.pk')
        files = glob.glob(files)
    d = DataFiles.load(files)
    chind = d['channels'].index(channel)
    v = [(d['y'][i], i, vi) for i, vi in enumerate(d['x'][:, chind, :])]
    v = numpy.array([vi for yi, i, vi in sorted(v)])
    SigTools.imagesc(v,
                     x=SigTools.samples2msec(range(v.shape[1]), d['fs']),
                     aspect='auto',
                     balance=0.0,
                     colorbar=True)
    pylab.title(', '.join(
        ['%d: %d' % (yi, (d['y'] == yi).sum())
         for yi in numpy.unique(d['y'])]) + (' (channel %s)' % channel))
    pylab.draw()
	def Initialize(self, indim, outdim):
		self.warp = 1000.0 # let the samples flowing into the ring buffer set the pace
		self.eegfs = self.samplingrate()
		self.hwfs = int(self.params['HardwareSamplingRate'])
		self.chunk = SigTools.msec2samples(float(self.params['HardwareChunkMsec']), self.hwfs)
		ringsize = SigTools.msec2samples(float(self.params['NIABufferSizeMsec']), self.hwfs)
		self.ring = SigTools.ring(ringsize, indim[0])
		self.ring.allow_overflow = True
		self.nominal['HardwareSamplesPerPacket'] = SigTools.msec2samples(self.nominal['SecondsPerPacket']*1000.0, self.hwfs)
		
		cutoff = float(self.params['DSFilterFreqFactor']) * self.eegfs / 2.0
		order = int(self.params['DSFilterOrder'])
		if order > 0 and cutoff > 0.0:
			self.filter = SigTools.causalfilter(freq_hz=cutoff, samplingfreq_hz=self.hwfs, order=order, type='lowpass') #, method=SigTools.firdesign)
		else:
			self.filter = None
		self.dsind = numpy.linspace(0.0, self.nominal['HardwareSamplesPerPacket'], self.nominal['SamplesPerPacket']+1, endpoint=True)
		self.dsind = numpy.round(self.dsind).astype(numpy.int).tolist()

		self._threads['listen'].post('start')
		self._threads['listen'].read('ready', wait=True, remove=True)
		self._check_threads()
	def stplot(self, img, fs=None, drawnow=True, **kwargs):
		#kwargs['colorbar'] = kwargs.get('colorbar', True)
		if img.shape[0] != self.size: raise ValueError("number of rows in image should match number of channels (=%d)" % self.size)
		import SigTools
		pylab = SigTools.Plotting.load_pylab()
		t = numpy.arange(img.shape[1], dtype=numpy.float64)
		if fs == None: x = None; xlabel = 'time-sample index'
		else: x = t = SigTools.samples2msec(t, fs); xlabel = 'time (msec)'
		x = None
		pylab.subplot(121)
		h = SigTools.Plotting.imagesc(img, x=x, aspect='auto', drawnow=False, picker=5, **kwargs)
		ax = pylab.gca()
		if x == None and fs != None:
			xl = ax.get_xlim()
			xt = numpy.array([x for x in ax.get_xticks() if min(xl)<=x<=max(xl)])
			xtl = ['%g' % x for x in SigTools.samples2msec(xt, fs)]
			ax.set(xticks=xt, xticklabels=xtl)
		ax.set(xlabel=xlabel, yticks=range(img.shape[0]), yticklabels=['%d: %s' % x for x in enumerate(self.get_labels())])
		ax.grid(True)
		h.channels = self
		h.img = numpy.asarray(img)
		h.timebase = t
		h.spaceax = pylab.subplot(222)
		h.timeax = pylab.subplot(224)
		def onpick(evt):
			col = int(round(evt.mouseevent.xdata))
			row = int(round(evt.mouseevent.ydata))
			h = evt.artist
			pylab.axes(h.spaceax)
			h.channels.plot(h.img[:,col], clim=h.get_clim(), drawnow=False)
			pylab.axes(h.timeax)
			SigTools.plot(h.timebase, h.img[row], drawnow=False)
			pylab.gca().set(ylim=h.get_clim(), xlim=(h.timebase[0],h.timebase[-1]))
			pylab.draw()
		pylab.gcf().canvas.mpl_connect('pick_event', onpick)
		if drawnow: pylab.draw()
		return h
	def Initialize(self, indim, outdim):
		self.warp = 1000.0 # let the audio samples flowing into the ring buffer set the pace
		self.eegfs = self.samplingrate()
		self.audiofs = int(self.params['AudioSamplingRate'])
		self.audiobits = int(self.params['AudioBitDepth'])
		self.audiochunk = WavTools.msec2samples(float(self.params['AudioChunkMsec']), self.audiofs)
		self.audiochannels = int(self.params['NumberOfAudioChannels'])
		self.use_env = int(self.params['UseAudioEnvelope'])
		ringsize = WavTools.msec2samples(float(self.params['AudioBufferSizeMsec']), self.audiofs)
		self.ring = SigTools.ring(ringsize, self.audiochannels)
		self.ring.allow_overflow = True
		self.nominal['AudioSamplesPerPacket'] = WavTools.msec2samples(self.nominal['SecondsPerPacket']*1000.0, self.audiofs)
		self.dsind = numpy.linspace(0.0, self.nominal['AudioSamplesPerPacket'], self.nominal['SamplesPerPacket']+1, endpoint=True)
		self.dsind = numpy.round(self.dsind).astype(numpy.int).tolist()
		self._threads['listen'].post('start')
		self._threads['listen'].read('ready', wait=True, remove=True)
		self._check_threads()
Example #14
0
 def Initialize(self, indim, outdim):
     self.warp = 1000.0  # let the audio samples flowing into the ring buffer set the pace
     self.eegfs = self.samplingrate()
     self.audiofs = int(self.params['AudioSamplingRate'])
     self.audiobits = int(self.params['AudioBitDepth'])
     self.audiochunk = WavTools.msec2samples(
         float(self.params['AudioChunkMsec']), self.audiofs)
     self.audiochannels = int(self.params['NumberOfAudioChannels'])
     self.use_env = int(self.params['UseAudioEnvelope'])
     ringsize = WavTools.msec2samples(
         float(self.params['AudioBufferSizeMsec']), self.audiofs)
     self.ring = SigTools.ring(ringsize, self.audiochannels)
     self.ring.allow_overflow = True
     self.nominal['AudioSamplesPerPacket'] = WavTools.msec2samples(
         self.nominal['SecondsPerPacket'] * 1000.0, self.audiofs)
     self.dsind = numpy.linspace(0.0,
                                 self.nominal['AudioSamplesPerPacket'],
                                 self.nominal['SamplesPerPacket'] + 1,
                                 endpoint=True)
     self.dsind = numpy.round(self.dsind).astype(numpy.int).tolist()
     self._threads['listen'].post('start')
     self._threads['listen'].read('ready', wait=True, remove=True)
     self._check_threads()
def TimingWindow(filename='.', ind=-1, save=None):
	"""
Recreate BCI2000's timing window offline, from a saved .dat file specified by <filename>.
It is also possible to supply a directory name as <filename>, and an index <ind> (default
value -1 meaning "the last run") to choose a file automatically from that directory.

Based on BCI2000's   src/shared/modules/signalsource/DataIOFilter.cpp where the timing window
content is computed in DataIOFilter::Process(), this is what appears to happen:
    
         Begin SampleBlock #t:
            Enter SignalSource module's first Process() method (DataIOFilter::Process)
            Save previous SampleBlock to file
            Wait to acquire new SampleBlock from hardware
 +--------- Measure SourceTime in SignalSource module
 |   |   |  Make a local copy of all states (NB: all except SourceTime were set during #t-1) ---+
B|  R|  S|  Pipe the signal through the rest of BCI2000                                         |
 |   |   +- Measure StimulusTime in Application module, on leaving last Process() method        |
 |   |                                                                                          |
 |   |                                                                                          |
 |   |   Begin SampleBlock #t+1:                                                                |
 |   +----- Enter SignalSource module's first Process() method (DataIOFilter::Process)          |
 |          Save data from #t, SourceTime state from #t, and other states from #t-1, to file <--+
 |          Wait to acquire new SampleBlock from hardware
 +--------- Measure SourceTime in SignalSource module
            Make a local copy of all states (NB: all except SourceTime were set during #t)
            Leave DataIOFilter::Process() and pipe the signal through the rest of BCI2000
            Measure StimulusTime in Application module, on leaving last Process() method

B stands for Block duration.
R stands for Roundtrip time (visible in VisualizeTiming, not reconstructable from the .dat file)
S is the filter cascade time (marked "Stimulus" in the VisualizeTiming window).

Note that, on any given SampleBlock as saved in the file, SourceTime will be *greater* than
any other timestamp states (including StimulusTime), because it is the only state that is
updated in time to be saved with the data packet it belongs to. All the others lag by one
packet.  This is corrected for at the point commented with ??? in the Python code. 
"""
	
	if hasattr(filename, 'filename'): filename = filename.filename
	
	b = bcistream(filename=filename, ind=ind)
		
	out = SigTools.sstruct()
	out.filename = b.filename
	#print "decoding..."
	sig,states = b.decode('all')
	#print "done"
	b.close()

	dT,T,rT = {},{},{}
	statenames = ['SourceTime', 'StimulusTime'] + ['PythonTiming%02d' % (x+1) for x in range(2)]
	statenames = [s for s in statenames if s in states]
	for key in statenames:
		dT[key],T[key] = SigTools.unwrapdiff(states[key].flatten(), base=65536, dtype=numpy.float64)

	sel, = numpy.where(dT['SourceTime'])
	for key in statenames:
		dT[key] = dT[key][sel[1:]]
		if key == 'SourceTime': tsel = sel[:-1]  # ??? why the shift
		else:                   tsel = sel[1:]   # ??? relative to here?
		T[key] = T[key][tsel+1]

	t0 = T['SourceTime'][0]
	for key in statenames: T[key] -= t0

	t = T['SourceTime'] / 1000

	expected = b.samples2msec(b.params['SampleBlockSize'])
	datestamp = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(b.datestamp))
	paramstr = ', '.join(['%s=%s' % (x,b.params[x]) for x in ['SampleBlockSize', 'SamplingRate', 'VisualizeTiming', 'VisualizeSource']])
	chainstr = '-'.join([x for x,y in b.params['SignalSourceFilterChain']+b.params['SignalProcessingFilterChain']+b.params['ApplicationFilterChain']])
	titlestr = '\n'.join([b.filename, datestamp, paramstr, chainstr])

	SigTools.plot(t[[0,-1]], [expected]*2, drawnow=False)
	SigTools.plot(t, dT['SourceTime'], hold=True, drawnow=False)

	for key in statenames:
		if key == 'SourceTime': continue
		rT[key] = T[key] - T['SourceTime']
		SigTools.plot(t, rT[key], hold=True, drawnow=False)
	
	import pylab
	pylab.title(titlestr)
	pylab.grid(True)
	pylab.xlabel('seconds')
	pylab.ylabel('milliseconds')
	ymin,ymax = pylab.ylim(); pylab.ylim(ymax=max(ymax,expected*2))
	pylab.xlim(xmax=t[-1])
	pylab.draw()
	out.params = SigTools.sstruct(b.params)
	out.summarystr = titlestr
	out.t = t
	out.SourceTime = T['SourceTime']
	out.StimulusTime = T['StimulusTime']
	out.BlockDuration = dT['SourceTime']
	out.BlockDuration2 = dT['StimulusTime']
	out.ProcessingTime = out.StimulusTime - out.SourceTime
	out.ExpectedBlockDuration = expected
	out.rT = rT
	out.dT = dT
	out.T = T
	
	if save:
		pylab.gcf().savefig(save, orientation='landscape')
	
	return out
Example #16
0
def bci2000chain(datfile, chain='SpectralSignalProcessing', parms=(), dims='auto', start=None, duration=None, verbose=False, keep=False, binpath=None, **kwargs):
	"""
	
	This is a provisional Python equivalent for tools/matlab/bci2000chain.m
	
	Excuse the relative paths - this example works if you're currently working in the root
	directory of the BCI2000 distro:
	
	    bci2000chain(datfile='data/samplefiles/eeg3_2.dat',
	                 chain='TransmissionFilter|SpatialFilter|ARFilter',
	                 binpath='tools/cmdline',
	                 parms=['tools/matlab/ExampleParameters.prm'], SpatialFilterType=3)
	
	TODO: documentation
	
	For clues, see http://www.bci2000.org/wiki/index.php/User_Reference:Matlab_Tools
	and matlab documentation in bci2000chain.m
	
	Most arguments are like the flags in the Matlab equivalent. mutatis to a large extent mutandis.
	Note that for now there is no global way of managing the system path. Either add the
	tools/cmdline directory to your $PATH variable before starting Python, or supply its location
	every time while calling, in the <binpath> argument.

	TODO: test on Windoze
	
	"""###
	
	if isinstance(chain, basestring):
		if chain.lower() == 'SpectralSignalProcessing'.lower():
			chain = 'TransmissionFilter|SpatialFilter|SpectralEstimator|LinearClassifier|LPFilter|ExpressionFilter|Normalizer'
		elif chain.lower() == 'ARSignalProcessing'.lower():
			chain = 'TransmissionFilter|SpatialFilter|ARFilter|LinearClassifier|LPFilter|ExpressionFilter|Normalizer'
		elif chain.lower() == 'P3SignalProcessing'.lower():
			chain = 'TransmissionFilter|SpatialFilter|P3TemporalFilter|LinearClassifier'
		chain = chain.split('|')
	chain = [c.strip() for c in chain if len(c.strip())]
	
	if len(chain) == 0: print('WARNING: chain is empty')
	
	if start != None and len(str(start).strip()): start = ' --start=' + str(start).replace(' ', '')
	else: start = ''
	if duration != None and len(str(duration).strip()): duration = ' --duration=' + str(duration).replace(' ', '')
	else: duration = ''
	
	if dims == None or str(dims).lower() == 'auto': dims = 0
	if dims not in (0, 2, 3): raise ValueError("dims must be 2, 3 or 'auto'")
	
	out = SigTools.sstruct()
	err = ''
	
	cmd = ''
	binaries = []
	tmpdir = tempfile.mkdtemp(prefix='bci2000chain_')
	
	tmpdatfile  = os.path.join(tmpdir, 'in.dat')
	prmfile_in  = os.path.join(tmpdir, 'in.prm')
	prmfile_out = os.path.join(tmpdir, 'out.prm')
	matfile     = os.path.join(tmpdir, 'out.mat')
	bcifile     = os.path.join(tmpdir, 'out.bci')
	shfile      = os.path.join(tmpdir, 'go.bat')
	logfile     = os.path.join(tmpdir, 'log.txt')
	
	if not isinstance(datfile, basestring):
		raise ValueError('datfile must be a filename') # TODO: if datfile contains the appropriate info, use some create_bcidat equivalent and do datfile = tmpdatfile
	
	if not os.path.isfile(datfile): raise IOError('file not found: %s' % datfile)
	
	mappings = {
		'$DATFILE':     datfile,
		'$PRMFILE_IN':  prmfile_in,
		'$PRMFILE_OUT': prmfile_out,
		'$MATFILE':     matfile,
		'$BCIFILE':     bcifile,
		'$SHFILE':      shfile,
		'$LOGFILE':     logfile,
	}
	
	def exe(name):
		if binpath: return os.path.join(binpath, name)
		else: return name
	
	if parms == None: parms = []
	if isinstance(parms, tuple): parms = list(parms)
	if not isinstance(parms, list): parms = [parms]
	else: parms = list(parms)
	if len(kwargs): parms.append(kwargs)
	
	if parms == None or len(parms) == 0:
		cmd += exe('bci_dat2stream') + start + duration + ' < "$DATFILE"'
		binaries.append('bci_dat2stream')
	else:
		if verbose: print '# writing custom parameter file %s' % prmfile_in
		parms = make_bciprm(datfile, parms, '>', prmfile_in)
		
		if dat2stream_has_p_flag:
			cmd += exe('bci_dat2stream') + ' -p$PRMFILE_IN' + start + duration + ' < "$DATFILE"' # new-style bci_dat2stream with -p option
			binaries.append('bci_dat2stream')
		else:
			if len(start) or len(duration):
				raise ValueError('old versionsof bci_dat2stream have no --start or --duration option')
			cmd += '('   # old-style bci_dat2stream with no -p option
			cmd +=         exe('bci_prm2stream') + ' < $PRMFILE_IN'
			cmd += '&& ' + exe('bci_dat2stream') + ' --transmit-sd < $DATFILE'
			cmd += ')'
			binaries.append('bci_dat2stream')
			binaries.append('bci_prm2stream')

	for c in chain:
		cmd += ' | ' + exe(c)
		binaries.append(c)

	if stream2mat_saves_parms:
		cmd += ' | ' + exe('bci_stream2mat') + ' > $MATFILE' # new-style bci_stream2mat with Parms output
		binaries.append('bci_stream2mat')
	else:
		cmd += ' > $BCIFILE'
		cmd += ' && ' + exe('bci_stream2mat') + ' < $BCIFILE > $MATFILE'
		cmd += ' && ' + exe('bci_stream2prm') + ' < $BCIFILE > $PRMFILE_OUT' # old-style bci_stream2mat without Parms output
		binaries.append('bci_stream2mat')
		binaries.append('bci_stream2prm')
	
	for k,v in mappings.items(): cmd = cmd.replace(k, v)
	
	win = sys.platform.lower().startswith('win')
	if win: shebang = ''
	else: shebang = '#!/bin/sh\n\n'
	open(shfile, 'wt').write(shebang+cmd+'\n')
	if not win: os.chmod(shfile, 484) # rwxr--r--
	
	def tidytext(x):
		return x.strip().replace('\r\n', '\n').replace('\r', '\n')
	def getstatusoutput(cmd):
		pipe = os.popen(cmd + ' 2>&1', 'r') # TODO: does this work on Windows? could always make use of logfile here if not
		output = pipe.read()
		status = pipe.close()
		if status == None: status = 0
		return status,tidytext(output)
	def getoutput(cmd):
		return getstatusoutput(cmd)[1]

	if verbose: print '# querying version information'	
	binaries = SigTools.sstruct([(bin, getoutput(exe(bin) + ' --version').replace('\n', '  ') ) for bin in binaries])

	if verbose: print cmd
	t0 = time.time()	
	failed,output = getstatusoutput(shfile)
	chaintime = time.time() - t0
	
	failsig = 'Configuration Error: '
	if failsig in output: failed = 1
	printable_output = output
	printable_lines = output.split('\n')
	maxlines = 10
	if len(printable_lines) > maxlines:
		printable_output = '\n'.join(printable_lines[:maxlines] + ['[%d more lines omitted]' % (len(printable_lines) - maxlines)])
	if failed:
		if verbose: err = 'system call failed:\n' + printable_output # cmd has already been printed, so don't clutter things further
		else: err = 'system call failed:\n%s\n%s' % (cmd, printable_output)
		
	if err == '':
		if verbose: print '# loading %s' % matfile
		try:
			mat = SigTools.loadmat(matfile)
		except:
			err = "The chain must have failed: could not load %s\nShell output was as follows:\n%s" % (matfile, printable_output)
		else:
			if 'Data' not in mat:  err = "The chain must have failed: no 'Data' variable found in %s\nShell output was as follows:\n%s" % (matfile, printable_output)
			if 'Index' not in mat: err = "The chain must have failed: no 'Index' variable found in %s\nShell output was as follows:\n%s" % (matfile, printable_output)

	if err == '':
		out.FileName = datfile
		if stream2mat_saves_parms:
			if verbose: print '# decoding parameters loaded from the mat-file'
			parms = make_bciprm(mat.Parms[0])
		else:
			if verbose: print '# reading output parameter file' + prmfile_out
			parms = ParmList(prmfile_out) # if you get an error that prmfile_out does not exist, recompile your bci_dat2stream and bci_stream2mat binaries from up-to-date sources, and ensure that dat2stream_has_p_flag and stream2mat_saves_parms, at the top of this file, are both set to 1
		
		out.DateStr = read_bcidate(parms, 'ISO')
		out.DateNum = read_bcidate(parms)
		out.FilterChain = chain
		out.ToolVersions = binaries
		out.ShellInput = cmd
		out.ShellOutput = output
		out.ChainTime = chaintime
		out.ChainSpeedFactor = None
		out.Megabytes = None
		out.Parms = parms
		out.Parms.sort()
		
		mat.Index = mat.Index[0,0]
		sigind = mat.Index.Signal - 1 # indices vary across channels fastest, then elements
		nChannels,nElements = sigind.shape
		nBlocks = mat.Data.shape[1]
		out.Blocks = nBlocks
		out.BlocksPerSecond = float(parms.SamplingRate.ScaledValue) / float(parms.SampleBlockSize.ScaledValue)
		out.SecondsPerBlock = float(parms.SampleBlockSize.ScaledValue)  / float(parms.SamplingRate.ScaledValue)
		out.ChainSpeedFactor = float(out.Blocks * out.SecondsPerBlock) / float(out.ChainTime)
		
		def unnumpify(x):
			while isinstance(x, numpy.ndarray) and x.size == 1: x = x[0]
			if isinstance(x, (numpy.ndarray,tuple,list)): x = [unnumpify(xi) for xi in x]
			return x
		
		out.Channels = nChannels
		out.ChannelLabels = unnumpify(mat.get('ChannelLabels', []))
		try: out.ChannelSet = SigTools.ChannelSet(out.ChannelLabels)
		except: out.ChannelSet = None
		
		out.Elements = nElements
		out.ElementLabels = unnumpify(mat.get('ElementLabels', []))
		out.ElementValues = numpy.ravel(mat.get('ElementValues', []))
		out.ElementUnit = unnumpify(mat.get('ElementUnit', None))
		out.ElementRate = out.BlocksPerSecond * out.Elements
		
		out.Time = out.SecondsPerBlock * numpy.arange(0, nBlocks)
		out.FullTime = out.Time
		out.FullElementValues = out.ElementValues


		# Why does sigind have to be transposed before vectorizing to achieve the same result as sigind(:) WITHOUT a transpose in Matlab?
		# This will probably forever be one of the many deep mysteries of numpy dimensionality handling
		out.Signal = mat.Data[sigind.T.ravel(), :]  # nChannels*nElements - by - nBlocks
		out.Signal = out.Signal + 0.0 # make a contiguous copy
		
		def seconds(s):  # -1 means "no information", 0 means "not units of time",  >0 means the scaling factor
			if getattr(s, 'ElementUnit', None) in ('',None): return -1
			s = s.ElementUnit
			if   s.endswith('seconds'): s = s[:-6]
			elif s.endswith('second'):  s = s[:-5]
			elif s.endswith('sec'):     s = s[:-2]		
			if s.endswith('s'): return { 'ps': 1e-12, 'ns': 1e-9, 'us': 1e-6, 'mus': 1e-6, 'ms': 1e-3, 's' : 1e+0, 'ks': 1e+3, 'Ms': 1e+6, 'Gs': 1e+9, 'Ts': 1e+12,}.get(s, 0)
			else: return 0
	
		if dims == 0: # dimensionality has not been specified explicitly: so guess, based on ElementUnit and/or filter name
			# 3-dimensional output makes more sense than continuous 2-D whenever "elements" can't just be concatenated into an unbroken time-stream
			if len(chain): lastfilter = chain[-1].lower()
			else: lastfilter = ''
			if lastfilter == 'p3temporalfilter':
				dims = 3
			else:
				factor = seconds(out)
				if factor > 0:  # units of time.  TODO: could detect whether the out.ElementValues*factor are (close enough to) contiguous from block to block; then p3temporalfilter wouldn't have to be a special case above 
					dims = 2 
				elif factor == 0: # not units of time: use 3D by default
					dims = 3
				elif lastfilter in ['p3temporalfilter', 'arfilter', 'fftfilter', 'coherencefilter', 'coherencefftfilter']:  # no ElementUnit info? guess based on filter name
					dims = 3
				else:
					dims = 2
		
		if dims == 3:
			out.Signal = numpy.reshape(out.Signal, (nChannels, nElements, nBlocks), order='F') # nChannels - by - nElements - by - nBlocks
			out.Signal = numpy.transpose(out.Signal, (2,0,1))                                  # nBlocks - by - nChannels - by - nElements
			out.Signal = out.Signal + 0.0 # make a contiguous copy
		elif dims == 2:
			out.FullTime = numpy.repeat(out.Time, nElements)
			factor = seconds(out)
			if len(out.ElementValues):
				out.FullElementValues = numpy.tile(out.ElementValues, nBlocks)
				if factor > 0: out.FullTime = out.FullTime + out.FullElementValues * factor
			
			out.Signal = numpy.reshape(out.Signal, (nChannels, nElements*nBlocks), order='F')  # nChannels - by - nSamples
			out.Signal = numpy.transpose(out.Signal, (1,0))                                    # nSamples - by - nChannels
			out.Signal = out.Signal + 0.0 # make a contiguous copy
		else:
			raise RuntimeError('internal error')
		
		out.States = SigTools.sstruct()
		states = [(k,int(getattr(mat.Index, k))-1) for k in mat.Index._fieldnames if k != 'Signal']
		for k,v in states: setattr(out.States, k, mat.Data[v,:])
		# TODO: how do the command-line tools handle event states? this seems to be set up to deliver one value per block whatever kind of state we're dealing with
		
		out.Megabytes = megs(out)
	else:
		out = err
		keep = True
		print err
		print
	
		
	if os.path.isdir(tmpdir):
		files = sorted([os.path.join(tmpdir, file) for file in os.listdir(tmpdir) if file not in ['.','..']])
		if keep: print 'The following commands should be executed to clean up the temporary files:'
		elif verbose: print '# removing temp files and directory ' + tmpdir
		
		for file in files:
			if keep: print "    os.remove('%s')" % file
			else: os.remove(file)
Example #17
0
def test(codebook,
         model=None,
         alphabet=None,
         txt=None,
         balanced_acc=0.75,
         default_context='. ',
         threshold=1.0,
         min_epochs=1,
         max_epochs=72,
         minprob=None,
         exponent=1.0,
         w=80):

    modelin = model
    if model == None: model = TextPrediction.LanguageModel(alphabet=alphabet)
    if alphabet == None: alphabet = TextPrediction.default_alphabet36
    if txt == None: txt = TextPrediction.default_test_text

    alphabet = model.match_case(alphabet)
    if len(set(alphabet) - set(model.alphabet)) > 0:
        raise ValueError(
            "some characters in the alphabet are not in the model")

    if len(txt) < 1000 and os.path.isfile(TextPrediction.FindFile(txt)):
        txt = TextPrediction.ReadUTF8(TextPrediction.FindFile(txt))
    # Things will fail if some characters in the test text are not in alphabet. So...
    if 'this is a massive hack':
        oldalphabet = model.alphabet
        model.alphabet = set(oldalphabet).intersection(alphabet)
        txt = model.clean(txt)
        model.alphabet = oldalphabet

    if isinstance(codebook, dict) and 'Matrix' in codebook:
        codebook = codebook['Matrix']

    if isinstance(codebook, basestring):
        codebook = [[int(i) for i in row.split()]
                    for row in codebook.strip().replace('.', '0').replace(
                        '-', '0').split('\n')]
    codebook = numpy.asarray(codebook)

    N, L = codebook.shape
    if N != len(alphabet):
        raise ValueError(
            "number of rows in the codebook should be equal to the number of symbols in the alphabet"
        )

    out = txt[:0]
    nchars = len(txt)
    correct = numpy.zeros((nchars, ), dtype=numpy.bool)
    nepochs = numpy.zeros((nchars, ), dtype=numpy.int16)
    z = SigTools.invcg(balanced_acc)

    confusion = numpy.zeros((2, 2), dtype=numpy.float64)
    correct_running = SigTools.running_mean()
    nepochs_running = SigTools.running_mean()
    try:
        for i in range(nchars):
            correct_symbol = txt[i]
            correct_index = alphabet.index(
                correct_symbol
            )  # will throw an error if not found, and rightly so
            context = model.prepend_context(txt[:i], default_context)
            d = TextPrediction.Decoder(choices=alphabet,
                                       context=context,
                                       model=model,
                                       verbose=False,
                                       threshold=threshold,
                                       min_epochs=min_epochs,
                                       max_epochs=max_epochs,
                                       minprob=minprob,
                                       exponent=exponent)
            result = None
            while result == None:
                #__IPYTHON__.dbstop()
                col = codebook[:, d.L % L]
                Cij = int(col[correct_index])
                d.new_column(col)
                mean = z * {0: -1, 1: 1}[Cij]
                x = numpy.random.randn() + mean
                p = SigTools.logistic(2.0 * z * x)
                confusion[Cij, int(round(p))] += 1
                result = d.new_transmission(p)
            nepochs[i] = d.L
            correct[i] = (result == correct_symbol)
            out = out + result
            nepochs_running += d.L
            correct_running += (result == correct_symbol)
            if w:
                sys.stdout.write(result)
                if (i + 1) % w == 0:
                    sys.stdout.write('   %3d\n' %
                                     round(100.0 * float(i + 1) / nchars))
                elif (i + 1) == nchars:
                    sys.stdout.write('\n')
    except KeyboardInterrupt:
        if w: sys.stdout.write('\n')
        pass

    ndone = len(out)

    s = sstruct()
    s.alphabet = alphabet
    s.codebook = codebook
    s.model = modelin
    s.input = txt[:ndone]
    s.output = out
    s.conditions = sstruct()
    s.conditions.threshold = threshold
    s.conditions.min_epochs = min_epochs
    s.conditions.max_epochs = max_epochs
    s.conditions.minprob = minprob
    s.conditions.exponent = exponent
    s.epoch_acc = sstruct()
    s.epoch_acc.desired = balanced_acc
    s.epoch_acc.empirical_nontarget, s.epoch_acc.empirical_target = (
        confusion.diagonal() / confusion.sum(axis=1)).flat
    s.epoch_acc.confusion = confusion
    s.nepochs = sstruct()
    s.nepochs.each = nepochs[:ndone]
    s.nepochs.mean = nepochs_running.m
    s.nepochs.std = nepochs_running.v_unbiased**0.5
    s.nepochs.ste = s.nepochs.std / nepochs_running.n**0.5
    s.letter_acc = sstruct()
    s.letter_acc.each = correct[:ndone]
    s.letter_acc.mean = correct_running.m
    s.letter_acc.std = correct_running.v_unbiased**0.5
    s.letter_acc.ste = s.letter_acc.std / correct_running.n**0.5
    return s
Example #18
0
def ClassifyERPs (
		featurefile,
		C = (10.0, 1.0, 0.1, 0.01),
		gamma = (1.0, 0.8, 0.6, 0.4, 0.2, 0.0),
		rmchan = (),
		rebias = True,
		save = False,
		description='ERPs to attended vs unattended events',
		maxcount=None,
	):

	d = DataFiles.load(featurefile, catdim=0, maxcount=maxcount)

	x = d['x']
	y = numpy.array(d['y'].flat)
	n = len(y)
	uy = numpy.unique(y)
	if uy.size != 2: raise ValueError("expected 2 classes in dataset, found %d" % uy.size)
	y = numpy.sign(y - uy.mean())

	cov,trchvar = SigTools.spcov(x=x, y=y, balance=False, return_trchvar=True) # NB: symwhitenkern would not be able to balance
	
	starttime = time.time()
	
	if isinstance(rmchan, basestring): rmchan = rmchan.split()
	allrmchan = tuple([ch.lower() for ch in rmchan]) + ('audl','audr','laud','raud','sync','vsync', 'vmrk', 'oldref')
	chlower = [ch.lower() for ch in d['channels']]
	unwanted = numpy.array([ch in allrmchan for ch in chlower])
	wanted = numpy.logical_not(unwanted)
	notfound = [ch for ch in rmchan if ch.lower() not in chlower]
	print ' '
	if len(notfound): print "WARNING: could not find channel%s %s\n" % ({1:''}.get(len(notfound),'s'), ', '.join(notfound))
	removed = [ch for removing,ch in zip(unwanted, d['channels']) if removing]
	if len(removed): print "removed %d channel%s (%s)" % (len(removed), {1:''}.get(len(removed),'s'), ', '.join(removed))
	print "classification will be based on %d channel%s" % (sum(wanted), {1:''}.get(sum(wanted),'s'))
	print "%d negatives + %d positives = %d exemplars" % (sum(y<0), sum(y>0), n)
	print ' '
	
	x[:, unwanted, :] = 0
	cov[:, unwanted] = 0
	cov[unwanted, :] = 0
	nu = numpy.asarray(cov).diagonal()[wanted].mean()
	for i in range(len(cov)):
		if cov[i,i] == 0: cov[i,i] = nu
	
	if not isinstance(C, (tuple,list,numpy.ndarray,type(None))): C = [C]
	if not isinstance(gamma, (tuple,list,numpy.ndarray,type(None))): gamma = [gamma]

	c = SigTools.klr2class(lossfunc=SigTools.balanced_loss, relcost='balance').varyhyper({})
	if c != None: c.hyper.C=list(C)
	if gamma == None: c.hyper.kernel.func = SigTools.linkern
	else: c.varyhyper({'kernel.func':SigTools.symwhitenkern, 'kernel.cov':[cov], 'kernel.gamma':list(gamma)})
	c.cvtrain(x=x,y=y)
	if rebias: c.rebias()
	c.calibrate()

	chosen = c.cv.chosen.hyper
	if gamma == None:
		Ps = None
		Gp = c.featureweight(x=x)
	else:
		Ps = SigTools.svd(SigTools.shrinkcov(cov, copy=True, gamma=chosen.kernel.gamma)).isqrtm
		xp = SigTools.spfilt(x, Ps.H, copy=True)
		Gp = c.featureweight(x=xp)
	
	u = SigTools.stfac(Gp, Ps)
	u.channels = d['channels']
	u.channels_used = wanted
	u.fs = d['fs']
	u.trchvar = trchvar
	
	elapsed = time.time() - starttime
	minutes = int(elapsed/60.0)
	seconds = int(round(elapsed - minutes * 60.0))
	print '%d min %d sec' % (minutes, seconds)
	datestamp = time.strftime('%Y-%m-%d %H:%M:%S')
	csummary = '%s (%s) trained on %d (CV %s = %.3f) at %s' % (
		c.__class__.__name__,
		SigTools.experiment()._shortdesc(chosen),
		sum(c.input.istrain),
		c.loss.func.__name__,
		c.loss.train,
		datestamp,
	)
	description = 'binary classification of %s: %s' % (description, csummary)
	u.description = description
	
	if save:
		if not isinstance(save, basestring):
			save = featurefile
			if isinstance(save, (tuple,list)): save = save[-1]
			if save.lower().endswith('.gz'): save = save[:-3]
			if save.lower().endswith('.pk'): save = save[:-3]
			save = save + '_weights.prm'
		print "\nsaving %s\n" % save
		Parameters.Param(u.G.A, name='ERPClassifierWeights', tab='PythonSig', section='Epoch', comment=csummary).writeto(save)
		Parameters.Param(c.model.bias, name='ERPClassifierBias', tab='PythonSig', section='Epoch', comment=csummary).appendto(save)
		Parameters.Param(description, name='SignalProcessingDescription', tab='PythonSig').appendto(save)
	return u,c
Example #19
0
def bci2000chain(datfile,
                 chain='SpectralSignalProcessing',
                 parms=(),
                 dims='auto',
                 start=None,
                 duration=None,
                 verbose=False,
                 keep=False,
                 binpath=None,
                 **kwargs):
    """
	
	This is a provisional Python port of the Matlab-based tools/matlab/bci2000chain.m
	
	Excuse the relative paths - this example works if you're currently working in the root
	directory of the BCI2000 distro:
	
	    bci2000chain( datfile='data/samplefiles/eeg3_2.dat',
	                  chain='TransmissionFilter|SpatialFilter|ARFilter',
	                  binpath='tools/cmdline',
	                  parms=[ 'tools/matlab/ExampleParameters.prm' ], SpatialFilterType=3 )
	
	Or for more portable operation, you can do this kind of thing:
	
	    bci2000root( '/PATH/TO/BCI2000' )
	    
	    bci2000chain(datfile=bci2000path( 'data/samplefiles/eeg3_2.dat' ),
	                 chain='TransmissionFilter|SpatialFilter|ARFilter',
	                 parms=[ bci2000path( 'tools/matlab/ExampleParameters.prm' ) ], SpatialFilterType=3 )
	
	
	For clues, see http://www.bci2000.org/wiki/index.php/User_Reference:Matlab_Tools
	and matlab documentation in bci2000chain.m
	
	Most arguments are like the flags in the Matlab equivalent. mutatis to a large extent mutandis.
	Note that for now there is no global way of managing the system path. Either add the
	tools/cmdline directory to your $PATH variable before starting Python, or supply its location
	every time while calling, in the <binpath> argument.

	"""###

    if verbose: verbosityFlag = '-v'
    else: verbosityFlag = '-q'

    if isinstance(chain, str):
        if chain.lower() == 'SpectralSignalProcessing'.lower():
            chain = 'TransmissionFilter|SpatialFilter|SpectralEstimator|LinearClassifier|LPFilter|ExpressionFilter|Normalizer'
        elif chain.lower() == 'ARSignalProcessing'.lower():
            chain = 'TransmissionFilter|SpatialFilter|ARFilter|LinearClassifier|LPFilter|ExpressionFilter|Normalizer'
        elif chain.lower() == 'P3SignalProcessing'.lower():
            chain = 'TransmissionFilter|SpatialFilter|P3TemporalFilter|LinearClassifier'
        chain = chain.split('|')
    chain = [c.strip() for c in chain if len(c.strip())]

    if len(chain) == 0: print('WARNING: chain is empty')

    if start != None and len(str(start).strip()):
        start = ' --start=' + str(start).replace(' ', '')
    else:
        start = ''
    if duration != None and len(str(duration).strip()):
        duration = ' --duration=' + str(duration).replace(' ', '')
    else:
        duration = ''

    if dims == None or str(dims).lower() == 'auto': dims = 0
    if dims not in (0, 2, 3): raise ValueError("dims must be 2, 3 or 'auto'")

    out = CONTAINER()
    err = ''

    cmd = ''
    binaries = []
    tmpdir = tempfile.mkdtemp(prefix='bci2000chain_')

    tmpdatfile = os.path.join(tmpdir, 'in.dat')
    prmfile_in = os.path.join(tmpdir, 'in.prm')
    prmfile_out = os.path.join(tmpdir, 'out.prm')
    matfile = os.path.join(tmpdir, 'out.mat')
    bcifile = os.path.join(tmpdir, 'out.bci')
    shfile = os.path.join(tmpdir, 'go.bat')
    logfile = os.path.join(tmpdir, 'log.txt')

    if not isinstance(datfile, str):
        raise ValueError(
            'datfile must be a filename'
        )  # TODO: if datfile contains the appropriate info, use some create_bcidat equivalent and do datfile = tmpdatfile

    if not os.path.isfile(datfile):
        raise IOError('file not found: %s' % datfile)

    mappings = {
        '$DATFILE': datfile,
        '$PRMFILE_IN': prmfile_in,
        '$PRMFILE_OUT': prmfile_out,
        '$MATFILE': matfile,
        '$BCIFILE': bcifile,
        '$SHFILE': shfile,
        '$LOGFILE': logfile,
    }

    if binpath == None and BCI2000_ROOT_DIR != None: binpath = 'tools/cmdline'
    binpath = bci2000path(binpath)

    def exe(name):
        if binpath: return '"' + os.path.join(binpath, name) + '"'
        else: return name

    if parms == None: parms = []
    if isinstance(parms, tuple): parms = list(parms)
    if not isinstance(parms, list): parms = [parms]
    else: parms = list(parms)
    if len(kwargs): parms.append(kwargs)

    if parms == None or len(parms) == 0:
        cmd += exe('bci_dat2stream') + start + duration + ' < "$DATFILE"'
        binaries.append('bci_dat2stream')
    else:
        if verbose: print('# writing custom parameter file %s' % prmfile_in)
        parms = make_bciprm(verbosityFlag, datfile, parms, '>', prmfile_in)

        if dat2stream_has_p_flag:
            cmd += exe(
                'bci_dat2stream'
            ) + ' "-p$PRMFILE_IN"' + start + duration + ' < "$DATFILE"'  # new-style bci_dat2stream with -p option
            binaries.append('bci_dat2stream')
        else:
            if len(start) or len(duration):
                raise ValueError(
                    'old versionsof bci_dat2stream have no --start or --duration option'
                )
            cmd += '('  # old-style bci_dat2stream with no -p option
            cmd += exe('bci_prm2stream') + ' < "$PRMFILE_IN"'
            cmd += '&& ' + exe(
                'bci_dat2stream') + ' --transmit-sd < "$DATFILE"'
            cmd += ')'
            binaries.append('bci_dat2stream')
            binaries.append('bci_prm2stream')

    for c in chain:
        cmd += ' | ' + exe(c)
        binaries.append(c)

    if stream2mat_saves_parms:
        cmd += ' | ' + exe(
            'bci_stream2mat'
        ) + ' > "$MATFILE"'  # new-style bci_stream2mat with Parms output
        binaries.append('bci_stream2mat')
    else:
        cmd += ' > "$BCIFILE"'
        cmd += ' && ' + exe('bci_stream2mat') + ' < "$BCIFILE" > "$MATFILE"'
        cmd += ' && ' + exe(
            'bci_stream2prm'
        ) + ' < "$BCIFILE" > "$PRMFILE_OUT"'  # old-style bci_stream2mat without Parms output
        binaries.append('bci_stream2mat')
        binaries.append('bci_stream2prm')

    for k, v in list(mappings.items()):
        cmd = cmd.replace(k, v)

    win = sys.platform.lower().startswith('win')
    if win: shebang = '@'
    else: shebang = '#!/bin/sh\n\n'
    open(shfile, 'wt').write(shebang + cmd + '\n')
    if not win: os.chmod(shfile, 484)  # rwxr--r--

    def tidytext(x):
        return x.strip().replace('\r\n', '\n').replace('\r', '\n')

    def getstatusoutput(cmd):
        pipe = os.popen(
            cmd + ' 2>&1', 'r'
        )  # TODO: does this work on Windows? could always make use of logfile here if not
        output = pipe.read()
        status = pipe.close()
        if status == None: status = 0
        return status, tidytext(output)

    def getoutput(cmd):
        return getstatusoutput(cmd)[1]

    if verbose: print('# querying version information')
    binaries = CONTAINER([
        (bin, getoutput(exe(bin) + ' --version').replace('\n', '  '))
        for bin in binaries
    ])

    if verbose: print(cmd)
    t0 = time.time()
    failed, output = getstatusoutput(shfile)
    chaintime = time.time() - t0

    failsig = 'Configuration Error: '
    if failsig in output: failed = 1
    printable_output = output
    printable_lines = output.split('\n')
    maxlines = 10
    if len(printable_lines) > maxlines:
        printable_output = '\n'.join(
            printable_lines[:maxlines] +
            ['[%d more lines omitted]' % (len(printable_lines) - maxlines)])
    if failed:
        if verbose:
            err = 'system call failed:\n' + printable_output  # cmd has already been printed, so don't clutter things further
        else:
            err = 'system call failed:\n\n%s\n\n%s' % (cmd, printable_output)

    if err == '':
        if verbose: print('# loading %s' % matfile)
        try:
            mat = ReadMatFile(matfile)
        except:
            err = "The chain must have failed: could not load %s\nShell output was as follows:\n%s" % (
                matfile, printable_output)
        else:
            if 'Data' not in mat:
                err = "The chain must have failed: no 'Data' variable found in %s\nShell output was as follows:\n%s" % (
                    matfile, printable_output)
            if 'Index' not in mat:
                err = "The chain must have failed: no 'Index' variable found in %s\nShell output was as follows:\n%s" % (
                    matfile, printable_output)

    if err == '':
        out.FileName = datfile
        if stream2mat_saves_parms:
            if verbose: print('# decoding parameters loaded from the mat-file')
            parms = make_bciprm(verbosityFlag, mat['Parms'][0])
        else:
            if verbose: print('# reading output parameter file' + prmfile_out)
            parms = ParmList(
                prmfile_out
            )  # if you get an error that prmfile_out does not exist, recompile your bci_dat2stream and bci_stream2mat binaries from up-to-date sources, and ensure that dat2stream_has_p_flag and stream2mat_saves_parms, at the top of this file, are both set to 1

        out.DateStr = read_bcidate(parms, 'ISO')
        out.DateNum = read_bcidate(parms)
        out.FilterChain = chain
        out.ToolVersions = binaries
        out.ShellInput = cmd
        out.ShellOutput = output
        out.ChainTime = chaintime
        out.ChainSpeedFactor = None
        out.Megabytes = None
        out.Parms = parms
        out.Parms.sort()

        mat['Index'] = mat['Index'][0, 0]
        sigind = mat[
            'Index'].Signal - 1  # indices vary across channels fastest, then elements
        nChannels, nElements = sigind.shape
        nBlocks = mat['Data'].shape[1]
        out.Blocks = nBlocks
        out.BlocksPerSecond = float(parms.SamplingRate.ScaledValue) / float(
            parms.SampleBlockSize.ScaledValue)
        out.SecondsPerBlock = float(parms.SampleBlockSize.ScaledValue) / float(
            parms.SamplingRate.ScaledValue)
        out.ChainSpeedFactor = float(out.Blocks * out.SecondsPerBlock) / float(
            out.ChainTime)

        def unnumpify(x):
            while isinstance(x, numpy.ndarray) and x.size == 1:
                x = x[0]
            if isinstance(x, (numpy.ndarray, tuple, list)):
                x = [unnumpify(xi) for xi in x]
            return x

        out.Channels = nChannels
        out.ChannelLabels = unnumpify(mat.get('ChannelLabels', []))
        try:
            out.ChannelSet = SigTools.ChannelSet(out.ChannelLabels)
        except:
            out.ChannelSet = None

        out.Elements = nElements
        out.ElementLabels = unnumpify(mat.get('ElementLabels', []))
        out.ElementValues = numpy.ravel(mat.get('ElementValues', []))
        out.ElementUnit = unnumpify(mat.get('ElementUnit', None))
        out.ElementRate = out.BlocksPerSecond * out.Elements

        out.Time = out.SecondsPerBlock * numpy.arange(0, nBlocks)
        out.FullTime = out.Time
        out.FullElementValues = out.ElementValues

        # Why does sigind have to be transposed before vectorizing to achieve the same result as sigind(:) WITHOUT a transpose in Matlab?
        # This will probably forever be one of the many deep mysteries of numpy dimensionality handling
        out.Signal = mat['Data'][
            sigind.T.ravel(), :]  # nChannels*nElements - by - nBlocks
        out.Signal = out.Signal + 0.0  # make a contiguous copy

        def seconds(
            s
        ):  # -1 means "no information", 0 means "not units of time",  >0 means the scaling factor
            if getattr(s, 'ElementUnit', None) in ('', None,
                                                   ()) or s.ElementUnit == []:
                return -1
            s = s.ElementUnit
            if s.endswith('seconds'): s = s[:-6]
            elif s.endswith('second'): s = s[:-5]
            elif s.endswith('sec'): s = s[:-2]
            if s.endswith('s'):
                return {
                    'ps': 1e-12,
                    'ns': 1e-9,
                    'us': 1e-6,
                    'mus': 1e-6,
                    'ms': 1e-3,
                    's': 1e+0,
                    'ks': 1e+3,
                    'Ms': 1e+6,
                    'Gs': 1e+9,
                    'Ts': 1e+12,
                }.get(s, 0)
            else:
                return 0

        if dims == 0:  # dimensionality has not been specified explicitly: so guess, based on ElementUnit and/or filter name
            # 3-dimensional output makes more sense than continuous 2-D whenever "elements" can't just be concatenated into an unbroken time-stream
            if len(chain): lastfilter = chain[-1].lower()
            else: lastfilter = ''
            if lastfilter == 'p3temporalfilter':
                dims = 3
            else:
                factor = seconds(out)
                if factor > 0:  # units of time.  TODO: could detect whether the out.ElementValues*factor are (close enough to) contiguous from block to block; then p3temporalfilter wouldn't have to be a special case above
                    dims = 2
                elif factor == 0:  # not units of time: use 3D by default
                    dims = 3
                elif lastfilter in [
                        'p3temporalfilter', 'arfilter', 'fftfilter',
                        'coherencefilter', 'coherencefftfilter'
                ]:  # no ElementUnit info? guess based on filter name
                    dims = 3
                else:
                    dims = 2

        if dims == 3:
            out.Signal = numpy.reshape(
                out.Signal, (nChannels, nElements, nBlocks),
                order='F')  # nChannels - by - nElements - by - nBlocks
            out.Signal = numpy.transpose(
                out.Signal,
                (2, 0, 1))  # nBlocks - by - nChannels - by - nElements
            out.Signal = out.Signal + 0.0  # make a contiguous copy
        elif dims == 2:
            out.FullTime = numpy.repeat(out.Time, nElements)
            factor = seconds(out)
            if len(out.ElementValues):
                out.FullElementValues = numpy.tile(out.ElementValues, nBlocks)
                if factor > 0:
                    out.FullTime = out.FullTime + out.FullElementValues * factor

            out.Signal = numpy.reshape(out.Signal,
                                       (nChannels, nElements * nBlocks),
                                       order='F')  # nChannels - by - nSamples
            out.Signal = numpy.transpose(out.Signal,
                                         (1, 0))  # nSamples - by - nChannels
            out.Signal = out.Signal + 0.0  # make a contiguous copy
        else:
            raise RuntimeError('internal error')

        out.States = CONTAINER()
        try:
            fieldnames = mat['Index']._fieldnames
        except AttributeError:
            items = list(mat['Index'].items())
        else:
            items = [(k, getattr(mat['Index'], k)) for k in fieldnames]
        states = [(k, int(v) - 1) for k, v in items if k != 'Signal']
        for k, v in states:
            setattr(out.States, k, mat['Data'][v, :])
        # TODO: how do the command-line tools handle event states? this seems to be set up to deliver one value per block whatever kind of state we're dealing with

        out.Megabytes = megs(out)
    else:
        out = err
        keep = True
        print(err)
        print()

    if os.path.isdir(tmpdir):
        files = sorted([
            os.path.join(tmpdir, file) for file in os.listdir(tmpdir)
            if file not in ['.', '..']
        ])
        if keep:
            print(
                'The following commands should be executed to clean up the temporary files:'
            )
        elif verbose:
            print('# removing temp files and directory ' + tmpdir)

        for file in files:
            if keep: print("os.remove(r'%s')" % file)
            else:
                try:
                    os.remove(file)
                except Exception as err:
                    sys.stderr.write("failed to remove %s:\n    %s\n" %
                                     (file, str(err)))

        if keep: print("os.rmdir(r'%s')" % tmpdir)
        else:
            try:
                os.rmdir(tmpdir)
            except Exception as err:
                sys.stderr.write("failed to remove %s:\n    %s" %
                                 (tmpdir, str(err)))
        if keep: print("")

    return out
def ClassifyERPs (
		featurefiles,
		C = (10.0, 1.0, 0.1, 0.01),
		gamma = (1.0, 0.8, 0.6, 0.4, 0.2, 0.0),
		keepchan = (),
		rmchan = (),
		rmchan_usualsuspects = ('AUDL','AUDR','LAUD','RAUD','SYNC','VSYNC', 'VMRK', 'OLDREF'),
		rebias = True,
		save = False,
		select = False,
		description='ERPs to attended vs unattended events',
		maxcount=None,
		classes=None,
		folds=None,
		time_window=None,
		keeptrials=None,
	):

	file_inventory = []
	d = DataFiles.load(featurefiles, catdim=0, maxcount=maxcount, return_details=file_inventory)
 	if isinstance(folds, basestring) and folds.lower() in ['lofo', 'loro', 'leave on run out', 'leave one file out']:
 		n, folds = 0, []
 		for each in file_inventory:
 			neach = each[1]['x']
 			folds.append(range(n, n+neach))
 			n += neach
 	
	if 'x' not in d: raise ValueError("found no trial data - no 'x' variable - in the specified files")
	if 'y' not in d: raise ValueError("found no trial labels - no 'y' variable - in the specified files")

	x = d['x']
	y = numpy.array(d['y'].flat)
	if keeptrials != None:
		x = x[numpy.asarray(keeptrials), :, :]
		y = y[numpy.asarray(keeptrials)]
		
	if time_window != None:
		fs = d['fs']
		t = SigTools.samples2msec(numpy.arange(x.shape[2]), fs)
		x[:, :, t<min(time_window)] = 0
		x[:, :, t>max(time_window)] = 0
		
		
	if classes != None:
		for cl in classes:
			if cl not in y: raise ValueError("class %s is not in the dataset" % str(cl))
		mask = numpy.array([yi in classes for yi in y])
		y = y[mask]
		x = x[mask]
		discarded = sum(mask==False)
		if discarded: print "discarding %d trials that are outside the requested classes %s"%(discarded,str(classes))
		
	n = len(y)
	uy = numpy.unique(y)
	if uy.size != 2: raise ValueError("expected 2 classes in dataset, found %d : %s" % (uy.size, str(uy)))
	for uyi in uy:
		nyi = sum([yi==uyi for yi in y])
		nyi_min = 2
		if nyi < nyi_min: raise ValueError('only %d exemplars of class %s - need at least %d' % (nyi,str(uyi),nyi_min))
			
	y = numpy.sign(y - uy.mean())

	cov,trchvar = SigTools.spcov(x=x, y=y, balance=False, return_trchvar=True) # NB: symwhitenkern would not be able to balance
	
	starttime = time.time()
	
	chlower = [ch.lower() for ch in d['channels']]
	if keepchan in [None,(),'',[]]:
		if isinstance(rmchan, basestring): rmchan = rmchan.split()
		if isinstance(rmchan_usualsuspects, basestring): rmchan_usualsuspects = rmchan_usualsuspects.split()
		allrmchan = [ch.lower() for ch in list(rmchan)+list(rmchan_usualsuspects)]
		unwanted = numpy.array([ch in allrmchan for ch in chlower])
		notfound = [ch for ch in rmchan if ch.lower() not in chlower]
	else:
		if isinstance(keepchan, basestring): keepchan = keepchan.split()
		lowerkeepchan = [ch.lower() for ch in keepchan]
		unwanted = numpy.array([ch not in lowerkeepchan for ch in chlower])
		notfound = [ch for ch in keepchan if ch.lower() not in chlower]
		
	wanted = numpy.logical_not(unwanted)
	print ' '
	if len(notfound): print "WARNING: could not find channel%s %s\n" % ({1:''}.get(len(notfound),'s'), ', '.join(notfound))
	removed = [ch for removing,ch in zip(unwanted, d['channels']) if removing]
	if len(removed): print "removed %d channel%s (%s)" % (len(removed), {1:''}.get(len(removed),'s'), ', '.join(removed))
	print "classification will be based on %d channel%s" % (sum(wanted), {1:''}.get(sum(wanted),'s'))
	print "%d negatives + %d positives = %d exemplars" % (sum(y<0), sum(y>0), n)
	print ' '
	
	x[:, unwanted, :] = 0
	cov[:, unwanted] = 0
	cov[unwanted, :] = 0
	nu = numpy.asarray(cov).diagonal()[wanted].mean()
	for i in range(len(cov)):
		if cov[i,i] == 0: cov[i,i] = nu
	
	if not isinstance(C, (tuple,list,numpy.ndarray,type(None))): C = [C]
	if not isinstance(gamma, (tuple,list,numpy.ndarray,type(None))): gamma = [gamma]

	c = SigTools.klr2class(lossfunc=SigTools.balanced_loss, relcost='balance')
	c.varyhyper({})
	if c != None: c.hyper.C=list(C)
	if gamma == None: c.hyper.kernel.func = SigTools.linkern
	else: c.varyhyper({'kernel.func':SigTools.symwhitenkern, 'kernel.cov':[cov], 'kernel.gamma':list(gamma)})
	c.cvtrain(x=x, y=y, folds=folds)
	if rebias: c.rebias()
	c.calibrate()

	chosen = c.cv.chosen.hyper
	if gamma == None:
		Ps = None
		Gp = c.featureweight(x=x)
	else:
		Ps = SigTools.svd(SigTools.shrinkcov(cov, copy=True, gamma=chosen.kernel.gamma)).isqrtm
		xp = SigTools.spfilt(x, Ps.H, copy=True)
		Gp = c.featureweight(x=xp)
	
	u = SigTools.stfac(Gp, Ps)
	u.channels = d['channels']		
	u.channels_used = wanted
	u.fs = d['fs']
	u.trchvar = trchvar
	try: u.channels = SigTools.ChannelSet(u.channels)
	except: print 'WARNING: failed to convert channels to ChannelSet'

	elapsed = time.time() - starttime
	minutes = int(elapsed/60.0)
	seconds = int(round(elapsed - minutes * 60.0))
	print '%d min %d sec' % (minutes, seconds)
	datestamp = time.strftime('%Y-%m-%d %H:%M:%S')
	csummary = '%s (%s) trained on %d (CV %s = %.3f) at %s' % (
		c.__class__.__name__,
		SigTools.experiment()._shortdesc(chosen),
		sum(c.input.istrain),
		c.loss.func.__name__,
		c.loss.train,
		datestamp,
	)
	description = 'binary classification of %s: %s' % (description, csummary)
	u.description = description
	
	if save or select:
		if not isinstance(save, basestring):
			save = featurefiles
			if isinstance(save, (tuple,list)): save = save[-1]
			if save.lower().endswith('.gz'): save = save[:-3]
			if save.lower().endswith('.pk'): save = save[:-3]
			save = save + '_weights.prm'
		print "\nsaving %s\n" % save
		Parameters.Param(u.G.A, Name='ERPClassifierWeights', Section='PythonSig', Subsection='Epoch', Comment=csummary).write_to(save)
		Parameters.Param(c.model.bias, Name='ERPClassifierBias', Section='PythonSig', Subsection='Epoch', Comment=csummary).append_to(save)
		Parameters.Param(description, Name='SignalProcessingDescription', Section='PythonSig').append_to(save)
		if select:
			if not isinstance(select, basestring): select = 'ChosenWeights.prm'
			if not os.path.isabs(select): select = os.path.join(os.path.split(save)[0], select)
			print "saving %s\n" % select
			import shutil; shutil.copyfile(save, select)
	
	print description
	return u,c
Example #21
0
    def initialize(cls, app, indim, outdim):
        if int(app.params['ERPDatabaseEnable']) == 1:
            if int(app.params['ShowSignalTime']):
                app.addstatemonitor('LastERPVal')
                app.addstatemonitor('ERPCollected')

            #===================================================================
            # Prepare the buffers for saving the data
            # -leaky_trap contains the data to be saved (trap size defined by pre_stim_samples + post_stim_samples + some breathing room
            # -trig_trap contains only the trigger channel
            #===================================================================
            app.x_vec = np.arange(app.erpwin[0],
                                  app.erpwin[1],
                                  1000.0 / app.eegfs,
                                  dtype=float)  #Needed when saving trials
            app.post_stim_samples = SigTools.msec2samples(
                app.erpwin[1], app.eegfs)
            app.pre_stim_samples = SigTools.msec2samples(
                np.abs(app.erpwin[0]), app.eegfs)
            app.leaky_trap = SigTools.Buffering.trap(
                app.pre_stim_samples + app.post_stim_samples + 5 * app.spb,
                len(app.params['ERPChan']),
                leaky=True)
            app.trig_trap = SigTools.Buffering.trap(
                app.post_stim_samples,
                1,
                trigger_channel=0,
                trigger_threshold=app.trigthresh[0])

            #===================================================================
            # Prepare the models from the database.
            #===================================================================
            app.subject = Subject.objects.get_or_create(
                name=app.params['SubjectName'])[0]
            #===================================================================
            # app.period = app.subject.get_or_create_recent_period(delay=0)
            # app.subject.periods.update()
            # app.period = app.subject.periods.order_by('-datum_id').all()[0]
            #===================================================================

            #===================================================================
            # Use a thread for database interactions because sometimes they will be slow.
            # (especially when calculating a trial's features)
            #===================================================================
            app.erp_thread = ERPThread(Queue.Queue(), app)
            app.erp_thread.setDaemon(
                True)  #Dunno, always there in the thread examples.
            app.erp_thread.start()  #Starts the thread.

            #===================================================================
            # Setup the ERP feedback elements.
            # -Screen will range from -2*fbthresh to +2*fbthresh
            # -Calculated ERP value will be scaled so 65536(int16) fills the screen.
            #===================================================================
            if int(app.params['ERPFeedbackDisplay']) == 2:
                fbthresh = app.params['ERPFeedbackThreshold'].val
                app.erp_scale = (2.0**16) / (4.0 * np.abs(fbthresh))
                if fbthresh < 0:
                    fbmax = fbthresh * app.erp_scale
                    fbmin = 2.0 * fbthresh * app.erp_scale
                else:
                    fbmax = 2.0 * fbthresh * app.erp_scale
                    fbmin = fbthresh * app.erp_scale
                m = app.scrh / float(
                    2**16)  #Conversion factor from signal amplitude to pixels.
                b_offset = app.scrh / 2.0  #Input 0.0 should be at this pixel value.
                app.addbar(color=(1, 0, 0),
                           pos=(0.9 * app.scrw, b_offset),
                           thickness=0.1 * app.scrw,
                           fac=m)
                n_bars = len(app.bars)
                #app.stimuli['bartext_1'].position=(50,50)
                app.stimuli['bartext_' + str(n_bars)].color = [0, 0, 0]
                erp_target_box = Block(position=(0.8 * app.scrw,
                                                 m * fbmin + b_offset),
                                       size=(0.2 * app.scrw,
                                             m * (fbmax - fbmin)),
                                       color=(1, 0, 0, 0.5),
                                       anchor='lowerleft')
                app.stimulus('erp_target_box', z=1, stim=erp_target_box)
Example #22
0
def ClassifyERPs(
    featurefiles,
    C=(10.0, 1.0, 0.1, 0.01),
    gamma=(1.0, 0.8, 0.6, 0.4, 0.2, 0.0),
    keepchan=(),
    rmchan=(),
    rmchan_usualsuspects=('AUDL', 'AUDR', 'LAUD', 'RAUD', 'SYNC', 'VSYNC',
                          'VMRK', 'OLDREF'),
    rebias=True,
    save=False,
    select=False,
    description='ERPs to attended vs unattended events',
    maxcount=None,
    classes=None,
    folds=None,
    time_window=None,
    keeptrials=None,
):

    file_inventory = []
    d = DataFiles.load(featurefiles,
                       catdim=0,
                       maxcount=maxcount,
                       return_details=file_inventory)
    if isinstance(folds, basestring) and folds.lower() in [
            'lofo', 'loro', 'leave on run out', 'leave one file out'
    ]:
        n, folds = 0, []
        for each in file_inventory:
            neach = each[1]['x']
            folds.append(range(n, n + neach))
            n += neach

    if 'x' not in d:
        raise ValueError(
            "found no trial data - no 'x' variable - in the specified files")
    if 'y' not in d:
        raise ValueError(
            "found no trial labels - no 'y' variable - in the specified files")

    x = d['x']
    y = numpy.array(d['y'].flat)
    if keeptrials != None:
        x = x[numpy.asarray(keeptrials), :, :]
        y = y[numpy.asarray(keeptrials)]

    if time_window != None:
        fs = d['fs']
        t = SigTools.samples2msec(numpy.arange(x.shape[2]), fs)
        x[:, :, t < min(time_window)] = 0
        x[:, :, t > max(time_window)] = 0

    if classes != None:
        for cl in classes:
            if cl not in y:
                raise ValueError("class %s is not in the dataset" % str(cl))
        mask = numpy.array([yi in classes for yi in y])
        y = y[mask]
        x = x[mask]
        discarded = sum(mask == False)
        if discarded:
            print "discarding %d trials that are outside the requested classes %s" % (
                discarded, str(classes))

    n = len(y)
    uy = numpy.unique(y)
    if uy.size != 2:
        raise ValueError("expected 2 classes in dataset, found %d : %s" %
                         (uy.size, str(uy)))
    for uyi in uy:
        nyi = sum([yi == uyi for yi in y])
        nyi_min = 2
        if nyi < nyi_min:
            raise ValueError(
                'only %d exemplars of class %s - need at least %d' %
                (nyi, str(uyi), nyi_min))

    y = numpy.sign(y - uy.mean())

    cov, trchvar = SigTools.spcov(
        x=x, y=y, balance=False,
        return_trchvar=True)  # NB: symwhitenkern would not be able to balance

    starttime = time.time()

    chlower = [ch.lower() for ch in d['channels']]
    if keepchan in [None, (), '', []]:
        if isinstance(rmchan, basestring): rmchan = rmchan.split()
        if isinstance(rmchan_usualsuspects, basestring):
            rmchan_usualsuspects = rmchan_usualsuspects.split()
        allrmchan = [
            ch.lower() for ch in list(rmchan) + list(rmchan_usualsuspects)
        ]
        unwanted = numpy.array([ch in allrmchan for ch in chlower])
        notfound = [ch for ch in rmchan if ch.lower() not in chlower]
    else:
        if isinstance(keepchan, basestring): keepchan = keepchan.split()
        lowerkeepchan = [ch.lower() for ch in keepchan]
        unwanted = numpy.array([ch not in lowerkeepchan for ch in chlower])
        notfound = [ch for ch in keepchan if ch.lower() not in chlower]

    wanted = numpy.logical_not(unwanted)
    print ' '
    if len(notfound):
        print "WARNING: could not find channel%s %s\n" % ({
            1: ''
        }.get(len(notfound), 's'), ', '.join(notfound))
    removed = [ch for removing, ch in zip(unwanted, d['channels']) if removing]
    if len(removed):
        print "removed %d channel%s (%s)" % (len(removed), {
            1: ''
        }.get(len(removed), 's'), ', '.join(removed))
    print "classification will be based on %d channel%s" % (sum(wanted), {
        1: ''
    }.get(sum(wanted), 's'))
    print "%d negatives + %d positives = %d exemplars" % (sum(y < 0),
                                                          sum(y > 0), n)
    print ' '

    x[:, unwanted, :] = 0
    cov[:, unwanted] = 0
    cov[unwanted, :] = 0
    nu = numpy.asarray(cov).diagonal()[wanted].mean()
    for i in range(len(cov)):
        if cov[i, i] == 0: cov[i, i] = nu

    if not isinstance(C, (tuple, list, numpy.ndarray, type(None))): C = [C]
    if not isinstance(gamma, (tuple, list, numpy.ndarray, type(None))):
        gamma = [gamma]

    c = SigTools.klr2class(lossfunc=SigTools.balanced_loss, relcost='balance')
    c.varyhyper({})
    if c != None: c.hyper.C = list(C)
    if gamma == None: c.hyper.kernel.func = SigTools.linkern
    else:
        c.varyhyper({
            'kernel.func': SigTools.symwhitenkern,
            'kernel.cov': [cov],
            'kernel.gamma': list(gamma)
        })
    c.cvtrain(x=x, y=y, folds=folds)
    if rebias: c.rebias()
    c.calibrate()

    chosen = c.cv.chosen.hyper
    if gamma == None:
        Ps = None
        Gp = c.featureweight(x=x)
    else:
        Ps = SigTools.svd(
            SigTools.shrinkcov(cov, copy=True,
                               gamma=chosen.kernel.gamma)).isqrtm
        xp = SigTools.spfilt(x, Ps.H, copy=True)
        Gp = c.featureweight(x=xp)

    u = SigTools.stfac(Gp, Ps)
    u.channels = d['channels']
    u.channels_used = wanted
    u.fs = d['fs']
    u.trchvar = trchvar
    try:
        u.channels = SigTools.ChannelSet(u.channels)
    except:
        print 'WARNING: failed to convert channels to ChannelSet'

    elapsed = time.time() - starttime
    minutes = int(elapsed / 60.0)
    seconds = int(round(elapsed - minutes * 60.0))
    print '%d min %d sec' % (minutes, seconds)
    datestamp = time.strftime('%Y-%m-%d %H:%M:%S')
    csummary = '%s (%s) trained on %d (CV %s = %.3f) at %s' % (
        c.__class__.__name__,
        SigTools.experiment()._shortdesc(chosen),
        sum(c.input.istrain),
        c.loss.func.__name__,
        c.loss.train,
        datestamp,
    )
    description = 'binary classification of %s: %s' % (description, csummary)
    u.description = description

    if save or select:
        if not isinstance(save, basestring):
            save = featurefiles
            if isinstance(save, (tuple, list)): save = save[-1]
            if save.lower().endswith('.gz'): save = save[:-3]
            if save.lower().endswith('.pk'): save = save[:-3]
            save = save + '_weights.prm'
        print "\nsaving %s\n" % save
        Parameters.Param(u.G.A,
                         Name='ERPClassifierWeights',
                         Section='PythonSig',
                         Subsection='Epoch',
                         Comment=csummary).write_to(save)
        Parameters.Param(c.model.bias,
                         Name='ERPClassifierBias',
                         Section='PythonSig',
                         Subsection='Epoch',
                         Comment=csummary).append_to(save)
        Parameters.Param(description,
                         Name='SignalProcessingDescription',
                         Section='PythonSig').append_to(save)
        if select:
            if not isinstance(select, basestring): select = 'ChosenWeights.prm'
            if not os.path.isabs(select):
                select = os.path.join(os.path.split(save)[0], select)
            print "saving %s\n" % select
            import shutil
            shutil.copyfile(save, select)

    print description
    return u, c
def StimulusTiming(filename='.', ind=None, channels=0, trigger='StimulusCode > 0', msec=200, rectify=False, threshold=0.5, use_eo=True, save=None, **kwargs):
	"""
In <filename> and <ind>, give it
  - a directory and ind=None:  for all .dat files in the directory, in session/run order
  - a directory and ind=an index or list of indices: for selected .dat files in the directory
  - a dat-file name and ind=anything:  for that particular file
  - a list of filenames and ind=anything: for certain explicitly-specified files

<channels>  may be a 0-based index, list of indices, list of channel names, or space- or comma-
			delimited string of channel names
<rectify>   subtracts the median and takes the abs before doing anything else
<threshold> is on the normalized scale of min=0, max=1 within the resulting image
<use_eo>    uses the EventOffset state to correct timings
	"""###
	if hasattr(filename, 'filename'): filename = filename.filename
		
	if ind==None:
		ind = -1
		if os.path.isdir(filename): filename = ListDatFiles(filename)
		
	if not isinstance(filename, (tuple,list)): filename = [filename]
	if not isinstance(ind, (tuple,list)): ind = [ind]
	n = max(len(filename), len(ind))
	if len(filename) == 1: filename = list(filename) * n
	if len(ind) == 1: ind = list(ind) * n
	
	if isinstance(channels, basestring): channels = channels.replace(',', ' ').split()
	if not isinstance(channels, (tuple,list)): channels = [channels]
	out = [SigTools.sstruct(
			files=[],
			events=[],
			t=None,
			channel=ch,
			img=[],
			edges=[],
			threshold=None,
			EventOffsets=[],
			UseEventOffsets=False,
		) for ch in channels]
	if len(filename) == 0: raise ValueError("no data files specified")
	for f,i in zip(filename, ind):
		b = bcistream(filename=f, ind=i)
		nsamp = b.msec2samples(msec)
		sig,st = b.decode('all')
		statenames = zip(*sorted([(-len(x),x) for x in st]))[1]
		criterion = trigger
		for x in statenames: criterion = criterion.replace(x, "st['%s']"%x)
		criterion = numpy.asarray(eval(criterion)).flatten()
		startind = RisingEdge(criterion).nonzero()[0] + 1
		print "%d events found in %s" % (len(startind), b.filename)
		
		for s in out:
			s.files.append(b.filename)
			s.events.append(len(startind))
			ch = s.channel
			if isinstance(ch, basestring): 
				chn = [x.lower() for x in b.params['ChannelNames']]
				if ch.lower() in chn: ch = chn.index(ch.lower())
				else: raise ValueError("could not find channel %s in %s" % (ch,b.filename))
			if len(b.params['ChannelNames']) == len(sig):
				s.channel = b.params['ChannelNames'][ch]
			
			xx = numpy.asarray(sig)[ch]
			if rectify: xx = numpy.abs(xx - numpy.median(xx))
			xx -= xx.min()
			if xx.max(): xx /= xx.max()
			s.threshold = threshold
			for ind in startind:
				if 'EventOffset' in st:
					eo = st['EventOffset'].flat[ind]
					if use_eo:
						ind += eo - 2**(b.statedefs['EventOffset']['length']-1)
						s.UseEventOffsets = True
				else:
					eo = 0
				s.EventOffsets.append(eo)
				x = xx[ind:ind+nsamp].tolist()
				x += [0.0] * (nsamp - len(x))
				s.img.append(x)
	
	for s in out:
		s.img = numpy.asarray(s.img)
		s.edges = [min(list(x.nonzero()[0])+[numpy.nan]) for x in (s.img > s.threshold)]
		s.edges = b.samples2msec(numpy.asarray(s.edges))
		s.t = b.samples2msec(numpy.arange(nsamp))	
		
	import pylab
	pylab.clf()
	for i,s in enumerate(out):
		pylab.subplot(1, len(out), i+1)
		y = y=range(1,len(s.img)+1)
		SigTools.imagesc(s.img, x=s.t, y=y, aspect='auto', **kwargs)
		xl,yl = pylab.xlim(),pylab.ylim()
		pylab.plot(s.edges, y, 'w*', markersize=10)
		pylab.xlim(xl); pylab.ylim(yl)
		pylab.grid('on')
		#pylab.ylim([len(s.img)+0.5,0.5]) # this corrupts the image!!
	pylab.draw()
	if save:
		pylab.gcf().savefig(save, orientation='portrait')
	return out
Example #24
0
def TimingWindow(filename='.', ind=-1, save=None):
    """
Recreate BCI2000's timing window offline, from a saved .dat file specified by <filename>.
It is also possible to supply a directory name as <filename>, and an index <ind> (default
value -1 meaning "the last run") to choose a file automatically from that directory.

Based on BCI2000's   src/shared/modules/signalsource/DataIOFilter.cpp where the timing window
content is computed in DataIOFilter::Process(), this is what appears to happen:
    
         Begin SampleBlock #t:
            Enter SignalSource module's first Process() method (DataIOFilter::Process)
            Save previous SampleBlock to file
            Wait to acquire new SampleBlock from hardware
 +--------- Measure SourceTime in SignalSource module
 |   |   |  Make a local copy of all states (NB: all except SourceTime were set during #t-1) ---+
B|  R|  S|  Pipe the signal through the rest of BCI2000                                         |
 |   |   +- Measure StimulusTime in Application module, on leaving last Process() method        |
 |   |                                                                                          |
 |   |                                                                                          |
 |   |   Begin SampleBlock #t+1:                                                                |
 |   +----- Enter SignalSource module's first Process() method (DataIOFilter::Process)          |
 |          Save data from #t, SourceTime state from #t, and other states from #t-1, to file <--+
 |          Wait to acquire new SampleBlock from hardware
 +--------- Measure SourceTime in SignalSource module
            Make a local copy of all states (NB: all except SourceTime were set during #t)
            Leave DataIOFilter::Process() and pipe the signal through the rest of BCI2000
            Measure StimulusTime in Application module, on leaving last Process() method

B stands for Block duration.
R stands for Roundtrip time (visible in VisualizeTiming, not reconstructable from the .dat file)
S is the filter cascade time (marked "Stimulus" in the VisualizeTiming window).

Note that, on any given SampleBlock as saved in the file, SourceTime will be *greater* than
any other timestamp states (including StimulusTime), because it is the only state that is
updated in time to be saved with the data packet it belongs to. All the others lag by one
packet.  This is corrected for at the point commented with ??? in the Python code. 
"""

    if hasattr(filename, 'filename'): filename = filename.filename

    b = bcistream(filename=filename, ind=ind)

    out = SigTools.sstruct()
    out.filename = b.filename
    #print "decoding..."
    sig, states = b.decode('all')
    #print "done"
    b.close()

    dT, T, rT = {}, {}, {}
    statenames = ['SourceTime', 'StimulusTime'
                  ] + ['PythonTiming%02d' % (x + 1) for x in range(2)]
    statenames = [s for s in statenames if s in states]
    for key in statenames:
        dT[key], T[key] = SigTools.unwrapdiff(states[key].flatten(),
                                              base=65536,
                                              dtype=numpy.float64)

    sel, = numpy.where(dT['SourceTime'])
    for key in statenames:
        dT[key] = dT[key][sel[1:]]
        if key == 'SourceTime': tsel = sel[:-1]  # ??? why the shift
        else: tsel = sel[1:]  # ??? relative to here?
        T[key] = T[key][tsel + 1]

    t0 = T['SourceTime'][0]
    for key in statenames:
        T[key] -= t0

    t = T['SourceTime'] / 1000

    expected = b.samples2msec(b.params['SampleBlockSize'])
    datestamp = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(b.datestamp))
    paramstr = ', '.join([
        '%s=%s' % (x, b.params[x]) for x in [
            'SampleBlockSize', 'SamplingRate', 'VisualizeTiming',
            'VisualizeSource'
        ]
    ])
    chainstr = '-'.join([
        x for x, y in b.params['SignalSourceFilterChain'] +
        b.params['SignalProcessingFilterChain'] +
        b.params['ApplicationFilterChain']
    ])
    titlestr = '\n'.join([b.filename, datestamp, paramstr, chainstr])

    SigTools.plot(t[[0, -1]], [expected] * 2, drawnow=False)
    SigTools.plot(t, dT['SourceTime'], hold=True, drawnow=False)

    for key in statenames:
        if key == 'SourceTime': continue
        rT[key] = T[key] - T['SourceTime']
        SigTools.plot(t, rT[key], hold=True, drawnow=False)

    import pylab
    pylab.title(titlestr)
    pylab.grid(True)
    pylab.xlabel('seconds')
    pylab.ylabel('milliseconds')
    ymin, ymax = pylab.ylim()
    pylab.ylim(ymax=max(ymax, expected * 2))
    pylab.xlim(xmax=t[-1])
    pylab.draw()
    out.params = SigTools.sstruct(b.params)
    out.summarystr = titlestr
    out.t = t
    out.SourceTime = T['SourceTime']
    out.StimulusTime = T['StimulusTime']
    out.BlockDuration = dT['SourceTime']
    out.BlockDuration2 = dT['StimulusTime']
    out.ProcessingTime = out.StimulusTime - out.SourceTime
    out.ExpectedBlockDuration = expected
    out.rT = rT
    out.dT = dT
    out.T = T

    if save:
        pylab.gcf().savefig(save, orientation='landscape')

    return out
Example #25
0
def StimulusTiming(filename='.',
                   ind=None,
                   channels=0,
                   trigger='StimulusCode > 0',
                   msec=200,
                   rectify=False,
                   threshold=0.5,
                   use_eo=True,
                   save=None,
                   **kwargs):
    """
In <filename> and <ind>, give it
  - a directory and ind=None:  for all .dat files in the directory, in session/run order
  - a directory and ind=an index or list of indices: for selected .dat files in the directory
  - a dat-file name and ind=anything:  for that particular file
  - a list of filenames and ind=anything: for certain explicitly-specified files

<channels>  may be a 0-based index, list of indices, list of channel names, or space- or comma-
			delimited string of channel names
<rectify>   subtracts the median and takes the abs before doing anything else
<threshold> is on the normalized scale of min=0, max=1 within the resulting image
<use_eo>    uses the EventOffset state to correct timings
	"""###
    if hasattr(filename, 'filename'): filename = filename.filename

    if ind == None:
        ind = -1
        if os.path.isdir(filename): filename = ListDatFiles(filename)

    if not isinstance(filename, (tuple, list)): filename = [filename]
    if not isinstance(ind, (tuple, list)): ind = [ind]
    n = max(len(filename), len(ind))
    if len(filename) == 1: filename = list(filename) * n
    if len(ind) == 1: ind = list(ind) * n

    if isinstance(channels, basestring):
        channels = channels.replace(',', ' ').split()
    if not isinstance(channels, (tuple, list)): channels = [channels]
    out = [
        SigTools.sstruct(
            files=[],
            events=[],
            t=None,
            channel=ch,
            img=[],
            edges=[],
            threshold=None,
            EventOffsets=[],
            UseEventOffsets=False,
        ) for ch in channels
    ]
    if len(filename) == 0: raise ValueError("no data files specified")
    for f, i in zip(filename, ind):
        b = bcistream(filename=f, ind=i)
        nsamp = b.msec2samples(msec)
        sig, st = b.decode('all')
        statenames = zip(*sorted([(-len(x), x) for x in st]))[1]
        criterion = trigger
        for x in statenames:
            criterion = criterion.replace(x, "st['%s']" % x)
        criterion = numpy.asarray(eval(criterion)).flatten()
        startind = RisingEdge(criterion).nonzero()[0] + 1
        print "%d events found in %s" % (len(startind), b.filename)

        for s in out:
            s.files.append(b.filename)
            s.events.append(len(startind))
            ch = s.channel
            if isinstance(ch, basestring):
                chn = [x.lower() for x in b.params['ChannelNames']]
                if ch.lower() in chn: ch = chn.index(ch.lower())
                else:
                    raise ValueError("could not find channel %s in %s" %
                                     (ch, b.filename))
            if len(b.params['ChannelNames']) == len(sig):
                s.channel = b.params['ChannelNames'][ch]

            xx = numpy.asarray(sig)[ch]
            if rectify: xx = numpy.abs(xx - numpy.median(xx))
            xx -= xx.min()
            if xx.max(): xx /= xx.max()
            s.threshold = threshold
            for ind in startind:
                if 'EventOffset' in st:
                    eo = st['EventOffset'].flat[ind]
                    if use_eo:
                        ind += eo - 2**(b.statedefs['EventOffset']['length'] -
                                        1)
                        s.UseEventOffsets = True
                else:
                    eo = 0
                s.EventOffsets.append(eo)
                x = xx[ind:ind + nsamp].tolist()
                x += [0.0] * (nsamp - len(x))
                s.img.append(x)

    for s in out:
        s.img = numpy.asarray(s.img)
        s.edges = [
            min(list(x.nonzero()[0]) + [numpy.nan])
            for x in (s.img > s.threshold)
        ]
        s.edges = b.samples2msec(numpy.asarray(s.edges))
        s.t = b.samples2msec(numpy.arange(nsamp))

    import pylab
    pylab.clf()
    for i, s in enumerate(out):
        pylab.subplot(1, len(out), i + 1)
        y = y = range(1, len(s.img) + 1)
        SigTools.imagesc(s.img, x=s.t, y=y, aspect='auto', **kwargs)
        xl, yl = pylab.xlim(), pylab.ylim()
        pylab.plot(s.edges, y, 'w*', markersize=10)
        pylab.xlim(xl)
        pylab.ylim(yl)
        pylab.grid('on')
        #pylab.ylim([len(s.img)+0.5,0.5]) # this corrupts the image!!
    pylab.draw()
    if save:
        pylab.gcf().savefig(save, orientation='portrait')
    return out
def test(codebook, model=None, alphabet=None, txt=None, balanced_acc=0.75, default_context='. ',
         threshold=1.0, min_epochs=1, max_epochs=72,
         minprob=None, exponent=1.0, w=80):

	modelin = model
	if model == None: model = TextPrediction.LanguageModel(alphabet=alphabet)
	if alphabet == None: alphabet = TextPrediction.default_alphabet36
	if txt==None: txt = TextPrediction.default_test_text

	alphabet = model.match_case(alphabet)
	if len(set(alphabet) - set(model.alphabet)) > 0:
		raise ValueError("some characters in the alphabet are not in the model")

	if len(txt) < 1000 and os.path.isfile(TextPrediction.FindFile(txt)):
		txt = TextPrediction.ReadUTF8(TextPrediction.FindFile(txt))
	# Things will fail if some characters in the test text are not in alphabet. So...
	if 'this is a massive hack':
		oldalphabet = model.alphabet
		model.alphabet = set(oldalphabet).intersection(alphabet)
		txt = model.clean(txt)
		model.alphabet = oldalphabet
	
	if isinstance(codebook, dict) and 'Matrix' in codebook: codebook = codebook['Matrix']
	
	if isinstance(codebook, basestring): codebook = [[int(i) for i in row.split()] for row in codebook.strip().replace('.','0').replace('-','0').split('\n')]
	codebook = numpy.asarray(codebook)
	
	N,L = codebook.shape
	if N != len(alphabet): raise ValueError("number of rows in the codebook should be equal to the number of symbols in the alphabet")
	
	out = txt[:0]
	nchars = len(txt)
	correct = numpy.zeros((nchars,), dtype=numpy.bool)
	nepochs = numpy.zeros((nchars,), dtype=numpy.int16)
	z = SigTools.invcg(balanced_acc)
	
	confusion = numpy.zeros((2,2),dtype=numpy.float64)
	correct_running = SigTools.running_mean()
	nepochs_running = SigTools.running_mean()
	try:
		for i in range(nchars):
			correct_symbol = txt[i]
			correct_index = alphabet.index(correct_symbol) # will throw an error if not found, and rightly so
			context = model.prepend_context(txt[:i], default_context)
			d = TextPrediction.Decoder(choices=alphabet,
			                           context=context, model=model, verbose=False, 
			                           threshold=threshold, min_epochs=min_epochs, max_epochs=max_epochs,
			                           minprob=minprob, exponent=exponent)
			result = None
			while result == None: 
				#__IPYTHON__.dbstop()
				col = codebook[:, d.L % L]
				Cij = int(col[correct_index])
				d.new_column(col)
				mean = z * {0:-1, 1:1}[Cij]
				x = numpy.random.randn() + mean
				p = SigTools.logistic(2.0 * z * x)
				confusion[Cij, int(round(p))] += 1
				result = d.new_transmission(p)
			nepochs[i] = d.L
			correct[i] = (result == correct_symbol)
			out = out + result
			nepochs_running += d.L
			correct_running += (result == correct_symbol)
			if w:
				sys.stdout.write(result); sys.stdout.flush()
				if (i+1) % w == 0: sys.stdout.write('   %3d\n' % round(100.0 * float(i+1)/nchars))
				elif (i+1) == nchars: sys.stdout.write('\n')
	except KeyboardInterrupt:
		if w: sys.stdout.write('\n')
		pass
	
	ndone = len(out)

	s = sstruct()
	s.alphabet = alphabet
	s.codebook = codebook
	s.model = modelin
	s.input = txt[:ndone]
	s.output = out
	s.conditions = sstruct()
	s.conditions.threshold = threshold
	s.conditions.min_epochs = min_epochs
	s.conditions.max_epochs = max_epochs
	s.conditions.minprob = minprob
	s.conditions.exponent = exponent
	s.epoch_acc = sstruct()
	s.epoch_acc.desired = balanced_acc
	s.epoch_acc.empirical_nontarget, s.epoch_acc.empirical_target = (confusion.diagonal() / confusion.sum(axis=1)).flat
	s.epoch_acc.confusion = confusion
	s.nepochs = sstruct()
	s.nepochs.each  = nepochs[:ndone]
	s.nepochs.mean  = nepochs_running.m
	s.nepochs.std   = nepochs_running.v_unbiased ** 0.5
	s.nepochs.ste   = s.nepochs.std / nepochs_running.n ** 0.5
	s.letter_acc = sstruct()
	s.letter_acc.each = correct[:ndone]
	s.letter_acc.mean = correct_running.m
	s.letter_acc.std  = correct_running.v_unbiased ** 0.5
	s.letter_acc.ste  = s.letter_acc.std / correct_running.n ** 0.5
	return s