Example #1
0
def extract_filter(subject,
                   session,
                   recording,
                   epoch,
                   signal_type='BB',
                   BEM='three_layer',
                   debug=False,
                   chunks=50,
                   njobs=4):
    mne.set_log_level('WARNING')
    lcmv.logging.getLogger().setLevel(logging.INFO)
    logging.getLogger().setLevel(logging.INFO)
    set_n_threads(1)
    subject_int = int(subject[1:])
    print('reading stimulus data')
    MRI_subjects = ["S1", "S5", "S6", "S8", "S11", "S12", "S16", "S17"]
    if subject in MRI_subjects:
        subject2 = subject
    else:
        subject2 = "fsaverage"

    logging.info('Reading stimulus data')

    if epoch == 'stimulus':
        data_cov, epochs, epochs_filename = get_stim_epoch(
            subject, session, recording)
    else:
        data_cov, epochs, epochs_filename = get_response_epoch(
            subject, session, recording)

    fname = '/storage/genis/preprocessed_megdata/filenames_sub%i.pickle' % (
        subject_int)
    f = open(fname, 'rb')
    data = pickle.load(f)
    df = pd.DataFrame.from_dict(data)
    raw_filename = df[df.subject == subject_int][df.session == session][
        df.trans_matrix == recording].filename.iloc[0]
    print('hola')
    #    trans_filename = '/home/gprat/cluster_home/pymeg/S%i-sess%i-%i-trans.fif' % (subject_int, session, recording)
    trans_filename = '/home/genis/pymeg/S%i-sess%i-%i-trans.fif' % (
        subject_int, session, recording)

    print('data_cov_done')

    #    raw_filename = glob('TODO' % (subject, session, recording))

    #    trans_filename = glob('TODO' % (subject, session, recording))[0]
    logging.info('Setting up source space and forward model')
    epo_filename = get_filename_trans_matrix(subject, session, recording)
    forward, bem, source = sr.get_leadfield(subject2,
                                            raw_filename,
                                            epo_filename,
                                            trans_filename,
                                            bem_sub_path='bem_ft')
    labels = sr.get_labels(subject2)
    labels = sr.labels_exclude(labels,
                               exclude_filters=[
                                   'wang2015atlas.IPS4', 'wang2015atlas.IPS5',
                                   'wang2015atlas.SPL', 'JWDG_lat_Unknown'
                               ])
    labels = sr.labels_remove_overlap(
        labels,
        priority_filters=['wang', 'JWDG'],
    )

    filters = lcmv.setup_filters(epochs.info, forward, data_cov, None, labels)
    subject_int = int(subject[1:])
    fname = 'filter_sub%i_SESS%i_recording%i_epoch%s.pickle' % (
        subject_int, session, recording, epoch)
    filename = join(path, fname)
    f = open(filename, 'wb')
    pickle.dump(filters, f)
    print('filter_done')
    f.close()

    return filters
