def get_labels(subject, only_glasser):
    if not only_glasser:
        labels = pymegsr.get_labels(
            subject="S%02i" % subject,
            filters=["*wang*.label", "*JWDG*.label"],
            annotations=["HCPMMP1"],
        )
        labels = pymegsr.labels_exclude(
            labels=labels,
            exclude_filters=[
                "wang2015atlas.IPS4",
                "wang2015atlas.IPS5",
                "wang2015atlas.SPL",
                "JWDG_lat_Unknown",
            ],
        )
        labels = pymegsr.labels_remove_overlap(
            labels=labels, priority_filters=["wang", "JWDG"]
        )
    else:
        labels = pymegsr.get_labels(
            subject="S%02i" % subject,
            filters=["select_nothing"],
            annotations=["HCPMMP1"],
        )
    return labels
Exemple #2
0
def aggregate(subject, session, datatype):
    from pymeg import aggregate_sr as asr
    from os.path import join
    data = (
        '/home/pmurphy/Surprise_accumulation/Analysis/MEG/Conv2mne/%s-SESS%i-*%s*lcmv.hdf'
        % (subject, session, datatype))
    labels = pymegsr.get_labels(subject)
    labels = pymegsr.labels_exclude(labels,
                                    exclude_filters=[
                                        'wang2015atlas.IPS4',
                                        'wang2015atlas.IPS5',
                                        'wang2015atlas.SPL', 'JWDG_lat_Unknown'
                                    ])
    clusters = {l.name: [l.name] for l in labels}
    if datatype == 'F':  # time-frequency
        agg = asr.aggregate_files(data,
                                  data, (-0.4, -0.2),
                                  to_decibels=True,
                                  all_clusters=clusters,
                                  hemis=['Single'])
    elif datatype == 'BB':  # broadband
        agg = asr.aggregate_files(data,
                                  data, (-0.2, 0),
                                  to_decibels=False,
                                  all_clusters=clusters,
                                  hemis=['Single'])

    filename = join('/home/rbrink/NRS/agg/',
                    'S%s_SESS%i_%s_agg.hdf' % (subject, session, datatype))
    asr.agg2hdf(agg, filename)
def plot_brain_color_legend(palette):
    from surfer import Brain
    from pymeg import atlas_glasser as ag
    from pymeg import source_reconstruction as sr

    labels = sr.get_labels(subject='S04',
                           filters=['*wang*.label', '*JWDG*.label'],
                           annotations=['HCPMMP1'])
    labels = sr.labels_exclude(labels=labels,
                               exclude_filters=[
                                   'wang2015atlas.IPS4', 'wang2015atlas.IPS5',
                                   'wang2015atlas.SPL', 'JWDG_lat_Unknown'
                               ])
    labels = sr.labels_remove_overlap(labels=labels,
                                      priority_filters=['wang', 'JWDG'])
    lc = ag.labels2clusters(labels)
    brain = Brain('S04', 'lh', 'inflated', views=['lat'], background='w')
    for cluster, labelobjects in lc.items():
        if cluster in palette.keys():
            color = palette[cluster]
            for l0 in labelobjects:
                if l0.hemi == 'lh':
                    brain.add_label(l0, color=color, alpha=1)
    brain.save_montage(
        '/Users/nwilming/Dropbox/UKE/confidence_study/brain_colorbar.png',
        [['par', 'fro'], ['lat', 'med']])
    return brain
