コード例 #1
0
ファイル: epochs2mat.py プロジェクト: vferat/NeuroDecode
def epochs2mat(data_dir,
               channel_picks,
               event_id,
               tmin,
               tmax,
               merge_epochs=False,
               spfilter=None,
               spchannels=None):
    if merge_epochs:
        # load all raw files in the directory and merge epochs
        fiflist = []
        for data_file in qc.get_file_list(data_dir, fullpath=True):
            if data_file[-4:] != '.fif':
                continue
            fiflist.append(data_file)
        raw, events = pu.load_multi(fiflist,
                                    spfilter=spfilter,
                                    spchannels=spchannels)
        matfile = data_dir + '/epochs_all.mat'
        save_mat(raw, events, channel_picks, event_id, tmin, tmax, matfile)
    else:
        # process individual raw file separately
        for data_file in qc.get_file_list(data_dir, fullpath=True):
            if data_file[-4:] != '.fif':
                continue
            [base, fname, fext] = qc.parse_path_list(data_file)
            matfile = '%s/%s-epochs.mat' % (base, fname)
            raw, events = pu.load_raw(data_file)
            save_mat(raw, events, channel_picks, event_id, tmin, tmax, matfile)

    logger.info('Exported to %s' % matfile)
コード例 #2
0
ファイル: fif_info.py プロジェクト: vferat/NeuroDecode
def run(fif_file):
    print('Loading "%s"' % fif_file)
    raw, events = pu.load_raw(fif_file)
    print('Raw info: %s' % raw)
    print('Channels: %s' % ', '.join(raw.ch_names))
    print('Events: %s' % set(events[:, 2]))
    print('Sampling freq: %.3f Hz' % raw.info['sfreq'])
    qc.print_c('\n>> Interactive mode start. Type quit or Ctrl+D to finish',
               'g')
    qc.print_c('>> Variables: raw, events\n', 'g')
    embed()
コード例 #3
0
def fif_resample(fif_dir, sfreq_target):
    out_dir = fif_dir + '/fif_resample%d' % sfreq_target
    qc.make_dirs(out_dir)
    for f in qc.get_file_list(fif_dir):
        pp = qc.parse_path(f)
        if pp.ext != 'fif':
            continue
        logger.info('Resampling %s' % f)
        raw, events = pu.load_raw(f)
        raw.resample(sfreq_target)
        fif_out = '%s/%s.fif' % (out_dir, pp.name)
        raw.save(fif_out)
        logger.info('Exported to %s' % fif_out)
コード例 #4
0
ファイル: fif2mat.py プロジェクト: vferat/NeuroDecode
def fif2mat(data_dir):
    out_dir = '%s/mat_files' % data_dir
    qc.make_dirs(out_dir)
    for rawfile in qc.get_file_list(data_dir, fullpath=True):
        if rawfile[-4:] != '.fif': continue
        raw, events = pu.load_raw(rawfile)
        events[:, 0] += 1  # MATLAB uses 1-based indexing
        sfreq = raw.info['sfreq']
        data = dict(signals=raw._data,
                    events=events,
                    sfreq=sfreq,
                    ch_names=raw.ch_names)
        fname = qc.parse_path(rawfile).name
        matfile = '%s/%s.mat' % (out_dir, fname)
        scipy.io.savemat(matfile, data)
        logger.info('Exported to %s' % matfile)
    logger.info('Finished exporting.')
コード例 #5
0
ファイル: merge_events.py プロジェクト: vferat/NeuroDecode
def merge_events(trigger_file, events, rawfile_in, rawfile_out):
    tdef = trigger_def(trigger_file)
    raw, eve = pu.load_raw(rawfile_in)

    logger.info('=== Before merging ===')
    notfounds = []
    for key in np.unique(eve[:, 2]):
        if key in tdef.by_value:
            logger.info(
                '%s: %d events' %
                (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0])))
        else:
            logger.info('%d: %d events' %
                        (key, len(np.where(eve[:, 2] == key)[0])))
            notfounds.append(key)
    if notfounds:
        for key in notfounds:
            logger.warning('Key %d was not found in the definition file.' %
                           key)

    for key in events:
        ev_src = events[key]
        ev_out = tdef.by_name[key]
        x = []
        for e in ev_src:
            x.append(np.where(eve[:, 2] == tdef.by_name[e])[0])
        eve[np.concatenate(x), 2] = ev_out

    # sanity check
    dups = np.where(0 == np.diff(eve[:, 0]))[0]
    assert len(dups) == 0

    # reset trigger channel
    raw._data[0] *= 0
    raw.add_events(eve, 'TRIGGER')
    raw.save(rawfile_out, overwrite=True)

    logger.info('=== After merging ===')
    for key in np.unique(eve[:, 2]):
        if key in tdef.by_value:
            logger.info(
                '%s: %d events' %
                (tdef.by_value[key], len(np.where(eve[:, 2] == key)[0])))
        else:
            logger.info('%s: %d events' %
                        (key, len(np.where(eve[:, 2] == key)[0])))
