Beispiel #1
0
def get_preview_batch(h5data: Tuple[str, str],
                      preview_shape: Optional[Tuple[int, ...]] = None,
                      transform: Optional[Callable] = None,
                      in_memory: bool = False) -> torch.Tensor:
    fname, key = h5data
    inp_h5 = h5py.File(fname, 'r')[key]
    if in_memory:
        inp_h5 = inp_h5.value
    dim = len(preview_shape)  # 2D or 3D
    inp_shape = np.array(inp_h5.shape[-dim:])
    if preview_shape is None:  # Slice everything
        inp_lo = np.zeros_like(inp_shape)
        inp_hi = inp_shape
    else:  # Slice only a preview_shape-sized region from the center of the input
        halfshape = np.array(preview_shape) // 2
        inp_center = inp_shape // 2
        inp_lo = inp_center - halfshape
        inp_hi = inp_center + halfshape
        if np.any(inp_center < halfshape):
            raise ValueError(
                'preview_shape is too big for shape of input source.'
                f'Requested {preview_shape}, but can only deliver {tuple(inp_shape)}.'
            )
    memstr = ' (in memory)' if in_memory else ''
    logger.info(f'\nPreview data{memstr}:')
    logger.info(
        f'  input:       {fname}[{key}]: {inp_h5.shape} ({inp_h5.dtype})\n')
    inp_np = slice_3d(inp_h5, inp_lo, inp_hi, prepend_empty_axis=True)
    if inp_np.ndim == dim + 1:  # Should be dim + 2 for (N, C) dims
        inp_np = inp_np[:, None]  # Add missing C dim
    if transform is not None:
        inp_np, _ = transform(inp_np, None)
    inp = torch.from_numpy(inp_np)
    return inp
