Beispiel #1
0
def any2fif(filename, interactive=False, outdir=None, channel_file=None):
    """
    Generic file format converter
    """
    p = qc.parse_path(filename)
    if outdir is not None:
        qc.make_dirs(outdir)

    if p.ext == 'pcl':
        eve_file = '%s/%s.txt' % (p.dir, p.name.replace('raw', 'eve'))
        if os.path.exists(eve_file):
            logger.info('Adding events from %s' % eve_file)
        else:
            eve_file = None
        pcl2fif(filename,
                interactive=interactive,
                outdir=outdir,
                external_event=eve_file)
    elif p.ext == 'eeg':
        eeg2fif(filename, interactive=interactive, outdir=outdir)
    elif p.ext in ['edf', 'bdf']:
        bdf2fif(filename, interactive=interactive, outdir=outdir)
    elif p.ext == 'gdf':
        gdf2fif(filename,
                interactive=interactive,
                outdir=outdir,
                channel_file=channel_file)
    elif p.ext == 'xdf':
        xdf2fif(filename, interactive=interactive, outdir=outdir)
    else:  # unknown format
        logger.error(
            'Ignored unrecognized file extension %s. It should be [.pcl | .eeg | .gdf | .bdf]'
            % p.ext)
Beispiel #2
0
def get_tfr(fif_file, cfg, tfr, n_jobs=1):
    raw, events = pu.load_raw(fif_file)
    p = qc.parse_path(fif_file)
    fname = p.name
    outpath = p.dir

    export_dir = '%s/plot_%s' % (outpath, fname)
    qc.make_dirs(export_dir)

    # set channels of interest
    picks = pu.channel_names_to_index(raw, cfg.CHANNEL_PICKS)
    spchannels = pu.channel_names_to_index(raw, cfg.SP_CHANNELS)

    if max(picks) > len(raw.info['ch_names']):
        msg = 'ERROR: "picks" has a channel index %d while there are only %d channels.' %\
              (max(picks), len(raw.info['ch_names']))
        raise RuntimeError(msg)

    # Apply filters
    pu.preprocess(raw, spatial=cfg.SP_FILTER, spatial_ch=spchannels, spectral=cfg.TP_FILTER,
                  spectral_ch=picks, notch=cfg.NOTCH_FILTER, notch_ch=picks,
                  multiplier=cfg.MULTIPLIER, n_jobs=n_jobs)

    # MNE TFR functions do not support Raw instances yet, so convert to Epoch
    if cfg.EVENT_START is None:
        raw._data[0][0] = 1
        events = np.array([[0, 0, 1]])
        classes = None
    else:
        classes = {'START':cfg.EVENT_START}
    tmax = (raw._data.shape[1] - 1) / raw.info['sfreq']
    epochs_all = mne.Epochs(raw, events, classes, tmin=0, tmax=tmax,
                    picks=picks, baseline=None, preload=True)
    print('\n>> Processing %s' % fif_file)
    freqs = cfg.FREQ_RANGE  # define frequencies of interest
    n_cycles = freqs / 2.  # different number of cycle per frequency
    power = tfr(epochs_all, freqs=freqs, n_cycles=n_cycles, use_fft=False,
        return_itc=False, decim=1, n_jobs=n_jobs)

    if cfg.EXPORT_MATLAB is True:
        # export all channels to MATLAB
        mout = '%s/%s-%s.mat' % (export_dir, fname, cfg.SP_FILTER)
        scipy.io.savemat(mout, {'tfr':power.data, 'chs':power.ch_names, 'events':events,
            'sfreq':raw.info['sfreq'], 'freqs':cfg.FREQ_RANGE})

    if cfg.EXPORT_PNG is True:
        # Plot power of each channel
        for ch in np.arange(len(picks)):
            ch_name = raw.ch_names[picks[ch]]
            title = 'Channel %s' % (ch_name)
            # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
            fig = power.plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False,
                colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False)
            fout = '%s/%s-%s-%s.png' % (export_dir, fname, cfg.SP_FILTER, ch_name)
            fig.savefig(fout)
            print('Exported %s' % fout)

    print('Finished !')