def extract(subject,
            session,
            recording,
            epoch,
            signal_type='BB',
            BEM='three_layer',
            debug=False,
            chunks=100,
            njobs=4):
    mne.set_log_level('WARNING')
    lcmv.logging.getLogger().setLevel(logging.INFO)
    set_n_threads(1)

    logging.info('Reading stimulus data')

    if epoch == 'stimulus':
        data_cov, epochs, epochs_filename = get_stim_epoch(
            subject, session, recording)
    else:
        data_cov, epochs, epochs_filename = get_response_epoch(
            subject, session, recording)

    raw_filename = glob('TODO' % (subject, session, recording))

    trans_filename = glob('TODO' % (subject, session, recording))[0]
    logging.info('Setting up source space and forward model')

    forward, bem, source = sr.get_leadfield(subject,
                                            raw_filename,
                                            epochs_filename,
                                            trans_filename,
                                            bem_sub_path='bem_ft')
    labels = sr.get_labels(subject)
    labels = sr.labels_exclude(labels,
                               exclude_filters=[
                                   'wang2015atlas.IPS4', 'wang2015atlas.IPS5',
                                   'wang2015atlas.SPL', 'JWDG_lat_Unknown'
                               ])
    labels = sr.labels_remove_overlap(
        labels,
        priority_filters=['wang', 'JWDG'],
    )

    fois_h = np.arange(36, 162, 4)
    fois_l = np.arange(2, 36, 1)
    tfr_params = {
        'HF': {
            'foi': fois_h,
            'cycles': fois_h * 0.25,
            'time_bandwidth': 2 + 1,
            'n_jobs': njobs,
            'est_val': fois_h,
            'est_key': 'HF',
            'sf': 600,
            'decim': 10
        },
        'LF': {
            'foi': fois_l,
            'cycles': fois_l * 0.4,
            'time_bandwidth': 1 + 1,
            'n_jobs': njobs,
            'est_val': fois_l,
            'est_key': 'LF',
            'sf': 600,
            'decim': 10
        }
    }

    events = epochs.events[:, 2]
    filters = lcmv.setup_filters(epochs.info, forward, data_cov, None, labels)
    set_n_threads(1)

    for i in range(0, len(events), chunks):
        filename = lcmvfilename(subject,
                                session,
                                signal_type,
                                recording,
                                chunk=i)
        if os.path.isfile(filename):
            continue
        if signal_type == 'BB':
            logging.info('Starting reconstruction of BB signal')
            M = lcmv.reconstruct_broadband(filters,
                                           epochs.info,
                                           epochs._data[i:i + chunks],
                                           events[i:i + chunks],
                                           epochs.times,
                                           njobs=1)
        else:
            logging.info('Starting reconstruction of TFR signal')
            M = lcmv.reconstruct_tfr(filters,
                                     epochs.info,
                                     epochs._data[i:i + chunks],
                                     events[i:i + chunks],
                                     epochs.times,
                                     est_args=tfr_params[signal_type],
                                     njobs=4)
        M.to_hdf(filename, 'epochs')
    set_n_threads(njobs)
