def get_voxel_grid_real_space(images, append_ones=False): # Get shape excluding channels shape = images.shape[:-1] # Get affine transforming voxel positions to real space positions vox_to_real_affine = images.affine[:-1, :-1] # Transform axes from voxel space to real space grid_vox_space = np.mgrid[0:shape[0]:1, 0:shape[1]:1, 0:shape[2]:1] # Move grid to real space grid_points_real_space = vox_to_real_affine.dot( mgrid_to_points(grid_vox_space).T).T # Center centered_grid_points_real_space = grid_points_real_space - \ np.mean(grid_points_real_space, axis=0) # Append column of ones? if append_ones: centered_grid_points_real_space = np.column_stack( (grid_points_real_space, np.ones(len(grid_points_real_space)))) # Return real space grid as mgrid points = points_to_mgrid(centered_grid_points_real_space, shape) return points
def get_base_patches_from(self, image, return_y=False): real_dims = image.real_shape # Calculate positions sample_space = np.asarray([max(i, self.real_box_dim) for i in real_dims]) d = (sample_space - self.real_box_dim) min_cov = [np.ceil(sample_space[i]/self.real_box_dim).astype(np.int) for i in range(3)] ds = [np.linspace(0, d[i], min_cov[i]) - sample_space[i]/2 for i in range(3)] # Get placement coordinate points placements = mgrid_to_points(np.meshgrid(*tuple(ds))) for p in placements: grid, axes, inv_mat = sample_box_at(real_placement=p, sample_dim=self.sample_dim, real_box_dim=self.real_box_dim, noise_sd=0.0, test_mode=True) im, lab = self._intrp_and_norm(image, grid, return_y) if return_y: yield im, lab, grid, axes, inv_mat, len(placements) else: yield im, grid, axes, inv_mat, len(placements)
def apply_rotation(self, mgrid): if self.rot_mat is not None: shape = mgrid[0].shape rotated = self.rot_mat.dot(mgrid_to_points(mgrid).T).T return points_to_mgrid(rotated, shape) else: return mgrid
def get_voxel_grid(images, as_points=False): shape = images.shape[:3] grid = np.mgrid[0:shape[0]:1, 0:shape[1]:1, 0:shape[2]:1] if as_points: return mgrid_to_points(grid) else: return grid
def sample_plane_at(norm_vector, sample_dim, real_space_span, offset_from_center, noise_sd, test_mode=False): # Prepare normal vector to the plane n_hat = np.array(norm_vector, np.float32) n_hat /= np.linalg.norm(n_hat) # Add noise? if type(noise_sd) is not np.ndarray: noise_sd = np.random.normal(scale=noise_sd, size=3) n_hat += noise_sd n_hat /= np.linalg.norm(n_hat) if np.all(n_hat[:-1] < 0.2): # Vector pointing primarily up, noise will have large effect on image # orientation. We force the first two components to go into the # positive direction to control variability of sampling n_hat[:-1] = np.abs(n_hat[:-1]) if np.all(np.isclose(n_hat[:-1], 0)): u = np.array([1, 0, 0]) v = np.array([0, 1, 0]) else: # Find vector in same vertical plane as nhat nhat_vs = n_hat.copy() nhat_vs[-1] = nhat_vs[-1] + 1 nhat_vs /= np.linalg.norm(nhat_vs) # Get two orthogonal vectors in plane, u pointing down in z-direction u = get_rotation_matrix(np.cross(n_hat, nhat_vs), -90).dot(n_hat) v = np.cross(n_hat, u) # Define basis matrix + displacement to center (affine transformation) basis = np.column_stack((u, v, n_hat)) # Define regular grid (centered at origin) hd = real_space_span // 2 g = np.linspace(-hd, hd, sample_dim) j = complex(sample_dim) grid = np.mgrid[-hd:hd:j, -hd:hd:j, offset_from_center:offset_from_center:1j] # Calculate voxel coordinates on the real space grid points = mgrid_to_points(grid) real_points = basis.dot(points.T).T real_grid = points_to_mgrid(real_points, grid.shape[1:]) if test_mode: return real_grid, g, np.linalg.inv(basis) else: return real_grid
def map_real_space_pred(pred, grid, inv_basis, voxel_grid_real_space, method="nearest"): """ TODO """ print("Mapping to real coordinate space...") # Prepare fill value vector, we set this to 1.0 background fill = np.zeros(shape=pred.shape[-1], dtype=np.float32) fill[0] = 1.0 # Initialize interpolator object intrp = RegularGridInterpolator(grid, pred, fill_value=fill, bounds_error=False, method=method) points = inv_basis.dot(mgrid_to_points(voxel_grid_real_space).T).T transformed_grid = points_to_mgrid(points, voxel_grid_real_space[0].shape) # Prepare mapped pred volume mapped = np.empty(transformed_grid[0].shape + (pred.shape[-1], ), dtype=pred.dtype) # Prepare interpolation function def _do(xs, ys, zs, index): return intrp((xs, ys, zs)), index # Prepare thread pool of 10 workers from concurrent.futures import ThreadPoolExecutor pool = ThreadPoolExecutor(max_workers=7) # Perform interpolation async. inds = np.arange(transformed_grid.shape[1]) result = pool.map(_do, transformed_grid[0], transformed_grid[1], transformed_grid[2], inds) i = 1 for map, ind in result: # Print status print(" %i/%i" % (i, inds[-1] + 1), end="\r", flush=True) i += 1 # Map the interpolation results into the volume mapped[ind] = map print("") pool.shutdown() return mapped
def get_base_patches(self, image): X = image.image # Calculate positions sample_space = np.asarray([max(i, self.dim) for i in image.shape[:3]]) d = (sample_space - self.dim) min_cov = [np.ceil(sample_space[i]/self.dim).astype(np.int) for i in range(3)] ds = [np.linspace(0, d[i], min_cov[i], dtype=np.int) for i in range(3)] # Get placement coordinate points placements = mgrid_to_points(np.meshgrid(*tuple(ds))) for p in placements: yield image.scaler.transform(X[p[0]:p[0]+self.dim, p[1]:p[1]+self.dim, p[2]:p[2]+self.dim]), p
def sample_box_at(real_placement, sample_dim, real_box_dim, noise_sd, test_mode): j = complex(sample_dim) a, b, c = real_placement grid = np.mgrid[a:a + real_box_dim:j, b:b + real_box_dim:j, c:c + real_box_dim:j] rot_mat = np.eye(3) rot_grid = grid if noise_sd: # Get random rotation vector rot_axis = get_random_views(N=1, dim=3, pos_z=True) rot_angle = False while not rot_angle: angle = np.abs(np.random.normal(scale=noise_sd, size=1)[0]) if angle < 2 * np.pi: rot_angle = angle rot_mat = get_rotation_matrix(rot_axis, angle_rad=rot_angle) # Center --> apply rotation --> revert centering --> mgrid points = mgrid_to_points(grid) center = np.mean(points, axis=0) points -= center points = rot_mat.dot(points.T).T + center rot_grid = points_to_mgrid(points, grid.shape[1:]) if test_mode: axes = (np.linspace(a, a + real_box_dim, sample_dim), np.linspace(b, b + real_box_dim, sample_dim), np.linspace(c, c + real_box_dim, sample_dim)) return rot_grid, axes, np.linalg.inv(rot_mat) else: return rot_grid
def pred_3D_iso(model, sequence, image, extra_boxes, min_coverage=None): total_extra_boxes = extra_boxes # Get reference to the image n_classes = sequence.n_classes pred_shape = tuple(image.shape[:3]) + (n_classes, ) vox_shape = tuple(image.shape[:3]) + (3, ) # Prepare interpolator object vox_grid = get_voxel_grid(image, as_points=False) # Get voxel regular grid centered in real space g_all, basis, _ = get_voxel_axes_real_space(image.image, image.affine, return_basis=True) g_all = list(g_all) # Flip axes? Must be strictly increasing flip = np.sign(np.diagonal(basis)) == -1 for i, (g, f) in enumerate(zip(g_all, flip)): if f: g_all[i] = np.flip(g, 0) vox_grid = np.flip(vox_grid, i + 1) vox_points = mgrid_to_points(vox_grid).reshape(vox_shape).astype( np.float32) # Setup interpolator - takes a point in the scanner space and returns # the nearest voxel coordinate intrp = RegularGridInterpolator(tuple(g_all), vox_points, method="nearest", bounds_error=False, fill_value=np.nan, dtype=np.float32) # Prepare prediction volume pred_vol = np.zeros(shape=pred_shape, dtype=np.float32) # Predict on base patches first base_patches = sequence.get_base_patches_from(image, return_y=False) # Sample boxes and predict --> sum into pred_vol is_covered, base_reached, extra_reached, N_base, N_extra = not min_coverage, False, False, 0, 0 while not is_covered or not base_reached or not extra_reached: try: im, rgrid, _, _, total_base = next(base_patches) N_base += 1 if isinstance(total_extra_boxes, str): # Number specified in string format '2x', '2.5x' etc. as a # multiplier of number of base patches total_extra_boxes = int( float(total_extra_boxes.split("x")[0]) * total_base) except StopIteration: p = sequence.get_N_random_patches_from(image, 1, return_y=False) im, rgrid, _, _ = next(p) N_extra += 1 # Predict on the box pred = model.predict_on_batch(np.expand_dims(im, 0))[0] # Apply rotation if needed rgrid = image.interpolator.apply_rotation(rgrid) # Interpolate to nearest vox grid positions vox_inds = intrp(tuple(rgrid)).reshape(-1, 3) # Flatten and mask results mask = np.logical_not(np.all(np.isnan(vox_inds), axis=-1)) vox_inds = [i for i in vox_inds[mask].astype(np.int).T] # Add to volume pred_vol[tuple(vox_inds)] += pred.reshape(-1, n_classes)[mask] # Check coverage fraction if min_coverage: covered = np.logical_not(np.all(np.isclose(pred_vol, 0), axis=-1)) coverage = np.sum(covered) / np.prod(pred_vol.shape[:3]) cov_string = "%.3f/%.3f" % coverage, min_coverage is_covered = coverage >= min_coverage else: cov_string = "[Not calculated]" print(" N base patches: %i/%i --- N extra patches %i/%i --- " "Coverage: %s" % (N_base, total_base, N_extra, total_extra_boxes, cov_string), end="\r", flush=True) # Check convergence base_reached = N_base >= total_base extra_reached = N_extra >= total_extra_boxes print("") # Return prediction volume - OBS not normalized return pred_vol
def get_patch_corners(self): xc = np.linspace(0, self.dim_r[0], self.strides[0]).astype(np.int) yc = np.linspace(0, self.dim_r[1], self.strides[1]).astype(np.int) zc = np.linspace(0, self.dim_r[2], self.strides[2]).astype(np.int) return mgrid_to_points(np.meshgrid(xc, yc, zc))