def run(cfg, state=mp.Value('i', 1), queue=None): """ Training protocol for Alpha/Theta neurofeedback. """ redirect_stdout_to_queue(logger, queue, 'INFO') # add tdef object cfg.tdef = trigger_def(cfg.TRIGGER_FILE) # Extract features if not state.value: sys.exit(-1) freqs = np.arange(cfg.FEATURES['PSD']['fmin'], cfg.FEATURES['PSD']['fmax']+0.5, 1/cfg.FEATURES['PSD']['wlen']) featdata = features.compute_features(cfg) # Average the PSD over the windows window_avg_psd = np.mean(np.squeeze(featdata['X_data']), 0) # Alpha ref over the alpha band alpha_ref = round(np.mean(window_avg_psd[freqs>=8])) alpha_thr = round(alpha_ref - (0.5 * np.std(window_avg_psd[freqs>=8]) )) # Theta ref over Theta band theta_ref = round(np.mean(window_avg_psd[freqs<8])) theta_thr = round(theta_ref - (0.5 * np.std(window_avg_psd[freqs<8]) )) logger.info('Theta ref = {}; alpha ref ={}' .format(theta_ref, alpha_ref)) logger.info('Theta thr = {}; alpha thr ={}' .format(theta_thr, alpha_thr))
def on_new_tdef_file(self, key, trigger_file): """ Update the event QComboBox with the new events from the new tdef file. """ self.tdef = trigger_def(trigger_file) if self.events: self.on_update_VBoxLayout()
def on_new_tdef_file(self, key, trigger_file): """ Update the QComboBox with the new events from the new tdef file. """ self.clear_hBoxLayout() tdef = trigger_def(trigger_file) nb_directions = 4 # Convert 'None' to real None (real None is removed when selected in the GUI) tdef_values = [None if i == 'None' else i for i in list(tdef.by_name)] self.create_the_comboBoxes(self.chosen_value, tdef_values, nb_directions)
def merge_events(trigger_file, events, rawfile_in, rawfile_out): tdef = trigger_def(trigger_file) raw, eve = pu.load_raw(rawfile_in) logger.info('=== Before merging ===') notfounds = [] for key in np.unique(eve[:, 2]): if key in tdef.by_value: logger.info( '%s: %d events' % (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0]))) else: logger.info('%d: %d events' % (key, len(np.where(eve[:, 2] == key)[0]))) notfounds.append(key) if notfounds: for key in notfounds: logger.warning('Key %d was not found in the definition file.' % key) for key in events: ev_src = events[key] ev_out = tdef.by_name[key] x = [] for e in ev_src: x.append(np.where(eve[:, 2] == tdef.by_name[e])[0]) eve[np.concatenate(x), 2] = ev_out # sanity check dups = np.where(0 == np.diff(eve[:, 0]))[0] assert len(dups) == 0 # reset trigger channel raw._data[0] *= 0 raw.add_events(eve, 'TRIGGER') raw.save(rawfile_out, overwrite=True) logger.info('=== After merging ===') for key in np.unique(eve[:, 2]): if key in tdef.by_value: logger.info( '%s: %d events' % (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0]))) else: logger.info('%s: %d events' % (key, len(np.where(eve[:, 2] == key)[0])))
def run(cfg, state=mp.Value('i', 1), queue=None, interactive=False, cv_file=None, feat_file=None, logger=logger): redirect_stdout_to_queue(logger, queue, 'INFO') # add tdef object cfg.tdef = trigger_def(cfg.TRIGGER_FILE) # Extract features if not state.value: sys.exit(-1) featdata = features.compute_features(cfg) # Find optimal threshold for TPR balancing #balance_tpr(cfg, featdata) # Perform cross validation if not state.value: sys.exit(-1) if cfg.CV_PERFORM[cfg.CV_PERFORM['selected']] is not None: cross_validate(cfg, featdata, cv_file=cv_file) # Train a decoder if not state.value: sys.exit(-1) if cfg.EXPORT_CLS is True: train_decoder(cfg, featdata, feat_file=feat_file) with state.get_lock(): state.value = 0
def stream_player(server_name, fif_file, chunk_size, auto_restart=True, wait_start=True, repeat=np.float('inf'), high_resolution=False, trigger_file=None): """ Input ===== server_name: LSL server name. fif_file: fif file to replay. chunk_size: number of samples to send at once (usually 16-32 is good enough). auto_restart: play from beginning again after reaching the end. wait_start: wait for user to start in the beginning. repeat: number of loops to play. high_resolution: use perf_counter() instead of sleep() for higher time resolution but uses much more cpu due to polling. trigger_file: used to convert event numbers into event strings for readability. Note: Run neurodecode.set_log_level('DEBUG') to print out the relative time stamps since started. """ raw, events = pu.load_raw(fif_file) sfreq = raw.info['sfreq'] # sampling frequency n_channels = len(raw.ch_names) # number of channels if trigger_file is not None: tdef = trigger_def(trigger_file) try: event_ch = raw.ch_names.index('TRIGGER') except ValueError: event_ch = None if raw is not None: logger.info_green('Successfully loaded %s' % fif_file) logger.info('Server name: %s' % server_name) logger.info('Sampling frequency %.3f Hz' % sfreq) logger.info('Number of channels : %d' % n_channels) logger.info('Chunk size : %d' % chunk_size) for i, ch in enumerate(raw.ch_names): logger.info('%d %s' % (i, ch)) logger.info('Trigger channel : %s' % event_ch) else: raise RuntimeError('Error while loading %s' % fif_file) # set server information sinfo = pylsl.StreamInfo(server_name, channel_count=n_channels, channel_format='float32',\ nominal_srate=sfreq, type='EEG', source_id=server_name) desc = sinfo.desc() channel_desc = desc.append_child("channels") for ch in raw.ch_names: channel_desc.append_child('channel').append_child_value('label', str(ch))\ .append_child_value('type','EEG').append_child_value('unit','microvolts') desc.append_child('amplifier').append_child('settings').append_child_value( 'is_slave', 'false') desc.append_child('acquisition').append_child_value( 'manufacturer', 'NeuroDecode').append_child_value('serial_number', 'N/A') outlet = pylsl.StreamOutlet(sinfo, chunk_size=chunk_size) if wait_start: input('Press Enter to start streaming.') logger.info('Streaming started') idx_chunk = 0 t_chunk = chunk_size / sfreq finished = False if high_resolution: t_start = time.perf_counter() else: t_start = time.time() # start streaming played = 1 while played < repeat: idx_current = idx_chunk * chunk_size chunk = raw._data[:, idx_current:idx_current + chunk_size] data = chunk.transpose().tolist() if idx_current >= raw._data.shape[1] - chunk_size: finished = True if high_resolution: # if a resolution over 2 KHz is needed t_sleep_until = t_start + idx_chunk * t_chunk while time.perf_counter() < t_sleep_until: pass else: # time.sleep() can have 500 us resolution using the tweak tool provided. t_wait = t_start + idx_chunk * t_chunk - time.time() if t_wait > 0.001: time.sleep(t_wait) outlet.push_chunk(data) logger.debug('[%8.3fs] sent %d samples (LSL %8.3f)' % (time.perf_counter(), len(data), pylsl.local_clock())) if event_ch is not None: event_values = set(chunk[event_ch]) - set([0]) if len(event_values) > 0: if trigger_file is None: logger.info('Events: %s' % event_values) else: for event in event_values: if event in tdef.by_value: logger.info('Events: %s (%s)' % (event, tdef.by_value[event])) else: logger.info('Events: %s (Undefined event)' % event) idx_chunk += 1 if finished: if auto_restart is False: input( 'Reached the end of data. Press Enter to restart or Ctrl+C to stop.' ) else: logger.info('Reached the end of data. Restarting.') idx_chunk = 0 finished = False if high_resolution: t_start = time.perf_counter() else: t_start = time.time() played += 1
def run(cfg, state=mp.Value('i', 1), queue=None): redirect_stdout_to_queue(logger, queue, 'INFO') # Wait the recording to start (GUI) while state.value == 2: # 0: stop, 1:start, 2:wait pass # Protocol start if equals to 1 if not state.value: sys.exit() refresh_delay = 1.0 / cfg.REFRESH_RATE cfg.tdef = trigger_def(cfg.TRIGGER_FILE) # visualizer keys = { 'left': 81, 'right': 83, 'up': 82, 'down': 84, 'pgup': 85, 'pgdn': 86, 'home': 80, 'end': 87, 'space': 32, 'esc': 27, ',': 44, '.': 46, 's': 115, 'c': 99, '[': 91, ']': 93, '1': 49, '!': 33, '2': 50, '@': 64, '3': 51, '#': 35 } color = dict(G=(20, 140, 0), B=(210, 0, 0), R=(0, 50, 200), Y=(0, 215, 235), K=(0, 0, 0), w=(200, 200, 200)) dir_sequence = [] for x in range(cfg.TRIALS_EACH): dir_sequence.extend(cfg.DIRECTIONS) random.shuffle(dir_sequence) num_trials = len(cfg.DIRECTIONS) * cfg.TRIALS_EACH event = 'start' trial = 1 # Hardware trigger if cfg.TRIGGER_DEVICE is None: logger.warning( 'No trigger device set. Press Ctrl+C to stop or Enter to continue.' ) #input() trigger = pyLptControl.Trigger(state, cfg.TRIGGER_DEVICE) if trigger.init(50) == False: logger.error( '\n** Error connecting to USB2LPT device. Use a mock trigger instead?' ) input('Press Ctrl+C to stop or Enter to continue.') trigger = pyLptControl.MockTrigger() trigger.init(50) # timers timer_trigger = qc.Timer() timer_dir = qc.Timer() timer_refresh = qc.Timer() t_dir = cfg.TIMINGS['DIR'] + random.uniform(-cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE']) t_dir_ready = cfg.TIMINGS['READY'] + random.uniform( -cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE']) bar = BarVisual(cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE) bar.fill() bar.glass_draw_cue() # start while trial <= num_trials: timer_refresh.sleep_atleast(refresh_delay) timer_refresh.reset() # segment= { 'cue':(s,e), 'dir':(s,e), 'label':0-4 } (zero-based) if event == 'start' and timer_trigger.sec() > cfg.TIMINGS['INIT']: event = 'gap_s' bar.fill() timer_trigger.reset() trigger.signal(cfg.tdef.INIT) elif event == 'gap_s': if cfg.TRIAL_PAUSE: bar.put_text('Press any key') bar.update() key = cv2.waitKey() if key == keys['esc'] or not state.value: break bar.fill() bar.put_text('Trial %d / %d' % (trial, num_trials)) event = 'gap' timer_trigger.reset() elif event == 'gap' and timer_trigger.sec() > cfg.TIMINGS['GAP']: event = 'cue' bar.fill() bar.draw_cue() trigger.signal(cfg.tdef.CUE) timer_trigger.reset() elif event == 'cue' and timer_trigger.sec() > cfg.TIMINGS['CUE']: event = 'dir_r' dir = dir_sequence[trial - 1] if dir == 'L': # left bar.move('L', 100, overlay=True) trigger.signal(cfg.tdef.LEFT_READY) elif dir == 'R': # right bar.move('R', 100, overlay=True) trigger.signal(cfg.tdef.RIGHT_READY) elif dir == 'U': # up bar.move('U', 100, overlay=True) trigger.signal(cfg.tdef.UP_READY) elif dir == 'D': # down bar.move('D', 100, overlay=True) trigger.signal(cfg.tdef.DOWN_READY) elif dir == 'B': # both hands bar.move('L', 100, overlay=True) bar.move('R', 100, overlay=True) trigger.signal(cfg.tdef.BOTH_READY) else: raise RuntimeError('Unknown direction %d' % dir) timer_trigger.reset() elif event == 'dir_r' and timer_trigger.sec() > t_dir_ready: bar.fill() bar.draw_cue() event = 'dir' timer_trigger.reset() timer_dir.reset() if dir == 'L': # left trigger.signal(cfg.tdef.LEFT_GO) elif dir == 'R': # right trigger.signal(cfg.tdef.RIGHT_GO) elif dir == 'U': # up trigger.signal(cfg.tdef.UP_GO) elif dir == 'D': # down trigger.signal(cfg.tdef.DOWN_GO) elif dir == 'B': # both trigger.signal(cfg.tdef.BOTH_GO) else: raise RuntimeError('Unknown direction %d' % dir) elif event == 'dir' and timer_trigger.sec() > t_dir: event = 'gap_s' bar.fill() trial += 1 logger.info('trial ' + str(trial - 1) + ' done') trigger.signal(cfg.tdef.BLANK) timer_trigger.reset() t_dir = cfg.TIMINGS['DIR'] + random.uniform( -cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE']) t_dir_ready = cfg.TIMINGS['READY'] + random.uniform( -cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE']) # protocol if event == 'dir': dx = min(100, int(100.0 * timer_dir.sec() / t_dir) + 1) if dir == 'L': # L bar.move('L', dx, overlay=True) elif dir == 'R': # R bar.move('R', dx, overlay=True) elif dir == 'U': # U bar.move('U', dx, overlay=True) elif dir == 'D': # D bar.move('D', dx, overlay=True) elif dir == 'B': # Both bar.move('L', dx, overlay=True) bar.move('R', dx, overlay=True) # wait for start if event == 'start': bar.put_text('Waiting to start') bar.update() key = 0xFF & cv2.waitKey(1) if key == keys['esc'] or not state.value: break bar.finish() with state.get_lock(): state.value = 0
import neurodecode import numpy as np import neurodecode.utils.q_common as qc from neurodecode.triggers.trigger_def import trigger_def tdef = trigger_def('triggerdef_16.ini') DATA_DIRS = [r'D:\data\MI\rx1\train'] CHANNEL_PICKS = [5, 6, 7, 11] '''""""""""""""""""""""""""""" Epochs and events of interest """""""""""""""""""""""""""''' TRIGGERS = {tdef.LEFT_GO, tdef.RIGHT_GO} EPOCH = [-2.0, 4.0] EVENT_FILE = None '''""""""""""""""""""""""""""" Baseline relative to onset while plotting None in index 0: beginning of data None in index 1: end of data """""""""""""""""""""""""""''' BS_TIMES = (None, 0) '''""""""""""""""""""""""""""" PSD """""""""""""""""""""""""""''' FREQ_RANGE = np.arange(1, 40, 1) '''""""""""""""""""""""""""""" Unit conversion
def run(cfg, state=mp.Value('i', 1), queue=None): """ Offline protocol """ # visualizer keys = { 'left': 81, 'right': 83, 'up': 82, 'down': 84, 'pgup': 85, 'pgdn': 86, 'home': 80, 'end': 87, 'space': 32, 'esc': 27, ',': 44, '.': 46, 's': 115, 'c': 99, '[': 91, ']': 93, '1': 49, '!': 33, '2': 50, '@': 64, '3': 51, '#': 35 } redirect_stdout_to_queue(logger, queue, 'INFO') # Wait the recording to start (GUI) while state.value == 2: # 0: stop, 1:start, 2:wait pass # Protocol runs if state equals to 1 if not state.value: sys.exit(-1) global_timer = qc.Timer(autoreset=False) # Init trigger communication cfg.tdef = trigger_def(cfg.TRIGGER_FILE) trigger = pyLptControl.Trigger(state, cfg.TRIGGER_DEVICE) if trigger.init(50) == False: logger.error('\n** Error connecting to trigger device.') raise RuntimeError # Preload the starting voice pgmixer.init() pgmixer.music.load(cfg.START_VOICE) # Init feedback viz = BarVisual(cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE) viz.fill() viz.put_text('Close your eyes and relax') viz.update() # PLay the start voice pgmixer.music.play() # Wait a key press key = 0xFF & cv2.waitKey(0) if key == keys['esc'] or not state.value: sys.exit(-1) viz.fill() viz.put_text('Recording in progress') viz.update() #---------------------------------------------------------------------- # Main #---------------------------------------------------------------------- trigger.signal(cfg.tdef.INIT) while state.value == 1 and global_timer.sec() < cfg.GLOBAL_TIME: key = cv2.waitKey(1) if key == keys['esc']: with state.get_lock(): state.value = 0 trigger.signal(cfg.tdef.END) # Remove the text viz.fill() viz.put_text('Recording is finished') viz.update() # Ending voice pgmixer.music.load(cfg.END_VOICE) pgmixer.music.play() time.sleep(5) # Close cv2 window viz.finish()
def run(cfg, state=mp.Value('i', 1), queue=None): redirect_stdout_to_queue(logger, queue, 'INFO') # Wait the recording to start (GUI) while state.value == 2: # 0: stop, 1:start, 2:wait pass # Protocol runs if state equals to 1 if not state.value: sys.exit(-1) if cfg.FAKE_CLS is None: # chooose amp if cfg.AMP_NAME is None and cfg.AMP_SERIAL is None: amp_name, amp_serial = pu.search_lsl(state, ignore_markers=True) else: amp_name = cfg.AMP_NAME amp_serial = cfg.AMP_SERIAL fake_dirs = None else: amp_name = None amp_serial = None fake_dirs = [v for (k, v) in cfg.DIRECTIONS] # events and triggers tdef = trigger_def(cfg.TRIGGER_FILE) #if cfg.TRIGGER_DEVICE is None: # input('\n** Warning: No trigger device set. Press Ctrl+C to stop or Enter to continue.') trigger = pyLptControl.Trigger(state, cfg.TRIGGER_DEVICE) if trigger.init(50) == False: logger.error( 'Cannot connect to USB2LPT device. Use a mock trigger instead?') input('Press Ctrl+C to stop or Enter to continue.') trigger = pyLptControl.MockTrigger() trigger.init(50) # init classification decoder = BCIDecoderDaemon( cfg.DECODER_FILE, buffer_size=1.0, fake=(cfg.FAKE_CLS is not None), amp_name=amp_name, amp_serial=amp_serial, fake_dirs=fake_dirs, parallel=cfg.PARALLEL_DECODING[cfg.PARALLEL_DECODING['selected']], alpha_new=cfg.PROB_ALPHA_NEW) # OLD: requires trigger values to be always defined #labels = [tdef.by_value[x] for x in decoder.get_labels()] # NEW: events can be mapped into integers: labels = [] dirdata = set([d[1] for d in cfg.DIRECTIONS]) for x in decoder.get_labels(): if x not in dirdata: labels.append(tdef.by_value[x]) else: labels.append(x) # map class labels to bar directions bar_def = {label: str(dir) for dir, label in cfg.DIRECTIONS} bar_dirs = [bar_def[l] for l in labels] dir_seq = [] for x in range(cfg.TRIALS_EACH): dir_seq.extend(bar_dirs) if cfg.TRIALS_RANDOMIZE: random.shuffle(dir_seq) else: dir_seq = [d[0] for d in cfg.DIRECTIONS] * cfg.TRIALS_EACH num_trials = len(dir_seq) logger.info('Initializing decoder.') while decoder.is_running() is 0: time.sleep(0.01) # bar visual object if cfg.FEEDBACK_TYPE == 'BAR': from neurodecode.protocols.viz_bars import BarVisual visual = BarVisual(cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE) if cfg.FEEDBACK_TYPE == 'COLORS': from neurodecode.protocols.viz_colors import ColorVisual visual = ColorVisual(cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE) elif cfg.FEEDBACK_TYPE == 'BODY': assert hasattr(cfg, 'FEEDBACK_IMAGE_PATH' ), 'FEEDBACK_IMAGE_PATH is undefined in your config.' from neurodecode.protocols.viz_human import BodyVisual visual = BodyVisual(cfg.FEEDBACK_IMAGE_PATH, use_glass=cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE) visual.put_text('Waiting to start') if cfg.LOG_PROBS: logdir = qc.parse_path_list(cfg.DECODER_FILE)[0] probs_logfile = time.strftime(logdir + "probs-%Y%m%d-%H%M%S.txt", time.localtime()) else: probs_logfile = None feedback = Feedback(cfg, state, visual, tdef, trigger, probs_logfile) # start trial = 1 dir_detected = [] prob_history = {c: [] for c in bar_dirs} while trial <= num_trials: if cfg.SHOW_TRIALS: title_text = 'Trial %d / %d' % (trial, num_trials) else: title_text = 'Ready' true_label = dir_seq[trial - 1] # profiling feedback #import cProfile #pr = cProfile.Profile() #pr.enable() result = feedback.classify(decoder, true_label, title_text, bar_dirs, prob_history=prob_history) #pr.disable() #pr.print_stats(sort='time') if result is None: break else: pred_label = result dir_detected.append(pred_label) if cfg.WITH_REX is True and pred_label == true_label: # if cfg.WITH_REX is True: if pred_label == 'U': rex_dir = 'N' elif pred_label == 'L': rex_dir = 'W' elif pred_label == 'R': rex_dir = 'E' elif pred_label == 'D': rex_dir = 'S' else: logger.warning('Rex cannot execute undefined action %s' % pred_label) rex_dir = None if rex_dir is not None: visual.move(pred_label, 100, overlay=False, barcolor='B') visual.update() logger.info('Executing Rex action %s' % rex_dir) os.system('%s/Rex/RexControlSimple.exe %s %s' % (pycnbi.ROOT, cfg.REX_COMPORT, rex_dir)) time.sleep(8) if true_label == pred_label: msg = 'Correct' else: msg = 'Wrong' if cfg.TRIALS_RETRY is False or true_label == pred_label: logger.info('Trial %d: %s (%s -> %s)' % (trial, msg, true_label, pred_label)) trial += 1 if len(dir_detected) > 0: # write performance and log results fdir, _, _ = qc.parse_path_list(cfg.DECODER_FILE) logfile = time.strftime(fdir + "/online-%Y%m%d-%H%M%S.txt", time.localtime()) with open(logfile, 'w') as fout: fout.write('Ground-truth,Prediction\n') for gt, dt in zip(dir_seq, dir_detected): fout.write('%s,%s\n' % (gt, dt)) cfmat, acc = qc.confusion_matrix(dir_seq, dir_detected) fout.write('\nAccuracy %.3f\nConfusion matrix\n' % acc) fout.write(cfmat) logger.info('Log exported to %s' % logfile) print('\nAccuracy %.3f\nConfusion matrix\n' % acc) print(cfmat) visual.finish() with state.get_lock(): state.value = 0 if decoder: decoder.stop() ''' # automatic thresholding if prob_history and len(bar_dirs) == 2: total = sum(len(prob_history[c]) for c in prob_history) fout = open(probs_logfile, 'a') msg = 'Automatic threshold optimization.\n' max_acc = 0 max_bias = 0 for bias in np.arange(-0.99, 1.00, 0.01): corrects = 0 for p in prob_history[bar_dirs[0]]: p_biased = (p + bias) / (bias + 1) # new sum = (p+bias) + (1-p) = bias+1 if p_biased >= 0.5: corrects += 1 for p in prob_history[bar_dirs[1]]: p_biased = (p + bias) / (bias + 1) # new sum = (p+bias) + (1-p) = bias+1 if p_biased < 0.5: corrects += 1 acc = corrects / total msg += '%s%.2f: %.3f\n' % (bar_dirs[0], bias, acc) if acc > max_acc: max_acc = acc max_bias = bias msg += 'Max acc = %.3f at bias %.2f\n' % (max_acc, max_bias) fout.write(msg) fout.close() print(msg) ''' logger.info('Finished.')
def run(cfg, amp_name, amp_serial, state=mp.Value('i', 1), experiment_mode=True, baseline=False): """ Online protocol for Alpha/Theta neurofeedback. """ #---------------------------------------------------------------------- # LSL stream connection #---------------------------------------------------------------------- sr = protocol_utils.connect_lsl_stream(amp_name=amp_name, amp_serial=amp_serial, window_size=cfg['window_size'], buffer_size=cfg['buffer_size']) sfreq = sr.get_sample_rate() trg_ch = sr.get_trigger_channel() #---------------------------------------------------------------------- # PSD estimators initialization #---------------------------------------------------------------------- psde_alpha = protocol_utils.init_psde(*list(cfg['alpha_band_freq'].values()), sampling_frequency=cfg['sampling_frequency'], n_jobs=cfg['n_jobs']) psde_theta = protocol_utils.init_psde(*list(cfg['theta_band_freq'].values()), sampling_frequency=cfg['sampling_frequency'], n_jobs=cfg['n_jobs']) #---------------------------------------------------------------------- # Initialize the feedback sounds #---------------------------------------------------------------------- sound_1, sound_2 = protocol_utils.init_feedback_sounds(cfg['music_state_1_path'], cfg['music_state_2_path']) #---------------------------------------------------------------------- # Main #---------------------------------------------------------------------- global_timer = qc.Timer(autoreset=False) internal_timer = qc.Timer(autoreset=True) pgmixer.init() if experiment_mode: # Init trigger communication trigger_signals = trigger_def(cfg['trigger_file']) trigger = pyLptControl.Trigger(state, cfg['trigger_device']) if trigger.init(50) == False: logger.error('\n** Error connecting to trigger device.') raise RuntimeError # Preload the starting voice print(cfg['start_voice_file']) pgmixer.music.load(cfg['start_voice_file']) # Init feedback viz = BarVisual(False, screen_pos=cfg['screen_pos'], screen_size=cfg['screen_size']) viz.fill() viz.put_text('Close your eyes and relax') viz.update() pgmixer.music.play() # Wait a key press key = 0xFF & cv2.waitKey(0) if key == KEYS['esc'] or not state.value: sys.exit(-1) print('recording started') trigger.signal(trigger_signals.INIT) state = 'RATIO_FEEDBACK' if not baseline: sound_1.play(loops=-1) sound_2.play(loops=-1) current_max = 0 last_ts = None last_ratio = None measured_psd_ratios = np.full(cfg['window_size_psd_max'], np.nan) while global_timer.sec() < cfg['global_time']: #---------------------------------------------------------------------- # Data acquisition #---------------------------------------------------------------------- # Pz = 8 sr.acquire() window, tslist = sr.get_window() # window = [samples x channels] window = window.T # window = [channels x samples] # Check if proper real-time acquisition if last_ts is not None: tsnew = np.where(np.array(tslist) > last_ts)[0] if len(tsnew) == 0: logger.warning('There seems to be delay in receiving data.') time.sleep(1) continue # Spatial filtering window = pu.preprocess(window, sfreq=sfreq, spatial=cfg.get('spatial_filter'), spatial_ch=cfg.get('spatial_channels')) #---------------------------------------------------------------------- # Computing the Power Spectrum Densities using multitapers #---------------------------------------------------------------------- # PSD if not baseline: if cfg['feature_type'] == FeatureType.THETA: psd_theta = protocol_utils.compute_psd(window, psde_theta) feature = psd_theta elif cfg['feature_type'] == FeatureType.ALPHA_THETA: psd_alpha = protocol_utils.compute_psd(window, psde_alpha) psd_theta = protocol_utils.compute_psd(window, psde_theta) feature = psd_alpha / psd_theta measured_psd_ratios = add_to_queue(measured_psd_ratios, feature) current_music_ratio = feature / np.max(measured_psd_ratios[~np.isnan(measured_psd_ratios)]) #current_music_ratio = feature / np.max(measured_psd_ratios) if last_ratio is not None: applied_music_ratio = last_ratio + (current_music_ratio - last_ratio) * 0.25 else: applied_music_ratio = current_music_ratio mix_sounds(style=cfg['music_mix_style'], sounds=(sound_1, sound_2), feature_value=applied_music_ratio) print((f"{cfg['feature_type']}: {feature:0.3f}" f"\t, current_music_ratio: {current_music_ratio:0.3f}" f"\t, applied music ratio: {applied_music_ratio:0.3f}" )) last_ratio = applied_music_ratio last_ts = tslist[-1] internal_timer.sleep_atleast(cfg['timer_sleep']) if not baseline: sound_1.fadeout(3) sound_2.fadeout(3) if experiment_mode: trigger.signal(trigger_signals.END) # Remove the text viz.fill() viz.put_text('Recording is finished') viz.update() # Ending voice pgmixer.music.load(cfg['end_voice_file']) pgmixer.music.play() time.sleep(5) # Close cv2 window viz.finish() print('done')