Exemplo n.º 1
0
def test_fetch_atlas_destrieux_2009_atlas():
    datadir = os.path.join(tmpdir, 'destrieux_2009')
    os.mkdir(datadir)
    dummy = open(
        os.path.join(datadir, 'destrieux2009_rois_labels_lateralized.csv'),
        'w')
    dummy.write("name,index")
    dummy.close()
    bunch = datasets.fetch_atlas_destrieux_2009(data_dir=tmpdir, verbose=0)

    assert_equal(len(url_request.urls), 1)
    assert_equal(
        bunch['maps'],
        os.path.join(tmpdir, 'destrieux_2009',
                     'destrieux2009_rois_lateralized.nii.gz'))

    dummy = open(os.path.join(datadir, 'destrieux2009_rois_labels.csv'), 'w')
    dummy.write("name,index")
    dummy.close()
    bunch = datasets.fetch_atlas_destrieux_2009(lateralized=False,
                                                data_dir=tmpdir,
                                                verbose=0)

    assert_equal(len(url_request.urls), 1)
    assert_equal(
        bunch['maps'],
        os.path.join(tmpdir, 'destrieux_2009', 'destrieux2009_rois.nii.gz'))
Exemplo n.º 2
0
def coordinate_label(mni_coord, atlas='aal', thresh=None, ret_proba=False):

    if atlas == 'aal':
        atl = datasets.fetch_atlas_aal()
        atl.prob = False
    elif atlas == 'harvard_oxford':
        atl = datasets.fetch_atlas_harvard_oxford('cort-prob-2mm')
        atl.prob = True

    elif atlas == 'destrieux':
        atl = datasets.fetch_atlas_destrieux_2009()
        atl.indices = atl.labels['index']
        atl.labels = atl.labels['name']
        atl.prob = False

    atl_map = load_img(atl.maps)
    atl_aff = atl_map.affine

    if atl.prob == True:
        atl_labels = atl.labels
    if atl.prob == False:
        atl_labels = atl.labels
        atl_indices = atl.indices

    labels_out = list()

    for coord in mni_coord:

        mat_coord = np.asarray(resampling.coord_transform(
            coord[0], coord[1], coord[2], np.linalg.inv(atl_aff)),
                               dtype=int)

        if atl.prob == True and ret_proba == True:

            lab_out = get_prob_atlas_label(atl_map,
                                           atl_labels,
                                           mat_coord,
                                           thresh=thresh)

        elif atl.prob == True and ret_proba == False:

            lab_out, _ = get_prob_atlas_label(atl_map,
                                              atl_labels,
                                              mat_coord,
                                              thresh=thresh)

        elif atl.prob == False:

            lab_out = get_atlas_label(atl_map, atl_labels, atl_indices,
                                      mat_coord)

        labels_out.append(lab_out)

    return labels_out
Exemplo n.º 3
0
def test_fetch_atlas_destrieux_2009_atlas():
    datadir = os.path.join(tmpdir, 'destrieux_2009')
    os.mkdir(datadir)
    dummy = open(os.path.join(
        datadir, 'destrieux2009_rois_labels_lateralized.csv'), 'w')
    dummy.write("name,index")
    dummy.close()
    bunch = datasets.fetch_atlas_destrieux_2009(data_dir=tmpdir, verbose=0)

    assert_equal(len(url_request.urls), 1)
    assert_equal(bunch['maps'], os.path.join(
        tmpdir, 'destrieux_2009', 'destrieux2009_rois_lateralized.nii.gz'))

    dummy = open(os.path.join(
        datadir, 'destrieux2009_rois_labels.csv'), 'w')
    dummy.write("name,index")
    dummy.close()
    bunch = datasets.fetch_atlas_destrieux_2009(
        lateralized=False, data_dir=tmpdir, verbose=0)

    assert_equal(len(url_request.urls), 1)
    assert_equal(bunch['maps'], os.path.join(
        tmpdir, 'destrieux_2009', 'destrieux2009_rois.nii.gz'))
Exemplo n.º 4
0
def get_atlas(name):
    if name == "destrieux_2009":
        atlas = datasets.fetch_atlas_destrieux_2009()
        atlas_filename = atlas['maps']
    elif name == "harvard_oxford":
        atlas = datasets.fetch_atlas_harvard_oxford("cort-maxprob-thr25-2mm")
        atlas_filename = atlas['maps']
    elif name == "aal":
        atlas = datasets.fetch_atlas_aal()
        atlas_filename = atlas['maps']
    elif name == "smith_2009":
        atlas = datasets.fetch_atlas_smith_2009()
        atlas_filename = atlas['rsn70']
    else:
        raise ValueError('Atlas name unkown')
    return atlas_filename
Exemplo n.º 5
0
def add_destrieux(nl):
    nl.new_symbol(name="destrieux")
    destrieux_atlas = datasets.fetch_atlas_destrieux_2009()
    destrieux_atlas_image = nib.load(destrieux_atlas["maps"])
    destrieux_labels = dict(destrieux_atlas["labels"])

    destrieux_set = set()
    for k, v in destrieux_labels.items():
        if k == 0:
            continue
        destrieux_set.add((
            v.decode("utf8"),
            ExplicitVBR.from_spatial_image_label(destrieux_atlas_image, k),
        ))

    nl.add_tuple_set(destrieux_set, name="destrieux")
