Exemplo n.º 1
0
def get_psd(sr, psde, picks):
    sr.acquire()
    w, ts = sr.get_window()  # w = times x channels
    w = w.T  # -> channels x times

    # apply filters. Important: maintain the original channel order at this point.
    pu.preprocess(w,
                  sfreq=sfreq,
                  spatial=spatial,
                  spatial_ch=spatial_ch,
                  spectral=spectral,
                  spectral_ch=spectral_ch,
                  notch=notch,
                  notch_ch=notch_ch,
                  multiplier=multiplier)

    # select the same channels used for training
    w = w[picks]

    # debug: show max - min
    # c=1; print( '### %d: %.1f - %.1f = %.1f'% ( picks[c], max(w[c]), min(w[c]), max(w[c])-min(w[c]) ) )

    # psde.transform = [channels x freqs]
    psd = psde.transform(w)

    return psd
Exemplo n.º 2
0
def get_tfr(fif_file, cfg, tfr, n_jobs=1):
    raw, events = pu.load_raw(fif_file)
    p = qc.parse_path(fif_file)
    fname = p.name
    outpath = p.dir

    export_dir = '%s/plot_%s' % (outpath, fname)
    qc.make_dirs(export_dir)

    # set channels of interest
    picks = pu.channel_names_to_index(raw, cfg.CHANNEL_PICKS)
    spchannels = pu.channel_names_to_index(raw, cfg.SP_CHANNELS)

    if max(picks) > len(raw.info['ch_names']):
        msg = 'ERROR: "picks" has a channel index %d while there are only %d channels.' %\
              (max(picks), len(raw.info['ch_names']))
        raise RuntimeError(msg)

    # Apply filters
    pu.preprocess(raw, spatial=cfg.SP_FILTER, spatial_ch=spchannels, spectral=cfg.TP_FILTER,
                  spectral_ch=picks, notch=cfg.NOTCH_FILTER, notch_ch=picks,
                  multiplier=cfg.MULTIPLIER, n_jobs=n_jobs)

    # MNE TFR functions do not support Raw instances yet, so convert to Epoch
    if cfg.EVENT_START is None:
        raw._data[0][0] = 1
        events = np.array([[0, 0, 1]])
        classes = None
    else:
        classes = {'START':cfg.EVENT_START}
    tmax = (raw._data.shape[1] - 1) / raw.info['sfreq']
    epochs_all = mne.Epochs(raw, events, classes, tmin=0, tmax=tmax,
                    picks=picks, baseline=None, preload=True)
    print('\n>> Processing %s' % fif_file)
    freqs = cfg.FREQ_RANGE  # define frequencies of interest
    n_cycles = freqs / 2.  # different number of cycle per frequency
    power = tfr(epochs_all, freqs=freqs, n_cycles=n_cycles, use_fft=False,
        return_itc=False, decim=1, n_jobs=n_jobs)

    if cfg.EXPORT_MATLAB is True:
        # export all channels to MATLAB
        mout = '%s/%s-%s.mat' % (export_dir, fname, cfg.SP_FILTER)
        scipy.io.savemat(mout, {'tfr':power.data, 'chs':power.ch_names, 'events':events,
            'sfreq':raw.info['sfreq'], 'freqs':cfg.FREQ_RANGE})

    if cfg.EXPORT_PNG is True:
        # Plot power of each channel
        for ch in np.arange(len(picks)):
            ch_name = raw.ch_names[picks[ch]]
            title = 'Channel %s' % (ch_name)
            # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
            fig = power.plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False,
                colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False)
            fout = '%s/%s-%s-%s.png' % (export_dir, fname, cfg.SP_FILTER, ch_name)
            fig.savefig(fout)
            print('Exported %s' % fout)

    print('Finished !')
