Example #1
0
def create_tract_mask(trk_file_path, mask_output_path, ref_img_path, hole_closing=0, blob_th=10):
    """Adapted from https://github.com/MIC-DKFZ/TractSeg/issues/39#issuecomment-496181262
    Creates binary mask from streamlines in .trk file.

    Args:
      trk_file_path: Path for the .trk file
      mask_output_path: Path to save the binary mask.
      ref_img_path: Path to the reference image to get affine and shape
      hole_closing: Integer for closing the holes. (Default value = 0)
      blob_th: Threshold for removing small blobs. (Default value = 10)

    Returns:
        None
    """

    ref_img = nib.load(ref_img_path)
    ref_affine = ref_img.affine
    ref_shape = ref_img.shape

    streamlines = nib.streamlines.load(trk_file_path).streamlines

    # Upsample Streamlines  (very important, especially when using DensityMap Threshold. Without upsampling eroded
    # results)
    print("Upsampling...")
    print("Nr of points before upsampling " + str(get_number_of_points(streamlines)))
    max_seq_len = abs(ref_affine[0, 0] / 4)
    print("max_seq_len: {}".format(max_seq_len))
    streamlines = list(utils_trk.subsegment(streamlines, max_seq_len))
    print("Nr of points after upsampling " + str(get_number_of_points(streamlines)))

    # Remember: Does not count if a fibers has no node inside of a voxel -> upsampling helps, but not perfect
    # Counts the number of unique streamlines that pass through each voxel -> oversampling does not distort result
    dm = utils_trk.density_map(streamlines, ref_shape, affine=ref_affine)

    # Create Binary Map
    dm_binary = dm > 1  # Using higher Threshold problematic, because often very sparse
    dm_binary_c = dm_binary

    # Filter Blobs
    dm_binary_c = remove_small_blobs(dm_binary_c, threshold=blob_th)

    # Closing of Holes (not ideal because tends to remove valid holes, e.g. in MCP)
    if hole_closing > 0:
        size = hole_closing
        dm_binary_c = ndimage.binary_closing(dm_binary_c, structure=np.ones((size, size, size))).astype(dm_binary.dtype)

    # Save Binary Mask
    dm_binary_img = nib.Nifti1Image(dm_binary_c.astype("uint8"), ref_affine)
    nib.save(dm_binary_img, mask_output_path)
Example #2
0
def filter_streamlines_leaving_mask(streamlines, mask):
    """
    Remove all streamlines that exit the mask
    """
    max_seq_len = 0.1
    streamlines = list(utils_trk.subsegment(streamlines, max_seq_len))

    new_str_idxs = []
    for i, streamline in enumerate(streamlines):
        new_str_idxs.append(i)
        for point in streamline:
            if mask[int(point[0]), int(point[1]), int(point[2])] == 0:
                new_str_idxs.pop()
                break
    return [streamlines[idx] for idx in new_str_idxs]
Example #3
0
def tract_to_binary(file_in: str, file_out: str, ref_img_path: str):
    HOLE_CLOSING = 0

    # choose from "trk" or "trk_legacy"
    #  Use "trk_legacy" for zenodo dataset v1.1.0 and below
    #  Use "trk" for zenodo dataset v1.2.0
    tracking_format = "trk"

    ref_img = nib.load(ref_img_path)
    ref_affine = ref_img.get_affine()
    ref_shape = ref_img.get_data().shape

    streams, hdr = trackvis.read(file_in)
    streamlines = [s[0] for s in streams]  # list of 2d ndarrays

    if tracking_format == "trk_legacy":
        streams, hdr = trackvis.read(file_in)
        streamlines = [s[0] for s in streams]
    else:
        sl_file = nib.streamlines.load(file_in)
        streamlines = sl_file.streamlines

    # Upsample Streamlines (very important, especially when using DensityMap Threshold. Without upsampling eroded results)
    max_seq_len = abs(ref_affine[0, 0] / 4)
    streamlines = list(utils_trk.subsegment(streamlines, max_seq_len))

    # Remember: Does not count if a fibers has no node inside of a voxel -> upsampling helps, but not perfect
    # Counts the number of unique streamlines that pass through each voxel -> oversampling does not distort result
    dm = utils_trk.density_map(streamlines, ref_affine, ref_shape)

    # Create Binary Map
    dm_binary = dm > 0  # Using higher Threshold problematic, because tends to remove valid parts (sparse fibers)
    dm_binary_c = dm_binary

    # Filter Blobs (might remove valid parts) -> do not use
    # dm_binary_c = remove_small_blobs(dm_binary_c, threshold=10)

    # Closing of Holes (not ideal because tends to remove valid holes, e.g. in MCP) -> do not use
    # size = 1
    # dm_binary_c = ndimage.binary_closing(dm_binary_c, structure=np.ones((size, size, size))).astype(dm_binary.dtype)

    # Save Binary Mask
    dm_binary_img = nib.Nifti1Image(dm_binary_c.astype("uint8"), ref_affine)
    nib.save(dm_binary_img, file_out)