Beispiel #3
0
def load_raw(rawfile,
             spfilter=None,
             spchannels=None,
             events_ext=None,
             multiplier=1,
             verbose='ERROR'):
    """
    Loads data from a fif-format file.
    You can convert non-fif files (.eeg, .bdf, .gdf, .pcl) to fif format.

    Parameters:
    rawfile: (absolute) data file path
    spfilter: 'car' | 'laplacian' | None
    spchannels: None | list (for CAR) | dict (for LAPLACIAN)
        'car': channel indices used for CAR filtering. If None, use all channels except
               the trigger channel (index 0).
        'laplacian': {channel:[neighbor1, neighbor2, ...], ...}
        *** Note ***
        Since PyCNBI puts trigger channel as index 0, data channel starts from index 1.
    events_ext: Add externally recorded events.
                [ [sample_index1, 0, event_value1],... ]
    multiplier: Multiply all values except triggers (to convert unit).

    Returns:
    raw: mne.io.RawArray object. First channel (index 0) is always trigger channel.
    events: mne-compatible events numpy array object (N x [frame, 0, type])
    spfilter= {None | 'car' | 'laplacian'}

    """

    if not os.path.exists(rawfile):
        raise IOError('File %s not found' % rawfile)
    if not os.path.isfile(rawfile):
        raise IOError('%s is not a file' % rawfile)

    extension = qc.parse_path(rawfile).ext
    assert extension in ['fif', 'fiff'], 'only fif format is supported'
    raw = mne.io.Raw(rawfile, preload=True, verbose=verbose)
    preprocess(raw,
               spatial=spfilter,
               spatial_ch=spchannels,
               multiplier=multiplier)
    if events_ext is not None:
        events = mne.read_events(events_ext)
    else:
        tch = find_event_channel(raw)
        if tch is not None:
            events = mne.find_events(raw,
                                     stim_channel=raw.ch_names[tch],
                                     shortest_event=1,
                                     uint_cast=True,
                                     consecutive=True)
        else:
            events = []

    return raw, events
Beispiel #4
0
def main(input_dir, channel_file=None):
    count = 0
    for f in qc.get_file_list(input_dir, fullpath=True, recursive=True):
        p = qc.parse_path(f)
        outdir = p.dir + '/fif/'
        if p.ext in ['pcl', 'bdf', 'edf', 'gdf', 'eeg', 'xdf']:
            print('Converting %s' % f)
            any2fif(f, interactive=True, outdir=outdir, channel_file=channel_file)
            count += 1

    print('\n>> %d files converted.' % count)
Beispiel #5
0
def load_config(cfg_module):
    cfg_module = cfg_module.replace('\\', '/')
    if '/' in cfg_module:
        spec = importlib.util.spec_from_file_location(
            qc.parse_path(cfg_module).name, cfg_module)
        cfg = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(cfg)
    else:
        cfg = importlib.import_module(cfg_module)
        importlib.reload(cfg)  # in case cfg_module was dynamically changed
    logger.info('Loaded config %s' % cfg_module)
    return cfg
Beispiel #6
0
def load_config(cfg_module):
    cfg_module = cfg_module.replace('\\', '/')
    if '/' in cfg_module:
        pp = qc.parse_path(cfg_module)
        cwd = os.getcwd()
        os.chdir(pp.dir)
        cfg = importlib.import_module(pp.name)
        os.chdir(cwd)
    else:
        cfg = importlib.import_module(cfg_module)
    logger.info('Loaded config %s' % cfg_module)
    return cfg
Beispiel #7
0
def fif2mat_file(fif_file, out_dir='./'):
    raw, events = pu.load_raw(fif_file)
    events[:, 0] += 1  # MATLAB uses 1-based indexing
    sfreq = raw.info['sfreq']
    data = dict(signals=raw._data,
                events=events,
                sfreq=sfreq,
                ch_names=raw.ch_names)
    fname = qc.parse_path(fif_file).name
    matfile = '%s/%s.mat' % (out_dir, fname)
    scipy.io.savemat(matfile, data)
    logger.info('Exported to %s' % matfile)
Beispiel #8
0
def fix_channel_names(fif_dir, new_channel_names):
    '''
    Change channel names of fif files in a given directory.

    Input
    -----
    @fif_dir: path to fif files
    @new_channel_names: list of new channel names

    Output
    ------
    Modified fif files are saved in fif_dir/corrected/

    Kyuhwa Lee, 2019.
    '''

    flist = []
    for f in qc.get_file_list(fif_dir):
        if qc.parse_path(f).ext == 'fif':
            flist.append(f)

    if len(flist) > 0:
        qc.make_dirs('%s/corrected' % fif_dir)
        for f in qc.get_file_list(fif_dir):
            logger.info('\nLoading %s' % f)
            p = qc.parse_path(f)
            if p.ext == 'fif':
                raw, eve = pu.load_raw(f)
                if len(raw.ch_names) != len(new_channel_names):
                    raise RuntimeError(
                        'The number of new channels do not matach that of fif file.'
                    )
                raw.info['ch_names'] = new_channel_names
                for ch, new_ch in zip(raw.info['chs'], new_channel_names):
                    ch['ch_name'] = new_ch
                out_fif = '%s/corrected/%s.fif' % (p.dir, p.name)
                logger.info('Exporting to %s' % out_fif)
                raw.save(out_fif)
    else:
        logger.warning('No fif files found in %s' % fif_dir)
