Exemplo n.º 1
0
def balance_samples(X, Y, balance_type, verbose=False):
    if balance_type == 'OVER':
        """
        Oversample from classes that lack samples
        """
        label_set = np.unique(Y)
        max_set = []
        X_balanced = np.array(X)
        Y_balanced = np.array(Y)

        # find a class with maximum number of samples
        for c in label_set:
            yl = np.where(Y == c)[0]
            if len(max_set) == 0 or len(yl) > max_set[1]:
                max_set = [c, len(yl)]
        for c in label_set:
            if c == max_set[0]: continue
            yl = np.where(Y == c)[0]
            extra_samples = max_set[1] - len(yl)
            extra_idx = np.random.choice(yl, extra_samples)
            X_balanced = np.append(X_balanced, X[extra_idx], axis=0)
            Y_balanced = np.append(Y_balanced, Y[extra_idx], axis=0)
    elif balance_type == 'UNDER':
        """
        Undersample from classes that are excessive
        """
        label_set = np.unique(Y)
        min_set = []

        # find a class with minimum number of samples
        for c in label_set:
            yl = np.where(Y == c)[0]
            if len(min_set) == 0 or len(yl) < min_set[1]:
                min_set = [c, len(yl)]
        yl = np.where(Y == min_set[0])[0]
        X_balanced = np.array(X[yl])
        Y_balanced = np.array(Y[yl])
        for c in label_set:
            if c == min_set[0]: continue
            yl = np.where(Y == c)[0]
            reduced_idx = np.random.choice(yl, min_set[1])
            X_balanced = np.append(X_balanced, X[reduced_idx], axis=0)
            Y_balanced = np.append(Y_balanced, Y[reduced_idx], axis=0)
    elif balance_type is None or balance_type is False:
        return X, Y
    else:
        logger.error('Unknown balancing type %s' % balance_type)
        raise ValueError

    logger.info_green('\nNumber of samples after %ssampling' %
                      balance_type.lower())
    for c in label_set:
        logger.info(
            '%s: %d -> %d' %
            (c, len(np.where(Y == c)[0]), len(np.where(Y_balanced == c)[0])))

    return X_balanced, Y_balanced