Exemplo n.º 3
0
    def get_prob(self):
        """
        Read the latest window

        Returns
        -------
        The likelihood P(X|C), where X=window, C=model
        """
        if self.fake:
            # fake deocder: biased likelihood for the first class
            probs = [random.uniform(0.0, 1.0)]
            # others class likelihoods are just set to equal
            p_others = (1 - probs[0]) / (len(self.labels) - 1)
            for x in range(1, len(self.labels)):
                probs.append(p_others)
            time.sleep(0.0625)  # simulated delay for PSD + RF
        else:
            self.sr.acquire()
            w, ts = self.sr.get_window()  # w = times x channels
            w = w.T  # -> channels x times

            # apply filters. Important: maintain the original channel order at this point.
            pu.preprocess(w, sfreq=self.sfreq, spatial=self.spatial, spatial_ch=self.spatial_ch,
                          spectral=self.spectral, spectral_ch=self.spectral_ch, notch=self.notch,
                          notch_ch=self.notch_ch, multiplier=self.multiplier)

            # select the same channels used for training
            w = w[self.picks]

            # debug: show max - min
            # c=1; print( '### %d: %.1f - %.1f = %.1f'% ( self.picks[c], max(w[c]), min(w[c]), max(w[c])-min(w[c]) ) )

            # psd = channels x freqs
            psd = self.psde.transform(w.reshape((1, w.shape[0], w.shape[1])))

            # update psd buffer ( < 1 msec overhead )
            self.psd_buffer = np.concatenate((self.psd_buffer, psd), axis=0)
            self.ts_buffer.append(ts[0])
            if ts[0] - self.ts_buffer[0] > self.buffer_sec:
                # search speed comparison for ordered arrays:
                # http://stackoverflow.com/questions/16243955/numpy-first-occurence-of-value-greater-than-existing-value
                t_index = np.searchsorted(self.ts_buffer, ts[0] - 1.0)
                self.ts_buffer = self.ts_buffer[t_index:]
                self.psd_buffer = self.psd_buffer[t_index:, :, :]  # numpy delete is slower
            # assert ts[0] - self.ts_buffer[0] <= self.buffer_sec

            # make a feautre vector and classify
            feats = np.concatenate(psd[0]).reshape(1, -1)

            # compute likelihoods
            probs = self.cls.predict_proba(feats)[0]

        return probs
