Ejemplo n.º 1
0
def _test():
    #%%
    import numpy as np
    import scipy.ndimage as ndi
    import ClearMap.DataProcessing.LargeData as ld
    import ClearMap.Visualization.Plot3d as p3d
    import ClearMap.DataProcessing.ConvolvePointList as cpl
    import ClearMap.ImageProcessing.Skeletonization.Topology3d as t3d
    import ClearMap.ImageProcessing.Skeletonization.SkeletonCleanUp as scu

    import ClearMap.ImageProcessing.Tracing.Connect as con
    reload(con)

    data = np.load('/home/ckirst/Desktop/data.npy')
    binary = np.load('/home/ckirst/Desktop/binarized.npy')
    skel = np.load('/home/ckirst/Desktop/skel.npy')
    #points = np.load('/home/ckirst/Desktop/pts.npy');

    data = np.copy(data, order='F')
    binary = np.copy(binary, order='F')
    skel = np.copy(skel, order='F')
    skel_copy = np.copy(skel, order='F')
    points = np.ravel_multi_index(np.where(skel), skel.shape, order='F')

    skel, points = scu.cleanOpenBranches(skel,
                                         skel_copy,
                                         points,
                                         length=3,
                                         clean=True)
    deg = cpl.convolve3DIndex(skel, t3d.n26, points)

    ends, isolated = con.findEndpoints(skel, points, border=25)
    special = np.sort(np.hstack([ends, isolated]))

    ends_xyz = np.array(np.unravel_index(ends, data.shape, order='F')).T
    isolated_xyz = np.array(np.unravel_index(isolated, data.shape,
                                             order='F')).T
    special_xyz = np.vstack([ends_xyz, isolated_xyz])

    #%%
    import ClearMap.ParallelProcessing.SharedMemoryManager as smm
    data_s = smm.asShared(data, order='F')
    binary_s = smm.asShared(binary.view('uint8'), order='F')
    skel_s = smm.asShared(skel.view('uint8'), order='F')

    smm.clean()
    res = con.addConnections(data_s,
                             binary_s,
                             skel_s,
                             points,
                             radius=20,
                             start_points=None,
                             add_to_skeleton=True,
                             add_to_mask=True,
                             verbose=True,
                             processes=4,
                             debug=False,
                             block_size=10)

    skel_s = skel_s.view(bool)
    binary_s = binary_s.view(bool)

    #%%
    mask_img = np.asarray(binary, dtype=int, order='A')
    mask_img[:] = mask_img + binary_s
    mask_img[:] = mask_img + skel

    data_img = np.copy(data, order='A')
    data_img[skel] = 120

    mask_img_f = np.reshape(mask_img, -1, order='A')
    data_img_f = np.reshape(data_img, -1, order='A')

    mask_img_f[res] = 7
    data_img_f[res] = 512

    mask_img_f[special] = 8
    data_img_f[special] = 150

    for d in [3, 4, 5]:
        mask_img_f[points[deg == d]] = d + 1

    try:
        con.viewer[0].setSource(mask_img)
        con.viewer[1].setSource(data_img)
    except:
        con.viewer = p3d.plot([mask_img, data_img])

    con.viewer[0].setMinMax([0, 8])
    con.viewer[1].setMinMax([24, 160])

    #%%
    mask = binary
    data_new = np.copy(data, order='A')
    data_new[skel] = 120

    skel_new = np.asarray(skel, dtype=int, order='A')
    skel_new[:] = skel_new + binary

    binary_new = np.copy(binary, order='A')
    qs = []
    for i, e in enumerate(special):
        print('------')
        print('%d / %d' % (i, len(special)))
        path, quality = con.connectPoint(data,
                                         mask,
                                         special,
                                         i,
                                         radius=25,
                                         skeleton=skel,
                                         tubeness=None,
                                         remove_local_mask=True,
                                         min_quality=15.0,
                                         verbose=True,
                                         maxSteps=15000,
                                         costPerDistance=1.0)

        #print path, quality
        if len(path) > 0:
            qs.append(quality * 1.0 / len(path))

            q = con.addPathToMask(skel_new, path, value=7)
            q = con.addPathToMask(data_new, path, value=512)
            binary_new = con.addDilatedPathToMask(binary_new,
                                                  path,
                                                  iterations=1)

    skel_new[:] = skel_new + binary_new
    q = con.addPathToMask(skel_new, special_xyz, value=6)
    for d in [3, 4, 5]:
        xyz = np.array(
            np.unravel_index(points[deg == d], data.shape, order='F')).T
        q = con.addPathToMask(skel_new, xyz, value=d)
    q = con.addPathToMask(data_new, special_xyz, value=150)

    try:
        con.viewer[0].setSource(skel_new)
        con.viewer[1].setSource(data_new)
    except:
        con.viewer = p3d.plot([skel_new, data_new])

    con.viewer[0].setMinMax([0, 8])
    con.viewer[1].setMinMax([24, 160])

    #%%
    import matplotlib.pyplot as plt
    plt.figure(1)
    plt.clf()
    #plt.plot(qs);
    plt.hist(qs)

    #%%
    i = 20
    i = 21
    i = 30
    i = 40
    r = 25
    center = np.unravel_index(ends[i], data.shape)
    print(center, data.shape)
    mask = binary
    path = con.tracePointToMask(data,
                                mask,
                                center,
                                radius=r,
                                points=special_xyz,
                                plot=True,
                                skel=skel,
                                binary=binary,
                                tubeness=None,
                                removeLocalMask=True,
                                maxSteps=None,
                                verbose=False,
                                costPerDistance=0.0)

    #%%

    nbs = ap.findNeighbours(ends, i, skel.shape, skel.strides, r)
    center = np.unravel_index(ends[i], skel.shape)

    nbs_xyz = np.array(np.unravel_index(nbs, skel.shape)).T
    dists = nbs_xyz - center
    dists = np.sum(dists * dists, axis=1)

    nb = np.argmin(dists)

    center = np.unravel_index(ends[i], data.shape)
    print(center, data.shape)
    mask = binary
    path = con.tracePointToNeighbor(data,
                                    mask,
                                    center,
                                    nbs_xyz[nb],
                                    radius=r,
                                    points=special_xyz,
                                    plot=True,
                                    skel=skel,
                                    binary=binary,
                                    tubeness=None,
                                    removeLocalMask=True,
                                    maxSteps=None,
                                    verbose=False,
                                    costPerDistance=0.0)

    #%%

    import ClearMap.ImageProcessing.Filter.FilterKernel as fkr
    dog = fkr.filterKernel('DoG', size=(13, 13, 13))
    dv.plot(dog)

    data_filter = ndi.correlate(np.asarray(data, dtype=float), dog)
    data_filter -= data_filter.min()
    data_filter = data_filter / 3.0
    #dv.dualPlot(data, data_filter);

    #%%add all paths
    reload(con)

    r = 25
    mask = binary
    data_new = data.copy()
    data_new[skel] = 120

    skel_new = np.asarray(skel, dtype=int)
    skel_new = skel_new + binary

    binary_new = binary.copy()

    for i, e in enumerate(special):
        center = np.unravel_index(e, data.shape)

        print(i, e, center)
        path = con.tracePointToMask(data,
                                    mask,
                                    center,
                                    radius=r,
                                    points=special_xyz,
                                    plot=False,
                                    skel=skel,
                                    binary=binary,
                                    tubeness=None,
                                    removeLocalMask=True,
                                    maxSteps=15000,
                                    costPerDistance=1.0)

        q = con.addPathToMask(skel_new, path, value=7)
        q = con.addPathToMask(data_new, path, value=512)
        binary_new = con.addDilatedPathToMask(binary_new, path, iterations=1)

    q = con.addPathToMask(skel_new, special_xyz, value=6)
    for d in [3, 4, 5]:
        xyz = np.array(np.unravel_index(points[deg == d], data.shape)).T
        q = con.addPathToMask(skel_new, xyz, value=d)
    q = con.addPathToMask(data_new, special_xyz, value=150)

    skel_new = skel_new + binary_new
    try:
        con.viewer[0].setSource(skel_new)
        con.viewer[1].setSource(data_new)
    except:
        con.viewer = dv.dualPlot(skel_new, data_new)

    con.viewer[0].setMinMax([0, 8])
    con.viewer[1].setMinMax([24, 160])

    #%%

    import ClearMap.ImageProcessing.Skeletonization.Skeletonize as skl

    skel_2 = skl.skeletonize3D(binary_new.copy())

    #%%

    np.save('/home/ckirst/Desktop/binarized_con.npy', binary_new)
    #%%

    # write image

    import ClearMap.IO.IO as io

    #r = np.asarray(128 * binary_new, dtype = 'uint8');
    #g = r.copy(); b = r.copy();
    #r[:] = r + 127 * skel_2[0];
    #g[:] = g - 128 * skel_2[0];
    #b[:] = b - 128 * skel_2[0];
    #img = np.stack((r,g,b), axis = 3)

    img = np.asarray(128 * binary_new, dtype='uint8')
    img[:] = img + 127 * skel_2[0]

    io.writeData('/home/ckirst/Desktop/3d.tif', img)
