def get_preview_batch(h5data: Tuple[str, str], preview_shape: Optional[Tuple[int, ...]] = None, transform: Optional[Callable] = None, in_memory: bool = False) -> torch.Tensor: fname, key = h5data inp_h5 = h5py.File(fname, 'r')[key] if in_memory: inp_h5 = inp_h5.value dim = len(preview_shape) # 2D or 3D inp_shape = np.array(inp_h5.shape[-dim:]) if preview_shape is None: # Slice everything inp_lo = np.zeros_like(inp_shape) inp_hi = inp_shape else: # Slice only a preview_shape-sized region from the center of the input halfshape = np.array(preview_shape) // 2 inp_center = inp_shape // 2 inp_lo = inp_center - halfshape inp_hi = inp_center + halfshape if np.any(inp_center < halfshape): raise ValueError( 'preview_shape is too big for shape of input source.' f'Requested {preview_shape}, but can only deliver {tuple(inp_shape)}.' ) memstr = ' (in memory)' if in_memory else '' logger.info(f'\nPreview data{memstr}:') logger.info( f' input: {fname}[{key}]: {inp_h5.shape} ({inp_h5.dtype})\n') inp_np = slice_3d(inp_h5, inp_lo, inp_hi, prepend_empty_axis=True) if inp_np.ndim == dim + 1: # Should be dim + 2 for (N, C) dims inp_np = inp_np[:, None] # Add missing C dim if transform is not None: inp_np, _ = transform(inp_np, None) inp = torch.from_numpy(inp_np) return inp
def warp_slice(inp_src: DataSource, patch_shape: Union[Tuple[int, ...], np.ndarray], M: np.ndarray, target_src: Optional[DataSource] = None, target_patch_shape: Optional[Union[Tuple[int], np.ndarray]] = None, target_discrete_ix: Optional[Sequence[int]] = None, debug: bool = False) -> Tuple[np.ndarray, Optional[np.ndarray]]: """ Cuts a warped slice out of the input image and out of the target_src image. Warping is applied by multiplying the original source coordinates with the inverse of the homogeneous (forward) transformation matrix ``M``. "Source coordinates" (``src_coords``) signify the coordinates of voxels in ``inp_src`` and ``target_src`` that are used to compose their respective warped versions. The idea here is that not the images themselves, but the coordinates from where they are read are warped. This allows for much higher efficiency for large image volumes because we don't have to calculate the expensive warping transform for the whole image, but only for the voxels that we eventually want to use for the new warped image. The transformed coordinates usually don't align to the discrete voxel grids of the original images (meaning they are not integers), so the new voxel values are obtained by linear interpolation. Parameters ---------- inp_src Input image source (in HDF5) patch_shape (spatial only) Patch shape ``(D, H, W)`` (spatial shape of the neural network's input node) M Forward warping tansformation matrix (4x4). Must contain translations in source and target_src array. target_src Optional target source array to be extracted from in the same way. target_patch_shape Patch size for the ``target_src`` array. target_discrete_ix List of target channels that contain discrete values. By default (``None``), every channel is is seen as discrete (this is generally the case for classification tasks). This information is used to decide what kind of interpolation should be used for reading target data: - discrete targets are obtained by nearest-neighbor interpolation - non-discrete (continuous) targets are linearly interpolated. debug: If ``True`` (default), enable additional sanity checks to catch warping issues early. Returns ------- inp Warped input image slice target Warped target_src image slice or ``None``, if ``target_src is None``. """ patch_shape = tuple(patch_shape) if len(inp_src.shape) == 3: n_f = 1 elif len(inp_src.shape) == 4: n_f = inp_src.shape[0] else: raise ValueError(f'Can\'t handle inp_src shape {inp_src.shape}') # Spatial shapes of input and target data sources inp_src_shape = np.array(inp_src.shape[-3:]) M_inv = np.linalg.inv(M.astype(np.float64)).astype(floatX) # stability... dest_corners = make_dest_corners(patch_shape) src_corners = np.dot(M_inv, dest_corners.T).T if np.any(M[3, :3] != 0): # homogeneous divide src_corners /= src_corners[:, 3][:, None] # check corners src_corners = src_corners[:, :3] lo = np.min(np.floor(src_corners), 0).astype(np.int) hi = np.max(np.ceil(src_corners + 1), 0).astype(np.int) # compute/transform dense coords dest_coords = make_dest_coords(patch_shape) src_coords = np.tensordot(dest_coords, M_inv, axes=[[-1], [1]]) if np.any(M[3, :3] != 0): # homogeneous divide src_coords /= src_coords[..., 3][..., None] # cut patch src_coords = src_coords[..., :3] # TODO: WIP code, integrate this into the warping pipeline with config options # Perform elastic deformation on warped coordinates so we don't have # to interpolate twice. # For more details, see elektronn3.data.transforms.ElasticTransform elastic = False if elastic: sigma = 4 alpha = 40 aniso_factor = 2 for i in range(3): # For each coordinate of dimension i, build a random displacement, # smooth it with sigma and multiply it by alpha elastic_displacement = gaussian_filter( np.random.rand(*patch_shape) * 2 - 1, sigma, mode='constant', cval=0) * alpha # Apply anisotropy correction if i == 0 and aniso_factor != 1: elastic_displacement /= aniso_factor # Apply deformation src_coords[..., i] += elastic_displacement # Clip out-of-bounds coordinates back to original cube edges to # prevent out-of-bounds reading np.clip(src_coords[..., i], lo[i], hi[i] - 1, out=src_coords[..., i]) if target_src is not None: target_src_shape = np.array(target_src.shape[-3:]) target_patch_shape = tuple(target_patch_shape) n_f_t = target_src.shape[0] if target_src.ndim == 4 else 1 target_src_offset = np.subtract(inp_src_shape, target_src.shape[-3:]) if np.any(np.mod(target_src_offset, 2)): raise ValueError("targets must be centered w.r.t. images") target_src_offset //= 2 target_offset = np.subtract(patch_shape, target_patch_shape) if np.any(np.mod(target_offset, 2)): raise ValueError("targets must be centered w.r.t. images") target_offset //= 2 src_coords_target = src_coords[ target_offset[0]:(target_offset[0] + target_patch_shape[0]), target_offset[1]:(target_offset[1] + target_patch_shape[1]), target_offset[2]:(target_offset[2] + target_patch_shape[2])] # shift coords to be w.r.t. to origin of target_src array lo_targ = np.floor( src_coords_target.min(2).min(1).min(0) - target_src_offset).astype( np.int) hi_targ = np.ceil( src_coords_target.max(2).max(1).max(0) + 1 - target_src_offset).astype(np.int) if np.any(lo_targ < 0) or np.any(hi_targ >= target_src_shape - 1): raise WarpingOOBError("Out of bounds for target_src") if np.any(lo < 0) or np.any(hi >= inp_src_shape - 1): raise WarpingOOBError("Out of bounds for inp_src") # Slice and interpolate input # Slice to hi + 1 because interpolation potentially needs this value. img_cut = slice_3d(inp_src, lo, hi + 1, dtype=floatX) if img_cut.ndim == 3: img_cut = img_cut[None] inp = np.zeros((n_f, ) + patch_shape, dtype=floatX) lo = lo.astype(floatX) if debug and np.any( (src_coords - lo).max(2).max(1).max(0) >= img_cut.shape[-3:]): raise WarpingSanityError( f'src_coords check failed (too high).\n{(src_coords - lo).max(2).max(1).max(0), img_cut.shape[-3:]}' ) if debug and np.any((src_coords - lo).min(2).min(1).min(0) < 0): raise WarpingSanityError( f'src_coords check failed (negative indices).\n{(src_coords - lo).min(2).min(1).min(0)}' ) for k in range(n_f): map_coordinates_linear(img_cut[k], src_coords, lo, inp[k]) # Slice and interpolate target if target_src is not None: # dtype is float as well here because of the static typing of the # numba-compiled map_coordinates functions # Slice to hi + 1 because interpolation potentially needs this value. target_cut = slice_3d(target_src, lo_targ, hi_targ + 1, dtype=floatX) if target_cut.ndim == 3: target_cut = target_cut[None] src_coords_target = np.ascontiguousarray(src_coords_target, dtype=floatX) target = np.zeros((n_f_t, ) + target_patch_shape, dtype=floatX) lo_targ = (lo_targ + target_src_offset).astype(floatX) if target_discrete_ix is None: target_discrete_ix = [True for i in range(n_f_t)] else: target_discrete_ix = [ i in target_discrete_ix for i in range(n_f_t) ] if debug and np.any( (src_coords_target - lo_targ).max(2).max(1).max(0) >= target_cut.shape[-3:]): raise WarpingSanityError( f'src_coords_target check failed (too high).\n{(src_coords_target - lo_targ).max(2).max(1).max(0)}\n{target_cut.shape[-3:]}' ) if debug and np.any( (src_coords_target - lo_targ).min(2).min(1).min(0) < 0): raise WarpingSanityError( f'src_coords_target check failed (negative indices).\n{(src_coords_target - lo_targ).min(2).min(1).min(0)}' ) for k, discr in enumerate(target_discrete_ix): if discr: map_coordinates_nearest(target_cut[k], src_coords_target, lo_targ, target[k]) if debug: unique_cut = set(list(np.unique(target_cut[k]))) unique_warp = set(list(np.unique(target[k]))) # If new values appear in discrete targets, there is something wrong. # unique_warp can have less values than unique_cut though, for example # if the warping transform coincidentally slices away all values of a class. if not unique_warp.issubset(unique_cut): print( f'Invalid target encountered:\n\nunique_cut=\n{unique_cut}\n' f'unique_warp=\n{unique_warp}\nM_inv=\n{M_inv}\n' f'src_coords_target - lo_targ=\n{src_coords_target - lo_targ}\n' ) # Try dropping to an IPython shell (Won't work with num_workers > 0). import IPython IPython.embed() raise SystemExit else: map_coordinates_linear(target_cut[k], src_coords_target, lo_targ, target[k]) else: target = None if debug and np.any(np.isnan(inp)): raise RuntimeError('Warping is broken: inp contains NaN.') if debug and np.any(np.isnan(target)): raise RuntimeError('Warping is broken: target contains NaN.') return inp, target