Exemplo n.º 4
0
def get_tfr(cfg, recursive=False, n_jobs=1):
    '''
    @params:
    tfr_type: 'multitaper' or 'morlet'
    recursive: if True, load raw files in sub-dirs recursively
    export_path: path to save plots
    n_jobs: number of cores to run in parallel
    '''

    cfg = check_cfg(cfg)
    tfr_type = cfg.TFR_TYPE
    export_path = cfg.EXPORT_PATH
    t_buffer = cfg.T_BUFFER
    if tfr_type == 'multitaper':
        tfr = mne.time_frequency.tfr_multitaper
    elif tfr_type == 'morlet':
        tfr = mne.time_frequency.tfr_morlet
    elif tfr_type == 'butter':
        butter_order = 4 # TODO: parameterize
        tfr = lfilter
    elif tfr_type == 'fir':
        raise NotImplementedError
    else:
        raise ValueError('Wrong TFR type %s' % tfr_type)
    n_jobs = cfg.N_JOBS
    if n_jobs is None:
        n_jobs = mp.cpu_count()

    if hasattr(cfg, 'DATA_DIRS'):
        if export_path is None:
            raise ValueError('For multiple directories, cfg.EXPORT_PATH cannot be None')
        else:
            outpath = export_path
        # custom event file
        if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None:
            events = mne.read_events(cfg.EVENT_FILE)
        file_prefix = 'grandavg'

        # load and merge files from all directories
        flist = []
        for ddir in cfg.DATA_DIRS:
            ddir = ddir.replace('\\', '/')
            if ddir[-1] != '/': ddir += '/'
            for f in qc.get_file_list(ddir, fullpath=True, recursive=recursive):
                if qc.parse_path(f).ext in ['fif', 'bdf', 'gdf']:
                    flist.append(f)
        raw, events = pu.load_multi(flist)
    else:
        print('Loading', cfg.DATA_FILE)
        raw, events = pu.load_raw(cfg.DATA_FILE)

        # custom events
        if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None:
            events = mne.read_events(cfg.EVENT_FILE)

        if export_path is None:
            [outpath, file_prefix, _] = qc.parse_path_list(cfg.DATA_FILE)
        else:
            outpath = export_path

    # re-referencing
    if cfg.REREFERENCE is not None:
        pu.rereference(raw, cfg.REREFERENCE[1], cfg.REREFERENCE[0])

    sfreq = raw.info['sfreq']

    # set channels of interest
    picks = pu.channel_names_to_index(raw, cfg.CHANNEL_PICKS)
    spchannels = pu.channel_names_to_index(raw, cfg.SP_CHANNELS)

    if max(picks) > len(raw.info['ch_names']):
        msg = 'ERROR: "picks" has a channel index %d while there are only %d channels.' %\
              (max(picks), len(raw.info['ch_names']))
        raise RuntimeError(msg)

    # Apply filters
    pu.preprocess(raw, spatial=cfg.SP_FILTER, spatial_ch=spchannels, spectral=cfg.TP_FILTER,
                  spectral_ch=picks, notch=cfg.NOTCH_FILTER, notch_ch=picks,
                  multiplier=cfg.MULTIPLIER, n_jobs=n_jobs)

    # Read epochs
    classes = {}
    for t in cfg.TRIGGERS:
        if t in set(events[:, -1]):
            if hasattr(cfg, 'tdef'):
                classes[cfg.tdef.by_value[t]] = t
            else:
                classes[str(t)] = t
    if len(classes) == 0:
        raise ValueError('No desired event was found from the data.')

    try:
        tmin = cfg.EPOCH[0]
        tmin_buffer = tmin - t_buffer
        raw_tmax = raw._data.shape[1] / sfreq - 0.1
        if cfg.EPOCH[1] is None:
            if cfg.POWER_AVERAGED:
                raise ValueError('EPOCH value cannot have None for grand averaged TFR')
            else:
                if len(cfg.TRIGGERS) > 1:
                    raise ValueError('If the end time of EPOCH is None, only a single event can be defined.')
                t_ref = events[np.where(events[:,2] == list(cfg.TRIGGERS)[0])[0][0], 0] / sfreq
                tmax = raw_tmax - t_ref - t_buffer
        else:
            tmax = cfg.EPOCH[1]
        tmax_buffer = tmax + t_buffer
        if tmax_buffer > raw_tmax:
            raise ValueError('Epoch length with buffer (%.3f) is larger than signal length (%.3f)' % (tmax_buffer, raw_tmax))

        #print('Epoch tmin = %.1f, tmax = %.1f, raw length = %.1f' % (tmin, tmax, raw_tmax))
        epochs_all = mne.Epochs(raw, events, classes, tmin=tmin_buffer, tmax=tmax_buffer,
                                proj=False, picks=picks, baseline=None, preload=True)
        if epochs_all.drop_log_stats() > 0:
            print('\n** Bad epochs found. Dropping into a Python shell.')
            print(epochs_all.drop_log)
            print('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \
                (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq))
            print('\nType exit to continue.\n')
            pdb.set_trace()
    except:
        print('\n*** (tfr_export) ERROR OCCURRED WHILE EPOCHING ***')
        traceback.print_exc()
        print('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \
            (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq))
        pdb.set_trace()

    power = {}
    for evname in classes:
        #export_dir = '%s/plot_%s' % (outpath, evname)
        export_dir = outpath
        qc.make_dirs(export_dir)
        print('\n>> Processing %s' % evname)
        freqs = cfg.FREQ_RANGE  # define frequencies of interest
        n_cycles = freqs / 2.  # different number of cycle per frequency
        if cfg.POWER_AVERAGED:
            # grand-average TFR
            epochs = epochs_all[evname][:]
            if len(epochs) == 0:
                print('No %s epochs. Skipping.' % evname)
                continue

            if tfr_type == 'butter':
                b, a = butter_bandpass(cfg.FREQ_RANGE[0], cfg.FREQ_RANGE[-1], sfreq, order=butter_order)
                tfr_filtered = lfilter(b, a, epochs, axis=2)
                tfr_hilbert = hilbert(tfr_filtered)
                tfr_power = abs(tfr_hilbert)
                tfr_data = np.mean(tfr_power, axis=0)
            elif tfr_type == 'fir':
                raise NotImplementedError
            else:
                power[evname] = tfr(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=False,
                    return_itc=False, decim=1, n_jobs=n_jobs)
                power[evname] = power[evname].crop(tmin=tmin, tmax=tmax)
                tfr_data = power[evname].data

            if cfg.EXPORT_MATLAB is True:
                # export all channels to MATLAB
                mout = '%s/%s-%s-%s.mat' % (export_dir, file_prefix, cfg.SP_FILTER, evname)
                scipy.io.savemat(mout, {'tfr':tfr_data, 'chs':epochs.ch_names,
                    'events':events, 'sfreq':sfreq, 'epochs':cfg.EPOCH, 'freqs':cfg.FREQ_RANGE})
                print('Exported %s' % mout)
            if cfg.EXPORT_PNG is True:
                # Inspect power for each channel
                for ch in np.arange(len(picks)):
                    chname = raw.ch_names[picks[ch]]
                    title = 'Peri-event %s - Channel %s' % (evname, chname)

                    # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
                    fig = power[evname].plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False,
                        colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False)
                    fout = '%s/%s-%s-%s-%s.png' % (export_dir, file_prefix, cfg.SP_FILTER, evname, chname)
                    fig.savefig(fout)
                    fig.clf()
                    print('Exported to %s' % fout)
        else:
            # TFR per event
            for ep in range(len(epochs_all[evname])):
                epochs = epochs_all[evname][ep]
                if len(epochs) == 0:
                    print('No %s epochs. Skipping.' % evname)
                    continue
                power[evname] = tfr(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=False,
                    return_itc=False, decim=1, n_jobs=n_jobs)
                power[evname] = power[evname].crop(tmin=tmin, tmax=tmax)
                if cfg.EXPORT_MATLAB is True:
                    # export all channels to MATLAB
                    mout = '%s/%s-%s-%s-ep%02d.mat' % (export_dir, file_prefix, cfg.SP_FILTER, evname, ep + 1)
                    scipy.io.savemat(mout, {'tfr':power[evname].data, 'chs':power[evname].ch_names,
                        'events':events, 'sfreq':sfreq, 'tmin':tmin, 'tmax':tmax, 'freqs':cfg.FREQ_RANGE})
                    print('Exported %s' % mout)
                if cfg.EXPORT_PNG is True:
                    # Inspect power for each channel
                    for ch in np.arange(len(picks)):
                        chname = raw.ch_names[picks[ch]]
                        title = 'Peri-event %s - Channel %s, Trial %d' % (evname, chname, ep + 1)
                        # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
                        fig = power[evname].plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False,
                            colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False)
                        fout = '%s/%s-%s-%s-%s-ep%02d.png' % (export_dir, file_prefix, cfg.SP_FILTER, evname, chname, ep + 1)
                        fig.savefig(fout)
                        fig.clf()
                        print('Exported %s' % fout)

    if hasattr(cfg, 'POWER_DIFF'):
        export_dir = '%s/diff' % outpath
        qc.make_dirs(export_dir)
        labels = classes.keys()
        df = power[labels[0]] - power[labels[1]]
        df.data = np.log(np.abs(df.data))
        # Inspect power diff for each channel
        for ch in np.arange(len(picks)):
            chname = raw.ch_names[picks[ch]]
            title = 'Peri-event %s-%s - Channel %s' % (labels[0], labels[1], chname)

            # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
            fig = df.plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False,
                          colorbar=True, title=title, vmin=3.0, vmax=-3.0, dB=False)
            fout = '%s/%s-%s-diff-%s-%s-%s.jpg' % (export_dir, file_prefix, cfg.SP_FILTER, labels[0], labels[1], chname)
            print('Exporting to %s' % fout)
            fig.savefig(fout)
            fig.clf()
    print('Finished !')
