Esempio n. 1
0
def _check_cfg_selected(cfg, optional_vars, select):
    """
    Used in case of dict attributes containing subparams
    Check that the selected cfg params is valid and that its
    subparameters are defined.

    Parameters
    ----------
    cfg : python.module
        The config module containing the parameters to check
    optional_vars :
        The optional parameters with predefined values for the param
    selected = the cfg parameter (type=dict) containing a key: selected
    """
    param = getattr(cfg, select)
    selected = param['selected']

    if selected not in param:
        logger.error('%s not defined in config.' % selected)
        raise RuntimeError
    for v, vv in optional_vars[selected].items():
        if v not in param[selected]:
            param[selected].update({v: vv})
            setattr(cfg, select, param)
            logger.warning(
                'Updating internal parameter for classifier %s: %s=%s' %
                (selected, v, vv))
def event_timestamps_to_indices(sigfile, eventfile):
    """
    Convert LSL timestamps to sample indices for separetely recorded events.

    Parameters:
    sigfile: raw signal file (Python Pickle) recorded with stream_recorder.py.
    eventfile: event file where events are indexed with LSL timestamps.

    Returns:
    events list, which can be used as an input to mne.io.RawArray.add_events().
    """

    raw = qc.load_obj(sigfile)
    ts = raw['timestamps'].reshape(-1)
    ts_min = min(ts)
    ts_max = max(ts)
    events = []

    with open(eventfile) as f:
        for l in f:
            data = l.strip().split('\t')
            event_ts = float(data[0])
            event_value = int(data[2])
            # find the first index not smaller than ts
            next_index = np.searchsorted(ts, event_ts)
            if next_index >= len(ts):
                logger.warning('Event %d at time %.3f is out of time range (%.3f - %.3f).' % (event_value, event_ts, ts_min, ts_max))
            else:
                events.append([next_index, 0, event_value])
    return events
Esempio n. 3
0
def load_config(cfg_module):
    if '/' in cfg_module:
        cfg_module = cfg_module.replace('/', '.').replace('.py', '')
        logger.warning('Replacing deprecated config path to new style: %s' %
                       cfg_module)
        logger.warning('Please change your argument.')
    return importlib.import_module(cfg_module)
def run(cfg, state=mp.Value('i', 1), queue=None):
    """
    Online protocol for Alpha/Theta neurofeedback.
    """
    redirect_stdout_to_queue(logger, queue, 'INFO')
    
    # Wait the recording to start (GUI)
    while state.value == 2: # 0: stop, 1:start, 2:wait
        pass

    # Protocol runs if state equals to 1
    if not state.value:
        sys.exit(-1)
    
    #----------------------------------------------------------------------
    # LSL stream connection
    #----------------------------------------------------------------------
    # chooose amp   
    amp_name, amp_serial = find_lsl_stream(cfg, state)
    
    # Connect to lsl stream
    sr = connect_lsl_stream(cfg, amp_name, amp_serial)
    
    # Get sampling rate
    sfreq = sr.get_sample_rate()
    
    # Get trigger channel
    trg_ch = sr.get_trigger_channel()
    
   
    #----------------------------------------------------------------------
    # Main
    #----------------------------------------------------------------------
    global_timer = qc.Timer(autoreset=False)
    internal_timer = qc.Timer(autoreset=True)
    
    while state.value == 1 and global_timer.sec() < cfg.GLOBAL_TIME:
        
        #----------------------------------------------------------------------
        # Data acquisition
        #----------------------------------------------------------------------        
        sr.acquire()
        window, tslist = sr.get_window()    # window = [samples x channels]
        window = window.T                   # window = [channels x samples]
               
        # Check if proper real-time acquisition
        tsnew = np.where(np.array(tslist) > last_ts)[0]
        if len(tsnew) == 0:
            logger.warning('There seems to be delay in receiving data.')
            time.sleep(1)
            continue
    
        #----------------------------------------------------------------------
        # ADD YOUR CODE HERE
        #----------------------------------------------------------------------
        
    
        
        last_ts = tslist[-1]
        internal_timer.sleep_atleast(cfg.TIMER_SLEEP)
Esempio n. 5
0
def check_config(cfg):
    """
    Ensure that the config file contains the parameters
    """
    critical_vars = {
        'COMMON': [
            'TRIGGER_DEVICE', 'TRIGGER_FILE', 'SCREEN_SIZE', 'START_VOICE',
            'END_VOICE'
        ],
    }
    optional_vars = {
        'GLOBAL_TIME': 2 * 60,
        'SCREEN_POS': (0, 0),
        'GLASS_USE': False,
    }

    for key in critical_vars['COMMON']:
        if not hasattr(cfg, key):
            logger.error('%s is a required parameter' % key)
            raise RuntimeError

    for key in optional_vars:
        if not hasattr(cfg, key):
            setattr(cfg, key, optional_vars[key])
            logger.warning('Setting undefined parameter %s=%s' %
                           (key, getattr(cfg, key)))
Esempio n. 6
0
def check_cfg_optional(cfg, optional_vars, key_var):
    """
    Check that the optional parameters are defined and if not assign them

    cfg = config module containing the parameters to check
    optional_vars = optional parameters with predefined values
    key_var = key to look at in optional_vars
    """
    for key, val in optional_vars[key_var].items():
        if not hasattr(cfg, key):
            setattr(cfg, key, val)
            logger.warning('Setting undefined parameter %s=%s' %
                           (key, getattr(cfg, key)))
Esempio n. 7
0
 def stop(self):
     """
     Stop the daemon
     """
     if self.is_running() == 0:
         logger.warning('Decoder already stopped.')
         return
     for running in self.running:
         running.value = 0
     for proc in self.procs:
         proc.join(10)
         if proc.is_alive():
             logger.warning('Process %s did not die properly.' % proc.pid())
     self.reset()
     logger.info(self.stopmsg)
Esempio n. 8
0
def sort_by_value(s, rev=False):
    """
    Sort dictionary or list by value and return a sorted list of keys and values.
    Values must be hashable and unique.
    """
    assert type(s) == dict or type(
        s) == list, 'Input must be a dictionary or list.'
    if type(s) == list:
        s = dict(enumerate(s))
    s_rev = dict((v, k) for k, v in s.items())
    if Q_VERBOSE > 0 and not len(s_rev) == len(s):
        logger.warning('sort_by_value(): %d identical values' %
                       (len(s.values()) - len(set(s.values())) + 1))
    values = sorted(s_rev, reverse=rev)
    keys = [s_rev[x] for x in values]
    return keys, values
 def paintEvent(self, e):
     # Distinguish between paint events from timer and event QT widget resizing, clicking etc (sender is None)
     # We should only paint when the timer triggered the event.
     # Just in case, there's a flag to force a repaint even when we shouldn't repaint
     sender = self.sender()
     if 'force_repaint' not in self.__dict__.keys():
         logger.warning('force_repaint is not set! Is it a Qt bug?')
         self.force_repaint = 0
     if (sender is None) and (not self.force_repaint):
         pass
     else:
         self.force_repaint = 0
         qp = QPainter()
         qp.begin(self)
         # Update the interface
         self.paintInterface(qp)
         qp.end()
Esempio n. 10
0
def merge_events(trigger_file, events, rawfile_in, rawfile_out):
    tdef = trigger_def(trigger_file)
    raw, eve = pu.load_raw(rawfile_in)

    logger.info('=== Before merging ===')
    notfounds = []
    for key in np.unique(eve[:, 2]):
        if key in tdef.by_value:
            logger.info(
                '%s: %d events' %
                (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0])))
        else:
            logger.info('%d: %d events' %
                        (key, len(np.where(eve[:, 2] == key)[0])))
            notfounds.append(key)
    if notfounds:
        for key in notfounds:
            logger.warning('Key %d was not found in the definition file.' %
                           key)

    for key in events:
        ev_src = events[key]
        ev_out = tdef.by_name[key]
        x = []
        for e in ev_src:
            x.append(np.where(eve[:, 2] == tdef.by_name[e])[0])
        eve[np.concatenate(x), 2] = ev_out

    # sanity check
    dups = np.where(0 == np.diff(eve[:, 0]))[0]
    assert len(dups) == 0

    # reset trigger channel
    raw._data[0] *= 0
    raw.add_events(eve, 'TRIGGER')
    raw.save(rawfile_out, overwrite=True)

    logger.info('=== After merging ===')
    for key in np.unique(eve[:, 2]):
        if key in tdef.by_value:
            logger.info(
                '%s: %d events' %
                (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0])))
        else:
            logger.info('%s: %d events' %
                        (key, len(np.where(eve[:, 2] == key)[0])))