Example #3
0
def do_source_recon(subj, session, njobs=4):

    epochs_filename_stim = os.path.join(data_folder, "epochs", subj, session,
                                        '{}-epo.fif.gz'.format('stimlock'))
    epochs_filename_resp = os.path.join(data_folder, "epochs", subj, session,
                                        '{}-epo.fif.gz'.format('resplock'))
    trans_filename = os.path.join(data_folder, "transformation_matrix",
                                  '{}_{}-trans.fif'.format(subj, session))

    if os.path.isfile(epochs_filename_stim):

        runs = sorted([
            run.split('/')[-1] for run in glob.glob(
                os.path.join(data_folder, "raw", subj, session, "meg", "*.ds"))
        ])
        center = int(np.floor(len(runs) / 2.0))
        raw_filename = os.path.join(data_folder, "raw", subj, session, "meg",
                                    runs[center])

        # # make transformation matrix:
        # sr.make_trans(subj, raw_filename, epochs_filename, trans_filename)

        # load labels:
        labels = sr.get_labels(subject=subj,
                               filters=['*wang*.label', '*JWG*.label'],
                               annotations=['HCPMMP1'])
        labels = sr.labels_exclude(labels=labels,
                                   exclude_filters=[
                                       'wang2015atlas.IPS4',
                                       'wang2015atlas.IPS5',
                                       'wang2015atlas.SPL', 'JWG_lat_Unknown'
                                   ])
        labels = sr.labels_remove_overlap(labels=labels,
                                          priority_filters=['wang', 'JWG'])
        print(labels)

        # load epochs:
        epochs_stim = mne.read_epochs(epochs_filename_stim)
        epochs_stim = epochs_stim.pick_channels(
            [x for x in epochs_stim.ch_names if x.startswith('M')])
        epochs_resp = mne.read_epochs(epochs_filename_resp)
        epochs_resp = epochs_resp.pick_channels(
            [x for x in epochs_resp.ch_names if x.startswith('M')])

        # baseline stuff:
        overlap = list(
            set(epochs_stim.events[:, 2]).intersection(
                set(epochs_resp.events[:, 2])))
        epochs_stim = epochs_stim[[str(l) for l in overlap]]
        epochs_resp = epochs_resp[[str(l) for l in overlap]]
        id_time = (-0.3 <= epochs_stim.times) & (epochs_stim.times <= -0.2)
        means = epochs_stim._data[:, :, id_time].mean(-1)
        epochs_stim._data = epochs_stim._data - means[:, :, np.newaxis]
        epochs_resp._data = epochs_resp._data - means[:, :, np.newaxis]

        # TFR settings:
        fois_h = np.arange(42, 162, 4)
        fois_l = np.arange(2, 42, 2)
        tfr_params = {
            'HF': {
                'foi': fois_h,
                'cycles': fois_h * 0.4,
                'time_bandwidth': 5 + 1,
                'n_jobs': njobs,
                'est_val': fois_h,
                'est_key': 'HF'
            },
            'LF': {
                'foi': fois_l,
                'cycles': fois_l * 0.4,
                'time_bandwidth': 1 + 1,
                'n_jobs': njobs,
                'est_val': fois_l,
                'est_key': 'LF'
            }
        }

        # get cov:
        data_cov = lcmv.get_cov(epochs_stim, tmin=0, tmax=1)
        noise_cov = None

        # get lead field:
        forward, bem, source = sr.get_leadfield(
            subject=subj,
            raw_filename=raw_filename,
            epochs_filename=epochs_filename_stim,
            trans_filename=trans_filename,
            conductivity=(0.3, 0.006, 0.3),
            njobs=njobs)

        # do source level analysis:
        for tl, epochs in zip(['stimlock', 'resplock'],
                              [epochs_stim, epochs_resp]):
            for signal_type in ['LF', 'HF']:
                print(signal_type)

                # events:
                events = epochs.events[:, 2]
                data = []
                filters = lcmv.setup_filters(epochs.info,
                                             forward,
                                             data_cov,
                                             None,
                                             labels,
                                             njobs=njobs)

                # in chunks:
                chunks = 100
                for i in range(0, len(events), chunks):
                    filename = os.path.join(data_folder, "source_level", 'lcmv_{}_{}_{}_{}_{}-source.hdf'.\
                        format(subj, session, tl, signal_type, i))
                    # if os.path.isfile(filename):
                    #     continue
                    M = lcmv.reconstruct_tfr(filters,
                                             epochs.info,
                                             epochs._data[i:i + chunks],
                                             events[i:i + chunks],
                                             epochs.times,
                                             est_args=tfr_params[signal_type],
                                             njobs=njobs)
                    M.to_hdf(filename, 'epochs')

                    del M