Ejemplo n.º 2
0
def connectPoint(data,
                 mask,
                 endpoints,
                 start_index,
                 radius,
                 tubeness=None,
                 min_quality=None,
                 remove_local_mask=True,
                 skeleton=None,
                 verbose=False,
                 **trace_parameter):
    """Tries to connect an end point"""

    #outine:
    # find neighbour end points and try to connect to nearest one
    # if path score good enough add path and remove two endpoints
    # else try to connect to binarized image
    # if path score good enugh connect to closest skeleton point
    # else not connectable

    #assumes everything is in fotran order
    strides = np.array(data.strides) / data.itemsize
    shape = data.shape
    #print strides, shape

    center_flat = endpoints[start_index]
    center_xyz = np.array(np.unravel_index(center_flat, data.shape, order='F'))

    mask_nbh = extractNeighbourhood(mask, center_xyz, radius)
    data_nbh = np.asarray(extractNeighbourhood(data, center_xyz, radius),
                          dtype=float,
                          order='F')
    shape_nbh = mask_nbh.shape

    center_nbh_xyz = np.zeros(3, dtype=int) + radius
    #center_nbh_flat = np.ravel_multi_index(center_nbh_xyz, shape_nbh, order = 'F');

    if tubeness is None:
        tubeness_nbh = cur.tubeness(
            ndi.gaussian_filter(np.asarray(data_nbh, dtype=float), sigma=1.0))
        tubeness_nbh = np.asarray(tubeness_nbh, order='F')
    else:
        tubeness_nbh = extractNeighbourhood(tubeness, center_xyz, radius)

    mask_nbh_label = np.empty(shape_nbh, dtype='int32', order='F')
    _ = ndi.label(mask_nbh,
                  structure=np.ones((3, 3, 3), dtype=bool),
                  output=mask_nbh_label)
    local_nbh = mask_nbh_label[tuple(center_nbh_xyz)] == mask_nbh_label

    # end point neighbours
    nbs_flat = ap.findNeighbours(endpoints, start_index, shape, strides,
                                 radius)

    if len(nbs_flat) > 0:
        nbs_nbh_xyz = np.vstack(np.unravel_index(
            nbs_flat, shape, order='F')).T - center_xyz + center_nbh_xyz
        nbs_nbh_flat = np.ravel_multi_index(nbs_nbh_xyz.T,
                                            shape_nbh,
                                            order='F')

        # remove connected neighbours
        non_local_nbh_flat = np.reshape(np.logical_not(local_nbh),
                                        -1,
                                        order='F')
        nbs_nbh_non_local_flat = nbs_nbh_flat[non_local_nbh_flat[nbs_nbh_flat]]

        if len(nbs_nbh_non_local_flat) > 0:
            #find nearest neighbour
            nbs_nbh_non_local_xyz = np.vstack(
                np.unravel_index(nbs_nbh_non_local_flat, shape, order='F')).T

            nbs_nbh_non_local_dist = nbs_nbh_non_local_xyz - center_nbh_xyz
            nbs_nbh_non_local_dist = np.sum(nbs_nbh_non_local_dist *
                                            nbs_nbh_non_local_dist,
                                            axis=1)

            neighbor_nbh_xyz = nbs_nbh_non_local_xyz[np.argmin(
                nbs_nbh_non_local_dist)]

            path, quality = trc.trace(data_nbh,
                                      tubeness_nbh,
                                      center_nbh_xyz,
                                      neighbor_nbh_xyz,
                                      verbose=False,
                                      returnQuality=True,
                                      **trace_parameter)

            if len(path) > 0:
                if quality / len(path) < min_quality:
                    if verbose:
                        print(
                            'Found good path to neighbour of length = %d with quality = %f (per length = %f) [%d / %d nonlocal neighbours]'
                            % (len(path), quality, quality / len(path),
                               len(nbs_nbh_non_local_flat), len(nbs_flat)))
                        #print path
                    return path + center_xyz - center_nbh_xyz, quality
                else:
                    if verbose:
                        print(
                            'Found bad  path to neighbour of length = %d with quality = %f (per length = %f) [%d / %d nonlocal neighbours]'
                            % (len(path), quality, quality / len(path),
                               len(nbs_nbh_non_local_flat), len(nbs_flat)))
                        #print path
            else:
                if verbose:
                    print(
                        'Found no path to neighbour [%d / %d nonlocal neighbours]'
                        % (len(nbs_nbh_non_local_flat), len(nbs_flat)))
                    #print path

    #tracing to neares neighbour failed
    if verbose:
        print('Found no valid path to neighbour, now tracing to binary!')
        #print path

    # Tracing to next binary
    if remove_local_mask:
        mask_nbh[local_nbh] = False

    distance_nbh = ndi.distance_transform_edt(np.logical_not(mask_nbh))
    distance_nbh = np.asarray(distance_nbh, order='F')

    path, quality = trc.traceToMask(data_nbh,
                                    tubeness_nbh,
                                    center_nbh_xyz,
                                    distance_nbh,
                                    verbose=False,
                                    returnQuality=True,
                                    **trace_parameter)

    if len(path) > 0:
        if quality / len(path) < min_quality:
            if verbose:
                print(
                    'Found good path to binary of length = %d with quality = %f (per length = %f)'
                    % (len(path), quality, quality / len(path)))
                #print path

            # trace to skeleton
            if skeleton is not None:
                #find closest point on skeleton
                final_xyz = path[0]
                skeleton_nbh = extractNeighbourhood(skeleton, center_xyz,
                                                    radius)
                local_end_path_nbh = mask_nbh_label[tuple(
                    final_xyz)] == mask_nbh_label
                skeleton_nbh_dxyz = np.vstack(
                    np.where(np.logical_and(skeleton_nbh,
                                            local_end_path_nbh))).T - final_xyz
                if len(
                        skeleton_nbh_dxyz
                ) == 0:  # could not find skeleton nearby -> give up for now
                    return path + center_xyz - center_nbh_xyz, quality

                skeleton_nbh_dist = np.sum(skeleton_nbh_dxyz *
                                           skeleton_nbh_dxyz,
                                           axis=1)
                closest_dxyz = skeleton_nbh_dxyz[np.argmin(skeleton_nbh_dist)]
                closest_xyz = closest_dxyz + final_xyz
                #print path[0], path[-1]
                #print center_nbh_xyz, closest_dxyz

                #generate pixel path
                max_l = np.max(np.abs(closest_dxyz)) + 1
                path_add_xyz = np.vstack([
                    np.asarray(np.linspace(f, c, max_l), dtype=int)
                    for f, c in zip(final_xyz, closest_xyz)
                ]).T
                path_add_flat = np.ravel_multi_index(path_add_xyz.T, shape_nbh)
                _, ids = np.unique(path_add_flat, return_index=True)
                path_add_xyz = path_add_xyz[ids]
                #print path_add_xyz;
                path = np.vstack([path, path_add_xyz])
                # note: this is not an ordered path anymore!

            return path + center_xyz - center_nbh_xyz, quality
        else:
            if verbose:
                print(
                    'Found bad  path to binary of length = %d with quality = %f (per length = %f)'
                    % (len(path), quality, quality / len(path)))
                #print path

    if verbose:
        print('Found no valid path to binary!')

    return np.zeros((0, 3)), 0