Пример #1
0
def gate_dtypes(dtypes, allowed, disallowed, augmenter=None):
    # assume that at least one allowed dtype string is given
    assert len(allowed) > 0, (
        "Expected at least one dtype to be allowed, but got an empty list.")
    # check only first dtype for performance
    assert ia.is_string(
        allowed[0]), ("Expected only strings as dtypes, but got type %s." %
                      (type(allowed[0]), ))

    if len(disallowed) > 0:
        # check only first disallowed dtype for performance
        assert ia.is_string(disallowed[0]), (
            "Expected only strings as dtypes, but got type %s." %
            (type(disallowed[0]), ))

    # verify that "allowed" and "disallowed" do not contain overlapping
    # dtypes
    inters = set(allowed).intersection(set(disallowed))
    nb_overlapping = len(inters)
    assert nb_overlapping == 0, (
        "Expected 'allowed' and 'disallowed' to not contain the same dtypes, "
        "but %d appeared in both arguments. Got allowed: %s, "
        "disallowed: %s, intersection: %s" %
        (nb_overlapping, ", ".join(allowed), ", ".join(disallowed),
         ", ".join(inters)))

    dtypes = normalize_dtypes(dtypes)

    for dtype in dtypes:
        if dtype.name in allowed:
            pass
        elif dtype.name in disallowed:
            if augmenter is None:
                raise ValueError(
                    "Got dtype '%s', which is a forbidden dtype (%s)." %
                    (dtype.name, ", ".join(disallowed)))

            raise ValueError(
                "Got dtype '%s' in augmenter '%s' (class '%s'), which "
                "is a forbidden dtype (%s)." %
                (dtype.name, augmenter.name, augmenter.__class__.__name__,
                 ", ".join(disallowed)))
        else:
            if augmenter is None:
                ia.warn(
                    "Got dtype '%s', which was neither explicitly allowed "
                    "(%s), nor explicitly disallowed (%s). Generated "
                    "outputs may contain errors." %
                    (dtype.name, ", ".join(allowed), ", ".join(disallowed)))
            else:
                ia.warn(
                    "Got dtype '%s' in augmenter '%s' (class '%s'), which was "
                    "neither explicitly allowed (%s), nor explicitly "
                    "disallowed (%s). Generated outputs may contain "
                    "errors." %
                    (dtype.name, augmenter.name, augmenter.__class__.__name__,
                     ", ".join(allowed), ", ".join(disallowed)))
Пример #2
0
def _warn_on_suspicious_multi_image_shapes(images):
    if images is None:
        return

    # check if it looks like (H, W, C) instead of (N, H, W)
    if ia.is_np_array(images):
        if images.ndim == 3 and images.shape[-1] in [1, 3]:
            ia.warn("You provided a numpy array of shape %s as a "
                    "multi-image augmentation input, which was interpreted as "
                    "(N, H, W). The last dimension however has value 1 or "
                    "3, which indicates that you provided a single image "
                    "with shape (H, W, C) instead. If that is the case, "
                    "you should use e.g. augmenter(image=<your input>) or "
                    "augment_image(<your input>) -- note the singular 'image' "
                    "instead of 'imageS'. Otherwise your single input image "
                    "will be interpreted as multiple images of shape (H, W) "
                    "during augmentation." % (images.shape, ),
                    category=SuspiciousMultiImageShapeWarning)
Пример #3
0
def _warn_on_suspicious_single_image_shape(image):
    if image is None:
        return

    # Check if it looks like (N, H, W) instead of (H, W, C).
    # We don't react to (1, 1, C) though, mostly because that is used in many
    # unittests.
    if image.ndim == 3 and image.shape[-1] >= 32 and image.shape[0:2] != (1,
                                                                          1):
        ia.warn("You provided a numpy array of shape %s as a "
                "single-image augmentation input, which was interpreted as "
                "(H, W, C). The last dimension however has a size of >=32, "
                "which indicates that you provided a multi-image array "
                "with shape (N, H, W) instead. If that is the case, "
                "you should use e.g. augmenter(imageS=<your input>) or "
                "augment_imageS(<your input>). Otherwise your multi-image "
                "input will be interpreted as a single image during "
                "augmentation." % (image.shape, ),
                category=SuspiciousSingleImageShapeWarning)