Exemplo n.º 2
0
def train_decoder(cfg, featdata, feat_file=None):
    """
    Train the final decoder using all data
    """
    # Init a classifier
    selected_classifier = cfg.CLASSIFIER['selected']
    if selected_classifier == 'GB':
        cls = GradientBoostingClassifier(
            loss='deviance',
            learning_rate=cfg.CLASSIFIER[selected_classifier]['learning_rate'],
            n_estimators=cfg.CLASSIFIER[selected_classifier]['trees'],
            subsample=1.0,
            max_depth=cfg.CLASSIFIER[selected_classifier]['depth'],
            random_state=cfg.CLASSIFIER[selected_classifier]['seed'],
            max_features='sqrt',
            verbose=0,
            warm_start=False,
            presort='auto')
    elif selected_classifier == 'XGB':
        cls = XGBClassifier(
            loss='deviance',
            learning_rate=cfg.CLASSIFIER[selected_classifier]['learning_rate'],
            n_estimators=cfg.CLASSIFIER[selected_classifier]['trees'],
            subsample=1.0,
            max_depth=cfg.CLASSIFIER[selected_classifier]['depth'],
            random_state=cfg.GB['seed'],
            max_features='sqrt',
            verbose=0,
            warm_start=False,
            presort='auto')
    elif selected_classifier == 'RF':
        cls = RandomForestClassifier(
            n_estimators=cfg.CLASSIFIER[selected_classifier]['trees'],
            max_features='auto',
            max_depth=cfg.CLASSIFIER[selected_classifier]['depth'],
            n_jobs=cfg.N_JOBS,
            random_state=cfg.CLASSIFIER[selected_classifier]['seed'],
            oob_score=False,
            class_weight='balanced_subsample')
    elif selected_classifier == 'LDA':
        cls = LDA()
    elif selected_classifier == 'rLDA':
        cls = rLDA(cfg.CLASSIFIER[selected_classifier][r_coeff])
    else:
        logger.error('Unknown classifier %s' % selected_classifier)
        raise ValueError

    # Setup features
    X_data = featdata['X_data']
    Y_data = featdata['Y_data']
    wlen = featdata['wlen']
    if cfg.FEATURES['PSD']['wlen'] is None:
        cfg.FEATURES['PSD']['wlen'] = wlen
    w_frames = featdata['w_frames']
    ch_names = featdata['ch_names']
    X_data_merged = np.concatenate(X_data)
    Y_data_merged = np.concatenate(Y_data)
    if cfg.CV['BALANCE_SAMPLES']:
        X_data_merged, Y_data_merged = balance_samples(
            X_data_merged,
            Y_data_merged,
            cfg.CV['BALANCE_SAMPLES'],
            verbose=True)

    # Start training the decoder
    logger.info_green('Training the decoder')
    timer = qc.Timer()
    cls.n_jobs = cfg.N_JOBS
    cls.fit(X_data_merged, Y_data_merged)
    logger.info('Trained %d samples x %d dimension in %.1f sec' %\
          (X_data_merged.shape[0], X_data_merged.shape[1], timer.sec()))
    cls.n_jobs = 1  # always set n_jobs=1 for testing

    # Export the decoder
    classes = {c: cfg.tdef.by_value[c] for c in np.unique(Y_data)}
    if cfg.FEATURES['selected'] == 'PSD':
        data = dict(cls=cls,
                    ch_names=ch_names,
                    psde=featdata['psde'],
                    sfreq=featdata['sfreq'],
                    picks=featdata['picks'],
                    classes=classes,
                    epochs=cfg.EPOCH,
                    w_frames=w_frames,
                    w_seconds=cfg.FEATURES['PSD']['wlen'],
                    wstep=cfg.FEATURES['PSD']['wstep'],
                    spatial=cfg.SP_FILTER,
                    spatial_ch=featdata['picks'],
                    spectral=cfg.TP_FILTER[cfg.TP_FILTER['selected']],
                    spectral_ch=featdata['picks'],
                    notch=cfg.NOTCH_FILTER[cfg.NOTCH_FILTER['selected']],
                    notch_ch=featdata['picks'],
                    multiplier=cfg.MULTIPLIER,
                    ref_ch=cfg.REREFERENCE[cfg.REREFERENCE['selected']],
                    decim=cfg.FEATURES['PSD']['decim'])
    clsfile = '%s/classifier/classifier-%s.pkl' % (cfg.DATA_PATH,
                                                   platform.architecture()[0])
    qc.make_dirs('%s/classifier' % cfg.DATA_PATH)
    qc.save_obj(clsfile, data)
    logger.info('Decoder saved to %s' % clsfile)

    # Reverse-lookup frequency from FFT
    fq = 0
    if type(cfg.FEATURES['PSD']['wlen']) == list:
        fq_res = 1.0 / cfg.FEATURES['PSD']['wlen'][0]
    else:
        fq_res = 1.0 / cfg.FEATURES['PSD']['wlen']
    fqlist = []
    while fq <= cfg.FEATURES['PSD']['fmax']:
        if fq >= cfg.FEATURES['PSD']['fmin']:
            fqlist.append(fq)
        fq += fq_res

    # Show top distinctive features
    if cfg.FEATURES['selected'] == 'PSD':
        logger.info_green('Good features ordered by importance')
        if selected_classifier in ['RF', 'GB', 'XGB']:
            keys, values = qc.sort_by_value(list(cls.feature_importances_),
                                            rev=True)
        elif selected_classifier in ['LDA', 'rLDA']:
            keys, values = qc.sort_by_value(cls.w, rev=True)
        keys = np.array(keys)
        values = np.array(values)

        if cfg.EXPORT_GOOD_FEATURES:
            if feat_file is None:
                gfout = open('%s/classifier/good_features.txt' % cfg.DATA_PATH,
                             'w')
            else:
                gfout = open(feat_file, 'w')

        if type(wlen) is not list:
            ch_names = [ch_names[c] for c in featdata['picks']]
        else:
            ch_names = []
            for w in range(len(wlen)):
                for c in featdata['picks']:
                    ch_names.append('w%d-%s' % (w, ch_names[c]))

        chlist, hzlist = features.feature2chz(keys, fqlist, ch_names=ch_names)
        valnorm = values[:cfg.FEAT_TOPN].copy()
        valsum = np.sum(valnorm)
        if valsum == 0:
            valsum = 1
        valnorm = valnorm / valsum * 100.0

        # show top-N features
        for i, (ch, hz) in enumerate(zip(chlist, hzlist)):
            if i >= cfg.FEAT_TOPN:
                break
            txt = '%-3s %5.1f Hz  normalized importance %-6s  raw importance %-6s  feature %-5d' %\
                  (ch, hz, '%.2f%%' % valnorm[i], '%.2f%%' % (values[i] * 100.0), keys[i])
            logger.info(txt)

        if cfg.EXPORT_GOOD_FEATURES:
            gfout.write('Importance(%) Channel Frequency Index\n')
            for i, (ch, hz) in enumerate(zip(chlist, hzlist)):
                gfout.write('%.3f\t%s\t%s\t%d\n' %
                            (values[i] * 100.0, ch, hz, keys[i]))
            gfout.close()
