def get_preview_batch(h5data: Tuple[str, str], preview_shape: Optional[Tuple[int, ...]] = None, transform: Callable = transforms.Identity(), in_memory: bool = False) -> torch.Tensor: fname, key = h5data inp_h5 = h5py.File(fname, 'r')[key] if in_memory: inp_h5 = inp_h5.value dim = len(preview_shape) # 2D or 3D inp_shape = np.array(inp_h5.shape[-dim:]) if preview_shape is None: # Slice everything inp_lo = np.zeros_like(inp_shape) inp_hi = inp_shape else: # Slice only a preview_shape-sized region from the center of the input halfshape = np.array(preview_shape) // 2 inp_center = inp_shape // 2 inp_lo = inp_center - halfshape inp_hi = inp_center + halfshape if np.any(inp_center < halfshape): raise ValueError( 'preview_shape is too big for shape of input source.' f'Requested {preview_shape}, but can only deliver {tuple(inp_shape)}.' ) memstr = ' (in memory)' if in_memory else '' logger.info(f'\nPreview data{memstr}:') logger.info( f' input: {fname}[{key}]: {inp_h5.shape} ({inp_h5.dtype})\n') inp_np = slice_h5(inp_h5, inp_lo, inp_hi, prepend_empty_axis=True) if inp_np.ndim == dim + 1: # Should be dim + 2 for (N, C) dims inp_np = inp_np[:, None] # Add missing C dim inp_np, _ = transform(inp_np, None) inp = torch.from_numpy(inp_np) return inp
def _create_preview_batch( self, inp_source: h5py.Dataset, target_source: h5py.Dataset, ) -> Tuple[torch.Tensor, torch.LongTensor]: # Central slicing halfshape = np.array(self.preview_shape) // 2 if inp_source.ndim == 4: inp_shape = np.array(inp_source.shape[1:]) target_shape = np.array(target_source.shape[1:]) elif inp_source.ndim == 3: inp_shape = np.array(inp_source.shape) target_shape = np.array(target_source.shape) inp_center = inp_shape // 2 inp_lo = inp_center - halfshape inp_hi = inp_center + halfshape target_center = target_shape // 2 target_lo = target_center - halfshape target_hi = target_center + halfshape if np.any(inp_center < halfshape): raise ValueError( 'preview_shape is too big for shape of input source.' f'Requested {self.preview_shape}, but can only deliver {tuple(inp_shape)}.' ) elif np.any(target_center < halfshape): raise ValueError( 'preview_shape is too big for shape of target source.' f'Requested {self.preview_shape}, but can only deliver {tuple(target_shape)}.' ) inp_np = slice_h5(inp_source, inp_lo, inp_hi, prepend_empty_axis=True) target_np = slice_h5( target_source, target_lo, target_hi, dtype=self._target_dtype, prepend_empty_axis=True ) inp_np, target_np = self.transform(inp_np, target_np) inp = torch.from_numpy(inp_np) target = torch.from_numpy(target_np) return inp, target
def warp_slice(inp_src, patch_shape, M, target_src=None, target_patch_shape=None, target_discrete_ix=None) -> Tuple[np.ndarray, np.ndarray]: """ Cuts a warped slice out of the input image and out of the target_src image. Warping is applied by multiplying the original source coordinates with the inverse of the homogeneous (forward) transformation matrix ``M``. "Source coordinates" (``src_coords``) signify the coordinates of voxels in ``inp_src`` and ``target_src`` that are used to compose their respective warped versions. The idea here is that not the images themselves, but the coordinates from where they are read are warped. This allows for much higher efficiency for large image volumes because we don't have to calculate the expensive warping transform for the whole image, but only for the voxels that we eventually want to use for the new warped image. The transformed coordinates usually don't align to the discrete voxel grids of the original images (meaning they are not integers), so the new voxel values are obtained by linear interpolation. Parameters ---------- inp_src: h5py.Dataset Input image source (in HDF5) patch_shape: tuple or np.ndarray (spatial only) Patch shape ``(D, H, W)`` (spatial shape of the neural network's input node) M: np.ndarray Forward warping tansformation matrix (4x4). Must contain translations in source and target_src array. target_src: h5py.Dataset or None Optional target source array to be extracted from in the same way. target_patch_shape: tuple or np.ndarray Patch size for the ``target_src`` array. target_discrete_ix: list List of target channels that contain discrete values. By default (``None``), every channel is is seen as discrete (this is generally the case for classification tasks). This information is used to decide what kind of interpolation should be used for reading target data: - discrete targets are obtained by nearest-neighbor interpolation - non-discrete (continuous) targets are linearly interpolated. Returns ------- inp: np.ndarray Warped input image slice target: np.ndarray or None Warped target_src image slice or ``None``, if ``target_src is None``. """ patch_shape = tuple(patch_shape) if len(inp_src.shape) == 3: print(f'inp_src.shape: {inp_src.shape}') raise NotImplementedError( 'elektronn3 has dropped support for data stored in raw 3D form without a channel axis. ' 'Please always supply it with a prepended channel, so it\n' 'has the form (C, D, H, W) (or in ELEKTRONN2 terms: (f, z, x, y)).' ) elif len(inp_src.shape) == 4: n_f = inp_src.shape[0] sh = inp_src.shape[1:] else: raise ValueError('inp_src wrong dim/shape') M_inv = np.linalg.inv(M.astype(np.float64)).astype(floatX) # stability... dest_corners = make_dest_corners(patch_shape) src_corners = np.dot(M_inv, dest_corners.T).T if np.any(M[3, :3] != 0): # homogeneous divide src_corners /= src_corners[:, 3][:, None] # check corners src_corners = src_corners[:, :3] lo = np.min(np.floor(src_corners), 0).astype(np.int) hi = np.max(np.ceil(src_corners + 1), 0).astype(np.int) # add 1 because linear interp if np.any(lo < 0) or np.any(hi >= sh): raise WarpingOOBError("Out of bounds") # compute/transform dense coords dest_coords = make_dest_coords(patch_shape) src_coords = np.tensordot(dest_coords, M_inv, axes=[[-1], [1]]) if np.any(M[3, :3] != 0): # homogeneous divide src_coords /= src_coords[..., 3][..., None] # cut patch src_coords = src_coords[..., :3] # Add 1 to hi to include this coordinate! img_cut = slice_h5(inp_src, lo, hi + 1, dtype=floatX) inp = np.zeros((n_f, ) + patch_shape, dtype=floatX) lo = lo.astype(floatX) for k in range(n_f): map_coordinates_linear(img_cut[k], src_coords, lo, inp[k]) if target_src is not None: target_patch_shape = tuple(target_patch_shape) n_f_t = target_src.shape[0] off = np.subtract(sh, target_src.shape[-3:]) if np.any(np.mod(off, 2)): raise ValueError("targets must be centered w.r.t. images") off //= 2 off_ps = np.subtract(patch_shape, target_patch_shape) if np.any(np.mod(off_ps, 2)): raise ValueError("targets must be centered w.r.t. images") off_ps //= 2 src_coords_target = src_coords[off_ps[0]:off_ps[0] + target_patch_shape[0], off_ps[1]:off_ps[1] + target_patch_shape[1], off_ps[2]:off_ps[2] + target_patch_shape[2]] # shift coords to be w.r.t. to origin of target_src array lo_targ = np.floor(src_coords_target.min(2).min(1).min(0) - off).astype(np.int) # add 1 because linear interp hi_targ = np.ceil(src_coords_target.max(2).max(1).max(0) - off + 1).astype(np.int) if np.any(lo_targ < 0) or np.any(hi_targ >= target_src.shape[-3:]): raise WarpingOOBError("Out of bounds for target_src") # dtype is float as well here because of the static typing of the # numba-compiled map_coordinates functions target_cut = slice_h5(target_src, lo_targ, hi_targ + 1, dtype=floatX) # TODO: This and the checks below only make sense for discrete targets. Continuous targets are currently BROKEN. n_target_classes = target_cut.max() src_coords_target = np.ascontiguousarray(src_coords_target, dtype=floatX) target = np.zeros((n_f_t, ) + target_patch_shape, dtype=floatX) lo_targ = (lo_targ + off).astype(floatX) if target_discrete_ix is None: target_discrete_ix = [True for i in range(n_f_t)] else: target_discrete_ix = [ i in target_discrete_ix for i in range(n_f_t) ] for k, discr in enumerate(target_discrete_ix): if discr: map_coordinates_nearest(target_cut[k], src_coords_target, lo_targ, target[k]) else: map_coordinates_linear(target_cut[k], src_coords_target, lo_targ, target[k]) if np.any(target > n_target_classes): print( f'warp_slice: Invalid target: max = {target.max()}. Clipping target...' ) target = target.clip(0, n_target_classes) # TODO: Or should we just throw an error? (~ WarpingOOB) else: target = None return inp, target
def warp_slice( inp_src: Union[h5py.Dataset, np.ndarray], patch_shape: Union[Tuple[int], np.ndarray], M: np.ndarray, target_src: Optional[Union[h5py.Dataset, np.ndarray]] = None, target_patch_shape: Optional[Union[Tuple[int], np.ndarray]] = None, target_discrete_ix: Optional[Sequence[int]] = None, debug: bool = True # TODO: This has some performance impact. Switch this off by default when we're sure everything works. ) -> Tuple[np.ndarray, Optional[np.ndarray]]: """ Cuts a warped slice out of the input image and out of the target_src image. Warping is applied by multiplying the original source coordinates with the inverse of the homogeneous (forward) transformation matrix ``M``. "Source coordinates" (``src_coords``) signify the coordinates of voxels in ``inp_src`` and ``target_src`` that are used to compose their respective warped versions. The idea here is that not the images themselves, but the coordinates from where they are read are warped. This allows for much higher efficiency for large image volumes because we don't have to calculate the expensive warping transform for the whole image, but only for the voxels that we eventually want to use for the new warped image. The transformed coordinates usually don't align to the discrete voxel grids of the original images (meaning they are not integers), so the new voxel values are obtained by linear interpolation. Parameters ---------- inp_src Input image source (in HDF5) patch_shape (spatial only) Patch shape ``(D, H, W)`` (spatial shape of the neural network's input node) M Forward warping tansformation matrix (4x4). Must contain translations in source and target_src array. target_src Optional target source array to be extracted from in the same way. target_patch_shape Patch size for the ``target_src`` array. target_discrete_ix List of target channels that contain discrete values. By default (``None``), every channel is is seen as discrete (this is generally the case for classification tasks). This information is used to decide what kind of interpolation should be used for reading target data: - discrete targets are obtained by nearest-neighbor interpolation - non-discrete (continuous) targets are linearly interpolated. debug: If ``True`` (default), enable additional sanity checks to catch warping issues early. Returns ------- inp Warped input image slice target Warped target_src image slice or ``None``, if ``target_src is None``. """ patch_shape = tuple(patch_shape) if len(inp_src.shape) == 3: n_f = 1 elif len(inp_src.shape) == 4: n_f = inp_src.shape[0] else: raise ValueError(f'Can\'t handle inp_src shape {inp_src.shape}') # Spatial shapes of input and target data sources inp_src_shape = np.array(inp_src.shape[-3:]) target_src_shape = np.array(target_src.shape[-3:]) M_inv = np.linalg.inv(M.astype(np.float64)).astype(floatX) # stability... dest_corners = make_dest_corners(patch_shape) src_corners = np.dot(M_inv, dest_corners.T).T if np.any(M[3, :3] != 0): # homogeneous divide src_corners /= src_corners[:, 3][:, None] # check corners src_corners = src_corners[:, :3] # compute/transform dense coords dest_coords = make_dest_coords(patch_shape) src_coords = np.tensordot(dest_coords, M_inv, axes=[[-1], [1]]) if np.any(M[3, :3] != 0): # homogeneous divide src_coords /= src_coords[..., 3][..., None] # cut patch src_coords = src_coords[..., :3] if target_src is not None: target_patch_shape = tuple(target_patch_shape) n_f_t = target_src.shape[0] if target_src.ndim == 4 else 1 target_src_offset = np.subtract(inp_src_shape, target_src.shape[-3:]) if np.any(np.mod(target_src_offset, 2)): raise ValueError("targets must be centered w.r.t. images") target_src_offset //= 2 target_offset = np.subtract(patch_shape, target_patch_shape) if np.any(np.mod(target_offset, 2)): raise ValueError("targets must be centered w.r.t. images") target_offset //= 2 src_coords_target = src_coords[ target_offset[0]:(target_offset[0] + target_patch_shape[0]), target_offset[1]:(target_offset[1] + target_patch_shape[1]), target_offset[2]:(target_offset[2] + target_patch_shape[2])] # shift coords to be w.r.t. to origin of target_src array lo_targ = np.floor( src_coords_target.min(2).min(1).min(0) - target_src_offset).astype( np.int) hi_targ = np.ceil( src_coords_target.max(2).max(1).max(0) + 1 - target_src_offset).astype(np.int) if np.any(lo_targ < 0) or np.any(hi_targ >= target_src_shape - 1): raise WarpingOOBError("Out of bounds for target_src") lo = np.min(np.floor(src_corners), 0).astype(np.int) hi = np.max(np.ceil(src_corners + 1), 0).astype(np.int) if np.any(lo < 0) or np.any(hi >= inp_src_shape - 1): raise WarpingOOBError("Out of bounds for inp_src") # Slice and interpolate input # Slice to hi + 1 because interpolation potentially needs this value. img_cut = slice_h5(inp_src, lo, hi + 1, dtype=floatX) if img_cut.ndim == 3: img_cut = img_cut[None] inp = np.zeros((n_f, ) + patch_shape, dtype=floatX) lo = lo.astype(floatX) if debug and np.any( (src_coords - lo).max(2).max(1).max(0) >= img_cut.shape[-3:]): raise WarpingSanityError( f'src_coords check failed (too high).\n{(src_coords - lo).max(2).max(1).max(0), img_cut.shape[-3:]}' ) if debug and np.any((src_coords - lo).min(2).min(1).min(0) < 0): raise WarpingSanityError( f'src_coords check failed (negative indices).\n{(src_coords - lo).min(2).min(1).min(0)}' ) for k in range(n_f): map_coordinates_linear(img_cut[k], src_coords, lo, inp[k]) # Slice and interpolate target if target_src is not None: # dtype is float as well here because of the static typing of the # numba-compiled map_coordinates functions # Slice to hi + 1 because interpolation potentially needs this value. target_cut = slice_h5(target_src, lo_targ, hi_targ + 1, dtype=floatX) if target_cut.ndim == 3: target_cut = target_cut[None] src_coords_target = np.ascontiguousarray(src_coords_target, dtype=floatX) target = np.zeros((n_f_t, ) + target_patch_shape, dtype=floatX) lo_targ = (lo_targ + target_src_offset).astype(floatX) if target_discrete_ix is None: target_discrete_ix = [True for i in range(n_f_t)] else: target_discrete_ix = [ i in target_discrete_ix for i in range(n_f_t) ] if debug and np.any( (src_coords_target - lo_targ).max(2).max(1).max(0) >= target_cut.shape[-3:]): raise WarpingSanityError( f'src_coords_target check failed (too high).\n{(src_coords_target - lo_targ).max(2).max(1).max(0)}\n{target_cut.shape[-3:]}' ) if debug and np.any( (src_coords_target - lo_targ).min(2).min(1).min(0) < 0): raise WarpingSanityError( f'src_coords_target check failed (negative indices).\n{(src_coords_target - lo_targ).min(2).min(1).min(0)}' ) for k, discr in enumerate(target_discrete_ix): if discr: map_coordinates_nearest(target_cut[k], src_coords_target, lo_targ, target[k]) if debug: unique_cut = set(list(np.unique(target_cut[k]))) unique_warp = set(list(np.unique(target[k]))) # If new values appear in discrete targets, there is something wrong. # unique_warp can have less values than unique_cut though, for example # if the warping transform coincidentally slices away all values of a class. if not unique_warp.issubset(unique_cut): print( f'Invalid target encountered:\n\nunique_cut=\n{unique_cut}\n' f'unique_warp=\n{unique_warp}\nM_inv=\n{M_inv}\n' f'src_coords_target - lo_targ=\n{src_coords_target - lo_targ}\n' ) # Try dropping to an IPython shell (Won't work with num_workers > 0). import IPython IPython.embed() raise SystemExit else: map_coordinates_linear(target_cut[k], src_coords_target, lo_targ, target[k]) else: target = None if debug and np.any(np.isnan(inp)): raise RuntimeError('Warping is broken: inp contains NaN.') if debug and np.any(np.isnan(target)): raise RuntimeError('Warping is broken: target contains NaN.') return inp, target