Пример #4
0
def blur_gaussian_(image, sigma, ksize=None, backend="auto", eps=1e-3):
    """Blur an image using gaussian blurring in-place.

    This operation *may* change the input image in-place.

    dtype support::

        if (backend="auto")::

            * ``uint8``: yes; fully tested (1)
            * ``uint16``: yes; tested (1)
            * ``uint32``: yes; tested (2)
            * ``uint64``: yes; tested (2)
            * ``int8``: yes; tested (1)
            * ``int16``: yes; tested (1)
            * ``int32``: yes; tested (1)
            * ``int64``: yes; tested (2)
            * ``float16``: yes; tested (1)
            * ``float32``: yes; tested (1)
            * ``float64``: yes; tested (1)
            * ``float128``: no
            * ``bool``: yes; tested (1)

            - (1) Handled by ``cv2``. See ``backend="cv2"``.
            - (2) Handled by ``scipy``. See ``backend="scipy"``.

        if (backend="cv2")::

            * ``uint8``: yes; fully tested
            * ``uint16``: yes; tested
            * ``uint32``: no (2)
            * ``uint64``: no (3)
            * ``int8``: yes; tested (4)
            * ``int16``: yes; tested
            * ``int32``: yes; tested (5)
            * ``int64``: no (6)
            * ``float16``: yes; tested (7)
            * ``float32``: yes; tested
            * ``float64``: yes; tested
            * ``float128``: no (8)
            * ``bool``: yes; tested (1)

            - (1) Mapped internally to ``float32``. Otherwise causes
                  ``TypeError: src data type = 0 is not supported``.
            - (2) Causes ``TypeError: src data type = 6 is not supported``.
            - (3) Causes ``cv2.error: OpenCV(3.4.5) (...)/filter.cpp:2957:
                  error: (-213:The function/feature is not implemented)
                  Unsupported combination of source format (=4), and buffer
                  format (=5) in function 'getLinearRowFilter'``.
            - (4) Mapped internally to ``int16``. Otherwise causes
                  ``cv2.error: OpenCV(3.4.5) (...)/filter.cpp:2957: error:
                  (-213:The function/feature is not implemented) Unsupported
                  combination of source format (=1), and buffer format (=5)
                  in function 'getLinearRowFilter'``.
            - (5) Mapped internally to ``float64``. Otherwise causes
                  ``cv2.error: OpenCV(3.4.5) (...)/filter.cpp:2957: error:
                  (-213:The function/feature is not implemented) Unsupported
                  combination of source format (=4), and buffer format (=5)
                  in function 'getLinearRowFilter'``.
            - (6) Causes ``cv2.error: OpenCV(3.4.5) (...)/filter.cpp:2957:
                  error: (-213:The function/feature is not implemented)
                  Unsupported combination of source format (=4), and buffer
                  format (=5) in function 'getLinearRowFilter'``.
            - (7) Mapped internally to ``float32``. Otherwise causes
                  ``TypeError: src data type = 23 is not supported``.
            - (8) Causes ``TypeError: src data type = 13 is not supported``.

        if (backend="scipy")::

            * ``uint8``: yes; fully tested
            * ``uint16``: yes; tested
            * ``uint32``: yes; tested
            * ``uint64``: yes; tested
            * ``int8``: yes; tested
            * ``int16``: yes; tested
            * ``int32``: yes; tested
            * ``int64``: yes; tested
            * ``float16``: yes; tested (1)
            * ``float32``: yes; tested
            * ``float64``: yes; tested
            * ``float128``: no (2)
            * ``bool``: yes; tested (3)

            - (1) Mapped internally to ``float32``. Otherwise causes
                  ``RuntimeError: array type dtype('float16') not supported``.
            - (2) Causes ``RuntimeError: array type dtype('float128') not
                  supported``.
            - (3) Mapped internally to ``float32``. Otherwise too inaccurate.

    Parameters
    ----------
    image : numpy.ndarray
        The image to blur. Expected to be of shape ``(H, W)`` or ``(H, W, C)``.

    sigma : number
        Standard deviation of the gaussian blur. Larger numbers result in
        more large-scale blurring, which is overall slower than small-scale
        blurring.

    ksize : None or int, optional
        Size in height/width of the gaussian kernel. This argument is only
        understood by the ``cv2`` backend. If it is set to ``None``, an
        appropriate value for `ksize` will automatically be derived from
        `sigma`. The value is chosen tighter for larger sigmas to avoid as
        much as possible very large kernel sizes and therey improve
        performance.

    backend : {'auto', 'cv2', 'scipy'}, optional
        Backend library to use. If ``auto``, then the likely best library
        will be automatically picked per image. That is usually equivalent
        to ``cv2`` (OpenCV) and it will fall back to ``scipy`` for datatypes
        not supported by OpenCV.

    eps : number, optional
        A threshold used to decide whether `sigma` can be considered zero.

    Returns
    -------
    numpy.ndarray
        The blurred image. Same shape and dtype as the input.
        (Input image *might* have been altered in-place.)

    """
    has_zero_sized_axes = (image.size == 0)
    if sigma > 0 + eps and not has_zero_sized_axes:
        dtype = image.dtype

        iadt.gate_dtypes(image,
                         allowed=[
                             "bool", "uint8", "uint16", "uint32", "int8",
                             "int16", "int32", "int64", "uint64", "float16",
                             "float32", "float64"
                         ],
                         disallowed=[
                             "uint128", "uint256", "int128", "int256",
                             "float96", "float128", "float256"
                         ],
                         augmenter=None)

        dts_not_supported_by_cv2 = ["uint32", "uint64", "int64", "float128"]
        backend_to_use = backend
        if backend == "auto":
            backend_to_use = ("cv2" if image.dtype.name
                              not in dts_not_supported_by_cv2 else "scipy")
        elif backend == "cv2":
            assert image.dtype.name not in dts_not_supported_by_cv2, (
                "Requested 'cv2' backend, but provided %s input image, which "
                "cannot be handled by that backend. Choose a different "
                "backend or set backend to 'auto' or use a different "
                "datatype." % (image.dtype.name, ))
        elif backend == "scipy":
            # can handle all dtypes that were allowed in gate_dtypes()
            pass

        if backend_to_use == "scipy":
            if dtype.name == "bool":
                # We convert bool to float32 here, because gaussian_filter()
                # seems to only return True when the underlying value is
                # approximately 1.0, not when it is above 0.5. So we do that
                # here manually. cv2 does not support bool for gaussian blur.
                image = image.astype(np.float32, copy=False)
            elif dtype.name == "float16":
                image = image.astype(np.float32, copy=False)

            # gaussian_filter() has no ksize argument
            # TODO it does have a truncate argument that truncates at x
            #      standard deviations -- maybe can be used similarly to ksize
            if ksize is not None:
                ia.warn(
                    "Requested 'scipy' backend or picked it automatically by "
                    "backend='auto' n blur_gaussian_(), but also provided "
                    "'ksize' argument, which is not understood by that "
                    "backend and will be ignored.")

            # Note that while gaussian_filter can be applied to all channels
            # at the same time, that should not be done here, because then
            # the blurring would also happen across channels (e.g. red values
            # might be mixed with blue values in RGB)
            if image.ndim == 2:
                image[:, :] = ndimage.gaussian_filter(image[:, :],
                                                      sigma,
                                                      mode="mirror")
            else:
                nb_channels = image.shape[2]
                for channel in sm.xrange(nb_channels):
                    image[:, :,
                          channel] = ndimage.gaussian_filter(image[:, :,
                                                                   channel],
                                                             sigma,
                                                             mode="mirror")
        else:
            if dtype.name == "bool":
                image = image.astype(np.float32, copy=False)
            elif dtype.name == "float16":
                image = image.astype(np.float32, copy=False)
            elif dtype.name == "int8":
                image = image.astype(np.int16, copy=False)
            elif dtype.name == "int32":
                image = image.astype(np.float64, copy=False)

            # ksize here is derived from the equation to compute sigma based
            # on ksize, see
            # https://docs.opencv.org/3.1.0/d4/d86/group__imgproc__filter.html
            # -> cv::getGaussianKernel()
            # example values:
            #   sig = 0.1 -> ksize = -1.666
            #   sig = 0.5 -> ksize = 0.9999
            #   sig = 1.0 -> ksize = 1.0
            #   sig = 2.0 -> ksize = 11.0
            #   sig = 3.0 -> ksize = 17.666
            # ksize = ((sig - 0.8)/0.3 + 1)/0.5 + 1

            if ksize is None:
                ksize = _compute_gaussian_blur_ksize(sigma)
            else:
                assert ia.is_single_integer(ksize), (
                    "Expected 'ksize' argument to be a number, "
                    "got %s." % (type(ksize), ))

            ksize = ksize + 1 if ksize % 2 == 0 else ksize

            if ksize > 0:
                image_warped = cv2.GaussianBlur(
                    _normalize_cv2_input_arr_(image), (ksize, ksize),
                    sigmaX=sigma,
                    sigmaY=sigma,
                    borderType=cv2.BORDER_REFLECT_101)

                # re-add channel axis removed by cv2 if input was (H, W, 1)
                image = (image_warped[..., np.newaxis] if image.ndim == 3
                         and image_warped.ndim == 2 else image_warped)

        if dtype.name == "bool":
            image = image > 0.5
        elif dtype.name != image.dtype.name:
            image = iadt.restore_dtypes_(image, dtype)

    return image