Exemplo n.º 5
0
def compute_features(cfg):
    # Load file list
    ftrain = []
    for f in qc.get_file_list(cfg.DATADIR, fullpath=True):
        if f[-4:] in ['.fif', '.fiff']:
            ftrain.append(f)

    # Preprocessing, epoching and PSD computation
    if len(ftrain) > 1 and cfg.CHANNEL_PICKS is not None and type(
            cfg.CHANNEL_PICKS[0]) == int:
        raise RuntimeError(
            'When loading multiple EEG files, CHANNEL_PICKS must be list of string, not integers because they may have different channel order.'
        )
    raw, events = pu.load_multi(ftrain)
    if cfg.REF_CH is not None:
        pu.rereference(raw, cfg.REF_CH[1], cfg.REF_CH[0])
    if cfg.LOAD_EVENTS_FILE is not None:
        events = mne.read_events(cfg.LOAD_EVENTS_FILE)
    triggers = {cfg.tdef.by_value[c]: c for c in set(cfg.TRIGGER_DEF)}

    # Pick channels
    if cfg.CHANNEL_PICKS is None:
        chlist = [int(x) for x in pick_types(raw.info, stim=False, eeg=True)]
    else:
        chlist = cfg.CHANNEL_PICKS
    picks = []
    for c in chlist:
        if type(c) == int:
            picks.append(c)
        elif type(c) == str:
            picks.append(raw.ch_names.index(c))
        else:
            raise RuntimeError(
                'CHANNEL_PICKS has a value of unknown type %s.\nCHANNEL_PICKS=%s'
                % (type(c), cfg.CHANNEL_PICKS))
    if cfg.EXCLUDES is not None:
        for c in cfg.EXCLUDES:
            if type(c) == str:
                if c not in raw.ch_names:
                    qc.print_c(
                        'Warning: Exclusion channel %s does not exist. Ignored.'
                        % c, 'Y')
                    continue
                c_int = raw.ch_names.index(c)
            elif type(c) == int:
                c_int = c
            else:
                raise RuntimeError(
                    'EXCLUDES has a value of unknown type %s.\nEXCLUDES=%s' %
                    (type(c), cfg.EXCLUDES))
            if c_int in picks:
                del picks[picks.index(c_int)]
    if max(picks) > len(raw.ch_names):
        raise ValueError(
            '"picks" has a channel index %d while there are only %d channels.'
            % (max(picks), len(raw.ch_names)))
    if hasattr(cfg, 'SP_CHANNELS') and cfg.SP_CHANNELS is not None:
        qc.print_c(
            'compute_features(): SP_CHANNELS parameter is not supported yet. Will be set to CHANNEL_PICKS.',
            'Y')
    if hasattr(cfg, 'TP_CHANNELS') and cfg.TP_CHANNELS is not None:
        qc.print_c(
            'compute_features(): TP_CHANNELS parameter is not supported yet. Will be set to CHANNEL_PICKS.',
            'Y')
    if hasattr(cfg, 'NOTCH_CHANNELS') and cfg.NOTCH_CHANNELS is not None:
        qc.print_c(
            'compute_features(): NOTCH_CHANNELS parameter is not supported yet. Will be set to CHANNEL_PICKS.',
            'Y')

    # Read epochs
    try:
        # Experimental: multiple epoch ranges
        if type(cfg.EPOCH[0]) is list:
            epochs_train = []
            for ep in cfg.EPOCH:
                epoch = Epochs(raw,
                               events,
                               triggers,
                               tmin=ep[0],
                               tmax=ep[1],
                               proj=False,
                               picks=picks,
                               baseline=None,
                               preload=True,
                               verbose=False,
                               detrend=None)
                # Channels are already selected by 'picks' param so use all channels.
                pu.preprocess(epoch,
                              spatial=cfg.SP_FILTER,
                              spatial_ch=None,
                              spectral=cfg.TP_FILTER,
                              spectral_ch=None,
                              notch=cfg.NOTCH_FILTER,
                              notch_ch=None,
                              multiplier=cfg.MULTIPLIER,
                              n_jobs=cfg.N_JOBS)
                epochs_train.append(epoch)
        else:
            # Usual method: single epoch range
            epochs_train = Epochs(raw,
                                  events,
                                  triggers,
                                  tmin=cfg.EPOCH[0],
                                  tmax=cfg.EPOCH[1],
                                  proj=False,
                                  picks=picks,
                                  baseline=None,
                                  preload=True,
                                  verbose=False,
                                  detrend=None)
            # Channels are already selected by 'picks' param so use all channels.
            pu.preprocess(epochs_train,
                          spatial=cfg.SP_FILTER,
                          spatial_ch=None,
                          spectral=cfg.TP_FILTER,
                          spectral_ch=None,
                          notch=cfg.NOTCH_FILTER,
                          notch_ch=None,
                          multiplier=cfg.MULTIPLIER,
                          n_jobs=cfg.N_JOBS)
    except:
        qc.print_c('\n*** (trainer.py) ERROR OCCURRED WHILE EPOCHING ***\n',
                   'R')
        # Catch and throw errors from child processes
        traceback.print_exc()
        if interactive:
            print('Dropping into a shell.\n')
            embed()
        raise RuntimeError

    label_set = np.unique(triggers.values())

    # Compute features
    if cfg.FEATURES == 'PSD':
        featdata = get_psd_feature(epochs_train,
                                   cfg.EPOCH,
                                   cfg.PSD,
                                   feat_picks=None,
                                   n_jobs=cfg.N_JOBS)
    elif cfg.FEATURES == 'TIMELAG':
        '''
        TODO: Implement multiple epochs for timelag feature
        '''
        raise NotImplementedError(
            'MULTIPLE EPOCHS NOT IMPLEMENTED YET FOR TIMELAG FEATURE.')
    elif cfg.FEATURES == 'WAVELET':
        '''
        TODO: Implement multiple epochs for wavelet feature
        '''
        raise NotImplementedError(
            'MULTIPLE EPOCHS NOT IMPLEMENTED YET FOR WAVELET FEATURE.')
    else:
        raise NotImplementedError('%s feature type is not supported.' %
                                  cfg.FEATURES)

    featdata['picks'] = picks
    featdata['sfreq'] = raw.info['sfreq']
    featdata['ch_names'] = raw.ch_names
    return featdata
