def seeds_to_vol(trk_fn, tom_fn, out_fn, reverse): # Open references volume to get the affine transformation ref_data, ref_affine = load_nifti(tom_fn) # Invert the reference affine transform to go from coordinates to voxels inverted_ref_affine = np.linalg.inv(ref_affine) # Open tract you want to convert to volume tractogram = nib.streamlines.load(trk_fn) streamlines = tractogram.streamlines streamlines = np.array(streamlines) # Get the seeds #seeds = streamlines[:,0,:] if reverse == 0: seeds = [s[0] for s in streamlines] else: seeds = [s[-1] for s in streamlines] # Create a template output volume result = np.zeros((145,174,145)) # Set seed points to 1 for coord in seeds: x, y, z = list(utils.apply_affine(aff=inverted_ref_affine, pts=np.array([coord])))[0] x, y, z = int(x), int(y), int(z) result[x][y][z] = 1 # Save result nifti_result = nib.Nifti1Image(result.astype("uint8"), ref_affine) nib.save(nifti_result, out_fn)
def load_streamlines_v2(fn, ref): streams, header = trackvis.read(fn) data, ref_affine = load_nifti(ref) transformed = [] for sl in streams: result = utils.apply_affine(aff=ref_affine, pts=sl[0]) transformed.append(result) transformed = np.array(transformed) original = np.array([sl[0] for sl in streams]) return [original, transformed]
def get_data(tom_fn, tractogram_fn, beginnings_fn, endings_fn, mean, sdev, coord2voxel): #print('-----') # Load TOM volume and preprocess tom = nib.load(tom_fn).get_data() # 144 x 144 x 144 x 3 tom = (tom - mean) / sdev # normalise based on dataset mean/stdev tom = torch.from_numpy(np.float32(tom)) tom = tom.permute(3, 0, 1, 2) # channels first for pytorch # On-the-fly augmentation noise_stdev = torch.rand(1) * 0.05 noise = torch.normal(mean=torch.zeros(tom.size()), std=torch.ones(tom.size()) * noise_stdev) tom += noise # Load the beginnings segmentation volume beginnings = nib.load(beginnings_fn).get_data() # 144 x 144 x 144 x 3 beginnings = torch.from_numpy(np.float32(beginnings)) beginnings = beginnings.permute(3, 0, 1, 2) # channels first for pytorch # Load the endings segmentation volume endings = nib.load(endings_fn).get_data() # 144 x 144 x 144 x 3 endings = torch.from_numpy(np.float32(endings)) endings = endings.permute(3, 0, 1, 2) # channels first for pytorch # Concatenate the seed volume as an extra channel of the first dimension of the TOM volume tom_seed = torch.cat((tom, beginnings), dim=0) tom_seed = torch.cat((tom_seed, endings), dim=0) # Load the tractogram #print('Loading tractogram') #t0 = time.time() streamlines, header = trackvis.read(tractogram_fn) streamlines = [s[0] for s in streamlines] streamlines = np.array(streamlines) # Convert to voxel coordinates and normalise streamlines = utils.apply_affine(aff=coord2voxel, pts=streamlines) streamlines /= 144 # Get seed coordinates and remove from streamlines seeds = np.array([sl[0].copy() for sl in streamlines], dtype=np.float32) for i in range(len(streamlines)): streamlines[i] -= streamlines[i][0] # Sort seeds and streamlines by seed points x, then y, then z streamlines = list(streamlines) streamlines = [ x for _, x in sorted( zip(seeds, streamlines), key=lambda pair: [pair[0][0], pair[0][1], pair[0][2]]) ] seeds = sorted(seeds, key=lambda k: [k[0], k[1], k[2]]) seeds = np.array(seeds, dtype=np.float32) # Convert seeds to torch seeds = torch.from_numpy(seeds) seeds = seeds.permute(1, 0) # channels first for pytorch # automatically converts list to numpy array and reshapes it # (num_sl, points_per_sl, 3) -> (sqrt(num_sl), sqrt(num_sl), points_per_sl*3) # Performed in 2 successive steps because I don't know if it works if I do it in one step streamlines = np.reshape(streamlines, (int(num_streamlines**(1 / 2)), int(num_streamlines**(1 / 2)), num_points, 3)) streamlines = np.reshape(streamlines, (int(num_streamlines**(1 / 2)), int(num_streamlines**(1 / 2)), num_points * 3)) tractogram = torch.from_numpy(streamlines) tractogram = tractogram.permute(2, 0, 1) # channels first for pytorch return [[tom_seed, seeds], tractogram]
""" _, affine = load_nifti( '../../data/final_hairnet_dataset/preprocessed/TOMs/599469_0_CST_left.nii.gz' ) inverse_affine = np.linalg.inv(affine) for tractogram_fn in glob( '../../data/final_hairnet_dataset/not_preprocessed/tractograms/*.trk' )[:3]: # Load trk streamlines, header = trackvis.read(tractogram_fn) streamlines = [s[0] for s in streamlines] streamlines = np.array(streamlines) # Convert to voxel space streamlines = utils.apply_affine(aff=inverse_affine, pts=streamlines) coords = np.reshape(streamlines, (-1, 3)) xs, ys, zs = coords[:, 0], coords[:, 1], coords[:, 2] #fig = px.scatter_3d(x=xs, y=ys, z=zs, color=np.array([xs[i] for i in range(len(xs))]), range_x=[0,144], range_y=[0,144], range_z=[0,144]) fig = go.Figure(data=go.Scatter3d( x=xs, y=ys, z=zs, mode='markers', marker=dict(size=1, color=[xs[i] for i in range(len(xs))]))) #fig.update_layout(scene_aspectmode='cube') fig.update_layout(scene_aspectmode='cube', scene=dict(xaxis=dict(range=[0, 144]),
def gen_TOM(streamlines, ref_file): ref_data, ref_affine = load_nifti(ref_file) coords_to_array = np.linalg.inv(ref_affine) # invert the matrix to convert from points to list indices #print(utils.apply_affine(aff=np.linalg.inv(ref_affine), pts=np.array([[0,0,0]]))) # coordinates a = [] b = [] c = [] # vectors d = [] e = [] f = [] i = 0 collection = {} collection = [] collection = [[[[np.array([0,0,0])] for x in range(len(ref_data[z][y]))] for y in range(len(ref_data[z]))] for z in range(len(ref_data))] streamlines = list(utils.subsegment(streamlines, abs(ref_affine[0,0]/4))) for sl in streamlines: for point in range(len(sl)-1): if i % 10 == 0: x, y, z = sl[point] # Convert from (0,0,0) = centre, to (0,0,0) = top left x, y, z = list(utils.apply_affine(aff=np.linalg.inv(ref_affine), pts=np.array([[x,y,z]])))[0] # Compute direction of movement vector = np.array(sl[point+1]) - np.array(sl[point]) # Normalise the magnitude size = (vector[0]**2 + vector[1]**2 + vector[2]**2)**(1/2) u, v, w = vector / size if size != 0 else [0, 0, 0] x, y, z = [int(x), int(y), int(z)] collection[z][y][x].append(np.array([u, v, w])) #if (z, y, x) in collection: # collection[(z, y, x)].append((u, v, w)) #else: # collection[(z, y, x)] = [(u, v, w)] #collection[z, y, x] #a.append(x) #b.append(y) #c.append(z) #d.append(u) #e.append(v) #f.append(w) i += 1 # get means for z in range(len(collection)): for y in range(len(collection[z])): for x in range(len(collection[z][y])): collection[z][y][x] = sum(collection[z][y][x])/len(collection[z][y][x]) collection = np.array(collection) #for z in range(len(collection)): # im = collection[z] #cv2.imshow('TOM', np.uint8(255*(im - np.min(im))/(np.max(im) - np.min(im)))) # cv2.imshow('TOM', np.uint8(im*255)) # cv2.waitKey(0) #fig = plt.figure() #ax = Axes3D(fig) #ax.scatter(a, b, c) #plt.show() #print(min(a), max(a), abs(max(a)-min(a))) #print(min(b), max(b), abs(max(b)-min(b))) #print(min(c), max(c), abs(max(c)-min(c))) #fig = plt.figure() #ax = fig.add_subplot(111, projection='3d') #ax.quiver(a, b, c, d, e, f) #plt.show() generated_nifti = nib.Nifti1Image(collection, affine=ref_affine) nib.save(generated_nifti, 'test.nii.gz') return collection
def get_data(tom_fn, tractogram_fn, beginnings_fn, endings_fn, mean, sdev, coord2voxel): # Load TOM volume and normalise tom = nib.load(tom_fn).get_data() # 144 x 144 x 144 x 3 tom = (tom - mean) / sdev # normalise based on dataset mean/stdev # Convert to torch and reshape tom = torch.from_numpy(np.float32(tom)) tom = tom.permute(3, 0, 1, 2) # channels first for pytorch # On-the-fly augmentation noise_stdev = torch.rand(1) * 0.05 noise = torch.normal(mean=torch.zeros(tom.size()), std=torch.ones(tom.size()) * noise_stdev) tom += noise # Load the beginnings segmentation volume beginnings = nib.load(beginnings_fn).get_data() # 144 x 144 x 144 x 3 beginnings = torch.from_numpy(np.float32(beginnings)) beginnings = beginnings.permute(3, 0, 1, 2) # channels first for pytorch # Load the endings segmentation volume endings = nib.load(endings_fn).get_data() # 144 x 144 x 144 x 3 endings = torch.from_numpy(np.float32(endings)) endings = endings.permute(3, 0, 1, 2) # channels first for pytorch # Concatenate the seed volume as an extra channel of the first dimension of the TOM volume tom_seed = torch.cat((tom, beginnings), dim=0) tom_seed = torch.cat((tom_seed, endings), dim=0) #print('The following should be 5 x 144 x 144 x 144') #print(tom_seed.size()) # Load the tractogram streamlines, header = trackvis.read(tractogram_fn) streamlines = [s[0] for s in streamlines] streamlines = np.array(streamlines) # Convert to voxel coordinates streamlines = utils.apply_affine(aff=coord2voxel, pts=streamlines) # Normalise tractogram into [0,1] range streamlines /= 144 # Get seed coordinates seeds = np.array([sl[0].copy() for sl in streamlines], dtype=np.float32) seeds = torch.from_numpy(seeds) seeds = seeds.permute(1, 0) # channels first for pytorch #print('The following should be: (3,1024):') #print(seeds.size()) # automatically converts list to numpy array and reshapes it # (num_sl, points_per_sl, 3) -> (sqrt(num_sl), sqrt(num_sl), points_per_sl*3) # Performed in 2 successive steps because I don't know if it works if I do it in one step # Remove seed points from streamlines for i in range(len(streamlines)): streamlines[i] -= streamlines[i][0] # Reshape streamlines and convert to tensor streamlines = np.reshape(streamlines, (int(num_streamlines**(1 / 2)), int(num_streamlines**(1 / 2)), num_points, 3)) streamlines = np.reshape(streamlines, (int(num_streamlines**(1 / 2)), int(num_streamlines**(1 / 2)), num_points * 3)) tractogram = torch.from_numpy(streamlines) tractogram = tractogram.permute(2, 0, 1) # channels first for pytorch # Scale augmentation #tom_seed = torch.from_numpy(zoom(tom_seed, (1, (x_zoom, y_zoom, z_zoom), order=0))) return [[tom_seed, seeds], tractogram]