コード例 #6
0
def fix_channel_names(fif_dir, new_channel_names):
    '''
    Change channel names of fif files in a given directory.

    Input
    -----
    @fif_dir: path to fif files
    @new_channel_names: list of new channel names

    Output
    ------
    Modified fif files are saved in fif_dir/corrected/

    Kyuhwa Lee, 2019.
    '''

    flist = []
    for f in qc.get_file_list(fif_dir):
        if qc.parse_path(f).ext == 'fif':
            flist.append(f)

    if len(flist) > 0:
        qc.make_dirs('%s/corrected' % fif_dir)
        for f in qc.get_file_list(fif_dir):
            logger.info('\nLoading %s' % f)
            p = qc.parse_path(f)
            if p.ext == 'fif':
                raw, eve = pu.load_raw(f)
                if len(raw.ch_names) != len(new_channel_names):
                    raise RuntimeError('The number of new channels do not matach that of fif file.')
                raw.info['ch_names'] = new_channel_names
                for ch, new_ch in zip(raw.info['chs'], new_channel_names):
                    ch['ch_name'] = new_ch
                out_fif = '%s/corrected/%s.fif' % (p.dir, p.name)
                logger.info('Exporting to %s' % out_fif)
                raw.save(out_fif)
    else:
        logger.warning('No fif files found in %s' % fif_dir)
コード例 #7
0
ファイル: raw2psd.py プロジェクト: vferat/NeuroDecode
def raw2psd(rawfile=None,
            fmin=1,
            fmax=40,
            wlen=0.5,
            wstep=1,
            tmin=0.0,
            tmax=None,
            channel_picks=None,
            excludes=[],
            n_jobs=1):
    """
    Compute PSD features over a sliding window on the entire raw file.
    Leading edge of the window is the time reference, i.e. do not use future data.

    Input
    =====
    rawfile: fif file.
    channel_picks: None or list of channel names
    tmin (sec): start time of the PSD window relative to the event onset.
    tmax (sec): end time of the PSD window relative to the event onset. None = until the end.
    fmin (Hz): minimum PSD frequency
    fmax (Hz): maximum PSD frequency
    wlen (sec): sliding window length for computing PSD (sec)
    wstep (int): sliding window step (time samples)
    excludes (list): list of channels to exclude
    """

    raw, eve = pu.load_raw(rawfile)
    sfreq = raw.info['sfreq']
    wframes = int(round(sfreq * wlen))
    raw_eeg = raw.pick_types(meg=False, eeg=True, stim=False, exclude=excludes)
    if channel_picks is None:
        rawdata = raw_eeg._data
        chlist = raw.ch_names
    else:
        chlist = []
        for ch in channel_picks:
            chlist.append(raw.ch_names.index(ch))
        rawdata = raw_eeg._data[np.array(chlist)]

    if tmax is None:
        t_end = rawdata.shape[1]
    else:
        t_end = int(round(tmax * sfreq))
    t_start = int(round(tmin * sfreq)) + wframes
    psde = mne.decoding.PSDEstimator(sfreq, fmin=fmin, fmax=fmax, n_jobs=1,\
        bandwidth=None, low_bias=True, adaptive=False, normalization='length',
        verbose=None)
    print('[PID %d] %s' % (os.getpid(), rawfile))
    psd_all = []
    evelist = []
    times = []
    t_len = t_end - t_start
    last_eve = 0
    y_i = 0
    t_last = t_start
    tm = qc.Timer()
    for t in range(t_start, t_end, wstep):
        # compute PSD
        window = rawdata[:, t - wframes:t]
        psd = psde.transform(
            window.reshape((1, window.shape[0], window.shape[1])))
        psd = psd.reshape(psd.shape[1], psd.shape[2])
        psd_all.append(psd)
        times.append(t)

        # matching events at the current window
        if y_i < eve.shape[0] and t >= eve[y_i][0]:
            last_eve = eve[y_i][2]
            y_i += 1
        evelist.append(last_eve)

        if tm.sec() >= 1:
            perc = (t - t_start) / t_len
            fps = (t - t_last) / wstep
            est = (t_end - t) / wstep / fps
            logger.info('[PID %d] %.1f%% (%.1f FPS, %ds left)' %
                        (os.getpid(), perc * 100.0, fps, est))
            t_last = t
            tm.reset()
    logger.info('Finished.')

    # export data
    try:
        chnames = [raw.ch_names[ch] for ch in chlist]
        psd_all = np.array(psd_all)
        [basedir, fname, fext] = qc.parse_path_list(rawfile)
        fout_header = '%s/psd-%s-header.pkl' % (basedir, fname)
        fout_psd = '%s/psd-%s-data.npy' % (basedir, fname)
        header = {
            'psdfile': fout_psd,
            'times': np.array(times),
            'sfreq': sfreq,
            'channels': chnames,
            'wframes': wframes,
            'events': evelist
        }
        qc.save_obj(fout_header, header)
        np.save(fout_psd, psd_all)
        logger.info('Exported to:\n(header) %s\n(numpy array) %s' %
                    (fout_header, fout_psd))
    except:
        logger.exception('(%s) Unexpected error occurred while exporting data. Dropping you into a shell for recovery.' %\
            os.path.basename(__file__))
        embed()