Beispiel #9
0
def fif_resample(fif_dir, sfreq_target):
    out_dir = fif_dir + '/fif_resample%d' % sfreq_target
    qc.make_dirs(out_dir)
    for f in qc.get_file_list(fif_dir):
        pp = qc.parse_path(f)
        if pp.ext != 'fif':
            continue
        logger.info('Resampling %s' % f)
        raw, events = pu.load_raw(f)
        raw.resample(sfreq_target)
        fif_out = '%s/%s.fif' % (out_dir, pp.name)
        raw.save(fif_out)
        logger.info('Exported to %s' % fif_out)
Beispiel #10
0
def trigger_def(ini_file, verbose=False):
    class TriggerDef(object):
        def __init__(self, items):
            self.by_name = {}
            self.by_value = {}
            for key, value in items:
                value = int(value)
                setattr(self, key, value)
                self.by_name[key] = value
                self.by_value[value] = key

        # show all possible trigger values
        def check_data(self):
            print('Attributes of the final class')
            for attr in dir(self):
                if not callable(getattr(self,
                                        attr)) and not attr.startswith("__"):
                    print(attr, getattr(self, attr))

    if not os.path.exists(ini_file):
        search_path = []
        path_ini = qc.parse_path(ini_file)
        path_self = qc.parse_path(__file__)
        search_path.append(ini_file + '.ini')
        search_path.append('%s/%s' % (path_self.dir, path_ini.name))
        search_path.append('%s/%s.ini' % (path_self.dir, path_ini.name))
        for ini_file in search_path:
            if os.path.exists(ini_file):
                if verbose:
                    logger.info('Found trigger definition file %s' % ini_file)
                break
        else:
            raise IOError('Trigger event definition file %s not found' %
                          ini_file)
    config = ConfigParser(inline_comment_prefixes=('#', ';'))
    config.optionxform = str
    config.read(ini_file)
    return TriggerDef(config.items('events'))
Beispiel #11
0
def fif2mat(data_dir):
    out_dir = '%s/mat_files' % data_dir
    qc.make_dirs(out_dir)
    for rawfile in qc.get_file_list(data_dir, fullpath=True):
        if rawfile[-4:] != '.fif': continue
        raw, events = pu.load_raw(rawfile)
        events[:,0] += 1 # MATLAB uses 1-based indexing
        sfreq = raw.info['sfreq']
        data = dict(signals=raw._data, events=events, sfreq=sfreq, ch_names=raw.ch_names)
        fname = qc.parse_path(rawfile).name
        matfile = '%s/%s.mat' % (out_dir, fname)
        scipy.io.savemat(matfile, data)
        logger.info('Exported to %s' % matfile)
    logger.info('Finished exporting.')
Beispiel #12
0
def fif2mat(input_path):
    if os.path.isdir(input_path):
        out_dir = '%s/mat_files' % input_path
        qc.make_dirs(out_dir)
        num_processed = 0
        for rawfile in qc.get_file_list(input_path, fullpath=True):
            if rawfile[-4:] != '.fif':
                continue
            fif2mat_file(rawfile, out_dir)
            num_processed += 1
        if num_processed == 0:
            logger.warning('No fif files found in the path.')
    elif os.path.isfile(input_path):
        out_dir = '%s/mat_files' % qc.parse_path(input_path).dir
        qc.make_dirs(out_dir)
        fif2mat_file(input_path, out_dir)
    else:
        raise ValueError('Neither directory nor file: %s' % input_path)
    logger.info('Finished.')
Beispiel #13
0
# auto import subpackages
from pycnbi.utils import q_common as qc
import os
ROOT = qc.parse_path(os.path.realpath(__file__)).dir
for d in qc.get_dir_list(ROOT):
    if os.path.exists('%s/__init__.py' % d):
        exe_package = 'import pycnbi.%s' % d.replace(ROOT + '/', '')
        exec(exe_package)