Пример #5
0
def _gate_dtypes(dtypes, allowed, disallowed, augmenter=None):
    """Verify that input dtypes are among allowed and not disallowed dtypes.

    Added in 0.5.0.

    Parameters
    ----------
    dtypes : numpy.ndarray or iterable of numpy.ndarray or iterable of numpy.dtype
        One or more input dtypes to verify.
        Must not be a dtype function (like ``np.int64``), only a proper
        dtype (like ``np.dtype("int64")``). For performance reasons this is
        not validated.

    allowed : set of numpy.dtype
        One or more allowed dtypes.

    disallowed : None or set of numpy.dtype
        Any number of disallowed dtypes. Should not intersect with allowed
        dtypes.

    augmenter : None or imgaug.augmenters.meta.Augmenter, optional
        If the gating happens for an augmenter, it should be provided
        here. This information will be used to improve output error
        messages and warnings.

    """
    if isinstance(dtypes, np.ndarray) or ia.is_np_scalar(dtypes):
        dtypes = set([dtypes.dtype])
    elif isinstance(dtypes, list):
        dtypes = {arr.dtype for arr in dtypes}

    dts_not_explicitly_allowed = dtypes - allowed
    all_allowed = (not dts_not_explicitly_allowed)

    if all_allowed:
        return

    if disallowed is None:
        disallowed = set()

    dts_explicitly_disallowed = dts_not_explicitly_allowed.intersection(
        disallowed
    )
    dts_undefined = dts_not_explicitly_allowed - disallowed

    if dts_explicitly_disallowed:
        for dtype in dts_explicitly_disallowed:
            if augmenter is None:
                raise ValueError(
                    "Got dtype '%s', which is a forbidden dtype (%s)." % (
                        np.dtype(dtype).name,
                        _dtype_names_to_string(disallowed)
                    ))

            raise ValueError(
                "Got dtype '%s' in augmenter '%s' (class '%s'), which "
                "is a forbidden dtype (%s)." % (
                    np.dtype(dtype).name,
                    augmenter.name,
                    augmenter.__class__.__name__,
                    _dtype_names_to_string(disallowed),
                ))

    if dts_undefined:
        for dtype in dts_undefined:
            if augmenter is None:
                ia.warn(
                    "Got dtype '%s', which was neither explicitly allowed "
                    "(%s), nor explicitly disallowed (%s). Generated "
                    "outputs may contain errors." % (
                        dtype.name,
                        _dtype_names_to_string(allowed),
                        _dtype_names_to_string(disallowed),
                    )
                )
            else:
                ia.warn(
                    "Got dtype '%s' in augmenter '%s' (class '%s'), which was "
                    "neither explicitly allowed (%s), nor explicitly "
                    "disallowed (%s). Generated outputs may contain "
                    "errors." % (
                        dtype.name,
                        augmenter.name,
                        augmenter.__class__.__name__,
                        _dtype_names_to_string(allowed),
                        _dtype_names_to_string(disallowed),
                    )
                )