Esempio n. 11
0
    def sort_by_value(s, reverse=False):
        assert type(s) == dict or type(
            s) == list, 'Input must be a dictionary or list.'

        if type(s) == list:
            s = dict(enumerate(s))

        s_rev = dict((v, k) for k, v in s.items())

        if not len(s_rev) == len(s):
            logger.warning('sort_by_value(): %d identical values' %
                           (len(s.values()) - len(set(s.values())) + 1))

        values = sorted(s_rev, reverse=reverse)
        keys = [s_rev[x] for x in values]

        return keys, values
Esempio n. 12
0
def check_config(cfg):
    critical_vars = {
        'COMMON': [
            'TRIGGER_DEVICE', 'TRIGGER_FILE', 'SCREEN_SIZE', 'DIRECTIONS',
            'DIR_RANDOM', 'TRIALS_EACH'
        ],
        'TIMINGS': [
            'INIT', 'GAP', 'CUE', 'READY', 'READY_RANDOMIZE', 'DIR',
            'DIR_RANDOMIZE'
        ]
    }
    optional_vars = {
        'FEEDBACK_TYPE': 'BAR',
        'FEEDBACK_IMAGE_PATH': None,
        'SCREEN_POS': (0, 0),
        'DIR_RANDOM': True,
        'GLASS_USE': False,
        'TRIAL_PAUSE': False,
        'REFRESH_RATE': 30
    }

    for key in critical_vars['COMMON']:
        if not hasattr(cfg, key):
            raise RuntimeError('%s is a required parameter' % key)

    if not hasattr(cfg, 'TIMINGS'):
        logger.error('"TIMINGS" not defined in config.')
        raise RuntimeError
    for v in critical_vars['TIMINGS']:
        if v not in cfg.TIMINGS:
            logger.error('%s not defined in config.' % v)
            raise RuntimeError

    for key in optional_vars:
        if not hasattr(cfg, key):
            setattr(cfg, key, optional_vars[key])
            logger.warning('Setting undefined %s=%s' % (key, optional[key]))

    if getattr(cfg, 'TRIGGER_DEVICE') == None:
        logger.warning(
            'The trigger device is set to None! No events will be saved.')
        raise RuntimeError(
            'The trigger device is set to None! No events will be saved.')
Esempio n. 13
0
def check_config(cfg):
    '''
    Check the variables contained in the loaded config file

    Parameters
    ----------
    cfg : python.module
        The loaded config module
    '''

    # Add here the critical variables that need to be defined in the config_offline.py
    critical_vars = {
        'COMMON': ['TRIALS_NB', 'TRIGGER_FILE', 'TRIGGER_DEVICE'],
    }

    # Add here the optional variables that do not need to be defined in the config_offline.py
    # If not defined, the variable will be added with the value defined below
    optional_vars = {
        'COMMON': {
            'REFRESH_RATE': 20,
        },

        # Internal parmameters for the CCC
        'XXX': {
            'min': 1,
            'max': 40,
        },
    }

    # Check the critical variables
    _check_cfg_mandatory(cfg, critical_vars, 'COMMON')

    # Check the optional variables
    _check_cfg_optional(cfg, optional_vars, 'COMMON')

    # Check the internal param of CCC
    _check_cfg_selected(cfg, optional_vars, 'CCC')

    # The TRIGGER_DEVICE attribute is mandatory
    if getattr(cfg, 'TRIGGER_DEVICE') == None:
        logger.warning(
            'The trigger device is set to None! No events will be saved.')
Esempio n. 14
0
 def signal(self, value):
     if self.lpttype == 'SOFTWARE':
         if self.verbose is True:
             logger.info('Sending software trigger %s' % value)
         return self.write_event(value)
     elif self.lpttype == 'FAKE':
         logger.info('Sending FAKE trigger signal %s' % value)
         return True
     else:
         if self.offtimer.is_alive():
             logger.warning(
                 'You are sending a new signal before the end of the last signal. Signal ignored.'
             )
             logger.warning('self.delay=%.1f' % self.delay)
             return False
         self.set_data(value)
         if self.verbose is True:
             logger.info('Sending %s' % value)
         self.offtimer.start()
         return True
Esempio n. 15
0
def sample_decoding(decoder):
    """
    Decoding example

    Parameters
    ----------
    decoder : The decoder to use
    """
    def get_index_max(seq):
        if type(seq) == list:
            return max(range(len(seq)), key=seq.__getitem__)
        elif type(seq) == dict:
            return max(seq, key=seq.__getitem__)
        else:
            logger.error('Unsupported input %s' % type(seq))
            return None

    # load trigger definitions for labeling
    labels = decoder.get_label_names()
    tm_watchdog = Timer(autoreset=True)
    tm_cls = Timer()
    while True:
        praw = decoder.get_prob_unread()
        psmooth = decoder.get_prob_smooth()
        if praw is None:
            # watch dog
            if tm_cls.sec() > 5:
                logger.warning(
                    'No classification was done in the last 5 seconds. Are you receiving data streams?'
                )
                tm_cls.reset()
            tm_watchdog.sleep_atleast(0.001)
            continue

        txt = '[%8.1f msec]' % (tm_cls.msec())
        for i, label in enumerate(labels):
            txt += '   %s %.3f (raw %.3f)' % (label, psmooth[i], praw[i])
        maxi = get_index_max(psmooth)
        txt += '   %s' % labels[maxi]
        print(txt)
        tm_cls.reset()
Esempio n. 16
0
def check_config(cfg):
    """
    Ensure that the config file contains the parameters
    """
    critical_vars = {
        'COMMON': ['DATA_PATH'],
    }
    optional_vars = {
    }
    
    for key in critical_vars['COMMON']:
        if not hasattr(cfg, key):
            logger.error('%s is a required parameter' % key)
            raise RuntimeError
    
    for key in optional_vars:
        if not hasattr(cfg, key):
            setattr(cfg, key, optional_vars[key])
            logger.warning('Setting undefined parameter %s=%s' % (key, getattr(cfg, key)))
    
    return cfg
Esempio n. 17
0
def fix_channel_names(fif_dir, new_channel_names):
    '''
    Change channel names of fif files in a given directory.

    Input
    -----
    @fif_dir: path to fif files
    @new_channel_names: list of new channel names

    Output
    ------
    Modified fif files are saved in fif_dir/corrected/

    Kyuhwa Lee, 2019.
    '''

    flist = []
    for f in qc.get_file_list(fif_dir):
        if qc.parse_path(f).ext == 'fif':
            flist.append(f)

    if len(flist) > 0:
        qc.make_dirs('%s/corrected' % fif_dir)
        for f in qc.get_file_list(fif_dir):
            logger.info('\nLoading %s' % f)
            p = qc.parse_path(f)
            if p.ext == 'fif':
                raw, eve = pu.load_raw(f)
                if len(raw.ch_names) != len(new_channel_names):
                    raise RuntimeError('The number of new channels do not matach that of fif file.')
                raw.info['ch_names'] = new_channel_names
                for ch, new_ch in zip(raw.info['chs'], new_channel_names):
                    ch['ch_name'] = new_ch
                out_fif = '%s/corrected/%s.fif' % (p.dir, p.name)
                logger.info('Exporting to %s' % out_fif)
                raw.save(out_fif)
    else:
        logger.warning('No fif files found in %s' % fif_dir)
Esempio n. 18
0
def check_config(cfg):
    """
    Ensure that the config file contains the parameters
    """
    critical_vars = {'COMMON': ['DATA_PATH']}

    optional_vars = {
        'AMP_NAME': None,
        'AMP_SERIAL': None,
        'GLOBAL_TIME': 1.0 * 60,
        'NJOBS': 1,
    }

    for key in critical_vars['COMMON']:
        if not hasattr(cfg, key):
            logger.error('%s is a required parameter' % key)
            raise RuntimeError

    for key in optional_vars:
        if not hasattr(cfg, key):
            setattr(cfg, key, optional_vars[key])
            logger.warning('Setting undefined parameter %s=%s' %
                           (key, getattr(cfg, key)))
