Exemple #1
0
def get_preview_batch(h5data: Tuple[str, str],
                      preview_shape: Optional[Tuple[int, ...]] = None,
                      transform: Callable = transforms.Identity(),
                      in_memory: bool = False) -> torch.Tensor:
    fname, key = h5data
    inp_h5 = h5py.File(fname, 'r')[key]
    if in_memory:
        inp_h5 = inp_h5.value
    dim = len(preview_shape)  # 2D or 3D
    inp_shape = np.array(inp_h5.shape[-dim:])
    if preview_shape is None:  # Slice everything
        inp_lo = np.zeros_like(inp_shape)
        inp_hi = inp_shape
    else:  # Slice only a preview_shape-sized region from the center of the input
        halfshape = np.array(preview_shape) // 2
        inp_center = inp_shape // 2
        inp_lo = inp_center - halfshape
        inp_hi = inp_center + halfshape
        if np.any(inp_center < halfshape):
            raise ValueError(
                'preview_shape is too big for shape of input source.'
                f'Requested {preview_shape}, but can only deliver {tuple(inp_shape)}.'
            )
    memstr = ' (in memory)' if in_memory else ''
    logger.info(f'\nPreview data{memstr}:')
    logger.info(
        f'  input:       {fname}[{key}]: {inp_h5.shape} ({inp_h5.dtype})\n')
    inp_np = slice_h5(inp_h5, inp_lo, inp_hi, prepend_empty_axis=True)
    if inp_np.ndim == dim + 1:  # Should be dim + 2 for (N, C) dims
        inp_np = inp_np[:, None]  # Add missing C dim
    inp_np, _ = transform(inp_np, None)
    inp = torch.from_numpy(inp_np)
    return inp
Exemple #2
0
    def _create_preview_batch(
            self,
            inp_source: h5py.Dataset,
            target_source: h5py.Dataset,
    ) -> Tuple[torch.Tensor, torch.LongTensor]:
        # Central slicing
        halfshape = np.array(self.preview_shape) // 2
        if inp_source.ndim == 4:
            inp_shape = np.array(inp_source.shape[1:])
            target_shape = np.array(target_source.shape[1:])
        elif inp_source.ndim == 3:
            inp_shape = np.array(inp_source.shape)
            target_shape = np.array(target_source.shape)
        inp_center = inp_shape // 2
        inp_lo = inp_center - halfshape
        inp_hi = inp_center + halfshape
        target_center = target_shape // 2
        target_lo = target_center - halfshape
        target_hi = target_center + halfshape
        if np.any(inp_center < halfshape):
            raise ValueError(
                'preview_shape is too big for shape of input source.'
                f'Requested {self.preview_shape}, but can only deliver {tuple(inp_shape)}.'
            )
        elif np.any(target_center < halfshape):
            raise ValueError(
                'preview_shape is too big for shape of target source.'
                f'Requested {self.preview_shape}, but can only deliver {tuple(target_shape)}.'
            )
        inp_np = slice_h5(inp_source, inp_lo, inp_hi, prepend_empty_axis=True)
        target_np = slice_h5(
            target_source, target_lo, target_hi,
            dtype=self._target_dtype, prepend_empty_axis=True
        )
        inp_np, target_np = self.transform(inp_np, target_np)

        inp = torch.from_numpy(inp_np)
        target = torch.from_numpy(target_np)

        return inp, target