Exemplo n.º 3
0
def cross_validate(cfg, featdata, cv_file=None):
    """
    Perform cross validation
    """
    # Init a classifier
    selected_classifier = cfg.CLASSIFIER['selected']
    if selected_classifier == 'GB':
        cls = GradientBoostingClassifier(
            loss='deviance',
            learning_rate=cfg.CLASSIFIER['GB']['learning_rate'],
            presort='auto',
            n_estimators=cfg.CLASSIFIER['GB']['trees'],
            subsample=1.0,
            max_depth=cfg.CLASSIFIER['GB']['depth'],
            random_state=cfg.CLASSIFIER['GB']['seed'],
            max_features='sqrt',
            verbose=0,
            warm_start=False)
    elif selected_classifier == 'XGB':
        cls = XGBClassifier(
            loss='deviance',
            learning_rate=cfg.CLASSIFIER['XGB']['learning_rate'],
            presort='auto',
            n_estimators=cfg.CLASSIFIER['XGB']['trees'],
            subsample=1.0,
            max_depth=cfg.CLASSIFIER['XGB']['depth'],
            random_state=cfg.CLASSIFIER['XGB'],
            max_features='sqrt',
            verbose=0,
            warm_start=False)
    elif selected_classifier == 'RF':
        cls = RandomForestClassifier(
            n_estimators=cfg.CLASSIFIER['RF']['trees'],
            max_features='auto',
            max_depth=cfg.CLASSIFIER['RF']['depth'],
            n_jobs=cfg.N_JOBS,
            random_state=cfg.CLASSIFIER['RF']['seed'],
            oob_score=False,
            class_weight='balanced_subsample')
    elif selected_classifier == 'LDA':
        cls = LDA()
    elif selected_classifier == 'rLDA':
        cls = rLDA(cfg.CLASSIFIER['rLDA']['r_coeff'])
    else:
        logger.error('Unknown classifier type %s' % selected_classifier)
        raise ValueError

    # Setup features
    X_data = featdata['X_data']
    Y_data = featdata['Y_data']
    wlen = featdata['wlen']

    # Choose CV type
    ntrials, nsamples, fsize = X_data.shape
    selected_cv = cfg.CV_PERFORM['selected']
    if selected_cv == 'LeaveOneOut':
        logger.info_green('%d-fold leave-one-out cross-validation' % ntrials)
        if SKLEARN_OLD:
            cv = LeaveOneOut(len(Y_data))
        else:
            cv = LeaveOneOut()
    elif selected_cv == 'StratifiedShuffleSplit':
        logger.info_green(
            '%d-fold stratified cross-validation with test set ratio %.2f' %
            (cfg.CV_PERFORM[selected_cv]['folds'],
             cfg.CV_PERFORM[selected_cv]['test_ratio']))
        if SKLEARN_OLD:
            cv = StratifiedShuffleSplit(
                Y_data[:, 0],
                cfg.CV_PERFORM[selected_cv]['folds'],
                test_size=cfg.CV_PERFORM[selected_cv]['test_ratio'],
                random_state=cfg.CV_PERFORM[selected_cv]['seed'])
        else:
            cv = StratifiedShuffleSplit(
                n_splits=cfg.CV_PERFORM[selected_cv]['folds'],
                test_size=cfg.CV_PERFORM[selected_cv]['test_ratio'],
                random_state=cfg.CV_PERFORM[selected_cv]['seed'])
    else:
        logger.error('%s is not supported yet. Sorry.' %
                     cfg.CV_PERFORM[cfg.CV_PERFORM['selected']])
        raise NotImplementedError
    logger.info('%d trials, %d samples per trial, %d feature dimension' %
                (ntrials, nsamples, fsize))

    # Do it!
    timer_cv = qc.Timer()
    scores, cm_txt = crossval_epochs(cv,
                                     X_data,
                                     Y_data,
                                     cls,
                                     cfg.tdef.by_value,
                                     cfg.CV['BALANCE_SAMPLES'],
                                     n_jobs=cfg.N_JOBS,
                                     ignore_thres=cfg.CV['IGNORE_THRES'],
                                     decision_thres=cfg.CV['DECISION_THRES'])
    t_cv = timer_cv.sec()

    # Export results
    txt = 'Cross validation took %d seconds.\n' % t_cv
    txt += '\n- Class information\n'
    txt += '%d epochs, %d samples per epoch, %d feature dimension (total %d samples)\n' %\
        (ntrials, nsamples, fsize, ntrials * nsamples)
    for ev in np.unique(Y_data):
        txt += '%s: %d trials\n' % (cfg.tdef.by_value[ev],
                                    len(np.where(Y_data[:, 0] == ev)[0]))
    if cfg.CV['BALANCE_SAMPLES']:
        txt += 'The number of samples was balanced using %ssampling.\n' % cfg.BALANCE_SAMPLES.lower(
        )
    txt += '\n- Experiment condition\n'
    txt += 'Sampling frequency: %.3f Hz\n' % featdata['sfreq']
    txt += 'Spatial filter: %s (channels: %s)\n' % (cfg.SP_FILTER,
                                                    cfg.SP_CHANNELS)
    txt += 'Spectral filter: %s\n' % cfg.TP_FILTER[cfg.TP_FILTER['selected']]
    txt += 'Notch filter: %s\n' % cfg.NOTCH_FILTER[
        cfg.NOTCH_FILTER['selected']]
    txt += 'Channels: ' + ','.join(
        [str(featdata['ch_names'][p]) for p in featdata['picks']]) + '\n'
    txt += 'PSD range: %.1f - %.1f Hz\n' % (cfg.FEATURES['PSD']['fmin'],
                                            cfg.FEATURES['PSD']['fmax'])
    txt += 'Window step: %.2f msec\n' % (
        1000.0 * cfg.FEATURES['PSD']['wstep'] / featdata['sfreq'])
    if type(wlen) is list:
        for i, w in enumerate(wlen):
            txt += 'Window size: %.1f msec\n' % (w * 1000.0)
            txt += 'Epoch range: %s sec\n' % (cfg.EPOCH[i])
    else:
        txt += 'Window size: %.1f msec\n' % (cfg.FEATURES['PSD']['wlen'] *
                                             1000.0)
        txt += 'Epoch range: %s sec\n' % (cfg.EPOCH)
    txt += 'Decimation factor: %d\n' % cfg.FEATURES['PSD']['decim']

    # Compute stats
    cv_mean, cv_std = np.mean(scores), np.std(scores)
    txt += '\n- Average CV accuracy over %d epochs (random seed=%s)\n' % (
        ntrials, cfg.CV_PERFORM[cfg.CV_PERFORM['selected']]['seed'])
    if cfg.CV_PERFORM[cfg.CV_PERFORM['selected']] in [
            'LeaveOneOut', 'StratifiedShuffleSplit'
    ]:
        txt += "mean %.3f, std: %.3f\n" % (cv_mean, cv_std)
    txt += 'Classifier: %s, ' % selected_classifier
    if selected_classifier == 'RF':
        txt += '%d trees, %s max depth, random state %s\n' % (
            cfg.CLASSIFIER['RF']['trees'], cfg.CLASSIFIER['RF']['depth'],
            cfg.CLASSIFIER['RF']['seed'])
    elif selected_classifier == 'GB' or selected_classifier == 'XGB':
        txt += '%d trees, %s max depth, %s learing_rate, random state %s\n' % (
            cfg.CLASSIFIER['GB']['trees'], cfg.CLASSIFIER['GB']['depth'],
            cfg.CLASSIFIER['GB']['learning_rate'],
            cfg.CLASSIFIER['GB']['seed'])
    elif selected_classifier == 'rLDA':
        txt += 'regularization coefficient %.2f\n' % cfg.CLASSIFIER['rLDA'][
            'r_coeff']
    if cfg.CV['IGNORE_THRES'] is not None:
        txt += 'Decision threshold: %.2f\n' % cfg.CV['IGNORE_THRES']
    txt += '\n- Confusion Matrix\n' + cm_txt
    logger.info(txt)

    # Export to a file
    if 'export_result' in cfg.CV_PERFORM[selected_cv] and cfg.CV_PERFORM[
            selected_cv]['export_result'] is True:
        if cv_file is None:
            if cfg.EXPORT_CLS is True:
                qc.make_dirs('%s/classifier' % cfg.DATA_PATH)
                fout = open('%s/classifier/cv_result.txt' % cfg.DATA_PATH, 'w')
            else:
                fout = open('%s/cv_result.txt' % cfg.DATA_PATH, 'w')
        else:
            fout = open(cv_file, 'w')
        fout.write(txt)
        fout.close()
