def odd_ext(x, n, axis=-1): """Extends `x` along with `axis` by odd-extension. This function was previously a part of "scipy.signal.signaltools" but is no longer exposed. Args: x : input array n : the number of points to be added to the both end axis: the axis to be extended """ if n < 1: return x if n > x.shape[axis] - 1: raise ValueError( f"The extension length n ({n}) is too big. " f"It must not exceed x.shape[axis]-1, which is {x.shape[axis] - 1}." ) left_end = lax.slice_in_dim(x, 0, 1, axis=axis) left_ext = jnp.flip(lax.slice_in_dim(x, 1, n + 1, axis=axis), axis=axis) right_end = lax.slice_in_dim(x, -1, None, axis=axis) right_ext = jnp.flip(lax.slice_in_dim(x, -(n + 1), -1, axis=axis), axis=axis) ext = jnp.concatenate( (2 * left_end - left_ext, x, 2 * right_end - right_ext), axis=axis) return ext
def correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0, precision=None): if boundary != 'fill' or fillvalue != 0: raise NotImplementedError( "correlate2d() only supports boundary='fill', fillvalue=0") if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2: raise ValueError("correlate2d() only supports 2-dimensional inputs.") swap = all(s1 <= s2 for s1, s2 in zip(in1.shape, in2.shape)) same_shape = all(s1 == s2 for s1, s2 in zip(in1.shape, in2.shape)) if mode == "same": in1, in2 = jnp.flip(in1), in2.conj() result = jnp.flip(_convolve_nd(in1, in2, mode, precision=precision)) elif mode == "valid": if swap and not same_shape: in1, in2 = jnp.flip(in2), in1.conj() result = _convolve_nd(in1, in2, mode, precision=precision) else: in1, in2 = jnp.flip(in1), in2.conj() result = jnp.flip(_convolve_nd(in1, in2, mode, precision=precision)) else: if swap: in1, in2 = jnp.flip(in2), in1.conj() result = _convolve_nd(in1, in2, mode, precision=precision).conj() else: in1, in2 = jnp.flip(in1), in2.conj() result = jnp.flip(_convolve_nd(in1, in2, mode, precision=precision)) return result
def _convolve_nd(in1, in2, mode, *, precision): if mode not in ["full", "same", "valid"]: raise ValueError("mode must be one of ['full', 'same', 'valid']") if in1.ndim != in2.ndim: raise ValueError("in1 and in2 must have the same number of dimensions") if in1.size == 0 or in2.size == 0: raise ValueError(f"zero-size arrays not supported in convolutions, got shapes {in1.shape} and {in2.shape}.") in1, in2 = _promote_dtypes_inexact(in1, in2) no_swap = all(s1 >= s2 for s1, s2 in zip(in1.shape, in2.shape)) swap = all(s1 <= s2 for s1, s2 in zip(in1.shape, in2.shape)) if not (no_swap or swap): raise ValueError("One input must be smaller than the other in every dimension.") shape_o = in2.shape if swap: in1, in2 = in2, in1 shape = in2.shape in2 = jnp.flip(in2) if mode == 'valid': padding = [(0, 0) for s in shape] elif mode == 'same': padding = [(s - 1 - (s_o - 1) // 2, s - s_o + (s_o - 1) // 2) for (s, s_o) in zip(shape, shape_o)] elif mode == 'full': padding = [(s - 1, s - 1) for s in shape] strides = tuple(1 for s in shape) result = lax.conv_general_dilated(in1[None, None], in2[None, None], strides, padding, precision=precision) return result[0, 0]
def correlate(in1, in2, mode='full', method='auto', precision=None): if method != 'auto': warnings.warn("correlate() ignores method argument") if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype( in2.dtype, jnp.complexfloating): raise NotImplementedError( "correlate() does not support complex inputs") return _convolve_nd(in1, jnp.flip(in2), mode, precision=precision)
def correlate(in1, in2, mode='full', method='auto', precision=None): if method != 'auto': warnings.warn("correlate() ignores method argument") return _convolve_nd(in1, jnp.flip(in2.conj()), mode, precision=precision)