def plot_summary_results(
        data,
        cmap='RdBu_r',
        limits={
            'MIDC_split': (0.3, 0.7),
            'MIDC_nosplit': (0.3, 0.7),
            'CONF_signed': (0.3, 0.7),
            'CONF_unsigned': (.3, 0.7),
            'SIDE_nosplit': (0.3, 0.7),
            'CONF_unsign_split': (.3, 0.7),
            'SSD': (-0.05, 0.05),
            'SSD_acc_contrast': (-0.05, 0.05),
            'SSD_acc_contrast_diff': (-0.05, 0.05),
            'SSD_delta_contrast': (-0.05, 0.05)
        },
        ex_sub='S04',
        measure='auc',
        epoch='response',
        classifier='svc',
        views=[['par', 'fro'], ['lat', 'med']]):
    from pymeg import roi_clusters as rois, source_reconstruction as sr

    # labels = sr.get_labels(ex_sub)
    labels = sr.get_labels(ex_sub)
    lc = rois.labels_to_clusters(labels, rois.all_clusters, hemi='lh')

    for signal, dsignal in data.groupby('signal'):
        vmin, vmax = limits[signal]
        norm = colors.Normalize(vmin=vmin, vmax=vmax)
        colortable = cm.get_cmap(cmap)
        cfunc = lambda x: colortable(norm(x))
        brain = plot_one_brain(dsignal,
                               signal,
                               lc,
                               cfunc,
                               ex_sub=ex_sub,
                               measure=measure,
                               classifier=classifier,
                               epoch=epoch,
                               views=views)
def extract(subject,
            session,
            recording,
            epoch,
            signal_type='BB',
            BEM='three_layer',
            debug=False,
            chunks=100,
            njobs=4):
    mne.set_log_level('WARNING')
    lcmv.logging.getLogger().setLevel(logging.INFO)
    set_n_threads(1)

    logging.info('Reading stimulus data')

    if epoch == 'stimulus':
        data_cov, epochs, epochs_filename = get_stim_epoch(
            subject, session, recording)
    else:
        data_cov, epochs, epochs_filename = get_response_epoch(
            subject, session, recording)

    raw_filename = glob('TODO' % (subject, session, recording))

    trans_filename = glob('TODO' % (subject, session, recording))[0]
    logging.info('Setting up source space and forward model')

    forward, bem, source = sr.get_leadfield(subject,
                                            raw_filename,
                                            epochs_filename,
                                            trans_filename,
                                            bem_sub_path='bem_ft')
    labels = sr.get_labels(subject)
    labels = sr.labels_exclude(labels,
                               exclude_filters=[
                                   'wang2015atlas.IPS4', 'wang2015atlas.IPS5',
                                   'wang2015atlas.SPL', 'JWDG_lat_Unknown'
                               ])
    labels = sr.labels_remove_overlap(
        labels,
        priority_filters=['wang', 'JWDG'],
    )

    fois_h = np.arange(36, 162, 4)
    fois_l = np.arange(2, 36, 1)
    tfr_params = {
        'HF': {
            'foi': fois_h,
            'cycles': fois_h * 0.25,
            'time_bandwidth': 2 + 1,
            'n_jobs': njobs,
            'est_val': fois_h,
            'est_key': 'HF',
            'sf': 600,
            'decim': 10
        },
        'LF': {
            'foi': fois_l,
            'cycles': fois_l * 0.4,
            'time_bandwidth': 1 + 1,
            'n_jobs': njobs,
            'est_val': fois_l,
            'est_key': 'LF',
            'sf': 600,
            'decim': 10
        }
    }

    events = epochs.events[:, 2]
    filters = lcmv.setup_filters(epochs.info, forward, data_cov, None, labels)
    set_n_threads(1)

    for i in range(0, len(events), chunks):
        filename = lcmvfilename(subject,
                                session,
                                signal_type,
                                recording,
                                chunk=i)
        if os.path.isfile(filename):
            continue
        if signal_type == 'BB':
            logging.info('Starting reconstruction of BB signal')
            M = lcmv.reconstruct_broadband(filters,
                                           epochs.info,
                                           epochs._data[i:i + chunks],
                                           events[i:i + chunks],
                                           epochs.times,
                                           njobs=1)
        else:
            logging.info('Starting reconstruction of TFR signal')
            M = lcmv.reconstruct_tfr(filters,
                                     epochs.info,
                                     epochs._data[i:i + chunks],
                                     events[i:i + chunks],
                                     epochs.times,
                                     est_args=tfr_params[signal_type],
                                     njobs=4)
        M.to_hdf(filename, 'epochs')
    set_n_threads(njobs)