Exemplo n.º 6
0
    def setup_(self):
        nl = fe.NeurolangDL()

        destrieux_atlas = datasets.fetch_atlas_destrieux_2009()
        yeo_atlas = datasets.fetch_atlas_yeo_2011()

        img = nib.load(destrieux_atlas['maps'])
        aff = img.affine
        data = img.get_data()
        rset = []
        for label, name in destrieux_atlas['labels']:
            if label == 0:
                continue
            voxels = np.transpose((data == label).nonzero())
            if len(voxels) == 0:
                continue
            rset.append((name.decode('utf8'),
                         fe.ExplicitVBR(voxels,
                                        aff,
                                        image_dim=img.shape,
                                        prebuild_tree=True)))
        nl.add_tuple_set(rset, name='destrieux')

        img = nib.load(yeo_atlas['thick_17'])
        aff = img.affine
        data = img.get_data().squeeze()
        rset = []
        for label in range(1, 18):
            name = str(label)
            if label == 0:
                continue
            voxels = np.transpose((data == label).nonzero())
            if len(voxels) == 0:
                continue
            rset.append((name,
                         fe.ExplicitVBR(voxels,
                                        aff,
                                        image_dim=data.shape,
                                        prebuild_tree=True)))
        nl.add_tuple_set(rset, name='yeo')
        self.nl = nl
Exemplo n.º 7
0
def fetch_atlases(atlas_names):
    """Fetch atlases provided by name(s)

    Parameters
    ----------
    atlas_names : str or list of str
        Grab atlas from web given the name. Few are shipped with FSL
        and Nilearn.
        Valid options:  ['harvard_oxford', 'destrieux', 'diedrichsen',
                         'juelich', 'jhu', 'mist', 'yeo_networks7',
                         'yeo_networks17']

    Returns
    -------
    data : dict
        Bunch of atlases
    """
    data = {}
    atlas_names = _check_atlases(atlas_names)
    for atlas_name in atlas_names:
        if atlas_name == 'harvard_oxford':
            name = 'cort-maxprob-thr25-2mm'
            data[atlas_name] = datasets.fetch_atlas_harvard_oxford(name)
        elif atlas_name == 'destrieux':
            data[atlas_name] = datasets.fetch_atlas_destrieux_2009()
        elif atlas_name == 'diedrichsen':
            data[atlas_name] = _fetch_atlas_diedrichsen('maxprob-thr25-2mm')
        elif atlas_name == 'juelich':
            data[atlas_name] = _fetch_atlas_juelich('maxprob-thr25-2mm')
        elif atlas_name == 'jhu':
            data[atlas_name] = _fetch_atlas_jhu('labels-2mm')
        elif atlas_name == 'mist':
            data[atlas_name] = fetch_mist()
        elif atlas_name in ['yeo_networks7', 'yeo_networks17']:
            data[atlas_name] = fetch_yeo(atlas_name)
        else:
            raise ValueError("Not a valid atlas. Given atlas is exhausted")
    return data
Exemplo n.º 8
0
    os.mkdir(write_dir)

# Access to the data
data_dir = os.path.join(SMOOTH_DERIVATIVES, 'group')
mask_gm = nib.load(os.path.join(data_dir, 'gm_mask.nii.gz'))
ref_affine, ref_shape = mask_gm.affine, mask_gm.shape

# df = make_db('/neurospin/ibc/smooth_derivatives')
df = data_parser(derivatives=SMOOTH_DERIVATIVES, conditions=CONTRASTS)
#conditions = df.contrast[df.modality == 'bold'].unique()
conditions = CONTRASTS.contrast.values
n_conditions = len(conditions)

# Mask of the ROI
# intersect with GM mask
atlas = datasets.fetch_atlas_destrieux_2009()
# 29: left hemisphere
# 104 right hemisphere
roi_index = 29
roi_mask = math_img('im1 == %d' % roi_index, im1=atlas.maps)
roi_mask = resample_img(roi_mask,
                        ref_affine,
                        ref_shape,
                        interpolation='nearest')
masker = NiftiMasker(mask_img=roi_mask, memory=mem).fit()

path_train = {}
path_test = {}
X_train = []
X_test = []
subjects = df.subject.unique()
Exemplo n.º 9
0
def load_database():
    return datasets.fetch_atlas_destrieux_2009(data_dir="neurolang_data")