Esempio n. 19
0
def check_config(cfg):
    """
    Check if the required parameters are defined in the config module

    Parameters
    ----------
    cfg : python.module
        The loaded config module
    """
    critical_vars = {
        'COMMON': [ 'AAA',
                    'BBB',
                    'CCC'],
    }

    # Check the critical variables
    optional_vars = {
        'COMMON': { 'DDD',
                    'EEE'},

        # Internal parmameters for the AAA
        'XXX': { 'min': 1, 'max': 40, },
    }

    # Check the critical variables
    _check_cfg_mandatory(cfg, critical_vars, 'COMMON')

    # Check the optional variables
    _check_cfg_optional(cfg, optional_vars, 'COMMON')

    # Check the internal param of AAA
    _check_cfg_selected(cfg, optional_vars, 'AAA')

    if getattr(cfg, 'TRIGGER_DEVICE') == None:
        logger.warning('The trigger device is set to None! No events will be saved.')

    return cfg
Esempio n. 20
0
def slice_win(epochs_data,
              w_starts,
              w_length,
              psde,
              picks=None,
              title=None,
              flatten=True,
              preprocess=None,
              verbose=False):
    '''
    Compute PSD values of a sliding window

    Params
        epochs_data ([channels]x[samples]): raw epoch data
        w_starts (list): starting indices of sample segments
        w_length (int): window length in number of samples
        psde: MNE PSDEstimator object
        picks (list): subset of channels within epochs_data
        title (string): print out the title associated with PID
        flatten (boolean): generate concatenated feature vectors
            If True: X = [windows] x [channels x freqs]
            If False: X = [windows] x [channels] x [freqs]
        preprocess (dict): None or parameters for pycnbi_utils.preprocess() with the following keys:
            sfreq, spatial, spatial_ch, spectral, spectral_ch, notch, notch_ch,
            multiplier, ch_names, rereference, decim, n_jobs
    Returns:
        [windows] x [channels*freqs] or [windows] x [channels] x [freqs]
    '''

    # raise error for wrong indexing
    def WrongIndexError(Exception):
        logger.error('%s' % Exception)

    if type(w_length) is not int:
        logger.warning('w_length type is %s. Converting to int.' %
                       type(w_length))
        w_length = int(w_length)
    if title is None:
        title = '[PID %d] Frames %d-%d' % (os.getpid(), w_starts[0],
                                           w_starts[-1] + w_length - 1)
    else:
        title = '[PID %d] %s' % (os.getpid(), title)
    if preprocess is not None and preprocess['decim'] != 1:
        title += ' (decim factor %d)' % preprocess['decim']
    logger.info(title)

    X = None
    for n in w_starts:
        n = int(round(n))
        if n >= epochs_data.shape[1]:
            logger.error(
                'w_starts has an out-of-bounds index %d for epoch length %d.' %
                (n, epochs_data.shape[1]))
            raise WrongIndexError
        window = epochs_data[:, n:(n + w_length)]

        if preprocess is not None:
            window = pu.preprocess(window,
                                   sfreq=preprocess['sfreq'],
                                   spatial=preprocess['spatial'],
                                   spatial_ch=preprocess['spatial_ch'],
                                   spectral=preprocess['spectral'],
                                   spectral_ch=preprocess['spectral_ch'],
                                   notch=preprocess['notch'],
                                   notch_ch=preprocess['notch_ch'],
                                   multiplier=preprocess['multiplier'],
                                   ch_names=preprocess['ch_names'],
                                   rereference=preprocess['rereference'],
                                   decim=preprocess['decim'],
                                   n_jobs=preprocess['n_jobs'])

        # dimension: psde.transform( [epochs x channels x times] )
        psd = psde.transform(
            window.reshape((1, window.shape[0], window.shape[1])))
        psd = psd.reshape((psd.shape[0], psd.shape[1] * psd.shape[2]))
        if picks:
            psd = psd[0][picks]
            psd = psd.reshape((1, len(psd)))

        if X is None:
            X = psd
        else:
            X = np.concatenate((X, psd), axis=0)

        if verbose == True:
            logger.info('[PID %d] processing frame %d / %d' %
                        (os.getpid(), n, w_starts[-1]))

    return X
Esempio n. 21
0
def test_receiver():
    import mne
    import os

    CH_INDEX = [1] # channel to monitor
    TIME_INDEX = None # integer or None. None = average of raw values of the current window
    SHOW_PSD = False
    mne.set_log_level('ERROR')
    os.environ['OMP_NUM_THREADS'] = '1' # actually improves performance for multitaper

    # connect to LSL server
    amp_name, amp_serial = pu.search_lsl()
    sr = StreamReceiver(window_size=1, buffer_size=1, amp_serial=amp_serial, eeg_only=False, amp_name=amp_name)
    sfreq = sr.get_sample_rate()
    trg_ch = sr.get_trigger_channel()
    logger.info('Trigger channel = %d' % trg_ch)

    # PSD init
    if SHOW_PSD:
        psde = mne.decoding.PSDEstimator(sfreq=sfreq, fmin=1, fmax=50, bandwidth=None, \
            adaptive=False, low_bias=True, n_jobs=1, normalization='length', verbose=None)

    watchdog = qc.Timer()
    tm = qc.Timer(autoreset=True)
    last_ts = 0
    while True:
        sr.acquire()
        window, tslist = sr.get_window() # window = [samples x channels]
        window = window.T # chanel x samples

        qc.print_c('LSL Diff = %.3f' % (pylsl.local_clock() - tslist[-1]), 'G')

        # print event values
        tsnew = np.where(np.array(tslist) > last_ts)[0]
        if len(tsnew) == 0:
            logger.warning('There seems to be delay in receiving data.')
            time.sleep(1)
            continue
        trigger = np.unique(window[trg_ch, tsnew[0]:])

        # for Biosemi
        # if sr.amp_name=='BioSemi':
        #    trigger= set( [255 & int(x-1) for x in trigger ] )

        if len(trigger) > 0:
            logger.info('Triggers: %s' % np.array(trigger))

        logger.info('[%.1f] Receiving data...' % watchdog.sec())

        if TIME_INDEX is None:
            datatxt = qc.list2string(np.mean(window[CH_INDEX, :], axis=1), '%-15.6f')
            print('[%.3f : %.3f]' % (tslist[0], tslist[-1]) + ' data: %s' % datatxt)
        else:
            datatxt = qc.list2string(window[CH_INDEX, TIME_INDEX], '%-15.6f')
            print('[%.3f]' % tslist[TIME_INDEX] + ' data: %s' % datatxt)

        # show PSD
        if SHOW_PSD:
            psd = psde.transform(window.reshape((1, window.shape[0], window.shape[1])))
            psd = psd.reshape((psd.shape[1], psd.shape[2]))
            psdmean = np.mean(psd, axis=1)
            for p in psdmean:
                print('%.1f' % p, end=' ')

        last_ts = tslist[-1]
        tm.sleep_atleast(0.05)