def extract_filter(subject,
                   session,
                   recording,
                   epoch,
                   signal_type='BB',
                   BEM='three_layer',
                   debug=False,
                   chunks=50,
                   njobs=4):
    mne.set_log_level('WARNING')
    lcmv.logging.getLogger().setLevel(logging.INFO)
    logging.getLogger().setLevel(logging.INFO)
    set_n_threads(1)
    subject_int = int(subject[1:])
    print('reading stimulus data')
    MRI_subjects = ["S1", "S5", "S6", "S8", "S11", "S12", "S16", "S17"]
    if subject in MRI_subjects:
        subject2 = subject
    else:
        subject2 = "fsaverage"

    logging.info('Reading stimulus data')

    if epoch == 'stimulus':
        data_cov, epochs, epochs_filename = get_stim_epoch(
            subject, session, recording)
    else:
        data_cov, epochs, epochs_filename = get_response_epoch(
            subject, session, recording)

    fname = '/storage/genis/preprocessed_megdata/filenames_sub%i.pickle' % (
        subject_int)
    f = open(fname, 'rb')
    data = pickle.load(f)
    df = pd.DataFrame.from_dict(data)
    raw_filename = df[df.subject == subject_int][df.session == session][
        df.trans_matrix == recording].filename.iloc[0]
    print('hola')
    #    trans_filename = '/home/gprat/cluster_home/pymeg/S%i-sess%i-%i-trans.fif' % (subject_int, session, recording)
    trans_filename = '/home/genis/pymeg/S%i-sess%i-%i-trans.fif' % (
        subject_int, session, recording)

    print('data_cov_done')

    #    raw_filename = glob('TODO' % (subject, session, recording))

    #    trans_filename = glob('TODO' % (subject, session, recording))[0]
    logging.info('Setting up source space and forward model')
    epo_filename = get_filename_trans_matrix(subject, session, recording)
    forward, bem, source = sr.get_leadfield(subject2,
                                            raw_filename,
                                            epo_filename,
                                            trans_filename,
                                            bem_sub_path='bem_ft')
    labels = sr.get_labels(subject2)
    labels = sr.labels_exclude(labels,
                               exclude_filters=[
                                   'wang2015atlas.IPS4', 'wang2015atlas.IPS5',
                                   'wang2015atlas.SPL', 'JWDG_lat_Unknown'
                               ])
    labels = sr.labels_remove_overlap(
        labels,
        priority_filters=['wang', 'JWDG'],
    )

    filters = lcmv.setup_filters(epochs.info, forward, data_cov, None, labels)
    subject_int = int(subject[1:])
    fname = 'filter_sub%i_SESS%i_recording%i_epoch%s.pickle' % (
        subject_int, session, recording, epoch)
    filename = join(path, fname)
    f = open(filename, 'wb')
    pickle.dump(filters, f)
    print('filter_done')
    f.close()

    return filters