Exemplo n.º 6
0
def slice_win(epochs_data,
              w_starts,
              w_length,
              psde,
              picks=None,
              title=None,
              flatten=True,
              preprocess=None,
              verbose=False):
    '''
    Compute PSD values of a sliding window

    Params
        epochs_data ([channels]x[samples]): raw epoch data
        w_starts (list): starting indices of sample segments
        w_length (int): window length in number of samples
        psde: MNE PSDEstimator object
        picks (list): subset of channels within epochs_data
        title (string): print out the title associated with PID
        flatten (boolean): generate concatenated feature vectors
            If True: X = [windows] x [channels x freqs]
            If False: X = [windows] x [channels] x [freqs]
        preprocess (dict): None or parameters for pycnbi_utils.preprocess() with the following keys:
            sfreq, spatial, spatial_ch, spectral, spectral_ch, notch, notch_ch,
            multiplier, ch_names, rereference, decim, n_jobs
    Returns:
        [windows] x [channels*freqs] or [windows] x [channels] x [freqs]
    '''

    # raise error for wrong indexing
    def WrongIndexError(Exception):
        logger.error('%s' % Exception)

    if type(w_length) is not int:
        logger.warning('w_length type is %s. Converting to int.' %
                       type(w_length))
        w_length = int(w_length)
    if title is None:
        title = '[PID %d] Frames %d-%d' % (os.getpid(), w_starts[0],
                                           w_starts[-1] + w_length - 1)
    else:
        title = '[PID %d] %s' % (os.getpid(), title)
    if preprocess is not None and preprocess['decim'] != 1:
        title += ' (decim factor %d)' % preprocess['decim']
    logger.info(title)

    X = None
    for n in w_starts:
        n = int(round(n))
        if n >= epochs_data.shape[1]:
            logger.error(
                'w_starts has an out-of-bounds index %d for epoch length %d.' %
                (n, epochs_data.shape[1]))
            raise WrongIndexError
        window = epochs_data[:, n:(n + w_length)]

        if preprocess is not None:
            window = pu.preprocess(window,
                                   sfreq=preprocess['sfreq'],
                                   spatial=preprocess['spatial'],
                                   spatial_ch=preprocess['spatial_ch'],
                                   spectral=preprocess['spectral'],
                                   spectral_ch=preprocess['spectral_ch'],
                                   notch=preprocess['notch'],
                                   notch_ch=preprocess['notch_ch'],
                                   multiplier=preprocess['multiplier'],
                                   ch_names=preprocess['ch_names'],
                                   rereference=preprocess['rereference'],
                                   decim=preprocess['decim'],
                                   n_jobs=preprocess['n_jobs'])

        # dimension: psde.transform( [epochs x channels x times] )
        psd = psde.transform(
            window.reshape((1, window.shape[0], window.shape[1])))
        psd = psd.reshape((psd.shape[0], psd.shape[1] * psd.shape[2]))
        if picks:
            psd = psd[0][picks]
            psd = psd.reshape((1, len(psd)))

        if X is None:
            X = psd
        else:
            X = np.concatenate((X, psd), axis=0)

        if verbose == True:
            logger.info('[PID %d] processing frame %d / %d' %
                        (os.getpid(), n, w_starts[-1]))

    return X
