Exemplo n.º 1
0
def crawl_bucket(bucket, path, jobdir):
    """Gets subject list for a given s3 bucket and path
    Parameters
    ----------
    bucket : str
        s3 bucket
    path : str
        The directory where the dataset is stored on the S3 bucket
    jobdir : str
        Directory of batch jobs to generate/check up on
    Returns
    -------
    OrderedDict
        dictionary containing all subjects and sessions from the path location
    """
    from pynets.core.utils import flatten

    # if jobdir has seshs info file in it, use that instead
    sesh_path = f"{jobdir}/seshs.json"
    if os.path.isfile(sesh_path):
        print("seshs.json found -- loading bucket info from there")
        with open(sesh_path, "r") as f:
            seshs = json.load(f)
        print("Information obtained from s3.")
        return seshs

    # set up bucket crawl
    subj_pattern = r"(?<=sub-)(\w*)(?=/ses)"
    sesh_pattern = r"(?<=ses-)(\d*)"
    all_subfiles = get_matching_s3_objects(bucket, path + "/sub-")
    all_subfiles = [obj for obj in all_subfiles]
    all_subs = [re.findall(subj_pattern, obj) for obj in all_subfiles]
    subjs = list(set([i for i in flatten(all_subs)]))
    seshs = OrderedDict()

    # populate seshs
    for subj in subjs:
        prefix = f"{path}/sub-{subj}/"
        all_seshfiles = get_matching_s3_objects(bucket, prefix)
        all_seshfiles = [obj for obj in all_seshfiles]
        all_seshs = [re.findall(sesh_pattern, obj) for obj in all_seshfiles]
        sesh = list(set([i for i in flatten(all_seshs)]))

        if sesh != []:
            seshs[subj] = sesh
            print(f"{subj} added to sessions.")
        else:
            seshs[subj] = None
            print(f"{subj} not added (no sessions).")

    # print session IDs and create json cache
    print(("Session IDs: " + ", ".join([
        subj + "-" + sesh if sesh is not None else subj for subj in subjs
        for sesh in seshs[subj]
    ])))
    with open(sesh_path, "w") as f:
        json.dump(seshs, f)
    print(f"{sesh_path} created.")
    print("Information obtained from s3.")
    return seshs
Exemplo n.º 2
0
def _omni_embed(pop_array, atlas, graph_path, ID, subgraph_name='whole_brain'):
    from graspy.embed import OmnibusEmbed, ClassicalMDS
    variance_threshold = VarianceThreshold(threshold=0.00001)
    diags = np.array([np.triu(pop_array[i]) for i in range(len(pop_array))])
    graphs_ix_keep = variance_threshold.fit(
        diags.reshape(diags.shape[0], diags.shape[1] *
                      diags.shape[2]).T).get_support(indices=True)
    pop_array_red = [pop_array[i] for i in graphs_ix_keep]

    # Omnibus embedding -- random dot product graph (rdpg)
    print("%s%s%s%s%s" % ('Embedding ensemble for atlas: ', atlas, ' and ',
                          subgraph_name, '...'))
    omni = OmnibusEmbed(check_lcc=False)
    mds = ClassicalMDS()
    try:
        omni_fit = omni.fit_transform(pop_array_red)
    except:
        omni_fit = omni.fit_transform(pop_array)

    # Transform omnibus tensor into dissimilarity feature
    mds_fit = mds.fit_transform(omni_fit)

    dir_path = str(Path(os.path.dirname(graph_path)).parent)

    namer_dir = dir_path + '/embeddings'
    if not os.path.isdir(namer_dir):
        os.makedirs(namer_dir, exist_ok=True)

    out_path = "%s%s%s%s%s%s%s%s" % (namer_dir, '/', list(
        flatten(ID))[0], '_omnetome_', atlas, '_', subgraph_name, '.npy')
    print('Saving...')
    np.save(out_path, mds_fit)
    del mds, mds_fit, omni, omni_fit
    return out_path
Exemplo n.º 3
0
def build_asetomes(est_path_iterlist, ID):
    """
    Embeds single graphs using the ASE algorithm.

    Parameters
    ----------
    est_path_iterlist : list
        List of file paths to .npy files, each containing a graph.
    ID : str
        A subject id or other unique identifier.
    """
    from pynets.core.utils import prune_suffices
    from pynets.stats.embeddings import _ase_embed

    out_paths = []
    for file_ in list(flatten(est_path_iterlist)):
        mat = np.load(file_)
        atlas = prune_suffices(file_.split('/')[-3])
        res = prune_suffices('_'.join(
            file_.split('/')[-1].split('modality')[1].split('_')[1:]).split(
                '_est')[0])
        if 'rsn' in res:
            subgraph = res.split('rsn-')[1]
        else:
            subgraph = 'whole_brain'
        out_path = _ase_embed(mat, atlas, file_, ID, subgraph_name=subgraph)
        out_paths.append(out_path)

    return out_paths
Exemplo n.º 4
0
def _mase_embed(pop_array, atlas, graph_path, ID, subgraph_name='whole_brain'):
    from graspy.embed import MultipleASE
    variance_threshold = VarianceThreshold(threshold=0.00001)
    diags = np.array([np.triu(pop_array[i]) for i in range(len(pop_array))])
    graphs_ix_keep = variance_threshold.fit(
        diags.reshape(diags.shape[0], diags.shape[1] *
                      diags.shape[2]).T).get_support(indices=True)
    pop_array_red = [pop_array[i] for i in graphs_ix_keep]

    # Omnibus embedding -- random dot product graph (rdpg)
    print("%s%s%s%s%s" % ('Embedding ensemble for atlas: ', atlas, ' and ',
                          subgraph_name, '...'))
    mase = MultipleASE()
    try:
        mase_fit = mase.fit_transform(pop_array_red)
    except:
        mase_fit = mase.fit_transform(pop_array)

    dir_path = str(Path(os.path.dirname(graph_path)).parent)
    namer_dir = dir_path + '/embeddings'
    if not os.path.isdir(namer_dir):
        os.makedirs(namer_dir, exist_ok=True)

    out_path = "%s%s%s%s%s%s%s%s" % (namer_dir, '/', list(
        flatten(ID))[0], '_masetome_', atlas, '_', subgraph_name, '.npy')
    print('Saving...')
    np.save(out_path, mase.scores_)
    del mase, mase_fit

    return out_path
Exemplo n.º 5
0
def build_asetomes(est_path_iterlist, ID):
    """
    Embeds single graphs using the ASE algorithm.

    Parameters
    ----------
    est_path_iterlist : list
        List of file paths to .npy files, each containing a graph.
    ID : str
        A subject id or other unique identifier.

    """
    import numpy as np
    from pynets.core.utils import prune_suffices, flatten
    from pynets.stats.embeddings import _ase_embed
    import yaml
    import pkg_resources

    # Available functional and structural connectivity models
    with open(pkg_resources.resource_filename("pynets", "runconfig.yaml"),
              "r") as stream:
        hardcoded_params = yaml.load(stream)
        try:
            n_components = hardcoded_params["gradients"]["n_components"][0]
        except KeyError:
            import sys
            print("ERROR: available gradient dimensionality presets not "
                  "sucessfully extracted from runconfig.yaml")
            sys.exit(1)

    stream.close()

    if isinstance(est_path_iterlist, list):
        est_path_iterlist = list(flatten(est_path_iterlist))
    else:
        est_path_iterlist = [est_path_iterlist]

    out_paths = []
    for file_ in est_path_iterlist:
        mat = np.load(file_)
        atlas = prune_suffices(file_.split("/")[-3])
        res = prune_suffices("_".join(
            file_.split("/")[-1].split("modality")[1].split("_")[1:]).split(
                "_est")[0])
        if "rsn" in res:
            subgraph = res.split("rsn-")[1]
        else:
            subgraph = "all_nodes"
        out_path = _ase_embed(mat,
                              atlas,
                              file_,
                              ID,
                              subgraph_name=subgraph,
                              n_components=n_components)
        out_paths.append(out_path)

    return out_paths
Exemplo n.º 6
0
def flatten(l):
    """
    Flatten list of lists.
    """
    import collections
    for el in l:
        if isinstance(el, collections.Iterable) and not isinstance(el, (str, bytes)):
            for ell in flatten(el):
                yield ell
        else:
            yield el
Exemplo n.º 7
0
def test_flatten():
    """
    Test list flatten functionality
    """
    # Slow, but successfully flattens a large array
    l = np.random.rand(3, 3, 3).tolist()
    l = utils.flatten(l)

    i = 0
    for item in l:
        i += 1
    assert i == (3 * 3 * 3)
Exemplo n.º 8
0
def build_asetomes(est_path_iterlist, ID):
    """
    Embeds single graphs using the ASE algorithm.

    Parameters
    ----------
    est_path_iterlist : list
        List of file paths to .npy files, each containing a graph.
    ID : str
        A subject id or other unique identifier.

    """
    import numpy as np
    from pynets.core.utils import prune_suffices, flatten
    from pynets.stats.embeddings import _ase_embed

    if isinstance(est_path_iterlist, list):
        est_path_iterlist = list(flatten(est_path_iterlist))
    else:
        est_path_iterlist = [est_path_iterlist]

    out_paths = []
    for file_ in est_path_iterlist:
        mat = np.load(file_)
        atlas = prune_suffices(file_.split("/")[-3])
        res = prune_suffices("_".join(
            file_.split("/")[-1].split("modality")[1].split("_")[1:]).split(
                "_est")[0])
        if "rsn" in res:
            subgraph = res.split("rsn-")[1]
        else:
            subgraph = "whole_brain"
        out_path = _ase_embed(mat, atlas, file_, ID, subgraph_name=subgraph)
        out_paths.append(out_path)

    return out_paths
Exemplo n.º 9
0
def build_multigraphs(est_path_iterlist):
    """
    Constructs a multimodal multigraph for each available resolution of
    vertices.

    Parameters
    ----------
    est_path_iterlist : list
        List of file paths to .npy file containing graph.

    Returns
    -------
    multigraph_list_all : list
        List of multiplex graph dictionaries corresponding to
        each unique node resolution.
    graph_path_list_top : list
        List of lists consisting of pairs of most similar
        structural and functional connectomes for each unique node resolution.

    References
    ----------
    .. [1] Bullmore, E., & Sporns, O. (2009). Complex brain networks: Graph
      theoretical analysis of structural and functional systems.
      Nature Reviews Neuroscience. https://doi.org/10.1038/nrn2575
    .. [2] Vaiana, M., & Muldoon, S. F. (2018). Multilayer Brain Networks.
      Journal of Nonlinear Science. https://doi.org/10.1007/s00332-017-9436-8

    """
    import os
    import itertools
    import numpy as np
    from pathlib import Path
    from pynets.statistics.individual.multiplex import matching
    from pynets.core.utils import flatten, load_runconfig

    raw_est_path_iterlist = list(flatten(est_path_iterlist))

    # Available functional and structural connectivity models
    hardcoded_params = load_runconfig()
    try:
        func_models = hardcoded_params["available_models"]["func_models"]
    except KeyError:
        print("ERROR: available functional models not sucessfully extracted"
              " from advanced.yaml")
    try:
        dwi_models = hardcoded_params["available_models"]["dwi_models"]
    except KeyError:
        print("ERROR: available structural models not sucessfully extracted"
              " from advanced.yaml")

    atlases = list(
        set([x.split("/")[-3].split("/")[0] for x in raw_est_path_iterlist]))
    parcel_dict_func = dict.fromkeys(atlases)
    parcel_dict_dwi = dict.fromkeys(atlases)
    est_path_iterlist_dwi = list(
        set([i for i in raw_est_path_iterlist if "dwi" in i]))
    est_path_iterlist_func = list(
        set([i for i in raw_est_path_iterlist if "func" in i]))

    if "_subnet" in ";".join(est_path_iterlist_func):
        func_subnets = list(
            set([
                i.split("_subnet-")[1].split("_")[0]
                for i in est_path_iterlist_func
            ]))
    else:
        func_subnets = []
    if "_subnet" in ";".join(est_path_iterlist_dwi):
        dwi_subnets = list(
            set([
                i.split("_subnet-")[1].split("_")[0]
                for i in est_path_iterlist_dwi
            ]))
    else:
        dwi_subnets = []

    dir_path = str(
        Path(os.path.dirname(est_path_iterlist_dwi[0])).parent.parent.parent)
    namer_dir = f"{dir_path}/dwi-func"
    if not os.path.isdir(namer_dir):
        os.mkdir(namer_dir)

    multigraph_list_all = []
    graph_path_list_all = []
    for atlas in atlases:
        if len(func_subnets) >= 1:
            parcel_dict_func[atlas] = {}
            for sub_net in func_subnets:
                parcel_dict_func[atlas][sub_net] = []
        else:
            parcel_dict_func[atlas] = []

        if len(dwi_subnets) >= 1:
            parcel_dict_dwi[atlas] = {}
            for sub_net in dwi_subnets:
                parcel_dict_dwi[atlas][sub_net] = []
        else:
            parcel_dict_dwi[atlas] = []

        for graph_path in est_path_iterlist_dwi:
            if atlas in graph_path:
                if len(dwi_subnets) >= 1:
                    for sub_net in dwi_subnets:
                        if sub_net in graph_path:
                            parcel_dict_dwi[atlas][sub_net].append(graph_path)
                else:
                    parcel_dict_dwi[atlas].append(graph_path)

        for graph_path in est_path_iterlist_func:
            if atlas in graph_path:
                if len(func_subnets) >= 1:
                    for sub_net in func_subnets:
                        if sub_net in graph_path:
                            parcel_dict_func[atlas][sub_net].append(graph_path)
                else:
                    parcel_dict_func[atlas].append(graph_path)

        parcel_dict = {}
        # Create dictionary of all possible pairs of structural-functional
        # graphs for each unique resolution of vertices
        if len(dwi_subnets) >= 1 and len(func_subnets) >= 1:
            parcel_dict[atlas] = {}
            subnets = np.intersect1d(dwi_subnets, func_subnets).tolist()
            for subnet in subnets:
                parcel_dict[atlas][subnet] = list(
                    set(
                        itertools.product(parcel_dict_dwi[atlas][subnet],
                                          parcel_dict_func[atlas][subnet])))
                for paths in list(parcel_dict[atlas][subnet]):
                    [mG_nx, mG, out_dwi_mat, out_func_mat] = matching(
                        paths,
                        atlas,
                        namer_dir,
                    )
        else:
            parcel_dict[atlas] = list(
                set(
                    itertools.product(parcel_dict_dwi[atlas],
                                      parcel_dict_func[atlas])))
            for paths in list(parcel_dict[atlas]):
                [mG_nx, mG, out_dwi_mat, out_func_mat] = matching(
                    paths,
                    atlas,
                    namer_dir,
                )
        multigraph_list_all.append((mG_nx, mG))
        graph_path_list_all.append((out_dwi_mat, out_func_mat))

    return (
        multigraph_list_all,
        graph_path_list_all,
        len(multigraph_list_all) * [namer_dir],
    )
Exemplo n.º 10
0
def build_asetomes(est_path_iterlist):
    """
    Embeds single graphs using the ASE algorithm.

    Parameters
    ----------
    est_path_iterlist : list
        List of file paths to .npy files, each containing a graph.

    """
    from pathlib import Path
    import os
    import numpy as np
    from pynets.statistics.individual.spectral import _ase_embed
    from pynets.core.utils import prune_suffices, flatten, load_runconfig

    # Available functional and structural connectivity models
    hardcoded_params = load_runconfig()
    try:
        n_components = hardcoded_params["gradients"][
            "n_components"][0]
    except KeyError:
        import sys
        print(
            "ERROR: available gradient dimensionality presets not "
            "sucessfully extracted from advanced.yaml"
        )
        sys.exit(1)

    if isinstance(est_path_iterlist, list):
        est_path_iterlist = list(flatten(est_path_iterlist))
    else:
        est_path_iterlist = [est_path_iterlist]

    out_paths = []
    for file_ in est_path_iterlist:
        mat = np.load(file_)
        if np.isfinite(mat).all() == False:
            continue

        atlas = prune_suffices(file_.split("/")[-3])
        res = prune_suffices("_".join(file_.split(
            "/")[-1].split("modality")[1].split("_")[1:]).split("_est")[0])
        if "subnet" in res:
            subgraph = res.split("subnet-")[1].split('_')[0]
        else:
            subgraph = "all_nodes"

            out_path = _ase_embed(mat, atlas, file_, subgraph_name=subgraph,
                                  n_components=n_components, prune=0, norm=1)

        if out_path is not None:
            out_paths.append(out_path)
        else:
            # Add a null tmp file to prevent pool from breaking
            dir_path = str(Path(os.path.dirname(file_)).parent)
            namer_dir = f"{dir_path}/embeddings"
            if os.path.isdir(namer_dir) is False:
                os.makedirs(namer_dir, exist_ok=True)
            out_path = f"{namer_dir}/gradient-ASE" \
                       f"_subnet-{atlas}_granularity-{subgraph}_" \
                       f"{os.path.basename(file_)}_NULL"
            # TODO: Replace this band-aid solution with the real fix
            out_path = out_path.replace('subnet-subnet-',
                                        'subnet-').replace(
                'granularity-granularity-', 'granularity-')
            if not os.path.exists(out_path):
                open(out_path, 'w').close()
            out_paths.append(out_path)

    return out_paths
