def get_leadfield(subject, session, head_model="three_layer"):
    """
    Compute leadfield with presets for this subject

    Parameters
    head_model : str, 'three_layer' or 'single_layer'
    """
    raw_filename = metadata.get_raw_filename(subject, session)
    epoch_filename = metadata.get_epoch_filename(subject, session, 0,
                                                 "stimulus", "fif")
    trans = get_trans(subject, session)

    return pymegsr.get_leadfield(
        "S%02i" % subject,
        raw_filename,
        epoch_filename,
        trans,
        conductivity=(0.3, 0.006, 0.3),
        bem_sub_path="bem_ft",
        njobs=4,
    )
def extract(subject,
            session,
            recording,
            epoch,
            signal_type='BB',
            BEM='three_layer',
            debug=False,
            chunks=100,
            njobs=4):
    mne.set_log_level('WARNING')
    lcmv.logging.getLogger().setLevel(logging.INFO)
    set_n_threads(1)

    logging.info('Reading stimulus data')

    if epoch == 'stimulus':
        data_cov, epochs, epochs_filename = get_stim_epoch(
            subject, session, recording)
    else:
        data_cov, epochs, epochs_filename = get_response_epoch(
            subject, session, recording)

    raw_filename = glob('TODO' % (subject, session, recording))

    trans_filename = glob('TODO' % (subject, session, recording))[0]
    logging.info('Setting up source space and forward model')

    forward, bem, source = sr.get_leadfield(subject,
                                            raw_filename,
                                            epochs_filename,
                                            trans_filename,
                                            bem_sub_path='bem_ft')
    labels = sr.get_labels(subject)
    labels = sr.labels_exclude(labels,
                               exclude_filters=[
                                   'wang2015atlas.IPS4', 'wang2015atlas.IPS5',
                                   'wang2015atlas.SPL', 'JWDG_lat_Unknown'
                               ])
    labels = sr.labels_remove_overlap(
        labels,
        priority_filters=['wang', 'JWDG'],
    )

    fois_h = np.arange(36, 162, 4)
    fois_l = np.arange(2, 36, 1)
    tfr_params = {
        'HF': {
            'foi': fois_h,
            'cycles': fois_h * 0.25,
            'time_bandwidth': 2 + 1,
            'n_jobs': njobs,
            'est_val': fois_h,
            'est_key': 'HF',
            'sf': 600,
            'decim': 10
        },
        'LF': {
            'foi': fois_l,
            'cycles': fois_l * 0.4,
            'time_bandwidth': 1 + 1,
            'n_jobs': njobs,
            'est_val': fois_l,
            'est_key': 'LF',
            'sf': 600,
            'decim': 10
        }
    }

    events = epochs.events[:, 2]
    filters = lcmv.setup_filters(epochs.info, forward, data_cov, None, labels)
    set_n_threads(1)

    for i in range(0, len(events), chunks):
        filename = lcmvfilename(subject,
                                session,
                                signal_type,
                                recording,
                                chunk=i)
        if os.path.isfile(filename):
            continue
        if signal_type == 'BB':
            logging.info('Starting reconstruction of BB signal')
            M = lcmv.reconstruct_broadband(filters,
                                           epochs.info,
                                           epochs._data[i:i + chunks],
                                           events[i:i + chunks],
                                           epochs.times,
                                           njobs=1)
        else:
            logging.info('Starting reconstruction of TFR signal')
            M = lcmv.reconstruct_tfr(filters,
                                     epochs.info,
                                     epochs._data[i:i + chunks],
                                     events[i:i + chunks],
                                     epochs.times,
                                     est_args=tfr_params[signal_type],
                                     njobs=4)
        M.to_hdf(filename, 'epochs')
    set_n_threads(njobs)