def decode(
    subject,
    area,
    epoch_type="stimulus",
    only_glasser=False,
    BEM="three_layer",
    debug=False,
    target="response",
):
    mne.set_log_level("WARNING")
    pymeglcmv.logging.getLogger().setLevel(logging.INFO)
    set_n_threads(1)
    labels = get_labels(subject, only_glasser)
    labels = [x for x in labels if any([cl for cl in areas_to_labels[area] if cl in x.name])]
    print(labels)
    if len(labels) < 1:
        raise RuntimeError('Expecting at least two labels')
    label = labels.pop()
    for l in labels:
        label += l
    #label = labels[0] + labels[1]

    print('Selecting this label for area %s:'%area, label)

    #return

    logging.info("Reading stimulus data")
    if epoch_type == "stimulus":
        # data_cov, epochs = get_stim_epoch(subject, session)
        data = [get_stim_epoch(subject, i) for i in range(4)]
    elif epoch_type == "response":
        # data_cov, epochs = get_response_epoch(subject, session)
        data = [get_response_epoch(subject, i) for i in range(4)]
    else:
        raise RuntimeError("Did not recognize epoch")


    logging.info("Setting up source space and forward model")

    fwds = [get_leadfield(subject, session, BEM)[0] for session in range(4)]

    fois = np.arange(10, 150, 5)
    lfois = np.arange(1, 10, 1)
    tfr_params = {
        "F": {
            "foi": fois,
            "cycles": fois * 0.1,
            "time_bandwidth": 2,
            "n_jobs": 1,
            "est_val": fois,
            "est_key": "F",
        },
        "LF": {
            "foi": lfois,
            "cycles": lfois * 0.25,
            "time_bandwidth": 2,
            "n_jobs": 1,
            "est_val": lfois,
            "est_key": "LF",
        },
    }

    events = [d[1].events[:, 2] for d in data]
    filters = []
    for (data_cov, epochs), forward in zip(data, fwds):
        filters.append(
            pymeglcmv.setup_filters(epochs.info, forward, data_cov, None, [label])
        )
    set_n_threads(1)

    F_tfrdata, events, F_freq, times = decoding.get_lcmv(
        tfr_params["F"], [d[1] for d in data], filters, njobs=6
    )
    LF_tfrdata, events, LF_freq, times = decoding.get_lcmv(
        tfr_params["LF"], [d[1] for d in data], filters, njobs=6
    )

    tfrdata = np.hstack([F_tfrdata, LF_tfrdata])
    del LF_tfrdata, F_tfrdata
    freq = np.concatenate([F_freq, LF_freq])
    meta = augment_meta(
        preprocessing.get_meta_for_subject(
            subject, epoch_type, sessions=range(4)
        ).set_index("hash")
    )

    # Kick out trials with RT < 0.225    
    choice_rt = meta.choice_rt
    valid_trials = choice_rt[choice_rt>=0.225].index.values        
    valid_trials = np.isin(events, valid_trials)
    tfrdata = tfrdata[valid_trials]
    events = events[valid_trials]
    # How many kicked out?
    n_out = (~valid_trials).sum()
    n_all = len(events)    
    print('Kicking out %i/%i (%0.2f percent) trials due to RT'%(n_out, n_all, n_out/n_all))

    all_s = []
    for target in ["response", "unsigned_confidence", "signed_confidence"]:
        fname = "/home/nwilming/S%i_trg%s_ROI_%s.hdf" % (subject, target, area)
        try:
            k = pd.read_hdf(fname)
        except FileNotFoundError:
            target_vals = meta.loc[:, target]
            dcd = decoding.Decoder(target_vals)
            k = dcd.classify(
                tfrdata, times, freq, events, area, 
                average_vertices=False, use_phase=True
            )
            k.loc[:, "target"] = target
            k.to_hdf(fname, "df")
        all_s.append(k)
    all_s = pd.concat(all_s)
    all_s.loc[:, 'ROI'] = area
    all_s.to_hdf("/home/nwilming/S%i_ROI_%s.hdf" % (subject, area), "df")
    return k