Exemplo n.º 11
0
    def parcellate(self):
        """
        API for performing any of a variety of clustering routines available through NiLearn.
        """
        import gc
        import time
        import os
        from nilearn.regions import Parcellations
        from pynets.fmri.estimation import fill_confound_nans

        start = time.time()

        if (self.clust_type == 'ward') and (self.local_corr != 'allcorr'):
            if self._local_conn_mat_path is not None:
                if not os.path.isfile(self._local_conn_mat_path):
                    raise FileNotFoundError('File containing sparse matrix of local connectivity structure not found.')
            else:
                raise FileNotFoundError('File containing sparse matrix of local connectivity structure not found.')

        if self.clust_type == 'complete' or self.clust_type == 'average' or self.clust_type == 'single' or \
            self.clust_type == 'ward' or (self.clust_type == 'rena' and self.num_conn_comps == 1) or \
                (self.clust_type == 'kmeans' and self.num_conn_comps == 1):
            self._clust_est = Parcellations(method=self.clust_type, standardize=self._standardize,
                                            detrend=self._detrending,
                                            n_parcels=self.k, mask=self._clust_mask_corr_img,
                                            connectivity=self._local_conn, mask_strategy='background', memory_level=2,
                                            random_state=42)

            if self.conf is not None:
                import pandas as pd
                confounds = pd.read_csv(self.conf, sep='\t')
                if confounds.isnull().values.any():
                    conf_corr = fill_confound_nans(confounds, self._dir_path)
                    self._clust_est.fit(self._func_img, confounds=conf_corr)
                else:
                    self._clust_est.fit(self._func_img, confounds=self.conf)
            else:
                self._clust_est.fit(self._func_img)

            self._clust_est.labels_img_.set_data_dtype(np.uint16)
            nib.save(self._clust_est.labels_img_, self.uatlas)
        elif self.clust_type == 'ncut':
            out_img = parcellate_ncut(self._local_conn, self.k, self._clust_mask_corr_img)
            out_img.set_data_dtype(np.uint16)
            nib.save(out_img, self.uatlas)
        elif self.clust_type == 'rena' or self.clust_type == 'kmeans' and self.num_conn_comps > 1:
            from pynets.core import nodemaker
            from nilearn.regions import connected_regions, Parcellations
            from nilearn.image import iter_img, new_img_like
            from pynets.core.utils import flatten, proportional

            mask_img_list = []
            mask_voxels_dict = dict()
            for i, mask_img in enumerate(list(iter_img(self._conn_comps))):
                mask_voxels_dict[i] = np.int(np.sum(np.asarray(mask_img.dataobj)))
                mask_img_list.append(mask_img)

            # Allocate k across connected components using Hagenbach-Bischoff Quota based on number of voxels
            k_list = proportional(self.k, list(mask_voxels_dict.values()))

            conn_comp_atlases = []
            print("%s%s%s" % ('Building ', len(mask_img_list), ' separate atlases with voxel-proportional nclusters '
                                                               'for each connected component...'))
            for i, mask_img in enumerate(mask_img_list):
                if k_list[i] == 0:
                    # print('0 voxels in component. Discarding...')
                    continue
                self._clust_est = Parcellations(method=self.clust_type, standardize=self._standardize,
                                                detrend=self._detrending,
                                                n_parcels=k_list[i], mask=mask_img,
                                                mask_strategy='background',
                                                memory_level=2,
                                                random_state=42)
                if self.conf is not None:
                    import pandas as pd
                    confounds = pd.read_csv(self.conf, sep='\t')
                    if confounds.isnull().values.any():
                        conf_corr = fill_confound_nans(confounds, self._dir_path)
                        self._clust_est.fit(self._func_img, confounds=conf_corr)
                    else:
                        self._clust_est.fit(self._func_img, confounds=self.conf)
                else:
                    self._clust_est.fit(self._func_img)
                conn_comp_atlases.append(self._clust_est.labels_img_)

            # Then combine the multiple atlases, corresponding to each connected component, into a single atlas
            atlas_of_atlases = []
            for atlas in conn_comp_atlases:
                bna_data = np.around(np.asarray(atlas.dataobj)).astype('uint16')

                # Get an array of unique parcels
                bna_data_for_coords_uniq = np.unique(bna_data)

                # Number of parcels:
                par_max = len(bna_data_for_coords_uniq) - 1
                img_stack = []
                for idx in range(1, par_max + 1):
                    roi_img = bna_data == bna_data_for_coords_uniq[idx].astype('uint16')
                    img_stack.append(roi_img.astype('uint16'))
                img_stack = np.array(img_stack)

                img_list = []
                for idy in range(par_max):
                    img_list.append(new_img_like(atlas, img_stack[idy]))
                atlas_of_atlases.append(img_list)
                del img_list, img_stack, bna_data

            atlas_of_atlases = list(flatten(atlas_of_atlases))

            [super_atlas_ward, _] = nodemaker.create_parcel_atlas(atlas_of_atlases)
            super_atlas_ward.set_data_dtype(np.uint16)

            nib.save(super_atlas_ward, self.uatlas)
            del atlas_of_atlases, super_atlas_ward, conn_comp_atlases, mask_img_list, mask_voxels_dict

        print("%s%s%s" % (self.clust_type, self.k, " clusters: %.2fs" % (time.time() - start)))

        del self._clust_est
        self._func_img.uncache()
        self._clust_mask_corr_img.uncache()
        gc.collect()

        return self.uatlas
Exemplo n.º 12
0
def build_multigraphs(est_path_iterlist, ID):
    """
    Constructs a multimodal multigraph for each available resolution of vertices.

    Parameters
    ----------
    est_path_iterlist : list
        List of file paths to .npy file containing graph.
    ID : str
        A subject id or other unique identifier.

    Returns
    -------
    multigraph_list_all : list
        List of multiplex graph dictionaries corresponding to
        each unique node resolution.
    graph_path_list_top : list
        List of lists consisting of pairs of most similar
        structural and functional connectomes for each unique node resolution.

    References
    ----------
    .. [1] Bullmore, E., & Sporns, O. (2009). Complex brain networks: Graph
      theoretical analysis of structural and functional systems.
      Nature Reviews Neuroscience. https://doi.org/10.1038/nrn2575
    .. [2] Vaiana, M., & Muldoon, S. F. (2018). Multilayer Brain Networks.
      Journal of Nonlinear Science. https://doi.org/10.1007/s00332-017-9436-8

    """
    import pkg_resources
    import yaml
    import os
    import itertools
    from pathlib import Path
    from pynets.core.utils import flatten
    from pynets.stats.netmotifs import motif_matching

    raw_est_path_iterlist = list(
        set([
            i.split('_thrtype')[0] + '_raw.npy'
            for i in list(flatten(est_path_iterlist))
        ]))

    # Available functional and structural connectivity models
    with open(pkg_resources.resource_filename("pynets", "runconfig.yaml"),
              'r') as stream:
        hardcoded_params = yaml.load(stream)
        try:
            func_models = hardcoded_params['available_models']['func_models']
        except KeyError:
            print(
                'ERROR: available functional models not sucessfully extracted from runconfig.yaml'
            )
        try:
            struct_models = hardcoded_params['available_models'][
                'struct_models']
        except KeyError:
            print(
                'ERROR: available structural models not sucessfully extracted from runconfig.yaml'
            )
    stream.close()

    atlases = list(
        set([x.split('/')[-3].split('/')[0] for x in raw_est_path_iterlist]))
    parcel_dict_func = dict.fromkeys(atlases)
    parcel_dict_dwi = dict.fromkeys(atlases)
    est_path_iterlist_dwi = list(
        set([
            i for i in raw_est_path_iterlist
            if i.split('est-')[1].split('_')[0] in struct_models
        ]))
    est_path_iterlist_func = list(
        set([
            i for i in raw_est_path_iterlist
            if i.split('est-')[1].split('_')[0] in func_models
        ]))

    func_subnets = list(
        set([
            i.split('_est')[0].split('/')[-1] for i in est_path_iterlist_func
        ]))
    dwi_subnets = list(
        set([i.split('_est')[0].split('/')[-1]
             for i in est_path_iterlist_dwi]))

    dir_path = str(
        Path(os.path.dirname(est_path_iterlist_dwi[0])).parent.parent.parent)
    namer_dir = f"{dir_path}/graphs_multilayer"
    if not os.path.isdir(namer_dir):
        os.mkdir(namer_dir)

    name_list = []
    metadata_list = []
    multigraph_list_all = []
    graph_path_list_all = []
    for atlas in atlases:
        if len(func_subnets) >= 1:
            parcel_dict_func[atlas] = {}
            for sub_net in func_subnets:
                parcel_dict_func[atlas][sub_net] = []
        else:
            parcel_dict_func[atlas] = []

        if len(dwi_subnets) >= 1:
            parcel_dict_dwi[atlas] = {}
            for sub_net in dwi_subnets:
                parcel_dict_dwi[atlas][sub_net] = []
        else:
            parcel_dict_dwi[atlas] = []

        for graph_path in est_path_iterlist_dwi:
            if atlas in graph_path:
                if len(dwi_subnets) >= 1:
                    for sub_net in dwi_subnets:
                        if sub_net in graph_path:
                            parcel_dict_dwi[atlas][sub_net].append(graph_path)
                else:
                    parcel_dict_dwi[atlas].append(graph_path)

        for graph_path in est_path_iterlist_func:
            if atlas in graph_path:
                if len(func_subnets) >= 1:
                    for sub_net in func_subnets:
                        if sub_net in graph_path:
                            parcel_dict_func[atlas][sub_net].append(graph_path)
                else:
                    parcel_dict_func[atlas].append(graph_path)

        parcel_dict = {}
        # Create dictionary of all possible pairs of structural-functional graphs for each unique resolution
        # of vertices
        if len(dwi_subnets) >= 1 and len(func_subnets) >= 1:
            parcel_dict[atlas] = {}
            dwi_subnets.sort(key=lambda x: x.split('_rsn-')[1])
            func_subnets.sort(key=lambda x: x.split('_rsn-')[1])

            for sub_net_dwi, sub_net_func in list(
                    zip(dwi_subnets, func_subnets)):
                rsn = sub_net_dwi.split('_rsn-')[1]
                parcel_dict[atlas][rsn] = list(
                    set(
                        itertools.product(
                            parcel_dict_dwi[atlas][sub_net_dwi],
                            parcel_dict_func[atlas][sub_net_func])))
                for paths in list(parcel_dict[atlas][rsn]):
                    [
                        name_list, metadata_list, multigraph_list_all,
                        graph_path_list_all
                    ] = motif_matching(paths,
                                       ID,
                                       atlas,
                                       namer_dir,
                                       name_list,
                                       metadata_list,
                                       multigraph_list_all,
                                       graph_path_list_all,
                                       rsn=rsn)
        else:
            parcel_dict[atlas] = list(
                set(
                    itertools.product(parcel_dict_dwi[atlas],
                                      parcel_dict_func[atlas])))
            for paths in list(parcel_dict[atlas]):
                [
                    name_list, metadata_list, multigraph_list_all,
                    graph_path_list_all
                ] = motif_matching(paths, ID, atlas, namer_dir, name_list,
                                   metadata_list, multigraph_list_all,
                                   graph_path_list_all)

    graph_path_list_top = [list(i[0].values()) for i in graph_path_list_all]
    assert len(multigraph_list_all) == len(name_list) == len(metadata_list)

    return multigraph_list_all, graph_path_list_top, len(name_list) * [
        namer_dir
    ], name_list, metadata_list
Exemplo n.º 13
0
def main():
    """Initializes main script from command-line call to generate single-subject or multi-subject workflow(s)"""
    import os
    import gc
    import sys
    import json
    import ast
    import yaml
    import itertools
    from types import SimpleNamespace
    from pathlib import Path
    import pkg_resources
    from pynets.core.utils import flatten
    from pynets.cli.pynets_run import build_workflow
    from multiprocessing import set_start_method, Process, Manager
    try:
        import pynets
    except ImportError:
        print('PyNets not installed! Ensure that you are referencing the correct site-packages and using Python3.6+')

    if len(sys.argv) < 1:
        print("\nMissing command-line inputs! See help options with the -h flag.\n")
        sys.exit()

    print('Obtaining Derivatives Layout...')

    modalities = ['func', 'dwi']

    bids_args = get_bids_parser().parse_args()
    participant_label = bids_args.participant_label
    session_label = bids_args.session_label
    modality = bids_args.modality
    bids_config = bids_args.config
    analysis_level = bids_args.analysis_level
    clean = bids_args.clean

    if analysis_level == 'group' and participant_label is not None:
        raise ValueError('Error: You have indicated a group analysis level run, but specified a participant label!')

    if analysis_level == 'participant' and participant_label is None:
        raise ValueError('Error: You have indicated a participant analysis level run, but not specified a participant '
                         'label!')

    if bids_config:
        with open(bids_config, 'r') as stream:
            arg_dict = json.load(stream)
    else:
        with open(pkg_resources.resource_filename("pynets", "config/bids_config_test.json"), 'r') as stream:
            arg_dict = json.load(stream)
        stream.close()

    # Available functional and structural connectivity models
    with open(pkg_resources.resource_filename("pynets", "runconfig.yaml"), 'r') as stream:
        hardcoded_params = yaml.load(stream)
        try:
            func_models = hardcoded_params['available_models']['func_models']
        except KeyError:
            print('ERROR: available functional models not successfully extracted from runconfig.yaml')
            sys.exit()
        try:
            struct_models = hardcoded_params['available_models']['struct_models']
        except KeyError:
            print('ERROR: available structural models not successfully extracted from runconfig.yaml')
            sys.exit()

        space = hardcoded_params['bids_defaults']['space'][0]
        func_desc = hardcoded_params['bids_defaults']['desc'][0]
    stream.close()

    # S3
    # Primary inputs

    s3 = bids_args.bids_dir.startswith("s3://")

    if not s3:
        bids_dir = bids_args.bids_dir

    # secondary inputs
    sec_s3_objs = []
    if isinstance(bids_args.ua, list):
        for i in bids_args.ua:
            if i.startswith("s3://"):
                print('Downloading user atlas: ', i, ' from S3...')
                sec_s3_objs.append(i)
    if isinstance(bids_args.cm, list):
        for i in bids_args.cm:
            if i.startswith("s3://"):
                print('Downloading clustering mask: ', i, ' from S3...')
                sec_s3_objs.append(i)
    if isinstance(bids_args.roi, list):
        for i in bids_args.roi:
            if i.startswith("s3://"):
                print('Downloading ROI mask: ', i, ' from S3...')
                sec_s3_objs.append(i)
    if isinstance(bids_args.way, list):
        for i in bids_args.way:
            if i.startswith("s3://"):
                print('Downloading tractography waymask: ', i, ' from S3...')
                sec_s3_objs.append(i)

    if bids_args.ref:
        if bids_args.ref.startswith("s3://"):
            print('Downloading atlas labeling reference file: ', bids_args.ref, ' from S3...')
            sec_s3_objs.append(bids_args.ref)

    if s3 or len(sec_s3_objs) > 0:
        from boto3.session import Session
        from pynets.core import cloud_utils
        from pynets.core.utils import as_directory

        home = os.path.expanduser("~")
        creds = bool(cloud_utils.get_credentials())

        if s3:
            buck, remo = cloud_utils.parse_path(bids_args.bids_dir)
            os.makedirs(f"{home}/.pynets", exist_ok=True)
            os.makedirs(f"{home}/.pynets/input", exist_ok=True)
            os.makedirs(f"{home}/.pynets/output", exist_ok=True)
            bids_dir = as_directory(f"{home}/.pynets/input", remove=False)
            if (not creds) and bids_args.push_location:
                raise AttributeError("""No AWS credentials found, but `--push_location` flag called. 
                Pushing will most likely fail.""")
            else:
                output_dir = as_directory(f"{home}/.pynets/output", remove=False)

            # Get S3 input data if needed
            if analysis_level == 'participant':
                for partic, ses in list(itertools.product(participant_label, session_label)):
                    if ses is not None:
                        info = "sub-" + partic + '/ses-' + ses
                    elif ses is None:
                        info = "sub-" + partic
                    cloud_utils.s3_get_data(buck, remo, bids_dir, modality, info=info)
            elif analysis_level == 'group':
                if len(session_label) > 1 and session_label[0] != 'None':
                    for ses in session_label:
                        info = 'ses-' + ses
                        cloud_utils.s3_get_data(buck, remo, bids_dir, modality, info=info)
                else:
                    cloud_utils.s3_get_data(buck, remo, bids_dir, modality)

        if len(sec_s3_objs) > 0:
            [access_key, secret_key] = cloud_utils.get_credentials()

            session = Session(
                aws_access_key_id=access_key,
                aws_secret_access_key=secret_key
            )

            s3_r = session.resource('s3')
            s3_c = cloud_utils.s3_client(service="s3")
            sec_dir = as_directory(home + "/.pynets/secondary_files", remove=False)
            for s3_obj in [i for i in sec_s3_objs if i is not None]:
                buck, remo = cloud_utils.parse_path(s3_obj)
                s3_c.download_file(buck, remo, f"{sec_dir}/{os.path.basename(s3_obj)}")

            if isinstance(bids_args.ua, list):
                local_ua = bids_args.ua.copy()
                for i in local_ua:
                    if i.startswith("s3://"):
                        local_ua[local_ua.index(i)] = f"{sec_dir}/{os.path.basename(i)}"
                bids_args.ua = local_ua
            if isinstance(bids_args.cm, list):
                local_cm = bids_args.cm.copy()
                for i in bids_args.cm:
                    if i.startswith("s3://"):
                        local_cm[local_cm.index(i)] = f"{sec_dir}/{os.path.basename(i)}"
                bids_args.cm = local_cm
            if isinstance(bids_args.roi, list):
                local_roi = bids_args.roi.copy()
                for i in bids_args.roi:
                    if i.startswith("s3://"):
                        local_roi[local_roi.index(i)] = f"{sec_dir}/{os.path.basename(i)}"
                bids_args.roi = local_roi
            if isinstance(bids_args.way, list):
                local_way = bids_args.way.copy()
                for i in bids_args.way:
                    if i.startswith("s3://"):
                        local_way[local_way.index(i)] = f"{sec_dir}/{os.path.basename(i)}"
                bids_args.way = local_way

            if bids_args.ref:
                if bids_args.ref.startswith("s3://"):
                    bids_args.ref = f"{sec_dir}/{os.path.basename(bids_args.ref)}"
    else:
        output_dir = bids_args.output_dir
        if output_dir is None:
            raise ValueError('Must specify an output directory')

    intermodal_dict = {k: [] for k in ['funcs', 'confs', 'dwis', 'bvals', 'bvecs', 'anats', 'masks',
                                       'subjs', 'seshs']}
    if analysis_level == 'group':
        if len(modality) > 1:
            i = 0
            for mod in modality:
                outs = sweep_directory(bids_dir, modality=mod, space=space, func_desc=func_desc, sesh=session_label)
                if mod == 'func':
                    if i == 0:
                        funcs, confs, _, _, _, anats, masks, subjs, seshs = outs
                    else:
                        funcs, confs, _, _, _, _, _, _, _ = outs
                    intermodal_dict['funcs'].append(funcs)
                    intermodal_dict['confs'].append(confs)
                elif mod == 'dwi':
                    if i == 0:
                        _, _, dwis, bvals, bvecs, anats, masks, subjs, seshs = outs
                    else:
                        _, _, dwis, bvals, bvecs, _, _, _, _ = outs
                    intermodal_dict['dwis'].append(dwis)
                    intermodal_dict['bvals'].append(bvals)
                    intermodal_dict['bvecs'].append(bvecs)
                intermodal_dict['anats'].append(anats)
                intermodal_dict['masks'].append(masks)
                intermodal_dict['subjs'].append(subjs)
                intermodal_dict['seshs'].append(seshs)
                i += 1
        else:
            intermodal_dict = None
            outs = sweep_directory(bids_dir, modality=modality[0], space=space, func_desc=func_desc,
                                   sesh=session_label)
            funcs, confs, dwis, bvals, bvecs, anats, masks, subjs, seshs = outs
    elif analysis_level == 'participant':
        if len(modality) > 1:
            i = 0
            for mod in modality:
                outs = sweep_directory(bids_dir, modality=mod, space=space, func_desc=func_desc,
                                       subj=participant_label, sesh=session_label)
                if mod == 'func':
                    if i == 0:
                        funcs, confs, _, _, _, anats, masks, subjs, seshs = outs
                    else:
                        funcs, confs, _, _, _, _, _, _, _ = outs
                    intermodal_dict['funcs'].append(funcs)
                    intermodal_dict['confs'].append(confs)
                elif mod == 'dwi':
                    if i == 0:
                        _, _, dwis, bvals, bvecs, anats, masks, subjs, seshs = outs
                    else:
                        _, _, dwis, bvals, bvecs, _, _, _, _ = outs
                    intermodal_dict['dwis'].append(dwis)
                    intermodal_dict['bvals'].append(bvals)
                    intermodal_dict['bvecs'].append(bvecs)
                intermodal_dict['anats'].append(anats)
                intermodal_dict['masks'].append(masks)
                intermodal_dict['subjs'].append(subjs)
                intermodal_dict['seshs'].append(seshs)
                i += 1
        else:
            intermodal_dict = None
            outs = sweep_directory(bids_dir, modality=modality[0], space=space, func_desc=func_desc,
                                   subj=participant_label, sesh=session_label)
            funcs, confs, dwis, bvals, bvecs, anats, masks, subjs, seshs = outs
    else:
        raise ValueError('Analysis level invalid. Must be `participant` or `group`. See --help.')

    if intermodal_dict:
        funcs, confs, dwis, bvals, bvecs, anats, masks, subjs, seshs = [list(set(list(flatten(i)))) for i in
                                                                        intermodal_dict.values()]

    arg_list = []
    for mod in modalities:
        arg_list.append(arg_dict[mod])

    arg_list.append(arg_dict['gen'])

    args_dict_all = {}
    models = []
    for d in arg_list:
        if 'mod' in d.keys():
            if len(modality) == 1:
                if any(x in d['mod'] for x in func_models):
                    if 'dwi' in modality:
                        del d['mod']
                elif any(x in d['mod'] for x in struct_models):
                    if 'func' in modality:
                        del d['mod']
            else:
                if any(x in d['mod'] for x in func_models) or any(x in d['mod'] for x in struct_models):
                    models.append(ast.literal_eval(d['mod']))
        args_dict_all.update(d)

    if len(modality) > 1:
        args_dict_all['mod'] = str(list(set(flatten(models))))

    print('Arguments parsed from bids_config.json:\n')
    print(args_dict_all)

    for key, val in args_dict_all.items():
        if isinstance(val, str):
            args_dict_all[key] = ast.literal_eval(val)

    id_list = []
    for i in sorted(list(set(subjs))):
        for ses in sorted(list(set(seshs))):
            id_list.append(i + '_' + ses)

    args_dict_all['work'] = bids_args.work
    args_dict_all['output_dir'] = output_dir
    args_dict_all['plug'] = bids_args.plug
    args_dict_all['pm'] = bids_args.pm
    args_dict_all['v'] = bids_args.v
    args_dict_all['clean'] = bids_args.clean
    if funcs is not None:
        args_dict_all['func'] = sorted(funcs)
    else:
        args_dict_all['func'] = None
    if confs is not None:
        args_dict_all['conf'] = sorted(confs)
    else:
        args_dict_all['conf'] = None
    if dwis is not None:
        args_dict_all['dwi'] = sorted(dwis)
        args_dict_all['bval'] = sorted(bvals)
        args_dict_all['bvec'] = sorted(bvecs)
    else:
        args_dict_all['dwi'] = None
        args_dict_all['bval'] = None
        args_dict_all['bvec'] = None
    if anats is not None:
        args_dict_all['anat'] = sorted(anats)
    else:
        args_dict_all['anat'] = None
    if masks is not None:
        args_dict_all['m'] = sorted(masks)
    else:
        args_dict_all['m'] = None
    args_dict_all['g'] = None
    if ('dwi' in modality) and (bids_args.way is not None):
        args_dict_all['way'] = bids_args.way
    else:
        args_dict_all['way'] = None
    args_dict_all['id'] = id_list
    args_dict_all['ua'] = bids_args.ua
    args_dict_all['ref'] = bids_args.ref
    args_dict_all['roi'] = bids_args.roi
    if ('func' in modality) and (bids_args.cm is not None):
        args_dict_all['cm'] = bids_args.cm
    else:
        args_dict_all['cm'] = None

    # Mimic argparse with SimpleNamespace object
    args = SimpleNamespace(**args_dict_all)
    print(args)

    set_start_method('forkserver')
    with Manager() as mgr:
        retval = mgr.dict()
        p = Process(target=build_workflow, args=(args, retval))
        p.start()
        p.join()
        if p.is_alive():
            p.terminate()

        retcode = p.exitcode or retval.get('return_code', 0)

        pynets_wf = retval.get('workflow', None)
        work_dir = retval.get('work_dir')
        plugin_settings = retval.get('plugin_settings', None)
        plugin_settings = retval.get('plugin_settings', None)
        execution_dict = retval.get('execution_dict', None)
        run_uuid = retval.get('run_uuid', None)

        retcode = retcode or int(pynets_wf is None)
        if retcode != 0:
            sys.exit(retcode)

        # Clean up master process before running workflow, which may create forks
        gc.collect()

    mgr.shutdown()

    if bids_args.push_location:
        print(f"Pushing to s3 at {bids_args.push_location}.")
        push_buck, push_remo = cloud_utils.parse_path(bids_args.push_location)
        for id in id_list:
            cloud_utils.s3_push_data(
                push_buck,
                push_remo,
                output_dir,
                modality,
                subject=id.split('_')[0],
                session=id.split('_')[1],
                creds=creds,
            )

    sys.exit(0)

    return
