コード例 #1
0
def event_timestamps_to_indices(sigfile, eventfile):
	"""
	Convert LSL timestamps to sample indices for separetely recorded events.

	Parameters:
	sigfile: raw signal file (Python Pickle) recorded with stream_recorder.py.
	eventfile: event file where events are indexed with LSL timestamps.

	Returns:
	events list, which can be used as an input to mne.io.RawArray.add_events().
	"""

	raw= qc.load_obj(sigfile)
	ts= raw['timestamps'].reshape(-1)
	ts_min= min(ts)
	ts_max= max(ts)
	events= []

	with open(eventfile) as f:
		for l in f:
			data= l.strip().split('\t')
			event_ts= float( data[0] )
			event_value= int( data[2] )
			# find the first index not smaller than ts
			next_index= np.searchsorted(ts, event_ts)
			if next_index >= len(ts):
				qc.print_c( '** WARNING: Event %d at time %.3f is out of time range (%.3f - %.3f).'% (event_value,event_ts,ts_min,ts_max), 'y' )
			else:
				events.append( [next_index, 0, event_value] )
			#print(events[-1])

	return events
コード例 #2
0
def get_decoder_info(classifier):
    """
	Get only the classifier information without connecting to a server

	Params
	------
		classifier: model file

	Returns
	-------
		info dictionary object
	"""

    model = qc.load_obj(classifier)
    if model == None:
        print('>> Error loading %s' % model)
        sys.exit(-1)

    cls = model['cls']
    psde = model['psde']
    labels = list(cls.classes_)
    w_seconds = model['w_seconds']
    w_frames = model['w_frames']
    wstep = model['wstep']
    sfreq = model['sfreq']
    psd_temp = psde.transform(np.zeros((1, len(model['picks']), w_frames)))
    psd_shape = psd_temp.shape
    psd_size = psd_temp.size

    info= dict(labels=labels, cls=cls, psde=psde, w_seconds=w_seconds, w_frames=w_frames,\
     wstep=wstep, sfreq=sfreq, psd_shape=psd_shape, psd_size=psd_size)
    return info
コード例 #3
0
def pcl2mat_old(fpcl):
	"""
	For old format data only
	"""
	raw= qc.load_obj(fpcl)
	assert type(raw['signals'])==type(list())
	signals= np.array( raw['signals'][0] ) # samples x channels
	ts= raw['timestamps'][0]
	srate= raw['sample_rate']
	n_ch= raw['channels']
	if n_ch > 17: # BioSemi
		ev16= signals[:,0]-1 # first channel is event channel
		events_raw= 0xFF & ev16.astype(int) # keep only the low 8 bits
		events= find_events( events_raw )
	else:
		events= find_events( signals[:,-1] )

	print('Signal dimension:', signals.shape)
	print('Timestamp dimension:', len(ts) )
	print('Sampling rate:', srate)
	print('No. channels:', n_ch)
	data= dict(signals=signals, timestamps=ts, events=events, sample_rate=srate, n_channels=n_ch)
	fmat= fpcl[:-4] + '.mat'
	scipy.io.savemat( fmat, data )
	print('Saved data as', fmat)
コード例 #4
0
def azureml_main(dataframe1=None, dataframe2=None):
    # If a zip file is connected to the third input port is connected,
    # it is unzipped under ".\Script Bundle". This directory is added
    # to sys.path. Therefore, if your zip file contains a Python file
    # mymodule.py you can import it using:
    # import mymodule
    # print('Input pandas.DataFrame #1:\r\n\r\n{0}'.format(dataframe1))

    # Import dependent modules. Run the tester module (tester.py) from your local machine to train and cross-validate your models.
    import sys
    import sklearn
    import numpy
    import pandas
    import q_common as qc
    import tester

    # System envrionment check
    print(sys.version)
    print('\nPlatform: %s' % tester.PLATFORM)
    print('sklearn: %s' % sklearn.__version__)
    print('pandas: %s' % pandas.__version__)
    print('numpy: %s' % numpy.__version__)
    print('MY_PATH: %s\n\n' % tester.MY_PATH)

    # Create a timer object to measure the runnning time
    tm = qc.Timer()

    # Load trained classifiers saved in a Python pickle format
    model_file = '%s/classifiers.pkl' % tester.MY_PATH
    model = qc.load_obj(model_file)
    assert model is not None

    # Load preprocessing and feature computation parameters
    cfg = model['cfg']
    psd_params = model['psd_params']
    epochs = model['epochs']

    # Compute features from raw data
    features = tester.get_features(dataframe1, cfg, psd_params, epochs)

    # Test classifiers on computed features
    answers_pd = tester.predictor(features, model)

    # Print out predictions and running time
    print('Done. Took %.1f seconds.' % tm.sec())
    print('\n*** Predicted labels start ***\n')
    print(answers_pd)
    print('\n*** Predicted labels end ***\n')

    # Return predictions
    return answers_pd
コード例 #5
0
    def __init__(self,
                 classifier=None,
                 buffer_size=1.0,
                 fake=False,
                 amp_serial=None,
                 amp_name=None):
        """
		Params
		------
			classifier: classifier file.
			buffer_size: buffer size in seconds.
		"""

        self.classifier = classifier
        self.buffer_sec = buffer_size
        self.startmsg = 'Decoder daemon started.'
        self.stopmsg = 'Decoder daemon stopped.'
        self.fake = fake
        self.amp_serial = amp_serial
        self.amp_name = amp_name

        if fake == False:
            self.model = qc.load_obj(self.classifier)
            if self.model == None:
                self.print('Error loading %s' % self.model)
                sys.exit(-1)
            else:
                self.labels = self.model['cls'].classes_
        else:
            # create a fake decoder with LEFT/RIGHT classes
            self.model = None
            from triggerdef_16 import TriggerDef
            tdef = TriggerDef()
            #			self.labels= [tdef.by_key['DOWN_GO'], tdef.by_key['UP_GO']]
            self.labels = [tdef.by_key['LEFT_GO'],
                           tdef.by_key['RIGHT_GO']]  ## dongliu
            self.startmsg = 'FAKE ' + self.startmsg
            self.stopmsg = 'FAKE ' + self.stopmsg

        self.psdlock = mp.Lock()
        self.reset()
        self.start()