def get_filter(
    subject,
    session,
    epoch_type="stimulus",
    only_glasser=False,
    BEM="three_layer",
):
    mne.set_log_level("WARNING")
    pymeglcmv.logging.getLogger().setLevel(logging.INFO)
    set_n_threads(1)

    logging.info("Reading stimulus data")
    if epoch_type == "stimulus":
        data_cov, epochs = get_stim_epoch(subject, session)
    elif epoch_type == "response":
        data_cov, epochs = get_response_epoch(subject, session)
    else:
        raise RuntimeError("Did not recognize epoch")

    logging.info("Setting up source space and forward model")

    forward, bem, source = get_leadfield(subject, session, BEM)

    if not only_glasser:
        labels = pymegsr.get_labels(
            subject="S%02i" % subject,
            filters=["*wang*.label", "*JWDG*.label"],
            annotations=["HCPMMP1"],
        )
        labels = pymegsr.labels_exclude(
            labels=labels,
            exclude_filters=[
                "wang2015atlas.IPS4",
                "wang2015atlas.IPS5",
                "wang2015atlas.SPL",
                "JWDG_lat_Unknown",
            ],
        )
        labels = pymegsr.labels_remove_overlap(
            labels=labels, priority_filters=["wang", "JWDG"])
    else:
        labels = pymegsr.get_labels(
            subject="S%02i" % subject,
            filters=["select_nothing"],
            annotations=["HCPMMP1"],
        )
    # Now chunk Reconstruction into blocks of ~100 trials to save Memory
    fois = np.arange(10, 150, 5)
    lfois = np.arange(1, 10, 1)
    tfr_params = {
        "F": {
            "foi": fois,
            "cycles": fois * 0.1,
            "time_bandwidth": 2,
            "n_jobs": 1,
            "est_val": fois,
            "est_key": "F",
        },
        "LF": {
            "foi": lfois,
            "cycles": lfois * 0.25,
            "time_bandwidth": 2,
            "n_jobs": 1,
            "est_val": lfois,
            "est_key": "LF",
        },
    }

    events = epochs.events[:, 2]
    data = []
    filters = pymeglcmv.setup_filters(epochs.info, forward, data_cov, None,
                                      labels)
    return filters
def extract(
    subject,
    session,
    epoch_type="stimulus",
    signal_type="BB",
    only_glasser=False,
    BEM="three_layer",
    debug=False,
    chunks=100,
    njobs=4,
):
    mne.set_log_level("WARNING")
    pymeglcmv.logging.getLogger().setLevel(logging.INFO)
    set_n_threads(1)

    logging.info("Reading stimulus data")
    if epoch_type == "stimulus":
        data_cov, epochs = get_stim_epoch(subject, session)
    elif epoch_type == "response":
        data_cov, epochs = get_response_epoch(subject, session)
    else:
        raise RuntimeError("Did not recognize epoch")

    logging.info("Setting up source space and forward model")

    forward, bem, source = get_leadfield(subject, session, BEM)

    if not only_glasser:
        labels = pymegsr.get_labels(
            subject="S%02i" % subject,
            filters=["*wang*.label", "*JWDG*.label"],
            annotations=["HCPMMP1"],
        )
        labels = pymegsr.labels_exclude(
            labels=labels,
            exclude_filters=[
                "wang2015atlas.IPS4",
                "wang2015atlas.IPS5",
                "wang2015atlas.SPL",
                "JWDG_lat_Unknown",
            ],
        )
        labels = pymegsr.labels_remove_overlap(
            labels=labels, priority_filters=["wang", "JWDG"])
    else:
        labels = pymegsr.get_labels(
            subject="S%02i" % subject,
            filters=["select_nothing"],
            annotations=["HCPMMP1"],
        )
    # Now chunk Reconstruction into blocks of ~100 trials to save Memory
    fois = np.arange(10, 150, 5)
    lfois = np.arange(1, 10, 1)
    tfr_params = {
        "F": {
            "foi": fois,
            "cycles": fois * 0.1,
            "time_bandwidth": 2,
            "n_jobs": 1,
            "est_val": fois,
            "est_key": "F",
        },
        "LF": {
            "foi": lfois,
            "cycles": lfois * 0.25,
            "time_bandwidth": 2,
            "n_jobs": 1,
            "est_val": lfois,
            "est_key": "LF",
        },
    }

    events = epochs.events[:, 2]
    data = []
    filters = pymeglcmv.setup_filters(epochs.info, forward, data_cov, None,
                                      labels)

    set_n_threads(1)

    for i in range(0, len(events), chunks):
        filename = lcmvfilename(subject,
                                session,
                                signal_type,
                                epoch_type,
                                chunk=i,
                                only_glasser=only_glasser)
        logging.info(filename)
        # if os.path.isfile(filename):
        #    continue
        if signal_type == "BB":
            logging.info("Starting reconstruction of BB signal")
            M = pymeglcmv.reconstruct_broadband(
                filters,
                epochs.info,
                epochs._data[i:i + chunks],
                events[i:i + chunks],
                epochs.times,
                njobs=1,
            )
        else:
            logging.info("Starting reconstruction of TFR signal")
            M = pymeglcmv.reconstruct_tfr(
                filters,
                epochs.info,
                epochs._data[i:i + chunks],
                events[i:i + chunks],
                epochs.times,
                est_args=tfr_params[signal_type],
                njobs=4,
            )
        M.to_hdf(filename, "epochs", mode="w")
    set_n_threads(njobs)