Exemplo n.º 14
0
def benchmark_reproducibility(base_dir, comb, modality, alg, par_dict, disc,
                              final_missingness_summary, icc_tmps_dir, icc,
                              mets, ids, template):
    import gc
    import json
    import glob
    from pathlib import Path
    import ast
    import matplotlib
    from pynets.stats.utils import gen_sub_vec
    matplotlib.use('Agg')

    df_summary = pd.DataFrame(
        columns=['grid', 'modality', 'embedding', 'discriminability'])
    print(comb)
    df_summary.at[0, "modality"] = modality
    df_summary.at[0, "embedding"] = alg

    if modality == 'func':
        try:
            extract, hpass, model, res, atlas, smooth = comb
        except BaseException:
            print(f"Missing {comb}...")
            extract, hpass, model, res, atlas = comb
            smooth = '0'
        # comb_tuple = (atlas, extract, hpass, model, res, smooth)
        comb_tuple = comb
    else:
        directget, minlength, model, res, atlas, tol = comb
        # comb_tuple = (atlas, directget, minlength, model, res, tol)
        comb_tuple = comb

    df_summary.at[0, "grid"] = comb_tuple

    # missing_sub_seshes = \
    #     final_missingness_summary.loc[(final_missingness_summary['alg']==alg)
    #                                   & (final_missingness_summary[
    #                                          'modality']==modality) &
    #                                   (final_missingness_summary[
    #                                        'grid']==comb_tuple)
    #                                   ].drop_duplicates(subset='id')

    # icc
    if icc is True and alg == 'topology':
        try:
            import pingouin as pg
        except ImportError:
            print("Cannot evaluate ICC. pingouin" " must be installed!")
        for met in mets:
            id_dict = {}
            dfs = []
            for ses in [str(i) for i in range(1, 11)]:
                for ID in ids:
                    id_dict[ID] = {}
                    if comb_tuple in par_dict[ID][str(
                            ses)][modality][alg].keys():
                        id_dict[ID][str(ses)] = \
                            par_dict[ID][str(ses)][modality][alg][comb_tuple][
                                mets.index(met)][0]
                    df = pd.DataFrame(id_dict).T
                    if df.empty:
                        del df
                        return df_summary
                    df.columns.values[0] = f"{met}"
                    df.replace(0, np.nan, inplace=True)
                    df['id'] = df.index
                    df['ses'] = ses
                    df.reset_index(drop=True, inplace=True)
                    dfs.append(df)
            df_long = pd.concat(dfs, names=[
                'id', 'ses', f"{met}"
            ]).drop(columns=[str(i) for i in range(1, 10)])
            if '10' in df_long.columns:
                df_long[f"{met}"] = df_long[f"{met}"].fillna(df_long['10'])
                df_long = df_long.drop(columns='10')
            try:
                c_icc = pg.intraclass_corr(data=df_long,
                                           targets='id',
                                           raters='ses',
                                           ratings=f"{met}",
                                           nan_policy='omit').round(3)
                c_icc = c_icc.set_index("Type")
                c_icc3 = c_icc.drop(
                    index=['ICC1', 'ICC2', 'ICC1k', 'ICC2k', 'ICC3'])
                df_summary.at[0, f"icc_{met}"] = c_icc3['ICC'].values[0]
                df_summary.at[0, f"icc_{met}_CI95%_L"] = \
                    c_icc3['CI95%'].values[0][0]
                df_summary.at[0, f"icc_{met}_CI95%_U"] = \
                    c_icc3['CI95%'].values[0][1]
            except BaseException:
                print('FAILED...')
                print(df_long)
                del df_long
                return df_summary
            del df_long
    elif icc is True and alg != 'topology':
        import re
        from pynets.stats.utils import parse_closest_ixs
        try:
            import pingouin as pg
        except ImportError:
            print("Cannot evaluate ICC. pingouin" " must be installed!")
        dfs = []
        coords_frames = []
        labels_frames = []
        for ses in [str(i) for i in range(1, 11)]:
            for ID in ids:
                if ses in par_dict[ID].keys():
                    if comb_tuple in par_dict[ID][str(
                            ses)][modality][alg].keys():
                        if 'data' in par_dict[ID][str(
                                ses)][modality][alg][comb_tuple].keys():
                            if par_dict[ID][str(ses)][modality][alg][
                                    comb_tuple]['data'] is not None:
                                if isinstance(
                                        par_dict[ID][str(ses)][modality][alg]
                                    [comb_tuple]['data'], str):
                                    data_path = par_dict[ID][str(ses)][
                                        modality][alg][comb_tuple]['data']
                                    parent_dir = Path(
                                        os.path.dirname(
                                            par_dict[ID][str(ses)][modality]
                                            [alg][comb_tuple]['data'])).parent
                                    if os.path.isfile(data_path):
                                        try:
                                            if data_path.endswith('.npy'):
                                                emb_data = np.load(data_path)
                                            elif data_path.endswith('.csv'):
                                                emb_data = np.array(
                                                    pd.read_csv(
                                                        data_path)).reshape(
                                                            -1, 1)
                                            else:
                                                emb_data = np.nan
                                            node_files = glob.glob(
                                                f"{parent_dir}/nodes/*.json")
                                        except:
                                            print(f"Failed to load data from "
                                                  f"{data_path}..")
                                            continue
                                    else:
                                        continue
                                else:
                                    node_files = glob.glob(
                                        f"{base_dir}/pynets/sub-{ID}/ses-"
                                        f"{ses}/{modality}/rsn-"
                                        f"{atlas}_res-{res}/nodes/*.json")
                                    emb_data = par_dict[ID][str(ses)][
                                        modality][alg][comb_tuple]['data']

                                emb_shape = emb_data.shape[0]

                                if len(node_files) > 0:
                                    ixs, node_dict = parse_closest_ixs(
                                        node_files,
                                        emb_shape,
                                        template=template)
                                    if len(ixs) != emb_shape:
                                        ixs, node_dict = parse_closest_ixs(
                                            node_files, emb_shape)
                                    if isinstance(node_dict, dict):
                                        coords = [
                                            node_dict[i]['coord']
                                            for i in node_dict.keys()
                                        ]
                                        labels = [
                                            node_dict[i]['label'][
                                                'BrainnetomeAtlas'
                                                'Fan2016']
                                            for i in node_dict.keys()
                                        ]
                                    else:
                                        print(f"Failed to parse coords/"
                                              f"labels from {node_files}. "
                                              f"Skipping...")
                                        continue
                                    df_coords = pd.DataFrame(
                                        [str(tuple(x)) for x in coords]).T
                                    df_coords.columns = [
                                        f"rsn-{atlas}_res-"
                                        f"{res}_{i}" for i in ixs
                                    ]
                                    # labels = [
                                    #     list(i['label'])[7] for i
                                    #     in
                                    #     node_dict]
                                    df_labels = pd.DataFrame(labels).T
                                    df_labels.columns = [
                                        f"rsn-{atlas}_res-"
                                        f"{res}_{i}" for i in ixs
                                    ]
                                    coords_frames.append(df_coords)
                                    labels_frames.append(df_labels)
                                else:
                                    print(f"No node files detected for "
                                          f"{comb_tuple} and {ID}-{ses}...")
                                    ixs = [
                                        i for i in par_dict[ID][str(ses)]
                                        [modality][alg][comb_tuple]['index']
                                        if i is not None
                                    ]
                                    coords_frames.append(pd.Series())
                                    labels_frames.append(pd.Series())

                                if len(ixs) == emb_shape:
                                    df_pref = pd.DataFrame(emb_data.T,
                                                           columns=[
                                                               f"{alg}_{i}_rsn"
                                                               f"-{atlas}_res-"
                                                               f"{res}"
                                                               for i in ixs
                                                           ])
                                    df_pref['id'] = ID
                                    df_pref['ses'] = ses
                                    df_pref.replace(0, np.nan, inplace=True)
                                    df_pref.reset_index(drop=True,
                                                        inplace=True)
                                    dfs.append(df_pref)
                                else:
                                    print(
                                        f"Embedding shape {emb_shape} for "
                                        f"{comb_tuple} does not correspond to "
                                        f"{len(ixs)} indices found for "
                                        f"{ID}-{ses}. Skipping...")
                                    continue
                        else:
                            print(
                                f"data not found in {comb_tuple}. Skipping...")
                            continue
                else:
                    continue

        if len(dfs) == 0:
            return df_summary

        if len(coords_frames) > 0 and len(labels_frames) > 0:
            coords_frames_icc = pd.concat(coords_frames)
            labels_frames_icc = pd.concat(labels_frames)
            nodes = True
        else:
            nodes = False

        df_long = pd.concat(dfs, axis=0)
        df_long = df_long.dropna(axis='columns', thresh=0.75 * len(df_long))
        df_long = df_long.dropna(axis='rows', how='all')

        dict_sum = df_summary.drop(
            columns=['grid', 'modality', 'embedding', 'discriminability'
                     ]).to_dict()

        for lp in [
                i for i in df_long.columns if 'ses' not in i and 'id' not in i
        ]:
            ix = int(lp.split(f"{alg}_")[1].split('_')[0])
            rsn = lp.split(f"{alg}_{ix}_")[1]
            df_long_clean = df_long[['id', 'ses', lp]]
            # df_long_clean = df_long[['id', 'ses', lp]].loc[(df_long[['id',
            # 'ses', lp]]['id'].duplicated() == True) & (df_long[['id', 'ses',
            # lp]]['ses'].duplicated() == True) & (df_long[['id', 'ses',
            # lp]][lp].isnull()==False)]
            # df_long_clean[lp] = np.abs(df_long_clean[lp].round(6))
            # df_long_clean['ses'] = df_long_clean['ses'].astype('int')
            # g = df_long_clean.groupby(['ses'])
            # df_long_clean = pd.DataFrame(g.apply(
            #     lambda x: x.sample(g.size().min()).reset_index(drop=True))
            #     ).reset_index(drop=True)
            try:
                c_icc = pg.intraclass_corr(data=df_long_clean,
                                           targets='id',
                                           raters='ses',
                                           ratings=lp,
                                           nan_policy='omit').round(3)
                c_icc = c_icc.set_index("Type")
                c_icc3 = c_icc.drop(
                    index=['ICC1', 'ICC2', 'ICC1k', 'ICC2k', 'ICC3'])
                icc_val = c_icc3['ICC'].values[0]
                if nodes is True:
                    coord_in = np.array(ast.literal_eval(
                        coords_frames_icc[f"{rsn}_{ix}"].mode().values[0]),
                                        dtype=np.dtype("O"))
                    label_in = np.array(
                        labels_frames_icc[f"{rsn}_{ix}"].mode().values[0],
                        dtype=np.dtype("O"))
                else:
                    coord_in = np.nan
                    label_in = np.nan
                dict_sum[f"{lp}_icc"] = icc_val
                del c_icc, c_icc3, icc_val
            except BaseException:
                print(f"FAILED for {lp}...")
                # print(df_long)
                #df_summary.at[0, f"{lp}_icc"] = np.nan
                coord_in = np.nan
                label_in = np.nan

            dict_sum[f"{lp}_coord"] = coord_in
            dict_sum[f"{lp}_label"] = label_in

        df_summary = pd.concat(
            [df_summary, pd.DataFrame(pd.Series(dict_sum).T).T], axis=1)

        print(df_summary)

        tup_name = str(comb_tuple).replace('\', \'',
                                           '_').replace('(', '').replace(
                                               ')', '').replace('\'', '')
        df_summary.to_csv(f"{icc_tmps_dir}/{alg}_{tup_name}.csv",
                          index=False,
                          header=True)
        del df_long

    # discriminability
    if disc is True:
        vect_all = []
        for ID in ids:
            try:
                out = gen_sub_vec(base_dir, par_dict, ID, modality, alg,
                                  comb_tuple)
            except BaseException:
                print(f"{ID} {modality} {alg} {comb_tuple} failed...")
                continue
            # print(out)
            vect_all.append(out)
        # ## TODO: Remove the .iloc below to include global efficiency.
        # vect_all = [pd.DataFrame(i).iloc[1:] for i in vect_all if i is not
        #             None and not np.isnan(np.array(i)).all()]
        vect_all = [
            pd.DataFrame(i) for i in vect_all
            if i is not None and not np.isnan(np.array(i)).all()
        ]

        if len(vect_all) > 0:
            if len(vect_all) > 0:
                X_top = pd.concat(vect_all, axis=0, join="outer")
                X_top = np.array(
                    X_top.dropna(axis='columns', thresh=0.50 * len(X_top)))
            else:
                print('Empty dataframe!')
                return df_summary

            shapes = []
            for ix, i in enumerate(vect_all):
                shapes.append(i.shape[0] * [list(ids)[ix]])
            Y = np.array(list(flatten(shapes)))
            if alg == 'topology':
                imp = IterativeImputer(max_iter=50, random_state=42)
            else:
                imp = SimpleImputer()
            X_top = imp.fit_transform(X_top)
            scaler = StandardScaler()
            X_top = scaler.fit_transform(X_top)
            try:
                discr_stat_val, rdf = discr_stat(X_top, Y)
                df_summary.at[0, "discriminability"] = discr_stat_val
                print(discr_stat_val)
                print("\n")
                del discr_stat_val
            except BaseException:
                print('Discriminability calculation failed...')
                return df_summary
            # print(rdf)
        del vect_all

    gc.collect()
    return df_summary