Example #4
0
    ref_affine = ref_img.affine
    ref_shape = ref_img.get_data().shape

    streams, hdr = trackvis.read(file_in)
    streamlines = [s[0] for s in streams]  # list of 2d ndarrays

    if tracking_format == "trk_legacy":
        streams, hdr = trackvis.read(file_in)
        streamlines = [s[0] for s in streams]
    else:
        sl_file = nib.streamlines.load(file_in)
        streamlines = sl_file.streamlines

    #Upsample Streamlines (very important, especially when using DensityMap Threshold. Without upsampling eroded results)
    max_seq_len = abs(ref_affine[0, 0] / 4)
    streamlines = list(utils_trk.subsegment(streamlines, max_seq_len))

    # Remember: Does not count if a fibers has no node inside of a voxel -> upsampling helps, but not perfect
    # Counts the number of unique streamlines that pass through each voxel -> oversampling does not distort result
    dm = utils_trk.density_map(streamlines, ref_shape, affine=ref_affine)

    # Create Binary Map
    dm_binary = dm > 0  # Using higher Threshold problematic, because tends to remove valid parts (sparse fibers)
    dm_binary_c = dm_binary

    #Filter Blobs (might remove valid parts) -> do not use
    #dm_binary_c = remove_small_blobs(dm_binary_c, threshold=10)

    #Closing of Holes (not ideal because tends to remove valid holes, e.g. in MCP) -> do not use
    # size = 1
    # dm_binary_c = ndimage.binary_closing(dm_binary_c, structure=np.ones((size, size, size))).astype(dm_binary.dtype)
Example #5
0
def gen_TOM(streamlines, ref_file):
    ref_data, ref_affine = load_nifti(ref_file)
    coords_to_array = np.linalg.inv(ref_affine) # invert the matrix to convert from points to list indices

    #print(utils.apply_affine(aff=np.linalg.inv(ref_affine), pts=np.array([[0,0,0]])))

   
    # coordinates
    a = []
    b = []
    c = []

    # vectors
    d = []
    e = []
    f = []

    i = 0

    collection = {}
    collection = []
    collection = [[[[np.array([0,0,0])] for x in range(len(ref_data[z][y]))] for y in range(len(ref_data[z]))] for z in range(len(ref_data))]
    streamlines = list(utils.subsegment(streamlines, abs(ref_affine[0,0]/4)))
    for sl in streamlines:
        for point in range(len(sl)-1):
            if i % 10 == 0:
                x, y, z = sl[point]

                # Convert from (0,0,0) = centre, to (0,0,0) = top left
                x, y, z = list(utils.apply_affine(aff=np.linalg.inv(ref_affine), pts=np.array([[x,y,z]])))[0]

                # Compute direction of movement
                vector = np.array(sl[point+1]) - np.array(sl[point])

                # Normalise the magnitude
                size = (vector[0]**2 + vector[1]**2 + vector[2]**2)**(1/2)
                u, v, w = vector / size if size != 0 else [0, 0, 0]

                x, y, z = [int(x), int(y), int(z)]
                collection[z][y][x].append(np.array([u, v, w]))
                #if (z, y, x) in collection:
                #    collection[(z, y, x)].append((u, v, w))
                #else:
                #    collection[(z, y, x)] = [(u, v, w)]
                #collection[z, y, x]
                #a.append(x)
                #b.append(y)
                #c.append(z)
                #d.append(u)
                #e.append(v)
                #f.append(w)
            i += 1

    # get means
    for z in range(len(collection)):
        for y in range(len(collection[z])):
            for x in range(len(collection[z][y])):
                collection[z][y][x] = sum(collection[z][y][x])/len(collection[z][y][x])

    collection = np.array(collection)

    #for z in range(len(collection)):
    #    im = collection[z]
        #cv2.imshow('TOM', np.uint8(255*(im - np.min(im))/(np.max(im) - np.min(im))))
    #    cv2.imshow('TOM', np.uint8(im*255))
    #    cv2.waitKey(0)
    #fig = plt.figure()
    #ax = Axes3D(fig)
    #ax.scatter(a, b, c)
    #plt.show()

    #print(min(a), max(a), abs(max(a)-min(a)))
    #print(min(b), max(b), abs(max(b)-min(b)))
    #print(min(c), max(c), abs(max(c)-min(c)))
    #fig = plt.figure()
    #ax = fig.add_subplot(111, projection='3d')
    #ax.quiver(a, b, c, d, e, f)
    #plt.show()

    generated_nifti = nib.Nifti1Image(collection, affine=ref_affine)
    nib.save(generated_nifti, 'test.nii.gz')

    return collection