def warp_slice(inp_src,
               patch_shape,
               M,
               target_src=None,
               target_patch_shape=None,
               target_discrete_ix=None) -> Tuple[np.ndarray, np.ndarray]:
    """
    Cuts a warped slice out of the input image and out of the target_src image.
    Warping is applied by multiplying the original source coordinates with
    the inverse of the homogeneous (forward) transformation matrix ``M``.

    "Source coordinates" (``src_coords``) signify the coordinates of voxels in
    ``inp_src`` and ``target_src`` that are used to compose their respective warped
    versions. The idea here is that not the images themselves, but the
    coordinates from where they are read are warped. This allows for much higher
    efficiency for large image volumes because we don't have to calculate the
    expensive warping transform for the whole image, but only for the voxels
    that we eventually want to use for the new warped image.
    The transformed coordinates usually don't align to the discrete
    voxel grids of the original images (meaning they are not integers), so the
    new voxel values are obtained by linear interpolation.

    Parameters
    ----------
    inp_src: h5py.Dataset
        Input image source (in HDF5)
    patch_shape: tuple or np.ndarray
        (spatial only) Patch shape ``(D, H, W)``
        (spatial shape of the neural network's input node)
    M: np.ndarray
        Forward warping tansformation matrix (4x4).
        Must contain translations in source and target_src array.
    target_src: h5py.Dataset or None
        Optional target source array to be extracted from in the same way.
    target_patch_shape: tuple or np.ndarray
        Patch size for the ``target_src`` array.
    target_discrete_ix: list
        List of target channels that contain discrete values.
        By default (``None``), every channel is is seen as discrete (this is
        generally the case for classification tasks).
        This information is used to decide what kind of interpolation should
        be used for reading target data:
        - discrete targets are obtained by nearest-neighbor interpolation
        - non-discrete (continuous) targets are linearly interpolated.

    Returns
    -------
    inp: np.ndarray
        Warped input image slice
    target: np.ndarray or None
        Warped target_src image slice
        or ``None``, if ``target_src is None``.
    """

    patch_shape = tuple(patch_shape)
    if len(inp_src.shape) == 3:
        print(f'inp_src.shape: {inp_src.shape}')
        raise NotImplementedError(
            'elektronn3 has dropped support for data stored in raw 3D form without a channel axis. '
            'Please always supply it with a prepended channel, so it\n'
            'has the form (C, D, H, W) (or in ELEKTRONN2 terms: (f, z, x, y)).'
        )
    elif len(inp_src.shape) == 4:
        n_f = inp_src.shape[0]
        sh = inp_src.shape[1:]
    else:
        raise ValueError('inp_src wrong dim/shape')

    M_inv = np.linalg.inv(M.astype(np.float64)).astype(floatX)  # stability...
    dest_corners = make_dest_corners(patch_shape)
    src_corners = np.dot(M_inv, dest_corners.T).T
    if np.any(M[3, :3] != 0):  # homogeneous divide
        src_corners /= src_corners[:, 3][:, None]

    # check corners
    src_corners = src_corners[:, :3]
    lo = np.min(np.floor(src_corners), 0).astype(np.int)
    hi = np.max(np.ceil(src_corners + 1),
                0).astype(np.int)  # add 1 because linear interp
    if np.any(lo < 0) or np.any(hi >= sh):
        raise WarpingOOBError("Out of bounds")
    # compute/transform dense coords
    dest_coords = make_dest_coords(patch_shape)
    src_coords = np.tensordot(dest_coords, M_inv, axes=[[-1], [1]])
    if np.any(M[3, :3] != 0):  # homogeneous divide
        src_coords /= src_coords[..., 3][..., None]
    # cut patch
    src_coords = src_coords[..., :3]
    # Add 1 to hi to include this coordinate!
    img_cut = slice_h5(inp_src, lo, hi + 1, dtype=floatX)

    inp = np.zeros((n_f, ) + patch_shape, dtype=floatX)
    lo = lo.astype(floatX)
    for k in range(n_f):
        map_coordinates_linear(img_cut[k], src_coords, lo, inp[k])
    if target_src is not None:
        target_patch_shape = tuple(target_patch_shape)
        n_f_t = target_src.shape[0]

        off = np.subtract(sh, target_src.shape[-3:])
        if np.any(np.mod(off, 2)):
            raise ValueError("targets must be centered w.r.t. images")
        off //= 2

        off_ps = np.subtract(patch_shape, target_patch_shape)
        if np.any(np.mod(off_ps, 2)):
            raise ValueError("targets must be centered w.r.t. images")
        off_ps //= 2

        src_coords_target = src_coords[off_ps[0]:off_ps[0] +
                                       target_patch_shape[0],
                                       off_ps[1]:off_ps[1] +
                                       target_patch_shape[1],
                                       off_ps[2]:off_ps[2] +
                                       target_patch_shape[2]]
        # shift coords to be w.r.t. to origin of target_src array
        lo_targ = np.floor(src_coords_target.min(2).min(1).min(0) -
                           off).astype(np.int)
        # add 1 because linear interp
        hi_targ = np.ceil(src_coords_target.max(2).max(1).max(0) - off +
                          1).astype(np.int)
        if np.any(lo_targ < 0) or np.any(hi_targ >= target_src.shape[-3:]):
            raise WarpingOOBError("Out of bounds for target_src")
        # dtype is float as well here because of the static typing of the
        # numba-compiled map_coordinates functions
        target_cut = slice_h5(target_src, lo_targ, hi_targ + 1, dtype=floatX)

        # TODO: This and the checks below only make sense for discrete targets. Continuous targets are currently BROKEN.
        n_target_classes = target_cut.max()
        src_coords_target = np.ascontiguousarray(src_coords_target,
                                                 dtype=floatX)
        target = np.zeros((n_f_t, ) + target_patch_shape, dtype=floatX)
        lo_targ = (lo_targ + off).astype(floatX)
        if target_discrete_ix is None:
            target_discrete_ix = [True for i in range(n_f_t)]
        else:
            target_discrete_ix = [
                i in target_discrete_ix for i in range(n_f_t)
            ]

        for k, discr in enumerate(target_discrete_ix):
            if discr:
                map_coordinates_nearest(target_cut[k], src_coords_target,
                                        lo_targ, target[k])
            else:
                map_coordinates_linear(target_cut[k], src_coords_target,
                                       lo_targ, target[k])

        if np.any(target > n_target_classes):
            print(
                f'warp_slice: Invalid target: max = {target.max()}. Clipping target...'
            )
            target = target.clip(0, n_target_classes)
            # TODO: Or should we just throw an error? (~ WarpingOOB)
    else:
        target = None
    return inp, target