def do_source_recon(subj, session, njobs=4):

    epochs_filename_stim = os.path.join(data_folder, "epochs", subj, session,
                                        '{}-epo.fif.gz'.format('stimlock'))
    epochs_filename_resp = os.path.join(data_folder, "epochs", subj, session,
                                        '{}-epo.fif.gz'.format('resplock'))
    trans_filename = os.path.join(data_folder, "transformation_matrix",
                                  '{}_{}-trans.fif'.format(subj, session))

    if os.path.isfile(epochs_filename_stim):

        runs = sorted([
            run.split('/')[-1] for run in glob.glob(
                os.path.join(data_folder, "raw", subj, session, "meg", "*.ds"))
        ])
        center = int(np.floor(len(runs) / 2.0))
        raw_filename = os.path.join(data_folder, "raw", subj, session, "meg",
                                    runs[center])

        # # make transformation matrix:
        # sr.make_trans(subj, raw_filename, epochs_filename, trans_filename)

        # load labels:
        labels = sr.get_labels(subject=subj,
                               filters=['*wang*.label', '*JWG*.label'],
                               annotations=['HCPMMP1'])
        labels = sr.labels_exclude(labels=labels,
                                   exclude_filters=[
                                       'wang2015atlas.IPS4',
                                       'wang2015atlas.IPS5',
                                       'wang2015atlas.SPL', 'JWG_lat_Unknown'
                                   ])
        labels = sr.labels_remove_overlap(labels=labels,
                                          priority_filters=['wang', 'JWG'])
        print(labels)

        # load epochs:
        epochs_stim = mne.read_epochs(epochs_filename_stim)
        epochs_stim = epochs_stim.pick_channels(
            [x for x in epochs_stim.ch_names if x.startswith('M')])
        epochs_resp = mne.read_epochs(epochs_filename_resp)
        epochs_resp = epochs_resp.pick_channels(
            [x for x in epochs_resp.ch_names if x.startswith('M')])

        # baseline stuff:
        overlap = list(
            set(epochs_stim.events[:, 2]).intersection(
                set(epochs_resp.events[:, 2])))
        epochs_stim = epochs_stim[[str(l) for l in overlap]]
        epochs_resp = epochs_resp[[str(l) for l in overlap]]
        id_time = (-0.3 <= epochs_stim.times) & (epochs_stim.times <= -0.2)
        means = epochs_stim._data[:, :, id_time].mean(-1)
        epochs_stim._data = epochs_stim._data - means[:, :, np.newaxis]
        epochs_resp._data = epochs_resp._data - means[:, :, np.newaxis]

        # TFR settings:
        fois_h = np.arange(42, 162, 4)
        fois_l = np.arange(2, 42, 2)
        tfr_params = {
            'HF': {
                'foi': fois_h,
                'cycles': fois_h * 0.4,
                'time_bandwidth': 5 + 1,
                'n_jobs': njobs,
                'est_val': fois_h,
                'est_key': 'HF'
            },
            'LF': {
                'foi': fois_l,
                'cycles': fois_l * 0.4,
                'time_bandwidth': 1 + 1,
                'n_jobs': njobs,
                'est_val': fois_l,
                'est_key': 'LF'
            }
        }

        # get cov:
        data_cov = lcmv.get_cov(epochs_stim, tmin=0, tmax=1)
        noise_cov = None

        # get lead field:
        forward, bem, source = sr.get_leadfield(
            subject=subj,
            raw_filename=raw_filename,
            epochs_filename=epochs_filename_stim,
            trans_filename=trans_filename,
            conductivity=(0.3, 0.006, 0.3),
            njobs=njobs)

        # do source level analysis:
        for tl, epochs in zip(['stimlock', 'resplock'],
                              [epochs_stim, epochs_resp]):
            for signal_type in ['LF', 'HF']:
                print(signal_type)

                # events:
                events = epochs.events[:, 2]
                data = []
                filters = lcmv.setup_filters(epochs.info,
                                             forward,
                                             data_cov,
                                             None,
                                             labels,
                                             njobs=njobs)

                # in chunks:
                chunks = 100
                for i in range(0, len(events), chunks):
                    filename = os.path.join(data_folder, "source_level", 'lcmv_{}_{}_{}_{}_{}-source.hdf'.\
                        format(subj, session, tl, signal_type, i))
                    # if os.path.isfile(filename):
                    #     continue
                    M = lcmv.reconstruct_tfr(filters,
                                             epochs.info,
                                             epochs._data[i:i + chunks],
                                             events[i:i + chunks],
                                             epochs.times,
                                             est_args=tfr_params[signal_type],
                                             njobs=njobs)
                    M.to_hdf(filename, 'epochs')

                    del M