Exemplo n.º 4
0
def balance_tpr(cfg, featdata):
    """
    Find the threshold of class index 0 that yields equal number of true positive samples of each class.
    Currently only available for binary classes.

    Params
    ======
    cfg: config module
    feetdata: feature data computed using compute_features()
    """

    n_jobs = cfg.N_JOBS
    if n_jobs is None:
        n_jobs = mp.cpu_count()
    if n_jobs > 1:
        logger.info('balance_tpr(): Using %d cores' % n_jobs)
        pool = mp.Pool(n_jobs)
        results = []

    # Init a classifier
    selected_classifier = cfg.CLASSIFIER[cfg.CLASSIFIER['selected']]
    if selected_classifier == 'GB':
        cls = GradientBoostingClassifier(
            loss='deviance',
            learning_rate=cfg.CLASSIFIER['GB']['learning_rate'],
            n_estimators=cfg.CLASSIFIER['GB']['trees'],
            subsample=1.0,
            max_depth=cfg.CLASSIFIER['GB']['depth'],
            random_state=cfg.CLASSIFIER[selected_classifier]['seed'],
            max_features='sqrt',
            verbose=0,
            warm_start=False,
            presort='auto')
    elif selected_classifier == 'XGB':
        cls = XGBClassifier(
            loss='deviance',
            learning_rate=cfg.CLASSIFIER['XGB']['learning_rate'],
            n_estimators=cfg.CLASSIFIER['XGB']['trees'],
            subsample=1.0,
            max_depth=cfg.CLASSIFIER['XGB']['depth'],
            random_state=cfg.CLASSIFIER['XGB']['seed'],
            max_features='sqrt',
            verbose=0,
            warm_start=False,
            presort='auto')
    elif selected_classifier == 'RF':
        cls = RandomForestClassifier(
            n_estimators=cfg.CLASSIFIER['RF']['trees'],
            max_features='auto',
            max_depth=cfg.CLASSIFIER['RF']['depth'],
            n_jobs=cfg.N_JOBS,
            random_state=cfg.CLASSIFIER['RF']['seed'],
            oob_score=False,
            class_weight='balanced_subsample')
    elif selected_classifier == 'LDA':
        cls = LDA()
    elif selected_classifier == 'rLDA':
        cls = rLDA(cfg.CLASSIFIER['rLDA'])
    else:
        logger.error('Unknown classifier type %s' % selected_classifier)
        raise ValueError

    # Setup features
    X_data = featdata['X_data']
    Y_data = featdata['Y_data']
    wlen = featdata['wlen']
    if cfg.CLASSIFIER['PSD']['wlen'] is None:
        cfg.CLASSIFIER['PSD']['wlen'] = wlen

    # Choose CV type
    ntrials, nsamples, fsize = X_data.shape
    selected_CV = cfg.CV_PERFORM[cfg.CV_PERFORM['selected']]
    if cselected_CV == 'LeaveOneOut':
        logger.info_green('\n%d-fold leave-one-out cross-validation' % ntrials)
        if SKLEARN_OLD:
            cv = LeaveOneOut(len(Y_data))
        else:
            cv = LeaveOneOut()
    elif selected_CV == 'StratifiedShuffleSplit':
        logger.info_green(
            '\n%d-fold stratified cross-validation with test set ratio %.2f' %
            (cfg.CV_PERFORM[selected_CV]['folds'],
             cfg.CV_PERFORM[selected_CV]['test_ratio']))
        if SKLEARN_OLD:
            cv = StratifiedShuffleSplit(
                Y_data[:, 0],
                cfg.CV_PERFORM[selected_CV]['folds'],
                test_size=cfg.CV_PERFORM[selected_CV]['test_ratio'],
                random_state=cfg.CV_PERFORM[selected_CV]['random_seed'])
        else:
            cv = StratifiedShuffleSplit(
                n_splits=cfg.CV_PERFORM[selected_CV]['folds'],
                test_size=cfg.CV_PERFORM[selected_CV]['test_ratio'],
                random_state=cfg.CV_PERFORM[selected_CV]['random_seed'])
    else:
        logger.error('%s is not supported yet. Sorry.' % selected_CV)
        raise NotImplementedError
    logger.info('%d trials, %d samples per trial, %d feature dimension' %
                (ntrials, nsamples, fsize))

    # For classifier itself, single core is usually faster
    cls.n_jobs = 1
    Y_preds = []

    if SKLEARN_OLD:
        splits = cv
    else:
        splits = cv.split(X_data, Y_data[:, 0])
    for cnum, (train, test) in enumerate(splits):
        X_train = np.concatenate(X_data[train])
        X_test = np.concatenate(X_data[test])
        Y_train = np.concatenate(Y_data[train])
        Y_test = np.concatenate(Y_data[test])
        if n_jobs > 1:
            results.append(
                pool.apply_async(
                    get_predict_proba,
                    [cls, X_train, Y_train, X_test, Y_test, cnum + 1]))
        else:
            Y_preds.append(
                get_predict_proba(cls, X_train, Y_train, X_test, Y_test,
                                  cnum + 1))
        cnum += 1

    # Aggregate predictions
    if n_jobs > 1:
        pool.close()
        pool.join()
        for r in results:
            Y_preds.append(r.get())
    Y_preds = np.concatenate(Y_preds, axis=0)

    # Find threshold for class index 0
    Y_preds = sorted(Y_preds)
    mid_idx = int(len(Y_preds) / 2)
    if len(Y_preds) == 1:
        return 0.5  # should not reach here in normal conditions
    elif len(Y_preds) % 2 == 0:
        thres = Y_preds[mid_idx -
                        1] + (Y_preds[mid_idx] - Y_preds[mid_idx - 1]) / 2
    else:
        thres = Y_preds[mid_idx]
    return thres