Exemple #4
0
def warp_slice(
    inp_src: Union[h5py.Dataset, np.ndarray],
    patch_shape: Union[Tuple[int], np.ndarray],
    M: np.ndarray,
    target_src: Optional[Union[h5py.Dataset, np.ndarray]] = None,
    target_patch_shape: Optional[Union[Tuple[int], np.ndarray]] = None,
    target_discrete_ix: Optional[Sequence[int]] = None,
    debug:
    bool = True  # TODO: This has some performance impact. Switch this off by default when we're sure everything works.
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
    """
    Cuts a warped slice out of the input image and out of the target_src image.
    Warping is applied by multiplying the original source coordinates with
    the inverse of the homogeneous (forward) transformation matrix ``M``.

    "Source coordinates" (``src_coords``) signify the coordinates of voxels in
    ``inp_src`` and ``target_src`` that are used to compose their respective warped
    versions. The idea here is that not the images themselves, but the
    coordinates from where they are read are warped. This allows for much higher
    efficiency for large image volumes because we don't have to calculate the
    expensive warping transform for the whole image, but only for the voxels
    that we eventually want to use for the new warped image.
    The transformed coordinates usually don't align to the discrete
    voxel grids of the original images (meaning they are not integers), so the
    new voxel values are obtained by linear interpolation.

    Parameters
    ----------
    inp_src
        Input image source (in HDF5)
    patch_shape
        (spatial only) Patch shape ``(D, H, W)``
        (spatial shape of the neural network's input node)
    M
        Forward warping tansformation matrix (4x4).
        Must contain translations in source and target_src array.
    target_src
        Optional target source array to be extracted from in the same way.
    target_patch_shape
        Patch size for the ``target_src`` array.
    target_discrete_ix
        List of target channels that contain discrete values.
        By default (``None``), every channel is is seen as discrete (this is
        generally the case for classification tasks).
        This information is used to decide what kind of interpolation should
        be used for reading target data:
        - discrete targets are obtained by nearest-neighbor interpolation
        - non-discrete (continuous) targets are linearly interpolated.
    debug: If ``True`` (default), enable additional sanity checks to catch
        warping issues early.

    Returns
    -------
    inp
        Warped input image slice
    target
        Warped target_src image slice
        or ``None``, if ``target_src is None``.
    """

    patch_shape = tuple(patch_shape)
    if len(inp_src.shape) == 3:
        n_f = 1
    elif len(inp_src.shape) == 4:
        n_f = inp_src.shape[0]
    else:
        raise ValueError(f'Can\'t handle inp_src shape {inp_src.shape}')

    # Spatial shapes of input and target data sources
    inp_src_shape = np.array(inp_src.shape[-3:])
    target_src_shape = np.array(target_src.shape[-3:])

    M_inv = np.linalg.inv(M.astype(np.float64)).astype(floatX)  # stability...
    dest_corners = make_dest_corners(patch_shape)
    src_corners = np.dot(M_inv, dest_corners.T).T
    if np.any(M[3, :3] != 0):  # homogeneous divide
        src_corners /= src_corners[:, 3][:, None]

    # check corners
    src_corners = src_corners[:, :3]
    # compute/transform dense coords
    dest_coords = make_dest_coords(patch_shape)
    src_coords = np.tensordot(dest_coords, M_inv, axes=[[-1], [1]])
    if np.any(M[3, :3] != 0):  # homogeneous divide
        src_coords /= src_coords[..., 3][..., None]
    # cut patch
    src_coords = src_coords[..., :3]

    if target_src is not None:
        target_patch_shape = tuple(target_patch_shape)
        n_f_t = target_src.shape[0] if target_src.ndim == 4 else 1

        target_src_offset = np.subtract(inp_src_shape, target_src.shape[-3:])
        if np.any(np.mod(target_src_offset, 2)):
            raise ValueError("targets must be centered w.r.t. images")
        target_src_offset //= 2

        target_offset = np.subtract(patch_shape, target_patch_shape)
        if np.any(np.mod(target_offset, 2)):
            raise ValueError("targets must be centered w.r.t. images")
        target_offset //= 2

        src_coords_target = src_coords[
            target_offset[0]:(target_offset[0] + target_patch_shape[0]),
            target_offset[1]:(target_offset[1] + target_patch_shape[1]),
            target_offset[2]:(target_offset[2] + target_patch_shape[2])]
        # shift coords to be w.r.t. to origin of target_src array
        lo_targ = np.floor(
            src_coords_target.min(2).min(1).min(0) - target_src_offset).astype(
                np.int)
        hi_targ = np.ceil(
            src_coords_target.max(2).max(1).max(0) + 1 -
            target_src_offset).astype(np.int)
        if np.any(lo_targ < 0) or np.any(hi_targ >= target_src_shape - 1):
            raise WarpingOOBError("Out of bounds for target_src")

    lo = np.min(np.floor(src_corners), 0).astype(np.int)
    hi = np.max(np.ceil(src_corners + 1), 0).astype(np.int)
    if np.any(lo < 0) or np.any(hi >= inp_src_shape - 1):
        raise WarpingOOBError("Out of bounds for inp_src")

    # Slice and interpolate input
    # Slice to hi + 1 because interpolation potentially needs this value.
    img_cut = slice_h5(inp_src, lo, hi + 1, dtype=floatX)
    if img_cut.ndim == 3:
        img_cut = img_cut[None]
    inp = np.zeros((n_f, ) + patch_shape, dtype=floatX)
    lo = lo.astype(floatX)

    if debug and np.any(
        (src_coords - lo).max(2).max(1).max(0) >= img_cut.shape[-3:]):
        raise WarpingSanityError(
            f'src_coords check failed (too high).\n{(src_coords - lo).max(2).max(1).max(0), img_cut.shape[-3:]}'
        )
    if debug and np.any((src_coords - lo).min(2).min(1).min(0) < 0):
        raise WarpingSanityError(
            f'src_coords check failed (negative indices).\n{(src_coords - lo).min(2).min(1).min(0)}'
        )

    for k in range(n_f):
        map_coordinates_linear(img_cut[k], src_coords, lo, inp[k])

    # Slice and interpolate target
    if target_src is not None:
        # dtype is float as well here because of the static typing of the
        # numba-compiled map_coordinates functions
        # Slice to hi + 1 because interpolation potentially needs this value.
        target_cut = slice_h5(target_src, lo_targ, hi_targ + 1, dtype=floatX)
        if target_cut.ndim == 3:
            target_cut = target_cut[None]
        src_coords_target = np.ascontiguousarray(src_coords_target,
                                                 dtype=floatX)
        target = np.zeros((n_f_t, ) + target_patch_shape, dtype=floatX)
        lo_targ = (lo_targ + target_src_offset).astype(floatX)
        if target_discrete_ix is None:
            target_discrete_ix = [True for i in range(n_f_t)]
        else:
            target_discrete_ix = [
                i in target_discrete_ix for i in range(n_f_t)
            ]

        if debug and np.any(
            (src_coords_target -
             lo_targ).max(2).max(1).max(0) >= target_cut.shape[-3:]):
            raise WarpingSanityError(
                f'src_coords_target check failed (too high).\n{(src_coords_target - lo_targ).max(2).max(1).max(0)}\n{target_cut.shape[-3:]}'
            )
        if debug and np.any(
            (src_coords_target - lo_targ).min(2).min(1).min(0) < 0):
            raise WarpingSanityError(
                f'src_coords_target check failed (negative indices).\n{(src_coords_target - lo_targ).min(2).min(1).min(0)}'
            )

        for k, discr in enumerate(target_discrete_ix):
            if discr:
                map_coordinates_nearest(target_cut[k], src_coords_target,
                                        lo_targ, target[k])

                if debug:
                    unique_cut = set(list(np.unique(target_cut[k])))
                    unique_warp = set(list(np.unique(target[k])))
                    # If new values appear in discrete targets, there is something wrong.
                    # unique_warp can have less values than unique_cut though, for example
                    #  if the warping transform coincidentally slices away all values of a class.
                    if not unique_warp.issubset(unique_cut):
                        print(
                            f'Invalid target encountered:\n\nunique_cut=\n{unique_cut}\n'
                            f'unique_warp=\n{unique_warp}\nM_inv=\n{M_inv}\n'
                            f'src_coords_target - lo_targ=\n{src_coords_target - lo_targ}\n'
                        )
                        # Try dropping to an IPython shell (Won't work with num_workers > 0).
                        import IPython
                        IPython.embed()
                        raise SystemExit

            else:
                map_coordinates_linear(target_cut[k], src_coords_target,
                                       lo_targ, target[k])

    else:
        target = None

    if debug and np.any(np.isnan(inp)):
        raise RuntimeError('Warping is broken: inp contains NaN.')
    if debug and np.any(np.isnan(target)):
        raise RuntimeError('Warping is broken: target contains NaN.')

    return inp, target