def feature_importances_topo(featfile,
                             topo_layout_file=None,
                             channels=None,
                             channel_name_show=None):
    """
    Compute feature importances across frequency bands and channels

    @params
    topo_laytout_file: if not None, topography map images will be generated and saved.
    channel_name_show: list of channel names to show on topography map.

    """
    print('Loading %s' % featfile)

    if channels is None:
        channel_set = set()
        with open(featfile) as f:
            f.readline()
            for l in f:
                ch = l.strip().split('\t')[1]
                channel_set.add(ch)
        channels = list(channel_set)

    # channel index lookup table
    ch2index = {ch: i for i, ch in enumerate(channels)}

    data_delta = np.zeros(len(channels))
    data_theta = np.zeros(len(channels))
    data_mu = np.zeros(len(channels))
    data_beta = np.zeros(len(channels))
    data_beta1 = np.zeros(len(channels))
    data_beta2 = np.zeros(len(channels))
    data_beta3 = np.zeros(len(channels))
    data_lgamma = np.zeros(len(channels))
    data_hgamma = np.zeros(len(channels))
    data_per_ch = np.zeros(len(channels))

    f = open(featfile)
    f.readline()
    for l in f:
        token = l.strip().split('\t')
        importance = float(token[0])
        ch = token[1]
        fq = float(token[2])
        if fq <= 3:
            data_delta[ch2index[ch]] += importance
        elif fq <= 7:
            data_theta[ch2index[ch]] += importance
        elif fq <= 12:
            data_mu[ch2index[ch]] += importance
        elif fq <= 30:
            data_beta[ch2index[ch]] += importance
        elif fq <= 70:
            data_lgamma[ch2index[ch]] += importance
        else:
            data_hgamma[ch2index[ch]] += importance
        if 12.5 <= fq <= 16:
            data_beta1[ch2index[ch]] += importance
        elif fq <= 20:
            data_beta2[ch2index[ch]] += importance
        elif fq <= 28:
            data_beta3[ch2index[ch]] += importance
        data_per_ch[ch2index[ch]] += importance

    hlen = 18 + len(channels) * 7
    result = '>> Feature importance distribution\n'
    result += 'bands   ' + qc.list2string(channels,
                                          '%6s') + ' | ' + 'per band\n'
    result += '-' * hlen + '\n'
    result += 'delta   ' + qc.list2string(
        data_delta, '%6.2f') + ' | %6.2f\n' % np.sum(data_delta)
    result += 'theta   ' + qc.list2string(
        data_theta, '%6.2f') + ' | %6.2f\n' % np.sum(data_theta)
    result += 'mu      ' + qc.list2string(
        data_mu, '%6.2f') + ' | %6.2f\n' % np.sum(data_mu)
    #result += 'beta    ' + qc.list2string(data_beta, '%6.2f') + ' | %6.2f\n' % np.sum(data_beta)
    result += 'beta1   ' + qc.list2string(
        data_beta1, '%6.2f') + ' | %6.2f\n' % np.sum(data_beta1)
    result += 'beta2   ' + qc.list2string(
        data_beta2, '%6.2f') + ' | %6.2f\n' % np.sum(data_beta2)
    result += 'beta3   ' + qc.list2string(
        data_beta3, '%6.2f') + ' | %6.2f\n' % np.sum(data_beta3)
    result += 'lgamma  ' + qc.list2string(
        data_lgamma, '%6.2f') + ' | %6.2f\n' % np.sum(data_lgamma)
    result += 'hgamma  ' + qc.list2string(
        data_hgamma, '%6.2f') + ' | %6.2f\n' % np.sum(data_hgamma)
    result += '-' * hlen + '\n'
    result += 'per_ch  ' + qc.list2string(data_per_ch, '%6.2f') + ' | 100.00\n'
    print(result)
    p = qc.parse_path(featfile)
    open('%s/%s_summary.txt' % (p.dir, p.name), 'w').write(result)

    # export topo maps
    if topo_layout_file is not None:
        # default visualization setting
        res = 64
        contours = 6

        # select channel names to show
        if channel_name_show is None:
            channel_name_show = channels
        chan_vis = [''] * len(channels)
        for ch in channel_name_show:
            chan_vis[channels.index(ch)] = ch

        # set channel locations and reverse lookup table
        chanloc = {}
        if not os.path.exists(topo_layout_file):
            topo_layout_file = PYCNBI_ROOT + '/layout/' + topo_layout_file
            if not os.path.exists(topo_layout_file):
                raise FileNotFoundError('Layout file %s not found.' %
                                        topo_layout_file)
        print('Using layout %s' % topo_layout_file)
        for l in open(topo_layout_file):
            token = l.strip().split('\t')
            ch = token[5]
            x = float(token[1])
            y = float(token[2])
            chanloc[ch] = [x, y]
        pos = np.zeros((len(channels), 2))
        for i, ch in enumerate(channels):
            pos[i] = chanloc[ch]

        vmin = min(data_per_ch)
        vmax = max(data_per_ch)
        total = sum(data_per_ch)
        rate_delta = sum(data_delta) * 100.0 / total
        rate_theta = sum(data_theta) * 100.0 / total
        rate_mu = sum(data_mu) * 100.0 / total
        rate_beta = sum(data_beta) * 100.0 / total
        rate_beta1 = sum(data_beta1) * 100.0 / total
        rate_beta2 = sum(data_beta2) * 100.0 / total
        rate_beta3 = sum(data_beta3) * 100.0 / total
        rate_lgamma = sum(data_lgamma) * 100.0 / total
        rate_hgamma = sum(data_hgamma) * 100.0 / total
        export_topo(data_per_ch,
                    pos,
                    'features_topo_all.png',
                    xlabel='all bands 1-40 Hz',
                    chan_vis=chan_vis)
        export_topo(data_delta,
                    pos,
                    'features_topo_delta.png',
                    xlabel='delta 1-3 Hz (%.1f%%)' % rate_delta,
                    chan_vis=chan_vis)
        export_topo(data_theta,
                    pos,
                    'features_topo_theta.png',
                    xlabel='theta 4-7 Hz (%.1f%%)' % rate_theta,
                    chan_vis=chan_vis)
        export_topo(data_mu,
                    pos,
                    'features_topo_mu.png',
                    xlabel='mu 8-12 Hz (%.1f%%)' % rate_mu,
                    chan_vis=chan_vis)
        export_topo(data_beta,
                    pos,
                    'features_topo_beta.png',
                    xlabel='beta 13-30 Hz (%.1f%%)' % rate_beta,
                    chan_vis=chan_vis)
        export_topo(data_beta1,
                    pos,
                    'features_topo_beta1.png',
                    xlabel='beta 12.5-16 Hz (%.1f%%)' % rate_beta1,
                    chan_vis=chan_vis)
        export_topo(data_beta2,
                    pos,
                    'features_topo_beta2.png',
                    xlabel='beta 16-20 Hz (%.1f%%)' % rate_beta2,
                    chan_vis=chan_vis)
        export_topo(data_beta3,
                    pos,
                    'features_topo_beta3.png',
                    xlabel='beta 20-28 Hz (%.1f%%)' % rate_beta3,
                    chan_vis=chan_vis)
        export_topo(data_lgamma,
                    pos,
                    'features_topo_lowgamma.png',
                    xlabel='low gamma 31-40 Hz (%.1f%%)' % rate_lgamma,
                    chan_vis=chan_vis)