Exemplo n.º 3
0
def extract_filter(subject,
                   session,
                   recording,
                   epoch,
                   signal_type='BB',
                   BEM='three_layer',
                   debug=False,
                   chunks=50,
                   njobs=4):
    mne.set_log_level('WARNING')
    lcmv.logging.getLogger().setLevel(logging.INFO)
    logging.getLogger().setLevel(logging.INFO)
    set_n_threads(1)
    subject_int = int(subject[1:])
    print('reading stimulus data')
    MRI_subjects = ["S1", "S5", "S6", "S8", "S11", "S12", "S16", "S17"]
    if subject in MRI_subjects:
        subject2 = subject
    else:
        subject2 = "fsaverage"

    logging.info('Reading stimulus data')

    if epoch == 'stimulus':
        data_cov, epochs, epochs_filename = get_stim_epoch(
            subject, session, recording)
    else:
        data_cov, epochs, epochs_filename = get_response_epoch(
            subject, session, recording)

    fname = '/storage/genis/preprocessed_megdata/filenames_sub%i.pickle' % (
        subject_int)
    f = open(fname, 'rb')
    data = pickle.load(f)
    df = pd.DataFrame.from_dict(data)
    raw_filename = df[df.subject == subject_int][df.session == session][
        df.trans_matrix == recording].filename.iloc[0]
    print('hola')
    #    trans_filename = '/home/gprat/cluster_home/pymeg/S%i-sess%i-%i-trans.fif' % (subject_int, session, recording)
    trans_filename = '/home/genis/pymeg/S%i-sess%i-%i-trans.fif' % (
        subject_int, session, recording)

    print('data_cov_done')

    #    raw_filename = glob('TODO' % (subject, session, recording))

    #    trans_filename = glob('TODO' % (subject, session, recording))[0]
    logging.info('Setting up source space and forward model')
    epo_filename = get_filename_trans_matrix(subject, session, recording)
    forward, bem, source = sr.get_leadfield(subject2,
                                            raw_filename,
                                            epo_filename,
                                            trans_filename,
                                            bem_sub_path='bem_ft')
    labels = sr.get_labels(subject2)
    labels = sr.labels_exclude(labels,
                               exclude_filters=[
                                   'wang2015atlas.IPS4', 'wang2015atlas.IPS5',
                                   'wang2015atlas.SPL', 'JWDG_lat_Unknown'
                               ])
    labels = sr.labels_remove_overlap(
        labels,
        priority_filters=['wang', 'JWDG'],
    )

    filters = lcmv.setup_filters(epochs.info, forward, data_cov, None, labels)
    subject_int = int(subject[1:])
    fname = 'filter_sub%i_SESS%i_recording%i_epoch%s.pickle' % (
        subject_int, session, recording, epoch)
    filename = join(path, fname)
    f = open(filename, 'wb')
    pickle.dump(filters, f)
    print('filter_done')
    f.close()

    return filters
Exemplo n.º 4
0
def do_source_recon(subj, session, njobs=4):

    epochs_filename_stim = os.path.join(data_folder, "epochs", subj, session,
                                        '{}-epo.fif.gz'.format('stimlock'))
    epochs_filename_resp = os.path.join(data_folder, "epochs", subj, session,
                                        '{}-epo.fif.gz'.format('resplock'))
    trans_filename = os.path.join(data_folder, "transformation_matrix",
                                  '{}_{}-trans.fif'.format(subj, session))

    if os.path.isfile(epochs_filename_stim):

        runs = sorted([
            run.split('/')[-1] for run in glob.glob(
                os.path.join(data_folder, "raw", subj, session, "meg", "*.ds"))
        ])
        center = int(np.floor(len(runs) / 2.0))
        raw_filename = os.path.join(data_folder, "raw", subj, session, "meg",
                                    runs[center])

        # # make transformation matrix:
        # sr.make_trans(subj, raw_filename, epochs_filename, trans_filename)

        # load labels:
        labels = sr.get_labels(subject=subj,
                               filters=['*wang*.label', '*JWG*.label'],
                               annotations=['HCPMMP1'])
        labels = sr.labels_exclude(labels=labels,
                                   exclude_filters=[
                                       'wang2015atlas.IPS4',
                                       'wang2015atlas.IPS5',
                                       'wang2015atlas.SPL', 'JWG_lat_Unknown'
                                   ])
        labels = sr.labels_remove_overlap(labels=labels,
                                          priority_filters=['wang', 'JWG'])
        print(labels)

        # load epochs:
        epochs_stim = mne.read_epochs(epochs_filename_stim)
        epochs_stim = epochs_stim.pick_channels(
            [x for x in epochs_stim.ch_names if x.startswith('M')])
        epochs_resp = mne.read_epochs(epochs_filename_resp)
        epochs_resp = epochs_resp.pick_channels(
            [x for x in epochs_resp.ch_names if x.startswith('M')])

        # baseline stuff:
        overlap = list(
            set(epochs_stim.events[:, 2]).intersection(
                set(epochs_resp.events[:, 2])))
        epochs_stim = epochs_stim[[str(l) for l in overlap]]
        epochs_resp = epochs_resp[[str(l) for l in overlap]]
        id_time = (-0.3 <= epochs_stim.times) & (epochs_stim.times <= -0.2)
        means = epochs_stim._data[:, :, id_time].mean(-1)
        epochs_stim._data = epochs_stim._data - means[:, :, np.newaxis]
        epochs_resp._data = epochs_resp._data - means[:, :, np.newaxis]

        # TFR settings:
        fois_h = np.arange(42, 162, 4)
        fois_l = np.arange(2, 42, 2)
        tfr_params = {
            'HF': {
                'foi': fois_h,
                'cycles': fois_h * 0.4,
                'time_bandwidth': 5 + 1,
                'n_jobs': njobs,
                'est_val': fois_h,
                'est_key': 'HF'
            },
            'LF': {
                'foi': fois_l,
                'cycles': fois_l * 0.4,
                'time_bandwidth': 1 + 1,
                'n_jobs': njobs,
                'est_val': fois_l,
                'est_key': 'LF'
            }
        }

        # get cov:
        data_cov = lcmv.get_cov(epochs_stim, tmin=0, tmax=1)
        noise_cov = None

        # get lead field:
        forward, bem, source = sr.get_leadfield(
            subject=subj,
            raw_filename=raw_filename,
            epochs_filename=epochs_filename_stim,
            trans_filename=trans_filename,
            conductivity=(0.3, 0.006, 0.3),
            njobs=njobs)

        # do source level analysis:
        for tl, epochs in zip(['stimlock', 'resplock'],
                              [epochs_stim, epochs_resp]):
            for signal_type in ['LF', 'HF']:
                print(signal_type)

                # events:
                events = epochs.events[:, 2]
                data = []
                filters = lcmv.setup_filters(epochs.info,
                                             forward,
                                             data_cov,
                                             None,
                                             labels,
                                             njobs=njobs)

                # in chunks:
                chunks = 100
                for i in range(0, len(events), chunks):
                    filename = os.path.join(data_folder, "source_level", 'lcmv_{}_{}_{}_{}_{}-source.hdf'.\
                        format(subj, session, tl, signal_type, i))
                    # if os.path.isfile(filename):
                    #     continue
                    M = lcmv.reconstruct_tfr(filters,
                                             epochs.info,
                                             epochs._data[i:i + chunks],
                                             events[i:i + chunks],
                                             epochs.times,
                                             est_args=tfr_params[signal_type],
                                             njobs=njobs)
                    M.to_hdf(filename, 'epochs')

                    del M