def get_filter(
    subject,
    session,
    epoch_type="stimulus",
    only_glasser=False,
    BEM="three_layer",
):
    mne.set_log_level("WARNING")
    pymeglcmv.logging.getLogger().setLevel(logging.INFO)
    set_n_threads(1)

    logging.info("Reading stimulus data")
    if epoch_type == "stimulus":
        data_cov, epochs = get_stim_epoch(subject, session)
    elif epoch_type == "response":
        data_cov, epochs = get_response_epoch(subject, session)
    else:
        raise RuntimeError("Did not recognize epoch")

    logging.info("Setting up source space and forward model")

    forward, bem, source = get_leadfield(subject, session, BEM)

    if not only_glasser:
        labels = pymegsr.get_labels(
            subject="S%02i" % subject,
            filters=["*wang*.label", "*JWDG*.label"],
            annotations=["HCPMMP1"],
        )
        labels = pymegsr.labels_exclude(
            labels=labels,
            exclude_filters=[
                "wang2015atlas.IPS4",
                "wang2015atlas.IPS5",
                "wang2015atlas.SPL",
                "JWDG_lat_Unknown",
            ],
        )
        labels = pymegsr.labels_remove_overlap(
            labels=labels, priority_filters=["wang", "JWDG"])
    else:
        labels = pymegsr.get_labels(
            subject="S%02i" % subject,
            filters=["select_nothing"],
            annotations=["HCPMMP1"],
        )
    # Now chunk Reconstruction into blocks of ~100 trials to save Memory
    fois = np.arange(10, 150, 5)
    lfois = np.arange(1, 10, 1)
    tfr_params = {
        "F": {
            "foi": fois,
            "cycles": fois * 0.1,
            "time_bandwidth": 2,
            "n_jobs": 1,
            "est_val": fois,
            "est_key": "F",
        },
        "LF": {
            "foi": lfois,
            "cycles": lfois * 0.25,
            "time_bandwidth": 2,
            "n_jobs": 1,
            "est_val": lfois,
            "est_key": "LF",
        },
    }

    events = epochs.events[:, 2]
    data = []
    filters = pymeglcmv.setup_filters(epochs.info, forward, data_cov, None,
                                      labels)
    return filters
def extract(
    subject,
    session,
    epoch_type="stimulus",
    signal_type="BB",
    only_glasser=False,
    BEM="three_layer",
    debug=False,
    chunks=100,
    njobs=4,
):
    mne.set_log_level("WARNING")
    pymeglcmv.logging.getLogger().setLevel(logging.INFO)
    set_n_threads(1)

    logging.info("Reading stimulus data")
    if epoch_type == "stimulus":
        data_cov, epochs = get_stim_epoch(subject, session)
    elif epoch_type == "response":
        data_cov, epochs = get_response_epoch(subject, session)
    else:
        raise RuntimeError("Did not recognize epoch")

    logging.info("Setting up source space and forward model")

    forward, bem, source = get_leadfield(subject, session, BEM)

    if not only_glasser:
        labels = pymegsr.get_labels(
            subject="S%02i" % subject,
            filters=["*wang*.label", "*JWDG*.label"],
            annotations=["HCPMMP1"],
        )
        labels = pymegsr.labels_exclude(
            labels=labels,
            exclude_filters=[
                "wang2015atlas.IPS4",
                "wang2015atlas.IPS5",
                "wang2015atlas.SPL",
                "JWDG_lat_Unknown",
            ],
        )
        labels = pymegsr.labels_remove_overlap(
            labels=labels, priority_filters=["wang", "JWDG"])
    else:
        labels = pymegsr.get_labels(
            subject="S%02i" % subject,
            filters=["select_nothing"],
            annotations=["HCPMMP1"],
        )
    # Now chunk Reconstruction into blocks of ~100 trials to save Memory
    fois = np.arange(10, 150, 5)
    lfois = np.arange(1, 10, 1)
    tfr_params = {
        "F": {
            "foi": fois,
            "cycles": fois * 0.1,
            "time_bandwidth": 2,
            "n_jobs": 1,
            "est_val": fois,
            "est_key": "F",
        },
        "LF": {
            "foi": lfois,
            "cycles": lfois * 0.25,
            "time_bandwidth": 2,
            "n_jobs": 1,
            "est_val": lfois,
            "est_key": "LF",
        },
    }

    events = epochs.events[:, 2]
    data = []
    filters = pymeglcmv.setup_filters(epochs.info, forward, data_cov, None,
                                      labels)

    set_n_threads(1)

    for i in range(0, len(events), chunks):
        filename = lcmvfilename(subject,
                                session,
                                signal_type,
                                epoch_type,
                                chunk=i,
                                only_glasser=only_glasser)
        logging.info(filename)
        # if os.path.isfile(filename):
        #    continue
        if signal_type == "BB":
            logging.info("Starting reconstruction of BB signal")
            M = pymeglcmv.reconstruct_broadband(
                filters,
                epochs.info,
                epochs._data[i:i + chunks],
                events[i:i + chunks],
                epochs.times,
                njobs=1,
            )
        else:
            logging.info("Starting reconstruction of TFR signal")
            M = pymeglcmv.reconstruct_tfr(
                filters,
                epochs.info,
                epochs._data[i:i + chunks],
                events[i:i + chunks],
                epochs.times,
                est_args=tfr_params[signal_type],
                njobs=4,
            )
        M.to_hdf(filename, "epochs", mode="w")
    set_n_threads(njobs)