Beispiel #15
0
def get_tfr(cfg, recursive=False, n_jobs=1):
    '''
    @params:
    tfr_type: 'multitaper' or 'morlet'
    recursive: if True, load raw files in sub-dirs recursively
    export_path: path to save plots
    n_jobs: number of cores to run in parallel
    '''

    cfg = check_cfg(cfg)
    tfr_type = cfg.TFR_TYPE
    export_path = cfg.EXPORT_PATH
    t_buffer = cfg.T_BUFFER
    if tfr_type == 'multitaper':
        tfr = mne.time_frequency.tfr_multitaper
    elif tfr_type == 'morlet':
        tfr = mne.time_frequency.tfr_morlet
    elif tfr_type == 'butter':
        butter_order = 4 # TODO: parameterize
        tfr = lfilter
    elif tfr_type == 'fir':
        raise NotImplementedError
    else:
        raise ValueError('Wrong TFR type %s' % tfr_type)
    n_jobs = cfg.N_JOBS
    if n_jobs is None:
        n_jobs = mp.cpu_count()

    if hasattr(cfg, 'DATA_DIRS'):
        if export_path is None:
            raise ValueError('For multiple directories, cfg.EXPORT_PATH cannot be None')
        else:
            outpath = export_path
        # custom event file
        if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None:
            events = mne.read_events(cfg.EVENT_FILE)
        file_prefix = 'grandavg'

        # load and merge files from all directories
        flist = []
        for ddir in cfg.DATA_DIRS:
            ddir = ddir.replace('\\', '/')
            if ddir[-1] != '/': ddir += '/'
            for f in qc.get_file_list(ddir, fullpath=True, recursive=recursive):
                if qc.parse_path(f).ext in ['fif', 'bdf', 'gdf']:
                    flist.append(f)
        raw, events = pu.load_multi(flist)
    else:
        print('Loading', cfg.DATA_FILE)
        raw, events = pu.load_raw(cfg.DATA_FILE)

        # custom events
        if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None:
            events = mne.read_events(cfg.EVENT_FILE)

        if export_path is None:
            [outpath, file_prefix, _] = qc.parse_path_list(cfg.DATA_FILE)
        else:
            outpath = export_path

    # re-referencing
    if cfg.REREFERENCE is not None:
        pu.rereference(raw, cfg.REREFERENCE[1], cfg.REREFERENCE[0])

    sfreq = raw.info['sfreq']

    # set channels of interest
    picks = pu.channel_names_to_index(raw, cfg.CHANNEL_PICKS)
    spchannels = pu.channel_names_to_index(raw, cfg.SP_CHANNELS)

    if max(picks) > len(raw.info['ch_names']):
        msg = 'ERROR: "picks" has a channel index %d while there are only %d channels.' %\
              (max(picks), len(raw.info['ch_names']))
        raise RuntimeError(msg)

    # Apply filters
    pu.preprocess(raw, spatial=cfg.SP_FILTER, spatial_ch=spchannels, spectral=cfg.TP_FILTER,
                  spectral_ch=picks, notch=cfg.NOTCH_FILTER, notch_ch=picks,
                  multiplier=cfg.MULTIPLIER, n_jobs=n_jobs)

    # Read epochs
    classes = {}
    for t in cfg.TRIGGERS:
        if t in set(events[:, -1]):
            if hasattr(cfg, 'tdef'):
                classes[cfg.tdef.by_value[t]] = t
            else:
                classes[str(t)] = t
    if len(classes) == 0:
        raise ValueError('No desired event was found from the data.')

    try:
        tmin = cfg.EPOCH[0]
        tmin_buffer = tmin - t_buffer
        raw_tmax = raw._data.shape[1] / sfreq - 0.1
        if cfg.EPOCH[1] is None:
            if cfg.POWER_AVERAGED:
                raise ValueError('EPOCH value cannot have None for grand averaged TFR')
            else:
                if len(cfg.TRIGGERS) > 1:
                    raise ValueError('If the end time of EPOCH is None, only a single event can be defined.')
                t_ref = events[np.where(events[:,2] == list(cfg.TRIGGERS)[0])[0][0], 0] / sfreq
                tmax = raw_tmax - t_ref - t_buffer
        else:
            tmax = cfg.EPOCH[1]
        tmax_buffer = tmax + t_buffer
        if tmax_buffer > raw_tmax:
            raise ValueError('Epoch length with buffer (%.3f) is larger than signal length (%.3f)' % (tmax_buffer, raw_tmax))

        #print('Epoch tmin = %.1f, tmax = %.1f, raw length = %.1f' % (tmin, tmax, raw_tmax))
        epochs_all = mne.Epochs(raw, events, classes, tmin=tmin_buffer, tmax=tmax_buffer,
                                proj=False, picks=picks, baseline=None, preload=True)
        if epochs_all.drop_log_stats() > 0:
            print('\n** Bad epochs found. Dropping into a Python shell.')
            print(epochs_all.drop_log)
            print('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \
                (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq))
            print('\nType exit to continue.\n')
            pdb.set_trace()
    except:
        print('\n*** (tfr_export) ERROR OCCURRED WHILE EPOCHING ***')
        traceback.print_exc()
        print('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \
            (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq))
        pdb.set_trace()

    power = {}
    for evname in classes:
        #export_dir = '%s/plot_%s' % (outpath, evname)
        export_dir = outpath
        qc.make_dirs(export_dir)
        print('\n>> Processing %s' % evname)
        freqs = cfg.FREQ_RANGE  # define frequencies of interest
        n_cycles = freqs / 2.  # different number of cycle per frequency
        if cfg.POWER_AVERAGED:
            # grand-average TFR
            epochs = epochs_all[evname][:]
            if len(epochs) == 0:
                print('No %s epochs. Skipping.' % evname)
                continue

            if tfr_type == 'butter':
                b, a = butter_bandpass(cfg.FREQ_RANGE[0], cfg.FREQ_RANGE[-1], sfreq, order=butter_order)
                tfr_filtered = lfilter(b, a, epochs, axis=2)
                tfr_hilbert = hilbert(tfr_filtered)
                tfr_power = abs(tfr_hilbert)
                tfr_data = np.mean(tfr_power, axis=0)
            elif tfr_type == 'fir':
                raise NotImplementedError
            else:
                power[evname] = tfr(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=False,
                    return_itc=False, decim=1, n_jobs=n_jobs)
                power[evname] = power[evname].crop(tmin=tmin, tmax=tmax)
                tfr_data = power[evname].data

            if cfg.EXPORT_MATLAB is True:
                # export all channels to MATLAB
                mout = '%s/%s-%s-%s.mat' % (export_dir, file_prefix, cfg.SP_FILTER, evname)
                scipy.io.savemat(mout, {'tfr':tfr_data, 'chs':epochs.ch_names,
                    'events':events, 'sfreq':sfreq, 'epochs':cfg.EPOCH, 'freqs':cfg.FREQ_RANGE})
                print('Exported %s' % mout)
            if cfg.EXPORT_PNG is True:
                # Inspect power for each channel
                for ch in np.arange(len(picks)):
                    chname = raw.ch_names[picks[ch]]
                    title = 'Peri-event %s - Channel %s' % (evname, chname)

                    # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
                    fig = power[evname].plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False,
                        colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False)
                    fout = '%s/%s-%s-%s-%s.png' % (export_dir, file_prefix, cfg.SP_FILTER, evname, chname)
                    fig.savefig(fout)
                    fig.clf()
                    print('Exported to %s' % fout)
        else:
            # TFR per event
            for ep in range(len(epochs_all[evname])):
                epochs = epochs_all[evname][ep]
                if len(epochs) == 0:
                    print('No %s epochs. Skipping.' % evname)
                    continue
                power[evname] = tfr(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=False,
                    return_itc=False, decim=1, n_jobs=n_jobs)
                power[evname] = power[evname].crop(tmin=tmin, tmax=tmax)
                if cfg.EXPORT_MATLAB is True:
                    # export all channels to MATLAB
                    mout = '%s/%s-%s-%s-ep%02d.mat' % (export_dir, file_prefix, cfg.SP_FILTER, evname, ep + 1)
                    scipy.io.savemat(mout, {'tfr':power[evname].data, 'chs':power[evname].ch_names,
                        'events':events, 'sfreq':sfreq, 'tmin':tmin, 'tmax':tmax, 'freqs':cfg.FREQ_RANGE})
                    print('Exported %s' % mout)
                if cfg.EXPORT_PNG is True:
                    # Inspect power for each channel
                    for ch in np.arange(len(picks)):
                        chname = raw.ch_names[picks[ch]]
                        title = 'Peri-event %s - Channel %s, Trial %d' % (evname, chname, ep + 1)
                        # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
                        fig = power[evname].plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False,
                            colorbar=True, title=title, vmin=cfg.VMIN, vmax=cfg.VMAX, dB=False)
                        fout = '%s/%s-%s-%s-%s-ep%02d.png' % (export_dir, file_prefix, cfg.SP_FILTER, evname, chname, ep + 1)
                        fig.savefig(fout)
                        fig.clf()
                        print('Exported %s' % fout)

    if hasattr(cfg, 'POWER_DIFF'):
        export_dir = '%s/diff' % outpath
        qc.make_dirs(export_dir)
        labels = classes.keys()
        df = power[labels[0]] - power[labels[1]]
        df.data = np.log(np.abs(df.data))
        # Inspect power diff for each channel
        for ch in np.arange(len(picks)):
            chname = raw.ch_names[picks[ch]]
            title = 'Peri-event %s-%s - Channel %s' % (labels[0], labels[1], chname)

            # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
            fig = df.plot([ch], baseline=cfg.BS_TIMES, mode=cfg.BS_MODE, show=False,
                          colorbar=True, title=title, vmin=3.0, vmax=-3.0, dB=False)
            fout = '%s/%s-%s-diff-%s-%s-%s.jpg' % (export_dir, file_prefix, cfg.SP_FILTER, labels[0], labels[1], chname)
            print('Exporting to %s' % fout)
            fig.savefig(fout)
            fig.clf()
    print('Finished !')