Exemplo n.º 10
0
def plot_vol_scatter(vol,ax=None, pointint = 2,c='b',alpha=0.6,s=12.,
                     xlim=[0,100],ylim=[0,100],zlim=[0,100],marker='o',figsize=(20,20),
                     linewidth=0.05,show_bb=True,bb=None,bg_img=None,
                     bg_params = {'alpha': 0.5,'s': 20.,'linewidth': 0.005,'marker':'.', 'c': 'k'},
                     bg_pointint=15): 

  if type(vol) == str: 
    if os.path.isfile(vol):
      img = nib.load(vol)
      dat = img.get_data()
  elif type(vol) == nib.Nifti1Image:
    img = vol
    dat = img.get_data()
  else: 
    Exception('image file type not recognized')

  
  xs,ys,zs = np.nonzero(dat>0)    
  idx = np.arange(0,xs.shape[0])#,pointint)

  if not ax: 
    fig = plt.figure(figsize=figsize)
    ax = fig.gca(projection='3d')
    ax.set_aspect('equal')
    ax.set_xlim([xlim[0],xlim[1]])  
    ax.set_ylim([ylim[0],ylim[1]])
    ax.set_zlim([zlim[0],zlim[1]])

  ax.scatter3D(xs[idx],ys[idx],zs[idx],c=c,alpha=alpha,s=s, marker=marker,linewidths=linewidth)
   

  if bg_img == 'nilearn_destrieux':
    dest = fetch_atlas_destrieux_2009()
    bg_img = resample_to_img(nib.load(dest['maps']),
                             img,interpolation='nearest')
  if bg_img != None:
    
    bg_dat = bg_img.get_data()
    xs,ys,zs = np.nonzero(bg_dat>0)
    idx = np.arange(0,xs.shape[0],bg_pointint)
    ax.scatter3D(xs[idx],ys[idx],zs[idx],**bg_params)


  if show_bb:

    if bb == None:

      bb = get_bounding_box_inds(dat)

    if len(bb) == 6:
      xmin,xmax,ymin,ymax,zmin,zmax = bb
    elif bb.shape == (3,2):
      [[xmin,xmax],[ymin,ymax],[zmin,zmax]] = bb
    else: 
      Exception('Bounding box not recognized')

    corners = np.array(list(product([xmin,xmax],
                                    [ymin,ymax],
                                    [zmin,zmax])))
    cornerpairs = list(combinations(corners,2))

    linestoplot = [(s,e) for (s,e) in cornerpairs \
                   if ((np.abs(s-e) == 0).sum() == 2)]

    for (s,e) in linestoplot:
      ax.plot3D(*zip(s,e), color=c)


  return ax
Exemplo n.º 11
0
# view.open_in_browser()

view

##############################################################################
# Impact of plot parameters on visualization
# ------------------------------------------
#
# You can specify arguments to be passed on to the function
# :func:`nilearn.surface.vol_to_surf` using `vol_to_surf_kwargs`. This allows
# fine-grained control of how the input 3D image is resampled and interpolated -
# for example if you are viewing a volumetric atlas, you would want to avoid
# averaging the labels between neighboring regions. Using nearest-neighbor
# interpolation with zero radius will achieve this.

destrieux = datasets.fetch_atlas_destrieux_2009(legacy_format=False)

view = plotting.view_img_on_surf(
    destrieux.maps,
    surf_mesh="fsaverage",
    vol_to_surf_kwargs={
        "n_samples": 1,
        "radius": 0.0,
        "interpolation": "nearest"
    },
    symmetric_cmap=False,
)

# view.open_in_browser()
view
Exemplo n.º 12
0
def get_labelled_atlas(query, data_dir=None, return_labels=True):
    """Parses input query to determine which atlas to fetch and what version
    of the atlas to use (if applicable).

    Parameters
    ----------
    query : str
        Input string in the following format:
        nilearn:{atlas_name}:{atlas_parameters}. The following can be for
        `atlas_name`: 'destrieux', 'yeo', 'aal', 'talairach', and 'schaefer'.
        `atlas_parameters` is not available for the `destrieux` atlas.
    data_dir : str, optional
        Directory in which to save atlas data. By default None, which creates
        a ~/nilearn_data/ directory as per nilearn.
    return_labels : bool, optional
        Whether to return atlas labels. Default is True. Not available for the
        'basc' atlas.

    Returns
    -------
    str, list or None
        The atlas image and the accompanying labels (if provided)

    Raises
    ------
    ValueError
        Raised when the query does is not formatted correctly or if the no
        match found.
    """

    # extract parameters
    params = query.split(':')
    if len(params) == 3:
        _, atlas_name, sub_param = params
    elif len(params) == 2:
        _, atlas_name = params
        sub_param = None
    else:
        raise ValueError('Incorrect atlas query string provided')

    # get atlas
    if atlas_name == 'destrieux':
        atlas = fetch_atlas_destrieux_2009(lateralized=True, data_dir=data_dir)
        img = atlas['maps']
        labels = atlas['labels']
    elif atlas_name == 'yeo':
        atlas = fetch_atlas_yeo_2011(data_dir=data_dir)
        img = atlas[sub_param]
        if '17' in sub_param:
            labels = pd.read_csv(atlas['colors_17'],
                                 sep=r'\s+')['NONE'].tolist()
    elif atlas_name == 'aal':
        version = 'SPM12' if sub_param is None else sub_param
        atlas = fetch_atlas_aal(version=version, data_dir=data_dir)
        img = atlas['maps']
        labels = atlas['labels']
    elif atlas_name == 'basc':

        version, scale = sub_param.split('-')
        atlas = fetch_atlas_basc_multiscale_2015(version=version,
                                                 data_dir=data_dir)
        img = atlas['scale{}'.format(scale.zfill(3))]
        labels = None
    elif atlas_name == 'talairach':
        atlas = fetch_atlas_talairach(level_name=sub_param, data_dir=data_dir)
        img = atlas['maps']
        labels = atlas['labels']
    elif atlas_name == 'schaefer':
        n_rois, networks, resolution = sub_param.split('-')
        # corrected version of schaefer labels until fixed in nilearn
        correct_url = ('https://raw.githubusercontent.com/ThomasYeoLab/CBIG/'
                       'v0.14.3-Update_Yeo2011_Schaefer2018_labelname/'
                       'stable_projects/brain_parcellation/'
                       'Schaefer2018_LocalGlobal/Parcellations/MNI/')
        atlas = fetch_atlas_schaefer_2018(n_rois=int(n_rois),
                                          yeo_networks=int(networks),
                                          resolution_mm=int(resolution),
                                          data_dir=data_dir,
                                          base_url=correct_url)
        img = atlas['maps']
        labels = atlas['labels']
    else:
        raise ValueError('No atlas detected. Check query string')

    if not return_labels:
        labels = None
    else:
        labels = labels.astype(str).tolist()

    return img, labels