Beispiel #2
0
def warp_slice(inp_src: DataSource,
               patch_shape: Union[Tuple[int, ...], np.ndarray],
               M: np.ndarray,
               target_src: Optional[DataSource] = None,
               target_patch_shape: Optional[Union[Tuple[int],
                                                  np.ndarray]] = None,
               target_discrete_ix: Optional[Sequence[int]] = None,
               debug: bool = False) -> Tuple[np.ndarray, Optional[np.ndarray]]:
    """
    Cuts a warped slice out of the input image and out of the target_src image.
    Warping is applied by multiplying the original source coordinates with
    the inverse of the homogeneous (forward) transformation matrix ``M``.

    "Source coordinates" (``src_coords``) signify the coordinates of voxels in
    ``inp_src`` and ``target_src`` that are used to compose their respective warped
    versions. The idea here is that not the images themselves, but the
    coordinates from where they are read are warped. This allows for much higher
    efficiency for large image volumes because we don't have to calculate the
    expensive warping transform for the whole image, but only for the voxels
    that we eventually want to use for the new warped image.
    The transformed coordinates usually don't align to the discrete
    voxel grids of the original images (meaning they are not integers), so the
    new voxel values are obtained by linear interpolation.

    Parameters
    ----------
    inp_src
        Input image source (in HDF5)
    patch_shape
        (spatial only) Patch shape ``(D, H, W)``
        (spatial shape of the neural network's input node)
    M
        Forward warping tansformation matrix (4x4).
        Must contain translations in source and target_src array.
    target_src
        Optional target source array to be extracted from in the same way.
    target_patch_shape
        Patch size for the ``target_src`` array.
    target_discrete_ix
        List of target channels that contain discrete values.
        By default (``None``), every channel is is seen as discrete (this is
        generally the case for classification tasks).
        This information is used to decide what kind of interpolation should
        be used for reading target data:
        - discrete targets are obtained by nearest-neighbor interpolation
        - non-discrete (continuous) targets are linearly interpolated.
    debug: If ``True`` (default), enable additional sanity checks to catch
        warping issues early.

    Returns
    -------
    inp
        Warped input image slice
    target
        Warped target_src image slice
        or ``None``, if ``target_src is None``.
    """

    patch_shape = tuple(patch_shape)
    if len(inp_src.shape) == 3:
        n_f = 1
    elif len(inp_src.shape) == 4:
        n_f = inp_src.shape[0]
    else:
        raise ValueError(f'Can\'t handle inp_src shape {inp_src.shape}')

    # Spatial shapes of input and target data sources
    inp_src_shape = np.array(inp_src.shape[-3:])

    M_inv = np.linalg.inv(M.astype(np.float64)).astype(floatX)  # stability...
    dest_corners = make_dest_corners(patch_shape)
    src_corners = np.dot(M_inv, dest_corners.T).T
    if np.any(M[3, :3] != 0):  # homogeneous divide
        src_corners /= src_corners[:, 3][:, None]

    # check corners
    src_corners = src_corners[:, :3]
    lo = np.min(np.floor(src_corners), 0).astype(np.int)
    hi = np.max(np.ceil(src_corners + 1), 0).astype(np.int)
    # compute/transform dense coords
    dest_coords = make_dest_coords(patch_shape)
    src_coords = np.tensordot(dest_coords, M_inv, axes=[[-1], [1]])
    if np.any(M[3, :3] != 0):  # homogeneous divide
        src_coords /= src_coords[..., 3][..., None]
    # cut patch
    src_coords = src_coords[..., :3]

    # TODO: WIP code, integrate this into the warping pipeline with config options
    # Perform elastic deformation on warped coordinates so we don't have
    #  to interpolate twice.
    # For more details, see elektronn3.data.transforms.ElasticTransform
    elastic = False
    if elastic:
        sigma = 4
        alpha = 40
        aniso_factor = 2

        for i in range(3):
            # For each coordinate of dimension i, build a random displacement,
            #  smooth it with sigma and multiply it by alpha
            elastic_displacement = gaussian_filter(
                np.random.rand(*patch_shape) * 2 - 1,
                sigma,
                mode='constant',
                cval=0) * alpha
            # Apply anisotropy correction
            if i == 0 and aniso_factor != 1:
                elastic_displacement /= aniso_factor
            # Apply deformation
            src_coords[..., i] += elastic_displacement
            # Clip out-of-bounds coordinates back to original cube edges to
            #  prevent out-of-bounds reading
            np.clip(src_coords[..., i],
                    lo[i],
                    hi[i] - 1,
                    out=src_coords[..., i])

    if target_src is not None:
        target_src_shape = np.array(target_src.shape[-3:])
        target_patch_shape = tuple(target_patch_shape)
        n_f_t = target_src.shape[0] if target_src.ndim == 4 else 1

        target_src_offset = np.subtract(inp_src_shape, target_src.shape[-3:])
        if np.any(np.mod(target_src_offset, 2)):
            raise ValueError("targets must be centered w.r.t. images")
        target_src_offset //= 2

        target_offset = np.subtract(patch_shape, target_patch_shape)
        if np.any(np.mod(target_offset, 2)):
            raise ValueError("targets must be centered w.r.t. images")
        target_offset //= 2

        src_coords_target = src_coords[
            target_offset[0]:(target_offset[0] + target_patch_shape[0]),
            target_offset[1]:(target_offset[1] + target_patch_shape[1]),
            target_offset[2]:(target_offset[2] + target_patch_shape[2])]
        # shift coords to be w.r.t. to origin of target_src array
        lo_targ = np.floor(
            src_coords_target.min(2).min(1).min(0) - target_src_offset).astype(
                np.int)
        hi_targ = np.ceil(
            src_coords_target.max(2).max(1).max(0) + 1 -
            target_src_offset).astype(np.int)
        if np.any(lo_targ < 0) or np.any(hi_targ >= target_src_shape - 1):
            raise WarpingOOBError("Out of bounds for target_src")

    if np.any(lo < 0) or np.any(hi >= inp_src_shape - 1):
        raise WarpingOOBError("Out of bounds for inp_src")

    # Slice and interpolate input
    # Slice to hi + 1 because interpolation potentially needs this value.
    img_cut = slice_3d(inp_src, lo, hi + 1, dtype=floatX)
    if img_cut.ndim == 3:
        img_cut = img_cut[None]
    inp = np.zeros((n_f, ) + patch_shape, dtype=floatX)
    lo = lo.astype(floatX)

    if debug and np.any(
        (src_coords - lo).max(2).max(1).max(0) >= img_cut.shape[-3:]):
        raise WarpingSanityError(
            f'src_coords check failed (too high).\n{(src_coords - lo).max(2).max(1).max(0), img_cut.shape[-3:]}'
        )
    if debug and np.any((src_coords - lo).min(2).min(1).min(0) < 0):
        raise WarpingSanityError(
            f'src_coords check failed (negative indices).\n{(src_coords - lo).min(2).min(1).min(0)}'
        )

    for k in range(n_f):
        map_coordinates_linear(img_cut[k], src_coords, lo, inp[k])

    # Slice and interpolate target
    if target_src is not None:
        # dtype is float as well here because of the static typing of the
        # numba-compiled map_coordinates functions
        # Slice to hi + 1 because interpolation potentially needs this value.
        target_cut = slice_3d(target_src, lo_targ, hi_targ + 1, dtype=floatX)
        if target_cut.ndim == 3:
            target_cut = target_cut[None]
        src_coords_target = np.ascontiguousarray(src_coords_target,
                                                 dtype=floatX)
        target = np.zeros((n_f_t, ) + target_patch_shape, dtype=floatX)
        lo_targ = (lo_targ + target_src_offset).astype(floatX)
        if target_discrete_ix is None:
            target_discrete_ix = [True for i in range(n_f_t)]
        else:
            target_discrete_ix = [
                i in target_discrete_ix for i in range(n_f_t)
            ]

        if debug and np.any(
            (src_coords_target -
             lo_targ).max(2).max(1).max(0) >= target_cut.shape[-3:]):
            raise WarpingSanityError(
                f'src_coords_target check failed (too high).\n{(src_coords_target - lo_targ).max(2).max(1).max(0)}\n{target_cut.shape[-3:]}'
            )
        if debug and np.any(
            (src_coords_target - lo_targ).min(2).min(1).min(0) < 0):
            raise WarpingSanityError(
                f'src_coords_target check failed (negative indices).\n{(src_coords_target - lo_targ).min(2).min(1).min(0)}'
            )

        for k, discr in enumerate(target_discrete_ix):
            if discr:
                map_coordinates_nearest(target_cut[k], src_coords_target,
                                        lo_targ, target[k])

                if debug:
                    unique_cut = set(list(np.unique(target_cut[k])))
                    unique_warp = set(list(np.unique(target[k])))
                    # If new values appear in discrete targets, there is something wrong.
                    # unique_warp can have less values than unique_cut though, for example
                    #  if the warping transform coincidentally slices away all values of a class.
                    if not unique_warp.issubset(unique_cut):
                        print(
                            f'Invalid target encountered:\n\nunique_cut=\n{unique_cut}\n'
                            f'unique_warp=\n{unique_warp}\nM_inv=\n{M_inv}\n'
                            f'src_coords_target - lo_targ=\n{src_coords_target - lo_targ}\n'
                        )
                        # Try dropping to an IPython shell (Won't work with num_workers > 0).
                        import IPython
                        IPython.embed()
                        raise SystemExit

            else:
                map_coordinates_linear(target_cut[k], src_coords_target,
                                       lo_targ, target[k])

    else:
        target = None

    if debug and np.any(np.isnan(inp)):
        raise RuntimeError('Warping is broken: inp contains NaN.')
    if debug and np.any(np.isnan(target)):
        raise RuntimeError('Warping is broken: target contains NaN.')

    return inp, target