def get_psd(sr, psde, picks): sr.acquire() w, ts = sr.get_window() # w = times x channels w = w.T # -> channels x times # apply filters. Important: maintain the original channel order at this point. pu.preprocess(w, sfreq=sfreq, spatial=spatial, spatial_ch=spatial_ch, spectral=spectral, spectral_ch=spectral_ch, notch=notch, notch_ch=notch_ch, multiplier=multiplier) # select the same channels used for training w = w[picks] # debug: show max - min # c=1; print( '### %d: %.1f - %.1f = %.1f'% ( picks[c], max(w[c]), min(w[c]), max(w[c])-min(w[c]) ) ) # psde.transform = [channels x freqs] psd = psde.transform(w) return psd
def get_tfr(fif_file, cfg, tfr, n_jobs=1): raw, events = pu.load_raw(fif_file) p = qc.parse_path(fif_file) fname = p.name outpath = p.dir export_dir = '%s/plot_%s' % (outpath, fname) qc.make_dirs(export_dir) # set channels of interest picks = pu.channel_names_to_index(raw, cfg.CHANNEL_PICKS) spchannels = pu.channel_names_to_index(raw, cfg.SP_CHANNELS) if max(picks) > len(raw.info['ch_names']): msg = 'ERROR: "picks" has a channel index %d while there are only %d channels.' %\ (max(picks), len(raw.info['ch_names'])) raise RuntimeError(msg) # Apply filters pu.preprocess(raw, spatial=cfg.SP_FILTER, spatial_ch=spchannels, spectral=cfg.TP_FILTER, spectral_ch=picks, notch=cfg.NOTCH_FILTER, notch_ch=picks, multiplier=cfg.MULTIPLIER, n_jobs=n_jobs) # MNE TFR functions do not support Raw instances yet, so convert to Epoch if cfg.EVENT_START is None: raw._data[0][0] = 1 events = np.array([[0, 0, 1]]) classes = None else: classes = {'START':cfg.EVENT_START} tmax = (raw._data.shape[1] - 1) / raw.info['sfreq'] epochs_all = mne.Epochs(raw, events, classes, tmin=0, tmax=tmax, picks=picks, baseline=None, preload=True) print('\n>> Processing %s' % fif_file) freqs = cfg.FREQ_RANGE # define frequencies of interest n_cycles = freqs / 2. # different number of cycle per frequency power = tfr(epochs_all, freqs=freqs, n_cycles=n_cycles, use_fft=False, return_itc=False, decim=1, n_jobs=n_jobs) if cfg.EXPORT_MATLAB is True: # export all channels to MATLAB mout = '%s/%s-%s.mat' % (export_dir, fname, cfg.SP_FILTER) scipy.io.savemat(mout, {'tfr':power.data, 'chs':power.ch_names, 'events':events, 'sfreq':raw.info['sfreq'], 'freqs':cfg.FREQ_RANGE}) if cfg.EXPORT_PNG is True: # Plot power of each channel for ch in np.arange(len(picks)): ch_name = raw.ch_names[picks[ch]] title = 'Channel %s' % (ch_name) # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent' fig = power.plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False, colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False) fout = '%s/%s-%s-%s.png' % (export_dir, fname, cfg.SP_FILTER, ch_name) fig.savefig(fout) print('Exported %s' % fout) print('Finished !')
def get_prob(self): """ Read the latest window Returns ------- The likelihood P(X|C), where X=window, C=model """ if self.fake: # fake deocder: biased likelihood for the first class probs = [random.uniform(0.0, 1.0)] # others class likelihoods are just set to equal p_others = (1 - probs[0]) / (len(self.labels) - 1) for x in range(1, len(self.labels)): probs.append(p_others) time.sleep(0.0625) # simulated delay for PSD + RF else: self.sr.acquire() w, ts = self.sr.get_window() # w = times x channels w = w.T # -> channels x times # apply filters. Important: maintain the original channel order at this point. pu.preprocess(w, sfreq=self.sfreq, spatial=self.spatial, spatial_ch=self.spatial_ch, spectral=self.spectral, spectral_ch=self.spectral_ch, notch=self.notch, notch_ch=self.notch_ch, multiplier=self.multiplier) # select the same channels used for training w = w[self.picks] # debug: show max - min # c=1; print( '### %d: %.1f - %.1f = %.1f'% ( self.picks[c], max(w[c]), min(w[c]), max(w[c])-min(w[c]) ) ) # psd = channels x freqs psd = self.psde.transform(w.reshape((1, w.shape[0], w.shape[1]))) # update psd buffer ( < 1 msec overhead ) self.psd_buffer = np.concatenate((self.psd_buffer, psd), axis=0) self.ts_buffer.append(ts[0]) if ts[0] - self.ts_buffer[0] > self.buffer_sec: # search speed comparison for ordered arrays: # http://stackoverflow.com/questions/16243955/numpy-first-occurence-of-value-greater-than-existing-value t_index = np.searchsorted(self.ts_buffer, ts[0] - 1.0) self.ts_buffer = self.ts_buffer[t_index:] self.psd_buffer = self.psd_buffer[t_index:, :, :] # numpy delete is slower # assert ts[0] - self.ts_buffer[0] <= self.buffer_sec # make a feautre vector and classify feats = np.concatenate(psd[0]).reshape(1, -1) # compute likelihoods probs = self.cls.predict_proba(feats)[0] return probs
def get_tfr(cfg, recursive=False, n_jobs=1): ''' @params: tfr_type: 'multitaper' or 'morlet' recursive: if True, load raw files in sub-dirs recursively export_path: path to save plots n_jobs: number of cores to run in parallel ''' cfg = check_cfg(cfg) tfr_type = cfg.TFR_TYPE export_path = cfg.EXPORT_PATH t_buffer = cfg.T_BUFFER if tfr_type == 'multitaper': tfr = mne.time_frequency.tfr_multitaper elif tfr_type == 'morlet': tfr = mne.time_frequency.tfr_morlet elif tfr_type == 'butter': butter_order = 4 # TODO: parameterize tfr = lfilter elif tfr_type == 'fir': raise NotImplementedError else: raise ValueError('Wrong TFR type %s' % tfr_type) n_jobs = cfg.N_JOBS if n_jobs is None: n_jobs = mp.cpu_count() if hasattr(cfg, 'DATA_DIRS'): if export_path is None: raise ValueError('For multiple directories, cfg.EXPORT_PATH cannot be None') else: outpath = export_path # custom event file if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None: events = mne.read_events(cfg.EVENT_FILE) file_prefix = 'grandavg' # load and merge files from all directories flist = [] for ddir in cfg.DATA_DIRS: ddir = ddir.replace('\\', '/') if ddir[-1] != '/': ddir += '/' for f in qc.get_file_list(ddir, fullpath=True, recursive=recursive): if qc.parse_path(f).ext in ['fif', 'bdf', 'gdf']: flist.append(f) raw, events = pu.load_multi(flist) else: print('Loading', cfg.DATA_FILE) raw, events = pu.load_raw(cfg.DATA_FILE) # custom events if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None: events = mne.read_events(cfg.EVENT_FILE) if export_path is None: [outpath, file_prefix, _] = qc.parse_path_list(cfg.DATA_FILE) else: outpath = export_path # re-referencing if cfg.REREFERENCE is not None: pu.rereference(raw, cfg.REREFERENCE[1], cfg.REREFERENCE[0]) sfreq = raw.info['sfreq'] # set channels of interest picks = pu.channel_names_to_index(raw, cfg.CHANNEL_PICKS) spchannels = pu.channel_names_to_index(raw, cfg.SP_CHANNELS) if max(picks) > len(raw.info['ch_names']): msg = 'ERROR: "picks" has a channel index %d while there are only %d channels.' %\ (max(picks), len(raw.info['ch_names'])) raise RuntimeError(msg) # Apply filters pu.preprocess(raw, spatial=cfg.SP_FILTER, spatial_ch=spchannels, spectral=cfg.TP_FILTER, spectral_ch=picks, notch=cfg.NOTCH_FILTER, notch_ch=picks, multiplier=cfg.MULTIPLIER, n_jobs=n_jobs) # Read epochs classes = {} for t in cfg.TRIGGERS: if t in set(events[:, -1]): if hasattr(cfg, 'tdef'): classes[cfg.tdef.by_value[t]] = t else: classes[str(t)] = t if len(classes) == 0: raise ValueError('No desired event was found from the data.') try: tmin = cfg.EPOCH[0] tmin_buffer = tmin - t_buffer raw_tmax = raw._data.shape[1] / sfreq - 0.1 if cfg.EPOCH[1] is None: if cfg.POWER_AVERAGED: raise ValueError('EPOCH value cannot have None for grand averaged TFR') else: if len(cfg.TRIGGERS) > 1: raise ValueError('If the end time of EPOCH is None, only a single event can be defined.') t_ref = events[np.where(events[:,2] == list(cfg.TRIGGERS)[0])[0][0], 0] / sfreq tmax = raw_tmax - t_ref - t_buffer else: tmax = cfg.EPOCH[1] tmax_buffer = tmax + t_buffer if tmax_buffer > raw_tmax: raise ValueError('Epoch length with buffer (%.3f) is larger than signal length (%.3f)' % (tmax_buffer, raw_tmax)) #print('Epoch tmin = %.1f, tmax = %.1f, raw length = %.1f' % (tmin, tmax, raw_tmax)) epochs_all = mne.Epochs(raw, events, classes, tmin=tmin_buffer, tmax=tmax_buffer, proj=False, picks=picks, baseline=None, preload=True) if epochs_all.drop_log_stats() > 0: print('\n** Bad epochs found. Dropping into a Python shell.') print(epochs_all.drop_log) print('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \ (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq)) print('\nType exit to continue.\n') pdb.set_trace() except: print('\n*** (tfr_export) ERROR OCCURRED WHILE EPOCHING ***') traceback.print_exc() print('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \ (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq)) pdb.set_trace() power = {} for evname in classes: #export_dir = '%s/plot_%s' % (outpath, evname) export_dir = outpath qc.make_dirs(export_dir) print('\n>> Processing %s' % evname) freqs = cfg.FREQ_RANGE # define frequencies of interest n_cycles = freqs / 2. # different number of cycle per frequency if cfg.POWER_AVERAGED: # grand-average TFR epochs = epochs_all[evname][:] if len(epochs) == 0: print('No %s epochs. Skipping.' % evname) continue if tfr_type == 'butter': b, a = butter_bandpass(cfg.FREQ_RANGE[0], cfg.FREQ_RANGE[-1], sfreq, order=butter_order) tfr_filtered = lfilter(b, a, epochs, axis=2) tfr_hilbert = hilbert(tfr_filtered) tfr_power = abs(tfr_hilbert) tfr_data = np.mean(tfr_power, axis=0) elif tfr_type == 'fir': raise NotImplementedError else: power[evname] = tfr(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=False, return_itc=False, decim=1, n_jobs=n_jobs) power[evname] = power[evname].crop(tmin=tmin, tmax=tmax) tfr_data = power[evname].data if cfg.EXPORT_MATLAB is True: # export all channels to MATLAB mout = '%s/%s-%s-%s.mat' % (export_dir, file_prefix, cfg.SP_FILTER, evname) scipy.io.savemat(mout, {'tfr':tfr_data, 'chs':epochs.ch_names, 'events':events, 'sfreq':sfreq, 'epochs':cfg.EPOCH, 'freqs':cfg.FREQ_RANGE}) print('Exported %s' % mout) if cfg.EXPORT_PNG is True: # Inspect power for each channel for ch in np.arange(len(picks)): chname = raw.ch_names[picks[ch]] title = 'Peri-event %s - Channel %s' % (evname, chname) # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent' fig = power[evname].plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False, colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False) fout = '%s/%s-%s-%s-%s.png' % (export_dir, file_prefix, cfg.SP_FILTER, evname, chname) fig.savefig(fout) fig.clf() print('Exported to %s' % fout) else: # TFR per event for ep in range(len(epochs_all[evname])): epochs = epochs_all[evname][ep] if len(epochs) == 0: print('No %s epochs. Skipping.' % evname) continue power[evname] = tfr(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=False, return_itc=False, decim=1, n_jobs=n_jobs) power[evname] = power[evname].crop(tmin=tmin, tmax=tmax) if cfg.EXPORT_MATLAB is True: # export all channels to MATLAB mout = '%s/%s-%s-%s-ep%02d.mat' % (export_dir, file_prefix, cfg.SP_FILTER, evname, ep + 1) scipy.io.savemat(mout, {'tfr':power[evname].data, 'chs':power[evname].ch_names, 'events':events, 'sfreq':sfreq, 'tmin':tmin, 'tmax':tmax, 'freqs':cfg.FREQ_RANGE}) print('Exported %s' % mout) if cfg.EXPORT_PNG is True: # Inspect power for each channel for ch in np.arange(len(picks)): chname = raw.ch_names[picks[ch]] title = 'Peri-event %s - Channel %s, Trial %d' % (evname, chname, ep + 1) # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent' fig = power[evname].plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False, colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False) fout = '%s/%s-%s-%s-%s-ep%02d.png' % (export_dir, file_prefix, cfg.SP_FILTER, evname, chname, ep + 1) fig.savefig(fout) fig.clf() print('Exported %s' % fout) if hasattr(cfg, 'POWER_DIFF'): export_dir = '%s/diff' % outpath qc.make_dirs(export_dir) labels = classes.keys() df = power[labels[0]] - power[labels[1]] df.data = np.log(np.abs(df.data)) # Inspect power diff for each channel for ch in np.arange(len(picks)): chname = raw.ch_names[picks[ch]] title = 'Peri-event %s-%s - Channel %s' % (labels[0], labels[1], chname) # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent' fig = df.plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False, colorbar=True, title=title, vmin=3.0, vmax=-3.0, dB=False) fout = '%s/%s-%s-diff-%s-%s-%s.jpg' % (export_dir, file_prefix, cfg.SP_FILTER, labels[0], labels[1], chname) print('Exporting to %s' % fout) fig.savefig(fout) fig.clf() print('Finished !')
def compute_features(cfg): # Load file list ftrain = [] for f in qc.get_file_list(cfg.DATADIR, fullpath=True): if f[-4:] in ['.fif', '.fiff']: ftrain.append(f) # Preprocessing, epoching and PSD computation if len(ftrain) > 1 and cfg.CHANNEL_PICKS is not None and type( cfg.CHANNEL_PICKS[0]) == int: raise RuntimeError( 'When loading multiple EEG files, CHANNEL_PICKS must be list of string, not integers because they may have different channel order.' ) raw, events = pu.load_multi(ftrain) if cfg.REF_CH is not None: pu.rereference(raw, cfg.REF_CH[1], cfg.REF_CH[0]) if cfg.LOAD_EVENTS_FILE is not None: events = mne.read_events(cfg.LOAD_EVENTS_FILE) triggers = {cfg.tdef.by_value[c]: c for c in set(cfg.TRIGGER_DEF)} # Pick channels if cfg.CHANNEL_PICKS is None: chlist = [int(x) for x in pick_types(raw.info, stim=False, eeg=True)] else: chlist = cfg.CHANNEL_PICKS picks = [] for c in chlist: if type(c) == int: picks.append(c) elif type(c) == str: picks.append(raw.ch_names.index(c)) else: raise RuntimeError( 'CHANNEL_PICKS has a value of unknown type %s.\nCHANNEL_PICKS=%s' % (type(c), cfg.CHANNEL_PICKS)) if cfg.EXCLUDES is not None: for c in cfg.EXCLUDES: if type(c) == str: if c not in raw.ch_names: qc.print_c( 'Warning: Exclusion channel %s does not exist. Ignored.' % c, 'Y') continue c_int = raw.ch_names.index(c) elif type(c) == int: c_int = c else: raise RuntimeError( 'EXCLUDES has a value of unknown type %s.\nEXCLUDES=%s' % (type(c), cfg.EXCLUDES)) if c_int in picks: del picks[picks.index(c_int)] if max(picks) > len(raw.ch_names): raise ValueError( '"picks" has a channel index %d while there are only %d channels.' % (max(picks), len(raw.ch_names))) if hasattr(cfg, 'SP_CHANNELS') and cfg.SP_CHANNELS is not None: qc.print_c( 'compute_features(): SP_CHANNELS parameter is not supported yet. Will be set to CHANNEL_PICKS.', 'Y') if hasattr(cfg, 'TP_CHANNELS') and cfg.TP_CHANNELS is not None: qc.print_c( 'compute_features(): TP_CHANNELS parameter is not supported yet. Will be set to CHANNEL_PICKS.', 'Y') if hasattr(cfg, 'NOTCH_CHANNELS') and cfg.NOTCH_CHANNELS is not None: qc.print_c( 'compute_features(): NOTCH_CHANNELS parameter is not supported yet. Will be set to CHANNEL_PICKS.', 'Y') # Read epochs try: # Experimental: multiple epoch ranges if type(cfg.EPOCH[0]) is list: epochs_train = [] for ep in cfg.EPOCH: epoch = Epochs(raw, events, triggers, tmin=ep[0], tmax=ep[1], proj=False, picks=picks, baseline=None, preload=True, verbose=False, detrend=None) # Channels are already selected by 'picks' param so use all channels. pu.preprocess(epoch, spatial=cfg.SP_FILTER, spatial_ch=None, spectral=cfg.TP_FILTER, spectral_ch=None, notch=cfg.NOTCH_FILTER, notch_ch=None, multiplier=cfg.MULTIPLIER, n_jobs=cfg.N_JOBS) epochs_train.append(epoch) else: # Usual method: single epoch range epochs_train = Epochs(raw, events, triggers, tmin=cfg.EPOCH[0], tmax=cfg.EPOCH[1], proj=False, picks=picks, baseline=None, preload=True, verbose=False, detrend=None) # Channels are already selected by 'picks' param so use all channels. pu.preprocess(epochs_train, spatial=cfg.SP_FILTER, spatial_ch=None, spectral=cfg.TP_FILTER, spectral_ch=None, notch=cfg.NOTCH_FILTER, notch_ch=None, multiplier=cfg.MULTIPLIER, n_jobs=cfg.N_JOBS) except: qc.print_c('\n*** (trainer.py) ERROR OCCURRED WHILE EPOCHING ***\n', 'R') # Catch and throw errors from child processes traceback.print_exc() if interactive: print('Dropping into a shell.\n') embed() raise RuntimeError label_set = np.unique(triggers.values()) # Compute features if cfg.FEATURES == 'PSD': featdata = get_psd_feature(epochs_train, cfg.EPOCH, cfg.PSD, feat_picks=None, n_jobs=cfg.N_JOBS) elif cfg.FEATURES == 'TIMELAG': ''' TODO: Implement multiple epochs for timelag feature ''' raise NotImplementedError( 'MULTIPLE EPOCHS NOT IMPLEMENTED YET FOR TIMELAG FEATURE.') elif cfg.FEATURES == 'WAVELET': ''' TODO: Implement multiple epochs for wavelet feature ''' raise NotImplementedError( 'MULTIPLE EPOCHS NOT IMPLEMENTED YET FOR WAVELET FEATURE.') else: raise NotImplementedError('%s feature type is not supported.' % cfg.FEATURES) featdata['picks'] = picks featdata['sfreq'] = raw.info['sfreq'] featdata['ch_names'] = raw.ch_names return featdata
def slice_win(epochs_data, w_starts, w_length, psde, picks=None, title=None, flatten=True, preprocess=None, verbose=False): ''' Compute PSD values of a sliding window Params epochs_data ([channels]x[samples]): raw epoch data w_starts (list): starting indices of sample segments w_length (int): window length in number of samples psde: MNE PSDEstimator object picks (list): subset of channels within epochs_data title (string): print out the title associated with PID flatten (boolean): generate concatenated feature vectors If True: X = [windows] x [channels x freqs] If False: X = [windows] x [channels] x [freqs] preprocess (dict): None or parameters for pycnbi_utils.preprocess() with the following keys: sfreq, spatial, spatial_ch, spectral, spectral_ch, notch, notch_ch, multiplier, ch_names, rereference, decim, n_jobs Returns: [windows] x [channels*freqs] or [windows] x [channels] x [freqs] ''' # raise error for wrong indexing def WrongIndexError(Exception): logger.error('%s' % Exception) if type(w_length) is not int: logger.warning('w_length type is %s. Converting to int.' % type(w_length)) w_length = int(w_length) if title is None: title = '[PID %d] Frames %d-%d' % (os.getpid(), w_starts[0], w_starts[-1] + w_length - 1) else: title = '[PID %d] %s' % (os.getpid(), title) if preprocess is not None and preprocess['decim'] != 1: title += ' (decim factor %d)' % preprocess['decim'] logger.info(title) X = None for n in w_starts: n = int(round(n)) if n >= epochs_data.shape[1]: logger.error( 'w_starts has an out-of-bounds index %d for epoch length %d.' % (n, epochs_data.shape[1])) raise WrongIndexError window = epochs_data[:, n:(n + w_length)] if preprocess is not None: window = pu.preprocess(window, sfreq=preprocess['sfreq'], spatial=preprocess['spatial'], spatial_ch=preprocess['spatial_ch'], spectral=preprocess['spectral'], spectral_ch=preprocess['spectral_ch'], notch=preprocess['notch'], notch_ch=preprocess['notch_ch'], multiplier=preprocess['multiplier'], ch_names=preprocess['ch_names'], rereference=preprocess['rereference'], decim=preprocess['decim'], n_jobs=preprocess['n_jobs']) # dimension: psde.transform( [epochs x channels x times] ) psd = psde.transform( window.reshape((1, window.shape[0], window.shape[1]))) psd = psd.reshape((psd.shape[0], psd.shape[1] * psd.shape[2])) if picks: psd = psd[0][picks] psd = psd.reshape((1, len(psd))) if X is None: X = psd else: X = np.concatenate((X, psd), axis=0) if verbose == True: logger.info('[PID %d] processing frame %d / %d' % (os.getpid(), n, w_starts[-1])) return X
def get_prob(self, timestamp=False): """ Read the latest window Input ----- timestamp: If True, returns LSL timestamp of the leading edge of the window used for decoding. Returns ------- The likelihood P(X|C), where X=window, C=model """ if self.fake: # fake deocder: biased likelihood for the first class probs = [random.uniform(0.0, 1.0)] # others class likelihoods are just set to equal p_others = (1 - probs[0]) / (len(self.labels) - 1) for x in range(1, len(self.labels)): probs.append(p_others) time.sleep(0.0625) # simulated delay t_prob = pylsl.local_clock() else: self.sr.acquire(blocking=True) w, ts = self.sr.get_window() # w = times x channels t_prob = ts[-1] w = w.T # -> channels x times # re-reference channels # TODO: add re-referencing function to preprocess() # apply filters. Important: maintain the original channel order at this point. w = pu.preprocess(w, sfreq=self.sfreq, spatial=self.spatial, spatial_ch=self.spatial_ch, spectral=self.spectral, spectral_ch=self.spectral_ch, notch=self.notch, notch_ch=self.notch_ch, multiplier=self.multiplier, decim=self.decim) # select the same channels used for training w = w[self.picks] # debug: show max - min # c=1; print( '### %d: %.1f - %.1f = %.1f'% ( self.picks[c], max(w[c]), min(w[c]), max(w[c])-min(w[c]) ) ) # psd = channels x freqs psd = self.psde.transform(w.reshape((1, w.shape[0], w.shape[1]))) # make a feautre vector and classify feats = np.concatenate(psd[0]).reshape(1, -1) # compute likelihoods probs = self.cls.predict_proba(feats)[0] # update psd buffer ( < 1 msec overhead ) ''' if self.psd_buffer is None: self.psd_buffer = psd else: self.psd_buffer = np.concatenate((self.psd_buffer, psd), axis=0) # TODO: CHECK THIS BLOCK self.ts_buffer.append(ts[0]) if ts[0] - self.ts_buffer[0] > self.buffer_sec: # search speed comparison for ordered arrays: # http://stackoverflow.com/questions/16243955/numpy-first-occurence-of-value-greater-than-existing-value #t_index = np.searchsorted(self.ts_buffer, ts[0] - 1.0) t_index = np.searchsorted(self.ts_buffer, ts[0] - self.buffer_sec) self.ts_buffer = self.ts_buffer[t_index:] self.psd_buffer = self.psd_buffer[t_index:, :, :] # numpy delete is slower # assert ts[0] - self.ts_buffer[0] <= self.buffer_sec ''' if timestamp: return probs, t_prob else: return probs