Exemplo n.º 15
0
def build_omnetome(est_path_iterlist, ID):
    """
    Embeds ensemble population of graphs into an embedded ensemble feature
    vector.

    Parameters
    ----------
    est_path_iterlist : list
        List of file paths to .npy file containing graph.
    ID : str
        A subject id or other unique identifier.

    References
    ----------
    .. [1] Liu, Y., He, L., Cao, B., Yu, P. S., Ragin, A. B., & Leow, A. D.
      (2018). Multi-view multi-graph embedding for brain network clustering
      analysis. 32nd AAAI Conference on Artificial Intelligence, AAAI 2018.
    .. [2] Levin, K., Athreya, A., Tang, M., Lyzinski, V., & Priebe, C. E.
      (2017, November). A central limit theorem for an omnibus embedding of
      multiple random dot product graphs. In Data Mining Workshops (ICDMW),
      2017 IEEE International Conference on (pp. 964-967). IEEE.

    """
    from pathlib import Path
    import sys
    import numpy as np
    from pynets.core.utils import flatten
    from pynets.stats.embeddings import _omni_embed
    from pynets.core.utils import load_runconfig

    # Available functional and structural connectivity models
    hardcoded_params = load_runconfig()

    try:
        func_models = hardcoded_params["available_models"]["func_models"]
    except KeyError:
        print(
            "ERROR: available functional models not sucessfully extracted"
            " from runconfig.yaml"
        )
        sys.exit(1)
    try:
        struct_models = hardcoded_params["available_models"][
            "struct_models"]
    except KeyError:
        print(
            "ERROR: available structural models not sucessfully extracted"
            " from runconfig.yaml"
        )
        sys.exit(1)
    try:
        n_components = hardcoded_params["gradients"][
            "n_components"][0]
    except KeyError:
        print(
            "ERROR: available gradient dimensionality presets not "
            "sucessfully extracted from runconfig.yaml"
        )
        sys.exit(1)

    if isinstance(est_path_iterlist, list):
        est_path_iterlist = list(flatten(est_path_iterlist))
    else:
        est_path_iterlist = [est_path_iterlist]

    if len(est_path_iterlist) > 1:
        atlases = list(set([x.split("/")[-3].split("/")[0]
                            for x in est_path_iterlist]))
        parcel_dict_func = dict.fromkeys(atlases)
        parcel_dict_dwi = dict.fromkeys(atlases)

        est_path_iterlist_dwi = list(
            set(
                [
                    i
                    for i in est_path_iterlist
                    if i.split("model-")[1].split("_")[0] in struct_models
                ]
            )
        )
        est_path_iterlist_func = list(
            set(
                [
                    i
                    for i in est_path_iterlist
                    if i.split("model-")[1].split("_")[0] in func_models
                ]
            )
        )

        if "_rsn" in ";".join(est_path_iterlist_func):
            func_subnets = list(
                set([i.split("_rsn-")[1].split("_")[0] for i in
                     est_path_iterlist_func])
            )
        else:
            func_subnets = []
        if "_rsn" in ";".join(est_path_iterlist_dwi):
            dwi_subnets = list(
                set([i.split("_rsn-")[1].split("_")[0] for i in
                     est_path_iterlist_dwi])
            )
        else:
            dwi_subnets = []

        out_paths_func = []
        out_paths_dwi = []
        for atlas in atlases:
            if len(func_subnets) >= 1:
                parcel_dict_func[atlas] = {}
                for sub_net in func_subnets:
                    parcel_dict_func[atlas][sub_net] = []
            else:
                parcel_dict_func[atlas] = []

            if len(dwi_subnets) >= 1:
                parcel_dict_dwi[atlas] = {}
                for sub_net in dwi_subnets:
                    parcel_dict_dwi[atlas][sub_net] = []
            else:
                parcel_dict_dwi[atlas] = []

            for graph_path in est_path_iterlist_dwi:
                if atlas in graph_path:
                    if len(dwi_subnets) >= 1:
                        for sub_net in dwi_subnets:
                            if sub_net in graph_path:
                                parcel_dict_dwi[atlas][sub_net].append(
                                    graph_path)
                    else:
                        parcel_dict_dwi[atlas].append(graph_path)

            for graph_path in est_path_iterlist_func:
                if atlas in graph_path:
                    if len(func_subnets) >= 1:
                        for sub_net in func_subnets:
                            if sub_net in graph_path:
                                parcel_dict_func[atlas][sub_net].append(
                                    graph_path)
                    else:
                        parcel_dict_func[atlas].append(graph_path)
            if len(parcel_dict_func[atlas]) > 0:
                if isinstance(parcel_dict_func[atlas], dict):
                    # RSN case
                    for rsn in parcel_dict_func[atlas]:
                        pop_rsn_list = []
                        graph_path_list = []
                        for graph in parcel_dict_func[atlas][rsn]:
                            pop_rsn_list.append(np.load(graph))
                            graph_path_list.append(graph)
                        if len(pop_rsn_list) > 1:
                            if len(
                                    list(set([i.shape for i in
                                              pop_rsn_list]))) > 1:
                                raise RuntimeWarning(
                                    "Inconsistent number of"
                                    " vertices in graph population "
                                    "that precludes embedding...")
                            out_path = _omni_embed(
                                pop_rsn_list, atlas, graph_path_list, ID, rsn,
                                n_components
                            )
                            out_paths_func.append(out_path)
                        else:
                            print(
                                "WARNING: Only one graph sampled, omnibus"
                                " embedding not appropriate."
                            )
                            pass
                else:
                    pop_list = []
                    graph_path_list = []
                    for pop_ref in parcel_dict_func[atlas]:
                        pop_list.append(np.load(pop_ref))
                        graph_path_list.append(pop_ref)
                    if len(pop_list) > 1:
                        if len(list(set([i.shape for i in pop_list]))) > 1:
                            raise RuntimeWarning(
                                "Inconsistent number of vertices in "
                                "graph population that precludes embedding")
                        out_path = _omni_embed(pop_list, atlas,
                                               graph_path_list, ID,
                                               n_components=n_components)
                        out_paths_func.append(out_path)
                    else:
                        print(
                            "WARNING: Only one graph sampled, omnibus "
                            "embedding not appropriate."
                        )
                        pass

            if len(parcel_dict_dwi[atlas]) > 0:
                if isinstance(parcel_dict_dwi[atlas], dict):
                    # RSN case
                    graph_path_list = []
                    for rsn in parcel_dict_dwi[atlas]:
                        pop_rsn_list = []
                        for graph in parcel_dict_dwi[atlas][rsn]:
                            pop_rsn_list.append(np.load(graph))
                            graph_path_list.append(graph)
                        if len(pop_rsn_list) > 1:
                            if len(
                                    list(set([i.shape for i in
                                              pop_rsn_list]))) > 1:
                                raise RuntimeWarning(
                                    "Inconsistent number of"
                                    " vertices in graph population "
                                    "that precludes embedding")
                            out_path = _omni_embed(
                                pop_rsn_list, atlas, graph_path_list, ID,
                                rsn, n_components
                            )
                            out_paths_dwi.append(out_path)
                        else:
                            print(
                                "WARNING: Only one graph sampled, omnibus"
                                " embedding not appropriate."
                            )
                            pass
                else:
                    pop_list = []
                    graph_path_list = []
                    for pop_ref in parcel_dict_dwi[atlas]:
                        pop_list.append(np.load(pop_ref))
                        graph_path_list.append(pop_ref)
                    if len(pop_list) > 1:
                        if len(list(set([i.shape for i in pop_list]))) > 1:
                            raise RuntimeWarning(
                                "Inconsistent number of vertices in graph"
                                " population that precludes embedding")
                        out_path = _omni_embed(pop_list, atlas,
                                               graph_path_list,
                                               ID, n_components=n_components)
                        out_paths_dwi.append(out_path)
                    else:
                        print(
                            "WARNING: Only one graph sampled, omnibus "
                            "embedding not appropriate."
                        )
                        pass
    else:
        print("At least two graphs required to build an omnetome...")
        out_paths_func = []
        out_paths_dwi = []
        pass

    return out_paths_dwi, out_paths_func
Exemplo n.º 16
0
def psycho_naming(coords, node_size):
    """
    Perform Automated Sentiment Labeling of each coordinate from a list of MNI coordinates.

    Parameters
    ----------
    coords : list
        List of (x, y, z) tuples in voxel-space corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    node_size : int
        Spherical centroid node size in the case that coordinate-based centroids
        are used as ROI's for tracking.

    Returns
    -------
    labels : list
        List of string labels corresponding to each coordinate-corresponding psychological topic.

    References
    ----------
    .. [1] Tor D., W. (2011). NeuroSynth: a new platform for large-scale automated synthesis of
      human functional neuroimaging data. Frontiers in Neuroinformatics.
      https://doi.org/10.3389/conf.fninf.2011.08.00058
    .. [2] Tausczik, Y. R., & Pennebaker, J. W. (2010). The psychological meaning of words:
      LIWC and computerized text analysis methods. Journal of Language and Social Psychology.
      https://doi.org/10.1177/0261927X09351676

    """
    import liwc
    import pkg_resources
    import nimare
    import nltk
    from collections import Counter
    from nltk.corpus import sentiwordnet as swn
    from pynets.core.utils import flatten
    from nltk.stem import WordNetLemmatizer

    try:
        swn.senti_synsets('TEST')
    except:
        nltk.download('sentiwordnet')
        nltk.download('wordnet')

    with open(pkg_resources.resource_filename("pynets", "runconfig.yaml"),
              'r') as stream:
        hardcoded_params = yaml.load(stream)
        try:
            LIWC_file = hardcoded_params['sentiment_labeling']['liwc_file'][0]
        except FileNotFoundError:
            print('LIWC file not found. Check runconfig.yaml.')
        try:
            neurosynth_dset_file = hardcoded_params['sentiment_labeling'][
                'neurosynth_db'][0]
        except FileNotFoundError:
            print(
                'Neurosynth dataset .pkl file not found. Check runconfig.yaml.'
            )
    stream.close()

    try:
        dset = nimare.dataset.Dataset.load(neurosynth_dset_file)
    except FileNotFoundError:
        print('Loading neurosynth dictionary failed!')

    try:
        parse, category_names = liwc.load_token_parser(LIWC_file)
    except FileNotFoundError:
        print('Loading LIWC dictionary failed!')

    labels = []
    print('Building coordinate labels...')
    for coord in coords:
        print(coord)
        roi_ids = dset.get_studies_by_coordinate(
            np.array(coord).reshape(1, -1), node_size)
        labs = dset.get_labels(ids=roi_ids)
        labs_filt = list(
            flatten([
                list([
                    i for j in swn.senti_synsets(i)
                    if j.pos_score() > 0.75 or j.neg_score() > 0.75
                ]) for i in labs
            ]))
        st = WordNetLemmatizer()
        labs_filt = list(set([st.lemmatize(k) for k in labs_filt]))
        liwc_counts = dict(
            Counter(
                top.split(' (')[0] for token in labs_filt
                for top in parse(token)
                if (top.split(' (')[0].lower() != 'bio') and (
                    top.split(' (')[0].lower() != 'adj') and (
                        top.split(' (')[0].lower() != 'verb') and
                (top.split(' (')[0].lower() != 'conj') and (
                    top.split(' (')[0].lower() != 'adverb') and (
                        top.split(' (')[0].lower() != 'auxverb') and (
                            top.split(' (')[0].lower() != 'prep') and (
                                top.split(' (')[0].lower() != 'article') and
                (top.split(' (')[0].lower() != 'ipron') and (
                    top.split(' (')[0].lower() != 'ppron') and (
                        top.split(' (')[0].lower() != 'pronoun') and (
                            top.split(' (')[0].lower() != 'function') and (
                                top.split(' (')[0].lower() != 'affect') and (
                                    top.split(' (')[0].lower() != 'cogproc')))
        liwc_counts_ordered = dict(
            sorted(liwc_counts.items(), key=lambda x: x[1], reverse=True))

        if 'posemo' and 'negemo' in liwc_counts_ordered.keys():
            if liwc_counts_ordered['posemo'] > liwc_counts_ordered['negemo']:
                del liwc_counts_ordered['negemo']
            else:
                del liwc_counts_ordered['posemo']
        liwc_counts_ordered_ratios = {}
        for i in liwc_counts_ordered:
            liwc_counts_ordered_ratios[i] = float(
                liwc_counts_ordered[i]) / float(
                    sum(liwc_counts_ordered.values()))

        lab = ' '.join(
            map(str, [
                key + ' ' + str(np.round(100 * val, 2)) + '%'
                for key, val in liwc_counts_ordered_ratios.items()
            ]))
        print(lab)
        if len(lab) > 0:
            labels.append(lab)
        else:
            labels.append(np.nan)
        del roi_ids, labs_filt, lab, liwc_counts_ordered, liwc_counts, labs
        print('\n')

    return labels
