def _init_interpolators(self, image, labels, bg_value, bg_class, affine): # Get voxel regular grid centered in real space g_all, basis, rot_mat = get_voxel_axes_real_space(image, affine, return_basis=True) g_all = list(g_all) # Set rotation matrix self.rot_mat = rot_mat # Flip axes? Must be strictly increasing flip = np.sign(np.diagonal(basis)) == -1 for i, (g, f) in enumerate(zip(g_all, flip)): if f: g_all[i] = np.flip(g, 0) image = np.flip(image, i) if labels is not None: labels = np.flip(labels, i) g_xx, g_yy, g_zz = g_all # Set interpolator for image, one for each channel im_intrps = [] for i in range(self.n_channels): im_intrps.append(RegularGridInterpolator((g_xx, g_yy, g_zz), image[..., i].squeeze(), bounds_error=False, fill_value=bg_value, method="linear", dtype=np.float32)) try: # Set interpolator for labels lab_intrp = RegularGridInterpolator((g_xx, g_yy, g_zz), labels, bounds_error=False, fill_value=bg_class, method="nearest", dtype=np.uint8) except (AttributeError, TypeError, ValueError): lab_intrp = None return im_intrps, lab_intrp
def map_real_space_pred(pred, grid, inv_basis, voxel_grid_real_space, method="nearest"): print("Mapping to real coordinate space...") # Prepare fill value vector, we set this to 1.0 background fill = np.zeros(shape=pred.shape[-1], dtype=np.float32) fill[0] = 1.0 # Initialize interpolator object intrp = RegularGridInterpolator(grid, pred, fill_value=fill, bounds_error=False, method=method) points = inv_basis.dot(mgrid_to_points(voxel_grid_real_space).T).T transformed_grid = points_to_mgrid(points, voxel_grid_real_space[0].shape) # Prepare mapped pred volume mapped = np.empty(transformed_grid[0].shape + (pred.shape[-1],), dtype=pred.dtype) # Prepare interpolation function def _do(xs, ys, zs, index): return intrp((xs, ys, zs)), index # Prepare thread pool of 10 workers from concurrent.futures import ThreadPoolExecutor from multiprocessing import cpu_count pool = ThreadPoolExecutor(max_workers=max(7, cpu_count())) # Perform interpolation async. inds = np.arange(transformed_grid.shape[1]) result = pool.map(_do, transformed_grid[0], transformed_grid[1], transformed_grid[2], inds) i = 1 for map, ind in result: # Print status print(" %i/%i" % (i, inds[-1]+1), end="\r", flush=True) i += 1 # Map the interpolation results into the volume mapped[ind] = map # Interpolate # mapped = intrp(tuple(transformed_grid)) print("") pool.shutdown() return mapped
def pred_3D_iso(model, sequence, image, extra_boxes, min_coverage=None): total_extra_boxes = extra_boxes # Get reference to the image n_classes = sequence.n_classes pred_shape = tuple(image.shape[:3]) + (n_classes,) vox_shape = tuple(image.shape[:3]) + (3,) # Prepare interpolator object vox_grid = get_voxel_grid(image, as_points=False) # Get voxel regular grid centered in real space g_all, basis, _ = get_voxel_axes_real_space(image.image, image.affine, return_basis=True) g_all = list(g_all) # Flip axes? Must be strictly increasing flip = np.sign(np.diagonal(basis)) == -1 for i, (g, f) in enumerate(zip(g_all, flip)): if f: g_all[i] = np.flip(g, 0) vox_grid = np.flip(vox_grid, i+1) vox_points = mgrid_to_points(vox_grid).reshape(vox_shape).astype(np.float32) # Setup interpolator - takes a point in the scanner space and returns # the nearest voxel coordinate intrp = RegularGridInterpolator(tuple(g_all), vox_points, method="nearest", bounds_error=False, fill_value=np.nan, dtype=np.float32) # Prepare prediction volume pred_vol = np.zeros(shape=pred_shape, dtype=np.float32) # Predict on base patches first base_patches = sequence.get_base_patches_from(image, return_y=False) # Sample boxes and predict --> sum into pred_vol is_covered, base_reached, extra_reached, N_base, N_extra = not min_coverage, False, False, 0, 0 while not is_covered or not base_reached or not extra_reached: try: im, rgrid, _, _, total_base = next(base_patches) N_base += 1 if isinstance(total_extra_boxes, str): # Number specified in string format '2x', '2.5x' etc. as a # multiplier of number of base patches total_extra_boxes = int(float(total_extra_boxes.split("x")[0]) * total_base) except StopIteration: p = sequence.get_N_random_patches_from(image, 1, return_y=False) im, rgrid, _, _ = next(p) N_extra += 1 # Predict on the box pred = model.predict(np.expand_dims(im, 0))[0] # Apply rotation if needed rgrid = image.interpolator.apply_rotation(rgrid) # Interpolate to nearest vox grid positions vox_inds = intrp(tuple(rgrid)).reshape(-1, 3) # Flatten and mask results mask = np.logical_not(np.all(np.isnan(vox_inds), axis=-1)) vox_inds = [i for i in vox_inds[mask].astype(np.int).T] # Add to volume pred_vol[tuple(vox_inds)] += pred.reshape(-1, n_classes)[mask] # Check coverage fraction if min_coverage: covered = np.logical_not(np.all(np.isclose(pred_vol, 0), axis=-1)) coverage = np.sum(covered) / np.prod(pred_vol.shape[:3]) cov_string = "%.3f/%.3f" % coverage, min_coverage is_covered = coverage >= min_coverage else: cov_string = "[Not calculated]" print(" N base patches: %i/%i --- N extra patches %i/%i --- " "Coverage: %s" % ( N_base, total_base, N_extra, total_extra_boxes, cov_string), end="\r", flush=True) # Check convergence base_reached = N_base >= total_base extra_reached = N_extra >= total_extra_boxes print("") # Return prediction volume - OBS not normalized return pred_vol