예제 #1
0
def get_skeleton(seg_fn, out_folder, bfs, res, edgTh, modified_bfs, opt):
    if opt == '0':  # mesh -> skeleton
        seg = ReadH5(seg_fn, 'main')
        CreateSkeleton(seg, out_folder, res, res)

    elif opt == '1':  # skeleton -> dense graph
        import networkx as nx
        print('read skel')
        skel = ReadSkeletons(out_folder,
                             skeleton_algorithm='thinning',
                             downsample_resolution=res,
                             read_edges=True)[1]

        print('save node positions')
        node_pos = np.stack(skel.get_nodes()).astype(int)
        WriteH5(out_folder + 'node_pos.h5', node_pos)

        print('generate dt for edge width')
        seg = ReadH5(seg_fn, 'main')
        sz = seg.shape
        bb = GetBbox(seg > 0)
        seg_b = seg[bb[0]:bb[1] + 1, bb[2]:bb[3] + 1, bb[4]:bb[5] + 1]
        dt = distance_transform_cdt(seg_b, return_distances=True)

        print('generate graph')
        new_graph, wt_dict, th_dict, ph_dict = GetGraphFromSkeleton(skel, dt=dt, dt_bb=[bb[x] for x in [0,2,4]],\
                                                       modified_bfs=modified_bfs)

        print('save as a networkx object')
        edge_list = GetEdgeList(new_graph, wt_dict, th_dict, ph_dict)
        G = nx.Graph(shape=sz)
        # add edge attributes
        G.add_edges_from(edge_list)
        nx.write_gpickle(G, out_folder + 'graph-%s.obj' % (bfs))

    elif opt == '2':  # reduced graph
        import networkx as nx
        G = nx.read_gpickle(out_folder + 'graph-%s.obj' % (bfs))

        n0 = len(G.nodes())
        G = ShrinkGraph_v2(G, threshold=edgTh)
        n1 = len(G.nodes())
        print('#nodes: %d -> %d' % (n0, n1))
        nx.write_gpickle(
            G,
            out_folder + 'graph-%s-%d-%d.obj' % (bfs, edgTh[0], 10 * edgTh[1]))
    elif opt == '3':  # generate h5 for visualization
        import networkx as nx
        G = nx.read_gpickle(out_folder + 'graph-%s-%d-%d.obj' %
                            (bfs, edgTh[0], 10 * edgTh[1]))
        pos = ReadH5(out_folder + 'node_pos.h5', 'main')
        vis = Graph2H5(G, pos)
        WriteH5(
            out_folder + 'graph-%s-%d-%d.h5' % (bfs, edgTh[0], 10 * edgTh[1]),
            vis)
예제 #2
0
def compute_ibex_skeleton_graphs(skeleton, out_folder, input_resolution, downsample_fac):
    input_resolution = np.array(input_resolution)
    downsample_fac = np.array(downsample_fac)
    downsample_res = input_resolution * downsample_fac

    # skeleton_ids = np.unique(skeleton)
    # skeleton_ids = skeleton_ids[skeleton_ids > 0]
    #
    # locations = scipy.ndimage.find_objects(skeleton)
    #
    # ibex_skeletons = []
    # for idx, skeleton_id in enumerate(skeleton_ids):
    #     loc = locations[int(skeleton_id) - 1]
    #     start_pos = np.array([loc[0].start, loc[1].start, loc[2].start], dtype=np.uint16)
    #     skel_mask = (skeleton[loc] == int(skeleton_id))
    #     CreateSkeleton(skeleton, out_folder, input_resolution, downsample_res)
    #     ibex_skeletons.append(ReadSkeletons(out_folder, skeleton_algorithm='thinning',  downsample_resolution=downsample_res, read_edges=True)[0])
    CreateSkeleton(skeleton, out_folder, input_resolution, downsample_res)
    ibex_skeletons = ReadSkeletons(out_folder, skeleton_algorithm='thinning', downsample_resolution=downsample_res, read_edges=True)

    keep_indices = [0] + [i for i, skel in enumerate(ibex_skeletons) if skel.get_nodes().shape[0] > 0]
    ibex_skeletons = [ibex_skeletons[i] for i in keep_indices]
    return ibex_skeletons
