Ejemplo n.º 1
0
def _roots_no_zeros(p):
    # build companion matrix and find its eigenvalues (the roots)
    if p.size < 2:
        return array([], dtype=dtypes._to_complex_dtype(p.dtype))
    A = diag(ones((p.size - 2, ), p.dtype), -1)
    A = A.at[0, :].set(-p[1:] / p[0])
    return linalg.eigvals(A)
Ejemplo n.º 2
0
 def testLaxIrfftDoesNotMutateInputs(self, dtype):
     if dtype == np.float64 and not config.x64_enabled:
         raise self.skipTest("float64 requires jax_enable_x64=true")
     x = (1 + 1j) * jnp.array([[1.0, 2.0], [3.0, 4.0]],
                              dtype=dtypes._to_complex_dtype(dtype))
     y = np.asarray(jnp.fft.irfft2(x))
     z = np.asarray(jnp.fft.irfft2(x))
     self.assertAllClose(y, z)
Ejemplo n.º 3
0
def _promote_dtypes_complex(*args):
  """Convenience function to apply Numpy argument dtype promotion.

  Promotes arguments to a complex type."""
  to_dtype, weak_type = dtypes._lattice_result_type(*args)
  to_dtype = dtypes.canonicalize_dtype(to_dtype)
  to_dtype_complex = dtypes._to_complex_dtype(to_dtype)
  return [lax_internal._convert_element_type(x, to_dtype_complex, weak_type)
          for x in args]
Ejemplo n.º 4
0
def _complex_uniform(key, shape, dtype):
    """
  Sample uniform random values within a disk on the complex plane,
  with zero mean and unit variance.
  """
    key_r, key_theta = random.split(key)
    real_dtype = np.array(0, dtype).real.dtype
    dtype = dtypes._to_complex_dtype(real_dtype)
    r = jnp.sqrt(2 * random.uniform(key_r, shape, real_dtype)).astype(dtype)
    theta = 2 * jnp.pi * random.uniform(key_theta, shape,
                                        real_dtype).astype(dtype)
    return r * jnp.exp(1j * theta)
Ejemplo n.º 5
0
def _complex_truncated_normal(key, upper, shape, dtype):
    """
  Sample random values from a centered normal distribution on the complex plane,
  whose modulus is truncated to `upper`, and the variance before the truncation is one.
  """
    key_r, key_theta = random.split(key)
    real_dtype = np.array(0, dtype).real.dtype
    dtype = dtypes._to_complex_dtype(real_dtype)
    t = (1 - jnp.exp(jnp.array(-(upper**2), dtype))) * random.uniform(
        key_r, shape, real_dtype).astype(dtype)
    r = jnp.sqrt(-jnp.log(1 - t))
    theta = 2 * jnp.pi * random.uniform(key_theta, shape,
                                        real_dtype).astype(dtype)
    return r * jnp.exp(1j * theta)
Ejemplo n.º 6
0
def roots(p, *, strip_zeros=True):
    _check_arraylike("roots", p)
    p = atleast_1d(*_promote_dtypes_inexact(p))
    if p.ndim != 1:
        raise ValueError("Input must be a rank-1 array.")
    if p.size < 2:
        return array([], dtype=dtypes._to_complex_dtype(p.dtype))
    num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0))

    if strip_zeros:
        num_leading_zeros = core.concrete_or_error(
            int, num_leading_zeros,
            "The error occurred in the jnp.roots() function. To use this within a "
            "JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros "
            "will be result in some returned roots being set to NaN.")
        return _roots_no_zeros(p[num_leading_zeros:])
    else:
        return _roots_with_zeros(p, num_leading_zeros)
Ejemplo n.º 7
0
    def testWelchWithDefaultStepArgsAgainstNumpy(self, *, shape, dtype,
                                                 nperseg, noverlap,
                                                 use_nperseg, use_noverlap,
                                                 timeaxis):
        kwargs = {'axis': timeaxis}

        if use_nperseg:
            kwargs['nperseg'] = nperseg
        else:
            kwargs['window'] = jnp.array(osp_signal.get_window(
                'hann', nperseg),
                                         dtype=dtypes._to_complex_dtype(dtype))
        if use_noverlap:
            kwargs['noverlap'] = noverlap

        def osp_fun(x):
            freqs, Pxx = osp_signal.welch(x, **kwargs)
            return freqs.astype(_real_dtype(dtype)), Pxx.astype(
                _real_dtype(dtype))

        jsp_fun = partial(jsp_signal.welch, **kwargs)
        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