Exemplo n.º 5
0
def stream_player(server_name,
                  fif_file,
                  chunk_size,
                  auto_restart=True,
                  wait_start=True,
                  repeat=np.float('inf'),
                  high_resolution=False,
                  trigger_file=None):
    """
    Input
    =====
    server_name: LSL server name.
    fif_file: fif file to replay.
    chunk_size: number of samples to send at once (usually 16-32 is good enough).
    auto_restart: play from beginning again after reaching the end.
    wait_start: wait for user to start in the beginning.
    repeat: number of loops to play.
    high_resolution: use perf_counter() instead of sleep() for higher time resolution
                     but uses much more cpu due to polling.
    trigger_file: used to convert event numbers into event strings for readability.
    
    Note: Run pycnbi.set_log_level('DEBUG') to print out the relative time stamps since started.
    
    """
    raw, events = pu.load_raw(fif_file)
    sfreq = raw.info['sfreq']  # sampling frequency
    n_channels = len(raw.ch_names)  # number of channels
    if trigger_file is not None:
        tdef = trigger_def(trigger_file)
    try:
        event_ch = raw.ch_names.index('TRIGGER')
    except ValueError:
        event_ch = None
    if raw is not None:
        logger.info_green('Successfully loaded %s' % fif_file)
        logger.info('Server name: %s' % server_name)
        logger.info('Sampling frequency %.3f Hz' % sfreq)
        logger.info('Number of channels : %d' % n_channels)
        logger.info('Chunk size : %d' % chunk_size)
        for i, ch in enumerate(raw.ch_names):
            logger.info('%d %s' % (i, ch))
        logger.info('Trigger channel : %s' % event_ch)
    else:
        raise RuntimeError('Error while loading %s' % fif_file)

    # set server information
    sinfo = pylsl.StreamInfo(server_name, channel_count=n_channels, channel_format='float32',\
        nominal_srate=sfreq, type='EEG', source_id=server_name)
    desc = sinfo.desc()
    channel_desc = desc.append_child("channels")
    for ch in raw.ch_names:
        channel_desc.append_child('channel').append_child_value('label', str(ch))\
            .append_child_value('type','EEG').append_child_value('unit','microvolts')
    desc.append_child('amplifier').append_child('settings').append_child_value(
        'is_slave', 'false')
    desc.append_child('acquisition').append_child_value(
        'manufacturer', 'PyCNBI').append_child_value('serial_number', 'N/A')
    outlet = pylsl.StreamOutlet(sinfo, chunk_size=chunk_size)

    if wait_start:
        input('Press Enter to start streaming.')
    logger.info('Streaming started')

    idx_chunk = 0
    t_chunk = chunk_size / sfreq
    finished = False
    if high_resolution:
        t_start = time.perf_counter()
    else:
        t_start = time.time()

    # start streaming
    played = 1
    while played < repeat:
        idx_current = idx_chunk * chunk_size
        chunk = raw._data[:, idx_current:idx_current + chunk_size]
        data = chunk.transpose().tolist()
        if idx_current >= raw._data.shape[1] - chunk_size:
            finished = True
        if high_resolution:
            # if a resolution over 2 KHz is needed
            t_sleep_until = t_start + idx_chunk * t_chunk
            while time.perf_counter() < t_sleep_until:
                pass
        else:
            # time.sleep() can have 500 us resolution using the tweak tool provided.
            t_wait = t_start + idx_chunk * t_chunk - time.time()
            if t_wait > 0.001:
                time.sleep(t_wait)
        outlet.push_chunk(data)
        logger.debug('[%8.3fs] sent %d samples (LSL %8.3f)' %
                     (time.perf_counter(), len(data), pylsl.local_clock()))
        if event_ch is not None:
            event_values = set(chunk[event_ch]) - set([0])
            if len(event_values) > 0:
                if trigger_file is None:
                    logger.info('Events: %s' % event_values)
                else:
                    for event in event_values:
                        if event in tdef.by_value:
                            logger.info('Events: %s (%s)' %
                                        (event, tdef.by_value[event]))
                        else:
                            logger.info('Events: %s (Undefined event)' % event)
        idx_chunk += 1

        if finished:
            if auto_restart is False:
                input(
                    'Reached the end of data. Press Enter to restart or Ctrl+C to stop.'
                )
            else:
                logger.info('Reached the end of data. Restarting.')
            idx_chunk = 0
            finished = False
            if high_resolution:
                t_start = time.perf_counter()
            else:
                t_start = time.time()
            played += 1