Exemplo n.º 5
0
def extract(
    recording_number,
    epoch,
    signal_type="BB",
    BEM="three_layer",
    chunks=100,
    njobs=4,
    glasser_only=True,
):
    recording = ps.recordings[recording_number]
    mne.set_log_level("ERROR")
    lcmv.logging.getLogger().setLevel(logging.INFO)
    set_n_threads(1)

    logging.info("Reading stimulus data")
    data_cov, epochs = get_epoch(epoch, recording)
    raw_filename = raw_path / recording.filename
    trans_filename = trans_path / (
        "SQC_S%02i-SESS%i_B%i_trans.fif" %
        (recording.subject, recording.session, recording.block[1]))
    epoch_filename = ps.filenames(recording.subject, epoch, recording.session,
                                  recording.block[1])[0]

    logging.info("Setting up source space and forward model")

    forward, bem, source = sr.get_leadfield(
        "SQC_S%02i" % recording.subject,
        str(raw_filename),
        str(epoch_filename),
        str(trans_filename),
        bem_sub_path="bem",
        sdir="/home/nwilming/seqconf/fsdir/",
    )
    if glasser_only:
        labels = sr.get_labels(
            "SQC_S%02i" % recording.subject,
            filters=["*wang2015atlas*"],
            sdir="/home/nwilming/seqconf/fsdir/",
        )
    else:
        labels = sr.get_labels("SQC_S%02i" % recording.subject,
                               sdir="/home/nwilming/seqconf/fsdir/")
        labels = sr.labels_exclude(
            labels,
            exclude_filters=[
                "wang2015atlas.IPS4",
                "wang2015atlas.IPS5",
                "wang2015atlas.SPL",
                "JWDG_lat_Unknown",
            ],
        )
        labels = sr.labels_remove_overlap(labels,
                                          priority_filters=["wang", "JWDG"])

    fois_h = np.arange(36, 162, 4)
    fois_l = np.arange(2, 36, 1)
    tfr_params = {
        "HF": {
            "foi": fois_h,
            "cycles": fois_h * 0.25,
            "time_bandwidth": 2 + 1,
            "n_jobs": njobs,
            "est_val": fois_h,
            "est_key": "HF",
            "sf": 600,
            "decim": 10,
        },
        "LF": {
            "foi": fois_l,
            "cycles": fois_l * 0.4,
            "time_bandwidth": 1 + 1,
            "n_jobs": njobs,
            "est_val": fois_l,
            "est_key": "LF",
            "sf": 600,
            "decim": 10,
        },
    }

    events = epochs.events[:, 2]
    filters = lcmv.setup_filters(epochs.info, forward, data_cov, None, labels)
    set_n_threads(1)

    for i in range(0, len(events), chunks):
        filename = lcmvfilename(
            recording,
            signal_type,
            epoch,
            chunk=i,
        )
        if os.path.isfile(filename):
            continue
        if signal_type == "BB":
            logging.info("Starting reconstruction of BB signal")
            M = lcmv.reconstruct_broadband(
                filters,
                epochs.info,
                epochs._data[i:i + chunks],
                events[i:i + chunks],
                epochs.times,
                njobs=1,
            )
        else:
            logging.info("Starting reconstruction of TFR signal")
            M = lcmv.reconstruct_tfr(
                filters,
                epochs.info,
                epochs._data[i:i + chunks],
                events[i:i + chunks],
                epochs.times,
                est_args=tfr_params[signal_type],
                njobs=4,
            )
        M.to_hdf(str(filename), "epochs")
    set_n_threads(njobs)