Exemplo n.º 1
0
 def GetScores(self):
     
     self.scoresG()
     self.scoresC()
     self.Gcore_real=fixG(self.Gcore_real)
     self.Gcompared_real=fixG(self.Gcompared_real)
     self.Gcore_exp=fixG(self.Gcore_exp)
     self.Gcompared_exp=fixG(self.Gcompared_exp)
     
     return [self.GFNR, self.GFPR, self.CFNR, self.CFPR]
     
                    
         
Exemplo n.º 2
0
    def BuildGraph(self):
        
        '''
        read nodes with their pos and diameter, read edges
        '''
        
        d, p, e, f = self.ReadLines()

        G=DiGraph()
        G.add_nodes_from(range(len(p)))
        
        for i, j in zip(G.GetNodes(), p):
            G.node[i]['pos']=j
            
        e=np.array(e)-1
        G.add_edges_from(e.tolist())
            
        for i, j in zip(e, d):
            
            try:
                if j>G.node[i[0]]['d']:
                    G.node[i[0]]['d']=j
            except:
                G.node[i[0]]['d']=j
                    
            try:
                if j>G.node[i[1]]['d']:
                    G.node[i[1]]['d']=j
            except:
                G.node[i[1]]['d']=j
            
        self.G=fixG(G)
Exemplo n.º 3
0
    def Update(self, FullyCC=False):

        filenameVertices = self.filenameVertices
        filenameEdges = self.filenameEdges

        P1, P2, P11, P22 = self.readCGAL(filenameEdges, filenameVertices)
        p, intersections, c = self.getGraph(P1, P2, P11, P22)

        G = Graph()
        G.add_nodes_from(range(np.shape(p)[0]))
        G.add_edges_from(np.ndarray.tolist(c))
        for i in range(np.shape(p)[0]):
            G.node[i]['pos'] = p[i, :]
        G.to_undirected()

        if FullyCC == True:

            # connected components
            graphs = list(nx.connected_component_subgraphs(G))
            s = 0
            ind = 0
            for idx, i in enumerate(graphs):
                if len(i) > s:
                    s = len(i)
                    ind = idx
            G = graphs[ind]
            G = fixG(G)

        self.Graph = G
Exemplo n.º 4
0
    def BuildGraph(self):
        '''
        1 Build graph based on info from 'path_net'
        2 Set attributes to graph nodes based on info from 'path_attr'
        '''
        def setatt(g, e, v, name):
            for i, j in zip(e, v):
                try:
                    if j > G.node[i[0]][name]:
                        G.node[i[0]][name] = j
                except:
                    G.node[i[0]][name] = j
                try:
                    if j > G.node[i[1]][name]:
                        G.node[i[1]][name] = j
                except:
                    G.node[i[1]][name] = j
            return g

        p, e = self.ReadNet()

        if self.path_attr is not None:
            self.attr = self.ReadAttr()

        if self.mode == 'di':
            G = DiGraph()
        else:
            G = Graph()

        G.add_nodes_from(range(len(p)))

        for i, j in zip(G.GetNodes(), p):
            G.node[i]['pos'] = j

        e = np.array(e) - 1
        G.add_edges_from(e.tolist())

        # set diameter/radius
        try:
            d = np.array(self.attr['Dia']).ravel().astype(float)
            G = setatt(G, e, d, 'd')
            G = setatt(G, e, d / 2, 'r')
        except:
            print('--Cannot set diam!')

        # set flow
        try:
            flow = np.array(self.attr['flow']).ravel().astype(float)
            G = setatt(G, e, flow, 'flow')
        except:
            print('--Cannot set flow!')

        # set po2
        try:
            po2 = np.array(self.attr['ppO2']).ravel().astype(float)
            G = setatt(G, e, po2, 'po2')
        except:
            print('--Cannot set po2!')

        self.G = fixG(G)
Exemplo n.º 5
0
    def GetOutput(self):

        self.ReadFile()
        self.AddSourcesSinks()

        if self.G is not None:
            return fixG(self.G)
