def _check_cfg_selected(cfg, optional_vars, select): """ Used in case of dict attributes containing subparams Check that the selected cfg params is valid and that its subparameters are defined. Parameters ---------- cfg : python.module The config module containing the parameters to check optional_vars : The optional parameters with predefined values for the param selected = the cfg parameter (type=dict) containing a key: selected """ param = getattr(cfg, select) selected = param['selected'] if selected not in param: logger.error('%s not defined in config.' % selected) raise RuntimeError for v, vv in optional_vars[selected].items(): if v not in param[selected]: param[selected].update({v: vv}) setattr(cfg, select, param) logger.warning( 'Updating internal parameter for classifier %s: %s=%s' % (selected, v, vv))
def event_timestamps_to_indices(sigfile, eventfile): """ Convert LSL timestamps to sample indices for separetely recorded events. Parameters: sigfile: raw signal file (Python Pickle) recorded with stream_recorder.py. eventfile: event file where events are indexed with LSL timestamps. Returns: events list, which can be used as an input to mne.io.RawArray.add_events(). """ raw = qc.load_obj(sigfile) ts = raw['timestamps'].reshape(-1) ts_min = min(ts) ts_max = max(ts) events = [] with open(eventfile) as f: for l in f: data = l.strip().split('\t') event_ts = float(data[0]) event_value = int(data[2]) # find the first index not smaller than ts next_index = np.searchsorted(ts, event_ts) if next_index >= len(ts): logger.warning('Event %d at time %.3f is out of time range (%.3f - %.3f).' % (event_value, event_ts, ts_min, ts_max)) else: events.append([next_index, 0, event_value]) return events
def load_config(cfg_module): if '/' in cfg_module: cfg_module = cfg_module.replace('/', '.').replace('.py', '') logger.warning('Replacing deprecated config path to new style: %s' % cfg_module) logger.warning('Please change your argument.') return importlib.import_module(cfg_module)
def run(cfg, state=mp.Value('i', 1), queue=None): """ Online protocol for Alpha/Theta neurofeedback. """ redirect_stdout_to_queue(logger, queue, 'INFO') # Wait the recording to start (GUI) while state.value == 2: # 0: stop, 1:start, 2:wait pass # Protocol runs if state equals to 1 if not state.value: sys.exit(-1) #---------------------------------------------------------------------- # LSL stream connection #---------------------------------------------------------------------- # chooose amp amp_name, amp_serial = find_lsl_stream(cfg, state) # Connect to lsl stream sr = connect_lsl_stream(cfg, amp_name, amp_serial) # Get sampling rate sfreq = sr.get_sample_rate() # Get trigger channel trg_ch = sr.get_trigger_channel() #---------------------------------------------------------------------- # Main #---------------------------------------------------------------------- global_timer = qc.Timer(autoreset=False) internal_timer = qc.Timer(autoreset=True) while state.value == 1 and global_timer.sec() < cfg.GLOBAL_TIME: #---------------------------------------------------------------------- # Data acquisition #---------------------------------------------------------------------- sr.acquire() window, tslist = sr.get_window() # window = [samples x channels] window = window.T # window = [channels x samples] # Check if proper real-time acquisition tsnew = np.where(np.array(tslist) > last_ts)[0] if len(tsnew) == 0: logger.warning('There seems to be delay in receiving data.') time.sleep(1) continue #---------------------------------------------------------------------- # ADD YOUR CODE HERE #---------------------------------------------------------------------- last_ts = tslist[-1] internal_timer.sleep_atleast(cfg.TIMER_SLEEP)
def check_config(cfg): """ Ensure that the config file contains the parameters """ critical_vars = { 'COMMON': [ 'TRIGGER_DEVICE', 'TRIGGER_FILE', 'SCREEN_SIZE', 'START_VOICE', 'END_VOICE' ], } optional_vars = { 'GLOBAL_TIME': 2 * 60, 'SCREEN_POS': (0, 0), 'GLASS_USE': False, } for key in critical_vars['COMMON']: if not hasattr(cfg, key): logger.error('%s is a required parameter' % key) raise RuntimeError for key in optional_vars: if not hasattr(cfg, key): setattr(cfg, key, optional_vars[key]) logger.warning('Setting undefined parameter %s=%s' % (key, getattr(cfg, key)))
def check_cfg_optional(cfg, optional_vars, key_var): """ Check that the optional parameters are defined and if not assign them cfg = config module containing the parameters to check optional_vars = optional parameters with predefined values key_var = key to look at in optional_vars """ for key, val in optional_vars[key_var].items(): if not hasattr(cfg, key): setattr(cfg, key, val) logger.warning('Setting undefined parameter %s=%s' % (key, getattr(cfg, key)))
def stop(self): """ Stop the daemon """ if self.is_running() == 0: logger.warning('Decoder already stopped.') return for running in self.running: running.value = 0 for proc in self.procs: proc.join(10) if proc.is_alive(): logger.warning('Process %s did not die properly.' % proc.pid()) self.reset() logger.info(self.stopmsg)
def sort_by_value(s, rev=False): """ Sort dictionary or list by value and return a sorted list of keys and values. Values must be hashable and unique. """ assert type(s) == dict or type( s) == list, 'Input must be a dictionary or list.' if type(s) == list: s = dict(enumerate(s)) s_rev = dict((v, k) for k, v in s.items()) if Q_VERBOSE > 0 and not len(s_rev) == len(s): logger.warning('sort_by_value(): %d identical values' % (len(s.values()) - len(set(s.values())) + 1)) values = sorted(s_rev, reverse=rev) keys = [s_rev[x] for x in values] return keys, values
def paintEvent(self, e): # Distinguish between paint events from timer and event QT widget resizing, clicking etc (sender is None) # We should only paint when the timer triggered the event. # Just in case, there's a flag to force a repaint even when we shouldn't repaint sender = self.sender() if 'force_repaint' not in self.__dict__.keys(): logger.warning('force_repaint is not set! Is it a Qt bug?') self.force_repaint = 0 if (sender is None) and (not self.force_repaint): pass else: self.force_repaint = 0 qp = QPainter() qp.begin(self) # Update the interface self.paintInterface(qp) qp.end()
def merge_events(trigger_file, events, rawfile_in, rawfile_out): tdef = trigger_def(trigger_file) raw, eve = pu.load_raw(rawfile_in) logger.info('=== Before merging ===') notfounds = [] for key in np.unique(eve[:, 2]): if key in tdef.by_value: logger.info( '%s: %d events' % (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0]))) else: logger.info('%d: %d events' % (key, len(np.where(eve[:, 2] == key)[0]))) notfounds.append(key) if notfounds: for key in notfounds: logger.warning('Key %d was not found in the definition file.' % key) for key in events: ev_src = events[key] ev_out = tdef.by_name[key] x = [] for e in ev_src: x.append(np.where(eve[:, 2] == tdef.by_name[e])[0]) eve[np.concatenate(x), 2] = ev_out # sanity check dups = np.where(0 == np.diff(eve[:, 0]))[0] assert len(dups) == 0 # reset trigger channel raw._data[0] *= 0 raw.add_events(eve, 'TRIGGER') raw.save(rawfile_out, overwrite=True) logger.info('=== After merging ===') for key in np.unique(eve[:, 2]): if key in tdef.by_value: logger.info( '%s: %d events' % (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0]))) else: logger.info('%s: %d events' % (key, len(np.where(eve[:, 2] == key)[0])))
def sort_by_value(s, reverse=False): assert type(s) == dict or type( s) == list, 'Input must be a dictionary or list.' if type(s) == list: s = dict(enumerate(s)) s_rev = dict((v, k) for k, v in s.items()) if not len(s_rev) == len(s): logger.warning('sort_by_value(): %d identical values' % (len(s.values()) - len(set(s.values())) + 1)) values = sorted(s_rev, reverse=reverse) keys = [s_rev[x] for x in values] return keys, values
def check_config(cfg): critical_vars = { 'COMMON': [ 'TRIGGER_DEVICE', 'TRIGGER_FILE', 'SCREEN_SIZE', 'DIRECTIONS', 'DIR_RANDOM', 'TRIALS_EACH' ], 'TIMINGS': [ 'INIT', 'GAP', 'CUE', 'READY', 'READY_RANDOMIZE', 'DIR', 'DIR_RANDOMIZE' ] } optional_vars = { 'FEEDBACK_TYPE': 'BAR', 'FEEDBACK_IMAGE_PATH': None, 'SCREEN_POS': (0, 0), 'DIR_RANDOM': True, 'GLASS_USE': False, 'TRIAL_PAUSE': False, 'REFRESH_RATE': 30 } for key in critical_vars['COMMON']: if not hasattr(cfg, key): raise RuntimeError('%s is a required parameter' % key) if not hasattr(cfg, 'TIMINGS'): logger.error('"TIMINGS" not defined in config.') raise RuntimeError for v in critical_vars['TIMINGS']: if v not in cfg.TIMINGS: logger.error('%s not defined in config.' % v) raise RuntimeError for key in optional_vars: if not hasattr(cfg, key): setattr(cfg, key, optional_vars[key]) logger.warning('Setting undefined %s=%s' % (key, optional[key])) if getattr(cfg, 'TRIGGER_DEVICE') == None: logger.warning( 'The trigger device is set to None! No events will be saved.') raise RuntimeError( 'The trigger device is set to None! No events will be saved.')
def check_config(cfg): ''' Check the variables contained in the loaded config file Parameters ---------- cfg : python.module The loaded config module ''' # Add here the critical variables that need to be defined in the config_offline.py critical_vars = { 'COMMON': ['TRIALS_NB', 'TRIGGER_FILE', 'TRIGGER_DEVICE'], } # Add here the optional variables that do not need to be defined in the config_offline.py # If not defined, the variable will be added with the value defined below optional_vars = { 'COMMON': { 'REFRESH_RATE': 20, }, # Internal parmameters for the CCC 'XXX': { 'min': 1, 'max': 40, }, } # Check the critical variables _check_cfg_mandatory(cfg, critical_vars, 'COMMON') # Check the optional variables _check_cfg_optional(cfg, optional_vars, 'COMMON') # Check the internal param of CCC _check_cfg_selected(cfg, optional_vars, 'CCC') # The TRIGGER_DEVICE attribute is mandatory if getattr(cfg, 'TRIGGER_DEVICE') == None: logger.warning( 'The trigger device is set to None! No events will be saved.')
def signal(self, value): if self.lpttype == 'SOFTWARE': if self.verbose is True: logger.info('Sending software trigger %s' % value) return self.write_event(value) elif self.lpttype == 'FAKE': logger.info('Sending FAKE trigger signal %s' % value) return True else: if self.offtimer.is_alive(): logger.warning( 'You are sending a new signal before the end of the last signal. Signal ignored.' ) logger.warning('self.delay=%.1f' % self.delay) return False self.set_data(value) if self.verbose is True: logger.info('Sending %s' % value) self.offtimer.start() return True
def sample_decoding(decoder): """ Decoding example Parameters ---------- decoder : The decoder to use """ def get_index_max(seq): if type(seq) == list: return max(range(len(seq)), key=seq.__getitem__) elif type(seq) == dict: return max(seq, key=seq.__getitem__) else: logger.error('Unsupported input %s' % type(seq)) return None # load trigger definitions for labeling labels = decoder.get_label_names() tm_watchdog = Timer(autoreset=True) tm_cls = Timer() while True: praw = decoder.get_prob_unread() psmooth = decoder.get_prob_smooth() if praw is None: # watch dog if tm_cls.sec() > 5: logger.warning( 'No classification was done in the last 5 seconds. Are you receiving data streams?' ) tm_cls.reset() tm_watchdog.sleep_atleast(0.001) continue txt = '[%8.1f msec]' % (tm_cls.msec()) for i, label in enumerate(labels): txt += ' %s %.3f (raw %.3f)' % (label, psmooth[i], praw[i]) maxi = get_index_max(psmooth) txt += ' %s' % labels[maxi] print(txt) tm_cls.reset()
def check_config(cfg): """ Ensure that the config file contains the parameters """ critical_vars = { 'COMMON': ['DATA_PATH'], } optional_vars = { } for key in critical_vars['COMMON']: if not hasattr(cfg, key): logger.error('%s is a required parameter' % key) raise RuntimeError for key in optional_vars: if not hasattr(cfg, key): setattr(cfg, key, optional_vars[key]) logger.warning('Setting undefined parameter %s=%s' % (key, getattr(cfg, key))) return cfg
def fix_channel_names(fif_dir, new_channel_names): ''' Change channel names of fif files in a given directory. Input ----- @fif_dir: path to fif files @new_channel_names: list of new channel names Output ------ Modified fif files are saved in fif_dir/corrected/ Kyuhwa Lee, 2019. ''' flist = [] for f in qc.get_file_list(fif_dir): if qc.parse_path(f).ext == 'fif': flist.append(f) if len(flist) > 0: qc.make_dirs('%s/corrected' % fif_dir) for f in qc.get_file_list(fif_dir): logger.info('\nLoading %s' % f) p = qc.parse_path(f) if p.ext == 'fif': raw, eve = pu.load_raw(f) if len(raw.ch_names) != len(new_channel_names): raise RuntimeError('The number of new channels do not matach that of fif file.') raw.info['ch_names'] = new_channel_names for ch, new_ch in zip(raw.info['chs'], new_channel_names): ch['ch_name'] = new_ch out_fif = '%s/corrected/%s.fif' % (p.dir, p.name) logger.info('Exporting to %s' % out_fif) raw.save(out_fif) else: logger.warning('No fif files found in %s' % fif_dir)
def check_config(cfg): """ Ensure that the config file contains the parameters """ critical_vars = {'COMMON': ['DATA_PATH']} optional_vars = { 'AMP_NAME': None, 'AMP_SERIAL': None, 'GLOBAL_TIME': 1.0 * 60, 'NJOBS': 1, } for key in critical_vars['COMMON']: if not hasattr(cfg, key): logger.error('%s is a required parameter' % key) raise RuntimeError for key in optional_vars: if not hasattr(cfg, key): setattr(cfg, key, optional_vars[key]) logger.warning('Setting undefined parameter %s=%s' % (key, getattr(cfg, key)))
def check_config(cfg): """ Check if the required parameters are defined in the config module Parameters ---------- cfg : python.module The loaded config module """ critical_vars = { 'COMMON': [ 'AAA', 'BBB', 'CCC'], } # Check the critical variables optional_vars = { 'COMMON': { 'DDD', 'EEE'}, # Internal parmameters for the AAA 'XXX': { 'min': 1, 'max': 40, }, } # Check the critical variables _check_cfg_mandatory(cfg, critical_vars, 'COMMON') # Check the optional variables _check_cfg_optional(cfg, optional_vars, 'COMMON') # Check the internal param of AAA _check_cfg_selected(cfg, optional_vars, 'AAA') if getattr(cfg, 'TRIGGER_DEVICE') == None: logger.warning('The trigger device is set to None! No events will be saved.') return cfg
def slice_win(epochs_data, w_starts, w_length, psde, picks=None, title=None, flatten=True, preprocess=None, verbose=False): ''' Compute PSD values of a sliding window Params epochs_data ([channels]x[samples]): raw epoch data w_starts (list): starting indices of sample segments w_length (int): window length in number of samples psde: MNE PSDEstimator object picks (list): subset of channels within epochs_data title (string): print out the title associated with PID flatten (boolean): generate concatenated feature vectors If True: X = [windows] x [channels x freqs] If False: X = [windows] x [channels] x [freqs] preprocess (dict): None or parameters for pycnbi_utils.preprocess() with the following keys: sfreq, spatial, spatial_ch, spectral, spectral_ch, notch, notch_ch, multiplier, ch_names, rereference, decim, n_jobs Returns: [windows] x [channels*freqs] or [windows] x [channels] x [freqs] ''' # raise error for wrong indexing def WrongIndexError(Exception): logger.error('%s' % Exception) if type(w_length) is not int: logger.warning('w_length type is %s. Converting to int.' % type(w_length)) w_length = int(w_length) if title is None: title = '[PID %d] Frames %d-%d' % (os.getpid(), w_starts[0], w_starts[-1] + w_length - 1) else: title = '[PID %d] %s' % (os.getpid(), title) if preprocess is not None and preprocess['decim'] != 1: title += ' (decim factor %d)' % preprocess['decim'] logger.info(title) X = None for n in w_starts: n = int(round(n)) if n >= epochs_data.shape[1]: logger.error( 'w_starts has an out-of-bounds index %d for epoch length %d.' % (n, epochs_data.shape[1])) raise WrongIndexError window = epochs_data[:, n:(n + w_length)] if preprocess is not None: window = pu.preprocess(window, sfreq=preprocess['sfreq'], spatial=preprocess['spatial'], spatial_ch=preprocess['spatial_ch'], spectral=preprocess['spectral'], spectral_ch=preprocess['spectral_ch'], notch=preprocess['notch'], notch_ch=preprocess['notch_ch'], multiplier=preprocess['multiplier'], ch_names=preprocess['ch_names'], rereference=preprocess['rereference'], decim=preprocess['decim'], n_jobs=preprocess['n_jobs']) # dimension: psde.transform( [epochs x channels x times] ) psd = psde.transform( window.reshape((1, window.shape[0], window.shape[1]))) psd = psd.reshape((psd.shape[0], psd.shape[1] * psd.shape[2])) if picks: psd = psd[0][picks] psd = psd.reshape((1, len(psd))) if X is None: X = psd else: X = np.concatenate((X, psd), axis=0) if verbose == True: logger.info('[PID %d] processing frame %d / %d' % (os.getpid(), n, w_starts[-1])) return X
def test_receiver(): import mne import os CH_INDEX = [1] # channel to monitor TIME_INDEX = None # integer or None. None = average of raw values of the current window SHOW_PSD = False mne.set_log_level('ERROR') os.environ['OMP_NUM_THREADS'] = '1' # actually improves performance for multitaper # connect to LSL server amp_name, amp_serial = pu.search_lsl() sr = StreamReceiver(window_size=1, buffer_size=1, amp_serial=amp_serial, eeg_only=False, amp_name=amp_name) sfreq = sr.get_sample_rate() trg_ch = sr.get_trigger_channel() logger.info('Trigger channel = %d' % trg_ch) # PSD init if SHOW_PSD: psde = mne.decoding.PSDEstimator(sfreq=sfreq, fmin=1, fmax=50, bandwidth=None, \ adaptive=False, low_bias=True, n_jobs=1, normalization='length', verbose=None) watchdog = qc.Timer() tm = qc.Timer(autoreset=True) last_ts = 0 while True: sr.acquire() window, tslist = sr.get_window() # window = [samples x channels] window = window.T # chanel x samples qc.print_c('LSL Diff = %.3f' % (pylsl.local_clock() - tslist[-1]), 'G') # print event values tsnew = np.where(np.array(tslist) > last_ts)[0] if len(tsnew) == 0: logger.warning('There seems to be delay in receiving data.') time.sleep(1) continue trigger = np.unique(window[trg_ch, tsnew[0]:]) # for Biosemi # if sr.amp_name=='BioSemi': # trigger= set( [255 & int(x-1) for x in trigger ] ) if len(trigger) > 0: logger.info('Triggers: %s' % np.array(trigger)) logger.info('[%.1f] Receiving data...' % watchdog.sec()) if TIME_INDEX is None: datatxt = qc.list2string(np.mean(window[CH_INDEX, :], axis=1), '%-15.6f') print('[%.3f : %.3f]' % (tslist[0], tslist[-1]) + ' data: %s' % datatxt) else: datatxt = qc.list2string(window[CH_INDEX, TIME_INDEX], '%-15.6f') print('[%.3f]' % tslist[TIME_INDEX] + ' data: %s' % datatxt) # show PSD if SHOW_PSD: psd = psde.transform(window.reshape((1, window.shape[0], window.shape[1]))) psd = psd.reshape((psd.shape[1], psd.shape[2])) psdmean = np.mean(psd, axis=1) for p in psdmean: print('%.1f' % p, end=' ') last_ts = tslist[-1] tm.sleep_atleast(0.05)
def acquire(self, blocking=True): """ Reads data into buffer. It is a blocking function as default. Fills the buffer and return the current chunk of data and timestamps. Returns: data [samples x channels], timestamps [samples] """ timestamp_offset = False if len(self.timestamps[0]) == 0: timestamp_offset = True self.watchdog.reset() tslist = [] received = False chunk = None while not received: while self.watchdog.sec() < 5: # chunk = [frames]x[ch], tslist = [frames] if len(tslist) == 0: chunk, tslist = self.inlets[0].pull_chunk(max_samples=self.stream_bufsize) if blocking == False and len(tslist) == 0: return np.empty((0, len(self.ch_list))), [] if len(tslist) > 0: if timestamp_offset is True: lsl_clock = pylsl.local_clock() received = True break time.sleep(0.0005) else: logger.warning('Timeout occurred while acquiring data. Amp driver bug?') # give up and return empty values to avoid deadlock return np.empty((0, len(self.ch_list))), [] data = np.array(chunk) # BioSemi has pull-up resistor instead of pull-down if self.amp_name == 'BioSemi' and self._lsl_tr_channel is not None: datatype = data.dtype data[:, self._lsl_tr_channel] = (np.bitwise_and(255, data[:, self._lsl_tr_channel].astype(int)) - 1).astype(datatype) # multiply values (to change unit) if self.multiplier != 1: data[:, self._lsl_eeg_channels] *= self.multiplier if self._lsl_tr_channel is not None: # move trigger channel to 0 and add back to the buffer data = np.concatenate((data[:, self._lsl_tr_channel].reshape(-1, 1), data[:, self._lsl_eeg_channels]), axis=1) else: # add an empty channel with zeros to channel 0 data = np.concatenate((np.zeros((data.shape[0],1)), data[:, self._lsl_eeg_channels]), axis=1) # add data to buffer chunk = data.tolist() self.buffers[0].extend(chunk) self.timestamps[0].extend(tslist) if self.bufsize > 0 and len(self.timestamps[0]) > self.bufsize: self.buffers[0] = self.buffers[0][-self.bufsize:] self.timestamps[0] = self.timestamps[0][-self.bufsize:] if timestamp_offset is True: timestamp_offset = False logger.info('LSL timestamp = %s' % lsl_clock) logger.info('Server timestamp = %s' % self.timestamps[-1][-1]) self.lsl_time_offset = self.timestamps[-1][-1] - lsl_clock logger.info('Offset = %.3f ' % (self.lsl_time_offset)) if abs(self.lsl_time_offset) > 0.1: logger.warning('LSL server has a high timestamp offset.') else: logger.info_green('LSL time server synchronized') ''' TODO: test the merging of multiple streams # if we have multiple synchronized amps if len(self.inlets) > 1: for i in range(1, len(self.inlets)): chunk, tslist = self.inlets[i].pull_chunk(max_samples=len(tslist)) # [frames][channels] self.buffers[i].extend(chunk) self.timestamps[i].extend(tslist) if self.bufsize > 0 and len(self.buffers[i]) > self.bufsize: self.buffers[i] = self.buffers[i][-self.bufsize:] ''' # data= array[samples, channels], tslist=[samples] return (data, tslist)
def compute_features(cfg): ''' Compute features using config specification. Performs preprocessing, epcoching and feature computation. Input ===== Config file object Output ====== Feature data in dictionary - X_data: feature vectors - Y_data: feature labels - wlen: window length in seconds - w_frames: window length in frames - psde: MNE PSD estimator object - picks: channels used for feature computation - sfreq: sampling frequency - ch_names: channel names - times: feature timestamp (leading edge of a window) ''' # Preprocessing, epoching and PSD computation ftrain = [] for f in qc.get_file_list(cfg.DATA_PATH, fullpath=True): if f[-4:] in ['.fif', '.fiff']: ftrain.append(f) if len(ftrain) > 1 and cfg.PICKED_CHANNELS is not None and type( cfg.PICKED_CHANNELS[0]) == int: logger.error( 'When loading multiple EEG files, PICKED_CHANNELS must be list of string, not integers because they may have different channel order.' ) raise RuntimeError raw, events = pu.load_multi(ftrain) reref = cfg.REREFERENCE[cfg.REREFERENCE['selected']] if reref is not None: pu.rereference(raw, reref['New'], reref['Old']) if cfg.LOAD_EVENTS[cfg.LOAD_EVENTS['selected']] is not None: events = mne.read_events(cfg.LOAD_EVENTS[cfg.LOAD_EVENTS['selected']]) trigger_def_int = set() for a in cfg.TRIGGER_DEF: trigger_def_int.add(getattr(cfg.tdef, a)) triggers = {cfg.tdef.by_value[c]: c for c in trigger_def_int} # Pick channels if cfg.PICKED_CHANNELS is None: chlist = [int(x) for x in pick_types(raw.info, stim=False, eeg=True)] else: chlist = cfg.PICKED_CHANNELS picks = [] for c in chlist: if type(c) == int: picks.append(c) elif type(c) == str: picks.append(raw.ch_names.index(c)) else: logger.error( 'PICKED_CHANNELS has a value of unknown type %s.\nPICKED_CHANNELS=%s' % (type(c), cfg.PICKED_CHANNELS)) raise RuntimeError if cfg.EXCLUDED_CHANNELS is not None: for c in cfg.EXCLUDED_CHANNELS: if type(c) == str: if c not in raw.ch_names: logger.warning( 'Exclusion channel %s does not exist. Ignored.' % c) continue c_int = raw.ch_names.index(c) elif type(c) == int: c_int = c else: logger.error( 'EXCLUDED_CHANNELS has a value of unknown type %s.\nEXCLUDED_CHANNELS=%s' % (type(c), cfg.EXCLUDED_CHANNELS)) raise RuntimeError if c_int in picks: del picks[picks.index(c_int)] if max(picks) > len(raw.ch_names): logger.error( '"picks" has a channel index %d while there are only %d channels.' % (max(picks), len(raw.ch_names))) raise ValueError if hasattr(cfg, 'SP_CHANNELS') and cfg.SP_CHANNELS is not None: logger.warning( 'SP_CHANNELS parameter is not supported yet. Will be set to PICKED_CHANNELS.' ) if hasattr(cfg, 'TP_CHANNELS') and cfg.TP_CHANNELS is not None: logger.warning( 'TP_CHANNELS parameter is not supported yet. Will be set to PICKED_CHANNELS.' ) if hasattr(cfg, 'NOTCH_CHANNELS') and cfg.NOTCH_CHANNELS is not None: logger.warning( 'NOTCH_CHANNELS parameter is not supported yet. Will be set to PICKED_CHANNELS.' ) if 'decim' not in cfg.FEATURES['PSD']: cfg.FEATURES['PSD']['decim'] = 1 logger.warning('PSD["decim"] undefined. Set to 1.') # Read epochs try: # Experimental: multiple epoch ranges if type(cfg.EPOCH[0]) is list: epochs_train = [] for ep in cfg.EPOCH: epoch = Epochs(raw, events, triggers, tmin=ep[0], tmax=ep[1], proj=False, picks=picks, baseline=None, preload=True, verbose=False, detrend=None) epochs_train.append(epoch) else: # Usual method: single epoch range epochs_train = Epochs(raw, events, triggers, tmin=cfg.EPOCH[0], tmax=cfg.EPOCH[1], proj=False, picks=picks, baseline=None, preload=True, verbose=False, detrend=None, on_missing='warning') except: logger.exception('Problem while epoching.') raise RuntimeError label_set = np.unique(triggers.values()) # Compute features if cfg.FEATURES['selected'] == 'PSD': preprocess = dict(sfreq=epochs_train.info['sfreq'], spatial=cfg.SP_FILTER, spatial_ch=None, spectral=cfg.TP_FILTER[cfg.TP_FILTER['selected']], spectral_ch=None, notch=cfg.NOTCH_FILTER[cfg.NOTCH_FILTER['selected']], notch_ch=None, multiplier=cfg.MULTIPLIER, ch_names=None, rereference=None, decim=cfg.FEATURES['PSD']['decim'], n_jobs=cfg.N_JOBS) featdata = get_psd_feature(epochs_train, cfg.EPOCH, cfg.FEATURES['PSD'], picks=None, preprocess=preprocess, n_jobs=cfg.N_JOBS) elif cfg.FEATURES == 'TIMELAG': ''' TODO: Implement multiple epochs for timelag feature ''' logger.error( 'MULTIPLE EPOCHS NOT IMPLEMENTED YET FOR TIMELAG FEATURE.') raise NotImplementedError elif cfg.FEATURES == 'WAVELET': ''' TODO: Implement multiple epochs for wavelet feature ''' logger.error( 'MULTIPLE EPOCHS NOT IMPLEMENTED YET FOR WAVELET FEATURE.') raise NotImplementedError else: logger.error('%s feature type is not supported.' % cfg.FEATURES) raise NotImplementedError featdata['picks'] = picks featdata['sfreq'] = raw.info['sfreq'] featdata['ch_names'] = raw.ch_names return featdata
def confusion_matrix(Y_true, Y_pred, label_len=6): """ Generate confusion matrix in a string format Parameters ---------- Y_true : list The true labels Y_pred : list The test labels label_len : int The maximum label text length displayed (minimum length: 6) Returns ------- cfmat : str The confusion matrix in str format (X-axis: prediction, -axis: ground truth) acc : float The accuracy """ import numpy as np from sklearn.metrics import confusion_matrix as sk_confusion_matrix # find labels if type(Y_true) == np.ndarray: Y_labels = np.unique(Y_true) else: Y_labels = list(set(Y_true)) # Check the provided label name length if label_len < 6: label_len = 6 logger.warning('label_len < 6. Setting to 6.') label_tpl = '%' + '-%ds' % label_len col_tpl = '%' + '-%d.2f' % label_len # sanity check if len(Y_pred) > len(Y_true): raise RuntimeError('Y_pred has more items than Y_true') elif len(Y_pred) < len(Y_true): Y_true = Y_true[:len(Y_pred)] cm = sk_confusion_matrix(Y_true, Y_pred, Y_labels) # compute confusion matrix cm_rate = cm.copy().astype('float') cm_sum = np.sum(cm, axis=1) # Fill confusion string for r, s in zip(cm_rate, cm_sum): if s > 0: r /= s cm_txt = label_tpl % 'gt\dt' for l in Y_labels: cm_txt += label_tpl % str(l)[:label_len] cm_txt += '\n' for l, r in zip(Y_labels, cm_rate): cm_txt += label_tpl % str(l)[:label_len] for c in r: cm_txt += col_tpl % c cm_txt += '\n' # compute accuracy correct = 0.0 for c in range(cm.shape[0]): correct += cm[c][c] cm_sum = cm.sum() if cm_sum > 0: acc = correct / cm.sum() else: acc = 0.0 return cm_txt, acc
def check_config(cfg): critical_vars = { 'COMMON': [ 'DECODER_FILE', 'TRIGGER_DEVICE', 'TRIGGER_FILE', 'DIRECTIONS', 'TRIALS_EACH', 'PROB_ALPHA_NEW' ], 'TIMINGS': ['INIT', 'GAP', 'READY', 'FEEDBACK', 'DIR_CUE', 'CLASSIFY'], 'BAR_STEP': ['left', 'right', 'up', 'down', 'both'] } optional_vars = { 'AMP_NAME': None, 'FAKE_CLS': None, 'TRIALS_RANDOMIZE': True, 'BAR_SLOW_START': { 'selected': 'False', 'False': None, 'True': [1.0] }, 'PARALLEL_DECODING': { 'selected': 'False', 'False': None, 'True': { 'period': 0.06, 'num_strides': 3 } }, 'SHOW_TRIALS': True, 'FREE_STYLE': False, 'REFRESH_RATE': 30, 'BAR_BIAS': None, 'BAR_REACH_FINISH': False, 'FEEDBACK_TYPE': 'BAR', 'SHOW_CUE': True, 'SCREEN_SIZE': (1920, 1080), 'SCREEN_POS': (0, 0), 'DEBUG_PROBS': False, 'LOG_PROBS': False, 'WITH_REX': False, 'WITH_STIMO': False, 'ADAPTIVE': None, } for key in critical_vars['COMMON']: if not hasattr(cfg, key): logger.error('%s is a required parameter' % key) raise RuntimeError if not hasattr(cfg, 'TIMINGS'): logger.error('"TIMINGS" not defined in config.') raise RuntimeError for v in critical_vars['TIMINGS']: if v not in cfg.TIMINGS: logger.error('%s not defined in config.' % v) raise RuntimeError if not hasattr(cfg, 'BAR_STEP'): logger.error('"BAR_STEP" not defined in config.') raise RuntimeError for v in critical_vars['BAR_STEP']: if v not in cfg.BAR_STEP: logger.error('%s not defined in config.' % v) raise RuntimeError for key in optional_vars: if not hasattr(cfg, key): setattr(cfg, key, optional_vars[key]) logger.warning('Setting undefined parameter %s=%s' % (key, getattr(cfg, key))) if getattr(cfg, 'TRIGGER_DEVICE') == None: logger.warning( 'The trigger device is set to None! No events will be saved.')
def run(cfg, state=mp.Value('i', 1), queue=None): def confusion_matrix(Y_true, Y_pred, label_len=6): """ Generate confusion matrix in a string format Parameters ---------- Y_true : list The true labels Y_pred : list The test labels label_len : int The maximum label text length displayed (minimum length: 6) Returns ------- cfmat : str The confusion matrix in str format (X-axis: prediction, -axis: ground truth) acc : float The accuracy """ import numpy as np from sklearn.metrics import confusion_matrix as sk_confusion_matrix # find labels if type(Y_true) == np.ndarray: Y_labels = np.unique(Y_true) else: Y_labels = list(set(Y_true)) # Check the provided label name length if label_len < 6: label_len = 6 logger.warning('label_len < 6. Setting to 6.') label_tpl = '%' + '-%ds' % label_len col_tpl = '%' + '-%d.2f' % label_len # sanity check if len(Y_pred) > len(Y_true): raise RuntimeError('Y_pred has more items than Y_true') elif len(Y_pred) < len(Y_true): Y_true = Y_true[:len(Y_pred)] cm = sk_confusion_matrix(Y_true, Y_pred, Y_labels) # compute confusion matrix cm_rate = cm.copy().astype('float') cm_sum = np.sum(cm, axis=1) # Fill confusion string for r, s in zip(cm_rate, cm_sum): if s > 0: r /= s cm_txt = label_tpl % 'gt\dt' for l in Y_labels: cm_txt += label_tpl % str(l)[:label_len] cm_txt += '\n' for l, r in zip(Y_labels, cm_rate): cm_txt += label_tpl % str(l)[:label_len] for c in r: cm_txt += col_tpl % c cm_txt += '\n' # compute accuracy correct = 0.0 for c in range(cm.shape[0]): correct += cm[c][c] cm_sum = cm.sum() if cm_sum > 0: acc = correct / cm.sum() else: acc = 0.0 return cm_txt, acc redirect_stdout_to_queue(logger, queue, 'INFO') # Wait the recording to start (GUI) while state.value == 2: # 0: stop, 1:start, 2:wait pass # Protocol runs if state equals to 1 if not state.value: sys.exit(-1) if cfg.FAKE_CLS is None: # chooose amp if cfg.AMP_NAME is None: amp_name = search_lsl(ignore_markers=True, state=state) else: amp_name = cfg.AMP_NAME fake_dirs = None else: amp_name = None fake_dirs = [v for (k, v) in cfg.DIRECTIONS] # events and triggers tdef = TriggerDef(cfg.TRIGGER_FILE) #if cfg.TRIGGER_DEVICE is None: # input('\n** Warning: No trigger device set. Press Ctrl+C to stop or Enter to continue.') trigger = Trigger(cfg.TRIGGER_DEVICE, state) if trigger.init(50) == False: logger.error( 'Cannot connect to USB2LPT device. Use a mock trigger instead?') input('Press Ctrl+C to stop or Enter to continue.') trigger = Trigger('FAKE', state) trigger.init(50) # For adaptive (need to share the actual true label accross process) label = mp.Value('i', 0) # init classification decoder = BCIDecoderDaemon(amp_name, cfg.DECODER_FILE, buffer_size=1.0, fake=(cfg.FAKE_CLS is not None), fake_dirs=fake_dirs, \ parallel=cfg.PARALLEL_DECODING[cfg.PARALLEL_DECODING['selected']], alpha_new=cfg.PROB_ALPHA_NEW, label=label) # OLD: requires trigger values to be always defined #labels = [tdef.by_value[x] for x in decoder.get_labels()] # NEW: events can be mapped into integers: labels = [] dirdata = set([d[1] for d in cfg.DIRECTIONS]) for x in decoder.get_labels(): if x not in dirdata: labels.append(tdef.by_value[x]) else: labels.append(x) # map class labels to bar directions bar_def = {label: str(dir) for dir, label in cfg.DIRECTIONS} bar_dirs = [bar_def[l] for l in labels] dir_seq = [] for x in range(cfg.TRIALS_EACH): dir_seq.extend(bar_dirs) logger.info('Initializing decoder.') while decoder.is_running() == 0: time.sleep(0.01) # bar visual object if cfg.FEEDBACK_TYPE == 'BAR': from neurodecode.protocols.viz_bars import BarVisual visual = BarVisual(cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE) elif cfg.FEEDBACK_TYPE == 'BODY': assert hasattr(cfg, 'FEEDBACK_IMAGE_PATH' ), 'FEEDBACK_IMAGE_PATH is undefined in your config.' from neurodecode.protocols.viz_human import BodyVisual visual = BodyVisual(cfg.FEEDBACK_IMAGE_PATH, use_glass=cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE) visual.put_text('Waiting to start') if cfg.LOG_PROBS: logdir = io.parse_path(cfg.DECODER_FILE).dir probs_logfile = time.strftime(logdir + "probs-%Y%m%d-%H%M%S.txt", time.localtime()) else: probs_logfile = None feedback = Feedback(cfg, state, visual, tdef, trigger, probs_logfile) # If adaptive classifier if cfg.ADAPTIVE[cfg.ADAPTIVE['selected']]: nb_runs = cfg.ADAPTIVE[cfg.ADAPTIVE['selected']][0] adaptive = True else: nb_runs = 1 adaptive = False run = 1 while run <= nb_runs: if cfg.TRIALS_RANDOMIZE: random.shuffle(dir_seq) else: dir_seq = [d[0] for d in cfg.DIRECTIONS] * cfg.TRIALS_EACH num_trials = len(dir_seq) # For adaptive, retrain classifier if run > 1: # Allow to retrain classifier with decoder.label.get_lock(): decoder.label.value = 1 # Wait that the retraining is done while decoder.label.value == 1: time.sleep(0.01) feedback.viz.put_text('Press any key') feedback.viz.update() cv2.waitKeyEx() feedback.viz.fill() # start trial = 1 dir_detected = [] prob_history = {c: [] for c in bar_dirs} while trial <= num_trials: if cfg.SHOW_TRIALS: title_text = 'Trial %d / %d' % (trial, num_trials) else: title_text = 'Ready' true_label = dir_seq[trial - 1] # profiling feedback #import cProfile #pr = cProfile.Profile() #pr.enable() result = feedback.classify(decoder, true_label, title_text, bar_dirs, prob_history=prob_history, adaptive=adaptive) #pr.disable() #pr.print_stats(sort='time') if result is None: decoder.stop() return else: pred_label = result dir_detected.append(pred_label) if cfg.WITH_REX is True and pred_label == true_label: # if cfg.WITH_REX is True: if pred_label == 'U': rex_dir = 'N' elif pred_label == 'L': rex_dir = 'W' elif pred_label == 'R': rex_dir = 'E' elif pred_label == 'D': rex_dir = 'S' else: logger.warning('Rex cannot execute undefined action %s' % pred_label) rex_dir = None if rex_dir is not None: visual.move(pred_label, 100, overlay=False, barcolor='B') visual.update() logger.info('Executing Rex action %s' % rex_dir) os.system( '%s/Rex/RexControlSimple.exe %s %s' % (os.environ['NEUROD_ROOT'], cfg.REX_COMPORT, rex_dir)) time.sleep(8) if true_label == pred_label: msg = 'Correct' else: msg = 'Wrong' if cfg.TRIALS_RETRY is False or true_label == pred_label: logger.info('Trial %d: %s (%s -> %s)' % (trial, msg, true_label, pred_label)) trial += 1 if len(dir_detected) > 0: # write performance and log results fdir = io.parse_path(cfg.DECODER_FILE).dir logfile = time.strftime(fdir + "/online-%Y%m%d-%H%M%S.txt", time.localtime()) with open(logfile, 'w') as fout: fout.write('Ground-truth,Prediction\n') for gt, dt in zip(dir_seq, dir_detected): fout.write('%s,%s\n' % (gt, dt)) cfmat, acc = confusion_matrix(dir_seq, dir_detected) fout.write('\nAccuracy %.3f\nConfusion matrix\n' % acc) fout.write(cfmat) logger.info('Log exported to %s' % logfile) print('\nAccuracy %.3f\nConfusion matrix\n' % acc) print(cfmat) run += 1 visual.finish() with state.get_lock(): state.value = 0 if decoder.is_running(): decoder.stop() ''' # automatic thresholding if prob_history and len(bar_dirs) == 2: total = sum(len(prob_history[c]) for c in prob_history) fout = open(probs_logfile, 'a') msg = 'Automatic threshold optimization.\n' max_acc = 0 max_bias = 0 for bias in np.arange(-0.99, 1.00, 0.01): corrects = 0 for p in prob_history[bar_dirs[0]]: p_biased = (p + bias) / (bias + 1) # new sum = (p+bias) + (1-p) = bias+1 if p_biased >= 0.5: corrects += 1 for p in prob_history[bar_dirs[1]]: p_biased = (p + bias) / (bias + 1) # new sum = (p+bias) + (1-p) = bias+1 if p_biased < 0.5: corrects += 1 acc = corrects / total msg += '%s%.2f: %.3f\n' % (bar_dirs[0], bias, acc) if acc > max_acc: max_acc = acc max_bias = bias msg += 'Max acc = %.3f at bias %.2f\n' % (max_acc, max_bias) fout.write(msg) fout.close() print(msg) ''' logger.info('Finished.')
def confusion_matrix(Y_true, Y_pred, label_len=6): """ Generate confusion matrix in a string format Input ----- Y_true: true labels Y_pred: test labels label_len: maximum label text length (minimum length: 6) Output ------ (cfmat, acc) cfmat: confusion matrix (string) X-axis: prediction Y-axis: ground truth acc: accuracy (float) """ # find labels if type(Y_true) == np.ndarray: Y_labels = np.unique(Y_true) else: Y_labels = [x for x in set(Y_true)] if label_len < 6: label_len = 6 logger.warning('label_len < 6. Setting to 6.') label_tpl = '%' + '-%ds' % label_len col_tpl = '%' + '-%d.2f' % label_len # sanity check if len(Y_pred) > len(Y_true): raise RuntimeError('Y_pred has more items than Y_true') elif len(Y_pred) < len(Y_true): Y_true = Y_true[:len(Y_pred)] cm = sklearn.metrics.confusion_matrix(Y_true, Y_pred, Y_labels) # compute confusion matrix cm_rate = cm.copy().astype('float') # cm_rate= cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] # this can have NaN cm_sum = np.sum(cm, axis=1) for r, s in zip(cm_rate, cm_sum): if s > 0: r /= s cm_txt = label_tpl % 'gt\dt' for l in Y_labels: cm_txt += label_tpl % l[:label_len] cm_txt += '\n' for l, r in zip(Y_labels, cm_rate): cm_txt += label_tpl % l[:label_len] for c in r: cm_txt += col_tpl % c cm_txt += '\n' # compute accuracy correct = 0.0 for c in range(cm.shape[0]): correct += cm[c][c] cm_sum = cm.sum() if cm_sum > 0: acc = correct / cm.sum() else: acc = 0.0 return cm_txt, acc
def connect(self, find_any=True): """ Run in child process """ server_found = False amps = [] channels = 0 while server_found == False: if self.amp_name is None and self.amp_serial is None: logger.info("Looking for a streaming server...") else: logger.info("Looking for %s (Serial %s) ..." % (self.amp_name, self.amp_serial)) streamInfos = pylsl.resolve_streams() if len(streamInfos) > 0: # For now, only 1 amp is supported by a single StreamReceiver object. for si in streamInfos: # is_slave= ('true'==pylsl.StreamInlet(si).info().desc().child('amplifier').child('settings').child('is_slave').first_child().value() ) inlet = pylsl.StreamInlet(si) # LSL XML parser has a bug which crashes so do not use for now #amp_serial = inlet.info().desc().child('acquisition').child_value('serial_number') amp_serial = 'N/A' amp_name = si.name() # connect to a specific amp only? if self.amp_serial is not None and self.amp_serial != amp_serial: continue # connect to a specific amp only? if self.amp_name is not None and self.amp_name != amp_name: continue # EEG streaming server only? if self.eeg_only and si.type() != 'EEG': continue if 'USBamp' in amp_name: logger.info('Found USBamp streaming server %s (type %s, amp_serial %s) @ %s.' % (amp_name, si.type(), amp_serial, si.hostname())) self._lsl_tr_channel = 16 channels += si.channel_count() ch_list = pu.lsl_channel_list(inlet) amps.append(si) server_found = True break elif 'BioSemi' in amp_name: logger.info('Found BioSemi streaming server %s (type %s, amp_serial %s) @ %s.' % (amp_name, si.type(), amp_serial, si.hostname())) self._lsl_tr_channel = 0 # or subtract -6684927? (value when trigger==0) channels += si.channel_count() ch_list = pu.lsl_channel_list(inlet) amps.append(si) server_found = True break elif 'SmartBCI' in amp_name: logger.info('Found SmartBCI streaming server %s (type %s, amp_serial %s) @ %s.' % (amp_name, si.type(), amp_serial, si.hostname())) self._lsl_tr_channel = 23 channels += si.channel_count() ch_list = pu.lsl_channel_list(inlet) amps.append(si) server_found = True break elif 'StreamPlayer' in amp_name: logger.info('Found StreamPlayer streaming server %s (type %s, amp_serial %s) @ %s.' % (amp_name, si.type(), amp_serial, si.hostname())) self._lsl_tr_channel = 0 channels += si.channel_count() ch_list = pu.lsl_channel_list(inlet) amps.append(si) server_found = True break elif 'openvibeSignal' in amp_name: logger.info('Found an Openvibe signal streaming server %s (type %s, amp_serial %s) @ %s.' % (amp_name, si.type(), amp_serial, si.hostname())) ch_list = pu.lsl_channel_list(inlet) self._lsl_tr_channel = find_event_channel(ch_names=ch_list) channels += si.channel_count() amps.append(si) server_found = True # OpenVibe standard unit is Volts, which is not ideal for some numerical computations self.multiplier = 10**6 # change V -> uV unit for OpenVibe sources break elif 'openvibeMarkers' in amp_name: logger.info('Found an Openvibe markers server %s (type %s, amp_serial %s) @ %s.' % (amp_name, si.type(), amp_serial, si.hostname())) ch_list = pu.lsl_channel_list(inlet) self._lsl_tr_channel = find_event_channel(ch_names=ch_list) channels += si.channel_count() amps.append(si) server_found = True break elif find_any: logger.info('Found a streaming server %s (type %s, amp_serial %s) @ %s.' % (amp_name, si.type(), amp_serial, si.hostname())) ch_list = pu.lsl_channel_list(inlet) self._lsl_tr_channel = find_event_channel(ch_names=ch_list) channels += si.channel_count() amps.append(si) server_found = True break time.sleep(1) self.amp_name = amp_name # define EEG channel indices self._lsl_eeg_channels = list(range(channels)) if self._lsl_tr_channel is None: logger.warning('Trigger channel not fonud. Adding an empty channel 0.') else: if self._lsl_tr_channel != 0: logger.info_yellow('Trigger channel found at index %d. Moving to index 0.' % self._lsl_tr_channel) self._lsl_eeg_channels.pop(self._lsl_tr_channel) self._lsl_eeg_channels = np.array(self._lsl_eeg_channels) self.tr_channel = 0 # trigger channel is always set to 0. self.eeg_channels = np.arange(1, channels) # signal channels start from 1. # create new inlets to read from the stream inlets_master = [] inlets_slaves = [] for amp in amps: # data type of the 2nd argument (max_buflen) is int according to LSL C++ specification! inlet = pylsl.StreamInlet(amp, max_buflen=self.stream_bufsec) inlets_master.append(inlet) self.buffers.append([]) self.timestamps.append([]) inlets = inlets_master + inlets_slaves sample_rate = amps[0].nominal_srate() logger.info('Channels: %d' % channels) logger.info('LSL Protocol version: %s' % amps[0].version()) logger.info('Source sampling rate: %.1f' % sample_rate) logger.info('Unit multiplier: %.1f' % self.multiplier) #self.winsize = int(self.winsec * sample_rate) #self.bufsize = int(self.bufsec * sample_rate) self.winsize = int(round(self.winsec * sample_rate)) self.bufsize = int(round(self.bufsec * sample_rate)) self.stream_bufsize = int(round(self.stream_bufsec * sample_rate)) self.sample_rate = sample_rate self.connected = True self.ch_list = ch_list self.inlets = inlets # Note: not picklable! # TODO: check if there's any problem with multiple inlets if len(self.inlets) > 1: logger.warning('Merging of multiple acquisition servers is not supported yet. Only %s will be used.' % amps[0].name()) ''' for i in range(1, len(self.inlets)): chunk, tslist = self.inlets[i].pull_chunk(max_samples=self.stream_bufsize) self.buffers[i].extend(chunk) self.timestamps[i].extend(tslist) if self.bufsize > 0 and len(self.buffers[i]) > self.bufsize: self.buffers[i] = self.buffers[i][-self.bufsize:] ''' # create channel info if self._lsl_tr_channel is None: self.ch_list = ['TRIGGER'] + self.ch_list else: for i, chn in enumerate(self.ch_list): if chn == 'TRIGGER' or chn == 'TRG' or 'STI ' in chn: self.ch_list.pop(i) self.ch_list = ['TRIGGER'] + self.ch_list break logger.info('self.ch_list %s' % self.ch_list) # fill in initial buffer logger.info('Waiting to fill initial buffer of length %d' % (self.winsize)) while len(self.timestamps[0]) < self.winsize: self.acquire() time.sleep(0.1) self.ready = True logger.info('Start receiving stream data.')
elif color.upper() == 'G': c = colorama.Fore.GREEN elif color.upper() == 'Y': c = colorama.Fore.YELLOW elif color.upper() == 'W': c = colorama.Fore.WHITE elif color.upper() == 'C': c = colorama.Fore.CYAN else: logger.error('print_c(): Unknown color code %s' % color) raise ValueError print(colorama.Style.BRIGHT + c + str(msg) + colorama.Style.RESET_ALL, end=end) except ImportError: logger.warning( 'colorama module not found. print_c() will ignore color codes.') def print_c(msg, color, end='\n'): print(msg, end=end) '''""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" List/Dict related """""""""""""""""""""""""""""""""""""""""""""""""""""""""""''' def list2string(vec, fmt, sep=' '): """ Convert a list to string with formatting, separated by sep (default is space). Example: fmt= '%.32e', '%.6f', etc. """
def run(cfg, state=mp.Value('i', 1), queue=None): redirect_stdout_to_queue(logger, queue, 'INFO') # Wait the recording to start (GUI) while state.value == 2: # 0: stop, 1:start, 2:wait pass # Protocol start if equals to 1 if not state.value: sys.exit() refresh_delay = 1.0 / cfg.REFRESH_RATE cfg.tdef = trigger_def(cfg.TRIGGER_FILE) # visualizer keys = { 'left': 81, 'right': 83, 'up': 82, 'down': 84, 'pgup': 85, 'pgdn': 86, 'home': 80, 'end': 87, 'space': 32, 'esc': 27, ',': 44, '.': 46, 's': 115, 'c': 99, '[': 91, ']': 93, '1': 49, '!': 33, '2': 50, '@': 64, '3': 51, '#': 35 } color = dict(G=(20, 140, 0), B=(210, 0, 0), R=(0, 50, 200), Y=(0, 215, 235), K=(0, 0, 0), w=(200, 200, 200)) dir_sequence = [] for x in range(cfg.TRIALS_EACH): dir_sequence.extend(cfg.DIRECTIONS) random.shuffle(dir_sequence) num_trials = len(cfg.DIRECTIONS) * cfg.TRIALS_EACH event = 'start' trial = 1 # Hardware trigger if cfg.TRIGGER_DEVICE is None: logger.warning( 'No trigger device set. Press Ctrl+C to stop or Enter to continue.' ) #input() trigger = pyLptControl.Trigger(state, cfg.TRIGGER_DEVICE) if trigger.init(50) == False: logger.error( '\n** Error connecting to USB2LPT device. Use a mock trigger instead?' ) input('Press Ctrl+C to stop or Enter to continue.') trigger = pyLptControl.MockTrigger() trigger.init(50) # timers timer_trigger = qc.Timer() timer_dir = qc.Timer() timer_refresh = qc.Timer() t_dir = cfg.TIMINGS['DIR'] + random.uniform(-cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE']) t_dir_ready = cfg.TIMINGS['READY'] + random.uniform( -cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE']) bar = BarVisual(cfg.GLASS_USE, screen_pos=cfg.SCREEN_POS, screen_size=cfg.SCREEN_SIZE) bar.fill() bar.glass_draw_cue() # start while trial <= num_trials: timer_refresh.sleep_atleast(refresh_delay) timer_refresh.reset() # segment= { 'cue':(s,e), 'dir':(s,e), 'label':0-4 } (zero-based) if event == 'start' and timer_trigger.sec() > cfg.TIMINGS['INIT']: event = 'gap_s' bar.fill() timer_trigger.reset() trigger.signal(cfg.tdef.INIT) elif event == 'gap_s': if cfg.TRIAL_PAUSE: bar.put_text('Press any key') bar.update() key = cv2.waitKey() if key == keys['esc'] or not state.value: break bar.fill() bar.put_text('Trial %d / %d' % (trial, num_trials)) event = 'gap' timer_trigger.reset() elif event == 'gap' and timer_trigger.sec() > cfg.TIMINGS['GAP']: event = 'cue' bar.fill() bar.draw_cue() trigger.signal(cfg.tdef.CUE) timer_trigger.reset() elif event == 'cue' and timer_trigger.sec() > cfg.TIMINGS['CUE']: event = 'dir_r' dir = dir_sequence[trial - 1] if dir == 'L': # left bar.move('L', 100, overlay=True) trigger.signal(cfg.tdef.LEFT_READY) elif dir == 'R': # right bar.move('R', 100, overlay=True) trigger.signal(cfg.tdef.RIGHT_READY) elif dir == 'U': # up bar.move('U', 100, overlay=True) trigger.signal(cfg.tdef.UP_READY) elif dir == 'D': # down bar.move('D', 100, overlay=True) trigger.signal(cfg.tdef.DOWN_READY) elif dir == 'B': # both hands bar.move('L', 100, overlay=True) bar.move('R', 100, overlay=True) trigger.signal(cfg.tdef.BOTH_READY) else: raise RuntimeError('Unknown direction %d' % dir) timer_trigger.reset() elif event == 'dir_r' and timer_trigger.sec() > t_dir_ready: bar.fill() bar.draw_cue() event = 'dir' timer_trigger.reset() timer_dir.reset() if dir == 'L': # left trigger.signal(cfg.tdef.LEFT_GO) elif dir == 'R': # right trigger.signal(cfg.tdef.RIGHT_GO) elif dir == 'U': # up trigger.signal(cfg.tdef.UP_GO) elif dir == 'D': # down trigger.signal(cfg.tdef.DOWN_GO) elif dir == 'B': # both trigger.signal(cfg.tdef.BOTH_GO) else: raise RuntimeError('Unknown direction %d' % dir) elif event == 'dir' and timer_trigger.sec() > t_dir: event = 'gap_s' bar.fill() trial += 1 logger.info('trial ' + str(trial - 1) + ' done') trigger.signal(cfg.tdef.BLANK) timer_trigger.reset() t_dir = cfg.TIMINGS['DIR'] + random.uniform( -cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE']) t_dir_ready = cfg.TIMINGS['READY'] + random.uniform( -cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE']) # protocol if event == 'dir': dx = min(100, int(100.0 * timer_dir.sec() / t_dir) + 1) if dir == 'L': # L bar.move('L', dx, overlay=True) elif dir == 'R': # R bar.move('R', dx, overlay=True) elif dir == 'U': # U bar.move('U', dx, overlay=True) elif dir == 'D': # D bar.move('D', dx, overlay=True) elif dir == 'B': # Both bar.move('L', dx, overlay=True) bar.move('R', dx, overlay=True) # wait for start if event == 'start': bar.put_text('Waiting to start') bar.update() key = 0xFF & cv2.waitKey(1) if key == keys['esc'] or not state.value: break bar.finish() with state.get_lock(): state.value = 0