Esempio n. 22
0
    def acquire(self, blocking=True):
        """
        Reads data into buffer. It is a blocking function as default.

        Fills the buffer and return the current chunk of data and timestamps.

        Returns:
            data [samples x channels], timestamps [samples]
        """
        timestamp_offset = False
        if len(self.timestamps[0]) == 0:
            timestamp_offset = True

        self.watchdog.reset()
        tslist = []
        received = False
        chunk = None
        while not received:
            while self.watchdog.sec() < 5:
                # chunk = [frames]x[ch], tslist = [frames]
                if len(tslist) == 0:
                    chunk, tslist = self.inlets[0].pull_chunk(max_samples=self.stream_bufsize)
                    if blocking == False and len(tslist) == 0:
                        return np.empty((0, len(self.ch_list))), []
                if len(tslist) > 0:
                    if timestamp_offset is True:
                        lsl_clock = pylsl.local_clock()
                    received = True
                    break
                time.sleep(0.0005)
            else:
                logger.warning('Timeout occurred while acquiring data. Amp driver bug?')
                # give up and return empty values to avoid deadlock
                return np.empty((0, len(self.ch_list))), []
        data = np.array(chunk)

        # BioSemi has pull-up resistor instead of pull-down
        if self.amp_name == 'BioSemi' and self._lsl_tr_channel is not None:
            datatype = data.dtype
            data[:, self._lsl_tr_channel] = (np.bitwise_and(255, data[:, self._lsl_tr_channel].astype(int)) - 1).astype(datatype)

        # multiply values (to change unit)
        if self.multiplier != 1:
            data[:, self._lsl_eeg_channels] *= self.multiplier

        if self._lsl_tr_channel is not None:
            # move trigger channel to 0 and add back to the buffer
            data = np.concatenate((data[:, self._lsl_tr_channel].reshape(-1, 1),
                                   data[:, self._lsl_eeg_channels]), axis=1)
        else:
            # add an empty channel with zeros to channel 0
            data = np.concatenate((np.zeros((data.shape[0],1)),
                                   data[:, self._lsl_eeg_channels]), axis=1)

        # add data to buffer
        chunk = data.tolist()
        self.buffers[0].extend(chunk)
        self.timestamps[0].extend(tslist)
        if self.bufsize > 0 and len(self.timestamps[0]) > self.bufsize:
            self.buffers[0] = self.buffers[0][-self.bufsize:]
            self.timestamps[0] = self.timestamps[0][-self.bufsize:]

        if timestamp_offset is True:
            timestamp_offset = False
            logger.info('LSL timestamp = %s' % lsl_clock)
            logger.info('Server timestamp = %s' % self.timestamps[-1][-1])
            self.lsl_time_offset = self.timestamps[-1][-1] - lsl_clock
            logger.info('Offset = %.3f ' % (self.lsl_time_offset))
            if abs(self.lsl_time_offset) > 0.1:
                logger.warning('LSL server has a high timestamp offset.')
            else:
                logger.info_green('LSL time server synchronized')

        ''' TODO: test the merging of multiple streams
        # if we have multiple synchronized amps
        if len(self.inlets) > 1:
            for i in range(1, len(self.inlets)):
                chunk, tslist = self.inlets[i].pull_chunk(max_samples=len(tslist))  # [frames][channels]
                self.buffers[i].extend(chunk)
                self.timestamps[i].extend(tslist)
                if self.bufsize > 0 and len(self.buffers[i]) > self.bufsize:
                    self.buffers[i] = self.buffers[i][-self.bufsize:]
        '''

        # data= array[samples, channels], tslist=[samples]
        return (data, tslist)
Esempio n. 23
0
def compute_features(cfg):
    '''
    Compute features using config specification.

    Performs preprocessing, epcoching and feature computation.

    Input
    =====
    Config file object

    Output
    ======
    Feature data in dictionary
    - X_data: feature vectors
    - Y_data: feature labels
    - wlen: window length in seconds
    - w_frames: window length in frames
    - psde: MNE PSD estimator object
    - picks: channels used for feature computation
    - sfreq: sampling frequency
    - ch_names: channel names
    - times: feature timestamp (leading edge of a window)
    '''
    # Preprocessing, epoching and PSD computation
    ftrain = []
    for f in qc.get_file_list(cfg.DATA_PATH, fullpath=True):
        if f[-4:] in ['.fif', '.fiff']:
            ftrain.append(f)
    if len(ftrain) > 1 and cfg.PICKED_CHANNELS is not None and type(
            cfg.PICKED_CHANNELS[0]) == int:
        logger.error(
            'When loading multiple EEG files, PICKED_CHANNELS must be list of string, not integers because they may have different channel order.'
        )
        raise RuntimeError
    raw, events = pu.load_multi(ftrain)

    reref = cfg.REREFERENCE[cfg.REREFERENCE['selected']]
    if reref is not None:
        pu.rereference(raw, reref['New'], reref['Old'])

    if cfg.LOAD_EVENTS[cfg.LOAD_EVENTS['selected']] is not None:
        events = mne.read_events(cfg.LOAD_EVENTS[cfg.LOAD_EVENTS['selected']])

    trigger_def_int = set()
    for a in cfg.TRIGGER_DEF:
        trigger_def_int.add(getattr(cfg.tdef, a))
    triggers = {cfg.tdef.by_value[c]: c for c in trigger_def_int}

    # Pick channels
    if cfg.PICKED_CHANNELS is None:
        chlist = [int(x) for x in pick_types(raw.info, stim=False, eeg=True)]
    else:
        chlist = cfg.PICKED_CHANNELS
    picks = []
    for c in chlist:
        if type(c) == int:
            picks.append(c)
        elif type(c) == str:
            picks.append(raw.ch_names.index(c))
        else:
            logger.error(
                'PICKED_CHANNELS has a value of unknown type %s.\nPICKED_CHANNELS=%s'
                % (type(c), cfg.PICKED_CHANNELS))
            raise RuntimeError
    if cfg.EXCLUDED_CHANNELS is not None:
        for c in cfg.EXCLUDED_CHANNELS:
            if type(c) == str:
                if c not in raw.ch_names:
                    logger.warning(
                        'Exclusion channel %s does not exist. Ignored.' % c)
                    continue
                c_int = raw.ch_names.index(c)
            elif type(c) == int:
                c_int = c
            else:
                logger.error(
                    'EXCLUDED_CHANNELS has a value of unknown type %s.\nEXCLUDED_CHANNELS=%s'
                    % (type(c), cfg.EXCLUDED_CHANNELS))
                raise RuntimeError
            if c_int in picks:
                del picks[picks.index(c_int)]
    if max(picks) > len(raw.ch_names):
        logger.error(
            '"picks" has a channel index %d while there are only %d channels.'
            % (max(picks), len(raw.ch_names)))
        raise ValueError
    if hasattr(cfg, 'SP_CHANNELS') and cfg.SP_CHANNELS is not None:
        logger.warning(
            'SP_CHANNELS parameter is not supported yet. Will be set to PICKED_CHANNELS.'
        )
    if hasattr(cfg, 'TP_CHANNELS') and cfg.TP_CHANNELS is not None:
        logger.warning(
            'TP_CHANNELS parameter is not supported yet. Will be set to PICKED_CHANNELS.'
        )
    if hasattr(cfg, 'NOTCH_CHANNELS') and cfg.NOTCH_CHANNELS is not None:
        logger.warning(
            'NOTCH_CHANNELS parameter is not supported yet. Will be set to PICKED_CHANNELS.'
        )
    if 'decim' not in cfg.FEATURES['PSD']:
        cfg.FEATURES['PSD']['decim'] = 1
        logger.warning('PSD["decim"] undefined. Set to 1.')

    # Read epochs
    try:
        # Experimental: multiple epoch ranges
        if type(cfg.EPOCH[0]) is list:
            epochs_train = []
            for ep in cfg.EPOCH:
                epoch = Epochs(raw,
                               events,
                               triggers,
                               tmin=ep[0],
                               tmax=ep[1],
                               proj=False,
                               picks=picks,
                               baseline=None,
                               preload=True,
                               verbose=False,
                               detrend=None)
                epochs_train.append(epoch)
        else:
            # Usual method: single epoch range
            epochs_train = Epochs(raw,
                                  events,
                                  triggers,
                                  tmin=cfg.EPOCH[0],
                                  tmax=cfg.EPOCH[1],
                                  proj=False,
                                  picks=picks,
                                  baseline=None,
                                  preload=True,
                                  verbose=False,
                                  detrend=None,
                                  on_missing='warning')
    except:
        logger.exception('Problem while epoching.')
        raise RuntimeError

    label_set = np.unique(triggers.values())

    # Compute features
    if cfg.FEATURES['selected'] == 'PSD':
        preprocess = dict(sfreq=epochs_train.info['sfreq'],
                          spatial=cfg.SP_FILTER,
                          spatial_ch=None,
                          spectral=cfg.TP_FILTER[cfg.TP_FILTER['selected']],
                          spectral_ch=None,
                          notch=cfg.NOTCH_FILTER[cfg.NOTCH_FILTER['selected']],
                          notch_ch=None,
                          multiplier=cfg.MULTIPLIER,
                          ch_names=None,
                          rereference=None,
                          decim=cfg.FEATURES['PSD']['decim'],
                          n_jobs=cfg.N_JOBS)
        featdata = get_psd_feature(epochs_train,
                                   cfg.EPOCH,
                                   cfg.FEATURES['PSD'],
                                   picks=None,
                                   preprocess=preprocess,
                                   n_jobs=cfg.N_JOBS)
    elif cfg.FEATURES == 'TIMELAG':
        '''
        TODO: Implement multiple epochs for timelag feature
        '''
        logger.error(
            'MULTIPLE EPOCHS NOT IMPLEMENTED YET FOR TIMELAG FEATURE.')
        raise NotImplementedError
    elif cfg.FEATURES == 'WAVELET':
        '''
        TODO: Implement multiple epochs for wavelet feature
        '''
        logger.error(
            'MULTIPLE EPOCHS NOT IMPLEMENTED YET FOR WAVELET FEATURE.')
        raise NotImplementedError
    else:
        logger.error('%s feature type is not supported.' % cfg.FEATURES)
        raise NotImplementedError

    featdata['picks'] = picks
    featdata['sfreq'] = raw.info['sfreq']
    featdata['ch_names'] = raw.ch_names
    return featdata