Exemplo n.º 17
0
def parcellate(func_boot_img, local_corr, clust_type, _local_conn_mat_path,
               num_conn_comps, _clust_mask_corr_img, _standardize,
               _detrending, k, _local_conn, conf, _dir_path, _conn_comps):
    """
    API for performing any of a variety of clustering routines available
    through NiLearn.
    """
    import time
    import os
    import numpy as np
    from nilearn.regions import Parcellations
    from pynets.fmri.estimation import fill_confound_nans
    # from joblib import Memory
    import tempfile

    cache_dir = tempfile.mkdtemp()
    # memory = Memory(cache_dir, verbose=0)

    start = time.time()

    if (clust_type == "ward") and (local_corr != "allcorr"):
        if _local_conn_mat_path is not None:
            if not os.path.isfile(_local_conn_mat_path):
                raise FileNotFoundError(
                    "File containing sparse matrix of local connectivity"
                    " structure not found."
                )
        else:
            raise FileNotFoundError(
                "File containing sparse matrix of local connectivity"
                " structure not found."
            )

    if (
        clust_type == "complete"
        or clust_type == "average"
        or clust_type == "single"
        or clust_type == "ward"
        or (clust_type == "rena" and num_conn_comps == 1)
        or (clust_type == "kmeans" and num_conn_comps == 1)
    ):
        _clust_est = Parcellations(
            method=clust_type,
            standardize=_standardize,
            detrend=_detrending,
            n_parcels=k,
            mask=_clust_mask_corr_img,
            connectivity=_local_conn,
            mask_strategy="background",
            random_state=42
        )

        if conf is not None:
            import pandas as pd
            import random
            from nipype.utils.filemanip import fname_presuffix, copyfile

            out_name_conf = fname_presuffix(
                conf, suffix=f"_tmp{random.randint(1, 1000)}",
                newpath=cache_dir
            )
            copyfile(
                conf,
                out_name_conf,
                copy=True,
                use_hardlink=False)

            confounds = pd.read_csv(out_name_conf, sep="\t")
            if confounds.isnull().values.any():
                conf_corr = fill_confound_nans(confounds, _dir_path)
                try:
                    _clust_est.fit(func_boot_img, confounds=conf_corr)
                except UserWarning:
                    return None
                os.remove(conf_corr)
            else:
                try:
                    _clust_est.fit(func_boot_img, confounds=out_name_conf)
                except UserWarning:
                    return None
            os.remove(out_name_conf)
        else:
            try:
                _clust_est.fit(func_boot_img)
            except UserWarning:
                return None
        _clust_est.labels_img_.set_data_dtype(np.uint16)
        print(
            f"{clust_type}{k}"
            f"{(' clusters: %.2fs' % (time.time() - start))}"
        )

        return _clust_est.labels_img_

    elif clust_type == "ncut":
        out_img = parcellate_ncut(
            _local_conn, k, _clust_mask_corr_img
        )
        out_img.set_data_dtype(np.uint16)
        print(
            f"{clust_type}{k}"
            f"{(' clusters: %.2fs' % (time.time() - start))}"
        )
        return out_img

    elif (
        clust_type == "rena"
        or clust_type == "kmeans"
        and num_conn_comps > 1
    ):
        from pynets.core import nodemaker
        from nilearn.regions import connected_regions, Parcellations
        from nilearn.image import iter_img, new_img_like
        from pynets.core.utils import flatten, proportional

        mask_img_list = []
        mask_voxels_dict = dict()
        for i, mask_img in enumerate(iter_img(_conn_comps)):
            mask_voxels_dict[i] = np.int(
                np.sum(np.asarray(mask_img.dataobj)))
            mask_img_list.append(mask_img)

        # Allocate k across connected components using Hagenbach-Bischoff
        # Quota based on number of voxels
        k_list = proportional(k, list(mask_voxels_dict.values()))

        conn_comp_atlases = []
        print(
            f"Building {len(mask_img_list)} separate atlases with "
            f"voxel-proportional k clusters for each "
            f"connected component...")
        for i, mask_img in enumerate(iter_img(mask_img_list)):
            if k_list[i] < 5:
                print(f"Only {k_list[i]} voxels in component. Discarding...")
                continue
            _clust_est = Parcellations(
                method=clust_type,
                standardize=_standardize,
                detrend=_detrending,
                n_parcels=k_list[i],
                mask=mask_img,
                mask_strategy="background",
                random_state=i
            )
            if conf is not None:
                import pandas as pd
                import random
                from nipype.utils.filemanip import fname_presuffix, copyfile

                out_name_conf = fname_presuffix(
                    conf, suffix=f"_tmp{random.randint(1, 1000)}",
                    newpath=cache_dir
                )
                copyfile(
                    conf,
                    out_name_conf,
                    copy=True,
                    use_hardlink=False)

                confounds = pd.read_csv(out_name_conf, sep="\t")
                if confounds.isnull().values.any():
                    conf_corr = fill_confound_nans(
                        confounds, _dir_path)
                    try:
                        _clust_est.fit(func_boot_img, confounds=conf_corr)
                    except UserWarning:
                        continue
                else:
                    try:
                        _clust_est.fit(func_boot_img, confounds=conf)
                    except UserWarning:
                        continue
            else:
                try:
                    _clust_est.fit(func_boot_img)
                except UserWarning:
                    continue
            conn_comp_atlases.append(_clust_est.labels_img_)

        # Then combine the multiple atlases, corresponding to each
        # connected component, into a single atlas
        atlas_of_atlases = []
        for atlas in iter_img(conn_comp_atlases):
            bna_data = np.around(
                np.asarray(
                    atlas.dataobj)).astype("uint16")

            # Get an array of unique parcels
            bna_data_for_coords_uniq = np.unique(bna_data)

            # Number of parcels:
            par_max = len(bna_data_for_coords_uniq) - 1
            img_stack = []
            for idx in range(1, par_max + 1):
                roi_img = bna_data == bna_data_for_coords_uniq[idx].astype(
                    "uint16")
                img_stack.append(roi_img.astype("uint16"))
            img_stack = np.array(img_stack)

            img_list = []
            for idy in range(par_max):
                img_list.append(new_img_like(atlas, img_stack[idy]))
            atlas_of_atlases.append(img_list)
            del img_list, img_stack, bna_data

        atlas_of_atlases = list(flatten(atlas_of_atlases))

        [super_atlas_ward, _] = nodemaker.create_parcel_atlas(
            atlas_of_atlases)
        super_atlas_ward.set_data_dtype(np.uint16)
        del atlas_of_atlases, conn_comp_atlases, mask_img_list, \
            mask_voxels_dict

        print(
            f"{clust_type}{k}"
            f"{(' clusters: %.2fs' % (time.time() - start))}"
        )

        # memory.clear(warn=False)

        return super_atlas_ward
Exemplo n.º 18
0
    def parcellate(self, func_boot_img):
        """
        API for performing any of a variety of clustering routines available
        through NiLearn.
        """
        import time
        import os
        from nilearn.regions import Parcellations
        from pynets.fmri.estimation import fill_confound_nans

        start = time.time()

        if (self.clust_type == "ward") and (self.local_corr != "allcorr"):
            if self._local_conn_mat_path is not None:
                if not os.path.isfile(self._local_conn_mat_path):
                    raise FileNotFoundError(
                        "File containing sparse matrix of local connectivity"
                        " structure not found.")
            else:
                raise FileNotFoundError(
                    "File containing sparse matrix of local connectivity"
                    " structure not found.")

        if (self.clust_type == "complete" or self.clust_type == "average"
                or self.clust_type == "single" or self.clust_type == "ward"
                or (self.clust_type == "rena" and self.num_conn_comps == 1)
                or (self.clust_type == "kmeans" and self.num_conn_comps == 1)):
            _clust_est = Parcellations(
                method=self.clust_type,
                standardize=self._standardize,
                detrend=self._detrending,
                n_parcels=self.k,
                mask=self._clust_mask_corr_img,
                connectivity=self._local_conn,
                mask_strategy="background",
                memory_level=2,
                random_state=42,
            )

            if self.conf is not None:
                import pandas as pd

                confounds = pd.read_csv(self.conf, sep="\t")
                if confounds.isnull().values.any():
                    conf_corr = fill_confound_nans(confounds, self._dir_path)
                    _clust_est.fit(func_boot_img, confounds=conf_corr)
                else:
                    _clust_est.fit(func_boot_img, confounds=self.conf)
            else:
                _clust_est.fit(func_boot_img)

            _clust_est.labels_img_.set_data_dtype(np.uint16)
            print(f"{self.clust_type}{self.k}"
                  f"{(' clusters: %.2fs' % (time.time() - start))}")
            return _clust_est.labels_img_

        elif self.clust_type == "ncut":
            out_img = parcellate_ncut(self._local_conn, self.k,
                                      self._clust_mask_corr_img)
            out_img.set_data_dtype(np.uint16)
            print(f"{self.clust_type}{self.k}"
                  f"{(' clusters: %.2fs' % (time.time() - start))}")
            return out_img

        elif (self.clust_type == "rena"
              or self.clust_type == "kmeans" and self.num_conn_comps > 1):
            from pynets.core import nodemaker
            from nilearn.regions import connected_regions, Parcellations
            from nilearn.image import iter_img, new_img_like
            from pynets.core.utils import flatten, proportional

            mask_img_list = []
            mask_voxels_dict = dict()
            for i, mask_img in enumerate(list(iter_img(self._conn_comps))):
                mask_voxels_dict[i] = np.int(
                    np.sum(np.asarray(mask_img.dataobj)))
                mask_img_list.append(mask_img)

            # Allocate k across connected components using Hagenbach-Bischoff
            # Quota based on number of voxels
            k_list = proportional(self.k, list(mask_voxels_dict.values()))

            conn_comp_atlases = []
            print(f"Building {len(mask_img_list)} separate atlases with "
                  f"voxel-proportional k clusters for each "
                  f"connected component...")
            for i, mask_img in enumerate(mask_img_list):
                if k_list[i] == 0:
                    # print('0 voxels in component. Discarding...')
                    continue
                _clust_est = Parcellations(
                    method=self.clust_type,
                    standardize=self._standardize,
                    detrend=self._detrending,
                    n_parcels=k_list[i],
                    mask=mask_img,
                    mask_strategy="background",
                    memory_level=2,
                    random_state=42,
                )
                if self.conf is not None:
                    import pandas as pd

                    confounds = pd.read_csv(self.conf, sep="\t")
                    if confounds.isnull().values.any():
                        conf_corr = fill_confound_nans(confounds,
                                                       self._dir_path)
                        _clust_est.fit(func_boot_img, confounds=conf_corr)
                    else:
                        _clust_est.fit(func_boot_img, confounds=self.conf)
                else:
                    _clust_est.fit(func_boot_img)
                conn_comp_atlases.append(_clust_est.labels_img_)

            # Then combine the multiple atlases, corresponding to each
            # connected component, into a single atlas
            atlas_of_atlases = []
            for atlas in conn_comp_atlases:
                bna_data = np.around(np.asarray(
                    atlas.dataobj)).astype("uint16")

                # Get an array of unique parcels
                bna_data_for_coords_uniq = np.unique(bna_data)

                # Number of parcels:
                par_max = len(bna_data_for_coords_uniq) - 1
                img_stack = []
                for idx in range(1, par_max + 1):
                    roi_img = bna_data == bna_data_for_coords_uniq[idx].astype(
                        "uint16")
                    img_stack.append(roi_img.astype("uint16"))
                img_stack = np.array(img_stack)

                img_list = []
                for idy in range(par_max):
                    img_list.append(new_img_like(atlas, img_stack[idy]))
                atlas_of_atlases.append(img_list)
                del img_list, img_stack, bna_data

            atlas_of_atlases = list(flatten(atlas_of_atlases))

            [super_atlas_ward,
             _] = nodemaker.create_parcel_atlas(atlas_of_atlases)
            super_atlas_ward.set_data_dtype(np.uint16)
            del atlas_of_atlases, conn_comp_atlases, mask_img_list, \
                mask_voxels_dict

            print(f"{self.clust_type}{self.k}"
                  f"{(' clusters: %.2fs' % (time.time() - start))}")
            return super_atlas_ward
Exemplo n.º 19
0
def collect_pandas_df(network, ID, net_mets_csv_list, plot_switch, multi_nets, multimodal):
    """
    API for summarizing independent lists of pickled pandas dataframes of graph metrics for each modality, RSN, and roi.

    Parameters
    ----------
    network : str
        Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default') used to filter nodes in the
        study of brain subgraphs.
    ID : str
        A subject id or other unique identifier.
    net_mets_csv_list : list
        List of file paths to pickled pandas dataframes as themselves.
    plot_switch : bool
        Activate summary plotting (histograms, ROC curves, etc.)
    multi_nets : list
        List of Yeo RSN's specified in workflow(s).
    multimodal : bool
        Indicates whether multiple modalities of input data have been specified.

    Returns
    -------
    combination_complete : bool
        If True, then collect_pandas_df completed successfully
    """
    from pathlib import Path
    import yaml
    from pynets.core.utils import flatten
    from pynets.stats.netstats import collect_pandas_df_make

    # Available functional and structural connectivity models
    with open("%s%s" % (str(Path(__file__).parent.parent), '/runconfig.yaml'), 'r') as stream:
        hardcoded_params = yaml.load(stream)
        try:
            func_models = hardcoded_params['available_models']['func_models']
        except KeyError:
            print('ERROR: available functional models not sucessfully extracted from runconfig.yaml')
        try:
            struct_models = hardcoded_params['available_models']['struct_models']
        except KeyError:
            print('ERROR: available structural models not sucessfully extracted from runconfig.yaml')

    net_mets_csv_list = list(flatten(net_mets_csv_list))

    if multi_nets is not None:
        net_mets_csv_list_nets = net_mets_csv_list
        for network in multi_nets:
            net_mets_csv_list = list(set([i for i in net_mets_csv_list_nets if network in i]))
            if multimodal is True:
                net_mets_csv_list_dwi = list(set([i for i in net_mets_csv_list if i.split('mets_')[1].split('_')[0]
                                                   in struct_models]))
                combination_complete_dwi = collect_pandas_df_make(net_mets_csv_list_dwi, ID, network, plot_switch)
                net_mets_csv_list_func = list(set([i for i in net_mets_csv_list if
                                                    i.split('mets_')[1].split('_')[0] in func_models]))
                combination_complete_func = collect_pandas_df_make(net_mets_csv_list_func, ID, network, plot_switch)

                if combination_complete_dwi is True and combination_complete_func is True:
                    combination_complete = True
                else:
                    combination_complete = False
            else:
                combination_complete = collect_pandas_df_make(net_mets_csv_list, ID, network, plot_switch)
    else:
        if multimodal is True:
            net_mets_csv_list_dwi = list(set([i for i in net_mets_csv_list if i.split('mets_')[1].split('_')[0] in
                                               struct_models]))
            combination_complete_dwi = collect_pandas_df_make(net_mets_csv_list_dwi, ID, network, plot_switch)
            net_mets_csv_list_func = list(set([i for i in net_mets_csv_list if i.split('mets_')[1].split('_')[0]
                                                in func_models]))
            combination_complete_func = collect_pandas_df_make(net_mets_csv_list_func, ID, network, plot_switch)

            if combination_complete_dwi is True and combination_complete_func is True:
                combination_complete = True
            else:
                combination_complete = False
        else:
            combination_complete = collect_pandas_df_make(net_mets_csv_list, ID, network, plot_switch)

    return combination_complete
Exemplo n.º 20
0
def streams2graph(atlas_mni,
                  streams,
                  overlap_thr,
                  dir_path,
                  track_type,
                  target_samples,
                  conn_model,
                  network,
                  node_size,
                  dens_thresh,
                  ID,
                  roi,
                  min_span_tree,
                  disp_filt,
                  parc,
                  prune,
                  atlas,
                  uatlas,
                  labels,
                  coords,
                  norm,
                  binary,
                  directget,
                  warped_fa,
                  error_margin,
                  max_length,
                  fa_wei=True):
    '''
    Use tracked streamlines as a basis for estimating a structural connectome.

    Parameters
    ----------
    atlas_mni : str
        File path to atlas parcellation Nifti1Image in T1w-warped MNI space.
    streams : str
        File path to streamline array sequence in .trk format.
    overlap_thr : int
        Number of voxels for which a given streamline must intersect with an ROI
        for an edge to be counted.
    dir_path : str
        Path to directory containing subject derivative data for a given pynets run.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    target_samples : int
        Total number of streamline samples specified to generate streams.
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    network : str
        Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_size : int
        Spherical centroid node size in the case that coordinate-based centroids
        are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone network' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    uatlas : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    directget : str
        The statistical approach to tracking. Options are: det (deterministic), closest (clos), boot (bootstrapped),
        and prob (probabilistic).
    warped_fa : str
        File path to MNI-space warped FA Nifti1Image.
    error_margin : int
        Euclidean margin of error for classifying a streamline as a connection to an ROI. Default is 2 voxels.
    max_length : int
        Maximum fiber length threshold in mm to restrict tracking.
    fa_wei :  bool
        Scale streamline count edges by fractional anistropy (FA). Default is False.

    Returns
    -------
    atlas_mni : str
        File path to atlas parcellation Nifti1Image in T1w-warped MNI space.
    streams : str
        File path to streamline array sequence in .trk format.
    conn_matrix : array
        Adjacency matrix stored as an m x n array of nodes and edges.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    target_samples : int
        Total number of streamline samples specified to generate streams.
    dir_path : str
        Path to directory containing subject derivative data for given run.
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    network : str
        Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_size : int
        Spherical centroid node size in the case that coordinate-based centroids
        are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone network' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    uatlas : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    directget : str
        The statistical approach to tracking. Options are: det (deterministic), closest (clos), boot (bootstrapped),
        and prob (probabilistic).
    max_length : int
        Maximum fiber length threshold in mm to restrict tracking.
    '''
    from dipy.tracking.streamline import Streamlines, values_from_volume
    from dipy.tracking._utils import (_mapping_to_voxel, _to_voxel_coordinates)
    import networkx as nx
    from itertools import combinations
    from collections import defaultdict
    from pynets.core import utils, nodemaker
    from dipy.io.streamline import load_tractogram
    from dipy.io.stateful_tractogram import Space, Origin
    import time

    # Load parcellation
    roi_img = nib.load(atlas_mni)
    atlas_data = np.around(roi_img.get_fdata())
    roi_zooms = roi_img.header.get_zooms()
    roi_shape = roi_img.shape

    # Read Streamlines
    streamlines = Streamlines(
        load_tractogram(streams,
                        roi_img,
                        to_space=Space.RASMM,
                        to_origin=Origin.TRACKVIS,
                        bbox_valid_check=False).streamlines)
    roi_img.uncache()

    fa_weights = values_from_volume(
        nib.load(warped_fa).get_fdata(), streamlines, np.eye(4))
    global_fa_weights = list(utils.flatten(fa_weights))
    min_global_fa_wei = min(global_fa_weights)
    max_global_fa_wei = max(global_fa_weights)
    fa_weights_norm = []
    for val_list in fa_weights:
        fa_weights_norm.append((val_list - min_global_fa_wei) /
                               (max_global_fa_wei - min_global_fa_wei))

    # Instantiate empty networkX graph object & dictionary and create voxel-affine mapping
    lin_T, offset = _mapping_to_voxel(np.eye(4))
    mx = len(np.unique(atlas_data.astype('uint16'))) - 1
    g = nx.Graph(ecount=0, vcount=mx)
    edge_dict = defaultdict(int)
    node_dict = dict(
        zip(np.unique(atlas_data.astype('uint16')) + 1,
            np.arange(mx) + 1))

    # Add empty vertices
    for node in range(1, mx + 1):
        g.add_node(node)

    # Build graph
    start_time = time.time()

    ix = 0
    for s in streamlines:
        # Map the streamlines coordinates to voxel coordinates and get labels for label_volume
        i, j, k = np.vstack(
            np.array([
                nodemaker.get_sphere(coord, error_margin, roi_zooms, roi_shape)
                for coord in _to_voxel_coordinates(s, lin_T, offset)
            ])).T

        # get labels for label_volume
        lab_arr = atlas_data[i, j, k]
        endlabels = []
        for lab in np.unique(lab_arr).astype('uint32'):
            if (lab > 0) and (np.sum(lab_arr == lab) >= overlap_thr):
                try:
                    endlabels.append(node_dict[lab])
                except UserWarning:
                    print("%s%s%s" % (
                        'Label ', lab,
                        ' missing from parcellation. Check registration and ensure valid '
                        'input parcellation file.'))

        edges = combinations(endlabels, 2)
        for edge in edges:
            lst = tuple([int(node) for node in edge])
            edge_dict[tuple(sorted(lst))] += 1

        edge_list = [(k[0], k[1], v) for k, v in edge_dict.items()]

        if fa_wei is True:
            # Add edgelist to g, weighted by average fa of the streamline
            g.add_weighted_edges_from(edge_list,
                                      weight=np.nanmean(fa_weights_norm[ix]))
        else:
            g.add_weighted_edges_from(edge_list)
        ix = ix + 1

    print("%s%s%s" % ('Graph construction runtime: ',
                      np.round(time.time() - start_time, 1), 's'))
    del streamlines

    if fa_wei is True:
        # Add average fa weights to streamline counts
        for u, v in list(g.edges):
            h = g.get_edge_data(u, v)
            edge_att_dict = {}
            for e, w in h.items():
                if w not in edge_att_dict.keys():
                    edge_att_dict[w] = []
                else:
                    edge_att_dict[w].append(e)
            for key in edge_att_dict.keys():
                edge_att_dict[key] = np.nanmean(edge_att_dict[key])
            vals = []
            for e2, w2 in edge_att_dict.items():
                vals.append(float(e2) * float(w2))
            g.edges[u, v].update({'weight': np.nanmean(vals)})

    # Convert to numpy matrix
    conn_matrix_raw = nx.to_numpy_matrix(g)

    # Enforce symmetry
    conn_matrix = np.maximum(conn_matrix_raw, conn_matrix_raw.T)

    return atlas_mni, streams, conn_matrix, track_type, target_samples, dir_path, conn_model, network, node_size, dens_thresh, ID, roi, min_span_tree, disp_filt, parc, prune, atlas, uatlas, labels, coords, norm, binary, directget, max_length
