def seeds_to_vol(trk_fn, tom_fn, out_fn, reverse):
    # Open references volume to get the affine transformation
    ref_data, ref_affine = load_nifti(tom_fn)

    # Invert the reference affine transform to go from coordinates to voxels
    inverted_ref_affine = np.linalg.inv(ref_affine)

    # Open tract you want to convert to volume
    tractogram = nib.streamlines.load(trk_fn)
    streamlines = tractogram.streamlines
    streamlines = np.array(streamlines)

    # Get the seeds
    #seeds = streamlines[:,0,:]
    if reverse == 0:
        seeds = [s[0] for s in streamlines]
    else:
        seeds = [s[-1] for s in streamlines]

    # Create a template output volume
    result = np.zeros((145,174,145))

    # Set seed points to 1
    for coord in seeds:
        x, y, z = list(utils.apply_affine(aff=inverted_ref_affine, pts=np.array([coord])))[0]
        x, y, z = int(x), int(y), int(z)
        result[x][y][z] = 1

    # Save result
    nifti_result = nib.Nifti1Image(result.astype("uint8"), ref_affine)
    nib.save(nifti_result, out_fn)
示例#2
0
def load_streamlines_v2(fn, ref):
    streams, header = trackvis.read(fn)

    data, ref_affine = load_nifti(ref)

    transformed = []
    for sl in streams:
        result = utils.apply_affine(aff=ref_affine, pts=sl[0])
        transformed.append(result)

    transformed = np.array(transformed)
    original = np.array([sl[0] for sl in streams])

    return [original, transformed]
def get_data(tom_fn, tractogram_fn, beginnings_fn, endings_fn, mean, sdev,
             coord2voxel):
    #print('-----')

    # Load TOM volume and preprocess
    tom = nib.load(tom_fn).get_data()  # 144 x 144 x 144 x 3
    tom = (tom - mean) / sdev  # normalise based on dataset mean/stdev
    tom = torch.from_numpy(np.float32(tom))
    tom = tom.permute(3, 0, 1, 2)  # channels first for pytorch

    # On-the-fly augmentation
    noise_stdev = torch.rand(1) * 0.05
    noise = torch.normal(mean=torch.zeros(tom.size()),
                         std=torch.ones(tom.size()) * noise_stdev)
    tom += noise

    # Load the beginnings segmentation volume
    beginnings = nib.load(beginnings_fn).get_data()  # 144 x 144 x 144 x 3
    beginnings = torch.from_numpy(np.float32(beginnings))
    beginnings = beginnings.permute(3, 0, 1, 2)  # channels first for pytorch

    # Load the endings segmentation volume
    endings = nib.load(endings_fn).get_data()  # 144 x 144 x 144 x 3
    endings = torch.from_numpy(np.float32(endings))
    endings = endings.permute(3, 0, 1, 2)  # channels first for pytorch

    # Concatenate the seed volume as an extra channel of the first dimension of the TOM volume
    tom_seed = torch.cat((tom, beginnings), dim=0)
    tom_seed = torch.cat((tom_seed, endings), dim=0)

    # Load the tractogram
    #print('Loading tractogram')
    #t0 = time.time()
    streamlines, header = trackvis.read(tractogram_fn)
    streamlines = [s[0] for s in streamlines]
    streamlines = np.array(streamlines)

    # Convert to voxel coordinates and normalise
    streamlines = utils.apply_affine(aff=coord2voxel, pts=streamlines)
    streamlines /= 144

    # Get seed coordinates and remove from streamlines
    seeds = np.array([sl[0].copy() for sl in streamlines], dtype=np.float32)
    for i in range(len(streamlines)):
        streamlines[i] -= streamlines[i][0]

    # Sort seeds and streamlines by seed points x, then y, then z
    streamlines = list(streamlines)
    streamlines = [
        x for _, x in sorted(
            zip(seeds, streamlines),
            key=lambda pair: [pair[0][0], pair[0][1], pair[0][2]])
    ]
    seeds = sorted(seeds, key=lambda k: [k[0], k[1], k[2]])
    seeds = np.array(seeds, dtype=np.float32)

    # Convert seeds to torch
    seeds = torch.from_numpy(seeds)
    seeds = seeds.permute(1, 0)  # channels first for pytorch

    # automatically converts list to numpy array and reshapes it
    # (num_sl, points_per_sl, 3) -> (sqrt(num_sl), sqrt(num_sl), points_per_sl*3)
    # Performed in 2 successive steps because I don't know if it works if I do it in one step
    streamlines = np.reshape(streamlines,
                             (int(num_streamlines**(1 / 2)),
                              int(num_streamlines**(1 / 2)), num_points, 3))
    streamlines = np.reshape(streamlines,
                             (int(num_streamlines**(1 / 2)),
                              int(num_streamlines**(1 / 2)), num_points * 3))
    tractogram = torch.from_numpy(streamlines)
    tractogram = tractogram.permute(2, 0, 1)  # channels first for pytorch

    return [[tom_seed, seeds], tractogram]