コード例 #8
0
def stream_player(server_name,
                  fif_file,
                  chunk_size,
                  auto_restart=True,
                  wait_start=True,
                  repeat=np.float('inf'),
                  high_resolution=False,
                  trigger_file=None):
    """
    Input
    =====
    server_name: LSL server name.
    fif_file: fif file to replay.
    chunk_size: number of samples to send at once (usually 16-32 is good enough).
    auto_restart: play from beginning again after reaching the end.
    wait_start: wait for user to start in the beginning.
    repeat: number of loops to play.
    high_resolution: use perf_counter() instead of sleep() for higher time resolution
                     but uses much more cpu due to polling.
    trigger_file: used to convert event numbers into event strings for readability.
    
    Note: Run neurodecode.set_log_level('DEBUG') to print out the relative time stamps since started.
    
    """
    raw, events = pu.load_raw(fif_file)
    sfreq = raw.info['sfreq']  # sampling frequency
    n_channels = len(raw.ch_names)  # number of channels
    if trigger_file is not None:
        tdef = trigger_def(trigger_file)
    try:
        event_ch = raw.ch_names.index('TRIGGER')
    except ValueError:
        event_ch = None
    if raw is not None:
        logger.info_green('Successfully loaded %s' % fif_file)
        logger.info('Server name: %s' % server_name)
        logger.info('Sampling frequency %.3f Hz' % sfreq)
        logger.info('Number of channels : %d' % n_channels)
        logger.info('Chunk size : %d' % chunk_size)
        for i, ch in enumerate(raw.ch_names):
            logger.info('%d %s' % (i, ch))
        logger.info('Trigger channel : %s' % event_ch)
    else:
        raise RuntimeError('Error while loading %s' % fif_file)

    # set server information
    sinfo = pylsl.StreamInfo(server_name, channel_count=n_channels, channel_format='float32',\
        nominal_srate=sfreq, type='EEG', source_id=server_name)
    desc = sinfo.desc()
    channel_desc = desc.append_child("channels")
    for ch in raw.ch_names:
        channel_desc.append_child('channel').append_child_value('label', str(ch))\
            .append_child_value('type','EEG').append_child_value('unit','microvolts')
    desc.append_child('amplifier').append_child('settings').append_child_value(
        'is_slave', 'false')
    desc.append_child('acquisition').append_child_value(
        'manufacturer',
        'NeuroDecode').append_child_value('serial_number', 'N/A')
    outlet = pylsl.StreamOutlet(sinfo, chunk_size=chunk_size)

    if wait_start:
        input('Press Enter to start streaming.')
    logger.info('Streaming started')

    idx_chunk = 0
    t_chunk = chunk_size / sfreq
    finished = False
    if high_resolution:
        t_start = time.perf_counter()
    else:
        t_start = time.time()

    # start streaming
    played = 1
    while played < repeat:
        idx_current = idx_chunk * chunk_size
        chunk = raw._data[:, idx_current:idx_current + chunk_size]
        data = chunk.transpose().tolist()
        if idx_current >= raw._data.shape[1] - chunk_size:
            finished = True
        if high_resolution:
            # if a resolution over 2 KHz is needed
            t_sleep_until = t_start + idx_chunk * t_chunk
            while time.perf_counter() < t_sleep_until:
                pass
        else:
            # time.sleep() can have 500 us resolution using the tweak tool provided.
            t_wait = t_start + idx_chunk * t_chunk - time.time()
            if t_wait > 0.001:
                time.sleep(t_wait)
        outlet.push_chunk(data)
        logger.debug('[%8.3fs] sent %d samples (LSL %8.3f)' %
                     (time.perf_counter(), len(data), pylsl.local_clock()))
        if event_ch is not None:
            event_values = set(chunk[event_ch]) - set([0])
            if len(event_values) > 0:
                if trigger_file is None:
                    logger.info('Events: %s' % event_values)
                else:
                    for event in event_values:
                        if event in tdef.by_value:
                            logger.info('Events: %s (%s)' %
                                        (event, tdef.by_value[event]))
                        else:
                            logger.info('Events: %s (Undefined event)' % event)
        idx_chunk += 1

        if finished:
            if auto_restart is False:
                input(
                    'Reached the end of data. Press Enter to restart or Ctrl+C to stop.'
                )
            else:
                logger.info('Reached the end of data. Restarting.')
            idx_chunk = 0
            finished = False
            if high_resolution:
                t_start = time.perf_counter()
            else:
                t_start = time.time()
            played += 1