Exemple #10
0
def extract(
    recording_number,
    epoch,
    signal_type="BB",
    BEM="three_layer",
    chunks=100,
    njobs=4,
    glasser_only=True,
):
    recording = ps.recordings[recording_number]
    mne.set_log_level("ERROR")
    lcmv.logging.getLogger().setLevel(logging.INFO)
    set_n_threads(1)

    logging.info("Reading stimulus data")
    data_cov, epochs = get_epoch(epoch, recording)
    raw_filename = raw_path / recording.filename
    trans_filename = trans_path / (
        "SQC_S%02i-SESS%i_B%i_trans.fif" %
        (recording.subject, recording.session, recording.block[1]))
    epoch_filename = ps.filenames(recording.subject, epoch, recording.session,
                                  recording.block[1])[0]

    logging.info("Setting up source space and forward model")

    forward, bem, source = sr.get_leadfield(
        "SQC_S%02i" % recording.subject,
        str(raw_filename),
        str(epoch_filename),
        str(trans_filename),
        bem_sub_path="bem",
        sdir="/home/nwilming/seqconf/fsdir/",
    )
    if glasser_only:
        labels = sr.get_labels(
            "SQC_S%02i" % recording.subject,
            filters=["*wang2015atlas*"],
            sdir="/home/nwilming/seqconf/fsdir/",
        )
    else:
        labels = sr.get_labels("SQC_S%02i" % recording.subject,
                               sdir="/home/nwilming/seqconf/fsdir/")
        labels = sr.labels_exclude(
            labels,
            exclude_filters=[
                "wang2015atlas.IPS4",
                "wang2015atlas.IPS5",
                "wang2015atlas.SPL",
                "JWDG_lat_Unknown",
            ],
        )
        labels = sr.labels_remove_overlap(labels,
                                          priority_filters=["wang", "JWDG"])

    fois_h = np.arange(36, 162, 4)
    fois_l = np.arange(2, 36, 1)
    tfr_params = {
        "HF": {
            "foi": fois_h,
            "cycles": fois_h * 0.25,
            "time_bandwidth": 2 + 1,
            "n_jobs": njobs,
            "est_val": fois_h,
            "est_key": "HF",
            "sf": 600,
            "decim": 10,
        },
        "LF": {
            "foi": fois_l,
            "cycles": fois_l * 0.4,
            "time_bandwidth": 1 + 1,
            "n_jobs": njobs,
            "est_val": fois_l,
            "est_key": "LF",
            "sf": 600,
            "decim": 10,
        },
    }

    events = epochs.events[:, 2]
    filters = lcmv.setup_filters(epochs.info, forward, data_cov, None, labels)
    set_n_threads(1)

    for i in range(0, len(events), chunks):
        filename = lcmvfilename(
            recording,
            signal_type,
            epoch,
            chunk=i,
        )
        if os.path.isfile(filename):
            continue
        if signal_type == "BB":
            logging.info("Starting reconstruction of BB signal")
            M = lcmv.reconstruct_broadband(
                filters,
                epochs.info,
                epochs._data[i:i + chunks],
                events[i:i + chunks],
                epochs.times,
                njobs=1,
            )
        else:
            logging.info("Starting reconstruction of TFR signal")
            M = lcmv.reconstruct_tfr(
                filters,
                epochs.info,
                epochs._data[i:i + chunks],
                events[i:i + chunks],
                epochs.times,
                est_args=tfr_params[signal_type],
                njobs=4,
            )
        M.to_hdf(str(filename), "epochs")
    set_n_threads(njobs)