def _broadcast_to(arr, shape): if hasattr(arr, "broadcast_to"): return arr.broadcast_to(shape) _check_arraylike("broadcast_to", arr) arr = arr if isinstance(arr, ndarray) else _asarray(arr) if not isinstance(shape, tuple) and np.ndim(shape) == 0: shape = (shape, ) shape = core.canonicalize_shape(shape) # check that shape is concrete arr_shape = np.shape(arr) if core.symbolic_equal_shape(arr_shape, shape): return arr else: nlead = len(shape) - len(arr_shape) shape_tail = shape[nlead:] compatible = all( core.symbolic_equal_one_of_dim(arr_d, [1, shape_d]) for arr_d, shape_d in safe_zip(arr_shape, shape_tail)) if nlead < 0 or not compatible: msg = "Incompatible shapes for broadcasting: {} and requested shape {}" raise ValueError(msg.format(arr_shape, shape)) diff, = np.where( tuple(not core.symbolic_equal_dim(arr_d, shape_d) for arr_d, shape_d in safe_zip(arr_shape, shape_tail))) new_dims = tuple(range(nlead)) + tuple(nlead + diff) kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims)) return lax.broadcast_in_dim(lax.squeeze(arr, tuple(diff)), shape, kept_dims)
def resize(image, shape: core.Shape, method: Union[str, ResizeMethod], antialias: bool = True, precision=lax.Precision.HIGHEST): """Image resize. The ``method`` argument expects one of the following resize methods: ``ResizeMethod.NEAREST``, ``"nearest"`` `Nearest neighbor interpolation`_. The values of ``antialias`` and ``precision`` are ignored. ``ResizeMethod.LINEAR``, ``"linear"``, ``"bilinear"``, ``"trilinear"``, ``"triangle"`` `Linear interpolation`_. If ``antialias`` is ``True``, uses a triangular filter when downsampling. ``ResizeMethod.CUBIC``, ``"cubic"``, ``"bicubic"``, ``"tricubic"`` `Cubic interpolation`_, using the Keys cubic kernel. ``ResizeMethod.LANCZOS3``, ``"lanczos3"`` `Lanczos resampling`_, using a kernel of radius 3. ``ResizeMethod.LANCZOS5``, ``"lanczos5"`` `Lanczos resampling`_, using a kernel of radius 5. .. _Nearest neighbor interpolation: https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation .. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation .. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation .. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling Args: image: a JAX array. shape: the output shape, as a sequence of integers with length equal to the number of dimensions of `image`. Note that :func:`resize` does not distinguish spatial dimensions from batch or channel dimensions, so this includes all dimensions of the image. To represent a batch or a channel dimension, simply leave that element of the shape unchanged. method: the resizing method to use; either a ``ResizeMethod`` instance or a string. Available methods are: LINEAR, LANCZOS3, LANCZOS5, CUBIC. antialias: should an antialiasing filter be used when downsampling? Defaults to ``True``. Has no effect when upsampling. Returns: The resized image. """ return _resize(image, core.canonicalize_shape(shape), method, antialias, precision)
def scale_and_translate(image, shape: core.Shape, spatial_dims: Sequence[int], scale, translation, method: Union[str, ResizeMethod], antialias: bool = True, precision=lax.Precision.HIGHEST): """Apply a scale and translation to an image. Generates a new image of shape 'shape' by resampling from the input image using the sampling method corresponding to method. For 2D images, this operation transforms a location in the input images, (x, y), to a location in the output image according to:: (x * scale[1] + translation[1], y * scale[0] + translation[0]) (Note the _inverse_ warp is used to generate the sample locations.) Assumes half-centered pixels, i.e the pixel at integer location row,col has coordinates y, x = row + 0.5, col + 0.5. Similarly for other input image dimensions. If an output location(pixel) maps to an input sample location that is outside the input boundaries then the value for the output location will be set to zero. The ``method`` argument expects one of the following resize methods: ``ResizeMethod.LINEAR``, ``"linear"``, ``"bilinear"``, ``"trilinear"``, ``"triangle"`` `Linear interpolation`_. If ``antialias`` is ``True``, uses a triangular filter when downsampling. ``ResizeMethod.CUBIC``, ``"cubic"``, ``"bicubic"``, ``"tricubic"`` `Cubic interpolation`_, using the Keys cubic kernel. ``ResizeMethod.LANCZOS3``, ``"lanczos3"`` `Lanczos resampling`_, using a kernel of radius 3. ``ResizeMethod.LANCZOS5``, ``"lanczos5"`` `Lanczos resampling`_, using a kernel of radius 5. .. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation .. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation .. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling Args: image: a JAX array. shape: the output shape, as a sequence of integers with length equal to the number of dimensions of `image`. spatial_dims: A length K tuple specifying the spatial dimensions that the passed scale and translation should be applied to. scale: A [K] array with the same number of dimensions as image, containing the scale to apply in each dimension. translation: A [K] array with the same number of dimensions as image, containing the translation to apply in each dimension. method: the resizing method to use; either a ``ResizeMethod`` instance or a string. Available methods are: LINEAR, LANCZOS3, LANCZOS5, CUBIC. antialias: Should an antialiasing filter be used when downsampling? Defaults to ``True``. Has no effect when upsampling. Returns: The scale and translated image. """ shape = core.canonicalize_shape(shape) if len(shape) != image.ndim: msg = ( 'shape must have length equal to the number of dimensions of x; ' f' {shape} vs {image.shape}') raise ValueError(msg) if isinstance(method, str): method = ResizeMethod.from_string(method) if method == ResizeMethod.NEAREST: # Nearest neighbor is currently special-cased for straight resize, so skip # for now. raise ValueError( 'Nearest neighbor resampling is not currently supported ' 'for scale_and_translate.') assert isinstance(method, ResizeMethod) kernel = _kernels[method] if not jnp.issubdtype(image.dtype, jnp.inexact): image = lax.convert_element_type(image, jnp.result_type(image, jnp.float32)) if not jnp.issubdtype(scale.dtype, jnp.inexact): scale = lax.convert_element_type(scale, jnp.result_type(scale, jnp.float32)) if not jnp.issubdtype(translation.dtype, jnp.inexact): translation = lax.convert_element_type( translation, jnp.result_type(translation, jnp.float32)) return _scale_and_translate(image, shape, spatial_dims, scale, translation, kernel, antialias, precision)