Exemplo n.º 7
0
    def get_prob(self, timestamp=False):
        """
        Read the latest window

        Input
        -----
        timestamp: If True, returns LSL timestamp of the leading edge of the window used for decoding.

        Returns
        -------
        The likelihood P(X|C), where X=window, C=model
        """
        if self.fake:
            # fake deocder: biased likelihood for the first class
            probs = [random.uniform(0.0, 1.0)]
            # others class likelihoods are just set to equal
            p_others = (1 - probs[0]) / (len(self.labels) - 1)
            for x in range(1, len(self.labels)):
                probs.append(p_others)
            time.sleep(0.0625)  # simulated delay
            t_prob = pylsl.local_clock()
        else:
            self.sr.acquire(blocking=True)
            w, ts = self.sr.get_window()  # w = times x channels
            t_prob = ts[-1]
            w = w.T  # -> channels x times

            # re-reference channels
            # TODO: add re-referencing function to preprocess()

            # apply filters. Important: maintain the original channel order at this point.
            w = pu.preprocess(w, sfreq=self.sfreq, spatial=self.spatial, spatial_ch=self.spatial_ch,
                          spectral=self.spectral, spectral_ch=self.spectral_ch, notch=self.notch,
                          notch_ch=self.notch_ch, multiplier=self.multiplier, decim=self.decim)

            # select the same channels used for training
            w = w[self.picks]

            # debug: show max - min
            # c=1; print( '### %d: %.1f - %.1f = %.1f'% ( self.picks[c], max(w[c]), min(w[c]), max(w[c])-min(w[c]) ) )

            # psd = channels x freqs
            psd = self.psde.transform(w.reshape((1, w.shape[0], w.shape[1])))

            # make a feautre vector and classify
            feats = np.concatenate(psd[0]).reshape(1, -1)

            # compute likelihoods
            probs = self.cls.predict_proba(feats)[0]

            # update psd buffer ( < 1 msec overhead )
            '''
            if self.psd_buffer is None:
                self.psd_buffer = psd
            else:
                self.psd_buffer = np.concatenate((self.psd_buffer, psd), axis=0)
                # TODO: CHECK THIS BLOCK
                self.ts_buffer.append(ts[0])
                if ts[0] - self.ts_buffer[0] > self.buffer_sec:
                    # search speed comparison for ordered arrays:
                    # http://stackoverflow.com/questions/16243955/numpy-first-occurence-of-value-greater-than-existing-value
                    #t_index = np.searchsorted(self.ts_buffer, ts[0] - 1.0)
                    t_index = np.searchsorted(self.ts_buffer, ts[0] - self.buffer_sec)
                    self.ts_buffer = self.ts_buffer[t_index:]
                    self.psd_buffer = self.psd_buffer[t_index:, :, :]  # numpy delete is slower
                # assert ts[0] - self.ts_buffer[0] <= self.buffer_sec
            '''

        if timestamp:
            return probs, t_prob
        else:
            return probs