Esempio n. 24
0
    def confusion_matrix(Y_true, Y_pred, label_len=6):
        """
        Generate confusion matrix in a string format
        Parameters
        ----------
        Y_true : list
            The true labels
        Y_pred : list
            The test labels
        label_len : int
            The maximum label text length displayed (minimum length: 6)
        Returns
        -------
        cfmat : str
            The confusion matrix in str format (X-axis: prediction, -axis: ground truth)
        acc : float
            The accuracy
        """
        import numpy as np
        from sklearn.metrics import confusion_matrix as sk_confusion_matrix

        # find labels
        if type(Y_true) == np.ndarray:
            Y_labels = np.unique(Y_true)
        else:
            Y_labels = list(set(Y_true))

        # Check the provided label name length
        if label_len < 6:
            label_len = 6
            logger.warning('label_len < 6. Setting to 6.')
        label_tpl = '%' + '-%ds' % label_len
        col_tpl = '%' + '-%d.2f' % label_len

        # sanity check
        if len(Y_pred) > len(Y_true):
            raise RuntimeError('Y_pred has more items than Y_true')
        elif len(Y_pred) < len(Y_true):
            Y_true = Y_true[:len(Y_pred)]

        cm = sk_confusion_matrix(Y_true, Y_pred, Y_labels)

        # compute confusion matrix
        cm_rate = cm.copy().astype('float')
        cm_sum = np.sum(cm, axis=1)

        # Fill confusion string
        for r, s in zip(cm_rate, cm_sum):
            if s > 0:
                r /= s
        cm_txt = label_tpl % 'gt\dt'
        for l in Y_labels:
            cm_txt += label_tpl % str(l)[:label_len]
        cm_txt += '\n'
        for l, r in zip(Y_labels, cm_rate):
            cm_txt += label_tpl % str(l)[:label_len]
            for c in r:
                cm_txt += col_tpl % c
            cm_txt += '\n'

        # compute accuracy
        correct = 0.0
        for c in range(cm.shape[0]):
            correct += cm[c][c]
        cm_sum = cm.sum()
        if cm_sum > 0:
            acc = correct / cm.sum()
        else:
            acc = 0.0

        return cm_txt, acc
Esempio n. 25
0
def check_config(cfg):
    critical_vars = {
        'COMMON': [
            'DECODER_FILE', 'TRIGGER_DEVICE', 'TRIGGER_FILE', 'DIRECTIONS',
            'TRIALS_EACH', 'PROB_ALPHA_NEW'
        ],
        'TIMINGS': ['INIT', 'GAP', 'READY', 'FEEDBACK', 'DIR_CUE', 'CLASSIFY'],
        'BAR_STEP': ['left', 'right', 'up', 'down', 'both']
    }

    optional_vars = {
        'AMP_NAME': None,
        'FAKE_CLS': None,
        'TRIALS_RANDOMIZE': True,
        'BAR_SLOW_START': {
            'selected': 'False',
            'False': None,
            'True': [1.0]
        },
        'PARALLEL_DECODING': {
            'selected': 'False',
            'False': None,
            'True': {
                'period': 0.06,
                'num_strides': 3
            }
        },
        'SHOW_TRIALS': True,
        'FREE_STYLE': False,
        'REFRESH_RATE': 30,
        'BAR_BIAS': None,
        'BAR_REACH_FINISH': False,
        'FEEDBACK_TYPE': 'BAR',
        'SHOW_CUE': True,
        'SCREEN_SIZE': (1920, 1080),
        'SCREEN_POS': (0, 0),
        'DEBUG_PROBS': False,
        'LOG_PROBS': False,
        'WITH_REX': False,
        'WITH_STIMO': False,
        'ADAPTIVE': None,
    }

    for key in critical_vars['COMMON']:
        if not hasattr(cfg, key):
            logger.error('%s is a required parameter' % key)
            raise RuntimeError

    if not hasattr(cfg, 'TIMINGS'):
        logger.error('"TIMINGS" not defined in config.')
        raise RuntimeError
    for v in critical_vars['TIMINGS']:
        if v not in cfg.TIMINGS:
            logger.error('%s not defined in config.' % v)
            raise RuntimeError

    if not hasattr(cfg, 'BAR_STEP'):
        logger.error('"BAR_STEP" not defined in config.')
        raise RuntimeError
    for v in critical_vars['BAR_STEP']:
        if v not in cfg.BAR_STEP:
            logger.error('%s not defined in config.' % v)
            raise RuntimeError

    for key in optional_vars:
        if not hasattr(cfg, key):
            setattr(cfg, key, optional_vars[key])
            logger.warning('Setting undefined parameter %s=%s' %
                           (key, getattr(cfg, key)))

    if getattr(cfg, 'TRIGGER_DEVICE') == None:
        logger.warning(
            'The trigger device is set to None! No events will be saved.')
