def run(cfg, state=mp.Value('i', 1), queue=None, interactive=False, cv_file=None, feat_file=None, logger=logger): redirect_stdout_to_queue(logger, queue, 'INFO') # add tdef object cfg.tdef = trigger_def(cfg.TRIGGER_FILE) # Extract features if not state.value: sys.exit(-1) featdata = features.compute_features(cfg) # Find optimal threshold for TPR balancing #balance_tpr(cfg, featdata) # Perform cross validation if not state.value: sys.exit(-1) if cfg.CV_PERFORM[cfg.CV_PERFORM['selected']] is not None: cross_validate(cfg, featdata, cv_file=cv_file) # Train a decoder if not state.value: sys.exit(-1) if cfg.EXPORT_CLS is True: train_decoder(cfg, featdata, feat_file=feat_file) with state.get_lock(): state.value = 0
def merge_events(trigger_file, events, eeg_in, eeg_out): tdef = trigger_def(trigger_file) raw, eve = pu.load_raw(eeg_in) print('\nEvents before merging') for key in np.unique(eve[:, 2]): print('%s: %d' % (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0]))) for key in events: ev_src = events[key] ev_out = tdef.by_key[key] x = [] for e in ev_src: x.append(np.where(eve[:, 2] == tdef.by_key[e])[0]) eve[np.concatenate(x), 2] = ev_out # sanity check dups = np.where(0 == np.diff(eve[:, 0]))[0] assert len(dups) == 0 assert max(eve[:, 2]) <= max(tdef.by_value.keys()) # reset trigger channel raw._data[0] *= 0 raw.add_events(eve, 'TRIGGER') raw.save(eeg_out, overwrite=True) print('\nResulting events') for key in np.unique(eve[:, 2]): print('%s: %d' % (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0])))
def run(cfg, state=mp.Value('i', 1), queue=None): """ Training protocol for Alpha/Theta neurofeedback. """ redirect_stdout_to_queue(logger, queue, 'INFO') # add tdef object cfg.tdef = trigger_def(cfg.TRIGGER_FILE) # Extract features if not state.value: sys.exit(-1) raw = mne.io.read_raw_brainvision(cfg.DATA_PATH, preload=True) outfile = cfg.OUT_MICROSTATES_FILE ch_names = [ 'P3', 'C3', 'F3', 'Fz', 'F4', 'C4', 'P4', 'Cz', 'Pz', 'Fp1', 'Fp2', 'T3', 'T5', 'O1', 'O2', 'F7', 'F8', 'T6', 'T4' ] raw.pick_channels(ch_names) raw.set_montage('standard_1005') raw.set_eeg_reference('average') raw.filter(1, 30) maps, segmentation = microstates.segment(raw.get_data(), n_states=4, max_n_peaks=10000000, max_iter=5000, normalize=True) np.savetxt(outfile, maps, delimiter=" ")
def preprocess(loadedraw, events,\ APPLY_CAR,\ l_freq,\ h_freq,\ filter_method,\ tmin,\ tmax,\ tlow,\ thigh,\ n_jobs,\ picks_feat,\ baselineRange,\ verbose=False): # Load raw, apply bandpass (if applicable), epoch raw = loadedraw.copy() # Di # %% Spatial filter - Common Average Reference (CAR) if APPLY_CAR: raw._data[1:] = raw._data[1:] - np.mean(raw._data[1:], axis=0) # print('Preprocess: CAR done') # %% Properties initialization tdef = trigger_def('triggerdef_errp.ini') sfreq = raw.info['sfreq'] event_id = dict(correct=tdef.by_key['FEEDBACK_CORRECT'], wrong=tdef.by_key['FEEDBACK_WRONG']) # %% Bandpass temporal filtering b, a, zi = pu.butter_bandpass(h_freq, l_freq, sfreq, raw._data.shape[0] - 1) # raw._data.shape[0]- 1 because channel 0 is trigger if filter_method is 'NC' and cv_container is None: raw.filter(l_freq=2, filter_length='10s', h_freq=h_freq, n_jobs=n_jobs, picks=picks_feat, method='fft', iir_params=None) # method='iir'and irr_params=None -> filter with a 4th order Butterworth # print('Preprocess: NC_bandpass filtering done') if filter_method is 'LFILT': # print('Preprocess: LFILT filtering done') for x in range(1, raw._data.shape[0]): # range starting from 1 because channel 0 is trigger # raw._data[x,:] = lfilter(b, a, raw._data[x,:]) raw._data[x, :], zi[:, x - 1] = lfilter(b, a, raw._data[x, :], -1, zi[:, x - 1]) # self.eeg[:,x], self.zi[:,x] = lfilter(b, a, self.eeg[:,x], -1,zi[:,x]) # %% Epoching and baselining # = tmin-paddingLength # t_upper = tmax+paddingLength t_lower = 0 t_upper = thigh # t_lower = 0 # t_upper = tmax+paddingLength epochs = mne.Epochs(raw, events=events, event_id=event_id, tmin=t_lower, tmax=t_upper, baseline=baselineRange, picks=picks_feat, preload=True, proj=False, verbose=verbose) total_wframes = epochs.get_data().shape[2] print('Preprocess: Epoching done') # if tmin != tmin_bkp: # # if the baseline range was before the initial tmin, epochs was tricked to # # to select this range (it expects that baseline is witin [tmin,tmax]) # # this part restore the initial tmin and prune the data # epochs.tmin = tmin_bkp # epochs._data = epochs._data[:,:,int((tmin_bkp-tmin)*sfreq):] return tdef, sfreq, event_id, b, a, zi, t_lower, t_upper, epochs, total_wframes
def __init__(self, classifier=None, buffer_size=1.0, fake=False, amp_serial=None,\ amp_name=None, fake_dirs=None, parallel=None, alpha_new=None): """ Params ------ classifier: file name of the classifier buffer_size: buffer window size in seconds fake: False: Connect to an amplifier LSL server and decode True: Create a mock decoder (fake probabilities biased to 1.0) buffer_size: Buffer size in seconds. parallel: dict(period, stride, num_strides) period: Decoding period length for a single decoder in seconds. stride: Time step between decoders in seconds. num_strides: Number of decoders to run in parallel. alpha_new: exponential smoothing factor, real value in [0, 1]. p_new = p_new * alpha_new + p_old * (1 - alpha_new) Example: If the decoder runs 32ms per cycle, we can set period=0.04, stride=0.01, num_strides=4 to achieve 100 Hz decoding. """ self.classifier = classifier self.buffer_sec = buffer_size self.startmsg = 'Decoder daemon started.' self.stopmsg = 'Decoder daemon stopped.' self.fake = fake self.amp_serial = amp_serial self.amp_name = amp_name self.parallel = parallel if alpha_new is None: alpha_new = 1 if not 0 <= alpha_new <= 1: raise ValueError('alpha_new must be a real number between 0 and 1.') self.alpha_new = alpha_new self.alpha_old = 1 - alpha_new if fake == False or fake is None: self.model = qc.load_obj(self.classifier) if self.model == None: raise IOError('Error loading %s' % self.model) else: self.labels = self.model['cls'].classes_ self.label_names = [self.model['classes'][k] for k in self.labels] else: # create a fake decoder with LEFT/RIGHT classes self.model = None tdef = trigger_def('triggerdef_16.ini') if type(fake_dirs) is not list: raise RuntimeError('Decoder(): wrong argument type of fake_dirs: %s.' % type(fake_dirs)) self.labels = [tdef.by_key[t] for t in fake_dirs] self.label_names = [tdef.by_value[v] for v in self.labels] self.startmsg = '** WARNING: FAKE ' + self.startmsg self.stopmsg = 'FAKE ' + self.stopmsg self.psdlock = mp.Lock() self.reset() self.start()
def on_new_tdef_file(self, key, trigger_file): """ Update the event QComboBox with the new events from the new tdef file. """ self.tdef = trigger_def(trigger_file) if self.events: self.on_update_VBoxLayout()
class Basic: """ Contains the basic parameters for the training modality of Motor Imagery protocol """ '''""""""""""""""""""""""""""" DATA """""""""""""""""""""""""""''' # read all data files from this directory for training DATADIR = r'C:\LSL\pycnbi_local\z2\records\fif' # which trigger set? tdef = trigger_def('triggerdef_16.ini') TRIGGER_DEF = {tdef.LEFT_GO, tdef.RIGHT_GO} # epoch ranges in seconds relative to onset EPOCH = [0.5, 4.5] '''""""""""""""""""""""""""""" CHANNEL SPECIFICATION CHANNEL_PICKS Pick a subset of channels for PSD. Note that Python uses zero-based indexing. However, for fif files saved using PyCNBI library, index 0 is the trigger channel and data channels start from index 1. (to be consistent with MATLAB) None: Use all channels. Ignored if LOAD_PSD == True REF_CH_NEW: Re-reference to this set of channels, averaged if more than 1. REF_CH_OLD: Recover this channel which was used as reference channel. """""""""""""""""""""""""""''' #CHANNEL_PICKS = None CHANNEL_PICKS = [ 'Fp1', 'Fp2', 'Fz', 'F3', 'F4', 'F7', 'F8', 'Cz', 'C3', 'C4', 'P3', 'Pz', 'P4' ] #CHANNEL_PICKS = CAP['ANTNEURO_64_NO_PERIPHERAL'] EXCLUDES = ['M1', 'M2', 'EOG'] REF_CH = None #REF_CH = ['CPz', ['M1', 'M2']] '''""""""""""""""""""""""""""" FILTERS """""""""""""""""""""""""""''' # apply spatial filter immediately after loading data SP_FILTER = 'car' # None | 'car' | 'laplacian' # only consider the following channels while computing SP_CHANNELS = CHANNEL_PICKS # CHANNEL_PICKS # None | list # apply spectrial filter immediately after applying SP_FILTER # Can be either overlap-add FIR or forward-backward IIR via filtfilt # Value: None or [lfreq, hfreq] # if lfreq < hfreq: bandpass # if lfreq > hfreq: bandstop # if lfreq == None: highpass # if hfreq == None: lowpass #TP_FILTER = [0.6, 4.0] TP_FILTER = None NOTCH_FILTER = None # None or list of values
def on_new_tdef_file(self, key, trigger_file): """ Update the QComboBox with the new events from the new tdef file. """ self.clear_hBoxLayout() tdef = trigger_def(trigger_file) nb_directions = 4 # Convert 'None' to real None (real None is removed when selected in the GUI) tdef_values = [ None if i == 'None' else i for i in list(tdef.by_name) ] self.create_the_comboBoxes(self.chosen_value, tdef_values, nb_directions)
def merge_events(trigger_file, events, rawfile_in, rawfile_out): tdef = trigger_def(trigger_file) raw, eve = pu.load_raw(rawfile_in) logger.info('=== Before merging ===') notfounds = [] for key in np.unique(eve[:, 2]): if key in tdef.by_value: logger.info( '%s: %d events' % (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0]))) else: logger.info('%d: %d events' % (key, len(np.where(eve[:, 2] == key)[0]))) notfounds.append(key) if notfounds: for key in notfounds: logger.warning('Key %d was not found in the definition file.' % key) for key in events: ev_src = events[key] ev_out = tdef.by_name[key] x = [] for e in ev_src: x.append(np.where(eve[:, 2] == tdef.by_name[e])[0]) eve[np.concatenate(x), 2] = ev_out # sanity check dups = np.where(0 == np.diff(eve[:, 0]))[0] assert len(dups) == 0 # reset trigger channel raw._data[0] *= 0 raw.add_events(eve, 'TRIGGER') raw.save(rawfile_out, overwrite=True) logger.info('=== After merging ===') for key in np.unique(eve[:, 2]): if key in tdef.by_value: logger.info( '%s: %d events' % (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0]))) else: logger.info('%s: %d events' % (key, len(np.where(eve[:, 2] == key)[0])))
def run(cfg, queue=None, interactive=False, cv_file=None, feat_file=None): redirect_stdout_to_queue(queue) cfg.tdef = trigger_def(cfg.TRIGGER_FILE) # Extract features featdata = features.compute_features(cfg) # Find optimal threshold for TPR balancing #balance_tpr(cfg, featdata) # Perform cross validation if cfg.CV_PERFORM[cfg.CV_PERFORM['selected']] is not None: cross_validate(cfg, featdata, cv_file=cv_file) # Train a decoder if cfg.EXPORT_CLS is True: train_decoder(cfg, featdata, feat_file=feat_file)
def config_run(cfg_module): if not (os.path.exists(cfg_module) and os.path.isfile(cfg_module)): raise IOError('%s cannot be loaded.' % os.path.realpath(cfg_module)) cfg = load_cfg(cfg_module) if cfg.FAKE_CLS is None: # chooose amp if cfg.AMP_NAME is None and cfg.AMP_SERIAL is None: amp_name, amp_serial = pu.search_lsl(ignore_markers=True) else: amp_name = cfg.AMP_NAME amp_serial = cfg.AMP_SERIAL fake_dirs = None else: amp_name = None amp_serial = None fake_dirs = [v for (k, v) in cfg.DIRECTIONS] # events and triggers tdef = trigger_def(cfg.TRIGGER_DEF) if cfg.TRIGGER_DEVICE is None: input( '\n** Warning: No trigger device set. Press Ctrl+C to stop or Enter to continue.' ) trigger = pyLptControl.Trigger(cfg.TRIGGER_DEVICE) if trigger.init(50) == False: qc.print_c( '\n** Error connecting to USB2LPT device. Use a mock trigger instead?', 'R') input('Press Ctrl+C to stop or Enter to continue.') trigger = pyLptControl.MockTrigger() trigger.init(50) # init classification decoder = BCIDecoderDaemon(cfg.CLS_MI, buffer_size=1.0, fake=(cfg.FAKE_CLS is not None), amp_name=amp_name, amp_serial=amp_serial, fake_dirs=fake_dirs, parallel=cfg.PARALLEL_DECODING, alpha_new=cfg.PROB_ALPHA_NEW) # OLD: requires trigger values to be always defined #labels = [tdef.by_value[x] for x in decoder.get_labels()] # NEW: events can be mapped into integers: labels = [] dirdata = set([d[1] for d in cfg.DIRECTIONS]) for x in decoder.get_labels(): if x not in dirdata: labels.append(tdef.by_value[x]) else: labels.append(x) # map class labels to bar directions bar_def = {label: str(dir) for dir, label in cfg.DIRECTIONS} bar_dirs = [bar_def[l] for l in labels] dir_seq = [] for x in range(cfg.TRIALS_EACH): dir_seq.extend(bar_dirs) if cfg.TRIALS_RANDOMIZE: random.shuffle(dir_seq) else: dir_seq = [d[0] for d in cfg.DIRECTIONS] * cfg.TRIALS_EACH num_trials = len(dir_seq) qc.print_c('Initializing decoder.', 'W') while decoder.is_running() is 0: time.sleep(0.01) # bar visual object if cfg.FEEDBACK_TYPE == 'BAR': from pycnbi.protocols.viz_bars import BarVisual visual = BarVisual(cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE) elif cfg.FEEDBACK_TYPE == 'BODY': assert hasattr(cfg, 'IMAGE_PATH'), 'IMAGE_PATH is undefined in your config.' from pycnbi.protocols.viz_human import BodyVisual visual = BodyVisual(cfg.IMAGE_PATH, use_glass=cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE) visual.put_text('Waiting to start') if cfg.LOG_PROBS: logdir = qc.parse_path_list(cfg.CLS_MI)[0] probs_logfile = time.strftime(logdir + "probs-%Y%m%d-%H%M%S.txt", time.localtime()) else: probs_logfile = None feedback = Feedback(cfg, visual, tdef, trigger, probs_logfile) # start trial = 1 dir_detected = [] prob_history = {c: [] for c in bar_dirs} while trial <= num_trials: if cfg.SHOW_TRIALS: title_text = 'Trial %d / %d' % (trial, num_trials) else: title_text = 'Ready' true_label = dir_seq[trial - 1] # profiling feedback #import cProfile #pr = cProfile.Profile() #pr.enable() result = feedback.classify(decoder, true_label, title_text, bar_dirs, prob_history=prob_history) #pr.disable() #pr.print_stats(sort='time') if result is None: break else: pred_label = result dir_detected.append(pred_label) if cfg.WITH_REX is True and pred_label == true_label: # if cfg.WITH_REX is True: if pred_label == 'U': rex_dir = 'N' elif pred_label == 'L': rex_dir = 'W' elif pred_label == 'R': rex_dir = 'E' elif pred_label == 'D': rex_dir = 'S' else: qc.print_c( 'Warning: Rex cannot execute undefined action %s' % pred_label, 'W') rex_dir = None if rex_dir is not None: visual.move(pred_label, 100, overlay=False, barcolor='B') visual.update() qc.print_c('Executing Rex action %s' % rex_dir, 'W') os.system('%s/Rex/RexControlSimple.exe %s %s' % (pycnbi.ROOT, cfg.REX_COMPORT, rex_dir)) time.sleep(8) if true_label == pred_label: msg = 'Correct' else: msg = 'Wrong' if cfg.TRIALS_RETRY is False or true_label == pred_label: print('Trial %d: %s (%s -> %s)' % (trial, msg, true_label, pred_label)) trial += 1 if len(dir_detected) > 0: # write performance and log results fdir, _, _ = qc.parse_path_list(cfg.CLS_MI) logfile = time.strftime(fdir + "/online-%Y%m%d-%H%M%S.txt", time.localtime()) with open(logfile, 'w') as fout: fout.write('Ground-truth,Prediction\n') for gt, dt in zip(dir_seq, dir_detected): fout.write('%s,%s\n' % (gt, dt)) cfmat, acc = qc.confusion_matrix(dir_seq, dir_detected) fout.write('\nAccuracy %.3f\nConfusion matrix\n' % acc) fout.write(cfmat) print('Log exported to %s' % logfile) print('\nAccuracy %.3f\nConfusion matrix\n' % acc) print(cfmat) visual.finish() if decoder: decoder.stop() ''' # automatic thresholding if prob_history and len(bar_dirs) == 2: total = sum(len(prob_history[c]) for c in prob_history) fout = open(probs_logfile, 'a') msg = 'Automatic threshold optimization.\n' max_acc = 0 max_bias = 0 for bias in np.arange(-0.99, 1.00, 0.01): corrects = 0 for p in prob_history[bar_dirs[0]]: p_biased = (p + bias) / (bias + 1) # new sum = (p+bias) + (1-p) = bias+1 if p_biased >= 0.5: corrects += 1 for p in prob_history[bar_dirs[1]]: p_biased = (p + bias) / (bias + 1) # new sum = (p+bias) + (1-p) = bias+1 if p_biased < 0.5: corrects += 1 acc = corrects / total msg += '%s%.2f: %.3f\n' % (bar_dirs[0], bias, acc) if acc > max_acc: max_acc = acc max_bias = bias msg += 'Max acc = %.3f at bias %.2f\n' % (max_acc, max_bias) fout.write(msg) fout.close() print(msg) ''' print('Finished.')
cfg = check_cfg(imp.load_source(cfg_module, cfg_module)) if cfg.FAKE_CLS is None: # chooose amp if cfg.AMP_NAME is None and cfg.AMP_SERIAL is None: amp_name, amp_serial = pu.search_lsl(ignore_markers=True) else: amp_name = cfg.AMP_NAME amp_serial = cfg.AMP_SERIAL fake_dirs = None else: amp_name = None amp_serial = None fake_dirs = [v for (k, v) in cfg.DIRECTIONS] # events and triggers tdef = trigger_def(cfg.TRIGGER_DEF) if cfg.TRIGGER_DEVICE is None: input('\n** Warning: No trigger device set. Press Ctrl+C to stop or Enter to continue.') trigger = pyLptControl.Trigger(cfg.TRIGGER_DEVICE) if trigger.init(50) == False: qc.print_c('\n** Error connecting to USB2LPT device. Use a mock trigger instead?', 'R') input('Press Ctrl+C to stop or Enter to continue.') trigger = pyLptControl.MockTrigger() trigger.init(50) # init classification qc.print_c('Initializing decoder.', 'W') decoder_UD = BCIDecoder(cfg.CLS_MI, buffer_size=10.0, fake=(cfg.FAKE_CLS is not None), amp_name=amp_name, amp_serial=amp_serial, fake_dirs=fake_dirs)
def getperf(APPLY_CAR, APPLY_PCA, APPLY_OVERSAMPLING, DO_CV, l_freq, h_freq, picks_feat, offset, tmin, tmax, baselineRange, reg_coeff, verbose): raw = loadedraw.copy() if APPLY_CAR: raw._data[1:] = raw._data[1:] - np.mean(raw._data[1:], axis=0) processing_steps.append('Car') tdef = trigger_def('triggerdef_errp.ini') sfreq = raw.info['sfreq'] event_id = dict(correct=tdef.by_name['FEEDBACK_CORRECT'], wrong=tdef.by_name['FEEDBACK_WRONG']) # %% simu online SIMULATE_ONLINE = False # if SIMULATE_ONLINE is True: # signal = mne.Epochs(raw, events=events, event_id=event_id, tmin=tmin, tmax=tmax,baseline=baselineRange, picks=picks_feat, preload=True, proj=False) # signal_data = np.reshape(signal[0]._data,(signal[0]._data.shape[1],signal[0]._data.shape[2])) # foo = compute_features(signal_data,sfreq,l_freq,h_freq,decim_factor) # print('finalfeat') # print(foo) # Spatial filter - Common Average Reference (CAR) # todo: reactivate # %% Dataset wide processing # Bandpass temporal filtering if SIMULATE_ONLINE is False: raw.filter( l_freq=l_freq, h_freq=h_freq, n_jobs=n_jobs, picks=picks_feat, method='iir', iir_params=None ) # method='iir'and irr_params=None -> filter with a 4th order Butterworth # %% Epoching and baselining # epochs = mne.Epochs(raw, events=events, event_id=event_id, tmin=tmin, tmax=tmax, baseline=baselineRange, picks=picks_feat, preload=True) epochs = mne.Epochs(raw, events=events, event_id=event_id, tmin=tmin, tmax=tmax, baseline=baselineRange, picks=picks_feat, preload=True, proj=False, verbose=False) if (baselineRange): processing_steps.append('Baselining') if tmin != tmin_bkp: # if the baseline range was before the initial tmin, epochs was tricked to # to select this range (it expects that baseline is witin [tmin,tmax]) # this part restore the initial tmin and prune the data epochs.tmin = tmin_bkp epochs._data = epochs._data[:, :, int((tmin_bkp - tmin) * sfreq):] # %% Fold creation # epochs.events contains the label that we want on the third column # We can then get the relevent data within a fold by doing epochs._data[test] # It will return an array with size ({test}L, [channel]L,{time}L) label = epochs.events[:, 2] cv = StratifiedShuffleSplit(label, n_iter=20, test_size=0.2, random_state=1337) if useLeaveOneOut is True: cv = LeaveOneOut(len(label)) if APPLY_PCA: processing_steps.append('PCA') if APPLY_OVERSAMPLING: processing_steps.append('Oversampling') processing_steps.append('Normalization') processing_steps.append('Downsampling') # %% Fold processing def apply_cv(epochs): count = 1 confusion_matrixes = [] confusion_matrixes_percent = [] tn_rates = [] tp_rates = [] predicted = '' test_label = '' firstIterCV = True probabilities = np.array([[]], ndmin=2) predictions = np.array([]) my_tpr_cont = [] my_fpr_cont = [] my_aucs = [] best_threshold = [] cv_probabilities = [] cv_probabilities_label = [] for train, test in cv: ## Train Data processing ## train_data = epochs._data[train] train_label = label[train] # Online simulation flag if SIMULATE_ONLINE is True: # epochs should have one epoch only train_bp = mne.filter.band_pass_filter( train_data, sfreq, l_freq, h_freq, method='iir') # bandpass on one epoch else: train_bp = train_data # Normalization (train_normalized, trainShiftFactor, trainScaleFactor) = normalizeAcrossEpoch(train_bp, 'MinMax') # Downsampling train_downsampling = train_normalized[:, :, ::decim_factor] # Merge (reshape) channel and time for the PCA train_reshaped = train_downsampling.reshape( train_downsampling.shape[0], -1) # PCA initialisation if APPLY_PCA is False: pca = None train_pcaed = train_reshaped else: pca = PCA(0.95) pca.fit(train_reshaped) pca.components_ = -pca.components_ # inversion of vector to be constistant with Inaki's code train_pcaed = pca.transform(train_reshaped) # PCA # train_pcaed = train_reshaped ## Test data processing ## test_data = epochs._data[test] test_label = label[test] # Compute_feature does the same steps as for train, but requires a computed PCA (that we got from train) # (bandpass, norm, ds, and merge channel and time) test_pcaed = compute_features(test_data, sfreq, l_freq, h_freq, decim_factor, trainShiftFactor, trainScaleFactor, pca) # test_pcaed = compute_features(test_data,sfreq,l_freq,h_freq,decim_factor,trainShiftFactor,trainScaleFactor,pca=None) ## Test ## train_x = train_pcaed test_x = test_pcaed # oversampling the least present sample # if APPLY_OVERSAMPLING: # idx_offset = balance_idx(train_label) # oversampled_train_label = np.append(train_label,train_label[idx_offset]) # oversampled_train_x = np.concatenate((train_x,train_x[idx_offset]),0) # train_label = oversampled_train_label # train_x = oversampled_train_x # Classifier init RF = dict(trees=100, maxdepth=None) cls = RandomForestClassifier(n_estimators=RF['trees'], max_features='auto', max_depth=RF['maxdepth'], n_jobs=n_jobs) # cls = RandomForestClassifier(n_estimators=RF['trees'], max_features='auto', max_depth=RF['maxdepth'], class_weight="balanced", n_jobs=n_jobs) # cls = LDA(solver='eigen') # cls = QDA(reg_param=0.3) # regularized LDA # cls.fit( train_x, train_label ) # Y_pred= cls.predict( test_x ) # prediction = Y_pred # Fitting # cls= rLDA(regcoeff) cls.fit(train_x, train_label) predicted = cls.predict(test_x) probs = cls.predict_proba(test_x) prediction = np.array(predicted) if useLeaveOneOut is True: if firstIterCV is True: probabilities = np.append(probabilities, probs, axis=1) firstIterCV = False predictions = np.append(predictions, prediction) else: probabilities = np.append(probabilities, probs, axis=0) predictions = np.append(predictions, prediction) else: predictions = np.append(predictions, prediction) probabilities = np.append(probabilities, probs) # Performance if useLeaveOneOut is not True: cm = np.array(confusion_matrix(test_label, prediction)) cm_normalized = cm.astype('float') / cm.sum( axis=1)[:, np.newaxis] confusion_matrixes.append(cm) confusion_matrixes_percent.append(cm_normalized) avg_confusion_matrixes = np.mean( confusion_matrixes_percent, axis=0) # print('CV #'+str(count)) # print('Prediction: '+str(prediction)) # print(' Actual: '+str(test_label)) # Append probs to the global list probs_np = np.array(probs) cv_probabilities.append(probs_np[:, 0]) cv_probabilities_label.append(test_label) # if useLeaveOneOut is not True: # print('Confusion matrix') # print(cm) # print('Confusion matrix (normalized)') # print(cm_normalized) # print('---') # print('True positive rate: '+str(cm_normalized[0][0])) # print('True negative rate: '+str(cm_normalized[1][1])) # print('===================') ## Manual ROC curve computation # if useLeaveOneOut is not True: # probs_np = np.array(probs) # myfpr = [] # mytpr = [] # mythresh = [] # for thresh in np.linspace(0,1,100): # newpred = [4 if x[0] < thresh else 3 for x in probs_np] #list comprehension to quickly go through the list. x[0] because hp_probs is shape (2,20) # newcm = confusion_matrix(test_label,newpred) # newcm_norm = newcm.astype('float') / newcm.sum(axis=1)[:, np.newaxis] # mytpr.append(newcm_norm[0][0]) # myfpr.append(newcm_norm[1][0]) # mythresh.append(thresh) # # my_tpr_cont.append(mytpr) # my_fpr_cont.append(myfpr) # # myroc_auc = auc(myfpr, mytpr) # my_aucs.append(myroc_auc) ## One CV done, go to the next one count += 1 # if useLeaveOneOut is not True: # my_fpr_cont_np = np.array(my_fpr_cont) # my_tpr_cont_np = np.array(my_tpr_cont) # # my_fpr_cont_avg = np.mean(my_fpr_cont_np,axis=0) # my_tpr_cont_avg = np.mean(my_tpr_cont_np,axis=0) # # # plt.plot(my_fpr_cont_avg,my_tpr_cont_avg) # plt.xlabel('false positive rate') # plt.ylabel('true positive rate') # # # auc_from_avg = auc(my_fpr_cont_avg,my_tpr_cont_avg) # auc_from_my_aucs = np.mean(my_aucs) # # # Make a subset of data where FPR < 0.2 # idx_below_fpr_0_2 = np.where(my_fpr_cont_avg < MAX_FPR) # fpr_below_fpr_0_2 = my_fpr_cont_avg[idx_below_fpr_0_2] # tpr_below_fpr_0_2 = my_tpr_cont_avg[idx_below_fpr_0_2] # # # Look for the best (max value) FPR in that subset # best_tpr_below_fpr_0_2 = np.max(tpr_below_fpr_0_2) # # ... get its idx # best_tpr_below_fpr_0_2_idx = np.array(np.where(my_tpr_cont_avg == best_tpr_below_fpr_0_2)).ravel() # # # Get the associated TPRs # best_tpr_below_fpr_0_2_associated_fpr = np.array(my_fpr_cont_avg)[best_tpr_below_fpr_0_2_idx] # # Get the best (min value) in that subset # best_associated_fpr = np.min(best_tpr_below_fpr_0_2_associated_fpr) # # ... get its idx # best_associated_fpr_idx = np.array(np.where(my_fpr_cont_avg == best_associated_fpr)).ravel() # # # The best idx is the one that is on both set # best_idx = best_tpr_below_fpr_0_2_idx[np.in1d(best_tpr_below_fpr_0_2_idx,best_associated_fpr_idx)] # plt.xlabel('False positive rate') # plt.ylabel('True positive rate') # threshold_list = np.linspace(0,1,100) # best_threshold = threshold_list[best_idx] # print('Best treshold(s):'+str(best_threshold)) # print('Gives a TPR of '+str(best_tpr_below_fpr_0_2)) # print('And a FPR of '+str(best_associated_fpr)) # # ## from mpl_toolkits.mplot3d import Axes3D ## fig = plt.figure() ## ax = fig.add_subplot(111,projection='3d') ## ax.plot(my_fpr_cont_avg,my_fpr_cont_avg,mythresh) ## mean_tpr /= len(cv) ## mean_tpr[-1] = 1.0 ## mean_auc = auc(mean_fpr, mean_tpr) # #plt.plot(mean_fpr, mean_tpr, 'k--',label='Mean ROC (area = %0.2f)' % mean_auc, lw=2) # ## if useLeaveOneOut is True: ## finalCM = confusion_matrix(predictions,label) ## #Used w/ one out ## cms = [] ## cms_norm = [] ## #threshold search ## probabilities_right = probabilities[0,:] ## probabilities_wrong = probabilities[1,:] ## for thresh in np.arange(0,1,0.05): ## pred_tmp = np.array([]) ## for prob in probabilities: ## if prob[0] < thresh: ## pred_tmp = np.append(pred_tmp,4) ## else: ## pred_tmp = np.append(pred_tmp,3) ## cm_tmp = confusion_matrix(pred_tmp,label) ## cms.append(cm_tmp) ## cm_tmp_norm = cm_tmp.astype('float') / cm_tmp.sum(axis=1)[:, np.newaxis] ## cms_norm.append(cm_tmp_norm) # if useLeaveOneOut is True: # avg_confusion_matrixes = 0 # auc_from_avg = 0 # auc_from_my_aucs = 0 # label_np = np.array(label) # lvofpr = [] # lvotpr = [] # lvothresh = [] # lvocms = [] # threshold_list = np.linspace(0,1,100) # for thresh in threshold_list: # lvopred = [4 if x[0] < thresh else 3 for x in probabilities] #list comprehension to quickly go through the list. x[0] because hp_probs is shape (2,20) # lvocm = confusion_matrix(label_np,lvopred) # lvocm_norm = lvocm.astype('float') / lvocm.sum(axis=1)[:, np.newaxis] # lvocms.append(lvocm_norm) # lvotpr.append(lvocm_norm[0][0]) # lvofpr.append(lvocm_norm[1][0]) # lvothresh.append(thresh) # # lvo_auc = auc(lvofpr,lvotpr) # # # Make a subset of data where FPR < 0.2 # # idx_below_fpr_0_2 = np.where(np.array(lvofpr) < MAX_FPR) # fpr_below_fpr_0_2 = np.array(lvofpr)[idx_below_fpr_0_2[0]] # tpr_below_fpr_0_2 = np.array(lvotpr)[idx_below_fpr_0_2[0]] # # # Look for the best (max value) FPR in that subset # best_tpr_below_fpr_0_2 = np.max(tpr_below_fpr_0_2) # # ... get its idx # best_tpr_below_fpr_0_2_idx = np.array(np.where(lvotpr == best_tpr_below_fpr_0_2)).ravel() # # # Get the associated TPRs # best_tpr_below_fpr_0_2_associated_fpr = np.array(lvofpr)[best_tpr_below_fpr_0_2_idx] # # Get the best (min value) in that subset # best_associated_fpr = np.min(best_tpr_below_fpr_0_2_associated_fpr) # # ... get its idx # best_associated_fpr_idx = np.array(np.where(lvofpr == best_associated_fpr)).ravel() # # # The best idx is the one that is on both set # best_idx = best_tpr_below_fpr_0_2_idx[np.in1d(best_tpr_below_fpr_0_2_idx,best_associated_fpr_idx)] # # plt.plot(lvofpr,lvotpr) # plt.xlabel('False positive rate') # plt.ylabel('True positive rate') # best_threshold = threshold_list[best_idx] # print('Best treshold:'+str(best_threshold)) # print('Gives a TPR of '+str(best_tpr_below_fpr_0_2)) # print('And a FPR of '+str(best_associated_fpr)) # print('CM') # print(lvocms[best_idx[0]]) auc_from_avg = None auc_from_my_aucs = None best_threshold = None cv_probabilities_np = np.array(cv_probabilities) cv_prob_linear = np.ravel(cv_probabilities) cv_prob_label_np = np.array(cv_probabilities_label) cv_prob_label_linear = np.ravel(cv_prob_label_np) threshold_list = np.linspace(0, 1, 100) biglist_fpr = [] biglist_tpr = [] biglist_thresh = [] biglist_cms = [] for thresh in threshold_list: biglist_pred = [ 4 if x < thresh else 3 for x in cv_prob_linear ] # list comprehension to quickly go through the list. biglist_cm = confusion_matrix(cv_prob_label_linear, biglist_pred) biglist_cm_norm = biglist_cm.astype('float') / biglist_cm.sum( axis=1)[:, np.newaxis] biglist_cms.append(biglist_cm_norm) biglist_tpr.append(biglist_cm_norm[0][0]) biglist_fpr.append(biglist_cm_norm[1][0]) biglist_thresh.append(thresh) biglist_auc = auc(biglist_fpr, biglist_tpr) # Make a subset of data where FPR < MAX_FPR idx_below_maxfpr = np.where(np.array(biglist_fpr) < MAX_FPR) fpr_below_maxfpr = np.array(biglist_fpr)[idx_below_maxfpr[0]] tpr_below_maxfpr = np.array(biglist_tpr)[idx_below_maxfpr[0]] # Look for the best (max value) FPR in that subset best_tpr_below_maxfpr = np.max(tpr_below_maxfpr) best_tpr_below_maxfpr_idx = np.array( np.where(biglist_tpr == best_tpr_below_maxfpr)).ravel() # get its idx # Get the associated TPRs best_tpr_below_maxfpr_associated_fpr = np.array( biglist_fpr)[best_tpr_below_maxfpr_idx] # Get the best (min value) in that subset best_associated_fpr = np.min(best_tpr_below_maxfpr_associated_fpr) # ... get its idx best_associated_fpr_idx = np.array( np.where(biglist_fpr == best_associated_fpr)).ravel() # The best idx is the one that is on both set best_idx = best_tpr_below_maxfpr_idx[np.in1d( best_tpr_below_maxfpr_idx, best_associated_fpr_idx)] best_threshold = threshold_list[best_idx] if False: plt.plot(biglist_fpr, biglist_tpr) plt.xlabel('False positive rate') plt.ylabel('True positive rate') print('#################################') print('Best treshold:' + str(best_threshold)) print('Gives a TPR of ' + str(best_tpr_below_maxfpr)) print('And a FPR of ' + str(best_associated_fpr)) print('CM') print(biglist_cms[best_idx[0]]) return (biglist_auc, biglist_cms, best_threshold, best_tpr_below_maxfpr) biglist_auc, biglist_cms, best_threshold = apply_cv(epochs) return (biglist_auc, biglist_cms, best_threshold, best_tpr_below_maxfpr)
def config_run(cfg_module): cfg = imp.load_source(cfg_module, cfg_module) tdef = trigger_def(cfg.TRIGGER_DEF) refresh_delay = 1.0 / cfg.REFRESH_RATE # visualizer keys = {'left':81, 'right':83, 'up':82, 'down':84, 'pgup':85, 'pgdn':86, 'home':80, 'end':87, 'space':32, 'esc':27, ',':44, '.':46, 's':115, 'c':99, '[':91, ']':93, '1':49, '!':33, '2':50, '@':64, '3':51, '#':35} color = dict(G=(20, 140, 0), B=(210, 0, 0), R=(0, 50, 200), Y=(0, 215, 235), K=(0, 0, 0), w=(200, 200, 200)) dir_sequence = [] for x in range(cfg.TRIALS_EACH): dir_sequence.extend(cfg.DIRECTIONS) random.shuffle(dir_sequence) num_trials = len(cfg.DIRECTIONS) * cfg.TRIALS_EACH event = 'start' trial = 1 # Hardware trigger if cfg.TRIGGER_DEVICE is None: input('\n** Warning: No trigger device set. Press Ctrl+C to stop or Enter to continue.') trigger = pyLptControl.Trigger(cfg.TRIGGER_DEVICE) if trigger.init(50) == False: print('\n** Error connecting to USB2LPT device. Use a mock trigger instead?') input('Press Ctrl+C to stop or Enter to continue.') trigger = pyLptControl.MockTrigger() trigger.init(50) timer_trigger = qc.Timer() timer_dir = qc.Timer() timer_refresh = qc.Timer() bar = BarVisual(cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE) bar.fill() bar.glass_draw_cue() # start while trial <= num_trials: timer_refresh.sleep_atleast(refresh_delay) timer_refresh.reset() # segment= { 'cue':(s,e), 'dir':(s,e), 'label':0-4 } (zero-based) if event == 'start' and timer_trigger.sec() > cfg.T_INIT: event = 'gap_s' bar.fill() timer_trigger.reset() trigger.signal(tdef.INIT) elif event == 'gap_s': bar.put_text('Trial %d / %d' % (trial, num_trials)) event = 'gap' elif event == 'gap' and timer_trigger.sec() > cfg.T_GAP: event = 'cue' bar.fill() bar.draw_cue() trigger.signal(tdef.CUE) timer_trigger.reset() elif event == 'cue' and timer_trigger.sec() > cfg.T_CUE: event = 'dir_r' dir = dir_sequence[trial - 1] if dir == 'L': # left bar.move('L', 100, overlay=True) trigger.signal(tdef.LEFT_READY) elif dir == 'R': # right bar.move('R', 100, overlay=True) trigger.signal(tdef.RIGHT_READY) elif dir == 'U': # up bar.move('U', 100, overlay=True) trigger.signal(tdef.UP_READY) elif dir == 'D': # down bar.move('D', 100, overlay=True) trigger.signal(tdef.DOWN_READY) elif dir == 'B': # both hands bar.move('L', 100, overlay=True) bar.move('R', 100, overlay=True) trigger.signal(tdef.BOTH_READY) else: raise RuntimeError('Unknown direction %d' % dir) timer_trigger.reset() elif event == 'dir_r' and timer_trigger.sec() > cfg.T_DIR_READY: bar.fill() bar.draw_cue() event = 'dir' timer_trigger.reset() timer_dir.reset() if dir == 'L': # left trigger.signal(tdef.LEFT_GO) elif dir == 'R': # right trigger.signal(tdef.RIGHT_GO) elif dir == 'U': # up trigger.signal(tdef.UP_GO) elif dir == 'D': # down trigger.signal(tdef.DOWN_GO) elif dir == 'B': # both trigger.signal(tdef.BOTH_GO) else: raise RuntimeError('Unknown direction %d' % dir) elif event == 'dir' and timer_trigger.sec() > cfg.T_DIR: event = 'gap_s' bar.fill() trial += 1 print('trial ' + str(trial - 1) + ' done') trigger.signal(tdef.BLANK) timer_trigger.reset() # protocol if event == 'dir': dx = min(100, int(100.0 * timer_dir.sec() / cfg.T_DIR) + 1) if dir == 'L': # L bar.move('L', dx, overlay=True) elif dir == 'R': # R bar.move('R', dx, overlay=True) elif dir == 'U': # U bar.move('U', dx, overlay=True) elif dir == 'D': # D bar.move('D', dx, overlay=True) elif dir == 'B': # Both bar.move('L', dx, overlay=True) bar.move('R', dx, overlay=True) # wait for start if event == 'start': bar.put_text('Waiting to start') bar.update() key = 0xFF & cv2.waitKey(1) if key == keys['esc']: break
def run(cfg, state=mp.Value('i', 1), queue=None): redirect_stdout_to_queue(logger, queue, 'INFO') # Wait the recording to start (GUI) while state.value == 2: # 0: stop, 1:start, 2:wait pass # Protocol start if equals to 1 if not state.value: sys.exit() refresh_delay = 1.0 / cfg.REFRESH_RATE cfg.tdef = trigger_def(cfg.TRIGGER_FILE) # visualizer keys = { 'left': 81, 'right': 83, 'up': 82, 'down': 84, 'pgup': 85, 'pgdn': 86, 'home': 80, 'end': 87, 'space': 32, 'esc': 27, ',': 44, '.': 46, 's': 115, 'c': 99, '[': 91, ']': 93, '1': 49, '!': 33, '2': 50, '@': 64, '3': 51, '#': 35 } color = dict(G=(20, 140, 0), B=(210, 0, 0), R=(0, 50, 200), Y=(0, 215, 235), K=(0, 0, 0), w=(200, 200, 200)) dir_sequence = [] for x in range(cfg.TRIALS_EACH): dir_sequence.extend(cfg.DIRECTIONS) random.shuffle(dir_sequence) num_trials = len(cfg.DIRECTIONS) * cfg.TRIALS_EACH event = 'start' trial = 1 # Hardware trigger if cfg.TRIGGER_DEVICE is None: logger.warning( 'No trigger device set. Press Ctrl+C to stop or Enter to continue.' ) #input() trigger = pyLptControl.Trigger(state, cfg.TRIGGER_DEVICE) if trigger.init(50) == False: logger.error( '\n** Error connecting to USB2LPT device. Use a mock trigger instead?' ) input('Press Ctrl+C to stop or Enter to continue.') trigger = pyLptControl.MockTrigger() trigger.init(50) # timers timer_trigger = qc.Timer() timer_dir = qc.Timer() timer_refresh = qc.Timer() t_dir = cfg.TIMINGS['DIR'] + random.uniform(-cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE']) t_dir_ready = cfg.TIMINGS['READY'] + random.uniform( -cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE']) bar = BarVisual(cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE) bar.fill() bar.glass_draw_cue() # start while trial <= num_trials: timer_refresh.sleep_atleast(refresh_delay) timer_refresh.reset() # segment= { 'cue':(s,e), 'dir':(s,e), 'label':0-4 } (zero-based) if event == 'start' and timer_trigger.sec() > cfg.TIMINGS['INIT']: event = 'gap_s' bar.fill() timer_trigger.reset() trigger.signal(cfg.tdef.INIT) elif event == 'gap_s': if cfg.TRIAL_PAUSE: bar.put_text('Press any key') bar.update() key = cv2.waitKey() if key == keys['esc'] or not state.value: break bar.fill() bar.put_text('Trial %d / %d' % (trial, num_trials)) event = 'gap' timer_trigger.reset() elif event == 'gap' and timer_trigger.sec() > cfg.TIMINGS['GAP']: event = 'cue' bar.fill() bar.draw_cue() trigger.signal(cfg.tdef.CUE) timer_trigger.reset() elif event == 'cue' and timer_trigger.sec() > cfg.TIMINGS['CUE']: event = 'dir_r' dir = dir_sequence[trial - 1] if dir == 'L': # left bar.move('L', 100, overlay=True) trigger.signal(cfg.tdef.LEFT_READY) elif dir == 'R': # right bar.move('R', 100, overlay=True) trigger.signal(cfg.tdef.RIGHT_READY) elif dir == 'U': # up bar.move('U', 100, overlay=True) trigger.signal(cfg.tdef.UP_READY) elif dir == 'D': # down bar.move('D', 100, overlay=True) trigger.signal(cfg.tdef.DOWN_READY) elif dir == 'B': # both hands bar.move('L', 100, overlay=True) bar.move('R', 100, overlay=True) trigger.signal(cfg.tdef.BOTH_READY) else: raise RuntimeError('Unknown direction %d' % dir) timer_trigger.reset() elif event == 'dir_r' and timer_trigger.sec() > t_dir_ready: bar.fill() bar.draw_cue() event = 'dir' timer_trigger.reset() timer_dir.reset() if dir == 'L': # left trigger.signal(cfg.tdef.LEFT_GO) elif dir == 'R': # right trigger.signal(cfg.tdef.RIGHT_GO) elif dir == 'U': # up trigger.signal(cfg.tdef.UP_GO) elif dir == 'D': # down trigger.signal(cfg.tdef.DOWN_GO) elif dir == 'B': # both trigger.signal(cfg.tdef.BOTH_GO) else: raise RuntimeError('Unknown direction %d' % dir) elif event == 'dir' and timer_trigger.sec() > t_dir: event = 'gap_s' bar.fill() trial += 1 logger.info('trial ' + str(trial - 1) + ' done') trigger.signal(cfg.tdef.BLANK) timer_trigger.reset() t_dir = cfg.TIMINGS['DIR'] + random.uniform( -cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE']) t_dir_ready = cfg.TIMINGS['READY'] + random.uniform( -cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE']) # protocol if event == 'dir': dx = min(100, int(100.0 * timer_dir.sec() / t_dir) + 1) if dir == 'L': # L bar.move('L', dx, overlay=True) elif dir == 'R': # R bar.move('R', dx, overlay=True) elif dir == 'U': # U bar.move('U', dx, overlay=True) elif dir == 'D': # D bar.move('D', dx, overlay=True) elif dir == 'B': # Both bar.move('L', dx, overlay=True) bar.move('R', dx, overlay=True) # wait for start if event == 'start': bar.put_text('Waiting to start') bar.update() key = 0xFF & cv2.waitKey(1) if key == keys['esc'] or not state.value: break bar.finish() with state.get_lock(): state.value = 0
cfg = check_cfg(imp.load_source(cfg_module, cfg_module)) if cfg.FAKE_CLS is None: # chooose amp if cfg.AMP_NAME is None and cfg.AMP_SERIAL is None: amp_name, amp_serial = pu.search_lsl(ignore_markers=True) else: amp_name = cfg.AMP_NAME amp_serial = cfg.AMP_SERIAL fake_dirs = None else: amp_name = None amp_serial = None fake_dirs = [v for (k, v) in cfg.DIRECTIONS] # events and triggers tdef = trigger_def(cfg.TRIGGER_FILE) if cfg.TRIGGER_DEVICE is None: input( '\n** Warning: No trigger device set. Press Ctrl+C to stop or Enter to continue.' ) trigger = pyLptControl.Trigger(cfg.TRIGGER_DEVICE) if trigger.init(50) == False: logger.error( 'Cannot connect to USB2LPT device. Use a mock trigger instead?') input('Press Ctrl+C to stop or Enter to continue.') trigger = pyLptControl.MockTrigger() trigger.init(50) # init classification logger.info('Initializing decoder')
def stream_player(server_name, fif_file, chunk_size, auto_restart=True, wait_start=True, repeat=np.float('inf'), high_resolution=False, trigger_file=None): """ Input ===== server_name: LSL server name. fif_file: fif file to replay. chunk_size: number of samples to send at once (usually 16-32 is good enough). auto_restart: play from beginning again after reaching the end. wait_start: wait for user to start in the beginning. repeat: number of loops to play. high_resolution: use perf_counter() instead of sleep() for higher time resolution but uses much more cpu due to polling. trigger_file: used to convert event numbers into event strings for readability. Note: Run pycnbi.set_log_level('DEBUG') to print out the relative time stamps since started. """ raw, events = pu.load_raw(fif_file) sfreq = raw.info['sfreq'] # sampling frequency n_channels = len(raw.ch_names) # number of channels if trigger_file is not None: tdef = trigger_def(trigger_file) try: event_ch = raw.ch_names.index('TRIGGER') except ValueError: event_ch = None if raw is not None: logger.info_green('Successfully loaded %s' % fif_file) logger.info('Server name: %s' % server_name) logger.info('Sampling frequency %.3f Hz' % sfreq) logger.info('Number of channels : %d' % n_channels) logger.info('Chunk size : %d' % chunk_size) for i, ch in enumerate(raw.ch_names): logger.info('%d %s' % (i, ch)) logger.info('Trigger channel : %s' % event_ch) else: raise RuntimeError('Error while loading %s' % fif_file) # set server information sinfo = pylsl.StreamInfo(server_name, channel_count=n_channels, channel_format='float32',\ nominal_srate=sfreq, type='EEG', source_id=server_name) desc = sinfo.desc() channel_desc = desc.append_child("channels") for ch in raw.ch_names: channel_desc.append_child('channel').append_child_value('label', str(ch))\ .append_child_value('type','EEG').append_child_value('unit','microvolts') desc.append_child('amplifier').append_child('settings').append_child_value( 'is_slave', 'false') desc.append_child('acquisition').append_child_value( 'manufacturer', 'PyCNBI').append_child_value('serial_number', 'N/A') outlet = pylsl.StreamOutlet(sinfo, chunk_size=chunk_size) if wait_start: input('Press Enter to start streaming.') logger.info('Streaming started') idx_chunk = 0 t_chunk = chunk_size / sfreq finished = False if high_resolution: t_start = time.perf_counter() else: t_start = time.time() # start streaming played = 1 while played < repeat: idx_current = idx_chunk * chunk_size chunk = raw._data[:, idx_current:idx_current + chunk_size] data = chunk.transpose().tolist() if idx_current >= raw._data.shape[1] - chunk_size: finished = True if high_resolution: # if a resolution over 2 KHz is needed t_sleep_until = t_start + idx_chunk * t_chunk while time.perf_counter() < t_sleep_until: pass else: # time.sleep() can have 500 us resolution using the tweak tool provided. t_wait = t_start + idx_chunk * t_chunk - time.time() if t_wait > 0.001: time.sleep(t_wait) outlet.push_chunk(data) logger.debug('[%8.3fs] sent %d samples (LSL %8.3f)' % (time.perf_counter(), len(data), pylsl.local_clock())) if event_ch is not None: event_values = set(chunk[event_ch]) - set([0]) if len(event_values) > 0: if trigger_file is None: logger.info('Events: %s' % event_values) else: for event in event_values: if event in tdef.by_value: logger.info('Events: %s (%s)' % (event, tdef.by_value[event])) else: logger.info('Events: %s (Undefined event)' % event) idx_chunk += 1 if finished: if auto_restart is False: input( 'Reached the end of data. Press Enter to restart or Ctrl+C to stop.' ) else: logger.info('Reached the end of data. Restarting.') idx_chunk = 0 finished = False if high_resolution: t_start = time.perf_counter() else: t_start = time.time() played += 1
def get_tfr(cfg, recursive=False, n_jobs=1): ''' @params: tfr_type: 'multitaper' or 'morlet' recursive: if True, load raw files in sub-dirs recursively export_path: path to save plots n_jobs: number of cores to run in parallel ''' cfg = check_config(cfg) tfr_type = cfg.TFR_TYPE export_path = cfg.EXPORT_PATH t_buffer = cfg.T_BUFFER if tfr_type == 'multitaper': tfr = mne.time_frequency.tfr_multitaper elif tfr_type == 'morlet': tfr = mne.time_frequency.tfr_morlet elif tfr_type == 'butter': butter_order = 4 # TODO: parameterize tfr = lfilter elif tfr_type == 'fir': raise NotImplementedError else: raise ValueError('Wrong TFR type %s' % tfr_type) n_jobs = cfg.N_JOBS if n_jobs is None: n_jobs = mp.cpu_count() if hasattr(cfg, 'DATA_PATHS'): if export_path is None: raise ValueError('For multiple directories, cfg.EXPORT_PATH cannot be None') else: outpath = export_path # custom event file if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None: events = mne.read_events(cfg.EVENT_FILE) file_prefix = 'grandavg' # load and merge files from all directories flist = [] for ddir in cfg.DATA_PATHS: ddir = ddir.replace('\\', '/') if ddir[-1] != '/': ddir += '/' for f in qc.get_file_list(ddir, fullpath=True, recursive=recursive): if qc.parse_path(f).ext in ['fif', 'bdf', 'gdf']: flist.append(f) raw, events = pu.load_multi(flist) else: logger.info('Loading %s' % cfg.DATA_FILE) raw, events = pu.load_raw(cfg.DATA_FILE) # custom events if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None: events = mne.read_events(cfg.EVENT_FILE) if export_path is None: [outpath, file_prefix, _] = qc.parse_path_list(cfg.DATA_FILE) else: file_prefix = qc.parse_path(cfg.DATA_FILE).name outpath = export_path file_prefix = qc.parse_path(cfg.DATA_FILE).name # re-referencing if cfg.REREFERENCE is not None: pu.rereference(raw, cfg.REREFERENCE[1], cfg.REREFERENCE[0]) assert cfg.REREFERENCE[0] in raw.ch_names sfreq = raw.info['sfreq'] # set channels of interest picks = pu.channel_names_to_index(raw, cfg.CHANNEL_PICKS) spchannels = pu.channel_names_to_index(raw, cfg.SP_CHANNELS) if max(picks) > len(raw.info['ch_names']): msg = 'ERROR: "picks" has a channel index %d while there are only %d channels.' %\ (max(picks), len(raw.info['ch_names'])) raise RuntimeError(msg) # Apply filters raw = pu.preprocess(raw, spatial=cfg.SP_FILTER, spatial_ch=spchannels, spectral=cfg.TP_FILTER, spectral_ch=picks, notch=cfg.NOTCH_FILTER, notch_ch=picks, multiplier=cfg.MULTIPLIER, n_jobs=n_jobs) # Read epochs classes = {} event_set = set(events[:, -1]) for t_name in cfg.TRIGGER_DEF: if cfg.TRIGGER_FILE is not None: tdef = trigger_def(cfg.TRIGGER_FILE) t_value = tdef.by_name[t_name] if t_value in event_set: classes[t_name] = t_value else: classes[str(t_name)] = t_name if len(classes) == 0: from IPython import embed; embed() raise ValueError('No desired event was found from the data.') try: tmin = cfg.EPOCH[0] tmin_buffer = tmin - t_buffer raw_tmax = raw._data.shape[1] / sfreq - 0.1 if cfg.EPOCH[1] is None: if cfg.POWER_AVERAGED: raise ValueError('EPOCH value cannot have None for grand averaged TFR') else: if len(cfg.TRIGGERS) > 1: raise ValueError('If the end time of EPOCH is None, only a single event can be defined.') t_ref = events[np.where(events[:,2] == list(cfg.TRIGGERS)[0])[0][0], 0] / sfreq tmax = raw_tmax - t_ref - t_buffer else: tmax = cfg.EPOCH[1] tmax_buffer = tmax + t_buffer if tmax_buffer > raw_tmax: raise ValueError('Epoch length with buffer (%.3f) is larger than signal length (%.3f)' % (tmax_buffer, raw_tmax)) epochs_all = mne.Epochs(raw, events, classes, tmin=tmin_buffer, tmax=tmax_buffer, proj=False, picks=picks, baseline=None, preload=True) if epochs_all.drop_log_stats() > 0: logger.error('\n** Bad epochs found. Dropping into a Python shell.') logger.error(epochs_all.drop_log) logger.error('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \ (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq)) logger.error('\nType exit to continue.\n') pdb.set_trace() except: logger.critical('\n*** (tfr_export) Unknown error occurred while epoching ***') logger.critical('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \ (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq)) pdb.set_trace() power = {} for evname in classes: export_dir = outpath qc.make_dirs(export_dir) logger.info('>> Processing %s' % evname) freqs = cfg.FREQ_RANGE # define frequencies of interest n_cycles = freqs / 2. # different number of cycle per frequency if cfg.POWER_AVERAGED: # grand-average TFR epochs = epochs_all[evname][:] if len(epochs) == 0: logger.WARNING('No %s epochs. Skipping.' % evname) continue if tfr_type == 'butter': b, a = butter_bandpass(cfg.FREQ_RANGE[0], cfg.FREQ_RANGE[-1], sfreq, order=butter_order) tfr_filtered = lfilter(b, a, epochs, axis=2) tfr_hilbert = hilbert(tfr_filtered) tfr_power = abs(tfr_hilbert) tfr_data = np.mean(tfr_power, axis=0) elif tfr_type == 'fir': raise NotImplementedError else: power[evname] = tfr(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=False, return_itc=False, decim=1, n_jobs=n_jobs) power[evname] = power[evname].crop(tmin=tmin, tmax=tmax) tfr_data = power[evname].data if cfg.EXPORT_MATLAB is True: # export all channels to MATLAB mout = '%s/%s-%s-%s.mat' % (export_dir, file_prefix, cfg.SP_FILTER, evname) scipy.io.savemat(mout, {'tfr':tfr_data, 'chs':epochs.ch_names, 'events':events, 'sfreq':sfreq, 'tmin':tmin, 'tmax':tmax, 'epochs':cfg.EPOCH, 'freqs':cfg.FREQ_RANGE}) logger.info('Exported %s' % mout) if cfg.EXPORT_PNG is True: # Inspect power for each channel for ch in np.arange(len(picks)): chname = raw.ch_names[picks[ch]] title = 'Peri-event %s - Channel %s' % (evname, chname) # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent' fig = power[evname].plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False, colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False) fout = '%s/%s-%s-%s-%s.png' % (export_dir, file_prefix, cfg.SP_FILTER, evname, chname) fig.savefig(fout) plt.close() logger.info('Exported to %s' % fout) else: # TFR per event for ep in range(len(epochs_all[evname])): epochs = epochs_all[evname][ep] if len(epochs) == 0: logger.WARNING('No %s epochs. Skipping.' % evname) continue power[evname] = tfr(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=False, return_itc=False, decim=1, n_jobs=n_jobs) power[evname] = power[evname].crop(tmin=tmin, tmax=tmax) if cfg.EXPORT_MATLAB is True: # export all channels to MATLAB mout = '%s/%s-%s-%s-ep%02d.mat' % (export_dir, file_prefix, cfg.SP_FILTER, evname, ep + 1) scipy.io.savemat(mout, {'tfr':power[evname].data, 'chs':power[evname].ch_names, 'events':events, 'sfreq':sfreq, 'tmin':tmin, 'tmax':tmax, 'epochs':cfg.EPOCH, 'freqs':cfg.FREQ_RANGE}) logger.info('Exported %s' % mout) if cfg.EXPORT_PNG is True: # Inspect power for each channel for ch in np.arange(len(picks)): chname = raw.ch_names[picks[ch]] title = 'Peri-event %s - Channel %s, Trial %d' % (evname, chname, ep + 1) # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent' fig = power[evname].plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False, colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False) fout = '%s/%s-%s-%s-%s-ep%02d.png' % (export_dir, file_prefix, cfg.SP_FILTER, evname, chname, ep + 1) fig.savefig(fout) plt.close() logger.info('Exported %s' % fout) if hasattr(cfg, 'POWER_DIFF'): export_dir = '%s/diff' % outpath qc.make_dirs(export_dir) labels = classes.keys() df = power[labels[0]] - power[labels[1]] df.data = np.log(np.abs(df.data)) # Inspect power diff for each channel for ch in np.arange(len(picks)): chname = raw.ch_names[picks[ch]] title = 'Peri-event %s-%s - Channel %s' % (labels[0], labels[1], chname) # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent' fig = df.plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False, colorbar=True, title=title, vmin=3.0, vmax=-3.0, dB=False) fout = '%s/%s-%s-diff-%s-%s-%s.jpg' % (export_dir, file_prefix, cfg.SP_FILTER, labels[0], labels[1], chname) logger.info('Exporting to %s' % fout) fig.savefig(fout) plt.close() logger.info('Finished !')
scaleFactor = np.max(epoch_data, 0) - np.min(epoch_data, 0) elif method == 'override': # todo: find a better name shiftFactor = givenShiftFactor scaleFactor = givenScaleFactor for trial in range(new_epochs_data.shape[0]): new_epochs_data[trial, :, :] = (new_epochs_data[trial, :, :] - shiftFactor) / scaleFactor return (new_epochs_data, shiftFactor, scaleFactor) if __name__ == '__main__': # load data raw, events = pu.load_multi(FLIST, spfilter='car') tdef = trigger_def('triggerdef_errp.ini') sfreq = raw.info['sfreq'] # epoching tmin = 0 tmax = 1 event_id = dict(correct=tdef.by_key['FEEDBACK_CORRECT'], wrong=tdef.by_key['FEEDBACK_WRONG']) # export plots: apply offline spectral filter if EXPORT_PLOTS == True: # apply filter on entire signal (for offline analysis) raw.filter(l_freq=l_freq, h_freq=h_freq, n_jobs=n_jobs, picks=picks_feat,
def config_run(cfg_module): cfg = load_cfg(cfg_module) # visualizer keys = { 'left': 81, 'right': 83, 'up': 82, 'down': 84, 'pgup': 85, 'pgdn': 86, 'home': 80, 'end': 87, 'space': 32, 'esc': 27, ',': 44, '.': 46, 's': 115, 'c': 99, '[': 91, ']': 93, '1': 49, '!': 33, '2': 50, '@': 64, '3': 51, '#': 35 } color = dict(G=(20, 140, 0), B=(210, 0, 0), R=(0, 50, 200), Y=(0, 215, 235), K=(0, 0, 0), W=(255, 255, 255), w=(200, 200, 200)) dir_sequence = [] for x in range(cfg.TRIALS_EACH): dir_sequence.extend(cfg.DIRECTIONS) random.shuffle(dir_sequence) num_trials = len(cfg.DIRECTIONS) * cfg.TRIALS_EACH tdef = trigger_def(cfg.TRIGGER_DEF) refresh_delay = 1.0 / cfg.REFRESH_RATE state = 'start' trial = 1 # STIMO protocol if cfg.WITH_STIMO is True: print('Opening STIMO serial port (%s / %d bps)' % (cfg.STIMO_COMPORT, cfg.STIMO_BAUDRATE)) import serial ser = serial.Serial(cfg.STIMO_COMPORT, cfg.STIMO_BAUDRATE) print('STIMO serial port %s is_open = %s' % (cfg.STIMO_COMPORT, ser.is_open)) # init trigger if cfg.TRIGGER_DEVICE is None: input( '\n** Warning: No trigger device set. Press Ctrl+C to stop or Enter to continue.' ) trigger = pyLptControl.Trigger(cfg.TRIGGER_DEVICE) if trigger.init(50) == False: print( '\n# Error connecting to USB2LPT device. Use a mock trigger instead?' ) input('Press Ctrl+C to stop or Enter to continue.') trigger = pyLptControl.MockTrigger() trigger.init(50) # visual feedback if cfg.FEEDBACK_TYPE == 'BAR': from pycnbi.protocols.viz_bars import BarVisual visual = BarVisual(cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE) elif cfg.FEEDBACK_TYPE == 'BODY': if not hasattr(cfg, 'IMAGE_PATH'): raise ValueError('IMAGE_PATH is undefined in your config.') from pycnbi.protocols.viz_human import BodyVisual visual = BodyVisual(cfg.IMAGE_PATH, use_glass=cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE) visual.put_text('Waiting to start ') timer_trigger = qc.Timer() timer_dir = qc.Timer() timer_refresh = qc.Timer() # start while trial <= num_trials: timer_refresh.sleep_atleast(refresh_delay) timer_refresh.reset() # segment= { 'cue':(s,e), 'dir':(s,e), 'label':0-4 } (zero-based) if state == 'start' and timer_trigger.sec() > cfg.T_INIT: state = 'gap_s' visual.fill() timer_trigger.reset() trigger.signal(tdef.INIT) elif state == 'gap_s': visual.put_text('Trial %d / %d' % (trial, num_trials)) state = 'gap' elif state == 'gap' and timer_trigger.sec() > cfg.T_GAP: state = 'cue' visual.fill() visual.draw_cue() trigger.signal(tdef.CUE) timer_trigger.reset() elif state == 'cue' and timer_trigger.sec() > cfg.T_CUE: state = 'dir_r' dir = dir_sequence[trial - 1] if dir == 'L': # left if cfg.FEEDBACK_TYPE == 'BAR': visual.move('L', 100) else: visual.put_text('LEFT') trigger.signal(tdef.LEFT_READY) elif dir == 'R': # right if cfg.FEEDBACK_TYPE == 'BAR': visual.move('R', 100) else: visual.put_text('RIGHT') trigger.signal(tdef.RIGHT_READY) elif dir == 'U': # up if cfg.FEEDBACK_TYPE == 'BAR': visual.move('U', 100) else: visual.put_text('UP') trigger.signal(tdef.UP_READY) elif dir == 'D': # down if cfg.FEEDBACK_TYPE == 'BAR': visual.move('D', 100) else: visual.put_text('DOWN') trigger.signal(tdef.DOWN_READY) elif dir == 'B': # both hands if cfg.FEEDBACK_TYPE == 'BAR': visual.move('L', 100) visual.move('R', 100) else: visual.put_text('BOTH') trigger.signal(tdef.BOTH_READY) else: raise RuntimeError('Unknown direction %d' % dir) gait_steps = 1 timer_trigger.reset() elif state == 'dir_r' and timer_trigger.sec() > cfg.T_DIR_READY: visual.draw_cue() state = 'dir' timer_trigger.reset() timer_dir.reset() t_step = cfg.T_DIR + random.random() * cfg.RANDOMIZE_LENGTH if dir == 'L': # left trigger.signal(tdef.LEFT_GO) elif dir == 'R': # right trigger.signal(tdef.RIGHT_GO) elif dir == 'U': # up trigger.signal(tdef.UP_GO) elif dir == 'D': # down trigger.signal(tdef.DOWN_GO) elif dir == 'B': # both trigger.signal(tdef.BOTH_GO) else: raise RuntimeError('Unknown direction %d' % dir) elif state == 'dir': if timer_trigger.sec() > t_step: if cfg.FEEDBACK_TYPE == 'BODY': if cfg.WITH_STIMO is True: if dir == 'L': # left ser.write(b'1') qc.print_c('STIMO: Sent 1', 'g') trigger.signal(tdef.LEFT_STIMO) elif dir == 'R': # right ser.write(b'2') qc.print_c('STIMO: Sent 2', 'g') trigger.signal(tdef.RIGHT_STIMO) else: if dir == 'L': # left trigger.signal(tdef.LEFT_RETURN) elif dir == 'R': # right trigger.signal(tdef.RIGHT_RETURN) else: trigger.signal(tdef.FEEDBACK) state = 'return' timer_trigger.reset() else: dx = min(100, int(100.0 * timer_dir.sec() / t_step) + 1) if dir == 'L': # L visual.move('L', dx, overlay=True) elif dir == 'R': # R visual.move('R', dx, overlay=True) elif dir == 'U': # U visual.move('U', dx, overlay=True) elif dir == 'D': # D visual.move('D', dx, overlay=True) elif dir == 'B': # Both visual.move('L', dx, overlay=True) visual.move('R', dx, overlay=True) elif state == 'return': if timer_trigger.sec() > cfg.T_RETURN: if gait_steps < cfg.GAIT_STEPS: gait_steps += 1 state = 'dir' visual.move('L', 0) if dir == 'L': dir = 'R' trigger.signal(tdef.RIGHT_GO) else: dir = 'L' trigger.signal(tdef.LEFT_GO) timer_dir.reset() t_step = cfg.T_DIR + random.random() * cfg.RANDOMIZE_LENGTH else: state = 'gap_s' visual.fill() trial += 1 print('trial ' + str(trial - 1) + ' done') trigger.signal(tdef.BLANK) timer_trigger.reset() else: dx = max( 0, int(100.0 * (cfg.T_RETURN - timer_trigger.sec()) / cfg.T_RETURN)) if dir == 'L': # L visual.move('L', dx, overlay=True) elif dir == 'R': # R visual.move('R', dx, overlay=True) elif dir == 'U': # U visual.move('U', dx, overlay=True) elif dir == 'D': # D visual.move('D', dx, overlay=True) elif dir == 'B': # Both visual.move('L', dx, overlay=True) visual.move('R', dx, overlay=True) # wait for start if state == 'start': visual.put_text('Waiting to start ') visual.update() key = 0xFF & cv2.waitKey(1) if key == keys['esc']: break # STIMO protocol if cfg.WITH_STIMO is True: ser.close() print('Closed STIMO serial port %s' % cfg.STIMO_COMPORT)
import pycnbi import numpy as np import pycnbi.utils.q_common as qc from pycnbi.triggers.trigger_def import trigger_def tdef = trigger_def('triggerdef_16.ini') DATA_DIRS = [r'D:\data\MI\rx1\train'] CHANNEL_PICKS = [5, 6, 7, 11] '''""""""""""""""""""""""""""" Epochs and events of interest """""""""""""""""""""""""""''' TRIGGERS = {tdef.LEFT_GO, tdef.RIGHT_GO} EPOCH = [-2.0, 4.0] EVENT_FILE = None '''""""""""""""""""""""""""""" Baseline relative to onset while plotting None in index 0: beginning of data None in index 1: end of data """""""""""""""""""""""""""''' BS_TIMES = (None, 0) '''""""""""""""""""""""""""""" PSD """""""""""""""""""""""""""''' FREQ_RANGE = np.arange(1, 40, 1) '''""""""""""""""""""""""""""" Unit conversion
def disp_params(self, cfg_template_module, cfg_module): """ Displays the parameters in the corresponding UI scrollArea. cfg = config module """ self.clear_params() # Extract the parameters and their possible values from the template modules. params = inspect.getmembers(cfg_template_module) # Extract the chosen values from the subject's specific module. all_chosen_values = inspect.getmembers(cfg_module) filePath = self.ui.lineEdit_pathSearch.text() # Load channels if self.modality == 'train': subjectDataPath = '%s/%s/fif' % (os.environ['PYCNBI_DATA'], filePath.split('/')[-1]) self.channels = read_params_from_txt(subjectDataPath, 'channelsList.txt') self.directions = () # Iterates over the classes for par in range(2): param = inspect.getmembers(params[par][1]) # Create layouts layout = QFormLayout() # Iterates over the list for p in param: # Remove useless attributes if '__' in p[0]: continue # Iterates over the dict for key, values in p[1].items(): chosen_value = self.extract_value_from_module( key, all_chosen_values) # For the feedback directions [offline and online]. if 'DIRECTIONS' in key: self.directions = values if self.modality is 'offline': nb_directions = 4 directions = Connect_Directions( key, chosen_value, values, nb_directions) elif self.modality is 'online': cls_path = self.paramsWidgets[ 'DECODER_FILE'].lineEdit_pathSearch.text() cls = qc.load_obj(cls_path) events = cls[ 'cls'].classes_ # Finds the events on which the decoder has been trained on events = list(map(int, events)) nb_directions = len(events) chosen_events = [ event[1] for event in chosen_value ] chosen_value = [val[0] for val in chosen_value] # Need tdef to convert int to str trigger values try: [tdef.by_value(i) for i in events] except: trigger_file = self.extract_value_from_module( 'TRIGGER_FILE', all_chosen_values) tdef = trigger_def(trigger_file) # self.on_guichanges('tdef', tdef) events = [tdef.by_value[i] for i in events] directions = Connect_Directions_Online( key, chosen_value, values, nb_directions, chosen_events, events) directions.signal_paramChanged.connect( self.on_guichanges) self.paramsWidgets.update({key: directions}) layout.addRow(key, directions.l) # For providing a folder path. elif 'PATH' in key: pathfolderfinder = PathFolderFinder( key, DEFAULT_PATH, chosen_value) pathfolderfinder.signal_pathChanged.connect( self.on_guichanges) self.paramsWidgets.update({key: pathfolderfinder}) layout.addRow(key, pathfolderfinder.layout) continue # For providing a file path. elif 'FILE' in key: pathfilefinder = PathFileFinder(key, chosen_value) pathfilefinder.signal_pathChanged.connect( self.on_guichanges) self.paramsWidgets.update({key: pathfilefinder}) layout.addRow(key, pathfilefinder.layout) continue # For the special case of choosing the trigger classes to train on elif 'TRIGGER_DEF' in key: trigger_file = self.extract_value_from_module( 'TRIGGER_FILE', all_chosen_values) tdef = trigger_def(trigger_file) # self.on_guichanges('tdef', tdef) nb_directions = 4 # Convert 'None' to real None (real None is removed when selected in the GUI) tdef_values = [ None if i == 'None' else i for i in list(tdef.by_name) ] directions = Connect_Directions( key, chosen_value, tdef_values, nb_directions) directions.signal_paramChanged.connect( self.on_guichanges) self.paramsWidgets.update({key: directions}) layout.addRow(key, directions.l) continue # To select specific electrodes elif '_CHANNELS' in key or 'CHANNELS_' in key: ch_select = Channel_Select(key, self.channels, chosen_value) ch_select.signal_paramChanged.connect( self.on_guichanges) self.paramsWidgets.update({key: ch_select}) layout.addRow(key, ch_select.layout) elif 'BIAS' in key: # Add None to the list in case of no bias wanted self.directions = tuple([None] + list(self.directions)) bias = Connect_Bias(key, self.directions, chosen_value) bias.signal_paramChanged.connect(self.on_guichanges) self.paramsWidgets.update({key: bias}) layout.addRow(key, bias.l) # For all the int values. elif values is int: spinBox = Connect_SpinBox(key, chosen_value) spinBox.signal_paramChanged.connect(self.on_guichanges) self.paramsWidgets.update({key: spinBox}) layout.addRow(key, spinBox.w) continue # For all the float values. elif values is float: doublespinBox = Connect_DoubleSpinBox( key, chosen_value) doublespinBox.signal_paramChanged.connect( self.on_guichanges) self.paramsWidgets.update({key: doublespinBox}) layout.addRow(key, doublespinBox.w) continue # For parameters with multiple non-fixed values in a list (user can modify them) elif values is list: modifiable_list = Connect_Modifiable_List( key, chosen_value) modifiable_list.signal_paramChanged.connect( self.on_guichanges) self.paramsWidgets.update({key: modifiable_list}) layout.addRow(key, modifiable_list.frame) continue # For parameters containing a string to modify elif values is str: lineEdit = Connect_LineEdit(key, chosen_value) lineEdit.signal_paramChanged[str, str].connect( self.on_guichanges) lineEdit.signal_paramChanged[str, type(None)].connect( self.on_guichanges) self.paramsWidgets.update({key: lineEdit}) layout.addRow(key, lineEdit.w) continue # For parameters with multiple fixed values. elif type(values) is tuple: comboParams = Connect_ComboBox(key, chosen_value, values) comboParams.signal_paramChanged.connect( self.on_guichanges) comboParams.signal_additionalParamChanged.connect( self.on_guichanges) self.paramsWidgets.update({key: comboParams}) layout.addRow(key, comboParams.layout) continue # For parameters with multiple non-fixed values in a dict (user can modify them) elif type(values) is dict: try: selection = chosen_value['selected'] comboParams = Connect_ComboBox( key, chosen_value, values) comboParams.signal_paramChanged.connect( self.on_guichanges) comboParams.signal_additionalParamChanged.connect( self.on_guichanges) self.paramsWidgets.update({key: comboParams}) layout.addRow(key, comboParams.layout) except: modifiable_dict = Connect_Modifiable_Dict( key, chosen_value, values) modifiable_dict.signal_paramChanged.connect( self.on_guichanges) self.paramsWidgets.update({key: modifiable_dict}) layout.addRow(key, modifiable_dict.frame) continue # Add a horizontal line to separate parameters' type. if p != param[-1]: separator = QFrame() separator.setFrameShape(QFrame.HLine) separator.setFrameShadow(QFrame.Sunken) layout.addRow(separator) # Display the parameters according to their types. if params[par][0] == 'Basic': self.ui.scrollAreaWidgetContents_Basics.setLayout(layout) elif params[par][0] == 'Advanced': self.ui.scrollAreaWidgetContents_Adv.setLayout(layout)