Exemplo n.º 21
0
def pass_meta_outs(conn_model_iterlist, est_path_iterlist, network_iterlist, thr_iterlist,
                   prune_iterlist, ID_iterlist, roi_iterlist, norm_iterlist, binary_iterlist, embed,
                   multimodal, multiplex):
    """
    Passes lists of iterable parameters as metadata.

    Parameters
    ----------
    conn_model_iterlist : list
       List of connectivity estimation model parameters (e.g. corr for correlation, cov for covariance,
       sps for precision covariance, partcorr for partial correlation). sps type is used by default.
    est_path_iterlist : list
        List of file paths to .npy file containing graph with thresholding applied.
    network_iterlist : list
        List of resting-state networks based on Yeo-7 and Yeo-17 naming (e.g. 'Default') used to filter nodes in the
        study of brain subgraphs.
    thr_iterlist : list
        List of values, between 0 and 1, to threshold the graph using any variety of methods
        triggered through other options.
    prune_iterlist : list
        List of booleans indicating whether final graphs were pruned of disconnected nodes/isolates.
    ID_iterlist : list
        List of repeated subject id strings.
    roi_iterlist : list
        List of file paths to binarized/boolean region-of-interest Nifti1Image files.
    norm_iterlist : list
        Indicates method of normalizing resulting graph.
    binary_iterlist : list
        List of booleans indicating whether resulting graph edges to form an unweighted graph were binarized.
    embed : str
        Embed the ensemble(s) produced into feature vector(s). Options include: omni or mase.
    multimodal : bool
        Boolean indicating whether multiple modalities of input data have been specified.
    multiplex : int
        Switch indicating approach to multiplex graph analysis if multimodal is also True.

    Returns
    -------
    conn_model_iterlist : list
       List of connectivity estimation model parameters (e.g. corr for correlation, cov for covariance,
       sps for precision covariance, partcorr for partial correlation). sps type is used by default.
    est_path_iterlist : list
        List of file paths to .npy file containing graph with thresholding applied.
    network_iterlist : list
        List of resting-state networks based on Yeo-7 and Yeo-17 naming (e.g. 'Default') used to filter nodes in the
        study of brain subgraphs.
    thr_iterlist : list
        List of values, between 0 and 1, to threshold the graph using any variety of methods
        triggered through other options.
    prune_iterlist : list
        List of booleans indicating whether final graphs were pruned of disconnected nodes/isolates.
    ID_iterlist : list
        List of repeated subject id strings.
    roi_iterlist : list
        List of file paths to binarized/boolean region-of-interest Nifti1Image files.
    norm_iterlist : list
        Indicates method of normalizing resulting graph.
    binary_iterlist : list
        List of booleans indicating whether resulting graph edges to form an unweighted graph were binarized.
    embed_iterlist : list
        List of booleans indicating whether omnibus embedding of graph population was performed.
    multimodal_iterlist : list
        List of booleans indicating whether multiple modalities of input data have been specified.
    """
    from pynets.core.utils import flatten
    from pynets.stats import netmotifs, embeddings

    if embed is not None:
        embeddings.build_embedded_connectome(list(flatten(est_path_iterlist)), list(flatten(ID_iterlist))[0],
                                             multimodal, embed)

    if (multiplex > 0) and (multimodal is True):
        multigraph_list_all = netmotifs.build_multigraphs(est_path_iterlist, list(flatten(ID_iterlist))[0])

    return (conn_model_iterlist, est_path_iterlist, network_iterlist, thr_iterlist, prune_iterlist, ID_iterlist,
            roi_iterlist, norm_iterlist, binary_iterlist)
Exemplo n.º 22
0
def streams2graph(atlas_mni,
                  streams,
                  overlap_thr,
                  dir_path,
                  track_type,
                  target_samples,
                  conn_model,
                  network,
                  node_size,
                  dens_thresh,
                  ID,
                  roi,
                  min_span_tree,
                  disp_filt,
                  parc,
                  prune,
                  atlas,
                  uatlas,
                  labels,
                  coords,
                  norm,
                  binary,
                  directget,
                  warped_fa,
                  error_margin,
                  min_length,
                  fa_wei=True):
    '''
    Use tracked streamlines as a basis for estimating a structural connectome.

    Parameters
    ----------
    atlas_mni : str
        File path to atlas parcellation Nifti1Image in T1w-warped MNI space.
    streams : str
        File path to streamline array sequence in .trk format.
    overlap_thr : int
        Number of voxels for which a given streamline must intersect with an ROI
        for an edge to be counted.
    dir_path : str
        Path to directory containing subject derivative data for a given pynets run.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    target_samples : int
        Total number of streamline samples specified to generate streams.
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    network : str
        Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_size : int
        Spherical centroid node size in the case that coordinate-based centroids
        are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone network' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    uatlas : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    directget : str
        The statistical approach to tracking. Options are: det (deterministic), closest (clos), boot (bootstrapped),
        and prob (probabilistic).
    warped_fa : str
        File path to MNI-space warped FA Nifti1Image.
    error_margin : int
        Euclidean margin of error for classifying a streamline as a connection to an ROI. Default is 2 voxels.
    min_length : int
        Minimum fiber length threshold in mm to restrict tracking.
    fa_wei :  bool
        Scale streamline count edges by fractional anistropy (FA). Default is False.

    Returns
    -------
    atlas_mni : str
        File path to atlas parcellation Nifti1Image in T1w-warped MNI space.
    streams : str
        File path to streamline array sequence in .trk format.
    conn_matrix : array
        Adjacency matrix stored as an m x n array of nodes and edges.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    target_samples : int
        Total number of streamline samples specified to generate streams.
    dir_path : str
        Path to directory containing subject derivative data for given run.
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    network : str
        Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_size : int
        Spherical centroid node size in the case that coordinate-based centroids
        are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone network' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    uatlas : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    directget : str
        The statistical approach to tracking. Options are: det (deterministic),
        closest (clos), boot (bootstrapped), and prob (probabilistic).
    min_length : int
        Minimum fiber length threshold in mm to restrict tracking.

    References
    ----------
    .. [1] Sporns, O., Tononi, G., & Kötter, R. (2005). The human connectome:
      A structural description of the human brain. PLoS Computational Biology.
      https://doi.org/10.1371/journal.pcbi.0010042
    .. [2] Sotiropoulos, S. N., & Zalesky, A. (2019). Building connectomes
      using diffusion MRI: why, how and but. NMR in Biomedicine.
      https://doi.org/10.1002/nbm.3752
    .. [3] Chung, M. K., Hanson, J. L., Adluru, N., Alexander, A. L., Davidson,
      R. J., & Pollak, S. D. (2017). Integrative Structural Brain Network
      Analysis in Diffusion Tensor Imaging. Brain Connectivity.
      https://doi.org/10.1089/brain.2016.0481

    '''
    import gc
    import time
    from dipy.tracking.streamline import Streamlines, values_from_volume
    from dipy.tracking._utils import (_mapping_to_voxel, _to_voxel_coordinates)
    import networkx as nx
    from itertools import combinations
    from collections import defaultdict
    from pynets.core import utils, nodemaker
    from pynets.dmri.dmri_utils import generate_sl
    from dipy.io.streamline import load_tractogram
    from dipy.io.stateful_tractogram import Space, Origin

    start = time.time()

    # Load parcellation
    roi_img = nib.load(atlas_mni)
    atlas_data = np.around(np.asarray(roi_img.dataobj))
    roi_zooms = roi_img.header.get_zooms()
    roi_shape = roi_img.shape

    # Read Streamlines
    streamlines = [
        i.astype(np.float32) for i in Streamlines(
            load_tractogram(streams,
                            roi_img,
                            to_space=Space.RASMM,
                            to_origin=Origin.TRACKVIS,
                            bbox_valid_check=False).streamlines)
    ]
    roi_img.uncache()

    if fa_wei is True:
        fa_weights = values_from_volume(
            np.asarray(nib.load(warped_fa).dataobj), streamlines, np.eye(4))
        global_fa_weights = list(utils.flatten(fa_weights))
        min_global_fa_wei = min(i for i in global_fa_weights if i > 0)
        max_global_fa_wei = max(global_fa_weights)
        fa_weights_norm = []
        # Here we normalize by global FA
        for val_list in fa_weights:
            fa_weights_norm.append(
                np.nanmean((val_list - min_global_fa_wei) /
                           (max_global_fa_wei - min_global_fa_wei)))

    # Make streamlines into generators to keep memory at a minimum
    sl = [generate_sl(i) for i in streamlines]
    del streamlines

    # Instantiate empty networkX graph object & dictionary and create voxel-affine mapping
    lin_T, offset = _mapping_to_voxel(np.eye(4))
    mx = len(np.unique(atlas_data.astype('uint16'))) - 1
    g = nx.Graph(ecount=0, vcount=mx)
    edge_dict = defaultdict(int)
    node_dict = dict(
        zip(np.unique(atlas_data.astype('uint16'))[1:],
            np.arange(mx) + 1))

    # Add empty vertices
    for node in range(1, mx + 1):
        g.add_node(node)

    # Build graph
    ix = 0
    bad_idxs = []
    for s in sl:
        # Map the streamlines coordinates to voxel coordinates and get labels for label_volume
        vox_coords = _to_voxel_coordinates(Streamlines(s), lin_T, offset)
        lab_coords = [
            nodemaker.get_sphere(coord, error_margin, roi_zooms, roi_shape)
            for coord in vox_coords
        ]
        [i, j, k] = np.vstack(np.array(lab_coords)).T

        # get labels for label_volume
        lab_arr = atlas_data[i, j, k]
        endlabels = []
        for ix, lab in enumerate(np.unique(lab_arr).astype('uint32')):
            if (lab > 0) and (np.sum(lab_arr == lab) >= overlap_thr):
                try:
                    endlabels.append(node_dict[lab])
                except:
                    bad_idxs.append(ix)
                    print(
                        f"Label {lab} missing from parcellation. Check registration and ensure valid input "
                        f"parcellation file.")

        edges = combinations(endlabels, 2)
        for edge in edges:
            lst = tuple([int(node) for node in edge])
            edge_dict[tuple(sorted(lst))] += 1

        edge_list = [(k[0], k[1], v) for k, v in edge_dict.items()]

        if fa_wei is True:
            # Add edgelist to g, weighted by average fa of the streamline
            g.add_weighted_edges_from(edge_list, weight=fa_weights_norm[ix])
        else:
            g.add_weighted_edges_from(edge_list)
        ix = ix + 1

        del lab_coords, lab_arr, endlabels, edges, edge_list

    gc.collect()

    if fa_wei is True:
        # Add average fa weights to streamline counts
        for u, v in list(g.edges):
            h = g.get_edge_data(u, v)
            edge_att_dict = {}
            for e, w in h.items():
                if w not in edge_att_dict.keys():
                    edge_att_dict[w] = []
                else:
                    edge_att_dict[w].append(e)
            for key in edge_att_dict.keys():
                edge_att_dict[key] = np.nanmean(edge_att_dict[key])
            vals = []
            for e2, w2 in edge_att_dict.items():
                vals.append(float(e2) * float(w2))
            g.edges[u, v].update({'weight': np.nanmean(vals)})

    # Convert to numpy matrix
    conn_matrix_raw = nx.to_numpy_array(g)

    # Impose symmetry
    conn_matrix = np.maximum(conn_matrix_raw, conn_matrix_raw.T)

    print('Graph Building Complete:\n', str(time.time() - start))

    if len(bad_idxs) > 0:
        bad_idxs = sorted(list(set(bad_idxs)), reverse=True)
        for j in bad_idxs:
            del labels[j], coords[j]

    coords = np.array(coords)
    labels = np.array(labels)

    return (atlas_mni, streams, conn_matrix, track_type, target_samples,
            dir_path, conn_model, network, node_size, dens_thresh, ID, roi,
            min_span_tree, disp_filt, parc, prune, atlas, uatlas, labels,
            coords, norm, binary, directget, min_length)