Example #7
0
def extract(
    recording_number,
    epoch,
    signal_type="BB",
    BEM="three_layer",
    chunks=100,
    njobs=4,
    glasser_only=True,
):
    recording = ps.recordings[recording_number]
    mne.set_log_level("ERROR")
    lcmv.logging.getLogger().setLevel(logging.INFO)
    set_n_threads(1)

    logging.info("Reading stimulus data")
    data_cov, epochs = get_epoch(epoch, recording)
    raw_filename = raw_path / recording.filename
    trans_filename = trans_path / (
        "SQC_S%02i-SESS%i_B%i_trans.fif" %
        (recording.subject, recording.session, recording.block[1]))
    epoch_filename = ps.filenames(recording.subject, epoch, recording.session,
                                  recording.block[1])[0]

    logging.info("Setting up source space and forward model")

    forward, bem, source = sr.get_leadfield(
        "SQC_S%02i" % recording.subject,
        str(raw_filename),
        str(epoch_filename),
        str(trans_filename),
        bem_sub_path="bem",
        sdir="/home/nwilming/seqconf/fsdir/",
    )
    if glasser_only:
        labels = sr.get_labels(
            "SQC_S%02i" % recording.subject,
            filters=["*wang2015atlas*"],
            sdir="/home/nwilming/seqconf/fsdir/",
        )
    else:
        labels = sr.get_labels("SQC_S%02i" % recording.subject,
                               sdir="/home/nwilming/seqconf/fsdir/")
        labels = sr.labels_exclude(
            labels,
            exclude_filters=[
                "wang2015atlas.IPS4",
                "wang2015atlas.IPS5",
                "wang2015atlas.SPL",
                "JWDG_lat_Unknown",
            ],
        )
        labels = sr.labels_remove_overlap(labels,
                                          priority_filters=["wang", "JWDG"])

    fois_h = np.arange(36, 162, 4)
    fois_l = np.arange(2, 36, 1)
    tfr_params = {
        "HF": {
            "foi": fois_h,
            "cycles": fois_h * 0.25,
            "time_bandwidth": 2 + 1,
            "n_jobs": njobs,
            "est_val": fois_h,
            "est_key": "HF",
            "sf": 600,
            "decim": 10,
        },
        "LF": {
            "foi": fois_l,
            "cycles": fois_l * 0.4,
            "time_bandwidth": 1 + 1,
            "n_jobs": njobs,
            "est_val": fois_l,
            "est_key": "LF",
            "sf": 600,
            "decim": 10,
        },
    }

    events = epochs.events[:, 2]
    filters = lcmv.setup_filters(epochs.info, forward, data_cov, None, labels)
    set_n_threads(1)

    for i in range(0, len(events), chunks):
        filename = lcmvfilename(
            recording,
            signal_type,
            epoch,
            chunk=i,
        )
        if os.path.isfile(filename):
            continue
        if signal_type == "BB":
            logging.info("Starting reconstruction of BB signal")
            M = lcmv.reconstruct_broadband(
                filters,
                epochs.info,
                epochs._data[i:i + chunks],
                events[i:i + chunks],
                epochs.times,
                njobs=1,
            )
        else:
            logging.info("Starting reconstruction of TFR signal")
            M = lcmv.reconstruct_tfr(
                filters,
                epochs.info,
                epochs._data[i:i + chunks],
                events[i:i + chunks],
                epochs.times,
                est_args=tfr_params[signal_type],
                njobs=4,
            )
        M.to_hdf(str(filename), "epochs")
    set_n_threads(njobs)