예제 #3
0
def compute_thinned_nodes(process_id, seg_ids, skel_vol_full, temp_folder, input_resolution, downsample_fac, output_file_name):
    process_id = str(process_id)
    out_folder = temp_folder + '/' + process_id
    if not os.path.isdir(out_folder):
        os.makedirs(out_folder)

    input_resolution = np.array(input_resolution).astype(np.uint8)
    downsample_fac = np.array(downsample_fac).astype(np.uint8)
    graphs = {}

    with h5py.File(temp_folder + '/(' + process_id + ')' + output_file_name , 'w') as hf_nodes:
        locations = scipy.ndimage.find_objects(skel_vol_full)
        for idx, seg_id in enumerate(seg_ids):
            loc = locations[int(seg_id) - 1]
            start_pos = np.array([loc[0].start, loc[1].start, loc[2].start], dtype=np.uint16)
            skel_mask = (skel_vol_full[loc] == int(seg_id))
            try:
                CreateSkeleton(skel_mask, out_folder, input_resolution, input_resolution*downsample_fac)
                skel_obj = ReadSkeletons(out_folder, skeleton_algorithm='thinning', downsample_resolution=input_resolution*downsample_fac, read_edges=True)[1]
                nodes = start_pos + np.stack(skel_obj.get_nodes()).astype(np.uint16)
            except:
                continue
            hf_nodes.create_dataset('allNodes' + str(seg_id), data=nodes, compression='gzip')
예제 #4
0
파일: demo.py 프로젝트: silky/ibexHelper
import networkx as nx
from scipy.ndimage.morphology import distance_transform_cdt

opt = sys.argv[1]

res = [120, 128, 128]  # z,y,x
out_folder = '../tmp/demo/'
bfs = 'bfs'
modified_bfs = False
edgTh = [40, 1]  # threshold
# 3d segment volume
seg_fn = '/mnt/coxfs01/donglai/data/JWR/snow_cell/cell128nm/neuron/cell26_d.h5'

if opt == '0':  # mesh -> skeleton
    seg = ReadH5(seg_fn, 'main')
    CreateSkeleton(seg, out_folder, res, res)

elif opt == '1':  # skeleton -> dense graph
    print('read skel')
    skel = ReadSkeletons(out_folder,
                         skeleton_algorithm='thinning',
                         downsample_resolution=res,
                         read_edges=True)[1]

    print('save node positions')
    node_pos = np.stack(skel.get_nodes()).astype(int)
    WriteH5(out_folder + 'node_pos.h5', node_pos)

    print('generate dt for edge width')
    seg = ReadH5(seg_fn, 'main')
    sz = seg.shape
예제 #5
0
    args = get_args()  # get args
    seg_fn = args.seg
    output_folder = args.out
    res = [int(i) for i in args.res.split(':')]
    dendrite_ids = np.array([int(i) for i in args.ids.split(':')])

    print('--Load segmentation volume..')
    seg = ReadH5(seg_fn, 'main')

    if args.cs == 1:  # only needed if no skeleton created yet
        print("\n--Create skeletons for given ids:")
        for i, did in enumerate(tqdm(dendrite_ids)):
            blockPrint()
            dendrite_folder = '{}/skels/{}/'.format(output_folder, did)
            CreateSkeleton(seg == did, dendrite_folder, res, res)
            enablePrint()

    print("\n--Analyse skeletons for given ids:")
    lookuptable = np.zeros((dendrite_ids.shape[0], 8))
    for i, did in enumerate(tqdm(dendrite_ids)):
        blockPrint()
        dendrite_folder = '{}/skels/{}/'.format(output_folder, did)
        # load skeleton of given seg id, return it and its graph
        G, skel = skeleton2graph(did, dendrite_folder, seg, res)

        # %% get longest axis
        if args.task == 0 or args.task == 3:  # distance based methods
            #         main_G, _, _, endnodes = search_longest_path_exhaustive(G)
            weight = 'weight'  # longest path based on edge parameter weigth
        elif args.task == 1 or args.task == 4:
예제 #6
0
def compute_skel_graph(process_id, seg_ids, skel_vol_full, temp_folder, input_resolution, downsample_fac, output_file_name, save_graph, use_spline=True):
    process_id = str(process_id)

    out_folder = temp_folder + '/' + process_id
    if not os.path.isdir(out_folder):
        os.makedirs(out_folder)

    input_resolution = np.array(input_resolution).astype(np.uint8)
    downsample_fac = np.array(downsample_fac).astype(np.uint8)
    graphs = {}
    edge_lists = {}
    with h5py.File(temp_folder + '/(' + process_id + ')' + output_file_name , 'w') as hf_nodes:
        locations = scipy.ndimage.find_objects(skel_vol_full)
        for idx, seg_id in enumerate(seg_ids):
            loc = locations[int(seg_id) - 1]
            start_pos = np.array([loc[0].start, loc[1].start, loc[2].start], dtype=np.uint16)
            skel_mask = (skel_vol_full[loc] == int(seg_id))
            try:
                CreateSkeleton(skel_mask, out_folder, input_resolution, input_resolution*downsample_fac)
                skel_obj = ReadSkeletons(out_folder, skeleton_algorithm='thinning', downsample_resolution=input_resolution*downsample_fac, read_edges=True)[1]
                nodes = np.stack(skel_obj.get_nodes()).astype(np.uint16)

                if nodes.shape[0] < 10:
                    # print('skipped skel: {} (too small!)'.format(seg_id))
                    continue

                graph, wt_dict, th_dict, ph_dict = GetGraphFromSkeleton(skel_obj, modified_bfs=False)
                edge_list = GetEdgeList(graph, wt_dict, th_dict, ph_dict)
            except:
                # print('Catched exp in skel: ', seg_id)
                #traceback.print_exc(file=sys.stdout)
                continue

            if save_graph is True:
                if use_spline is True:
                    _, graph = upsample_skeleton_using_splines(edge_list, nodes, skel_mask.shape, return_graph=True, start_pos=start_pos)
                    graphs[seg_id] = graph
                else:
                    g = nx.Graph()
                    nodes_shifted = nodes + start_pos
                    for long_sec in edge_list:
                        path = long_sec[2]['path']
                        g.add_edges_from([(tuple(nodes_shifted[path[i]]), tuple(nodes_shifted[path[i + 1]])) for i in range(len(path)-1)])
                    graphs[seg_id] = g
                    edge_lists[seg_id] = edge_list

            g = nx.Graph()
            g.add_edges_from(edge_list)
            j_ids = [x for x in g.nodes() if g.degree(x) > 2]
            e_ids = [x for x in g.nodes() if g.degree(x) == 1]
            nodes = nodes + start_pos
            if len(j_ids) > 0:
                junctions = nodes[j_ids]
                hf_nodes.create_dataset('j' + str(seg_id), data=junctions, compression='gzip')

            if len(e_ids) > 0:
                end_points = nodes[e_ids]
                hf_nodes.create_dataset('e' + str(seg_id), data=end_points, compression='gzip')
            hf_nodes.create_dataset('allNodes' + str(seg_id), data=nodes, compression='gzip')

        if save_graph is True:
            with open(temp_folder + '/(' + process_id + ')graph.h5', 'wb') as pfile:
                pickle.dump(graphs, pfile, protocol=pickle.HIGHEST_PROTOCOL)
            with open(temp_folder + '/(' + process_id + ')edge_list.pkl', 'wb') as pfile:
                pickle.dump(edge_lists, pfile, protocol=pickle.HIGHEST_PROTOCOL)