Exemplo n.º 13
0
def get_atlas_rois(atlas, roi_idx, hemisphere, res=None, path=None):
    """
    Extract ROIs from a given atlas.

    Parameters
    ----------
    atlas : str
        Atlas dataset to be downloaded through nilearn's dataset_fetch_atlas functionality.
    roi_idx: list
        List of int of the ROI(s) you want to extract from the atlas. If not sure, use get_atlas_info.
    hemisphere: list
        List of str, that is hemispheres of the ROI(s) you want to extract. Can be ['left'], ['right'] or ['left', 'right'].
    res: str
        Specific version of atlas to be downloaded. Only necessary for Harvard-Oxford and Talairach.
        Please check nilearns respective documentation at
        https://nilearn.github.io/modules/generated/nilearn.datasets.fetch_atlas_harvard_oxford.html or
        https://nilearn.github.io/modules/generated/nilearn.datasets.fetch_atlas_talairach.html
    path: str
        Path to where the extracted ROI(s) will be saved to. If None, ROI(s) will be saved in the current
        working directory.

    Returns
    -------
    list_rois: list
        A list of the extracted ROIs.

    Examples
    --------
    >>> get_atlas_rois('aal', [1, 2, 3], ['left', 'right'], path='/home/urial/Desktop')
    list_rois
    """

    if atlas == 'aal':
        atl_ds = datasets.fetch_atlas_aal()

    elif atlas == 'harvard_oxford':
        if res is None:
            print(
                'Please provide the specific version of the Harvard-Oxford atlas you would like to use.'
            )
        else:
            atl_ds = datasets.fetch_atlas_harvard_oxford(res)

    elif atlas == 'destriuex':
        atl_ds = datasets.fetch_atlas_destrieux_2009()

    elif atlas == 'msdl':
        atl_ds = datasets.fetch_atlas_msdl()

    elif atlas == 'talairach':
        if res is None:
            print(
                'Please provide the level of the Talairach atlas you would like to use.'
            )
        else:
            atl_ds = datasets.fetch_atlas_talairach(level_name=res)

    elif atlas == 'pauli_2017':
        atl_ds = datasets.fetch_atlas_pauli_2017()

    if roi_idx is None:
        print('Please provide the indices of the ROIs you want to extract.')
    elif hemisphere is None:
        print(
            'Please provide the hemisphere(s) from which you want to extract ROIs.'
        )

    for label in roi_idx:
        for hemi in hemisphere:
            roi_ex = Node(PickAtlas(), name='roi_ex')
            roi_ex.inputs.atlas = atl_ds.maps
            roi_ex.inputs.labels = label
            roi_ex.inputs.hemi = hemi
            if path is None:
                roi_ex.inputs.output_file = '%s_%s_%s.nii.gz' % (
                    atlas, str(label), hemi)
                roi_ex.run()
                list_rois = glob('%s_*.nii.gz' % atlas)
            elif path:
                roi_ex.inputs.output_file = opj(
                    path, '%s_%s_%s.nii.gz' % (atlas, str(label), hemi))
                roi_ex.run()
                list_rois = glob(opj(path, '%s_*.nii.gz' % atlas))

    print('The following ROIs were extracted: ')
    print('\n'.join(map(str, list_rois)))

    return list_rois
Exemplo n.º 14
0
def get_atlas_info(atlas, res=None):
    """
    Gather all information from a specified atlas, including the path to the atlas maps, as well as labels
    and their indexes.

    Parameters
    ----------
    atlas : str
        Atlas dataset to be downloaded through nilearn's dataset_fetch_atlas functionality.
    res: str
        Specific version of atlas to be downloaded. Only necessary for Harvard-Oxford and Talairach.
        Please check nilearns respective documentation at
        https://nilearn.github.io/modules/generated/nilearn.datasets.fetch_atlas_harvard_oxford.html or
        https://nilearn.github.io/modules/generated/nilearn.datasets.fetch_atlas_talairach.html

    Returns
    -------
    atlas_info_df : pandas dataframe
        A pandas dataframe containing information about the ROIs and their indexes included in a given atlas.
    atl_ds.maps : str
        Path to the atlas maps.

    Examples
    --------
    >>> get_atlas_info('aal')
    atlas_info_df
    atl_ds.maps
    """

    if atlas == 'aal':
        atl_ds = datasets.fetch_atlas_aal()

    elif atlas == 'harvard_oxford':
        if res is None:
            print(
                'Please provide the specific version of the Harvard-Oxford atlas you would like to use.'
            )
        else:
            atl_ds = datasets.fetch_atlas_harvard_oxford(res)

    elif atlas == 'destriuex':
        atl_ds = datasets.fetch_atlas_destrieux_2009()

    elif atlas == 'msdl':
        atl_ds = datasets.fetch_atlas_msdl()

    elif atlas == 'talairach':
        if res is None:
            print(
                'Please provide the level of the Talairach atlas you would like to use.'
            )
        else:
            atl_ds = datasets.fetch_atlas_talairach(level_name=res)

    elif atlas == 'pauli_2017':
        atl_ds = datasets.fetch_atlas_pauli_2017()

    index = []
    labels = []

    for ind, label in enumerate(atl_ds.labels):
        index.append(ind)
        if atlas == 'destriuex':
            labels.append(label[1])
        else:
            labels.append(label)

    atlas_info_df = pd.DataFrame({'index': index, 'label': labels})

    return atlas_info_df, atl_ds.maps