Exemplo n.º 6
0
    def acquire(self, blocking=True):
        """
        Reads data into buffer. It is a blocking function as default.

        Fills the buffer and return the current chunk of data and timestamps.

        Returns:
            data [samples x channels], timestamps [samples]
        """
        timestamp_offset = False
        if len(self.timestamps[0]) == 0:
            timestamp_offset = True

        self.watchdog.reset()
        tslist = []
        received = False
        chunk = None
        while not received:
            while self.watchdog.sec() < 5:
                # chunk = [frames]x[ch], tslist = [frames]
                if len(tslist) == 0:
                    chunk, tslist = self.inlets[0].pull_chunk(
                        max_samples=self.stream_bufsize)
                    if blocking == False and len(tslist) == 0:
                        return np.empty((0, len(self.ch_list))), []
                if len(tslist) > 0:
                    if timestamp_offset is True:
                        lsl_clock = pylsl.local_clock()
                    received = True
                    break
                time.sleep(0.0005)
            else:
                logger.warning(
                    'Timeout occurred while acquiring data. Amp driver bug?')
                # give up and return empty values to avoid deadlock
                return np.empty((0, len(self.ch_list))), []
        data = np.array(chunk)

        # BioSemi has pull-up resistor instead of pull-down
        if self.amp_name == 'BioSemi' and self._lsl_tr_channel is not None:
            datatype = data.dtype
            data[:, self._lsl_tr_channel] = (np.bitwise_and(
                255, data[:, self._lsl_tr_channel].astype(int)) -
                                             1).astype(datatype)

        # multiply values (to change unit)
        if self.multiplier != 1:
            data[:, self._lsl_eeg_channels] *= self.multiplier

        if self._lsl_tr_channel is not None:
            # move trigger channel to 0 and add back to the buffer
            data = np.concatenate((data[:, self._lsl_tr_channel].reshape(
                -1, 1), data[:, self._lsl_eeg_channels]),
                                  axis=1)
        else:
            # add an empty channel with zeros to channel 0
            data = np.concatenate((np.zeros(
                (data.shape[0], 1)), data[:, self._lsl_eeg_channels]),
                                  axis=1)

        # add data to buffer
        chunk = data.tolist()
        self.buffers[0].extend(chunk)
        self.timestamps[0].extend(tslist)
        if self.bufsize > 0 and len(self.timestamps[0]) > self.bufsize:
            self.buffers[0] = self.buffers[0][-self.bufsize:]
            self.timestamps[0] = self.timestamps[0][-self.bufsize:]

        if timestamp_offset is True:
            timestamp_offset = False
            logger.info('LSL timestamp = %s' % lsl_clock)
            logger.info('Server timestamp = %s' % self.timestamps[-1][-1])
            self.lsl_time_offset = self.timestamps[-1][-1] - lsl_clock
            logger.info('Offset = %.3f ' % (self.lsl_time_offset))
            if abs(self.lsl_time_offset) > 0.1:
                logger.warning('LSL server has a high timestamp offset.')
            else:
                logger.info_green('LSL time server synchronized')
        ''' TODO: test the merging of multiple streams
        # if we have multiple synchronized amps
        if len(self.inlets) > 1:
            for i in range(1, len(self.inlets)):
                chunk, tslist = self.inlets[i].pull_chunk(max_samples=len(tslist))  # [frames][channels]
                self.buffers[i].extend(chunk)
                self.timestamps[i].extend(tslist)
                if self.bufsize > 0 and len(self.buffers[i]) > self.bufsize:
                    self.buffers[i] = self.buffers[i][-self.bufsize:]
        '''

        # data= array[samples, channels], tslist=[samples]
        return (data, tslist)