Esempio n. 26
0
def run(cfg, state=mp.Value('i', 1), queue=None):
    def confusion_matrix(Y_true, Y_pred, label_len=6):
        """
        Generate confusion matrix in a string format
        Parameters
        ----------
        Y_true : list
            The true labels
        Y_pred : list
            The test labels
        label_len : int
            The maximum label text length displayed (minimum length: 6)
        Returns
        -------
        cfmat : str
            The confusion matrix in str format (X-axis: prediction, -axis: ground truth)
        acc : float
            The accuracy
        """
        import numpy as np
        from sklearn.metrics import confusion_matrix as sk_confusion_matrix

        # find labels
        if type(Y_true) == np.ndarray:
            Y_labels = np.unique(Y_true)
        else:
            Y_labels = list(set(Y_true))

        # Check the provided label name length
        if label_len < 6:
            label_len = 6
            logger.warning('label_len < 6. Setting to 6.')
        label_tpl = '%' + '-%ds' % label_len
        col_tpl = '%' + '-%d.2f' % label_len

        # sanity check
        if len(Y_pred) > len(Y_true):
            raise RuntimeError('Y_pred has more items than Y_true')
        elif len(Y_pred) < len(Y_true):
            Y_true = Y_true[:len(Y_pred)]

        cm = sk_confusion_matrix(Y_true, Y_pred, Y_labels)

        # compute confusion matrix
        cm_rate = cm.copy().astype('float')
        cm_sum = np.sum(cm, axis=1)

        # Fill confusion string
        for r, s in zip(cm_rate, cm_sum):
            if s > 0:
                r /= s
        cm_txt = label_tpl % 'gt\dt'
        for l in Y_labels:
            cm_txt += label_tpl % str(l)[:label_len]
        cm_txt += '\n'
        for l, r in zip(Y_labels, cm_rate):
            cm_txt += label_tpl % str(l)[:label_len]
            for c in r:
                cm_txt += col_tpl % c
            cm_txt += '\n'

        # compute accuracy
        correct = 0.0
        for c in range(cm.shape[0]):
            correct += cm[c][c]
        cm_sum = cm.sum()
        if cm_sum > 0:
            acc = correct / cm.sum()
        else:
            acc = 0.0

        return cm_txt, acc

    redirect_stdout_to_queue(logger, queue, 'INFO')

    # Wait the recording to start (GUI)
    while state.value == 2:  # 0: stop, 1:start, 2:wait
        pass

    #  Protocol runs if state equals to 1
    if not state.value:
        sys.exit(-1)

    if cfg.FAKE_CLS is None:
        # chooose amp
        if cfg.AMP_NAME is None:
            amp_name = search_lsl(ignore_markers=True, state=state)
        else:
            amp_name = cfg.AMP_NAME
        fake_dirs = None
    else:
        amp_name = None
        fake_dirs = [v for (k, v) in cfg.DIRECTIONS]

    # events and triggers
    tdef = TriggerDef(cfg.TRIGGER_FILE)
    #if cfg.TRIGGER_DEVICE is None:
    #    input('\n** Warning: No trigger device set. Press Ctrl+C to stop or Enter to continue.')
    trigger = Trigger(cfg.TRIGGER_DEVICE, state)
    if trigger.init(50) == False:
        logger.error(
            'Cannot connect to USB2LPT device. Use a mock trigger instead?')
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = Trigger('FAKE', state)
        trigger.init(50)

    # For adaptive (need to share the actual true label accross process)
    label = mp.Value('i', 0)

    # init classification
    decoder = BCIDecoderDaemon(amp_name, cfg.DECODER_FILE, buffer_size=1.0, fake=(cfg.FAKE_CLS is not None), fake_dirs=fake_dirs, \
                               parallel=cfg.PARALLEL_DECODING[cfg.PARALLEL_DECODING['selected']], alpha_new=cfg.PROB_ALPHA_NEW, label=label)

    # OLD: requires trigger values to be always defined
    #labels = [tdef.by_value[x] for x in decoder.get_labels()]
    # NEW: events can be mapped into integers:
    labels = []
    dirdata = set([d[1] for d in cfg.DIRECTIONS])
    for x in decoder.get_labels():
        if x not in dirdata:
            labels.append(tdef.by_value[x])
        else:
            labels.append(x)

    # map class labels to bar directions
    bar_def = {label: str(dir) for dir, label in cfg.DIRECTIONS}
    bar_dirs = [bar_def[l] for l in labels]
    dir_seq = []
    for x in range(cfg.TRIALS_EACH):
        dir_seq.extend(bar_dirs)

    logger.info('Initializing decoder.')
    while decoder.is_running() == 0:
        time.sleep(0.01)

    # bar visual object
    if cfg.FEEDBACK_TYPE == 'BAR':
        from neurodecode.protocols.viz_bars import BarVisual
        visual = BarVisual(cfg.GLASS_USE,
                           screen_pos=cfg.SCREEN_POS,
                           screen_size=cfg.SCREEN_SIZE)
    elif cfg.FEEDBACK_TYPE == 'BODY':
        assert hasattr(cfg, 'FEEDBACK_IMAGE_PATH'
                       ), 'FEEDBACK_IMAGE_PATH is undefined in your config.'
        from neurodecode.protocols.viz_human import BodyVisual
        visual = BodyVisual(cfg.FEEDBACK_IMAGE_PATH,
                            use_glass=cfg.GLASS_USE,
                            screen_pos=cfg.SCREEN_POS,
                            screen_size=cfg.SCREEN_SIZE)
    visual.put_text('Waiting to start')
    if cfg.LOG_PROBS:
        logdir = io.parse_path(cfg.DECODER_FILE).dir
        probs_logfile = time.strftime(logdir + "probs-%Y%m%d-%H%M%S.txt",
                                      time.localtime())
    else:
        probs_logfile = None
    feedback = Feedback(cfg, state, visual, tdef, trigger, probs_logfile)

    # If adaptive classifier
    if cfg.ADAPTIVE[cfg.ADAPTIVE['selected']]:
        nb_runs = cfg.ADAPTIVE[cfg.ADAPTIVE['selected']][0]
        adaptive = True
    else:
        nb_runs = 1
        adaptive = False

    run = 1
    while run <= nb_runs:

        if cfg.TRIALS_RANDOMIZE:
            random.shuffle(dir_seq)
        else:
            dir_seq = [d[0] for d in cfg.DIRECTIONS] * cfg.TRIALS_EACH
        num_trials = len(dir_seq)

        # For adaptive, retrain classifier
        if run > 1:

            #  Allow to retrain classifier
            with decoder.label.get_lock():
                decoder.label.value = 1

            # Wait that the retraining is done
            while decoder.label.value == 1:
                time.sleep(0.01)

            feedback.viz.put_text('Press any key')
            feedback.viz.update()
            cv2.waitKeyEx()
            feedback.viz.fill()

        # start
        trial = 1
        dir_detected = []
        prob_history = {c: [] for c in bar_dirs}
        while trial <= num_trials:
            if cfg.SHOW_TRIALS:
                title_text = 'Trial %d / %d' % (trial, num_trials)
            else:
                title_text = 'Ready'
            true_label = dir_seq[trial - 1]

            # profiling feedback
            #import cProfile
            #pr = cProfile.Profile()
            #pr.enable()
            result = feedback.classify(decoder,
                                       true_label,
                                       title_text,
                                       bar_dirs,
                                       prob_history=prob_history,
                                       adaptive=adaptive)
            #pr.disable()
            #pr.print_stats(sort='time')

            if result is None:
                decoder.stop()
                return
            else:
                pred_label = result
            dir_detected.append(pred_label)

            if cfg.WITH_REX is True and pred_label == true_label:
                # if cfg.WITH_REX is True:
                if pred_label == 'U':
                    rex_dir = 'N'
                elif pred_label == 'L':
                    rex_dir = 'W'
                elif pred_label == 'R':
                    rex_dir = 'E'
                elif pred_label == 'D':
                    rex_dir = 'S'
                else:
                    logger.warning('Rex cannot execute undefined action %s' %
                                   pred_label)
                    rex_dir = None
                if rex_dir is not None:
                    visual.move(pred_label, 100, overlay=False, barcolor='B')
                    visual.update()
                    logger.info('Executing Rex action %s' % rex_dir)
                    os.system(
                        '%s/Rex/RexControlSimple.exe %s %s' %
                        (os.environ['NEUROD_ROOT'], cfg.REX_COMPORT, rex_dir))
                    time.sleep(8)

            if true_label == pred_label:
                msg = 'Correct'
            else:
                msg = 'Wrong'
            if cfg.TRIALS_RETRY is False or true_label == pred_label:
                logger.info('Trial %d: %s (%s -> %s)' %
                            (trial, msg, true_label, pred_label))
                trial += 1

        if len(dir_detected) > 0:
            # write performance and log results
            fdir = io.parse_path(cfg.DECODER_FILE).dir
            logfile = time.strftime(fdir + "/online-%Y%m%d-%H%M%S.txt",
                                    time.localtime())
            with open(logfile, 'w') as fout:
                fout.write('Ground-truth,Prediction\n')
                for gt, dt in zip(dir_seq, dir_detected):
                    fout.write('%s,%s\n' % (gt, dt))
                cfmat, acc = confusion_matrix(dir_seq, dir_detected)
                fout.write('\nAccuracy %.3f\nConfusion matrix\n' % acc)
                fout.write(cfmat)
                logger.info('Log exported to %s' % logfile)
            print('\nAccuracy %.3f\nConfusion matrix\n' % acc)
            print(cfmat)

        run += 1

    visual.finish()

    with state.get_lock():
        state.value = 0

    if decoder.is_running():
        decoder.stop()
    '''
    # automatic thresholding
    if prob_history and len(bar_dirs) == 2:
        total = sum(len(prob_history[c]) for c in prob_history)
        fout = open(probs_logfile, 'a')
        msg = 'Automatic threshold optimization.\n'
        max_acc = 0
        max_bias = 0
        for bias in np.arange(-0.99, 1.00, 0.01):
            corrects = 0
            for p in prob_history[bar_dirs[0]]:
                p_biased = (p + bias) / (bias + 1) # new sum = (p+bias) + (1-p) = bias+1
                if p_biased >= 0.5:
                    corrects += 1
            for p in prob_history[bar_dirs[1]]:
                p_biased = (p + bias) / (bias + 1) # new sum = (p+bias) + (1-p) = bias+1
                if p_biased < 0.5:
                    corrects += 1
            acc = corrects / total
            msg += '%s%.2f: %.3f\n' % (bar_dirs[0], bias, acc)
            if acc > max_acc:
                max_acc = acc
                max_bias = bias
        msg += 'Max acc = %.3f at bias %.2f\n' % (max_acc, max_bias)
        fout.write(msg)
        fout.close()
        print(msg)
    '''

    logger.info('Finished.')