コード例 #6
0
def load_raw_old(rawfile, spfilter=None, spchannels=None, events_ext=None):
    """
	** Deprecated function **
	Please use convert2fif to convert non-fif files to fif first.


	Returns raw data and events

	Supports gdf, bdf, fif, and Python raw format (pcl).
	Any non-fif file will be saved into .fif format in the fif/ directory after loading.

	Parameters:
	rawfile: (absolute) data file path
	spfilter: 'car' | 'laplacian' | None
	spchannels: None | list (for CAR) | dict (for LAPLACIAN)
		'car': channel indices used for CAR filtering. If None, use all channels except
			   the trigger channel (index 0).
		'laplacian': {channel:[neighbor1, neighbor2, ...], ...}
		*** Note ***
		Since PyCNBI puts trigger channel as index 0, data channel starts from index 1.
	events_mne: Add externally recorded events (e.g. software trigger).
				events_mne is of format: [ [sample_index1, 0, event_value1],... ]

	Returns:
	raw: mne.io.RawArray object. First channel (index 0) is always trigger channel.
	events: mne-compatible events numpy array object (N x [frame, 0, type])
	spfilter= {None | 'car' | 'laplacian'}

	"""

    if not os.path.exists(rawfile):
        qc.print_c('# ERROR: File %s not found' % rawfile, 'r')
        sys.exit(-1)

    rawfile = rawfile.replace('\\', '/')
    dirs = rawfile.split('/')
    if len(dirs) == 1: basedir = './'
    else: basedir = '/'.join(dirs[:-1]) + '/'
    extension = rawfile.split('.')[-1]
    basename = '.'.join(rawfile.split('.')[:-1])
    raw = None
    events = []

    if extension == 'pcl':
        data = qc.load_obj(rawfile)

        if type(data['signals']) == list:
            print('Converting into numpy format')
            signals_raw = np.array(
                data['signals'][0]).T  # to channels x samples
        else:
            signals_raw = data['signals'].T  # to channels x samples
        sample_rate = data['sample_rate']
        events_raw = data['events']

        # BioSemi or gtec?
        if data['channels'] == 17:
            # move the trigger channel to the first row
            if find_event_channel(signals_raw) != 16:
                qc.print_c(
                    '**** WARNING: Assuming GTEC_16 format. Double-check trigger channel !! *****',
                    'r')
            signals = np.concatenate(
                (signals_raw[16, :].reshape(1, -1), signals_raw[:16, :]))
            info = mne.create_info(CAP['GTEC_16'], sample_rate,
                                   CAP['GTEC_16_INFO'])
        elif data['channels'] >= 73:
            signals = signals_raw[:
                                  73, :]  # trigger channel is already the first row
            sigtrig = signals[0, :] - 1
            signals[0, :] = 0xFF & sigtrig.astype(
                int)  # keep only the low 8 bits
            info = mne.create_info(CAP['BIOSEMI_64'], sample_rate,
                                   CAP['BIOSEMI_64_INFO'])
        elif data['channels'] == 24:
            qc.print_c(
                '**** ASSUMING SmartBCI system with no trigger channel ****',
                'y')
            if True:
                # A1=9, A2=16
                ear_avg = (signals_raw[8] + signals_raw[15]) / 2.0
                signals = signals_raw - ear_avg
                trigger = np.zeros((1, signals_raw.shape[1]))
                signals = np.vstack((trigger, signals))
            else:
                signals = signals_raw[:
                                      24, :]  # trigger channel is already the first row
            sigtrig = signals[0, :]
            signals[0, :] = 0x00
            info = mne.create_info(CAP['SMARTBCI_24'], sample_rate,
                                   CAP['SMARTBCI_24_INFO'])
        else:  # ok, unknown format
            # guess trigger channel
            trig_ch = find_event_channel(signals_raw)
            if trig_ch is not None:
                qc.print_c(
                    'Found trigger channel %d. Moving to channel 0.' % trig_ch,
                    'y')
                signals = np.concatenate(
                    (signals_raw[[trig_ch]], signals_raw[:trig_ch],
                     signals_raw[trig_ch + 1:]),
                    axis=0)
                assert signals_raw.shape == signals.shape
                num_eeg_channels = data['channels'] - 1
            else:
                # assuming no trigger channel exists, add a trigger channel to index 0 for consistency.
                qc.print_c(
                    '**** Unrecognized number of channels (%d). Adding an event channel to index 0.'
                    % data['channels'], 'r')
                eventch = np.zeros([1, signals_raw.shape[1]])
                signals = np.concatenate((eventch, signals_raw), axis=0)
                num_eeg_channels = data['channels']

            ch_names = ['TRIGGER'] + [
                'CH%d' % (x + 1) for x in range(num_eeg_channels)
            ]
            ch_info = ['stim'] + ['eeg'] * num_eeg_channels
            info = mne.create_info(ch_names, sample_rate, ch_info)

    elif extension in ['fif', 'fiff']:
        raw = mne.io.Raw(rawfile, preload=True)

    elif extension in ['bdf', 'gdf']:
        # convert to mat using MATLAB (MNE's edf reader has an offset bug)
        matfile = basename + '.mat'
        if not os.path.exists(matfile):
            print('>> Converting input to mat file')
            run = "[sig,header]=sload('%s.%s'); save('%s.mat','sig','header');" % (
                basename, extension, basename)
            qc.matlab(run)
            if not os.path.exists(matfile):
                qc.print_c('>> ERROR: mat file convertion error.', 'r')
                sys.exit()

        mat = scipy.io.loadmat(matfile)
        os.remove(matfile)
        sample_rate = int(mat['header']['SampleRate'])
        nch = mat['sig'].shape[1]

        if extension == 'gdf':
            # Note: gdf might  have a software trigger channel
            if nch == 17:
                ch_names = CAP['GTEC_16']
                ch_info = CAP['GTEC_16_INFO'][:nch]
            else:
                ch_names = ['TRIGGER'] + ['ch%d' % x for x in range(1, nch)]
                ch_info = ['stim'] + ['eeg'] * (nch - 1)

            # read events from header
            '''
			Important:
				event position may have the same frame number for two consecutive events
				It might be due to the CNBI software trigger bug
			Example:
				f1.20121220.102907.offline.mi.mi_rhlh.gdf (Two 10201's in evpos)
			'''
            evtype = mat['header']['EVENT'][0][0][0]['TYP'][0]
            evpos = mat['header']['EVENT'][0][0][0]['POS'][0]
            for e in range(evtype.shape[0]):
                label = int(evtype[e])
                events.append([int(evpos[e][0]), 0, label])

        elif extension == 'bdf':
            # assume Biosemi always has the same number of channels
            if nch == 73:
                ch_names = CAP['BIOSEMI_64']
                extra_ch = nch - len(CAP['BIOSEMI_64_INFO'])
                extra_names = []
                for ch in range(extra_ch):
                    extra_names.append('EXTRA%d' % ch)
                ch_names = ch_names + extra_names
                ch_info = CAP['BIOSEMI_64_INFO'] + ['misc'] * extra_ch
            else:
                qc.print_c(
                    '****** load_raw(): WARNING: Unrecognized number of channels (%d) ******'
                    % nch, 'y')
                qc.print_c(
                    'The last channel will be assumed to be trigger. Press Enter to continue, or Ctrl+C to break.',
                    'r')
                raw_input()
                # Set the trigger to be channel 0 because later we will move it to channel 0.
                ch_names = ['TRIGGER'
                            ] + ['CH%d' % (x + 1) for x in range(nch - 1)]
                ch_info = ['stim'] + ['eeg'] * (nch - 1)

        # Move the event channel to 0 (for consistency)
        signals_raw = mat['sig'].T  # -> channels x samples
        signals = np.concatenate(
            (signals_raw[-1, :].reshape(1, -1), signals_raw[:-1, :]))

        # Note: Biosig's sload() sometimes returns bogus event values so we use the following for events
        bdf = mne.io.read_raw_edf(rawfile, preload=True)
        events = mne.find_events(bdf)
        signals[-1][:] = bdf._data[
            -1][:]  # overwrite with the correct event values

        info = mne.create_info(ch_names, sample_rate, ch_info)
    else:
        # unknown format
        qc.print_c(
            'ERROR: Unrecognized file extension %s. It should be [.pcl | .fif | .fiff | .gdf | .bdf]'
            % extension, 'r')
        sys.exit(-1)

    if raw is None:
        # signals= channels x samples
        raw = mne.io.RawArray(signals, info)

        # check if software trigger
        trigch = raw.info['ch_names'].index('TRIGGER')
        if events != [] and max(raw[trigch][0][0]) == 0:
            raw.add_events(events, stim_channel='TRIGGER')

        # external events with LSL timestamps
        if events_ext != None:
            if extension != 'pcl':
                qc.print_c(
                    '>> ERROR: external events can be only added to raw .pcl files',
                    'r')
                sys.exit(-1)
            events_index = event_timestamps_to_indices(rawfile, events_ext)
            raw.add_events(events_index, stim_channel='TRIGGER')

        qc.make_dirs(basedir + 'fif/')
        fifname = basedir + 'fif/' + basename.split('/')[-1] + '.fif'
        raw.save(fifname, overwrite=True, verbose=False)
        print('Saving to', fifname)

    # find a value changing from zero to a non-zero value
    events = mne.find_events(raw, stim_channel='TRIGGER', shortest_event=1)

    # apply spatial filter
    n_channels = raw._data.shape[0]
    if spfilter == 'car':
        if not spchannels:
            raw._data[1:] = raw._data[1:] - np.mean(raw._data[1:], axis=0)
        else:
            raw._data[spchannels] = raw._data[spchannels] - np.mean(
                raw._data[spchannels], axis=0)
    elif spfilter == 'laplacian':
        if type(spchannels) is not dict:
            raise RuntimeError, 'For Lapcacian, SP_CHANNELS must be of a form {CHANNEL:[NEIGHBORS], ...}'
        rawcopy = raw._data.copy()
        for src in spchannels:
            nei = spchannels[src]
            raw._data[src] = rawcopy[src] - np.mean(rawcopy[nei], axis=0)
    elif spfilter == 'bipolar':
        raw._data[1:] -= raw._data[spchannels]
    elif spfilter is None:
        pass
    else:
        qc.print_c('# ERROR: Unknown spatial filter', spfilter, 'r')
        sys.exit(-1)

    return raw, events
