def balance_samples(X, Y, balance_type, verbose=False): if balance_type == 'OVER': """ Oversample from classes that lack samples """ label_set = np.unique(Y) max_set = [] X_balanced = np.array(X) Y_balanced = np.array(Y) # find a class with maximum number of samples for c in label_set: yl = np.where(Y == c)[0] if len(max_set) == 0 or len(yl) > max_set[1]: max_set = [c, len(yl)] for c in label_set: if c == max_set[0]: continue yl = np.where(Y == c)[0] extra_samples = max_set[1] - len(yl) extra_idx = np.random.choice(yl, extra_samples) X_balanced = np.append(X_balanced, X[extra_idx], axis=0) Y_balanced = np.append(Y_balanced, Y[extra_idx], axis=0) elif balance_type == 'UNDER': """ Undersample from classes that are excessive """ label_set = np.unique(Y) min_set = [] # find a class with minimum number of samples for c in label_set: yl = np.where(Y == c)[0] if len(min_set) == 0 or len(yl) < min_set[1]: min_set = [c, len(yl)] yl = np.where(Y == min_set[0])[0] X_balanced = np.array(X[yl]) Y_balanced = np.array(Y[yl]) for c in label_set: if c == min_set[0]: continue yl = np.where(Y == c)[0] reduced_idx = np.random.choice(yl, min_set[1]) X_balanced = np.append(X_balanced, X[reduced_idx], axis=0) Y_balanced = np.append(Y_balanced, Y[reduced_idx], axis=0) elif balance_type is None or balance_type is False: return X, Y else: logger.error('Unknown balancing type %s' % balance_type) raise ValueError logger.info_green('\nNumber of samples after %ssampling' % balance_type.lower()) for c in label_set: logger.info( '%s: %d -> %d' % (c, len(np.where(Y == c)[0]), len(np.where(Y_balanced == c)[0]))) return X_balanced, Y_balanced
def train_decoder(cfg, featdata, feat_file=None): """ Train the final decoder using all data """ # Init a classifier selected_classifier = cfg.CLASSIFIER['selected'] if selected_classifier == 'GB': cls = GradientBoostingClassifier( loss='deviance', learning_rate=cfg.CLASSIFIER[selected_classifier]['learning_rate'], n_estimators=cfg.CLASSIFIER[selected_classifier]['trees'], subsample=1.0, max_depth=cfg.CLASSIFIER[selected_classifier]['depth'], random_state=cfg.CLASSIFIER[selected_classifier]['seed'], max_features='sqrt', verbose=0, warm_start=False, presort='auto') elif selected_classifier == 'XGB': cls = XGBClassifier( loss='deviance', learning_rate=cfg.CLASSIFIER[selected_classifier]['learning_rate'], n_estimators=cfg.CLASSIFIER[selected_classifier]['trees'], subsample=1.0, max_depth=cfg.CLASSIFIER[selected_classifier]['depth'], random_state=cfg.GB['seed'], max_features='sqrt', verbose=0, warm_start=False, presort='auto') elif selected_classifier == 'RF': cls = RandomForestClassifier( n_estimators=cfg.CLASSIFIER[selected_classifier]['trees'], max_features='auto', max_depth=cfg.CLASSIFIER[selected_classifier]['depth'], n_jobs=cfg.N_JOBS, random_state=cfg.CLASSIFIER[selected_classifier]['seed'], oob_score=False, class_weight='balanced_subsample') elif selected_classifier == 'LDA': cls = LDA() elif selected_classifier == 'rLDA': cls = rLDA(cfg.CLASSIFIER[selected_classifier][r_coeff]) else: logger.error('Unknown classifier %s' % selected_classifier) raise ValueError # Setup features X_data = featdata['X_data'] Y_data = featdata['Y_data'] wlen = featdata['wlen'] if cfg.FEATURES['PSD']['wlen'] is None: cfg.FEATURES['PSD']['wlen'] = wlen w_frames = featdata['w_frames'] ch_names = featdata['ch_names'] X_data_merged = np.concatenate(X_data) Y_data_merged = np.concatenate(Y_data) if cfg.CV['BALANCE_SAMPLES']: X_data_merged, Y_data_merged = balance_samples( X_data_merged, Y_data_merged, cfg.CV['BALANCE_SAMPLES'], verbose=True) # Start training the decoder logger.info_green('Training the decoder') timer = qc.Timer() cls.n_jobs = cfg.N_JOBS cls.fit(X_data_merged, Y_data_merged) logger.info('Trained %d samples x %d dimension in %.1f sec' %\ (X_data_merged.shape[0], X_data_merged.shape[1], timer.sec())) cls.n_jobs = 1 # always set n_jobs=1 for testing # Export the decoder classes = {c: cfg.tdef.by_value[c] for c in np.unique(Y_data)} if cfg.FEATURES['selected'] == 'PSD': data = dict(cls=cls, ch_names=ch_names, psde=featdata['psde'], sfreq=featdata['sfreq'], picks=featdata['picks'], classes=classes, epochs=cfg.EPOCH, w_frames=w_frames, w_seconds=cfg.FEATURES['PSD']['wlen'], wstep=cfg.FEATURES['PSD']['wstep'], spatial=cfg.SP_FILTER, spatial_ch=featdata['picks'], spectral=cfg.TP_FILTER[cfg.TP_FILTER['selected']], spectral_ch=featdata['picks'], notch=cfg.NOTCH_FILTER[cfg.NOTCH_FILTER['selected']], notch_ch=featdata['picks'], multiplier=cfg.MULTIPLIER, ref_ch=cfg.REREFERENCE[cfg.REREFERENCE['selected']], decim=cfg.FEATURES['PSD']['decim']) clsfile = '%s/classifier/classifier-%s.pkl' % (cfg.DATA_PATH, platform.architecture()[0]) qc.make_dirs('%s/classifier' % cfg.DATA_PATH) qc.save_obj(clsfile, data) logger.info('Decoder saved to %s' % clsfile) # Reverse-lookup frequency from FFT fq = 0 if type(cfg.FEATURES['PSD']['wlen']) == list: fq_res = 1.0 / cfg.FEATURES['PSD']['wlen'][0] else: fq_res = 1.0 / cfg.FEATURES['PSD']['wlen'] fqlist = [] while fq <= cfg.FEATURES['PSD']['fmax']: if fq >= cfg.FEATURES['PSD']['fmin']: fqlist.append(fq) fq += fq_res # Show top distinctive features if cfg.FEATURES['selected'] == 'PSD': logger.info_green('Good features ordered by importance') if selected_classifier in ['RF', 'GB', 'XGB']: keys, values = qc.sort_by_value(list(cls.feature_importances_), rev=True) elif selected_classifier in ['LDA', 'rLDA']: keys, values = qc.sort_by_value(cls.w, rev=True) keys = np.array(keys) values = np.array(values) if cfg.EXPORT_GOOD_FEATURES: if feat_file is None: gfout = open('%s/classifier/good_features.txt' % cfg.DATA_PATH, 'w') else: gfout = open(feat_file, 'w') if type(wlen) is not list: ch_names = [ch_names[c] for c in featdata['picks']] else: ch_names = [] for w in range(len(wlen)): for c in featdata['picks']: ch_names.append('w%d-%s' % (w, ch_names[c])) chlist, hzlist = features.feature2chz(keys, fqlist, ch_names=ch_names) valnorm = values[:cfg.FEAT_TOPN].copy() valsum = np.sum(valnorm) if valsum == 0: valsum = 1 valnorm = valnorm / valsum * 100.0 # show top-N features for i, (ch, hz) in enumerate(zip(chlist, hzlist)): if i >= cfg.FEAT_TOPN: break txt = '%-3s %5.1f Hz normalized importance %-6s raw importance %-6s feature %-5d' %\ (ch, hz, '%.2f%%' % valnorm[i], '%.2f%%' % (values[i] * 100.0), keys[i]) logger.info(txt) if cfg.EXPORT_GOOD_FEATURES: gfout.write('Importance(%) Channel Frequency Index\n') for i, (ch, hz) in enumerate(zip(chlist, hzlist)): gfout.write('%.3f\t%s\t%s\t%d\n' % (values[i] * 100.0, ch, hz, keys[i])) gfout.close()
def cross_validate(cfg, featdata, cv_file=None): """ Perform cross validation """ # Init a classifier selected_classifier = cfg.CLASSIFIER['selected'] if selected_classifier == 'GB': cls = GradientBoostingClassifier( loss='deviance', learning_rate=cfg.CLASSIFIER['GB']['learning_rate'], presort='auto', n_estimators=cfg.CLASSIFIER['GB']['trees'], subsample=1.0, max_depth=cfg.CLASSIFIER['GB']['depth'], random_state=cfg.CLASSIFIER['GB']['seed'], max_features='sqrt', verbose=0, warm_start=False) elif selected_classifier == 'XGB': cls = XGBClassifier( loss='deviance', learning_rate=cfg.CLASSIFIER['XGB']['learning_rate'], presort='auto', n_estimators=cfg.CLASSIFIER['XGB']['trees'], subsample=1.0, max_depth=cfg.CLASSIFIER['XGB']['depth'], random_state=cfg.CLASSIFIER['XGB'], max_features='sqrt', verbose=0, warm_start=False) elif selected_classifier == 'RF': cls = RandomForestClassifier( n_estimators=cfg.CLASSIFIER['RF']['trees'], max_features='auto', max_depth=cfg.CLASSIFIER['RF']['depth'], n_jobs=cfg.N_JOBS, random_state=cfg.CLASSIFIER['RF']['seed'], oob_score=False, class_weight='balanced_subsample') elif selected_classifier == 'LDA': cls = LDA() elif selected_classifier == 'rLDA': cls = rLDA(cfg.CLASSIFIER['rLDA']['r_coeff']) else: logger.error('Unknown classifier type %s' % selected_classifier) raise ValueError # Setup features X_data = featdata['X_data'] Y_data = featdata['Y_data'] wlen = featdata['wlen'] # Choose CV type ntrials, nsamples, fsize = X_data.shape selected_cv = cfg.CV_PERFORM['selected'] if selected_cv == 'LeaveOneOut': logger.info_green('%d-fold leave-one-out cross-validation' % ntrials) if SKLEARN_OLD: cv = LeaveOneOut(len(Y_data)) else: cv = LeaveOneOut() elif selected_cv == 'StratifiedShuffleSplit': logger.info_green( '%d-fold stratified cross-validation with test set ratio %.2f' % (cfg.CV_PERFORM[selected_cv]['folds'], cfg.CV_PERFORM[selected_cv]['test_ratio'])) if SKLEARN_OLD: cv = StratifiedShuffleSplit( Y_data[:, 0], cfg.CV_PERFORM[selected_cv]['folds'], test_size=cfg.CV_PERFORM[selected_cv]['test_ratio'], random_state=cfg.CV_PERFORM[selected_cv]['seed']) else: cv = StratifiedShuffleSplit( n_splits=cfg.CV_PERFORM[selected_cv]['folds'], test_size=cfg.CV_PERFORM[selected_cv]['test_ratio'], random_state=cfg.CV_PERFORM[selected_cv]['seed']) else: logger.error('%s is not supported yet. Sorry.' % cfg.CV_PERFORM[cfg.CV_PERFORM['selected']]) raise NotImplementedError logger.info('%d trials, %d samples per trial, %d feature dimension' % (ntrials, nsamples, fsize)) # Do it! timer_cv = qc.Timer() scores, cm_txt = crossval_epochs(cv, X_data, Y_data, cls, cfg.tdef.by_value, cfg.CV['BALANCE_SAMPLES'], n_jobs=cfg.N_JOBS, ignore_thres=cfg.CV['IGNORE_THRES'], decision_thres=cfg.CV['DECISION_THRES']) t_cv = timer_cv.sec() # Export results txt = 'Cross validation took %d seconds.\n' % t_cv txt += '\n- Class information\n' txt += '%d epochs, %d samples per epoch, %d feature dimension (total %d samples)\n' %\ (ntrials, nsamples, fsize, ntrials * nsamples) for ev in np.unique(Y_data): txt += '%s: %d trials\n' % (cfg.tdef.by_value[ev], len(np.where(Y_data[:, 0] == ev)[0])) if cfg.CV['BALANCE_SAMPLES']: txt += 'The number of samples was balanced using %ssampling.\n' % cfg.BALANCE_SAMPLES.lower( ) txt += '\n- Experiment condition\n' txt += 'Sampling frequency: %.3f Hz\n' % featdata['sfreq'] txt += 'Spatial filter: %s (channels: %s)\n' % (cfg.SP_FILTER, cfg.SP_CHANNELS) txt += 'Spectral filter: %s\n' % cfg.TP_FILTER[cfg.TP_FILTER['selected']] txt += 'Notch filter: %s\n' % cfg.NOTCH_FILTER[ cfg.NOTCH_FILTER['selected']] txt += 'Channels: ' + ','.join( [str(featdata['ch_names'][p]) for p in featdata['picks']]) + '\n' txt += 'PSD range: %.1f - %.1f Hz\n' % (cfg.FEATURES['PSD']['fmin'], cfg.FEATURES['PSD']['fmax']) txt += 'Window step: %.2f msec\n' % ( 1000.0 * cfg.FEATURES['PSD']['wstep'] / featdata['sfreq']) if type(wlen) is list: for i, w in enumerate(wlen): txt += 'Window size: %.1f msec\n' % (w * 1000.0) txt += 'Epoch range: %s sec\n' % (cfg.EPOCH[i]) else: txt += 'Window size: %.1f msec\n' % (cfg.FEATURES['PSD']['wlen'] * 1000.0) txt += 'Epoch range: %s sec\n' % (cfg.EPOCH) txt += 'Decimation factor: %d\n' % cfg.FEATURES['PSD']['decim'] # Compute stats cv_mean, cv_std = np.mean(scores), np.std(scores) txt += '\n- Average CV accuracy over %d epochs (random seed=%s)\n' % ( ntrials, cfg.CV_PERFORM[cfg.CV_PERFORM['selected']]['seed']) if cfg.CV_PERFORM[cfg.CV_PERFORM['selected']] in [ 'LeaveOneOut', 'StratifiedShuffleSplit' ]: txt += "mean %.3f, std: %.3f\n" % (cv_mean, cv_std) txt += 'Classifier: %s, ' % selected_classifier if selected_classifier == 'RF': txt += '%d trees, %s max depth, random state %s\n' % ( cfg.CLASSIFIER['RF']['trees'], cfg.CLASSIFIER['RF']['depth'], cfg.CLASSIFIER['RF']['seed']) elif selected_classifier == 'GB' or selected_classifier == 'XGB': txt += '%d trees, %s max depth, %s learing_rate, random state %s\n' % ( cfg.CLASSIFIER['GB']['trees'], cfg.CLASSIFIER['GB']['depth'], cfg.CLASSIFIER['GB']['learning_rate'], cfg.CLASSIFIER['GB']['seed']) elif selected_classifier == 'rLDA': txt += 'regularization coefficient %.2f\n' % cfg.CLASSIFIER['rLDA'][ 'r_coeff'] if cfg.CV['IGNORE_THRES'] is not None: txt += 'Decision threshold: %.2f\n' % cfg.CV['IGNORE_THRES'] txt += '\n- Confusion Matrix\n' + cm_txt logger.info(txt) # Export to a file if 'export_result' in cfg.CV_PERFORM[selected_cv] and cfg.CV_PERFORM[ selected_cv]['export_result'] is True: if cv_file is None: if cfg.EXPORT_CLS is True: qc.make_dirs('%s/classifier' % cfg.DATA_PATH) fout = open('%s/classifier/cv_result.txt' % cfg.DATA_PATH, 'w') else: fout = open('%s/cv_result.txt' % cfg.DATA_PATH, 'w') else: fout = open(cv_file, 'w') fout.write(txt) fout.close()
def balance_tpr(cfg, featdata): """ Find the threshold of class index 0 that yields equal number of true positive samples of each class. Currently only available for binary classes. Params ====== cfg: config module feetdata: feature data computed using compute_features() """ n_jobs = cfg.N_JOBS if n_jobs is None: n_jobs = mp.cpu_count() if n_jobs > 1: logger.info('balance_tpr(): Using %d cores' % n_jobs) pool = mp.Pool(n_jobs) results = [] # Init a classifier selected_classifier = cfg.CLASSIFIER[cfg.CLASSIFIER['selected']] if selected_classifier == 'GB': cls = GradientBoostingClassifier( loss='deviance', learning_rate=cfg.CLASSIFIER['GB']['learning_rate'], n_estimators=cfg.CLASSIFIER['GB']['trees'], subsample=1.0, max_depth=cfg.CLASSIFIER['GB']['depth'], random_state=cfg.CLASSIFIER[selected_classifier]['seed'], max_features='sqrt', verbose=0, warm_start=False, presort='auto') elif selected_classifier == 'XGB': cls = XGBClassifier( loss='deviance', learning_rate=cfg.CLASSIFIER['XGB']['learning_rate'], n_estimators=cfg.CLASSIFIER['XGB']['trees'], subsample=1.0, max_depth=cfg.CLASSIFIER['XGB']['depth'], random_state=cfg.CLASSIFIER['XGB']['seed'], max_features='sqrt', verbose=0, warm_start=False, presort='auto') elif selected_classifier == 'RF': cls = RandomForestClassifier( n_estimators=cfg.CLASSIFIER['RF']['trees'], max_features='auto', max_depth=cfg.CLASSIFIER['RF']['depth'], n_jobs=cfg.N_JOBS, random_state=cfg.CLASSIFIER['RF']['seed'], oob_score=False, class_weight='balanced_subsample') elif selected_classifier == 'LDA': cls = LDA() elif selected_classifier == 'rLDA': cls = rLDA(cfg.CLASSIFIER['rLDA']) else: logger.error('Unknown classifier type %s' % selected_classifier) raise ValueError # Setup features X_data = featdata['X_data'] Y_data = featdata['Y_data'] wlen = featdata['wlen'] if cfg.CLASSIFIER['PSD']['wlen'] is None: cfg.CLASSIFIER['PSD']['wlen'] = wlen # Choose CV type ntrials, nsamples, fsize = X_data.shape selected_CV = cfg.CV_PERFORM[cfg.CV_PERFORM['selected']] if cselected_CV == 'LeaveOneOut': logger.info_green('\n%d-fold leave-one-out cross-validation' % ntrials) if SKLEARN_OLD: cv = LeaveOneOut(len(Y_data)) else: cv = LeaveOneOut() elif selected_CV == 'StratifiedShuffleSplit': logger.info_green( '\n%d-fold stratified cross-validation with test set ratio %.2f' % (cfg.CV_PERFORM[selected_CV]['folds'], cfg.CV_PERFORM[selected_CV]['test_ratio'])) if SKLEARN_OLD: cv = StratifiedShuffleSplit( Y_data[:, 0], cfg.CV_PERFORM[selected_CV]['folds'], test_size=cfg.CV_PERFORM[selected_CV]['test_ratio'], random_state=cfg.CV_PERFORM[selected_CV]['random_seed']) else: cv = StratifiedShuffleSplit( n_splits=cfg.CV_PERFORM[selected_CV]['folds'], test_size=cfg.CV_PERFORM[selected_CV]['test_ratio'], random_state=cfg.CV_PERFORM[selected_CV]['random_seed']) else: logger.error('%s is not supported yet. Sorry.' % selected_CV) raise NotImplementedError logger.info('%d trials, %d samples per trial, %d feature dimension' % (ntrials, nsamples, fsize)) # For classifier itself, single core is usually faster cls.n_jobs = 1 Y_preds = [] if SKLEARN_OLD: splits = cv else: splits = cv.split(X_data, Y_data[:, 0]) for cnum, (train, test) in enumerate(splits): X_train = np.concatenate(X_data[train]) X_test = np.concatenate(X_data[test]) Y_train = np.concatenate(Y_data[train]) Y_test = np.concatenate(Y_data[test]) if n_jobs > 1: results.append( pool.apply_async( get_predict_proba, [cls, X_train, Y_train, X_test, Y_test, cnum + 1])) else: Y_preds.append( get_predict_proba(cls, X_train, Y_train, X_test, Y_test, cnum + 1)) cnum += 1 # Aggregate predictions if n_jobs > 1: pool.close() pool.join() for r in results: Y_preds.append(r.get()) Y_preds = np.concatenate(Y_preds, axis=0) # Find threshold for class index 0 Y_preds = sorted(Y_preds) mid_idx = int(len(Y_preds) / 2) if len(Y_preds) == 1: return 0.5 # should not reach here in normal conditions elif len(Y_preds) % 2 == 0: thres = Y_preds[mid_idx - 1] + (Y_preds[mid_idx] - Y_preds[mid_idx - 1]) / 2 else: thres = Y_preds[mid_idx] return thres
def stream_player(server_name, fif_file, chunk_size, auto_restart=True, wait_start=True, repeat=np.float('inf'), high_resolution=False, trigger_file=None): """ Input ===== server_name: LSL server name. fif_file: fif file to replay. chunk_size: number of samples to send at once (usually 16-32 is good enough). auto_restart: play from beginning again after reaching the end. wait_start: wait for user to start in the beginning. repeat: number of loops to play. high_resolution: use perf_counter() instead of sleep() for higher time resolution but uses much more cpu due to polling. trigger_file: used to convert event numbers into event strings for readability. Note: Run pycnbi.set_log_level('DEBUG') to print out the relative time stamps since started. """ raw, events = pu.load_raw(fif_file) sfreq = raw.info['sfreq'] # sampling frequency n_channels = len(raw.ch_names) # number of channels if trigger_file is not None: tdef = trigger_def(trigger_file) try: event_ch = raw.ch_names.index('TRIGGER') except ValueError: event_ch = None if raw is not None: logger.info_green('Successfully loaded %s' % fif_file) logger.info('Server name: %s' % server_name) logger.info('Sampling frequency %.3f Hz' % sfreq) logger.info('Number of channels : %d' % n_channels) logger.info('Chunk size : %d' % chunk_size) for i, ch in enumerate(raw.ch_names): logger.info('%d %s' % (i, ch)) logger.info('Trigger channel : %s' % event_ch) else: raise RuntimeError('Error while loading %s' % fif_file) # set server information sinfo = pylsl.StreamInfo(server_name, channel_count=n_channels, channel_format='float32',\ nominal_srate=sfreq, type='EEG', source_id=server_name) desc = sinfo.desc() channel_desc = desc.append_child("channels") for ch in raw.ch_names: channel_desc.append_child('channel').append_child_value('label', str(ch))\ .append_child_value('type','EEG').append_child_value('unit','microvolts') desc.append_child('amplifier').append_child('settings').append_child_value( 'is_slave', 'false') desc.append_child('acquisition').append_child_value( 'manufacturer', 'PyCNBI').append_child_value('serial_number', 'N/A') outlet = pylsl.StreamOutlet(sinfo, chunk_size=chunk_size) if wait_start: input('Press Enter to start streaming.') logger.info('Streaming started') idx_chunk = 0 t_chunk = chunk_size / sfreq finished = False if high_resolution: t_start = time.perf_counter() else: t_start = time.time() # start streaming played = 1 while played < repeat: idx_current = idx_chunk * chunk_size chunk = raw._data[:, idx_current:idx_current + chunk_size] data = chunk.transpose().tolist() if idx_current >= raw._data.shape[1] - chunk_size: finished = True if high_resolution: # if a resolution over 2 KHz is needed t_sleep_until = t_start + idx_chunk * t_chunk while time.perf_counter() < t_sleep_until: pass else: # time.sleep() can have 500 us resolution using the tweak tool provided. t_wait = t_start + idx_chunk * t_chunk - time.time() if t_wait > 0.001: time.sleep(t_wait) outlet.push_chunk(data) logger.debug('[%8.3fs] sent %d samples (LSL %8.3f)' % (time.perf_counter(), len(data), pylsl.local_clock())) if event_ch is not None: event_values = set(chunk[event_ch]) - set([0]) if len(event_values) > 0: if trigger_file is None: logger.info('Events: %s' % event_values) else: for event in event_values: if event in tdef.by_value: logger.info('Events: %s (%s)' % (event, tdef.by_value[event])) else: logger.info('Events: %s (Undefined event)' % event) idx_chunk += 1 if finished: if auto_restart is False: input( 'Reached the end of data. Press Enter to restart or Ctrl+C to stop.' ) else: logger.info('Reached the end of data. Restarting.') idx_chunk = 0 finished = False if high_resolution: t_start = time.perf_counter() else: t_start = time.time() played += 1
def acquire(self, blocking=True): """ Reads data into buffer. It is a blocking function as default. Fills the buffer and return the current chunk of data and timestamps. Returns: data [samples x channels], timestamps [samples] """ timestamp_offset = False if len(self.timestamps[0]) == 0: timestamp_offset = True self.watchdog.reset() tslist = [] received = False chunk = None while not received: while self.watchdog.sec() < 5: # chunk = [frames]x[ch], tslist = [frames] if len(tslist) == 0: chunk, tslist = self.inlets[0].pull_chunk( max_samples=self.stream_bufsize) if blocking == False and len(tslist) == 0: return np.empty((0, len(self.ch_list))), [] if len(tslist) > 0: if timestamp_offset is True: lsl_clock = pylsl.local_clock() received = True break time.sleep(0.0005) else: logger.warning( 'Timeout occurred while acquiring data. Amp driver bug?') # give up and return empty values to avoid deadlock return np.empty((0, len(self.ch_list))), [] data = np.array(chunk) # BioSemi has pull-up resistor instead of pull-down if self.amp_name == 'BioSemi' and self._lsl_tr_channel is not None: datatype = data.dtype data[:, self._lsl_tr_channel] = (np.bitwise_and( 255, data[:, self._lsl_tr_channel].astype(int)) - 1).astype(datatype) # multiply values (to change unit) if self.multiplier != 1: data[:, self._lsl_eeg_channels] *= self.multiplier if self._lsl_tr_channel is not None: # move trigger channel to 0 and add back to the buffer data = np.concatenate((data[:, self._lsl_tr_channel].reshape( -1, 1), data[:, self._lsl_eeg_channels]), axis=1) else: # add an empty channel with zeros to channel 0 data = np.concatenate((np.zeros( (data.shape[0], 1)), data[:, self._lsl_eeg_channels]), axis=1) # add data to buffer chunk = data.tolist() self.buffers[0].extend(chunk) self.timestamps[0].extend(tslist) if self.bufsize > 0 and len(self.timestamps[0]) > self.bufsize: self.buffers[0] = self.buffers[0][-self.bufsize:] self.timestamps[0] = self.timestamps[0][-self.bufsize:] if timestamp_offset is True: timestamp_offset = False logger.info('LSL timestamp = %s' % lsl_clock) logger.info('Server timestamp = %s' % self.timestamps[-1][-1]) self.lsl_time_offset = self.timestamps[-1][-1] - lsl_clock logger.info('Offset = %.3f ' % (self.lsl_time_offset)) if abs(self.lsl_time_offset) > 0.1: logger.warning('LSL server has a high timestamp offset.') else: logger.info_green('LSL time server synchronized') ''' TODO: test the merging of multiple streams # if we have multiple synchronized amps if len(self.inlets) > 1: for i in range(1, len(self.inlets)): chunk, tslist = self.inlets[i].pull_chunk(max_samples=len(tslist)) # [frames][channels] self.buffers[i].extend(chunk) self.timestamps[i].extend(tslist) if self.bufsize > 0 and len(self.buffers[i]) > self.bufsize: self.buffers[i] = self.buffers[i][-self.bufsize:] ''' # data= array[samples, channels], tslist=[samples] return (data, tslist)
def get_psd_feature(epochs_train, window, psdparam, picks=None, preprocess=None, n_jobs=1): """ Wrapper for get_psd() adding meta information. Input ===== epochs_train: mne.Epochs object or list of mne.Epochs object. window: [t_start, t_end]. Time window range for computing PSD. psdparam: {fmin:float, fmax:float, wlen:float, wstep:int, decim:int}. fmin, fmax in Hz, wlen in seconds, wstep in number of samples. picks: Channels to compute features from. Output ====== dict object containing computed features. """ if type(window[0]) is list: sfreq = epochs_train[0].info['sfreq'] wlen = [] w_frames = [] # multiple PSD estimators, defined for each epoch if type(psdparam) is list: ''' TODO: implement multi-window PSD for each epoch assert len(psdparam) == len(window) for i, p in enumerate(psdparam): if p['wlen'] is None: wl = window[i][1] - window[i][0] else: wl = p['wlen'] wlen.append(wl) w_frames.append(int(sfreq * wl)) ''' logger.error('Multiple psd function not implemented yet.') raise NotImplementedError # same PSD estimator for all epochs else: for i, e in enumerate(window): if psdparam['wlen'] is None: wl = window[i][1] - window[i][0] else: wl = psdparam['wlen'] assert wl > 0 wlen.append(wl) w_frames.append(int(round(sfreq * wl))) else: sfreq = epochs_train.info['sfreq'] wlen = window[1] - window[0] if psdparam['wlen'] is None: psdparam['wlen'] = wlen w_frames = int(round( sfreq * psdparam['wlen'])) # window length in number of samples(frames) if 'decim' not in psdparam or psdparam['decim'] is None: psdparam['decim'] = 1 psde_sfreq = sfreq / psdparam['decim'] psde = mne.decoding.PSDEstimator(sfreq=psde_sfreq, fmin=psdparam['fmin'], fmax=psdparam['fmax'], bandwidth=None, adaptive=False, low_bias=True, n_jobs=1, normalization='length', verbose='WARNING') logger.info_green('PSD computation') if type(epochs_train) is list: X_all = [] for i, ep in enumerate(epochs_train): X, Y_data = get_psd(ep, psde, w_frames[i], psdparam['wstep'], picks, n_jobs=n_jobs, preprocess=preprocess, decim=psdparam['decim']) X_all.append(X) # concatenate along the feature dimension # feature index order: window block x channel block x frequency block # feature vector = [window1, window2, ...] # where windowX = [channel1, channel2, ...] # where channelX = [freq1, freq2, ...] X_data = np.concatenate(X_all, axis=2) else: # feature index order: channel block x frequency block # feature vector = [channel1, channel2, ...] # where channelX = [freq1, freq2, ...] X_data, Y_data = get_psd(epochs_train, psde, w_frames, psdparam['wstep'], picks, n_jobs=n_jobs, preprocess=preprocess, decim=psdparam['decim']) # assign relative timestamps for each feature. time reference is the leading edge of a window. w_starts = np.arange(0, epochs_train.get_data().shape[2] - w_frames, psdparam['wstep']) t_features = w_starts / sfreq + psdparam['wlen'] + window[0] return dict(X_data=X_data, Y_data=Y_data, wlen=wlen, w_frames=w_frames, psde=psde, times=t_features, decim=psdparam['decim'])
def __init__(self, classifier=None, buffer_size=1.0, fake=False, amp_serial=None, amp_name=None): """ Params ------ classifier: classifier file spatial: spatial filter to use buffer_size: length of the signal buffer in seconds """ self.classifier = classifier self.buffer_sec = buffer_size self.fake = fake self.amp_serial = amp_serial self.amp_name = amp_name if self.fake == False: model = qc.load_obj(self.classifier) if model is None: logger.error('Classifier model is None.') raise ValueError self.cls = model['cls'] self.psde = model['psde'] self.labels = list(self.cls.classes_) self.label_names = [model['classes'][k] for k in self.labels] self.spatial = model['spatial'] self.spectral = model['spectral'] self.notch = model['notch'] self.w_seconds = model['w_seconds'] self.w_frames = model['w_frames'] self.wstep = model['wstep'] self.sfreq = model['sfreq'] if 'decim' not in model: model['decim'] = 1 self.decim = model['decim'] if not int(round(self.sfreq * self.w_seconds)) == self.w_frames: logger.error('sfreq * w_sec %d != w_frames %d' % (int(round(self.sfreq * self.w_seconds)), self.w_frames)) raise RuntimeError if 'multiplier' in model: self.multiplier = model['multiplier'] else: self.multiplier = 1 # Stream Receiver self.sr = StreamReceiver(window_size=self.w_seconds, amp_name=self.amp_name, amp_serial=self.amp_serial) if self.sfreq != self.sr.sample_rate: logger.error('Amplifier sampling rate (%.3f) != model sampling rate (%.3f). Stop.' % (self.sr.sample_rate, self.sfreq)) raise RuntimeError # Map channel indices based on channel names of the streaming server self.spatial_ch = model['spatial_ch'] self.spectral_ch = model['spectral_ch'] self.notch_ch = model['notch_ch'] #self.ref_ch = model['ref_ch'] # not supported yet self.ch_names = self.sr.get_channel_names() mc = model['ch_names'] self.picks = [self.ch_names.index(mc[p]) for p in model['picks']] if self.spatial_ch is not None: self.spatial_ch = [self.ch_names.index(mc[p]) for p in model['spatial_ch']] if self.spectral_ch is not None: self.spectral_ch = [self.ch_names.index(mc[p]) for p in model['spectral_ch']] if self.notch_ch is not None: self.notch_ch = [self.ch_names.index(mc[p]) for p in model['notch_ch']] # PSD buffer #psd_temp = self.psde.transform(np.zeros((1, len(self.picks), self.w_frames // self.decim))) #self.psd_shape = psd_temp.shape #self.psd_size = psd_temp.size #self.psd_buffer = np.zeros((0, self.psd_shape[1], self.psd_shape[2])) #self.psd_buffer = None self.ts_buffer = [] logger.info_green('Loaded classifier %s (sfreq=%.3f, decim=%d)' % (' vs '.join(self.label_names), self.sfreq, self.decim)) else: # Fake left-right decoder model = None self.psd_shape = None self.psd_size = None # TODO: parameterize directions using fake_dirs self.labels = [11, 9] self.label_names = ['LEFT_GO', 'RIGHT_GO']
def log_decoding(decoder, logfile, amp_name=None, amp_serial=None, pklfile=True, matfile=False, autostop=False, prob_smooth=False): """ Decode online and write results with event timestamps input ----- decoder: Decoder or DecoderDaemon class object. logfile: File name to contain the result in Python pickle format. amp_name: LSL server name (if known). amp_serial: LSL server serial number (if known). pklfile: Export results to Python pickle format. matfile: Export results to Matlab .mat file if True. autostop: Automatically finish when no more data is received. prob_smooth: Use smoothed probability values according to decoder's smoothing parameter. """ import cv2 import scipy # run event acquisition process in the background state = mp.Value('i', 1) event_queue = mp.Queue() proc = mp.Process(target=log_decoding_helper, args=[state, event_queue, amp_name, amp_serial, autostop]) proc.start() logger.info_green('Spawned event acquisition process.') # init variables and choose decoding function labels = decoder.get_label_names() probs = [] prob_times = [] if prob_smooth: decode_fn = decoder.get_prob_smooth_unread else: decode_fn = decoder.get_prob_unread # simple controller UI cv2.namedWindow("Decoding", cv2.WINDOW_AUTOSIZE) cv2.moveWindow("Decoding", 1400, 50) img = np.zeros([100, 400, 3], np.uint8) cv2.putText(img, 'Press any key to start', (20, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA) cv2.imshow("Decoding", img) cv2.waitKeyEx() img *= 0 cv2.putText(img, 'Press ESC to stop', (40, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA) cv2.imshow("Decoding", img) key = 0 started = False tm_watchdog = qc.Timer(autoreset=True) tm_cls = qc.Timer() while key != 27: prob, prob_time = decode_fn(True) t_lsl = pylsl.local_clock() key = cv2.waitKeyEx(1) if prob is None: # watch dog if tm_cls.sec() > 5: if autostop and started: logger.info('No more streaming data. Finishing.') break tm_cls.reset() tm_watchdog.sleep_atleast(0.001) continue probs.append(prob) prob_times.append(prob_time) txt = '[%.3f] ' % prob_time txt += ', '.join(['%s: %.2f' % (l, p) for l, p in zip(labels, prob)]) txt += ' (%d ms, LSL Diff = %.3f)' % (tm_cls.msec(), (t_lsl-prob_time)) logger.info(txt) if not started: started = True tm_cls.reset() # finish up processes cv2.destroyAllWindows() logger.info('Cleaning up event acquisition process.') state.value = 0 decoder.stop() event_times, event_values = event_queue.get() proc.join() # save values if len(prob_times) == 0: logger.error('No decoding result. Please debug.') import pdb pdb.set_trace() t_start = prob_times[0] probs = np.vstack(probs) event_times = np.array(event_times) event_times = event_times[np.where(event_times >= t_start)[0]] - t_start prob_times = np.array(prob_times) - t_start event_values = np.array(event_values) data = dict(probs=probs, prob_times=prob_times, event_times=event_times, event_values=event_values, labels=labels) if pklfile: qc.save_obj(logfile, data) logger.info('Saved to %s' % logfile) if matfile: pp = qc.parse_path(logfile) matout = '%s/%s.mat' % (pp.dir, pp.name) scipy.io.savemat(matout, data) logger.info('Saved to %s' % matout)