Esempio n. 27
0
def confusion_matrix(Y_true, Y_pred, label_len=6):
    """
    Generate confusion matrix in a string format

    Input
    -----
        Y_true: true labels
        Y_pred: test labels
        label_len: maximum label text length (minimum length: 6)

    Output
    ------
        (cfmat, acc)
        cfmat: confusion matrix (string)
            X-axis: prediction
            Y-axis: ground truth
        acc: accuracy (float)
    """

    # find labels
    if type(Y_true) == np.ndarray:
        Y_labels = np.unique(Y_true)
    else:
        Y_labels = [x for x in set(Y_true)]
    if label_len < 6:
        label_len = 6
        logger.warning('label_len < 6. Setting to 6.')
    label_tpl = '%' + '-%ds' % label_len
    col_tpl = '%' + '-%d.2f' % label_len

    # sanity check
    if len(Y_pred) > len(Y_true):
        raise RuntimeError('Y_pred has more items than Y_true')
    elif len(Y_pred) < len(Y_true):
        Y_true = Y_true[:len(Y_pred)]

    cm = sklearn.metrics.confusion_matrix(Y_true, Y_pred, Y_labels)

    # compute confusion matrix
    cm_rate = cm.copy().astype('float')
    # cm_rate= cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] # this can have NaN
    cm_sum = np.sum(cm, axis=1)

    for r, s in zip(cm_rate, cm_sum):
        if s > 0:
            r /= s
    cm_txt = label_tpl % 'gt\dt'
    for l in Y_labels:
        cm_txt += label_tpl % l[:label_len]
    cm_txt += '\n'
    for l, r in zip(Y_labels, cm_rate):
        cm_txt += label_tpl % l[:label_len]
        for c in r:
            cm_txt += col_tpl % c
        cm_txt += '\n'

    # compute accuracy
    correct = 0.0
    for c in range(cm.shape[0]):
        correct += cm[c][c]
    cm_sum = cm.sum()
    if cm_sum > 0:
        acc = correct / cm.sum()
    else:
        acc = 0.0

    return cm_txt, acc
Esempio n. 28
0
    def connect(self, find_any=True):
        """
        Run in child process
        """
        server_found = False
        amps = []
        channels = 0
        while server_found == False:
            if self.amp_name is None and self.amp_serial is None:
                logger.info("Looking for a streaming server...")
            else:
                logger.info("Looking for %s (Serial %s) ..." % (self.amp_name, self.amp_serial))
            streamInfos = pylsl.resolve_streams()
            if len(streamInfos) > 0:
                # For now, only 1 amp is supported by a single StreamReceiver object.
                for si in streamInfos:
                    # is_slave= ('true'==pylsl.StreamInlet(si).info().desc().child('amplifier').child('settings').child('is_slave').first_child().value() )
                    inlet = pylsl.StreamInlet(si)
                    # LSL XML parser has a bug which crashes so do not use for now
                    #amp_serial = inlet.info().desc().child('acquisition').child_value('serial_number')
                    amp_serial = 'N/A'
                    amp_name = si.name()

                    # connect to a specific amp only?
                    if self.amp_serial is not None and self.amp_serial != amp_serial:
                        continue

                    # connect to a specific amp only?
                    if self.amp_name is not None and self.amp_name != amp_name:
                        continue

                    # EEG streaming server only?
                    if self.eeg_only and si.type() != 'EEG':
                        continue

                    if 'USBamp' in amp_name:
                        logger.info('Found USBamp streaming server %s (type %s, amp_serial %s) @ %s.' % (amp_name, si.type(), amp_serial, si.hostname()))
                        self._lsl_tr_channel = 16
                        channels += si.channel_count()
                        ch_list = pu.lsl_channel_list(inlet)
                        amps.append(si)
                        server_found = True
                        break
                    elif 'BioSemi' in amp_name:
                        logger.info('Found BioSemi streaming server %s (type %s, amp_serial %s) @ %s.' % (amp_name, si.type(), amp_serial, si.hostname()))
                        self._lsl_tr_channel = 0  # or subtract -6684927? (value when trigger==0)
                        channels += si.channel_count()
                        ch_list = pu.lsl_channel_list(inlet)
                        amps.append(si)
                        server_found = True
                        break
                    elif 'SmartBCI' in amp_name:
                        logger.info('Found SmartBCI streaming server %s (type %s, amp_serial %s) @ %s.' % (amp_name, si.type(), amp_serial, si.hostname()))
                        self._lsl_tr_channel = 23
                        channels += si.channel_count()
                        ch_list = pu.lsl_channel_list(inlet)
                        amps.append(si)
                        server_found = True
                        break
                    elif 'StreamPlayer' in amp_name:
                        logger.info('Found StreamPlayer streaming server %s (type %s, amp_serial %s) @ %s.' % (amp_name, si.type(), amp_serial, si.hostname()))
                        self._lsl_tr_channel = 0
                        channels += si.channel_count()
                        ch_list = pu.lsl_channel_list(inlet)
                        amps.append(si)
                        server_found = True
                        break
                    elif 'openvibeSignal' in amp_name:
                        logger.info('Found an Openvibe signal streaming server %s (type %s, amp_serial %s) @ %s.' % (amp_name, si.type(), amp_serial, si.hostname()))
                        ch_list = pu.lsl_channel_list(inlet)
                        self._lsl_tr_channel = find_event_channel(ch_names=ch_list)
                        channels += si.channel_count()
                        amps.append(si)
                        server_found = True
                        # OpenVibe standard unit is Volts, which is not ideal for some numerical computations
                        self.multiplier = 10**6 # change V -> uV unit for OpenVibe sources
                        break
                    elif 'openvibeMarkers' in amp_name:
                        logger.info('Found an Openvibe markers server %s (type %s, amp_serial %s) @ %s.' % (amp_name, si.type(), amp_serial, si.hostname()))
                        ch_list = pu.lsl_channel_list(inlet)
                        self._lsl_tr_channel = find_event_channel(ch_names=ch_list)
                        channels += si.channel_count()
                        amps.append(si)
                        server_found = True
                        break
                    elif find_any:
                        logger.info('Found a streaming server %s (type %s, amp_serial %s) @ %s.' % (amp_name, si.type(), amp_serial, si.hostname()))
                        ch_list = pu.lsl_channel_list(inlet)
                        self._lsl_tr_channel = find_event_channel(ch_names=ch_list)
                        channels += si.channel_count()
                        amps.append(si)
                        server_found = True
                        break
            time.sleep(1)

        self.amp_name = amp_name

        # define EEG channel indices
        self._lsl_eeg_channels = list(range(channels))
        if self._lsl_tr_channel is None:
            logger.warning('Trigger channel not fonud. Adding an empty channel 0.')
        else:
            if self._lsl_tr_channel != 0:
                logger.info_yellow('Trigger channel found at index %d. Moving to index 0.' % self._lsl_tr_channel)
            self._lsl_eeg_channels.pop(self._lsl_tr_channel)
        self._lsl_eeg_channels = np.array(self._lsl_eeg_channels)
        self.tr_channel = 0  # trigger channel is always set to 0.
        self.eeg_channels = np.arange(1, channels)  # signal channels start from 1.

        # create new inlets to read from the stream
        inlets_master = []
        inlets_slaves = []
        for amp in amps:
            # data type of the 2nd argument (max_buflen) is int according to LSL C++ specification!
            inlet = pylsl.StreamInlet(amp, max_buflen=self.stream_bufsec)
            inlets_master.append(inlet)
            self.buffers.append([])
            self.timestamps.append([])

        inlets = inlets_master + inlets_slaves
        sample_rate = amps[0].nominal_srate()
        logger.info('Channels: %d' % channels)
        logger.info('LSL Protocol version: %s' % amps[0].version())
        logger.info('Source sampling rate: %.1f' % sample_rate)
        logger.info('Unit multiplier: %.1f' % self.multiplier)

        #self.winsize = int(self.winsec * sample_rate)
        #self.bufsize = int(self.bufsec * sample_rate)
        self.winsize = int(round(self.winsec * sample_rate))
        self.bufsize = int(round(self.bufsec * sample_rate))
        self.stream_bufsize = int(round(self.stream_bufsec * sample_rate))
        self.sample_rate = sample_rate
        self.connected = True
        self.ch_list = ch_list
        self.inlets = inlets  # Note: not picklable!

        # TODO: check if there's any problem with multiple inlets
        if len(self.inlets) > 1:
            logger.warning('Merging of multiple acquisition servers is not supported yet. Only %s will be used.' % amps[0].name())
            '''
            for i in range(1, len(self.inlets)):
                chunk, tslist = self.inlets[i].pull_chunk(max_samples=self.stream_bufsize)
                self.buffers[i].extend(chunk)
                self.timestamps[i].extend(tslist)
                if self.bufsize > 0 and len(self.buffers[i]) > self.bufsize:
                    self.buffers[i] = self.buffers[i][-self.bufsize:]
            '''

        # create channel info
        if self._lsl_tr_channel is None:
            self.ch_list = ['TRIGGER'] + self.ch_list
        else:
            for i, chn in enumerate(self.ch_list):
                if chn == 'TRIGGER' or chn == 'TRG' or 'STI ' in chn:
                    self.ch_list.pop(i)
                    self.ch_list = ['TRIGGER'] + self.ch_list
                    break
        logger.info('self.ch_list %s' % self.ch_list)

        # fill in initial buffer
        logger.info('Waiting to fill initial buffer of length %d' % (self.winsize))
        while len(self.timestamps[0]) < self.winsize:
            self.acquire()
            time.sleep(0.1)
        self.ready = True
        logger.info('Start receiving stream data.')