Exemplo n.º 15
0
def main(
    workdir,
    outdir,
    atlas,
    kernel,
    sparsity,
    affinity,
    approach,
    gradients,
    subcort,
    neurosynth,
    neurosynth_file,
    sleuth_file,
    nimare_dataset,
    roi_mask,
    term,
    topic,
):
    workdir = op.join(workdir, "tmp")
    if op.isdir(workdir):
        shutil.rmtree(workdir)
    os.makedirs(workdir)

    atlas_name = "atlas-{0}".format(atlas)
    kernel_name = "kernel-{0}".format(kernel)
    sparsity_name = "sparsity-{0}".format(sparsity)
    affinity_name = "affinity-{0}".format(affinity)
    approach_name = "approach-{0}".format(approach)
    gradients_name = "gradients-{0}".format(gradients)
    dset = None

    # handle neurosynth dataset, if called
    if neurosynth:
        if neurosynth_file is None:

            ns_data_dir = op.join(workdir, "neurosynth")
            dataset_file = op.join(ns_data_dir, "neurosynth_dataset.pkl.gz")
            # download neurosynth dataset if necessary
            if not op.isfile(dataset_file):
                neurosynth_download(ns_data_dir)

        else:
            dataset_file = neurosynth_file

        dset = Dataset.load(dataset_file)
        dataset_name = "dataset-neurosynth"

    # handle sleuth text file, if called
    if sleuth_file is not None:
        dset = convert_sleuth_to_dataset(sleuth_file, target="mni152_2mm")
        dataset_name = "dataset-{0}".format(op.basename(sleuth_file).split(".")[0])

    if nimare_dataset is not None:
        dset = Dataset.load(nimare_dataset)
        dataset_name = "dataset-{0}".format(op.basename(nimare_dataset).split(".")[0])

    if dset:
        # slice studies, if needed
        if roi_mask is not None:
            roi_ids = dset.get_studies_by_mask(roi_mask)
            print(
                "{}/{} studies report at least one coordinate in the "
                "ROI".format(len(roi_ids), len(dset.ids))
            )
            dset_sel = dset.slice(roi_ids)
            dset = dset_sel
            dataset_name = "dataset-neurosynth_mask-{0}".format(
                op.basename(roi_mask).split(".")[0]
            )

        if term is not None:
            labels = ["Neurosynth_TFIDF__{label}".format(label=label) for label in [term]]
            term_ids = dset.get_studies_by_label(labels=labels, label_threshold=0.1)
            print(
                "{}/{} studies report association "
                "with the term {}".format(len(term_ids), len(dset.ids), term)
            )
            dset_sel = dset.slice(term_ids)
            dset = dset_sel
            # img_inds = np.nonzero(dset.masker.mask_img.get_fdata())  # unused
            # vox_locs = np.unravel_index(img_inds, dset.masker.mask_img.shape)  # unused
            dataset_name = "dataset-neurosynth_term-{0}".format(term)

        if topic is not None:
            topics = [
                "Neurosynth_{version}__{topic}".format(version=topic[0], topic=topic)
                for topic in topic[1:]
            ]
            topics_ids = []
            for topic in topics:
                topic_ids = dset.annotations.id[np.where(dset.annotations[topic])[0]].tolist()
                topics_ids.extend(topic_ids)
                print(
                    "{}/{} studies report association "
                    "with the term {}".format(len(topic_ids), len(dset.ids), topic)
                )
            topics_ids_unique = np.unique(topics_ids)
            print("{} unique ids".format(len(topics_ids_unique)))
            dset_sel = dset.slice(topics_ids_unique)
            dset = dset_sel
            # img_inds = np.nonzero(dset.masker.mask_img.get_fdata())  # unused
            # vox_locs = np.unravel_index(img_inds, dset.masker.mask_img.shape)  # unused
            dataset_name = "dataset-neurosynth_topic-{0}".format("_".join(topic[1:]))

        if (
            neurosynth
            or (sleuth_file is not None)
            or (nimare_dataset is not None)
        ):
            # set kernel for MA smoothing
            if kernel == "peaks2maps":
                print("Running peak2maps")
                k = Peaks2MapsKernel(resample_to_mask=True)
            elif kernel == "alekernel":
                print("Running alekernel")
                k = ALEKernel(fwhm=15)

            if atlas is not None:
                if atlas == "harvard-oxford":
                    print("Parcellating using the Harvard Oxford Atlas")
                    # atlas_labels = atlas.labels[1:]  # unused
                    atlas_shape = atlas.maps.shape
                    atlas_affine = atlas.maps.affine
                    atlas_data = atlas.maps.get_fdata()
                elif atlas == "aal":
                    print("Parcellating using the AAL Atlas")
                    atlas = datasets.fetch_atlas_aal()
                    # atlas_labels = atlas.labels  # unused
                    atlas_shape = nib.load(atlas.maps).shape
                    atlas_affine = nib.load(atlas.maps).affine
                    atlas_data = nib.load(atlas.maps).get_fdata()
                elif atlas == "craddock-2012":
                    print("Parcellating using the Craddock-2012 Atlas")
                    atlas = datasets.fetch_atlas_craddock_2012()
                elif atlas == "destrieux-2009":
                    print("Parcellating using the Destrieux-2009 Atlas")
                    atlas = datasets.fetch_atlas_destrieux_2009(lateralized=True)
                    # atlas_labels = atlas.labels[3:]  # unused
                    atlas_shape = nib.load(atlas.maps).shape
                    atlas_affine = nib.load(atlas.maps).affine
                    atlas_data = nib.load(atlas.maps).get_fdata()
                elif atlas == "msdl":
                    print("Parcellating using the MSDL Atlas")
                    atlas = datasets.fetch_atlas_msdl()
                elif atlas == "surface":
                    print("Generating surface vertices")

                if atlas != "fsaverage5" and atlas != "hcp":
                    imgs = k.transform(dset, return_type="image")

                    masker = NiftiLabelsMasker(
                        labels_img=atlas.maps, standardize=True, memory="nilearn_cache"
                    )
                    time_series = masker.fit_transform(imgs)

                else:
                    # change to array for other approach
                    imgs = k.transform(dset, return_type="image")
                    print(np.shape(imgs))

                    if atlas == "fsaverage5":
                        fsaverage = fetch_surf_fsaverage(mesh="fsaverage5")
                        pial_left = fsaverage.pial_left
                        pial_right = fsaverage.pial_right
                        medial_wall_inds_left = surface.load_surf_data(
                            "./templates/lh.Medial_wall.label"
                        )
                        print(np.shape(medial_wall_inds_left))
                        medial_wall_inds_right = surface.load_surf_data(
                            "./templates/rh.Medial_wall.label"
                        )
                        print(np.shape(medial_wall_inds_right))
                        sulc_left = fsaverage.sulc_left
                        sulc_right = fsaverage.sulc_right

                    elif atlas == "hcp":
                        pial_left = "./templates/S1200.L.pial_MSMAll.32k_fs_LR.surf.gii"
                        pial_right = "./templates/S1200.R.pial_MSMAll.32k_fs_LR.surf.gii"
                        medial_wall_inds_left = np.where(
                            nib.load("./templates/hcp.tmp.lh.dscalar.nii").get_fdata()[0] == 0
                        )[0]
                        medial_wall_inds_right = np.where(
                            nib.load("./templates/hcp.tmp.rh.dscalar.nii").get_fdata()[0] == 0
                        )[0]
                        left_verts = 32492 - len(medial_wall_inds_left)
                        sulc_left = nib.load(
                            "./templates/S1200.sulc_MSMAll.32k_fs_LR.dscalar.nii"
                        ).get_fdata()[0][0:left_verts]
                        sulc_left = np.insert(
                            sulc_left,
                            np.subtract(
                                medial_wall_inds_left, np.arange(len(medial_wall_inds_left))
                            ),
                            0,
                        )
                        sulc_right = nib.load(
                            "./templates/S1200.sulc_MSMAll.32k_fs_LR.dscalar.nii"
                        ).get_fdata()[0][left_verts:]
                        sulc_right = np.insert(
                            sulc_right,
                            np.subtract(
                                medial_wall_inds_right, np.arange(len(medial_wall_inds_right))
                            ),
                            0,
                        )

                    surf_lh = surface.vol_to_surf(
                        imgs,
                        pial_left,
                        radius=6.0,
                        interpolation="nearest",
                        kind="ball",
                        n_samples=None,
                        mask_img=dset.masker.mask_img,
                    )
                    surf_rh = surface.vol_to_surf(
                        imgs,
                        pial_right,
                        radius=6.0,
                        interpolation="nearest",
                        kind="ball",
                        n_samples=None,
                        mask_img=dset.masker.mask_img,
                    )
                    surfs = np.transpose(np.vstack((surf_lh, surf_rh)))
                    del surf_lh, surf_rh

                    # handle cortex first
                    coords_left = surface.load_surf_data(pial_left)[0]
                    coords_left = np.delete(coords_left, medial_wall_inds_left, axis=0)
                    coords_right = surface.load_surf_data(pial_right)[0]
                    coords_right = np.delete(coords_right, medial_wall_inds_right, axis=0)

                    print("Left Hemipshere Vertices")
                    surface_macms_lh, inds_discard_lh = build_macms(dset, surfs, coords_left)
                    print(np.shape(surface_macms_lh))
                    print(inds_discard_lh)

                    print("Right Hemipshere Vertices")
                    surface_macms_rh, inds_discard_rh = build_macms(dset, surfs, coords_right)
                    print(np.shape(surface_macms_rh))
                    print(len(inds_discard_rh))

                    lh_vertices_total = np.shape(surface_macms_lh)[0]
                    rh_vertices_total = np.shape(surface_macms_rh)[0]
                    time_series = np.transpose(np.vstack((surface_macms_lh, surface_macms_rh)))
                    print(np.shape(time_series))
                    del surface_macms_lh, surface_macms_rh

                    if subcort:
                        subcort_img = nib.load("templates/rois-subcortical_mni152_mask.nii.gz")
                        subcort_vox = np.asarray(np.where(subcort_img.get_fdata()))
                        subcort_mm = vox2mm(subcort_vox.T, subcort_img.affine)

                        print("Subcortical Voxels")
                        subcort_macm, inds_discard_subcort = build_macms(dset, surfs, subcort_mm)

                        num_subcort_vox = np.shape(subcort_macm)[0]
                        print(inds_discard_subcort)

                        time_series = np.hstack((time_series, np.asarray(subcort_macm).T))
                        print(np.shape(time_series))

                time_series = time_series.astype("float32")

                print("calculating correlation matrix")
                correlation = ConnectivityMeasure(kind="correlation")
                time_series = correlation.fit_transform([time_series])[0]
                print(np.shape(time_series))

                if affinity == "cosine":
                    time_series = calculate_affinity(time_series, 10 * sparsity)

            else:
                time_series = np.transpose(k.transform(dset, return_type="array"))

    print("Performing gradient analysis")

    gradients, statistics = embed.compute_diffusion_map(
        time_series, alpha=0.5, return_result=True, overwrite=True
    )
    pickle.dump(statistics, open(op.join(workdir, "statistics.p"), "wb"))

    # if subcortical included in gradient decomposition, remove gradient scores
    if subcort:
        subcort_grads = gradients[np.shape(gradients)[0] - num_subcort_vox :, :]
        subcort_grads = insert(subcort_grads, inds_discard_subcort)
        gradients = gradients[0 : np.shape(gradients)[0] - num_subcort_vox, :]

    # get left hemisphere gradient scores, and insert 0's where medial wall is
    gradients_lh = gradients[0:lh_vertices_total, :]
    if len(inds_discard_lh) > 0:
        gradients_lh = insert(gradients_lh, inds_discard_lh)
    gradients_lh = insert(gradients_lh, medial_wall_inds_left)

    # get right hemisphere gradient scores and insert 0's where medial wall is
    gradients_rh = gradients[-rh_vertices_total:, :]
    if len(inds_discard_rh) > 0:
        gradients_rh = insert(gradients_rh, inds_discard_rh)
    gradients_rh = insert(gradients_rh, medial_wall_inds_right)

    grad_dict = {
        "grads_lh": gradients_lh,
        "grads_rh": gradients_rh,
        "pial_left": pial_left,
        "sulc_left": sulc_left,
        "pial_right": pial_right,
        "sulc_right": sulc_right,
    }
    if subcort:
        grad_dict["subcort_grads"] = subcort_grads
    pickle.dump(grad_dict, open(op.join(workdir, "gradients.p"), "wb"))

    # map the gradient to the parcels
    for i in range(np.shape(gradients)[1]):
        if atlas is not None:
            if atlas == "fsaverage5" or atlas == "hcp":

                plot_surfaces(grad_dict, i, workdir)

                if subcort:
                    tmpimg = masking.unmask(subcort_grads[:, i], subcort_img)
                    nib.save(tmpimg, op.join(workdir, "gradient-{0}.nii.gz".format(i)))
            else:
                tmpimg = np.zeros(atlas_shape)
                for j, n in enumerate(np.unique(atlas_data)[1:]):
                    inds = atlas_data == n
                    tmpimg[inds] = gradients[j, i]
                    nib.save(
                        nib.Nifti1Image(tmpimg, atlas_affine),
                        op.join(workdir, "gradient-{0}.nii.gz".format(i)),
                    )
        else:
            tmpimg = np.zeros(np.prod(dset.masker.mask_img.shape))
            inds = np.ravel_multi_index(
                np.nonzero(dset.masker.mask_img.get_fdata()), dset.masker.mask_img.shape
            )
            tmpimg[inds] = gradients[:, i]
            nib.save(
                nib.Nifti1Image(
                    np.reshape(tmpimg, dset.masker.mask_img.shape), dset.masker.mask_img.affine
                ),
                op.join(workdir, "gradient-{0}.nii.gz".format(i)),
            )

            os.system(
                "python3 /Users/miriedel/Desktop/GitHub/surflay/make_figures.py "
                "-f {grad_image} --colormap jet".format(
                    grad_image=op.join(workdir, "gradient-{0}.nii.gz".format(i))
                )
            )

    output_dir = op.join(
        outdir,
        (
            f"{dataset_name}_{atlas_name}_{kernel_name}_{sparsity_name}_{gradients_name}_"
            f"{affinity_name}_{approach_name}"
        )
    )

    shutil.copytree(workdir, output_dir)

    shutil.rmtree(workdir)