示例#4
0
    """

_, affine = load_nifti(
    '../../data/final_hairnet_dataset/preprocessed/TOMs/599469_0_CST_left.nii.gz'
)
inverse_affine = np.linalg.inv(affine)
for tractogram_fn in glob(
        '../../data/final_hairnet_dataset/not_preprocessed/tractograms/*.trk'
)[:3]:
    # Load trk
    streamlines, header = trackvis.read(tractogram_fn)
    streamlines = [s[0] for s in streamlines]
    streamlines = np.array(streamlines)

    # Convert to voxel space
    streamlines = utils.apply_affine(aff=inverse_affine, pts=streamlines)

    coords = np.reshape(streamlines, (-1, 3))
    xs, ys, zs = coords[:, 0], coords[:, 1], coords[:, 2]

    #fig = px.scatter_3d(x=xs, y=ys, z=zs, color=np.array([xs[i] for i in range(len(xs))]), range_x=[0,144], range_y=[0,144], range_z=[0,144])
    fig = go.Figure(data=go.Scatter3d(
        x=xs,
        y=ys,
        z=zs,
        mode='markers',
        marker=dict(size=1, color=[xs[i] for i in range(len(xs))])))
    #fig.update_layout(scene_aspectmode='cube')

    fig.update_layout(scene_aspectmode='cube',
                      scene=dict(xaxis=dict(range=[0, 144]),
示例#5
0
def gen_TOM(streamlines, ref_file):
    ref_data, ref_affine = load_nifti(ref_file)
    coords_to_array = np.linalg.inv(ref_affine) # invert the matrix to convert from points to list indices

    #print(utils.apply_affine(aff=np.linalg.inv(ref_affine), pts=np.array([[0,0,0]])))

   
    # coordinates
    a = []
    b = []
    c = []

    # vectors
    d = []
    e = []
    f = []

    i = 0

    collection = {}
    collection = []
    collection = [[[[np.array([0,0,0])] for x in range(len(ref_data[z][y]))] for y in range(len(ref_data[z]))] for z in range(len(ref_data))]
    streamlines = list(utils.subsegment(streamlines, abs(ref_affine[0,0]/4)))
    for sl in streamlines:
        for point in range(len(sl)-1):
            if i % 10 == 0:
                x, y, z = sl[point]

                # Convert from (0,0,0) = centre, to (0,0,0) = top left
                x, y, z = list(utils.apply_affine(aff=np.linalg.inv(ref_affine), pts=np.array([[x,y,z]])))[0]

                # Compute direction of movement
                vector = np.array(sl[point+1]) - np.array(sl[point])

                # Normalise the magnitude
                size = (vector[0]**2 + vector[1]**2 + vector[2]**2)**(1/2)
                u, v, w = vector / size if size != 0 else [0, 0, 0]

                x, y, z = [int(x), int(y), int(z)]
                collection[z][y][x].append(np.array([u, v, w]))
                #if (z, y, x) in collection:
                #    collection[(z, y, x)].append((u, v, w))
                #else:
                #    collection[(z, y, x)] = [(u, v, w)]
                #collection[z, y, x]
                #a.append(x)
                #b.append(y)
                #c.append(z)
                #d.append(u)
                #e.append(v)
                #f.append(w)
            i += 1

    # get means
    for z in range(len(collection)):
        for y in range(len(collection[z])):
            for x in range(len(collection[z][y])):
                collection[z][y][x] = sum(collection[z][y][x])/len(collection[z][y][x])

    collection = np.array(collection)

    #for z in range(len(collection)):
    #    im = collection[z]
        #cv2.imshow('TOM', np.uint8(255*(im - np.min(im))/(np.max(im) - np.min(im))))
    #    cv2.imshow('TOM', np.uint8(im*255))
    #    cv2.waitKey(0)
    #fig = plt.figure()
    #ax = Axes3D(fig)
    #ax.scatter(a, b, c)
    #plt.show()

    #print(min(a), max(a), abs(max(a)-min(a)))
    #print(min(b), max(b), abs(max(b)-min(b)))
    #print(min(c), max(c), abs(max(c)-min(c)))
    #fig = plt.figure()
    #ax = fig.add_subplot(111, projection='3d')
    #ax.quiver(a, b, c, d, e, f)
    #plt.show()

    generated_nifti = nib.Nifti1Image(collection, affine=ref_affine)
    nib.save(generated_nifti, 'test.nii.gz')

    return collection
示例#6
0
def get_data(tom_fn, tractogram_fn, beginnings_fn, endings_fn, mean, sdev,
             coord2voxel):
    # Load TOM volume and normalise
    tom = nib.load(tom_fn).get_data()  # 144 x 144 x 144 x 3
    tom = (tom - mean) / sdev  # normalise based on dataset mean/stdev

    # Convert to torch and reshape
    tom = torch.from_numpy(np.float32(tom))
    tom = tom.permute(3, 0, 1, 2)  # channels first for pytorch

    # On-the-fly augmentation
    noise_stdev = torch.rand(1) * 0.05
    noise = torch.normal(mean=torch.zeros(tom.size()),
                         std=torch.ones(tom.size()) * noise_stdev)
    tom += noise

    # Load the beginnings segmentation volume
    beginnings = nib.load(beginnings_fn).get_data()  # 144 x 144 x 144 x 3
    beginnings = torch.from_numpy(np.float32(beginnings))
    beginnings = beginnings.permute(3, 0, 1, 2)  # channels first for pytorch

    # Load the endings segmentation volume
    endings = nib.load(endings_fn).get_data()  # 144 x 144 x 144 x 3
    endings = torch.from_numpy(np.float32(endings))
    endings = endings.permute(3, 0, 1, 2)  # channels first for pytorch

    # Concatenate the seed volume as an extra channel of the first dimension of the TOM volume
    tom_seed = torch.cat((tom, beginnings), dim=0)
    tom_seed = torch.cat((tom_seed, endings), dim=0)
    #print('The following should be 5 x 144 x 144 x 144')
    #print(tom_seed.size())

    # Load the tractogram
    streamlines, header = trackvis.read(tractogram_fn)
    streamlines = [s[0] for s in streamlines]
    streamlines = np.array(streamlines)

    # Convert to voxel coordinates
    streamlines = utils.apply_affine(aff=coord2voxel, pts=streamlines)

    # Normalise tractogram into [0,1] range
    streamlines /= 144

    # Get seed coordinates
    seeds = np.array([sl[0].copy() for sl in streamlines], dtype=np.float32)
    seeds = torch.from_numpy(seeds)
    seeds = seeds.permute(1, 0)  # channels first for pytorch
    #print('The following should be: (3,1024):')
    #print(seeds.size())

    # automatically converts list to numpy array and reshapes it
    # (num_sl, points_per_sl, 3) -> (sqrt(num_sl), sqrt(num_sl), points_per_sl*3)
    # Performed in 2 successive steps because I don't know if it works if I do it in one step

    # Remove seed points from streamlines
    for i in range(len(streamlines)):
        streamlines[i] -= streamlines[i][0]

    # Reshape streamlines and convert to tensor
    streamlines = np.reshape(streamlines,
                             (int(num_streamlines**(1 / 2)),
                              int(num_streamlines**(1 / 2)), num_points, 3))
    streamlines = np.reshape(streamlines,
                             (int(num_streamlines**(1 / 2)),
                              int(num_streamlines**(1 / 2)), num_points * 3))
    tractogram = torch.from_numpy(streamlines)
    tractogram = tractogram.permute(2, 0, 1)  # channels first for pytorch

    # Scale augmentation
    #tom_seed = torch.from_numpy(zoom(tom_seed, (1, (x_zoom, y_zoom, z_zoom), order=0)))

    return [[tom_seed, seeds], tractogram]