Exemplo n.º 23
0
def streams2graph(atlas_mni, streams, dir_path, track_type, target_samples,
                  conn_model, network, node_size, dens_thresh, ID, roi,
                  min_span_tree, disp_filt, parc, prune, atlas, uatlas, labels,
                  coords, norm, binary, directget, warped_fa, min_length,
                  error_margin):
    """
    Use tracked streamlines as a basis for estimating a structural connectome.

    Parameters
    ----------
    atlas_mni : str
        File path to atlas parcellation Nifti1Image in T1w-warped MNI space.
    streams : str
        File path to streamline array sequence in .trk format.
    dir_path : str
        Path to directory containing subject derivative data for a given
        pynets run.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    target_samples : int
        Total number of streamline samples specified to generate streams.
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    network : str
        Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_size : int
        Spherical centroid node size in the case that coordinate-based
        centroids are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone network' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    uatlas : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    directget : str
        The statistical approach to tracking. Options are:
        det (deterministic), closest (clos), boot (bootstrapped),
        and prob (probabilistic).
    warped_fa : str
        File path to MNI-space warped FA Nifti1Image.
    min_length : int
        Minimum fiber length threshold in mm to restrict tracking.
    error_margin : int
        Euclidean margin of error for classifying a streamline as a connection
         to an ROI. Default is 2 voxels.

    Returns
    -------
    atlas_mni : str
        File path to atlas parcellation Nifti1Image in T1w-warped MNI space.
    streams : str
        File path to streamline array sequence in .trk format.
    conn_matrix : array
        Adjacency matrix stored as an m x n array of nodes and edges.
    track_type : str
        Tracking algorithm used (e.g. 'local' or 'particle').
    target_samples : int
        Total number of streamline samples specified to generate streams.
    dir_path : str
        Path to directory containing subject derivative data for given run.
    conn_model : str
        Connectivity reconstruction method (e.g. 'csa', 'tensor', 'csd').
    network : str
        Resting-state network based on Yeo-7 and Yeo-17 naming (e.g. 'Default')
        used to filter nodes in the study of brain subgraphs.
    node_size : int
        Spherical centroid node size in the case that coordinate-based
        centroids are used as ROI's for tracking.
    dens_thresh : bool
        Indicates whether a target graph density is to be used as the basis for
        thresholding.
    ID : str
        A subject id or other unique identifier.
    roi : str
        File path to binarized/boolean region-of-interest Nifti1Image file.
    min_span_tree : bool
        Indicates whether local thresholding from the Minimum Spanning Tree
        should be used.
    disp_filt : bool
        Indicates whether local thresholding using a disparity filter and
        'backbone network' should be used.
    parc : bool
        Indicates whether to use parcels instead of coordinates as ROI nodes.
    prune : bool
        Indicates whether to prune final graph of disconnected nodes/isolates.
    atlas : str
        Name of atlas parcellation used.
    uatlas : str
        File path to atlas parcellation Nifti1Image in MNI template space.
    labels : list
        List of string labels corresponding to graph nodes.
    coords : list
        List of (x, y, z) tuples corresponding to a coordinate atlas used or
        which represent the center-of-mass of each parcellation node.
    norm : int
        Indicates method of normalizing resulting graph.
    binary : bool
        Indicates whether to binarize resulting graph edges to form an
        unweighted graph.
    directget : str
        The statistical approach to tracking. Options are: det (deterministic),
        closest (clos), boot (bootstrapped), and prob (probabilistic).
    min_length : int
        Minimum fiber length threshold in mm to restrict tracking.
    error_margin : int
        Euclidean margin of error for classifying a streamline as a connection
         to an ROI. Default is 2 voxels.

    References
    ----------
    .. [1] Sporns, O., Tononi, G., & Kötter, R. (2005). The human connectome:
      A structural description of the human brain. PLoS Computational Biology.
      https://doi.org/10.1371/journal.pcbi.0010042
    .. [2] Sotiropoulos, S. N., & Zalesky, A. (2019). Building connectomes
      using diffusion MRI: why, how and but. NMR in Biomedicine.
      https://doi.org/10.1002/nbm.3752
    .. [3] Chung, M. K., Hanson, J. L., Adluru, N., Alexander, A. L., Davidson,
      R. J., & Pollak, S. D. (2017). Integrative Structural Brain Network
      Analysis in Diffusion Tensor Imaging. Brain Connectivity.
      https://doi.org/10.1089/brain.2016.0481
    """
    import gc
    import time
    import pkg_resources
    import sys
    import yaml
    from dipy.tracking.streamline import Streamlines, values_from_volume
    from dipy.tracking._utils import _mapping_to_voxel, _to_voxel_coordinates
    import networkx as nx
    from itertools import combinations
    from collections import defaultdict
    from pynets.core import utils, nodemaker
    from pynets.dmri.dmri_utils import generate_sl
    from dipy.io.streamline import load_tractogram
    from dipy.io.stateful_tractogram import Space, Origin

    with open(pkg_resources.resource_filename("pynets", "runconfig.yaml"),
              "r") as stream:
        hardcoded_params = yaml.load(stream)
        fa_wei = hardcoded_params["StructuralNetworkWeighting"][
            "fa_weighting"][0]
        fiber_density = hardcoded_params["StructuralNetworkWeighting"][
            "fiber_density"][0]
        overlap_thr = hardcoded_params["StructuralNetworkWeighting"][
            "overlap_thr"][0]
        roi_neighborhood_tol = \
        hardcoded_params['tracking']["roi_neighborhood_tol"][0]
    stream.close()

    start = time.time()

    if float(roi_neighborhood_tol) <= float(error_margin):
        try:
            raise ValueError('roi_neighborhood_tol preset cannot be less than '
                             'the value of the structural connectome error'
                             '_margin parameter.')
        except ValueError:
            import sys
            sys.exit(1)
    else:
        print(f"Using fiber-roi intersection tolerance: {error_margin}...")

    # Load FA
    fa_img = nib.load(warped_fa)

    # Load parcellation
    roi_img = nib.load(atlas_mni)
    atlas_data = np.around(np.asarray(roi_img.dataobj))
    roi_zooms = roi_img.header.get_zooms()
    roi_shape = roi_img.shape

    # Read Streamlines
    streamlines = [
        i.astype(np.float32) for i in Streamlines(
            load_tractogram(
                streams, fa_img, to_origin=Origin.NIFTI,
                to_space=Space.VOXMM).streamlines)
    ]

    # from fury import actor, window
    # renderer = window.Renderer()
    # template_actor = actor.contour_from_roi(roi_img.get_fdata(),
    #                                         color=(50, 50, 50), opacity=0.05)
    # renderer.add(template_actor)
    # lines_actor = actor.streamtube(streamlines, window.colors.orange,
    #                                linewidth=0.3)
    # renderer.add(lines_actor)
    # window.show(renderer)

    roi_img.uncache()

    if fa_wei is True:
        fa_weights = values_from_volume(
            np.asarray(fa_img.dataobj, dtype=np.float32), streamlines,
            np.eye(4))
        global_fa_weights = list(utils.flatten(fa_weights))
        min_global_fa_wei = min([i for i in global_fa_weights if i > 0])
        max_global_fa_wei = max(global_fa_weights)
        fa_weights_norm = []
        # Here we normalize by global FA
        for val_list in fa_weights:
            fa_weights_norm.append(
                np.nanmean((val_list - min_global_fa_wei) /
                           (max_global_fa_wei - min_global_fa_wei)))

    # Make streamlines into generators to keep memory at a minimum
    total_streamlines = len(streamlines)
    sl = [generate_sl(i) for i in streamlines]
    del streamlines

    # Instantiate empty networkX graph object & dictionary and create
    # voxel-affine mapping
    lin_T, offset = _mapping_to_voxel(np.eye(4))
    mx = len(np.unique(atlas_data.astype("uint16"))) - 1
    g = nx.Graph(ecount=0, vcount=mx)
    edge_dict = defaultdict(int)
    node_dict = dict(
        zip(np.unique(atlas_data.astype("uint16"))[1:],
            np.arange(mx) + 1))

    # Add empty vertices with label volume attributes
    for node in range(1, mx + 1):
        g.add_node(node,
                   roi_volume=np.sum(atlas_data.astype("uint16") == node))

    # Build graph
    pc = 0
    bad_idxs = []
    fiberlengths = {}
    fa_weights_dict = {}
    print(f"Quantifying fiber-ROI intersection for {atlas}:")
    for ix, s in enumerate(sl):
        # Percent counter
        pcN = int(round(100 * float(ix / total_streamlines)))
        if pcN % 10 == 0 and ix > 0 and pcN > pc:
            pc = pcN
            print(f"{pcN}%")

        # Map the streamlines coordinates to voxel coordinates and get labels
        # for label_volume
        vox_coords = _to_voxel_coordinates(Streamlines(s), lin_T, offset)

        lab_coords = [
            nodemaker.get_sphere(coord, error_margin, roi_zooms, roi_shape)
            for coord in vox_coords
        ]
        [i, j, k] = np.vstack(np.array(lab_coords)).T

        # get labels for label_volume
        lab_arr = atlas_data[i, j, k]
        # print(lab_arr)
        endlabels = []
        for jx, lab in enumerate(np.unique(lab_arr).astype("uint32")):
            if (lab > 0) and (np.sum(lab_arr == lab) >= overlap_thr):
                try:
                    endlabels.append(node_dict[lab])
                except BaseException:
                    bad_idxs.append(jx)
                    print(f"Label {lab} missing from parcellation. Check "
                          f"registration and ensure valid input parcellation "
                          f"file.")

        edges = combinations(endlabels, 2)
        for edge in edges:
            # Get fiber lengths along edge
            if fiber_density is True:
                if not (edge[0], edge[1]) in fiberlengths.keys():
                    fiberlengths[(edge[0], edge[1])] = [len(vox_coords)]
                else:
                    fiberlengths[(edge[0], edge[1])].append(len(vox_coords))

            # Get FA values along edge
            if fa_wei is True:
                if not (edge[0], edge[1]) in fa_weights_dict.keys():
                    fa_weights_dict[(edge[0], edge[1])] = [fa_weights_norm[ix]]
                else:
                    fa_weights_dict[(edge[0],
                                     edge[1])].append(fa_weights_norm[ix])

            lst = tuple([int(node) for node in edge])
            edge_dict[tuple(sorted(lst))] += 1

        edge_list = [(k[0], k[1], count) for k, count in edge_dict.items()]

        g.add_weighted_edges_from(edge_list)

        del lab_coords, lab_arr, endlabels, edges, edge_list

    gc.collect()

    # Add fiber density attributes for each edge
    # Adapted from the nnormalized fiber-density estimation routines of
    # Sebastian Tourbier.
    if fiber_density is True:
        print("Weighting edges by fiber density...")
        # Summarize total fibers and total label volumes
        total_fibers = 0
        total_volume = 0
        u_start = -1
        for u, v, d in g.edges(data=True):
            total_fibers += len(d)
            if u != u_start:
                total_volume += g.nodes[int(u)]['roi_volume']
            u_start = u

        ix = 0
        for u, v, d in g.edges(data=True):
            if d['weight'] > 0:
                edge_fiberlength_mean = np.nanmean(fiberlengths[(u, v)])
                fiber_density = (float(
                    ((float(d['weight']) / float(total_fibers)) /
                     float(edge_fiberlength_mean)) *
                    ((2.0 * float(total_volume)) /
                     (g.nodes[int(u)]['roi_volume'] +
                      g.nodes[int(v)]['roi_volume'])))) * 1000
            else:
                fiber_density = 0
            g.edges[u, v].update({"fiber_density": fiber_density})
            ix += 1

    if fa_wei is True:
        print("Weighting edges by FA...")
        # Add FA attributes for each edge
        ix = 0
        for u, v, d in g.edges(data=True):
            if d['weight'] > 0:
                edge_average_fa = np.nanmean(fa_weights_dict[(u, v)])
            else:
                edge_average_fa = np.nan
            g.edges[u, v].update({"fa_weight": edge_average_fa})
            ix += 1

    # Summarize weights
    if fa_wei is True and fiber_density is True:
        for u, v, d in g.edges(data=True):
            g.edges[u, v].update(
                {"final_weight": (d['fa_weight']) * d['fiber_density']})
    elif fiber_density is True and fa_wei is False:
        for u, v, d in g.edges(data=True):
            g.edges[u, v].update({"final_weight": d['fiber_density']})
    elif fa_wei is True and fiber_density is False:
        for u, v, d in g.edges(data=True):
            g.edges[u,
                    v].update({"final_weight": d['fa_weight'] * d['weight']})
    else:
        for u, v, d in g.edges(data=True):
            g.edges[u, v].update({"final_weight": d['weight']})

    # Convert weighted graph to numpy matrix
    conn_matrix_raw = nx.to_numpy_array(g, weight='final_weight')

    # Enforce symmetry
    conn_matrix = np.maximum(conn_matrix_raw, conn_matrix_raw.T)

    print("Structural graph completed:\n", str(time.time() - start))

    if len(bad_idxs) > 0:
        bad_idxs = sorted(list(set(bad_idxs)), reverse=True)
        for j in bad_idxs:
            del labels[j], coords[j]

    coords = np.array(coords)
    labels = np.array(labels)

    assert len(coords) == len(labels) == conn_matrix.shape[0]

    return (atlas_mni, streams, conn_matrix, track_type, target_samples,
            dir_path, conn_model, network, node_size, dens_thresh, ID, roi,
            min_span_tree, disp_filt, parc, prune, atlas, uatlas, labels,
            coords, norm, binary, directget, min_length, error_margin)
Exemplo n.º 24
0
def build_asetomes(est_path_iterlist, ID):
    """
    Embeds single graphs using the ASE algorithm.

    Parameters
    ----------
    est_path_iterlist : list
        List of file paths to .npy files, each containing a graph.
    ID : str
        A subject id or other unique identifier.

    """
    from pathlib import Path
    import os
    import numpy as np
    from pynets.core.utils import prune_suffices, flatten
    from pynets.stats.embeddings import _ase_embed
    from pynets.core.utils import load_runconfig

    # Available functional and structural connectivity models
    hardcoded_params = load_runconfig()
    try:
        n_components = hardcoded_params["gradients"][
            "n_components"][0]
    except KeyError:
        import sys
        print(
            "ERROR: available gradient dimensionality presets not "
            "sucessfully extracted from runconfig.yaml"
        )
        sys.exit(1)

    if isinstance(est_path_iterlist, list):
        est_path_iterlist = list(flatten(est_path_iterlist))
    else:
        est_path_iterlist = [est_path_iterlist]

    out_paths = []
    for file_ in est_path_iterlist:
        mat = np.load(file_)
        if np.isfinite(mat).all() == False:
            continue

        atlas = prune_suffices(file_.split("/")[-3])
        res = prune_suffices("_".join(file_.split(
            "/")[-1].split("modality")[1].split("_")[1:]).split("_est")[0])
        if "rsn" in res:
            subgraph = res.split("rsn-")[1].split('_')[0]
        else:
            subgraph = "all_nodes"
        out_path = _ase_embed(mat, atlas, file_, ID, subgraph_name=subgraph,
                              n_components=n_components)
        if out_path is not None:
            out_paths.append(out_path)
        else:
            # Add a null tmp file to prevent pool from breaking
            dir_path = str(Path(os.path.dirname(file_)).parent)
            namer_dir = f"{dir_path}/embeddings"
            if os.path.isdir(namer_dir) is False:
                os.makedirs(namer_dir, exist_ok=True)
            out_path = f"{namer_dir}/gradient-ASE" \
                       f"_{atlas}_{subgraph}_{os.path.basename(file_)}_NULL"
            if not os.path.exists(out_path):
                os.mknod(out_path)
            out_paths.append(out_path)

    return out_paths
Exemplo n.º 25
0
def benchmark_reproducibility(comb, modality, alg, sub_dict_clean, disc,
                              int_consist, final_missingness_summary):
    df_summary = pd.DataFrame(
        columns=['grid', 'modality', 'embedding', 'discriminability'])
    print(comb)
    df_summary.at[0, "modality"] = modality
    df_summary.at[0, "embedding"] = alg

    if modality == 'func':
        try:
            extract, hpass, model, res, atlas, smooth = comb
        except:
            print(f"Missing {comb}...")
            extract, hpass, model, res, atlas = comb
            smooth = '0'
        comb_tuple = (atlas, extract, hpass, model, res, smooth)
    else:
        directget, minlength, model, res, atlas, tol = comb
        comb_tuple = (atlas, directget, minlength, model, res, tol)

    df_summary.at[0, "grid"] = comb_tuple

    missing_sub_seshes = \
        final_missingness_summary.loc[(final_missingness_summary['alg']==alg)
                                      & (final_missingness_summary[
                                             'modality']==modality) &
                                      (final_missingness_summary[
                                           'grid']==comb_tuple)
                                      ].drop_duplicates(subset='id')

    # int_consist
    if int_consist is True and alg == 'topology':
        try:
            import pingouin as pg
        except ImportError:
            print("Cannot evaluate test-retest int_consist. pingouin"
                  " must be installed!")
        for met in mets:
            id_dict = {}
            for ID in ids:
                id_dict[ID] = {}
                for ses in sub_dict_clean[ID].keys():
                    if comb_tuple in sub_dict_clean[ID][ses][modality][
                            alg].keys():
                        id_dict[ID][ses] = \
                        sub_dict_clean[ID][ses][modality][alg][comb_tuple][
                            mets.index(met)][0]
            df_wide = pd.DataFrame(id_dict).T
            if df_wide.empty:
                del df_wide
                return pd.Series()
            df_wide = df_wide.add_prefix(f"{met}_visit_")
            df_wide.replace(0, np.nan, inplace=True)
            try:
                c_alpha = pg.cronbach_alpha(data=df_wide)
            except:
                print('FAILED...')
                print(df_wide)
                del df_wide
                return pd.Series()
            df_summary.at[0, f"cronbach_alpha_{met}"] = c_alpha[0]
            del df_wide

    # icc
    if icc is True and alg == 'topology':
        try:
            import pingouin as pg
        except ImportError:
            print("Cannot evaluate ICC. pingouin" " must be installed!")
        for met in mets:
            id_dict = {}
            dfs = []
            for ses in [str(i) for i in range(1, 11)]:
                for ID in ids:
                    id_dict[ID] = {}
                    if comb_tuple in sub_dict_clean[ID][ses][modality][
                            alg].keys():
                        id_dict[ID][ses] = \
                        sub_dict_clean[ID][ses][modality][alg][comb_tuple][
                            mets.index(met)][0]
                    df = pd.DataFrame(id_dict).T
                    if df.empty:
                        del df_long
                        return pd.Series()
                    df.columns.values[0] = f"{met}"
                    df.replace(0, np.nan, inplace=True)
                    df['id'] = df.index
                    df['ses'] = ses
                    df.reset_index(drop=True, inplace=True)
                    dfs.append(df)
            df_long = pd.concat(dfs, names=[
                'id', 'ses', f"{met}"
            ]).drop(columns=[str(i) for i in range(1, 10)])
            try:
                c_icc = pg.intraclass_corr(data=df_long,
                                           targets='id',
                                           raters='ses',
                                           ratings=f"{met}",
                                           nan_policy='omit').round(3)
                c_icc = c_icc.set_index("Type")
                df_summary.at[0, f"icc_{met}"] = pd.DataFrame(
                    c_icc.drop(
                        index=['ICC1', 'ICC2', 'ICC3'])['ICC']).mean()[0]
            except:
                print('FAILED...')
                print(df_long)
                del df_long
                return pd.Series()
            del df_long

    if disc is True:
        vect_all = []
        for ID in ids:
            try:
                out = gen_sub_vec(sub_dict_clean, ID, modality, alg,
                                  comb_tuple)
            except:
                print(f"{ID} {modality} {alg} {comb_tuple} failed...")
                continue
            # print(out)
            vect_all.append(out)
        vect_all = [
            i for i in vect_all if i is not None and not np.isnan(i).all()
        ]
        if len(vect_all) > 0:
            if alg == 'topology':
                X_top = np.swapaxes(np.hstack(vect_all), 0, 1)
                bad_ixs = [i[1] for i in np.argwhere(np.isnan(X_top))]
                for m in set(bad_ixs):
                    if (X_top.shape[0] - bad_ixs.count(m)) / \
                        X_top.shape[0] < 0.50:
                        X_top = np.delete(X_top, m, axis=1)
            else:
                if len(vect_all) > 0:
                    X_top = np.array(pd.concat(vect_all, axis=0))
                else:
                    return pd.Series()
            shapes = []
            for ix, i in enumerate(vect_all):
                shapes.append(i.shape[0] * [list(ids)[ix]])
            Y = np.array(list(flatten(shapes)))
            if alg == 'topology':
                imp = IterativeImputer(max_iter=50, random_state=42)
            else:
                imp = SimpleImputer()
            X_top = imp.fit_transform(X_top)
            scaler = StandardScaler()
            X_top = scaler.fit_transform(X_top)
            try:
                discr_stat_val, rdf = discr_stat(X_top, Y)
            except:
                return pd.Series()
            df_summary.at[0, "discriminability"] = discr_stat_val
            print(discr_stat_val)
            print("\n")
            # print(rdf)
            del discr_stat_val
        del vect_all
    return df_summary
