예제 #1
0
def _broadcast_to(arr, shape):
    if hasattr(arr, "broadcast_to"):
        return arr.broadcast_to(shape)
    _check_arraylike("broadcast_to", arr)
    arr = arr if isinstance(arr, ndarray) else _asarray(arr)
    if not isinstance(shape, tuple) and np.ndim(shape) == 0:
        shape = (shape, )
    shape = core.canonicalize_shape(shape)  # check that shape is concrete
    arr_shape = np.shape(arr)
    if core.symbolic_equal_shape(arr_shape, shape):
        return arr
    else:
        nlead = len(shape) - len(arr_shape)
        shape_tail = shape[nlead:]
        compatible = all(
            core.symbolic_equal_one_of_dim(arr_d, [1, shape_d])
            for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
        if nlead < 0 or not compatible:
            msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
            raise ValueError(msg.format(arr_shape, shape))
        diff, = np.where(
            tuple(not core.symbolic_equal_dim(arr_d, shape_d)
                  for arr_d, shape_d in safe_zip(arr_shape, shape_tail)))
        new_dims = tuple(range(nlead)) + tuple(nlead + diff)
        kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))
        return lax.broadcast_in_dim(lax.squeeze(arr, tuple(diff)), shape,
                                    kept_dims)
예제 #2
0
def resize(image,
           shape: core.Shape,
           method: Union[str, ResizeMethod],
           antialias: bool = True,
           precision=lax.Precision.HIGHEST):
    """Image resize.

  The ``method`` argument expects one of the following resize methods:

  ``ResizeMethod.NEAREST``, ``"nearest"``
    `Nearest neighbor interpolation`_. The values of ``antialias`` and
    ``precision`` are ignored.

  ``ResizeMethod.LINEAR``, ``"linear"``, ``"bilinear"``, ``"trilinear"``, ``"triangle"``
    `Linear interpolation`_. If ``antialias`` is ``True``, uses a triangular
    filter when downsampling.

  ``ResizeMethod.CUBIC``, ``"cubic"``, ``"bicubic"``, ``"tricubic"``
    `Cubic interpolation`_, using the Keys cubic kernel.

  ``ResizeMethod.LANCZOS3``, ``"lanczos3"``
    `Lanczos resampling`_, using a kernel of radius 3.

  ``ResizeMethod.LANCZOS5``, ``"lanczos5"``
    `Lanczos resampling`_, using a kernel of radius 5.

  .. _Nearest neighbor interpolation: https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
  .. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation
  .. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation
  .. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling

  Args:
    image: a JAX array.
    shape: the output shape, as a sequence of integers with length equal to
      the number of dimensions of `image`. Note that :func:`resize` does not
      distinguish spatial dimensions from batch or channel dimensions, so this
      includes all dimensions of the image. To represent a batch or a channel
      dimension, simply leave that element of the shape unchanged.
    method: the resizing method to use; either a ``ResizeMethod`` instance or a
      string. Available methods are: LINEAR, LANCZOS3, LANCZOS5, CUBIC.
    antialias: should an antialiasing filter be used when downsampling? Defaults
      to ``True``. Has no effect when upsampling.
  Returns:
    The resized image.
  """
    return _resize(image, core.canonicalize_shape(shape), method, antialias,
                   precision)
예제 #3
0
def scale_and_translate(image,
                        shape: core.Shape,
                        spatial_dims: Sequence[int],
                        scale,
                        translation,
                        method: Union[str, ResizeMethod],
                        antialias: bool = True,
                        precision=lax.Precision.HIGHEST):
    """Apply a scale and translation to an image.

  Generates a new image of shape 'shape' by resampling from the input image
  using the sampling method corresponding to method. For 2D images, this
  operation transforms a location in the input images, (x, y), to a location
  in the output image according to::

    (x * scale[1] + translation[1], y * scale[0] + translation[0])

  (Note the _inverse_ warp is used to generate the sample locations.)
  Assumes half-centered pixels, i.e the pixel at integer location row,col has
  coordinates y, x = row + 0.5, col + 0.5.
  Similarly for other input image dimensions.

  If an output location(pixel) maps to an input sample location that is outside
  the input boundaries then the value for the output location will be set to
  zero.

  The ``method`` argument expects one of the following resize methods:

  ``ResizeMethod.LINEAR``, ``"linear"``, ``"bilinear"``, ``"trilinear"``,
    ``"triangle"`` `Linear interpolation`_. If ``antialias`` is ``True``, uses a
    triangular filter when downsampling.

  ``ResizeMethod.CUBIC``, ``"cubic"``, ``"bicubic"``, ``"tricubic"``
    `Cubic interpolation`_, using the Keys cubic kernel.

  ``ResizeMethod.LANCZOS3``, ``"lanczos3"``
    `Lanczos resampling`_, using a kernel of radius 3.

  ``ResizeMethod.LANCZOS5``, ``"lanczos5"``
    `Lanczos resampling`_, using a kernel of radius 5.

  .. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation
  .. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation
  .. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling

  Args:
    image: a JAX array.
    shape: the output shape, as a sequence of integers with length equal to the
      number of dimensions of `image`.
    spatial_dims: A length K tuple specifying the spatial dimensions that the
      passed scale and translation should be applied to.
    scale: A [K] array with the same number of dimensions as image, containing
      the scale to apply in each dimension.
    translation: A [K] array with the same number of dimensions as image,
      containing the translation to apply in each dimension.
    method: the resizing method to use; either a ``ResizeMethod`` instance or a
      string. Available methods are: LINEAR, LANCZOS3, LANCZOS5, CUBIC.
    antialias: Should an antialiasing filter be used when downsampling? Defaults
      to ``True``. Has no effect when upsampling.

  Returns:
    The scale and translated image.
  """
    shape = core.canonicalize_shape(shape)
    if len(shape) != image.ndim:
        msg = (
            'shape must have length equal to the number of dimensions of x; '
            f' {shape} vs {image.shape}')
        raise ValueError(msg)
    if isinstance(method, str):
        method = ResizeMethod.from_string(method)
    if method == ResizeMethod.NEAREST:
        # Nearest neighbor is currently special-cased for straight resize, so skip
        # for now.
        raise ValueError(
            'Nearest neighbor resampling is not currently supported '
            'for scale_and_translate.')
    assert isinstance(method, ResizeMethod)

    kernel = _kernels[method]
    if not jnp.issubdtype(image.dtype, jnp.inexact):
        image = lax.convert_element_type(image,
                                         jnp.result_type(image, jnp.float32))
    if not jnp.issubdtype(scale.dtype, jnp.inexact):
        scale = lax.convert_element_type(scale,
                                         jnp.result_type(scale, jnp.float32))
    if not jnp.issubdtype(translation.dtype, jnp.inexact):
        translation = lax.convert_element_type(
            translation, jnp.result_type(translation, jnp.float32))
    return _scale_and_translate(image, shape, spatial_dims, scale, translation,
                                kernel, antialias, precision)