コード例 #7
0
ファイル: trainer.py プロジェクト: beihangld3/BCI_IRMCT
def run_trainer(cfg, ftrain, interactive=False):
    # feature selection?
    datadir = cfg.DATADIR
    feat_picks = None
    txt = 'all'
    if cfg.USE_CVA:
        fcva = ftrain[0] + '.cva'
        if os.path.exists(fcva):
            feat_picks = open(fcva).readline().strip().split(',')
            feat_picks = [int(x) for x in feat_picks]
            print('\n>> Using only selected features')
            print(feat_picks)
            txt = 'cva'

    if hasattr(cfg, 'BALANCE_SAMPLES'):
        do_balance = cfg.BALANCE_SAMPLES
    else:
        do_balance = False

    # preprocessing, epoching and PSD computation
    n_epochs = {}
    if cfg.LOAD_PSD:
        raise RunetimeError, 'SORRY, CODE NOT FINISHED.'
        labels = np.array([])
        X_data = None
        Y_data = None
        sfreq = None
        ts = None
        te = None
        for fpsd in qc.get_file_list(datadir, fullpath=True):
            if fpsd[-4:] != '.psd': continue
            data = qc.load_obj(fpsd)
            labels = np.hstack((labels, data['Y'][:, 0]))
            if X_data is None:
                sfreq = data['sfreq']
                tmin = data['tmin']
                tmax = data['tmax']
                '''
				TODO: implement multi-segment epochs
				'''
                if type(cfg.EPOCH[0]) is list:
                    print('MULTI-SEGMENT EPOCH IS NOT SUPPORTED YET.')
                    sys.exit(-1)
                if cfg.EPOCH[0] < tmin or cfg.EPOCH[1] > tmax:
                    raise RuntimeError, '\n*** Epoch time range is out of data range.'
                ts = int((cfg.EPOCH[0] - tmin) * sfreq / data['wstep'])
                te = int((cfg.EPOCH[1] - tmin) * sfreq / data['wstep'])

                # X: trials x channels x features
                X_data = data['X'][:, ts:te, :]
                Y_data = data['Y'][:, ts:te]
            else:
                X_data = np.vstack((X_data, data['X'][:, ts:te, :]))
                Y_data = np.vstack((Y_data, data['Y'][:, ts:te]))
        assert (len(labels) > 0)
        psde = data['psde']
        psd_tmin = data['tmin']
        psd_tmax = data['tmax']
        picks = data['picks']
        w_frames = int(sfreq * data['wlen'])  # window length
        psdparams = dict(fmin=data['fmin'],
                         fmax=data['fmax'],
                         wlen=data['wlen'],
                         wstep=data['wstep'])

        if 'classes' in data:
            triggers = data['classes']
        else:
            triggers = {c: cfg.tdef.by_value[c] for c in set(labels)}

        spfilter = data['spfilter']
        spchannels = data['spchannels']
        tpfilter = data['tpfilter']
        for ev in data['classes']:
            n_epochs[ev] = len(
                np.where(Y_data[:, 0] == data['classes'][ev])[0])

    else:
        spfilter = cfg.SP_FILTER
        tpfilter = cfg.TP_FILTER

        # Load multiple files
        if hasattr(cfg, 'MULTIPLIER'):
            multiplier = cfg.MULTIPLIER
        else:
            multiplier = 1
        raw, events = pu.load_multi(ftrain,
                                    spfilter=spfilter,
                                    multiplier=multiplier)
        if cfg.LOAD_EVENTS_FILE is not None:
            events = mne.read_events(cfg.LOAD_EVENTS_FILE)

        triggers = {cfg.tdef.by_value[c]: c for c in set(cfg.TRIGGER_DEF)}

        # Pick channels
        if cfg.CHANNEL_PICKS is None:
            picks = pick_types(raw.info,
                               meg=False,
                               eeg=True,
                               stim=False,
                               eog=False,
                               exclude='bads')
        else:
            picks = []
            for c in cfg.CHANNEL_PICKS:
                if type(c) == int:
                    picks.append(c)
                elif type(c) == str:
                    picks.append(raw.ch_names.index(c))
                else:
                    raise RuntimeError, 'CHANNEL_PICKS is unknown format.\nCHANNEL_PICKS=%s' % cfg.CHANNEL_PICKS

        if max(picks) > len(raw.info['ch_names']):
            print('ERROR: "picks" has a channel index %d while there are only %d channels.'%\
             ( max(picks),len(raw.info['ch_names']) ) )
            sys.exit(-1)

        # Spatial filter
        if cfg.SP_CHANNELS is None:
            spchannels = pick_types(raw.info,
                                    meg=False,
                                    eeg=True,
                                    stim=False,
                                    eog=False,
                                    exclude='bads')
        else:
            spchannels = []
            for c in cfg.SP_CHANNELS:
                if type(c) == int:
                    spchannels.append(c)
                elif type(c) == str:
                    spchannels.append(raw.ch_names.index(c))
                else:
                    raise RuntimeError, 'SP_CHANNELS is unknown format.\nSP_CHANNELS=%s' % cfg.SP_CHANNELS

        # Spectral filter
        if tpfilter is not None:
            raw = raw.filter(tpfilter[0],
                             tpfilter[1],
                             picks=picks,
                             n_jobs=mp.cpu_count())
        if cfg.NOTCH_FILTER is not None:
            raw = raw.notch_filter(cfg.NOTCH_FILTER,
                                   picks=picks,
                                   n_jobs=mp.cpu_count())

        # Read epochs
        try:
            if type(cfg.EPOCH[0]) is list:
                epochs_train = []
                for ep in cfg.EPOCH:
                    epochs_train.append( Epochs(raw, events, triggers, tmin=ep[0], tmax=ep[1], proj=False,\
                     picks=picks, baseline=None, preload=True, add_eeg_ref=False, verbose=False, detrend=None) )
            else:
                epochs_train= Epochs(raw, events, triggers, tmin=cfg.EPOCH[0], tmax=cfg.EPOCH[1], proj=False,\
                 picks=picks, baseline=None, preload=True, add_eeg_ref=False, verbose=False, detrend=None)
        except:
            print('\n*** (trainer.py) ERROR OCCURRED WHILE EPOCHING ***\n')
            traceback.print_exc()
            if interactive:
                print('Dropping into a shell.\n')
                pdb.set_trace()
            raise RuntimeError

        label_set = np.unique(triggers.values())
        sfreq = raw.info['sfreq']

        # Compute features
        if cfg.FEATURES == 'PSD':
            res = get_psd_feature(epochs_train, cfg.EPOCH, cfg.PSD, feat_picks)
            X_data = res['X_data']
            Y_data = res['Y_data']
            wlen = res['wlen']
            w_frames = res['w_frames']
            psde = res['psde']
            psdfile = '%s/psd/psd-train.pcl' % datadir

        elif cfg.FEATURES == 'TIMELAG':
            '''
			TODO: Implement multiple epochs for timelag feature
			'''
            if type(epcohs_train) is list:
                print(
                    'MULTIPLE EPOCHS NOT IMPLEMENTED YET FOR TIMELAG FEATURE.')
                sys.exit(-1)

            X_data, Y_data = get_timelags(epochs_train,
                                          cfg.TIMELAG['w_frames'],
                                          cfg.TIMELAG['wstep'],
                                          cfg.TIMELAG['downsample'])
        elif cfg.FEATURES == 'WAVELET':
            '''
			TODO: Implement multiple epochs for wavelet feature
			'''
            if type(epcohs_train) is list:
                print(
                    'MULTIPLE EPOCHS NOT IMPLEMENTED YET FOR WAVELET FEATURE.')
                sys.exit(-1)

            ############################### DO WE NEED SLIDING WINDOW ?????????
            X_data, Y_data = None, None
            for ev in epochs_train.event_id:
                e = 0
                for ep in epochs_train[ev]:
                    e += 1
                    freqs = np.arange(4, 30, 2)
                    n_cycles = freqs / 2
                    tfr = mne.time_frequency.cwt_morlet(ep,
                                                        sfreq,
                                                        freqs=freqs,
                                                        n_cycles=n_cycles)
                    tlen = 0.8
                    tfr = np.log(np.abs(tfr[:, :, round(-sfreq * tlen):]))
                    '''
					qc.make_dirs('%s/mat'% cfg.DATADIR)
					scipy.io.savemat('%s/mat/tfr-%s-%d.mat'% (cfg.DATADIR,ev,e), {'tfr':tfr[2]})
					'''
                    feat = tfr.reshape(1, -1)
                    if X_data is None:
                        X_data = feat
                    else:
                        X_data = np.concatenate((X_data, feat), axis=0)
                # Y_data dimension is different here !
                y = np.empty((epochs_train[ev]._data.shape[0]))  # windows x 1
                y.fill(epochs_train.event_id[ev])
                if Y_data is None:
                    Y_data = y
                else:
                    Y_data = np.concatenate((Y_data, y))

            cls= RandomForestClassifier(n_estimators=cfg.RF['trees'], max_features='auto',\
             max_depth=cfg.RF['maxdepth'], n_jobs=mp.cpu_count() )#, class_weight={cfg.tdef.LOS:20, cfg.tdef.LO:1})
            scores = []
            cnum = 1
            timer = qc.Timer()
            num_labels = len(label_set)
            cm = np.zeros((num_labels, num_labels))

            # select train and test trial ID's
            from sklearn import cross_validation
            cv = cross_validation.ShuffleSplit(X_data.shape[0],
                                               n_iter=20,
                                               test_size=0.1)
            for train, test in cv:
                timer.reset()
                X_train = X_data[train]
                X_test = X_data[test]
                Y_train = Y_data[train]
                Y_test = Y_data[test]
                if do_balance != False:
                    X_train, Y_train = balance_samples(X_train, Y_train,
                                                       do_balance, False)
                    X_test, Y_test = balance_samples(X_test, Y_test,
                                                     do_balance, False)

                cls.n_jobs = mp.cpu_count()
                cls.fit(X_train, Y_train)
                cls.n_jobs = 1
                #score= cls.score( X_test, Y_test )
                Y_pred = cls.predict(X_test)
                score = skmetrics.accuracy_score(Y_test, Y_pred)
                cm += skmetrics.confusion_matrix(Y_test, Y_pred, label_set)
                scores.append(score)
                print('Cross-validation %d / %d (%.2f) - %.1f sec' %
                      (cnum, len(cv), score, timer.sec()))
                cnum += 1

            # show confusion matrix
            cm_sum = np.sum(cm, axis=1)
            cm_rate = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            print('\nY: ground-truth, X: predicted')
            for l in label_set:
                print('%-5s' % cfg.tdef.by_value[l][:5], end='\t')
            print()
            for r in cm_rate:
                for c in r:
                    print('%-5.2f' % c, end='\t')
                print()
            print('Average accuracy: %.2f' % np.mean(scores))
            #X_data= X_data.reshape(1, X_data.shape[0], X_data.shape[1])
            sys.exit()

        else:
            print('>> ERROR: %s not supported yet.' % cfg.FEATURES)
            sys.exit()

        psdparams = cfg.PSD
        for ev in triggers:
            n_epochs[ev] = len(np.where(events[:, -1] == triggers[ev])[0])

    # Init a classifier
    if cfg.CLASSIFIER == 'RF':
        # Make sure to set n_jobs=cpu_count() for training and n_jobs=1 for testing.
        cls= RandomForestClassifier(n_estimators=cfg.RF['trees'], max_features='auto',\
         max_depth=cfg.RF['maxdepth'], n_jobs=mp.cpu_count(), class_weight='balanced' )
    elif cfg.CLASSIFIER == 'LDA':
        cls = LDA()
    elif cfg.CLASSIFIER == 'rLDA':
        cls = rLDA(cfg.RLDA_REGULARIZE_COEFF)
    else:
        raise RuntimeError, '*** Unknown classifier %s' % cfg.CLASSIFIER

    # Cross-validation
    if cfg.CV_PERFORM is not None:
        ntrials, nsamples, fsize = X_data.shape

        if cfg.CV_PERFORM == 'LeaveOneOut':
            print('\n>> %d-fold leave-one-out cross-validation' % ntrials)
            cv = LeaveOneOut(len(Y_data))
        elif cfg.CV_PERFORM == 'StratifiedShuffleSplit':
            print(
                '\n>> %d-fold stratified cross-validation with test set ratio %.2f'
                % (cfg.CV_FOLDS, cfg.CV_TEST_RATIO))
            cv = StratifiedShuffleSplit(Y_data[:, 0],
                                        cfg.CV_FOLDS,
                                        test_size=cfg.CV_TEST_RATIO,
                                        random_state=0)
        else:
            print('>> ERROR: Unsupported CV method yet.')
            sys.exit(-1)
        print('%d trials, %d samples per trial, %d feature dimension' %
              (ntrials, nsamples, fsize))

        # Do it!
        scores = crossval_epochs(cv, X_data, Y_data, cls, cfg.tdef.by_value,
                                 do_balance)

        # Results
        print('\n>> Class information')
        for ev in np.unique(Y_data):
            print(
                '%s: %d trials' %
                (cfg.tdef.by_value[ev], len(np.where(Y_data[:, 0] == ev)[0])))
        if do_balance:
            print('The number of samples was balanced across classes. Method:',
                  do_balance)

        print('\n>> Experiment conditions')
        print('Spatial filter: %s (channels: %s)' % (spfilter, spchannels))
        print('Spectral filter: %s' % tpfilter)
        print('Notch filter: %s' % cfg.NOTCH_FILTER)
        print('Channels: %s' % picks)
        print('PSD range: %.1f - %.1f Hz' %
              (psdparams['fmin'], psdparams['fmax']))
        print('Window step: %.1f msec' % (1000.0 * psdparams['wstep'] / sfreq))
        if type(wlen) is list:
            for i, w in enumerate(wlen):
                print('Window size: %.1f sec' % (w))
                print('Epoch range: %s sec' % (cfg.EPOCH[i]))
        else:
            print('Window size: %.1f sec' % (psdparams['wlen']))
            print('Epoch range: %s sec' % (cfg.EPOCH))

        #chance= 1.0 / len(np.unique(Y_data))
        cv_mean, cv_std = np.mean(scores), np.std(scores)
        print('\n>> Average CV accuracy over %d epochs' % ntrials)
        if cfg.CV_PERFORM in ['LeaveOneOut', 'StratifiedShuffleSplit']:
            print("mean %.3f, std: %.3f" % (cv_mean, cv_std))
        print('Classifier: %s' % cfg.CLASSIFIER)
        if cfg.CLASSIFIER == 'RF':
            print('            %d trees, %d max depth' %
                  (cfg.RF['trees'], cfg.RF['maxdepth']))

        if cfg.USE_LOG:
            logfile = '%s/result_%s_%s.txt' % (datadir, cfg.CLASSIFIER, txt)
            logout = open(logfile, 'a')
            logout.write('%s\t%.3f\t%.3f\n' %
                         (ftrain[0], np.mean(scores), np.var(scores)))
            logout.close()

    # Train classifier
    archtype = platform.architecture()[0]

    clsfile = '%s/classifier/classifier-%s.pcl' % (datadir, archtype)
    print('\n>> Training classifier')
    X_data_merged = np.concatenate(X_data)
    Y_data_merged = np.concatenate(Y_data)
    if do_balance:
        X_data_merged, Y_data_merged = balance_samples(X_data_merged,
                                                       Y_data_merged,
                                                       do_balance,
                                                       verbose=True)

    timer = qc.Timer()
    cls.fit(X_data_merged, Y_data_merged)
    print('Trained %d samples x %d dimension in %.1f sec'% \
     (X_data_merged.shape[0], X_data_merged.shape[1], timer.sec()))
    # set n_jobs = 1 for testing
    cls.n_jobs = 1

    if cfg.EXPORT_CLS == True:
        classes = {c: cfg.tdef.by_value[c] for c in np.unique(Y_data)}
        if cfg.FEATURES == 'PSD':
            data = dict(cls=cls,
                        psde=psde,
                        sfreq=sfreq,
                        picks=picks,
                        classes=classes,
                        epochs=cfg.EPOCH,
                        w_frames=w_frames,
                        w_seconds=psdparams['wlen'],
                        wstep=psdparams['wstep'],
                        spfilter=spfilter,
                        spchannels=spchannels,
                        refchannel=None,
                        tpfilter=tpfilter,
                        notch=cfg.NOTCH_FILTER,
                        triggers=cfg.tdef)
        elif cfg.FEATURES == 'TIMELAG':
            data = dict(cls=cls, parameters=cfg.TIMELAG)

        qc.make_dirs('%s/classifier' % datadir)
        qc.save_obj(clsfile, data)

    # Show top distinctive features
    if cfg.CLASSIFIER == 'RF' and cfg.FEATURES == 'PSD':
        print('\n>> Good features ordered by importance')
        keys, _ = qc.sort_by_value(list(cls.feature_importances_), rev=True)
        if cfg.EXPORT_GOOD_FEATURES:
            gfout = open('%s/good_features.txt' % datadir, 'w')

        # reverse-lookup frequency from fft
        if type(wlen) is not list:
            fq = 0
            fq_res = 1.0 / psdparams['wlen']
            fqlist = []
            while fq <= psdparams['fmax']:
                if fq >= psdparams['fmin']: fqlist.append(fq)
                fq += fq_res

            for k in keys[:cfg.FEAT_TOPN]:
                ch, hz = feature2chz(k, fqlist, picks, ch_names=raw.ch_names)
                print('%s, %.1f Hz  (feature %d)' % (ch, hz, k))
                if cfg.EXPORT_GOOD_FEATURES:
                    gfout.write('%s\t%.1f\n' % (ch, hz))

            if cfg.EXPORT_GOOD_FEATURES:
                if cfg.CV_PERFORM is not None:
                    gfout.write(
                        '\nCross-validation performance: mean %.2f, std %.2f\n'
                        % (cv_mean, cv_std))
                gfout.close()
            print()
        else:
            print('Ignoring good features because of multiple epochs.')

    # Test file
    if len(cfg.ftest) > 0:
        raw_test, events_test = pu.load_raw('%s' % (cfg.ftest), spfilter)
        '''
		TODO: implement multi-segment epochs
		'''
        if type(cfg.EPOCH[0]) is list:
            print('MULTI-SEGMENT EPOCH IS NOT SUPPORTED YET.')
            sys.exit(-1)

        epochs_test= Epochs(raw_test, events_test, triggers, tmin=cfg.EPOCH[0], tmax=cfg.EPOCH[1],\
         proj=False, picks=picks, baseline=None, preload=True, add_eeg_ref=False)

        if cfg.FEATURES == 'PSD':
            psdfile = 'psd-test.pcl'
            if not os.path.exists(psdfile):
                print('\n>> Computing PSD for test set')
                X_test, y_test = pu.get_psd(epochs_test, psde, w_frames,
                                            int(sfreq / 8))
                qc.save_obj(psdfile, {'X': X_test, 'y': y_test})
            else:
                print('\n>> Loading %s' % psdfile)
                data = qc.load_obj(psdfile)
                X_test, y_test = data['X'], data['y']
        else:
            print('>> Feature not supported yet for testing set.')
            sys.exit(-1)

        score_test = cls.score(np.concatenate(X_test), np.concatenate(y_test))
        print('Testing score', score_test)

        # running performance
        print('\nRunning performance over time')
        scores_windows = []
        timer = qc.Timer()
        for ep in range(y_test.shape[0]):
            scores = []
            frames = X_test[ep].shape[0]
            timer.reset()
            for t in range(frames):
                X = X_test[ep][t, :]
                y = [y_test[ep][t]]
                scores.append(cls.score(X, y))
                #print('%d /%d   %.1f msec'% (t,X_test[ep].shape[0],1000*timer.sec()) )
            print('Tested epoch %d, %.3f msec per window' %
                  (ep, timer.sec() * 1000.0 / frames))
            scores_windows.append(scores)
        scores_windows = np.array(scores_windows)

        ###############################################################################
        # Plot performance over time
        ###############################################################################
        #w_times= (w_start + w_frames / 2.) / sfreq + epochs.tmin
        step = float(epochs_test.tmax -
                     epochs_test.tmin) / scores_windows.shape[1]
        w_times = np.arange(epochs_test.tmin, epochs_test.tmax, step)
        plt.plot(w_times, np.mean(scores_windows, 0), label='Score')
        plt.axvline(0, linestyle='--', color='k', label='Onset')
        plt.axhline(0.5, linestyle='-', color='k', label='Chance')
        plt.xlabel('time (s)')
        plt.ylabel('Classification accuracy')
        plt.title('Classification score over time')
        plt.legend(loc='lower right')
        plt.show()
    '''
コード例 #8
0
ファイル: trainer.py プロジェクト: heygayboy/EEG_MEG
def run_trainer(cfg, ftrain, interactive=False):
    # feature selection?
    datadir= cfg.DATADIR
    feat_picks= None
    txt= 'all'

    do_balance= False

    # preprocessing, epoching and PSD computation
    n_epochs= {}

    spfilter= cfg.SP_FILTER
    tpfilter= cfg.TP_FILTER

    # Load multiple files
    multiplier= 1
    raw, events= pu.load_multi(ftrain, spfilter=spfilter, multiplier=multiplier)
    #print(raw._data.shape)  #(17L, 2457888L)
    triggers= { cfg.tdef.by_value[c]:c for c in set(cfg.TRIGGER_DEF) }

    # Pick channels
    if cfg.CHANNEL_PICKS is None:
        picks= pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude='bads') 
        #print (picks) # [ 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16]
    else:
        picks= []
        for c in cfg.CHANNEL_PICKS:
            if type(c)==int:
                picks.append(c)
            elif type(c)==str:
                picks.append( raw.ch_names.index(c) )
            else:
                raise RuntimeError, 'CHANNEL_PICKS is unknown format.\nCHANNEL_PICKS=%s'% cfg.CHANNEL_PICKS
 
    if max(picks) > len(raw.info['ch_names']):
        print('ERROR: "picks" has a channel index %d while there are only %d channels.'%\
            ( max(picks),len(raw.info['ch_names']) ) )
        sys.exit(-1)
# 
    # Spatial filter
    if cfg.SP_CHANNELS is None:
        spchannels= pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude='bads')
    else:
        spchannels= []
        for c in cfg.SP_CHANNELS:
            if type(c)==int:
                spchannels.append(c)
            elif type(c)==str:
                spchannels.append( raw.ch_names.index(c) )
            else:
                raise RuntimeError, 'SP_CHANNELS is unknown format.\nSP_CHANNELS=%s'% cfg.SP_CHANNELS
# 
    # Spectral filter
    if tpfilter is not None:
        raw= raw.filter( tpfilter[0], tpfilter[1], picks=picks, n_jobs= mp.cpu_count() )
    if cfg.NOTCH_FILTER is not None:
        raw= raw.notch_filter( cfg.NOTCH_FILTER, picks=picks, n_jobs= mp.cpu_count() )
    
    # Read epochs
    try:
        
        epochs_train= Epochs(raw, events, triggers, tmin=cfg.EPOCH[0], tmax=cfg.EPOCH[1], proj=False,\
            picks=picks, baseline=None, preload=True, add_eeg_ref=False, verbose=False, detrend=None)
        #print (epochs_train)# <Epochs  |  n_events : 422 (all good), tmin : 1.0 (s), tmax : 2.0 (s), baseline : None, ~26.5 MB, data loaded,'LEFT_GO': 212, 'RIGHT_GO': 210>
    except:
        print('\n*** (trainer.py) ERROR OCCURRED WHILE EPOCHING ***\n')
        traceback.print_exc()
        if interactive:
            print('Dropping into a shell.\n')
            pdb.set_trace()
        raise RuntimeError
    '''
    epochs_data= epochs_train.get_data()
    print (epochs_data.shape)  #(422L, 16L, 513L)  trail*channel*caiyangdian
    
    #Visualize raw data for some channel in some trial
    ptrial=1
    trail=np.zeros((len(spchannels),epochs_data.shape[2]))
    print(trail)
    for pch in range(len(spchannels)):
        print(pch)
        trail[pch,::] =epochs_data[ptrial,pch,::]
    color=["b","g","r",'c','m','y','k','w',"b","g","r",'c','m','y','k','w']
    linstyle=['-','-','-','-','-','-','-','-','--','--','--','--','--','--','--','--',]
    for pch in range(len(spchannels)):
        print(color[pch])
        print(linstyle[pch])
        plt.plot(np.linspace(cfg.EPOCH[0], cfg.EPOCH[1], epochs_data.shape[2]), trail[pch,::],c=color[pch],ls=linstyle[pch],
                 label='channel %d'%(pch+1),lw=0.5)  
        
    plt.xlabel('time/s')  
    plt.ylabel('voltage/uV')  
    plt.title('Viewer')  
    plt.legend(loc="lower right")  
    plt.show()
    '''
    
    
    label_set= np.unique(triggers.values())
    sfreq= raw.info['sfreq']
  
    # Compute features
    res= get_psd_feature(epochs_train, cfg.EPOCH, cfg.PSD, feat_picks)
    X_data= res['X_data'] 
    Y_data= res['Y_data']
    wlen= res['wlen']
    w_frames= res['w_frames']
    psde= res['psde']
    psdfile= '%s/psd/psd-train.pcl'% datadir
    plot_pca_componet(X_data, Y_data)
    
    
    
  
    psdparams= cfg.PSD
#     print (events)
    for ev in triggers:
        print (ev) 
        n_epochs[ev]= len( np.where(events[:,-1]==triggers[ev])[0] )#{'RIGHT_GO': 150, 'LEFT_GO': 150} total trails
  
    # Init a classifier
    if cfg.CLASSIFIER=='RF':
        # Make sure to set n_jobs=cpu_count() for training and n_jobs=1 for testing.
        cls= RandomForestClassifier(n_estimators=cfg.RF['trees'], max_features='auto',\
            max_depth=cfg.RF['maxdepth'], n_jobs=mp.cpu_count(), class_weight='balanced' )
    elif cfg.CLASSIFIER=='LDA':
        cls= LDA()
#     elif cfg.CLASSIFIER=='rLDA':
#         cls= rLDA(cfg.RLDA_REGULARIZE_COEFF)
    else:
        raise RuntimeError, '*** Unknown classifier %s'% cfg.CLASSIFIER
  
    # Cross-validation
    if cfg.CV_PERFORM is not None:
        ntrials, nsamples, fsize= X_data.shape
  
        if cfg.CV_PERFORM=='LeaveOneOut':
            print('\n>> %d-fold leave-one-out cross-validation'% ntrials)
            cv= LeaveOneOut(len(Y_data))
        elif cfg.CV_PERFORM=='StratifiedShuffleSplit':
            print('\n>> %d-fold stratified cross-validation with test set ratio %.2f'% (cfg.CV_FOLDS, cfg.CV_TEST_RATIO))
            cv= StratifiedShuffleSplit(Y_data[:,0], cfg.CV_FOLDS, test_size=cfg.CV_TEST_RATIO, random_state=0)
        else:
            print('>> ERROR: Unsupported CV method yet.')
            sys.exit(-1)
        print('%d trials, %d samples per trial, %d feature dimension'% (ntrials, nsamples, fsize) )
  
        # Do it!
        scores= crossval_epochs(cv, X_data, Y_data, cls, cfg.tdef.by_value, do_balance)
         
         
        '''
        #learning curve        
        train_sizes,train_loss,test_loss=learning_curve(cls,X_data.reshape(X_data.shape[0]*X_data.shape[1],X_data.shape[2]),Y_data.reshape(Y_data.shape[0]*Y_data.shape[1]),train_sizes=[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0])
        print(X_data.shape)
        print(Y_data.shape)
        train_loss_mean=np.mean(train_loss,axis=1)
        test_loss_mean=np.mean(test_loss,axis=1)
        plt.plot(train_sizes,train_loss_mean,label='training')
        plt.plot(train_sizes,test_loss_mean,label='Cross-validation')
        plt.xlabel('training examples')
        plt.ylabel('loss')
        plt.legend(loc='best')
        plt.show()  
        ''' 
           
 
        # Results
        print('\n>> Class information')
        for ev in np.unique(Y_data):
            print('%s: %d trials'% (cfg.tdef.by_value[ev], len(np.where(Y_data[:,0]==ev)[0])) )
        if do_balance:
            print('The number of samples was balanced across classes. Method:', do_balance)
  
        print('\n>> Experiment conditions')
        print('Spatial filter: %s (channels: %s)'% (spfilter, spchannels) )
        print('Spectral filter: %s'% tpfilter)
        print('Notch filter: %s'% cfg.NOTCH_FILTER)
        print('Channels: %s'% picks)
        print('PSD range: %.1f - %.1f Hz'% (psdparams['fmin'], psdparams['fmax']) )
        print('Window step: %.1f msec'% (1000.0 * psdparams['wstep'] / sfreq) )
        if type(wlen) is list:
            for i, w in enumerate(wlen):
                print('Window size: %.1f sec'% (w) )
                print('Epoch range: %s sec'% (cfg.EPOCH[i]))
        else:
            print('Window size: %.1f sec'% (psdparams['wlen']) )
            print('Epoch range: %s sec'% (cfg.EPOCH))
  
        #chance= 1.0 / len(np.unique(Y_data))
        cv_mean, cv_std= np.mean(scores), np.std(scores)
        print('\n>> Average CV accuracy over %d epochs'% ntrials)
        if cfg.CV_PERFORM in ['LeaveOneOut','StratifiedShuffleSplit']:
            print("mean %.3f, std: %.3f" % (cv_mean, cv_std) )
        print('Classifier: %s'% cfg.CLASSIFIER)
        if cfg.CLASSIFIER=='RF':
            print('            %d trees, %d max depth'% (cfg.RF['trees'], cfg.RF['maxdepth']) )
  
        if cfg.USE_LOG:
            logfile= '%s/result_%s_%s.txt'% (datadir, cfg.CLASSIFIER, txt)
            logout= open(logfile, 'a')
            logout.write('%s\t%.3f\t%.3f\n'% (ftrain[0], np.mean(scores), np.var(scores)) )
            logout.close()
  
    # Train classifier
    archtype= platform.architecture()[0] # (’64bit’, ‘Windows7’)
  
    clsfile= '%s/classifier/classifier-%s.pcl'% (datadir,archtype)
    print('\n>> Training classifier')
    X_data_merged= np.concatenate( X_data )
    Y_data_merged= np.concatenate( Y_data ) 
    timer= qc.Timer()
    cls.fit( X_data_merged, Y_data_merged)
    print('Trained %d samples x %d dimension in %.1f sec'% \
        (X_data_merged.shape[0], X_data_merged.shape[1], timer.sec()))
    # set n_jobs = 1 for testing
    cls.n_jobs= 1
  
 
    classes= { c:cfg.tdef.by_value[c] for c in np.unique(Y_data) }
    #save FEATURES'PSD':
    data= dict( cls=cls, psde=psde, sfreq=sfreq, picks=picks, classes=classes,
        epochs=cfg.EPOCH, w_frames=w_frames, w_seconds=psdparams['wlen'],
        wstep=psdparams['wstep'], spfilter=spfilter, spchannels=spchannels, refchannel=None,
        tpfilter=tpfilter, notch=cfg.NOTCH_FILTER, triggers=cfg.tdef )  
    qc.make_dirs('%s/classifier'% datadir)
    qc.save_obj(clsfile, data)
  
    # Show top distinctive features
    if cfg.CLASSIFIER=='RF':
        print('\n>> Good features ordered by importance')
        keys, _= qc.sort_by_value( list(cls.feature_importances_), rev=True )
        if cfg.EXPORT_GOOD_FEATURES:
            gfout= open('%s/good_features.txt'% datadir, 'w')
  
        # reverse-lookup frequency from fft
        if type(wlen) is not list:
            fq= 0
            fq_res= 1.0 / psdparams['wlen']
            fqlist= []
            while fq <= psdparams['fmax']:
                if fq >= psdparams['fmin']: fqlist.append(fq)
                fq += fq_res
  
            for k in keys[:cfg.FEAT_TOPN]:
                ch,hz= qc.feature2chz(k, fqlist, picks, ch_names=raw.ch_names)
                print('%s, %.1f Hz  (feature %d)'% (ch,hz,k) )
                if cfg.EXPORT_GOOD_FEATURES:
                    gfout.write( '%s\t%.1f\n'% (ch, hz) )
              
            if cfg.EXPORT_GOOD_FEATURES:
                if cfg.CV_PERFORM is not None:
                    gfout.write('\nCross-validation performance: mean %.2f, std %.2f\n'%(cv_mean, cv_std) )
                gfout.close()
            print()
        else:
            print('Ignoring good features because of multiple epochs.')
 
    
    # Test file
    if len(cfg.ftest) > 0:
        raw_test, events_test= pu.load_raw('%s'%(cfg.ftest), spfilter)
 
        '''
        TODO: implement multi-segment epochs
        '''
        if type(cfg.EPOCH[0]) is list:
            print('MULTI-SEGMENT EPOCH IS NOT SUPPORTED YET.')
            sys.exit(-1)
 
        epochs_test= Epochs(raw_test, events_test, triggers, tmin=cfg.EPOCH[0], tmax=cfg.EPOCH[1],\
            proj=False, picks=picks, baseline=None, preload=True, add_eeg_ref=False)
 
        
        psdfile= 'psd-test.pcl'
        if not os.path.exists(psdfile):
            print('\n>> Computing PSD for test set')
            X_test, y_test= pu.get_psd(epochs_test, psde, w_frames, int(sfreq/8))
            qc.save_obj(psdfile, {'X':X_test, 'y':y_test})
        else:
            print('\n>> Loading %s'% psdfile)
            data= qc.load_obj(psdfile)
            X_test, y_test= data['X'], data['y']
        
 
        score_test= cls.score( np.concatenate(X_test), np.concatenate(y_test) )
        print('Testing score', score_test)
 
        # running performance
        print('\nRunning performance over time')
        scores_windows= []
        timer= qc.Timer()
        for ep in range( y_test.shape[0] ):
            scores= []
            frames= X_test[ep].shape[0]
            timer.reset()
            for t in range(frames):
                X= X_test[ep][t,:]
                y= [y_test[ep][t]]
                scores.append( cls.score(X, y) )
                #print('%d /%d   %.1f msec'% (t,X_test[ep].shape[0],1000*timer.sec()) )
            print('Tested epoch %d, %.3f msec per window'%(ep, timer.sec()*1000.0/frames) )
            scores_windows.append(scores)
        scores_windows= np.array(scores_windows)  
コード例 #9
0
    def __init__(self,
                 classifier=None,
                 buffer_size=1.0,
                 fake=False,
                 amp_serial=None,
                 amp_name=None):
        """
		Params
		------
			classifier: classifier file
			spfilter: spatial filter to use
			buffer_size: length of the signal buffer in seconds
		"""

        from stream_receiver import StreamReceiver

        self.classifier = classifier
        self.buffer_sec = buffer_size
        self.fake = fake
        self.amp_serial = amp_serial
        self.amp_name = amp_name

        if self.fake == False:
            model = qc.load_obj(self.classifier)
            if model == None:
                self.print('Error loading %s' % model)
                sys.exit(-1)
            self.cls = model['cls']
            self.psde = model['psde']
            self.labels = list(self.cls.classes_)
            self.spfilter = model['spfilter']
            self.spchannels = model['spchannels']
            self.notch = model['notch']
            self.w_seconds = model['w_seconds']
            self.w_frames = model['w_frames']
            self.wstep = model['wstep']
            self.sfreq = model['sfreq']
            assert int(self.sfreq * self.w_seconds) == self.w_frames

            # window from StreamReceiver is 0-based
            self.picks = np.array(model['picks']) - 1

            # PSD buffer
            psd_temp = self.psde.transform(
                np.zeros((1, len(model['picks']), self.w_frames)))
            self.psd_shape = psd_temp.shape
            self.psd_size = psd_temp.size
            self.psd_buffer = np.zeros(
                (0, self.psd_shape[1], self.psd_shape[2]))
            self.ts_buffer = []

            # Stream Receiver
            self.sr = StreamReceiver(window_size=self.w_seconds,
                                     amp_name=self.amp_name,
                                     amp_serial=self.amp_serial)
            if self.sfreq != self.sr.sample_rate:
                self.print(
                    'WARNING: The amplifier sampling rate (%.1f) != training data sampling rate (%.1f).'
                    % (self.sr.sample_rate, self.sfreq))
        else:
            model = None
            self.psd_shape = None
            self.psd_size = None
            from triggerdef_16 import TriggerDef
            tdef = TriggerDef()
            # must be changed to non-specific labels
            self.labels = [tdef.by_key['DOWN_GO'], tdef.by_key['UP_GO']]