Esempio n. 29
0
        elif color.upper() == 'G':
            c = colorama.Fore.GREEN
        elif color.upper() == 'Y':
            c = colorama.Fore.YELLOW
        elif color.upper() == 'W':
            c = colorama.Fore.WHITE
        elif color.upper() == 'C':
            c = colorama.Fore.CYAN
        else:
            logger.error('print_c(): Unknown color code %s' % color)
            raise ValueError
        print(colorama.Style.BRIGHT + c + str(msg) + colorama.Style.RESET_ALL,
              end=end)

except ImportError:
    logger.warning(
        'colorama module not found. print_c() will ignore color codes.')

    def print_c(msg, color, end='\n'):
        print(msg, end=end)


'''"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
 List/Dict related
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""'''


def list2string(vec, fmt, sep=' '):
    """
    Convert a list to string with formatting, separated by sep (default is space).
    Example: fmt= '%.32e', '%.6f', etc.
    """
Esempio n. 30
0
def run(cfg, state=mp.Value('i', 1), queue=None):

    redirect_stdout_to_queue(logger, queue, 'INFO')

    # Wait the recording to start (GUI)
    while state.value == 2:  # 0: stop, 1:start, 2:wait
        pass
    #  Protocol start if equals to 1
    if not state.value:
        sys.exit()

    refresh_delay = 1.0 / cfg.REFRESH_RATE

    cfg.tdef = trigger_def(cfg.TRIGGER_FILE)

    # visualizer
    keys = {
        'left': 81,
        'right': 83,
        'up': 82,
        'down': 84,
        'pgup': 85,
        'pgdn': 86,
        'home': 80,
        'end': 87,
        'space': 32,
        'esc': 27,
        ',': 44,
        '.': 46,
        's': 115,
        'c': 99,
        '[': 91,
        ']': 93,
        '1': 49,
        '!': 33,
        '2': 50,
        '@': 64,
        '3': 51,
        '#': 35
    }
    color = dict(G=(20, 140, 0),
                 B=(210, 0, 0),
                 R=(0, 50, 200),
                 Y=(0, 215, 235),
                 K=(0, 0, 0),
                 w=(200, 200, 200))

    dir_sequence = []
    for x in range(cfg.TRIALS_EACH):
        dir_sequence.extend(cfg.DIRECTIONS)
    random.shuffle(dir_sequence)
    num_trials = len(cfg.DIRECTIONS) * cfg.TRIALS_EACH

    event = 'start'
    trial = 1

    # Hardware trigger
    if cfg.TRIGGER_DEVICE is None:
        logger.warning(
            'No trigger device set. Press Ctrl+C to stop or Enter to continue.'
        )
        #input()
    trigger = pyLptControl.Trigger(state, cfg.TRIGGER_DEVICE)
    if trigger.init(50) == False:
        logger.error(
            '\n** Error connecting to USB2LPT device. Use a mock trigger instead?'
        )
        input('Press Ctrl+C to stop or Enter to continue.')
        trigger = pyLptControl.MockTrigger()
        trigger.init(50)

    # timers
    timer_trigger = qc.Timer()
    timer_dir = qc.Timer()
    timer_refresh = qc.Timer()
    t_dir = cfg.TIMINGS['DIR'] + random.uniform(-cfg.TIMINGS['DIR_RANDOMIZE'],
                                                cfg.TIMINGS['DIR_RANDOMIZE'])
    t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(
        -cfg.TIMINGS['READY_RANDOMIZE'], cfg.TIMINGS['READY_RANDOMIZE'])

    bar = BarVisual(cfg.GLASS_USE,
                    screen_pos=cfg.SCREEN_POS,
                    screen_size=cfg.SCREEN_SIZE)
    bar.fill()
    bar.glass_draw_cue()

    # start
    while trial <= num_trials:
        timer_refresh.sleep_atleast(refresh_delay)
        timer_refresh.reset()

        # segment= { 'cue':(s,e), 'dir':(s,e), 'label':0-4 } (zero-based)
        if event == 'start' and timer_trigger.sec() > cfg.TIMINGS['INIT']:
            event = 'gap_s'
            bar.fill()
            timer_trigger.reset()
            trigger.signal(cfg.tdef.INIT)
        elif event == 'gap_s':
            if cfg.TRIAL_PAUSE:
                bar.put_text('Press any key')
                bar.update()
                key = cv2.waitKey()
                if key == keys['esc'] or not state.value:
                    break
                bar.fill()
            bar.put_text('Trial %d / %d' % (trial, num_trials))
            event = 'gap'
            timer_trigger.reset()
        elif event == 'gap' and timer_trigger.sec() > cfg.TIMINGS['GAP']:
            event = 'cue'
            bar.fill()
            bar.draw_cue()
            trigger.signal(cfg.tdef.CUE)
            timer_trigger.reset()
        elif event == 'cue' and timer_trigger.sec() > cfg.TIMINGS['CUE']:
            event = 'dir_r'
            dir = dir_sequence[trial - 1]
            if dir == 'L':  # left
                bar.move('L', 100, overlay=True)
                trigger.signal(cfg.tdef.LEFT_READY)
            elif dir == 'R':  # right
                bar.move('R', 100, overlay=True)
                trigger.signal(cfg.tdef.RIGHT_READY)
            elif dir == 'U':  # up
                bar.move('U', 100, overlay=True)
                trigger.signal(cfg.tdef.UP_READY)
            elif dir == 'D':  # down
                bar.move('D', 100, overlay=True)
                trigger.signal(cfg.tdef.DOWN_READY)
            elif dir == 'B':  # both hands
                bar.move('L', 100, overlay=True)
                bar.move('R', 100, overlay=True)
                trigger.signal(cfg.tdef.BOTH_READY)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
            timer_trigger.reset()
        elif event == 'dir_r' and timer_trigger.sec() > t_dir_ready:
            bar.fill()
            bar.draw_cue()
            event = 'dir'
            timer_trigger.reset()
            timer_dir.reset()
            if dir == 'L':  # left
                trigger.signal(cfg.tdef.LEFT_GO)
            elif dir == 'R':  # right
                trigger.signal(cfg.tdef.RIGHT_GO)
            elif dir == 'U':  # up
                trigger.signal(cfg.tdef.UP_GO)
            elif dir == 'D':  # down
                trigger.signal(cfg.tdef.DOWN_GO)
            elif dir == 'B':  # both
                trigger.signal(cfg.tdef.BOTH_GO)
            else:
                raise RuntimeError('Unknown direction %d' % dir)
        elif event == 'dir' and timer_trigger.sec() > t_dir:
            event = 'gap_s'
            bar.fill()
            trial += 1
            logger.info('trial ' + str(trial - 1) + ' done')
            trigger.signal(cfg.tdef.BLANK)
            timer_trigger.reset()
            t_dir = cfg.TIMINGS['DIR'] + random.uniform(
                -cfg.TIMINGS['DIR_RANDOMIZE'], cfg.TIMINGS['DIR_RANDOMIZE'])
            t_dir_ready = cfg.TIMINGS['READY'] + random.uniform(
                -cfg.TIMINGS['READY_RANDOMIZE'],
                cfg.TIMINGS['READY_RANDOMIZE'])

        # protocol
        if event == 'dir':
            dx = min(100, int(100.0 * timer_dir.sec() / t_dir) + 1)
            if dir == 'L':  # L
                bar.move('L', dx, overlay=True)
            elif dir == 'R':  # R
                bar.move('R', dx, overlay=True)
            elif dir == 'U':  # U
                bar.move('U', dx, overlay=True)
            elif dir == 'D':  # D
                bar.move('D', dx, overlay=True)
            elif dir == 'B':  # Both
                bar.move('L', dx, overlay=True)
                bar.move('R', dx, overlay=True)

        # wait for start
        if event == 'start':
            bar.put_text('Waiting to start')

        bar.update()
        key = 0xFF & cv2.waitKey(1)
        if key == keys['esc'] or not state.value:
            break

    bar.finish()

    with state.get_lock():
        state.value = 0