Exemplo n.º 6
0
    def Update(self, label=None, ret=False):
        '''
        generate a graph based on the input label image
        
        method: generate initial geometry --> contract graph --> refine graph
        
        
        @article{damseh2019laplacian,
            title={Laplacian Flow Dynamics on Geometric Graphs for Anatomical Modeling of Cerebrovascular Networks}, 
            author={Damseh, Rafat and Delafontaine-Martel, Patrick and Pouliot, Philippe and Cheriet, Farida and Lesage, Frederic}, 
            journal={arXiv preprint arXiv:1912.10003}, year={2019}}
        
        @article{damseh2018automatic,
            title={Automatic Graph-Based Modeling of Brain Microvessels Captured With Two-Photon Microscopy}, 
            author={Damseh, Rafat and Pouliot, Philippe and Gagnon, Louis and Sakadzic, Sava and Boas, 
                    David and Cheriet, Farida and Lesage, Frederic}, 
            journal={IEEE journal of biomedical and health informatics}, 
            volume={23}, 
            number={6}, 
            pages={2551--2562}, 
            year={2018}, 
            publisher={IEEE}} 
                    
        '''

        # ---------- generate -----------#
        if label is None:
            generate = GenerateGraph(self.label)
        else:
            generate = GenerateGraph(label)

        generate.UpdateGridGraph(Sampling=self.sampling)
        graph = generate.GetOutput()

        # ---------- contract -----------#
        contract = ContractGraph(graph)
        contract.Update(DistParam=self.dist_param,
                        MedParam=self.med_param,
                        SpeedParam=self.speed_param,
                        DegreeThreshold=self.degree_threshold,
                        StopParam=self.stop_param,
                        NFreeIteration=self.n_free_iteration)
        gc = contract.GetOutput()

        # ---------- refine -----------#
        refine = RefineGraph(gc)
        refine.Update(AreaParam=self.area_param, PolyParam=self.poly_param)
        gr = refine.GetOutput()
        gr = fixG(gr)

        # ----- return ----#
        if ret:
            return gr

        else:
            self.Graph = gr
Exemplo n.º 7
0
def PostProcessMRIGraph(graph, upper_distance=7.0, k=5):
    '''
    This function reconnect seperated segments of MRI graph
    '''
    pos_all = np.array(graph.GetNodesPos())  # pos all nodes
    nodes_all = graph.GetNodes()

    # ----- overconnecting the graph --------#

    #subgraphs
    graphs = list(nx.connected_component_subgraphs(graph))

    # obtain end nodes/nodes and their positions in each segment
    nodes = [i.GetNodes() for i in graphs]
    end_nodes = [i.GetJuntionNodes(bifurcation=[0, 1])
                 for i in graphs]  # end nodes in each subgraph
    end_nodes_pos = [
        np.array([graph.node[i]['pos'] for i in j]) for j in end_nodes
    ]  # pos of end nodes

    # obtain closest node from ther segments to an end node from current segment
    closest_nodes = []
    for end_n, n, end_p in zip(end_nodes, nodes,
                               end_nodes_pos):  #iterate over each segment

        other_nodes = list(set(nodes_all).symmetric_difference(set(n)))
        other_pos = np.array([graph.node[i]['pos'] for i in other_nodes])

        # closest nodes in graph to current segment end nodes ...
        # except for nodes in current segment
        mapping = dict(zip(range(len(other_nodes)), other_nodes))

        ind_notvalid = len(other_pos)
        tree = sp.spatial.cKDTree(other_pos)
        closest = [
            tree.query(i, k=k, distance_upper_bound=upper_distance)[1][1:]
            for i in end_p
        ]
        closest = [[i for i in j if i != ind_notvalid]
                   for j in closest]  # fix from query
        closest = [[mapping[i] for i in j] for j in closest]  # fix indixing
        closest_nodes.append(closest)

    # create new graph amd add new edges
    graph_new = graph.copy()
    closest_nodes = [i for j in closest_nodes for i in j]
    end_nodes = [i for j in end_nodes for i in j]
    edges_new = [[i, k] for i, j in zip(end_nodes, closest_nodes) for k in j]
    graph_new.add_edges_from(edges_new)

    graphs_new = list(nx.connected_component_subgraphs(graph_new))
    print('Elements in each connected component: ')
    print([len(i) for i in graphs_new])

    # refine overconnectivity
    from VascGraph.Skeletonize import RefineGraph
    final_graph = FullyConnectedGraph(graph_new)
    refine = RefineGraph(final_graph)
    refine.Update(AreaParam=50.0, PolyParam=10)
    final_graph = fixG(refine.GetOutput())

    return final_graph