コード例 #9
0
def get_tfr(cfg, recursive=False, n_jobs=1):
    '''
    @params:
    tfr_type: 'multitaper' or 'morlet'
    recursive: if True, load raw files in sub-dirs recursively
    export_path: path to save plots
    n_jobs: number of cores to run in parallel
    '''

    cfg = check_config(cfg)
    tfr_type = cfg.TFR_TYPE
    export_path = cfg.EXPORT_PATH
    t_buffer = cfg.T_BUFFER
    if tfr_type == 'multitaper':
        tfr = mne.time_frequency.tfr_multitaper
    elif tfr_type == 'morlet':
        tfr = mne.time_frequency.tfr_morlet
    elif tfr_type == 'butter':
        butter_order = 4  # TODO: parameterize
        tfr = lfilter
    elif tfr_type == 'fir':
        raise NotImplementedError
    else:
        raise ValueError('Wrong TFR type %s' % tfr_type)
    n_jobs = cfg.N_JOBS
    if n_jobs is None:
        n_jobs = mp.cpu_count()

    if hasattr(cfg, 'DATA_PATHS'):
        if export_path is None:
            raise ValueError(
                'For multiple directories, cfg.EXPORT_PATH cannot be None')
        else:
            outpath = export_path
        # custom event file
        if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None:
            events = mne.read_events(cfg.EVENT_FILE)
        file_prefix = 'grandavg'

        # load and merge files from all directories
        flist = []
        for ddir in cfg.DATA_PATHS:
            ddir = ddir.replace('\\', '/')
            if ddir[-1] != '/': ddir += '/'
            for f in qc.get_file_list(ddir, fullpath=True,
                                      recursive=recursive):
                if qc.parse_path(f).ext in ['fif', 'bdf', 'gdf']:
                    flist.append(f)
        raw, events = pu.load_multi(flist)
    else:
        logger.info('Loading %s' % cfg.DATA_FILE)
        raw, events = pu.load_raw(cfg.DATA_FILE)

        # custom events
        if hasattr(cfg, 'EVENT_FILE') and cfg.EVENT_FILE is not None:
            events = mne.read_events(cfg.EVENT_FILE)

        if export_path is None:
            [outpath, file_prefix, _] = qc.parse_path_list(cfg.DATA_FILE)
        else:
            file_prefix = qc.parse_path(cfg.DATA_FILE).name
            outpath = export_path
            file_prefix = qc.parse_path(cfg.DATA_FILE).name

    # re-referencing
    if cfg.REREFERENCE is not None:
        pu.rereference(raw, cfg.REREFERENCE[1], cfg.REREFERENCE[0])
        assert cfg.REREFERENCE[0] in raw.ch_names

    sfreq = raw.info['sfreq']

    # set channels of interest
    picks = pu.channel_names_to_index(raw, cfg.CHANNEL_PICKS)
    spchannels = pu.channel_names_to_index(raw, cfg.SP_CHANNELS)

    if max(picks) > len(raw.info['ch_names']):
        msg = 'ERROR: "picks" has a channel index %d while there are only %d channels.' %\
              (max(picks), len(raw.info['ch_names']))
        raise RuntimeError(msg)

    # Apply filters
    raw = pu.preprocess(raw,
                        spatial=cfg.SP_FILTER,
                        spatial_ch=spchannels,
                        spectral=cfg.TP_FILTER,
                        spectral_ch=picks,
                        notch=cfg.NOTCH_FILTER,
                        notch_ch=picks,
                        multiplier=cfg.MULTIPLIER,
                        n_jobs=n_jobs)

    # Read epochs
    classes = {}
    for t in cfg.TRIGGERS:
        if t in set(events[:, -1]):
            if hasattr(cfg, 'tdef'):
                classes[cfg.tdef.by_value[t]] = t
            else:
                classes[str(t)] = t
    if len(classes) == 0:
        raise ValueError('No desired event was found from the data.')

    try:
        tmin = cfg.EPOCH[0]
        tmin_buffer = tmin - t_buffer
        raw_tmax = raw._data.shape[1] / sfreq - 0.1
        if cfg.EPOCH[1] is None:
            if cfg.POWER_AVERAGED:
                raise ValueError(
                    'EPOCH value cannot have None for grand averaged TFR')
            else:
                if len(cfg.TRIGGERS) > 1:
                    raise ValueError(
                        'If the end time of EPOCH is None, only a single event can be defined.'
                    )
                t_ref = events[np.where(
                    events[:, 2] == list(cfg.TRIGGERS)[0])[0][0], 0] / sfreq
                tmax = raw_tmax - t_ref - t_buffer
        else:
            tmax = cfg.EPOCH[1]
        tmax_buffer = tmax + t_buffer
        if tmax_buffer > raw_tmax:
            raise ValueError(
                'Epoch length with buffer (%.3f) is larger than signal length (%.3f)'
                % (tmax_buffer, raw_tmax))
        epochs_all = mne.Epochs(raw,
                                events,
                                classes,
                                tmin=tmin_buffer,
                                tmax=tmax_buffer,
                                proj=False,
                                picks=picks,
                                baseline=None,
                                preload=True)
        if epochs_all.drop_log_stats() > 0:
            logger.error(
                '\n** Bad epochs found. Dropping into a Python shell.')
            logger.error(epochs_all.drop_log)
            logger.error('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \
                (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq))
            logger.error('\nType exit to continue.\n')
            pdb.set_trace()
    except:
        logger.critical(
            '\n*** (tfr_export) Unknown error occurred while epoching ***')
        logger.critical('tmin = %.1f, tmax = %.1f, tmin_buffer = %.1f, tmax_buffer = %.1f, raw length = %.1f' % \
            (tmin, tmax, tmin_buffer, tmax_buffer, raw._data.shape[1] / sfreq))
        pdb.set_trace()

    power = {}
    for evname in classes:
        export_dir = outpath
        qc.make_dirs(export_dir)
        logger.info('>> Processing %s' % evname)
        freqs = cfg.FREQ_RANGE  # define frequencies of interest
        n_cycles = freqs / 2.  # different number of cycle per frequency
        if cfg.POWER_AVERAGED:
            # grand-average TFR
            epochs = epochs_all[evname][:]
            if len(epochs) == 0:
                logger.WARNING('No %s epochs. Skipping.' % evname)
                continue

            if tfr_type == 'butter':
                b, a = butter_bandpass(cfg.FREQ_RANGE[0],
                                       cfg.FREQ_RANGE[-1],
                                       sfreq,
                                       order=butter_order)
                tfr_filtered = lfilter(b, a, epochs, axis=2)
                tfr_hilbert = hilbert(tfr_filtered)
                tfr_power = abs(tfr_hilbert)
                tfr_data = np.mean(tfr_power, axis=0)
            elif tfr_type == 'fir':
                raise NotImplementedError
            else:
                power[evname] = tfr(epochs,
                                    freqs=freqs,
                                    n_cycles=n_cycles,
                                    use_fft=False,
                                    return_itc=False,
                                    decim=1,
                                    n_jobs=n_jobs)
                power[evname] = power[evname].crop(tmin=tmin, tmax=tmax)
                tfr_data = power[evname].data

            if cfg.EXPORT_MATLAB is True:
                # export all channels to MATLAB
                mout = '%s/%s-%s-%s.mat' % (export_dir, file_prefix,
                                            cfg.SP_FILTER, evname)
                scipy.io.savemat(
                    mout, {
                        'tfr': tfr_data,
                        'chs': epochs.ch_names,
                        'events': events,
                        'sfreq': sfreq,
                        'tmin': tmin,
                        'tmax': tmax,
                        'epochs': cfg.EPOCH,
                        'freqs': cfg.FREQ_RANGE
                    })
                logger.info('Exported %s' % mout)
            if cfg.EXPORT_PNG is True:
                # Inspect power for each channel
                for ch in np.arange(len(picks)):
                    chname = raw.ch_names[picks[ch]]
                    title = 'Peri-event %s - Channel %s' % (evname, chname)

                    # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
                    fig = power[evname].plot([ch],
                                             baseline=cfg.BS_TIMES,
                                             mode=cfg.BS_MODE,
                                             show=False,
                                             colorbar=True,
                                             title=title,
                                             vmin=cfg.VMIN,
                                             vmax=cfg.VMAX,
                                             dB=False)
                    fout = '%s/%s-%s-%s-%s.png' % (
                        export_dir, file_prefix, cfg.SP_FILTER, evname, chname)
                    fig.savefig(fout)
                    plt.close()
                    logger.info('Exported to %s' % fout)
        else:
            # TFR per event
            for ep in range(len(epochs_all[evname])):
                epochs = epochs_all[evname][ep]
                if len(epochs) == 0:
                    logger.WARNING('No %s epochs. Skipping.' % evname)
                    continue
                power[evname] = tfr(epochs,
                                    freqs=freqs,
                                    n_cycles=n_cycles,
                                    use_fft=False,
                                    return_itc=False,
                                    decim=1,
                                    n_jobs=n_jobs)
                power[evname] = power[evname].crop(tmin=tmin, tmax=tmax)
                if cfg.EXPORT_MATLAB is True:
                    # export all channels to MATLAB
                    mout = '%s/%s-%s-%s-ep%02d.mat' % (
                        export_dir, file_prefix, cfg.SP_FILTER, evname, ep + 1)
                    scipy.io.savemat(
                        mout, {
                            'tfr': power[evname].data,
                            'chs': power[evname].ch_names,
                            'events': events,
                            'sfreq': sfreq,
                            'tmin': tmin,
                            'tmax': tmax,
                            'epochs': cfg.EPOCH,
                            'freqs': cfg.FREQ_RANGE
                        })
                    logger.info('Exported %s' % mout)
                if cfg.EXPORT_PNG is True:
                    # Inspect power for each channel
                    for ch in np.arange(len(picks)):
                        chname = raw.ch_names[picks[ch]]
                        title = 'Peri-event %s - Channel %s, Trial %d' % (
                            evname, chname, ep + 1)
                        # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
                        fig = power[evname].plot([ch],
                                                 baseline=cfg.BS_TIMES,
                                                 mode=cfg.BS_MODE,
                                                 show=False,
                                                 colorbar=True,
                                                 title=title,
                                                 vmin=cfg.VMIN,
                                                 vmax=cfg.VMAX,
                                                 dB=False)
                        fout = '%s/%s-%s-%s-%s-ep%02d.png' % (
                            export_dir, file_prefix, cfg.SP_FILTER, evname,
                            chname, ep + 1)
                        fig.savefig(fout)
                        plt.close()
                        logger.info('Exported %s' % fout)

    if hasattr(cfg, 'POWER_DIFF'):
        export_dir = '%s/diff' % outpath
        qc.make_dirs(export_dir)
        labels = classes.keys()
        df = power[labels[0]] - power[labels[1]]
        df.data = np.log(np.abs(df.data))
        # Inspect power diff for each channel
        for ch in np.arange(len(picks)):
            chname = raw.ch_names[picks[ch]]
            title = 'Peri-event %s-%s - Channel %s' % (labels[0], labels[1],
                                                       chname)

            # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
            fig = df.plot([ch],
                          baseline=cfg.BS_TIMES,
                          mode=cfg.BS_MODE,
                          show=False,
                          colorbar=True,
                          title=title,
                          vmin=3.0,
                          vmax=-3.0,
                          dB=False)
            fout = '%s/%s-%s-diff-%s-%s-%s.jpg' % (export_dir, file_prefix,
                                                   cfg.SP_FILTER, labels[0],
                                                   labels[1], chname)
            logger.info('Exporting to %s' % fout)
            fig.savefig(fout)
            plt.close()
    logger.info('Finished !')
コード例 #10
0
def get_tfr(fif_file, cfg, tfr, n_jobs=1):
    raw, events = pu.load_raw(fif_file)
    p = qc.parse_path(fif_file)
    fname = p.name
    outpath = p.dir

    export_dir = '%s/plot_%s' % (outpath, fname)
    qc.make_dirs(export_dir)

    # set channels of interest
    picks = pu.channel_names_to_index(raw, cfg.CHANNEL_PICKS)
    spchannels = pu.channel_names_to_index(raw, cfg.SP_CHANNELS)

    if max(picks) > len(raw.info['ch_names']):
        msg = 'ERROR: "picks" has a channel index %d while there are only %d channels.' %\
              (max(picks), len(raw.info['ch_names']))
        raise RuntimeError(msg)

    # Apply filters
    raw = pu.preprocess(raw,
                        spatial=cfg.SP_FILTER,
                        spatial_ch=spchannels,
                        spectral=cfg.TP_FILTER,
                        spectral_ch=picks,
                        notch=cfg.NOTCH_FILTER,
                        notch_ch=picks,
                        multiplier=cfg.MULTIPLIER,
                        n_jobs=n_jobs)

    # MNE TFR functions do not support Raw instances yet, so convert to Epoch
    if cfg.EVENT_START is None:
        raw._data[0][0] = 1
        events = np.array([[0, 0, 1]])
        classes = None
    else:
        classes = {'START': cfg.EVENT_START}
    tmax = (raw._data.shape[1] - 1) / raw.info['sfreq']
    epochs_all = mne.Epochs(raw,
                            events,
                            classes,
                            tmin=0,
                            tmax=tmax,
                            picks=picks,
                            baseline=None,
                            preload=True)
    logger.info('\n>> Processing %s' % fif_file)
    freqs = cfg.FREQ_RANGE  # define frequencies of interest
    n_cycles = freqs / 2.  # different number of cycle per frequency
    power = tfr(epochs_all,
                freqs=freqs,
                n_cycles=n_cycles,
                use_fft=False,
                return_itc=False,
                decim=1,
                n_jobs=n_jobs)

    if cfg.EXPORT_MATLAB is True:
        # export all channels to MATLAB
        mout = '%s/%s-%s.mat' % (export_dir, fname, cfg.SP_FILTER)
        scipy.io.savemat(
            mout, {
                'tfr': power.data,
                'chs': power.ch_names,
                'events': events,
                'sfreq': raw.info['sfreq'],
                'freqs': cfg.FREQ_RANGE
            })

    if cfg.EXPORT_PNG is True:
        # Plot power of each channel
        for ch in np.arange(len(picks)):
            ch_name = raw.ch_names[picks[ch]]
            title = 'Channel %s' % (ch_name)
            # mode= None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
            fig = power.plot([ch],
                             baseline=cfg.BS_TIMES,
                             mode=cfg.BS_MODE,
                             show=False,
                             colorbar=True,
                             title=title,
                             vmin=cfg.VMIN,
                             vmax=cfg.VMAX,
                             dB=False)
            fout = '%s/%s-%s-%s.png' % (export_dir, fname, cfg.SP_FILTER,
                                        ch_name)
            fig.savefig(fout)
            logger.info('Exported %s' % fout)

    logger.info('Finished !')