Exemplo n.º 7
0
def get_psd_feature(epochs_train,
                    window,
                    psdparam,
                    picks=None,
                    preprocess=None,
                    n_jobs=1):
    """
    Wrapper for get_psd() adding meta information.

    Input
    =====
    epochs_train: mne.Epochs object or list of mne.Epochs object.
    window: [t_start, t_end]. Time window range for computing PSD.
    psdparam: {fmin:float, fmax:float, wlen:float, wstep:int, decim:int}.
              fmin, fmax in Hz, wlen in seconds, wstep in number of samples.
    picks: Channels to compute features from.

    Output
    ======
    dict object containing computed features.
    """

    if type(window[0]) is list:
        sfreq = epochs_train[0].info['sfreq']
        wlen = []
        w_frames = []
        # multiple PSD estimators, defined for each epoch
        if type(psdparam) is list:
            '''
            TODO: implement multi-window PSD for each epoch
            assert len(psdparam) == len(window)
            for i, p in enumerate(psdparam):
                if p['wlen'] is None:
                    wl = window[i][1] - window[i][0]
                else:
                    wl = p['wlen']
                wlen.append(wl)
                w_frames.append(int(sfreq * wl))
            '''
            logger.error('Multiple psd function not implemented yet.')
            raise NotImplementedError
        # same PSD estimator for all epochs
        else:
            for i, e in enumerate(window):
                if psdparam['wlen'] is None:
                    wl = window[i][1] - window[i][0]
                else:
                    wl = psdparam['wlen']
                assert wl > 0
                wlen.append(wl)
                w_frames.append(int(round(sfreq * wl)))
    else:
        sfreq = epochs_train.info['sfreq']
        wlen = window[1] - window[0]
        if psdparam['wlen'] is None:
            psdparam['wlen'] = wlen
        w_frames = int(round(
            sfreq *
            psdparam['wlen']))  # window length in number of samples(frames)
    if 'decim' not in psdparam or psdparam['decim'] is None:
        psdparam['decim'] = 1

    psde_sfreq = sfreq / psdparam['decim']
    psde = mne.decoding.PSDEstimator(sfreq=psde_sfreq,
                                     fmin=psdparam['fmin'],
                                     fmax=psdparam['fmax'],
                                     bandwidth=None,
                                     adaptive=False,
                                     low_bias=True,
                                     n_jobs=1,
                                     normalization='length',
                                     verbose='WARNING')

    logger.info_green('PSD computation')
    if type(epochs_train) is list:
        X_all = []
        for i, ep in enumerate(epochs_train):
            X, Y_data = get_psd(ep,
                                psde,
                                w_frames[i],
                                psdparam['wstep'],
                                picks,
                                n_jobs=n_jobs,
                                preprocess=preprocess,
                                decim=psdparam['decim'])
            X_all.append(X)
        # concatenate along the feature dimension
        # feature index order: window block x channel block x frequency block
        # feature vector = [window1, window2, ...]
        # where windowX = [channel1, channel2, ...]
        # where channelX = [freq1, freq2, ...]
        X_data = np.concatenate(X_all, axis=2)
    else:
        # feature index order: channel block x frequency block
        # feature vector = [channel1, channel2, ...]
        # where channelX = [freq1, freq2, ...]
        X_data, Y_data = get_psd(epochs_train,
                                 psde,
                                 w_frames,
                                 psdparam['wstep'],
                                 picks,
                                 n_jobs=n_jobs,
                                 preprocess=preprocess,
                                 decim=psdparam['decim'])

    # assign relative timestamps for each feature. time reference is the leading edge of a window.
    w_starts = np.arange(0,
                         epochs_train.get_data().shape[2] - w_frames,
                         psdparam['wstep'])
    t_features = w_starts / sfreq + psdparam['wlen'] + window[0]
    return dict(X_data=X_data,
                Y_data=Y_data,
                wlen=wlen,
                w_frames=w_frames,
                psde=psde,
                times=t_features,
                decim=psdparam['decim'])
Exemplo n.º 8
0
    def __init__(self, classifier=None, buffer_size=1.0, fake=False, amp_serial=None, amp_name=None):
        """
        Params
        ------
        classifier: classifier file
        spatial: spatial filter to use
        buffer_size: length of the signal buffer in seconds
        """

        self.classifier = classifier
        self.buffer_sec = buffer_size
        self.fake = fake
        self.amp_serial = amp_serial
        self.amp_name = amp_name

        if self.fake == False:
            model = qc.load_obj(self.classifier)
            if model is None:
                logger.error('Classifier model is None.')
                raise ValueError
            self.cls = model['cls']
            self.psde = model['psde']
            self.labels = list(self.cls.classes_)
            self.label_names = [model['classes'][k] for k in self.labels]
            self.spatial = model['spatial']
            self.spectral = model['spectral']
            self.notch = model['notch']
            self.w_seconds = model['w_seconds']
            self.w_frames = model['w_frames']
            self.wstep = model['wstep']
            self.sfreq = model['sfreq']
            if 'decim' not in model:
                model['decim'] = 1
            self.decim = model['decim']
            if not int(round(self.sfreq * self.w_seconds)) == self.w_frames:
                logger.error('sfreq * w_sec %d != w_frames %d' % (int(round(self.sfreq * self.w_seconds)), self.w_frames))
                raise RuntimeError

            if 'multiplier' in model:
                self.multiplier = model['multiplier']
            else:
                self.multiplier = 1

            # Stream Receiver
            self.sr = StreamReceiver(window_size=self.w_seconds, amp_name=self.amp_name, amp_serial=self.amp_serial)
            if self.sfreq != self.sr.sample_rate:
                logger.error('Amplifier sampling rate (%.3f) != model sampling rate (%.3f). Stop.' % (self.sr.sample_rate, self.sfreq))
                raise RuntimeError

            # Map channel indices based on channel names of the streaming server
            self.spatial_ch = model['spatial_ch']
            self.spectral_ch = model['spectral_ch']
            self.notch_ch = model['notch_ch']
            #self.ref_ch = model['ref_ch'] # not supported yet
            self.ch_names = self.sr.get_channel_names()
            mc = model['ch_names']
            self.picks = [self.ch_names.index(mc[p]) for p in model['picks']]
            if self.spatial_ch is not None:
                self.spatial_ch = [self.ch_names.index(mc[p]) for p in model['spatial_ch']]
            if self.spectral_ch is not None:
                self.spectral_ch = [self.ch_names.index(mc[p]) for p in model['spectral_ch']]
            if self.notch_ch is not None:
                self.notch_ch = [self.ch_names.index(mc[p]) for p in model['notch_ch']]

            # PSD buffer
            #psd_temp = self.psde.transform(np.zeros((1, len(self.picks), self.w_frames // self.decim)))
            #self.psd_shape = psd_temp.shape
            #self.psd_size = psd_temp.size
            #self.psd_buffer = np.zeros((0, self.psd_shape[1], self.psd_shape[2]))
            #self.psd_buffer = None

            self.ts_buffer = []

            logger.info_green('Loaded classifier %s (sfreq=%.3f, decim=%d)' % (' vs '.join(self.label_names), self.sfreq, self.decim))
        else:
            # Fake left-right decoder
            model = None
            self.psd_shape = None
            self.psd_size = None
            # TODO: parameterize directions using fake_dirs
            self.labels = [11, 9]
            self.label_names = ['LEFT_GO', 'RIGHT_GO']