Ejemplo n.º 8
0
def _schur(a, output):
    if output == "complex":
        a = a.astype(dtypes._to_complex_dtype(a.dtype))
    return lax_linalg.schur(a)
Ejemplo n.º 9
0
def _spectral_helper(x,
                     y,
                     fs=1.0,
                     window='hann',
                     nperseg=None,
                     noverlap=None,
                     nfft=None,
                     detrend_type='constant',
                     return_onesided=True,
                     scaling='density',
                     axis=-1,
                     mode='psd',
                     boundary=None,
                     padded=False):
    """LAX-backend implementation of `scipy.signal._spectral_helper`.

  Unlike the original helper function, `y` can be None for explicitly
  indicating auto-spectral (non cross-spectral) computation.  In addition to
  this, `detrend` argument is renamed to `detrend_type` for avoiding internal
  name overlap.
  """
    if mode not in ('psd', 'stft'):
        raise ValueError(f"Unknown value for mode {mode}, "
                         "must be one of: ('psd', 'stft')")

    def make_pad(mode, **kwargs):
        def pad(x, n, axis=-1):
            pad_width = [(0, 0) for unused_n in range(x.ndim)]
            pad_width[axis] = (n, n)
            return jnp.pad(x, pad_width, mode, **kwargs)

        return pad

    boundary_funcs = {
        'even': make_pad('reflect'),
        'odd': odd_ext,
        'constant': make_pad('edge'),
        'zeros': make_pad('constant', constant_values=0.0),
        None: lambda x, *args, **kwargs: x
    }

    # Check/ normalize inputs
    if boundary not in boundary_funcs:
        raise ValueError(f"Unknown boundary option '{boundary}', "
                         f"must be one of: {list(boundary_funcs.keys())}")

    axis = jax.core.concrete_or_error(operator.index, axis,
                                      "axis of windowed-FFT")
    axis = canonicalize_axis(axis, x.ndim)

    if y is None:
        _check_arraylike('spectral_helper', x)
        x, = _promote_dtypes_inexact(x)
        outershape = tuple_delete(x.shape, axis)
    else:
        if mode != 'psd':
            raise ValueError(
                "two-argument mode is available only when mode=='psd'")
        _check_arraylike('spectral_helper', x, y)
        x, y = _promote_dtypes_inexact(x, y)
        if x.ndim != y.ndim:
            raise ValueError(
                "two-arguments must have the same rank ({x.ndim} vs {y.ndim})."
            )
        # Check if we can broadcast the outer axes together
        try:
            outershape = jnp.broadcast_shapes(tuple_delete(x.shape, axis),
                                              tuple_delete(y.shape, axis))
        except ValueError as err:
            raise ValueError('x and y cannot be broadcast together.') from err

    result_dtype = dtypes._to_complex_dtype(x.dtype)
    freq_dtype = np.finfo(result_dtype).dtype

    if nperseg is not None:  # if specified by user
        nperseg = jax.core.concrete_or_error(int, nperseg,
                                             "nperseg of windowed-FFT")
        if nperseg < 1:
            raise ValueError('nperseg must be a positive integer')
    # parse window; if array like, then set nperseg = win.shape
    win, nperseg = signal_helper._triage_segments(window,
                                                  nperseg,
                                                  input_length=x.shape[axis],
                                                  dtype=result_dtype)

    if noverlap is None:
        noverlap = nperseg // 2
    else:
        noverlap = jax.core.concrete_or_error(int, noverlap,
                                              "noverlap of windowed-FFT")
    if nfft is None:
        nfft = nperseg
    else:
        nfft = jax.core.concrete_or_error(int, nfft, "nfft of windowed-FFT")

    # Special cases for size == 0
    if y is None:
        if x.size == 0:
            return jnp.zeros(x.shape, freq_dtype), jnp.zeros(
                x.shape, freq_dtype), jnp.zeros(x.shape, result_dtype)
    else:
        if x.size == 0 or y.size == 0:
            shape = tuple_insert(outershape,
                                 min([x.shape[axis], y.shape[axis]]), axis)
            return jnp.zeros(shape, freq_dtype), jnp.zeros(
                shape, freq_dtype), jnp.zeros(shape, result_dtype)

    # Move time-axis to the end
    x = jnp.moveaxis(x, axis, -1)
    if y is not None and y.ndim > 1:
        y = jnp.moveaxis(y, axis, -1)

    # Check if x and y are the same length, zero-pad if necessary
    if y is not None and x.shape[-1] != y.shape[-1]:
        if x.shape[-1] < y.shape[-1]:
            pad_shape = list(x.shape)
            pad_shape[-1] = y.shape[-1] - x.shape[-1]
            x = jnp.concatenate((x, jnp.zeros_like(x, shape=pad_shape)), -1)
        else:
            pad_shape = list(y.shape)
            pad_shape[-1] = x.shape[-1] - y.shape[-1]
            y = jnp.concatenate((y, jnp.zeros_like(x, shape=pad_shape)), -1)

    if nfft < nperseg:
        raise ValueError('nfft must be greater than or equal to nperseg.')
    if noverlap >= nperseg:
        raise ValueError('noverlap must be less than nperseg.')
    nstep = nperseg - noverlap

    # Apply paddings
    if boundary is not None:
        ext_func = boundary_funcs[boundary]
        x = ext_func(x, nperseg // 2, axis=-1)
        if y is not None:
            y = ext_func(y, nperseg // 2, axis=-1)

    if padded:
        # Pad to integer number of windowed segments
        # I.e make x.shape[-1] = nperseg + (nseg-1)*nstep, with integer nseg
        nadd = (-(x.shape[-1] - nperseg) % nstep) % nperseg
        x = jnp.concatenate(
            (x, jnp.zeros_like(x, shape=(*x.shape[:-1], nadd))), axis=-1)
        if y is not None:
            y = jnp.concatenate(
                (y, jnp.zeros_like(x, shape=(*y.shape[:-1], nadd))), axis=-1)

    # Handle detrending and window functions
    if not detrend_type:
        detrend_func = lambda d: d
    elif not callable(detrend_type):
        detrend_func = partial(detrend, type=detrend_type, axis=-1)
    elif axis != -1:
        # Wrap this function so that it receives a shape that it could
        # reasonably expect to receive.
        def detrend_func(d):
            d = jnp.moveaxis(d, axis, -1)
            d = detrend_type(d)
            return jnp.moveaxis(d, -1, axis)
    else:
        detrend_func = detrend_type

    # Determine scale
    if scaling == 'density':
        scale = 1.0 / (fs * (win * win).sum())
    elif scaling == 'spectrum':
        scale = 1.0 / win.sum()**2
    else:
        raise ValueError(f'Unknown scaling: {scaling}')
    if mode == 'stft':
        scale = jnp.sqrt(scale)

    # Determine onesided/ two-sided
    if return_onesided:
        sides = 'onesided'
        if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
            sides = 'twosided'
            warnings.warn('Input data is complex, switching to '
                          'return_onesided=False')
    else:
        sides = 'twosided'

    if sides == 'twosided':
        freqs = jax.numpy.fft.fftfreq(nfft, 1 / fs).astype(freq_dtype)
    elif sides == 'onesided':
        freqs = jax.numpy.fft.rfftfreq(nfft, 1 / fs).astype(freq_dtype)

    # Perform the windowed FFTs
    result = _fft_helper(x.astype(result_dtype), win, detrend_func, nperseg,
                         noverlap, nfft, sides)

    if y is not None:
        # All the same operations on the y data
        result_y = _fft_helper(y.astype(result_dtype), win, detrend_func,
                               nperseg, noverlap, nfft, sides)
        result = jnp.conjugate(result) * result_y
    elif mode == 'psd':
        result = jnp.conjugate(result) * result

    result *= scale

    if sides == 'onesided' and mode == 'psd':
        end = None if nfft % 2 else -1
        result = result.at[..., 1:end].mul(2)

    time = jnp.arange(nperseg / 2,
                      x.shape[-1] - nperseg / 2 + 1,
                      nperseg - noverlap,
                      dtype=freq_dtype) / fs
    if boundary is not None:
        time -= (nperseg / 2) / fs

    result = result.astype(result_dtype)

    # All imaginary parts are zero anyways
    if y is None and mode != 'stft':
        result = result.real

    # Move frequency axis back to axis where the data came from
    result = jnp.moveaxis(result, -1, axis)

    return freqs, time, result
Ejemplo n.º 10
0
 def np_fun(arg):
     roots = np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype))
     if len(roots) < len(arg) - 1:
         roots = np.pad(roots, (0, len(arg) - len(roots) - 1),
                        constant_values=complex(np.nan, np.nan))
     return roots
Ejemplo n.º 11
0
 def np_fun(arg):
     return np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype))
Ejemplo n.º 12
0
def _complex_dtype(dtype):
    return dtypes._to_complex_dtype(dtype)