Exemplo n.º 26
0
def main():
    """Initializes main script from command-line call to generate
    single-subject or multi-subject workflow(s)"""
    import os
    import gc
    import sys
    import json
    from pynets.core.utils import build_args_from_config
    import itertools
    from types import SimpleNamespace
    import pkg_resources
    from pynets.core.utils import flatten
    from pynets.cli.pynets_run import build_workflow
    from multiprocessing import set_start_method, Process, Manager
    from colorama import Fore, Style

    try:
        import pynets
    except ImportError:
        print(
            "PyNets not installed! Ensure that you are referencing the correct"
            " site-packages and using Python3.6+"
        )

    if len(sys.argv) < 1:
        print("\nMissing command-line inputs! See help options with the -h"
              " flag.\n")
        sys.exit()

    print(f"{Fore.LIGHTBLUE_EX}\nBIDS API\n")

    print(Style.RESET_ALL)

    print(f"{Fore.LIGHTGREEN_EX}Obtaining Derivatives Layout...")

    print(Style.RESET_ALL)

    modalities = ["func", "dwi"]
    space = 'T1w'

    bids_args = get_bids_parser().parse_args()
    participant_label = bids_args.participant_label
    session_label = bids_args.session_label
    run = bids_args.run_label
    if isinstance(run, list):
        run = str(run[0]).zfill(2)
    modality = bids_args.modality
    bids_config = bids_args.config
    analysis_level = bids_args.analysis_level
    clean = bids_args.clean

    if analysis_level == "group" and participant_label is not None:
        raise ValueError(
            "Error: You have indicated a group analysis level run, but"
            " specified a participant label!"
        )

    if analysis_level == "participant" and participant_label is None:
        raise ValueError(
            "Error: You have indicated a participant analysis level run, but"
            " not specified a participant "
            "label!")

    if bids_config:
        with open(bids_config, "r") as stream:
            arg_dict = json.load(stream)
    else:
        with open(
            pkg_resources.resource_filename("pynets",
                                            "config/bids_config.json"),
            "r",
        ) as stream:
            arg_dict = json.load(stream)
        stream.close()

    # S3
    # Primary inputs
    s3 = bids_args.bids_dir.startswith("s3://")

    if not s3:
        bids_dir = bids_args.bids_dir

    # secondary inputs
    sec_s3_objs = []
    if isinstance(bids_args.ua, list):
        for i in bids_args.ua:
            if i.startswith("s3://"):
                print("Downloading user atlas: ", i, " from S3...")
                sec_s3_objs.append(i)
    if isinstance(bids_args.cm, list):
        for i in bids_args.cm:
            if i.startswith("s3://"):
                print("Downloading clustering mask: ", i, " from S3...")
                sec_s3_objs.append(i)
    if isinstance(bids_args.roi, list):
        for i in bids_args.roi:
            if i.startswith("s3://"):
                print("Downloading ROI mask: ", i, " from S3...")
                sec_s3_objs.append(i)
    if isinstance(bids_args.way, list):
        for i in bids_args.way:
            if i.startswith("s3://"):
                print("Downloading tractography waymask: ", i, " from S3...")
                sec_s3_objs.append(i)

    if bids_args.ref:
        if bids_args.ref.startswith("s3://"):
            print(
                "Downloading atlas labeling reference file: ",
                bids_args.ref,
                " from S3...",
            )
            sec_s3_objs.append(bids_args.ref)

    if s3 or len(sec_s3_objs) > 0:
        from boto3.session import Session
        from pynets.core import cloud_utils
        from pynets.core.utils import as_directory

        home = os.path.expanduser("~")
        creds = bool(cloud_utils.get_credentials())

        if s3:
            buck, remo = cloud_utils.parse_path(bids_args.bids_dir)
            os.makedirs(f"{home}/.pynets", exist_ok=True)
            os.makedirs(f"{home}/.pynets/input", exist_ok=True)
            os.makedirs(f"{home}/.pynets/output", exist_ok=True)
            bids_dir = as_directory(f"{home}/.pynets/input", remove=False)
            if (not creds) and bids_args.push_location:
                raise AttributeError(
                    """No AWS credentials found, but `--push_location` flag
                     called. Pushing will most likely fail.""")
            else:
                output_dir = as_directory(
                    f"{home}/.pynets/output", remove=False)

            # Get S3 input data if needed
            if analysis_level == "participant":
                for partic, ses in list(
                    itertools.product(participant_label, session_label)
                ):
                    if ses is not None:
                        info = "sub-" + partic + "/ses-" + ses
                    elif ses is None:
                        info = "sub-" + partic
                    cloud_utils.s3_get_data(
                        buck, remo, bids_dir, modality, info=info)
            elif analysis_level == "group":
                if len(session_label) > 1 and session_label[0] != "None":
                    for ses in session_label:
                        info = "ses-" + ses
                        cloud_utils.s3_get_data(
                            buck, remo, bids_dir, modality, info=info
                        )
                else:
                    cloud_utils.s3_get_data(buck, remo, bids_dir, modality)

        if len(sec_s3_objs) > 0:
            [access_key, secret_key] = cloud_utils.get_credentials()

            session = Session(
                aws_access_key_id=access_key, aws_secret_access_key=secret_key
            )

            s3_r = session.resource("s3")
            s3_c = cloud_utils.s3_client(service="s3")
            sec_dir = as_directory(
                home + "/.pynets/secondary_files", remove=False)
            for s3_obj in [i for i in sec_s3_objs if i is not None]:
                buck, remo = cloud_utils.parse_path(s3_obj)
                s3_c.download_file(
                    buck, remo, f"{sec_dir}/{os.path.basename(s3_obj)}")

            if isinstance(bids_args.ua, list):
                local_ua = bids_args.ua.copy()
                for i in local_ua:
                    if i.startswith("s3://"):
                        local_ua[local_ua.index(
                            i)] = f"{sec_dir}/{os.path.basename(i)}"
                bids_args.ua = local_ua
            if isinstance(bids_args.cm, list):
                local_cm = bids_args.cm.copy()
                for i in bids_args.cm:
                    if i.startswith("s3://"):
                        local_cm[local_cm.index(
                            i)] = f"{sec_dir}/{os.path.basename(i)}"
                bids_args.cm = local_cm
            if isinstance(bids_args.roi, list):
                local_roi = bids_args.roi.copy()
                for i in bids_args.roi:
                    if i.startswith("s3://"):
                        local_roi[
                            local_roi.index(i)
                        ] = f"{sec_dir}/{os.path.basename(i)}"
                bids_args.roi = local_roi
            if isinstance(bids_args.way, list):
                local_way = bids_args.way.copy()
                for i in bids_args.way:
                    if i.startswith("s3://"):
                        local_way[
                            local_way.index(i)
                        ] = f"{sec_dir}/{os.path.basename(i)}"
                bids_args.way = local_way

            if bids_args.ref:
                if bids_args.ref.startswith("s3://"):
                    bids_args.ref = f"{sec_dir}/" \
                                    f"{os.path.basename(bids_args.ref)}"
    else:
        output_dir = bids_args.output_dir
        if output_dir is None:
            raise ValueError("Must specify an output directory")

    intermodal_dict = {
        k: []
        for k in [
            "funcs",
            "confs",
            "dwis",
            "bvals",
            "bvecs",
            "anats",
            "masks",
            "subjs",
            "seshs",
        ]
    }
    if analysis_level == "group":
        if len(modality) > 1:
            i = 0
            for mod_ in modality:
                outs = sweep_directory(
                    bids_dir,
                    modality=mod_,
                    space=space,
                    sesh=session_label,
                    run=run
                )
                if mod_ == "func":
                    if i == 0:
                        funcs, confs, _, _, _, anats, masks, subjs, seshs =\
                            outs
                    else:
                        funcs, confs, _, _, _, _, _, _, _ = outs
                    intermodal_dict["funcs"].append(funcs)
                    intermodal_dict["confs"].append(confs)
                elif mod_ == "dwi":
                    if i == 0:
                        _, _, dwis, bvals, bvecs, anats, masks, subjs, seshs =\
                            outs
                    else:
                        _, _, dwis, bvals, bvecs, _, _, _, _ = outs
                    intermodal_dict["dwis"].append(dwis)
                    intermodal_dict["bvals"].append(bvals)
                    intermodal_dict["bvecs"].append(bvecs)
                intermodal_dict["anats"].append(anats)
                intermodal_dict["masks"].append(masks)
                intermodal_dict["subjs"].append(subjs)
                intermodal_dict["seshs"].append(seshs)
                i += 1
        else:
            intermodal_dict = None
            outs = sweep_directory(
                bids_dir,
                modality=modality[0],
                space=space,
                sesh=session_label,
                run=run
            )
            funcs, confs, dwis, bvals, bvecs, anats, masks, subjs, seshs = outs
    elif analysis_level == "participant":
        if len(modality) > 1:
            i = 0
            for mod_ in modality:
                outs = sweep_directory(
                    bids_dir,
                    modality=mod_,
                    space=space,
                    subj=participant_label,
                    sesh=session_label,
                    run=run
                )
                if mod_ == "func":
                    if i == 0:
                        funcs, confs, _, _, _, anats, masks, subjs, seshs =\
                            outs
                    else:
                        funcs, confs, _, _, _, _, _, _, _ = outs
                    intermodal_dict["funcs"].append(funcs)
                    intermodal_dict["confs"].append(confs)
                elif mod_ == "dwi":
                    if i == 0:
                        _, _, dwis, bvals, bvecs, anats, masks, subjs, seshs =\
                            outs
                    else:
                        _, _, dwis, bvals, bvecs, _, _, _, _ = outs
                    intermodal_dict["dwis"].append(dwis)
                    intermodal_dict["bvals"].append(bvals)
                    intermodal_dict["bvecs"].append(bvecs)
                intermodal_dict["anats"].append(anats)
                intermodal_dict["masks"].append(masks)
                intermodal_dict["subjs"].append(subjs)
                intermodal_dict["seshs"].append(seshs)
                i += 1
        else:
            intermodal_dict = None
            outs = sweep_directory(
                bids_dir,
                modality=modality[0],
                space=space,
                subj=participant_label,
                sesh=session_label,
                run=run
            )
            funcs, confs, dwis, bvals, bvecs, anats, masks, subjs, seshs = outs
    else:
        raise ValueError(
            "Analysis level invalid. Must be `participant` or `group`. See"
            " --help."
        )

    if intermodal_dict:
        funcs, confs, dwis, bvals, bvecs, anats, masks, subjs, seshs = [
            list(set(list(flatten(i)))) for i in intermodal_dict.values()
        ]

    args_dict_all = build_args_from_config(modality, arg_dict)

    id_list = []
    for i in sorted(list(set(subjs))):
        for ses in sorted(list(set(seshs))):
            id_list.append(i + "_" + ses)

    args_dict_all["work"] = bids_args.work
    args_dict_all["output_dir"] = output_dir
    args_dict_all["plug"] = bids_args.plug
    args_dict_all["pm"] = bids_args.pm
    args_dict_all["v"] = bids_args.v
    args_dict_all["clean"] = bids_args.clean
    if funcs is not None:
        args_dict_all["func"] = sorted(funcs)
    else:
        args_dict_all["func"] = None
    if confs is not None:
        args_dict_all["conf"] = sorted(confs)
    else:
        args_dict_all["conf"] = None
    if dwis is not None:
        args_dict_all["dwi"] = sorted(dwis)
        args_dict_all["bval"] = sorted(bvals)
        args_dict_all["bvec"] = sorted(bvecs)
    else:
        args_dict_all["dwi"] = None
        args_dict_all["bval"] = None
        args_dict_all["bvec"] = None
    if anats is not None:
        args_dict_all["anat"] = sorted(anats)
    else:
        args_dict_all["anat"] = None
    if masks is not None:
        args_dict_all["m"] = sorted(masks)
    else:
        args_dict_all["m"] = None
    args_dict_all["g"] = None
    if ("dwi" in modality) and (bids_args.way is not None):
        args_dict_all["way"] = bids_args.way
    else:
        args_dict_all["way"] = None
    args_dict_all["id"] = id_list
    args_dict_all["ua"] = bids_args.ua
    args_dict_all["ref"] = bids_args.ref
    args_dict_all["roi"] = bids_args.roi
    if ("func" in modality) and (bids_args.cm is not None):
        args_dict_all["cm"] = bids_args.cm
    else:
        args_dict_all["cm"] = None

    # Mimic argparse with SimpleNamespace object
    args = SimpleNamespace(**args_dict_all)
    print(args)

    set_start_method("forkserver")
    with Manager() as mgr:
        retval = mgr.dict()
        p = Process(target=build_workflow, args=(args, retval))
        p.start()
        p.join()
        if p.is_alive():
            p.terminate()

        retcode = p.exitcode or retval.get("return_code", 0)

        pynets_wf = retval.get("workflow", None)
        work_dir = retval.get("work_dir")
        plugin_settings = retval.get("plugin_settings", None)
        plugin_settings = retval.get("plugin_settings", None)
        execution_dict = retval.get("execution_dict", None)
        run_uuid = retval.get("run_uuid", None)

        retcode = retcode or int(pynets_wf is None)
        if retcode != 0:
            sys.exit(retcode)

        # Clean up master process before running workflow, which may create
        # forks
        gc.collect()

    mgr.shutdown()

    if bids_args.push_location:
        print(f"Pushing to s3 at {bids_args.push_location}.")
        push_buck, push_remo = cloud_utils.parse_path(bids_args.push_location)
        for id in id_list:
            cloud_utils.s3_push_data(
                push_buck,
                push_remo,
                output_dir,
                modality,
                subject=id.split("_")[0],
                session=id.split("_")[1],
                creds=creds,
            )

    sys.exit(0)

    return
Exemplo n.º 27
0
def build_multigraphs(est_path_iterlist, ID):
    """
    Constructs a multimodal multigraph for each available resolution of
    vertices.

    Parameters
    ----------
    est_path_iterlist : list
        List of file paths to .npy file containing graph.
    ID : str
        A subject id or other unique identifier.

    Returns
    -------
    multigraph_list_all : list
        List of multiplex graph dictionaries corresponding to
        each unique node resolution.
    graph_path_list_top : list
        List of lists consisting of pairs of most similar
        structural and functional connectomes for each unique node resolution.

    References
    ----------
    .. [1] Bullmore, E., & Sporns, O. (2009). Complex brain networks: Graph
      theoretical analysis of structural and functional systems.
      Nature Reviews Neuroscience. https://doi.org/10.1038/nrn2575
    .. [2] Vaiana, M., & Muldoon, S. F. (2018). Multilayer Brain Networks.
      Journal of Nonlinear Science. https://doi.org/10.1007/s00332-017-9436-8

    """
    import os
    import itertools
    import numpy as np
    from pathlib import Path
    from pynets.core.utils import flatten
    from pynets.stats.netmotifs import motif_matching
    from pynets.core.utils import load_runconfig

    raw_est_path_iterlist = list(
        set(
            [
                os.path.dirname(i) + '/raw' + os.path.basename(i).split(
                    "_thrtype")[0] + ".npy"
                for i in list(flatten(est_path_iterlist))
            ]
        )
    )

    # Available functional and structural connectivity models
    hardcoded_params = load_runconfig()
    try:
        func_models = hardcoded_params["available_models"]["func_models"]
    except KeyError:
        print(
            "ERROR: available functional models not sucessfully extracted"
            " from runconfig.yaml"
        )
    try:
        struct_models = hardcoded_params["available_models"][
            "struct_models"]
    except KeyError:
        print(
            "ERROR: available structural models not sucessfully extracted"
            " from runconfig.yaml"
        )

    atlases = list(set([x.split("/")[-3].split("/")[0]
                        for x in raw_est_path_iterlist]))
    parcel_dict_func = dict.fromkeys(atlases)
    parcel_dict_dwi = dict.fromkeys(atlases)
    est_path_iterlist_dwi = list(
        set(
            [
                i
                for i in raw_est_path_iterlist
                if i.split("model-")[1].split("_")[0] in struct_models
            ]
        )
    )
    est_path_iterlist_func = list(
        set(
            [
                i
                for i in raw_est_path_iterlist
                if i.split("model-")[1].split("_")[0] in func_models
            ]
        )
    )

    if "_rsn" in ";".join(est_path_iterlist_func):
        func_subnets = list(
            set([i.split("_rsn-")[1].split("_")[0] for i in
                 est_path_iterlist_func])
        )
    else:
        func_subnets = []
    if "_rsn" in ";".join(est_path_iterlist_dwi):
        dwi_subnets = list(
            set([i.split("_rsn-")[1].split("_")[0] for i in
                 est_path_iterlist_dwi])
        )
    else:
        dwi_subnets = []

    dir_path = str(
        Path(
            os.path.dirname(
                est_path_iterlist_dwi[0])).parent.parent.parent)
    namer_dir = f"{dir_path}/graphs_multilayer"
    if not os.path.isdir(namer_dir):
        os.mkdir(namer_dir)

    name_list = []
    metadata_list = []
    multigraph_list_all = []
    graph_path_list_all = []
    for atlas in atlases:
        if len(func_subnets) >= 1:
            parcel_dict_func[atlas] = {}
            for sub_net in func_subnets:
                parcel_dict_func[atlas][sub_net] = []
        else:
            parcel_dict_func[atlas] = []

        if len(dwi_subnets) >= 1:
            parcel_dict_dwi[atlas] = {}
            for sub_net in dwi_subnets:
                parcel_dict_dwi[atlas][sub_net] = []
        else:
            parcel_dict_dwi[atlas] = []

        for graph_path in est_path_iterlist_dwi:
            if atlas in graph_path:
                if len(dwi_subnets) >= 1:
                    for sub_net in dwi_subnets:
                        if sub_net in graph_path:
                            parcel_dict_dwi[atlas][sub_net].append(graph_path)
                else:
                    parcel_dict_dwi[atlas].append(graph_path)

        for graph_path in est_path_iterlist_func:
            if atlas in graph_path:
                if len(func_subnets) >= 1:
                    for sub_net in func_subnets:
                        if sub_net in graph_path:
                            parcel_dict_func[atlas][sub_net].append(graph_path)
                else:
                    parcel_dict_func[atlas].append(graph_path)

        parcel_dict = {}
        # Create dictionary of all possible pairs of structural-functional
        # graphs for each unique resolution of vertices
        if len(dwi_subnets) >= 1 and len(func_subnets) >= 1:
            parcel_dict[atlas] = {}
            rsns = np.intersect1d(dwi_subnets, func_subnets).tolist()
            for rsn in rsns:
                parcel_dict[atlas][rsn] = list(set(itertools.product(
                    parcel_dict_dwi[atlas][rsn],
                    parcel_dict_func[atlas][rsn])))
                for paths in list(parcel_dict[atlas][rsn]):
                    [
                        name_list,
                        metadata_list,
                        multigraph_list_all,
                        graph_path_list_all,
                    ] = motif_matching(
                        paths,
                        ID,
                        atlas,
                        namer_dir,
                        name_list,
                        metadata_list,
                        multigraph_list_all,
                        graph_path_list_all,
                        rsn=rsn,
                    )
        else:
            parcel_dict[atlas] = list(set(itertools.product(
                parcel_dict_dwi[atlas], parcel_dict_func[atlas])))
            for paths in list(parcel_dict[atlas]):
                [
                    name_list,
                    metadata_list,
                    multigraph_list_all,
                    graph_path_list_all,
                ] = motif_matching(
                    paths,
                    ID,
                    atlas,
                    namer_dir,
                    name_list,
                    metadata_list,
                    multigraph_list_all,
                    graph_path_list_all,
                )

    graph_path_list_top = [list(i[0].values()) for i in graph_path_list_all]
    assert len(multigraph_list_all) == len(name_list) == len(metadata_list)

    return (
        multigraph_list_all,
        graph_path_list_top,
        len(name_list) * [namer_dir],
        name_list,
        metadata_list,
    )