Exemplo n.º 9
0
def log_decoding(decoder, logfile, amp_name=None, amp_serial=None, pklfile=True, matfile=False, autostop=False, prob_smooth=False):
    """
    Decode online and write results with event timestamps

    input
    -----
    decoder: Decoder or DecoderDaemon class object.
    logfile: File name to contain the result in Python pickle format.
    amp_name: LSL server name (if known).
    amp_serial: LSL server serial number (if known).
    pklfile: Export results to Python pickle format.
    matfile: Export results to Matlab .mat file if True.
    autostop: Automatically finish when no more data is received.
    prob_smooth: Use smoothed probability values according to decoder's smoothing parameter.
    """

    import cv2
    import scipy

    # run event acquisition process in the background
    state = mp.Value('i', 1)
    event_queue = mp.Queue()
    proc = mp.Process(target=log_decoding_helper, args=[state, event_queue, amp_name, amp_serial, autostop])
    proc.start()
    logger.info_green('Spawned event acquisition process.')

    # init variables and choose decoding function
    labels = decoder.get_label_names()
    probs = []
    prob_times = []
    if prob_smooth:
        decode_fn = decoder.get_prob_smooth_unread
    else:
        decode_fn = decoder.get_prob_unread
        
    # simple controller UI
    cv2.namedWindow("Decoding", cv2.WINDOW_AUTOSIZE)
    cv2.moveWindow("Decoding", 1400, 50)
    img = np.zeros([100, 400, 3], np.uint8)
    cv2.putText(img, 'Press any key to start', (20, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
    cv2.imshow("Decoding", img)
    cv2.waitKeyEx()
    img *= 0
    cv2.putText(img, 'Press ESC to stop', (40, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
    cv2.imshow("Decoding", img)

    key = 0
    started = False
    tm_watchdog = qc.Timer(autoreset=True)
    tm_cls = qc.Timer()
    while key != 27:
        prob, prob_time = decode_fn(True)
        t_lsl = pylsl.local_clock()
        key = cv2.waitKeyEx(1)
        if prob is None:
            # watch dog
            if tm_cls.sec() > 5:
                if autostop and started:
                    logger.info('No more streaming data. Finishing.')
                    break
                tm_cls.reset()
            tm_watchdog.sleep_atleast(0.001)
            continue
        probs.append(prob)
        prob_times.append(prob_time)
        txt = '[%.3f] ' % prob_time
        txt += ', '.join(['%s: %.2f' % (l, p) for l, p in zip(labels, prob)])
        txt += ' (%d ms, LSL Diff = %.3f)' % (tm_cls.msec(), (t_lsl-prob_time))
        logger.info(txt)
        if not started:
            started = True
        tm_cls.reset()

    # finish up processes
    cv2.destroyAllWindows()
    logger.info('Cleaning up event acquisition process.')
    state.value = 0
    decoder.stop()
    event_times, event_values = event_queue.get()
    proc.join()

    # save values
    if len(prob_times) == 0:
        logger.error('No decoding result. Please debug.')
        import pdb
        pdb.set_trace()
    t_start = prob_times[0]
    probs = np.vstack(probs)
    event_times = np.array(event_times)
    event_times = event_times[np.where(event_times >= t_start)[0]] - t_start
    prob_times = np.array(prob_times) - t_start
    event_values = np.array(event_values)
    data = dict(probs=probs, prob_times=prob_times, event_times=event_times, event_values=event_values, labels=labels)
    if pklfile:
        qc.save_obj(logfile, data)
        logger.info('Saved to %s' % logfile)
    if matfile:
        pp = qc.parse_path(logfile)
        matout = '%s/%s.mat' % (pp.dir, pp.name)
        scipy.io.savemat(matout, data)
        logger.info('Saved to %s' % matout)