Beispiel #1
0
def odd_ext(x, n, axis=-1):
    """Extends `x` along with `axis` by odd-extension.

  This function was previously a part of "scipy.signal.signaltools" but is no
  longer exposed.

  Args:
    x : input array
    n : the number of points to be added to the both end
    axis: the axis to be extended
  """
    if n < 1:
        return x
    if n > x.shape[axis] - 1:
        raise ValueError(
            f"The extension length n ({n}) is too big. "
            f"It must not exceed x.shape[axis]-1, which is {x.shape[axis] - 1}."
        )
    left_end = lax.slice_in_dim(x, 0, 1, axis=axis)
    left_ext = jnp.flip(lax.slice_in_dim(x, 1, n + 1, axis=axis), axis=axis)
    right_end = lax.slice_in_dim(x, -1, None, axis=axis)
    right_ext = jnp.flip(lax.slice_in_dim(x, -(n + 1), -1, axis=axis),
                         axis=axis)
    ext = jnp.concatenate(
        (2 * left_end - left_ext, x, 2 * right_end - right_ext), axis=axis)
    return ext
Beispiel #2
0
def correlate2d(in1,
                in2,
                mode='full',
                boundary='fill',
                fillvalue=0,
                precision=None):
    if boundary != 'fill' or fillvalue != 0:
        raise NotImplementedError(
            "correlate2d() only supports boundary='fill', fillvalue=0")
    if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
        raise ValueError("correlate2d() only supports 2-dimensional inputs.")

    swap = all(s1 <= s2 for s1, s2 in zip(in1.shape, in2.shape))
    same_shape = all(s1 == s2 for s1, s2 in zip(in1.shape, in2.shape))

    if mode == "same":
        in1, in2 = jnp.flip(in1), in2.conj()
        result = jnp.flip(_convolve_nd(in1, in2, mode, precision=precision))
    elif mode == "valid":
        if swap and not same_shape:
            in1, in2 = jnp.flip(in2), in1.conj()
            result = _convolve_nd(in1, in2, mode, precision=precision)
        else:
            in1, in2 = jnp.flip(in1), in2.conj()
            result = jnp.flip(_convolve_nd(in1, in2, mode,
                                           precision=precision))
    else:
        if swap:
            in1, in2 = jnp.flip(in2), in1.conj()
            result = _convolve_nd(in1, in2, mode, precision=precision).conj()
        else:
            in1, in2 = jnp.flip(in1), in2.conj()
            result = jnp.flip(_convolve_nd(in1, in2, mode,
                                           precision=precision))
    return result
Beispiel #3
0
def _convolve_nd(in1, in2, mode, *, precision):
  if mode not in ["full", "same", "valid"]:
    raise ValueError("mode must be one of ['full', 'same', 'valid']")
  if in1.ndim != in2.ndim:
    raise ValueError("in1 and in2 must have the same number of dimensions")
  if in1.size == 0 or in2.size == 0:
    raise ValueError(f"zero-size arrays not supported in convolutions, got shapes {in1.shape} and {in2.shape}.")
  in1, in2 = _promote_dtypes_inexact(in1, in2)

  no_swap = all(s1 >= s2 for s1, s2 in zip(in1.shape, in2.shape))
  swap = all(s1 <= s2 for s1, s2 in zip(in1.shape, in2.shape))
  if not (no_swap or swap):
    raise ValueError("One input must be smaller than the other in every dimension.")

  shape_o = in2.shape
  if swap:
    in1, in2 = in2, in1
  shape = in2.shape
  in2 = jnp.flip(in2)

  if mode == 'valid':
    padding = [(0, 0) for s in shape]
  elif mode == 'same':
    padding = [(s - 1 - (s_o - 1) // 2, s - s_o + (s_o - 1) // 2)
               for (s, s_o) in zip(shape, shape_o)]
  elif mode == 'full':
    padding = [(s - 1, s - 1) for s in shape]

  strides = tuple(1 for s in shape)
  result = lax.conv_general_dilated(in1[None, None], in2[None, None], strides,
                                    padding, precision=precision)
  return result[0, 0]
Beispiel #4
0
def correlate(in1, in2, mode='full', method='auto', precision=None):
    if method != 'auto':
        warnings.warn("correlate() ignores method argument")
    if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(
            in2.dtype, jnp.complexfloating):
        raise NotImplementedError(
            "correlate() does not support complex inputs")
    return _convolve_nd(in1, jnp.flip(in2), mode, precision=precision)
Beispiel #5
0
def correlate(in1, in2, mode='full', method='auto', precision=None):
    if method != 'auto':
        warnings.warn("correlate() ignores method argument")
    return _convolve_nd(in1, jnp.flip(in2.conj()), mode, precision=precision)