Exemplo n.º 8
0
                  degree_threshold=degree_threshold,
                  clustering_resolution=clustering_r,
                  stop_param=stop_param,
                  n_free_iteration=n_free_iteration,
                  area_param=area_param,
                  poly_param=poly_param)

    sk.UpdateWithStitching(size=size,
                           niter1=niter1,
                           niter2=niter2,
                           is_parallel=is_parallel,
                           n_parallel=n_parallel)

    fullgraph = sk.GetOutput()

    # save graph
    WritePajek(path='', name='mygraph.pajek', graph=fixG(fullgraph))

    #load graph
    loaded_g = ReadPajek('mygraph.pajek').GetOutput()

    print('--Visualize final skeleton ...')
    splot = StackPlot(new_engine=True)
    splot.Update((s > 0).astype(int))

    gplot = GraphPlot()
    gplot.Update(loaded_g)
    gplot.SetTubeRadiusByScale(True)
    gplot.SetTubeRadiusByColor(True)
    gplot.SetTubeRadius(3)
Exemplo n.º 9
0
                    SpeedParam=speed_param, 
                    DegreeThreshold=degree_threshold, 
                    StopParam=stop_param,
                    NFreeIteration=n_free_iteration)
    gc=contract.GetOutput()
    
    
    #refine graph
    refine=RefineGraph(gc)
    refine.Update(AreaParam=area_param, 
                  PolyParam=poly_param)
    gr=refine.GetOutput()    
    
    
    #gr=FullyConnectedGraph(gr) # uncomment to get only fully connected components of the graph
    gr=fixG(gr, copy=True) # this is to fix nodes indixing to be starting from 0 (important for visualization)
    

  
    #-------------------------------------------------------------------------#
    # read/ write
    #-------------------------------------------------------------------------#

#    # save graph
    WritePajek(path='', name='mygraph.pajek', graph=fixG(gr))

#    #load graph    
    loaded_g=ReadPajek('mygraph.pajek').GetOutput()
    
    
    