Beispiel #16
0
def log_decoding(decoder, logfile, amp_name=None, amp_serial=None, pklfile=True, matfile=False, autostop=False, prob_smooth=False):
    """
    Decode online and write results with event timestamps

    input
    -----
    decoder: Decoder or DecoderDaemon class object.
    logfile: File name to contain the result in Python pickle format.
    amp_name: LSL server name (if known).
    amp_serial: LSL server serial number (if known).
    pklfile: Export results to Python pickle format.
    matfile: Export results to Matlab .mat file if True.
    autostop: Automatically finish when no more data is received.
    prob_smooth: Use smoothed probability values according to decoder's smoothing parameter.
    """

    import cv2
    import scipy

    # run event acquisition process in the background
    state = mp.Value('i', 1)
    event_queue = mp.Queue()
    proc = mp.Process(target=log_decoding_helper, args=[state, event_queue, amp_name, amp_serial, autostop])
    proc.start()
    logger.info_green('Spawned event acquisition process.')

    # init variables and choose decoding function
    labels = decoder.get_label_names()
    probs = []
    prob_times = []
    if prob_smooth:
        decode_fn = decoder.get_prob_smooth_unread
    else:
        decode_fn = decoder.get_prob_unread
        
    # simple controller UI
    cv2.namedWindow("Decoding", cv2.WINDOW_AUTOSIZE)
    cv2.moveWindow("Decoding", 1400, 50)
    img = np.zeros([100, 400, 3], np.uint8)
    cv2.putText(img, 'Press any key to start', (20, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
    cv2.imshow("Decoding", img)
    cv2.waitKeyEx()
    img *= 0
    cv2.putText(img, 'Press ESC to stop', (40, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
    cv2.imshow("Decoding", img)

    key = 0
    started = False
    tm_watchdog = qc.Timer(autoreset=True)
    tm_cls = qc.Timer()
    while key != 27:
        prob, prob_time = decode_fn(True)
        t_lsl = pylsl.local_clock()
        key = cv2.waitKeyEx(1)
        if prob is None:
            # watch dog
            if tm_cls.sec() > 5:
                if autostop and started:
                    logger.info('No more streaming data. Finishing.')
                    break
                tm_cls.reset()
            tm_watchdog.sleep_atleast(0.001)
            continue
        probs.append(prob)
        prob_times.append(prob_time)
        txt = '[%.3f] ' % prob_time
        txt += ', '.join(['%s: %.2f' % (l, p) for l, p in zip(labels, prob)])
        txt += ' (%d ms, LSL Diff = %.3f)' % (tm_cls.msec(), (t_lsl-prob_time))
        logger.info(txt)
        if not started:
            started = True
        tm_cls.reset()

    # finish up processes
    cv2.destroyAllWindows()
    logger.info('Cleaning up event acquisition process.')
    state.value = 0
    decoder.stop()
    event_times, event_values = event_queue.get()
    proc.join()

    # save values
    if len(prob_times) == 0:
        logger.error('No decoding result. Please debug.')
        import pdb
        pdb.set_trace()
    t_start = prob_times[0]
    probs = np.vstack(probs)
    event_times = np.array(event_times)
    event_times = event_times[np.where(event_times >= t_start)[0]] - t_start
    prob_times = np.array(prob_times) - t_start
    event_values = np.array(event_values)
    data = dict(probs=probs, prob_times=prob_times, event_times=event_times, event_values=event_values, labels=labels)
    if pklfile:
        qc.save_obj(logfile, data)
        logger.info('Saved to %s' % logfile)
    if matfile:
        pp = qc.parse_path(logfile)
        matout = '%s/%s.mat' % (pp.dir, pp.name)
        scipy.io.savemat(matout, data)
        logger.info('Saved to %s' % matout)
Beispiel #17
0
    # reset trigger channel
    raw._data[0] *= 0
    raw.add_events(eve, 'TRIGGER')
    raw.save(rawfile_out, overwrite=True)

    logger.info('=== After merging ===')
    for key in np.unique(eve[:, 2]):
        if key in tdef.by_value:
            logger.info(
                '%s: %d events' %
                (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0])))
        else:
            logger.info('%s: %d events' %
                        (key, len(np.where(eve[:, 2] == key)[0])))


# sample code
if __name__ == '__main__':
    fif_dir = r'D:\data\STIMO\GO004\offline\all'
    trigger_file = 'triggerdef_gait_chuv.ini'
    events = {'BOTH_GO': ['LEFT_GO', 'RIGHT_GO']}

    out_dir = fif_dir + '/merged'
    qc.make_dirs(out_dir)
    for rawfile_in in qc.get_file_list(fif_dir):
        p = qc.parse_path(rawfile_in)
        if p.ext != 'fif':
            continue
        rawfile_out = '%s/%s.%s' % (out_dir, p.name, p.ext)
        merge_events(trigger_file, events, rawfile_in, rawfile_out)
Beispiel #18
0
    dups = np.where(0 == np.diff(eve[:, 0]))[0]
    assert len(dups) == 0
    assert max(eve[:, 2]) <= max(tdef.by_value.keys())

    # reset trigger channel
    raw._data[0] *= 0
    raw.add_events(eve, 'TRIGGER')
    raw.save(eeg_out, overwrite=True)

    print('\nResulting events')
    for key in np.unique(eve[:, 2]):
        print('%s: %d' %
              (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0])))


if __name__ == '__main__':
    fif_dir = r'D:\data\STIMO\GO004\offline\all'
    trigger_file = 'triggerdef_gait_chuv.ini'
    events = {'BOTH_GO': ['LEFT_GO', 'RIGHT_GO']}

    fiflist = []
    out_dir = fif_dir + '/merged'
    qc.make_dirs(out_dir)
    for f in qc.get_file_list(fif_dir):
        p = qc.parse_path(f)
        if p.ext != 'fif':
            continue
        eeg_in = f
        eeg_out = '%s/%s.%s' % (out_dir, p.name, p.ext)
        merge_events(trigger_file, events, eeg_in, eeg_out)