Exemplo n.º 16
0
    def process(self):
        raw_dir = '/'.join([self.root, 'raw'])
        fs_subject_dir = osp.join(raw_dir, self.raw_file_names[0],
                                  'freesurfer/5.1')

        # copy .annot to SUBJECT_DIR
        os.system('cp {} {}'.format(osp.join(raw_dir, self.raw_file_names[2]),
                                    fs_subject_dir))
        os.system('cp {} {}'.format(osp.join(raw_dir, self.raw_file_names[3]),
                                    fs_subject_dir))

        # process fs outputs
        shell_script_path = osp.join(os.getcwd(), 'fs_preproc.sh')
        process_fs_output(fs_subject_dir, shell_script_path)

        if self.atlas == 'HCPMMP1':
            fs_subject_dir = osp.join(fs_subject_dir,
                                      'all_output')  # `all_output` hardcoded

        # phenotypic for label
        s3_pheno_path = '/'.join([self.root, 'raw', self.raw_file_names[1]])
        pheno_df = pd.read_csv(s3_pheno_path)

        # load atlas for fmri
        if self.atlas == 'HCPMMP1':
            # load and transform atlas
            atlas_nii_file = '/'.join(
                [self.root, 'raw', self.raw_file_names[4]])
            atlas_img = image.load_img(atlas_nii_file)
            # split left and right hemisphere
            atlas_img.get_data()[atlas_img.get_data()[:int(atlas_img.shape[0] / 2 + 1), :, :].nonzero()] += \
                atlas_img.get_data().max()
            num_nodes = 360
        elif self.atlas == 'destrieux':
            from nilearn.datasets import fetch_atlas_destrieux_2009
            atlas_nii_file = fetch_atlas_destrieux_2009().maps
            atlas_img = image.load_img(atlas_nii_file)
            num_nodes = 148

        # read FreeSurfer output
        anatomical_features_dict = read_fs_stats(fs_subject_dir, self.atlas)

        # make subject_ids
        subject_ids = list(anatomical_features_dict.keys())
        if self.site == 'ALL':
            import urllib
            all_subject_ids_path = 'https://raw.githubusercontent.com/parisots/population-gcn/master/subject_IDs.txt'
            response = urllib.request.urlopen(all_subject_ids_path)
            all_subject_ids = [
                s.decode() for s in response.read().splitlines()
            ]
            subject_ids = [s for s in subject_ids if s[-5:] in all_subject_ids]
            assert len(subject_ids) == 871

        # process the data
        data_list = []
        failed_subject_list = []
        for subject in tqdm(subject_ids, desc='subject_list'):
            try:
                y, sex, iq, site_id, subject_id = label_from_pheno(
                    pheno_df, subject)

                # read anatomical features from dict
                lh_df, rh_df = anatomical_features_dict[subject]
                node_features = torch.from_numpy(
                    np.concatenate([
                        lh_df[self.anatomical_feature_names].values,
                        rh_df[self.anatomical_feature_names].values
                    ])).float()
                if node_features.shape[0] != num_nodes:
                    # check missing nodes, for 'destrieux'
                    continue

                # path for preprocessed functional MRI
                fmri_nii_file = '/'.join([
                    self.root, 'raw', 'Outputs', self.pipeline, self.strategy,
                    self.derivative, "{}_func_preproc.nii.gz".format(subject)
                ])

                # nilearn masker and corr
                masker = NiftiLabelsMasker(labels_img=atlas_img,
                                           standardize=True,
                                           memory='nilearn_cache',
                                           verbose=5)
                correlation_measure = ConnectivityMeasure(kind='correlation')
                time_series = masker.fit_transform(fmri_nii_file)

                # 0 in regions
                # for i in range(num_nodes):
                #     assert np.any(time_series[:, i])

                # handle broken file in ABIDE preprocessed filternoglobal
                if subject == 'UM_1_0050302':
                    time_series = nilearn.signal.clean(time_series.transpose(),
                                                       low_pass=0.1,
                                                       high_pass=0.01,
                                                       t_r=2).transpose()
                if subject == 'Leuven_2_0050730':
                    time_series = nilearn.signal.clean(time_series.transpose(),
                                                       low_pass=0.1,
                                                       high_pass=0.01,
                                                       t_r=1.6667).transpose()
                # optional data augmentation
                if self.resample_ts:
                    time_series_list = resample_temporal(time_series)
                elif self.dfc_resample:
                    time_series_list = window_slpit_ts(time_series,
                                                       window=30,
                                                       step=5)
                else:
                    time_series_list = [time_series]
                # correlation form time series
                connectivity_matrix_list = correlation_measure.fit_transform(
                    time_series_list)
                for adj, time_series in zip(connectivity_matrix_list,
                                            time_series_list):
                    time_series, raw_adj = torch.tensor(
                        time_series), torch.tensor(adj)
                    adj_statistics = get_adj_statistics(adj)
                    padded_time_series = torch.zeros(30, 360)
                    padded_time_series[:time_series.shape[0]] = time_series
                    padded_time_series = padded_time_series.t()

                    # transform adj
                    # np.fill_diagonal(adj, 0)  # remove self-loop for transform
                    adj = self.transform_edge(
                        adj) if self.transform_edge is not None else adj
                    # set a threshold for adj
                    if self.threshold is not None:
                        adj = top_k_percent_adj(adj, self.threshold)
                        assert check_strongly_connected(adj)
                    # create torch_geometric Data
                    edge_index, edge_weight = from_scipy_sparse_matrix(
                        coo_matrix(adj))

                    data = Data(x=node_features,
                                edge_index=edge_index,
                                edge_attr=edge_weight,
                                y=y)
                    # additional node feature
                    data.adj_statistics = adj_statistics
                    data.time_series, data.raw_adj = padded_time_series, raw_adj
                    # phenotypic data
                    data.sex, data.iq, data.site_id, data.subject_id = sex, iq, site_id, subject_id
                    data.num_nodes = data.x.shape[0]
                    data_list.append(data)
            except Exception as e:
                print(e)
                logging.warning("failed at subject {}".format(subject))
                failed_subject_list.append(subject)
        # with open('failed_fmri_subject_list', 'w') as f:
        #     f.write("\n".join(failed_subject_list))
        print("failed_subject_list", failed_subject_list)
        if self.site == 'ALL':
            data_list = repermute(data_list, all_subject_ids)
        self.data, self.slices = self.collate(data_list)
        torch.save((self.data, self.slices), self.processed_paths[0])