Exemplo n.º 10
0
    def UpdateWithStitching(self,
                            size,
                            niter1=10,
                            niter2=5,
                            is_parallel=False,
                            n_parallel=5,
                            ret=False):
        '''
        this funtion allow to generate graphs as follows:
            1) image patching -->  2) patch-based contraction (fixing boundary nodes) 
            --> 3) graph stitching --> 4) boundary contraction --> 5) global contraction --> refinement 
        
        it is helpful when graphing large inputs.
        
        Inputs:
            size: dimention of a 3D patch --> [size, size, size] 
            niter1: number of contraction iterations on patches 
            niter2: number of contraction iterations on boundary nodes 
            is_parallel: if True, patch-based contraction will run in parallel using 'ray'
            n_parallel: number of parallel processes (note: for limited RAM memory, 'n_parallel' should be smaller) 
            ret: if True, this function will return the output graph
        '''
        try:
            from sklearn import neighbors
        except:
            print('  \'scikit-learn\' must be instaled to run this funtion!')

        if is_parallel:

            try:
                import ray
            except:
                print(
                    '  \'ray\' must be installed to run patch-based contraction in parallel!'
                )

            GraphParallel = activate_parallel()

        # obtain distance map
        #self.label=image.morphology.distance_transform_edt(self.label)
        self.label = DistMap3D(self.label)

        # patching
        print('--Extract patches ...')
        patches, patchesid = Decompose(self.label,
                                       size=size)  # extract patches
        patches_shape = [patches.shape[0], patches.shape[1], patches.shape[2]]

        print('--Obtain semi-contracted graphs from patches ...')
        # run contraction avoiding boundary nodes for each patch
        graphs = []
        inds = np.arange(0, len(patchesid), n_parallel)
        patchesid_ = [patchesid[ind:ind + n_parallel] for ind in inds]

        for inds in patchesid_:

            if is_parallel:  # in parallel
                ray.init()
                subpatches = [ray.put(patches[ind]) for ind in inds]
                subgraphs = [
                    GraphParallel.remote(
                        patch,
                        niter=niter1,
                        Sampling=self.sampling,
                        DistParam=self.dist_param,
                        MedParam=self.med_param,
                        SpeedParam=self.speed_param,
                        DegreeThreshold=self.degree_threshold,
                        ClusteringResolution=self.clustering_resolution)
                    for patch in subpatches
                ]
                subgraphs = [ray.get(g) for g in subgraphs]
                ray.shutdown()
                graphs.append(subgraphs)

            else:  # in serial
                subpatches = [patches[ind] for ind in inds]
                subgraphs = [
                    GraphSerial(
                        patch,
                        niter=niter1,
                        Sampling=self.sampling,
                        DistParam=self.dist_param,
                        MedParam=self.med_param,
                        SpeedParam=self.speed_param,
                        DegreeThreshold=self.degree_threshold,
                        ClusteringResolution=self.clustering_resolution)
                    for patch in subpatches
                ]
                subgraphs = [g for g in subgraphs]
                graphs.append(subgraphs)
        graphs = [k1 for k in graphs for k1 in k]  # uravel
        del patches

        # adjust the position of graph nodes coming from each patch
        area = np.sum([k.Area for k in graphs if k is not None])
        pluspos = (size) * np.array(patchesid)
        for plus, g in zip(pluspos, graphs):
            if g is not None:
                AddPos(g, plus)

        print('--Combine semi-contracted graphs ...')
        fullgraph = EmptyGraph()
        nnodes = 0
        for idx, g in enumerate(graphs):
            if g is not None:
                print('    graph id ' + str(idx) + ' added')
                nnodes += fullgraph.number_of_nodes()
                new_nodes = nnodes + np.array(range(g.number_of_nodes()))
                mapping = dict(zip(g.GetNodes(), new_nodes))
                g = nx.relabel_nodes(g, mapping)
                fullgraph.add_nodes_from(g.GetNodes())
                fullgraph.add_edges_from(g.GetEdges())
                for k in new_nodes:
                    fullgraph.node[k]['pos'] = g.node[k]['pos']
                    fullgraph.node[k]['r'] = g.node[k]['r']
                    fullgraph.node[k]['ext'] = g.node[k]['ext']
            else:
                print('    graph id ' + str(idx) + ' is None')

        fullgraph = fixG(fullgraph)
        fullgraph.Area = area
        del graphs

        print('--Stitch semi-contracted graphs ...')
        nodes = np.array(
            [k for k in fullgraph.GetNodes() if fullgraph.node[k]['ext'] == 1])
        nodesid = dict(zip(range(len(nodes)), nodes))
        pos = np.array([fullgraph.node[k]['pos'] for k in nodes])
        pos_tree = neighbors.KDTree(pos)
        a = pos_tree.query_radius(pos, r=1.0)
        new_edges = [[(nodesid[k[0]], nodesid[k1]) for k1 in k[1:]] for k in a]
        new_edges = [k1 for k in new_edges for k1 in k]
        fullgraph.add_edges_from(new_edges)

        del a
        del nodes
        del pos
        del new_edges
        del pos_tree

        print('--Contract ext nodes ...')
        ContractExt(fullgraph,
                    niter=niter2,
                    DistParam=self.dist_param,
                    MedParam=self.med_param,
                    SpeedParam=self.speed_param,
                    DegreeThreshold=self.degree_threshold,
                    ClusteringResolution=self.clustering_resolution)

        print('--Generate final skeleton ...')
        contract_final = ContractGraph(Graph=fullgraph)
        contract_final.Update(DistParam=self.dist_param,
                              MedParam=self.med_param,
                              SpeedParam=self.speed_param,
                              DegreeThreshold=self.degree_threshold,
                              StopParam=self.stop_param,
                              NFreeIteration=self.n_free_iteration)
        gc = contract_final.GetOutput()

        print('--Refine final skeleton ...')
        refine = RefineGraph(Graph=gc)
        refine.Update()
        gr = refine.GetOutput()
        gr = fixG(gr)

        # ----- return ----#
        if ret:
            return gr
        else:
            self.Graph = gr