Exemple #1
0
 def testFft(self, fft_ndims, shape, bdims):
     rng = jtu.rand_default(self.rng())
     ndims = len(shape)
     axes = range(ndims - fft_ndims, ndims)
     fft_lengths = [shape[axis] for axis in axes]
     op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths)
     self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng)
Exemple #2
0
def _fft_core(func_name, fft_type, a, s, axes, norm):
    full_name = "jax.numpy.fft." + func_name

    if s is not None:
        s = tuple(map(operator.index, s))
        if np.any(np.less(s, 0)):
            raise ValueError("Shape should be non-negative.")
    if norm is not None:
        raise NotImplementedError("%s only supports norm=None, got %s" %
                                  (full_name, norm))
    if s is not None and axes is not None and len(s) != len(axes):
        # Same error as numpy.
        raise ValueError("Shape and axes have different lengths.")

    orig_axes = axes
    if axes is None:
        if s is None:
            axes = range(a.ndim)
        else:
            axes = range(a.ndim - len(s), a.ndim)

    if len(axes) != len(set(axes)):
        raise ValueError("%s does not support repeated axes. Got axes %s." %
                         (full_name, axes))

    if len(axes) > 3:
        # XLA does not support FFTs over more than 3 dimensions
        raise ValueError("%s only supports 1D, 2D, and 3D FFTs. "
                         "Got axes %s with input rank %s." %
                         (full_name, orig_axes, a.ndim))

    # XLA only supports FFTs over the innermost axes, so rearrange if necessary.
    if orig_axes is not None:
        axes = tuple(range(a.ndim - len(axes), a.ndim))
        a = jnp.moveaxis(a, orig_axes, axes)

    if s is not None:
        a = jnp.asarray(a)
        in_s = list(a.shape)
        for axis, x in safe_zip(axes, s):
            in_s[axis] = x
        if fft_type == xla_client.FftType.IRFFT:
            in_s[-1] = (in_s[-1] // 2 + 1)
        # Cropping
        a = a[tuple(map(slice, in_s))]
        # Padding
        a = jnp.pad(a, [(0, x - y) for x, y in zip(in_s, a.shape)])
    else:
        if fft_type == xla_client.FftType.IRFFT:
            s = [a.shape[axis] for axis in axes[:-1]]
            if axes:
                s += [max(0, 2 * (a.shape[axes[-1]] - 1))]
        else:
            s = [a.shape[axis] for axis in axes]

    transformed = lax.fft(a, fft_type, s)

    if orig_axes is not None:
        transformed = jnp.moveaxis(transformed, axes, orig_axes)
    return transformed
Exemple #3
0
def _fft_core(func_name, fft_type, a, s, axes, norm):
  # TODO(skye): implement padding/cropping based on 's'.
  full_name = "jax.numpy.fft." + func_name
  if s is not None:
    raise NotImplementedError("%s only supports s=None, got %s" % (full_name, s))
  if norm is not None:
    raise NotImplementedError("%s only supports norm=None, got %s" % (full_name, norm))
  if s is not None and axes is not None and len(s) != len(axes):
    # Same error as numpy.
    raise ValueError("Shape and axes have different lengths.")

  orig_axes = axes
  if axes is None:
    if s is None:
      axes = range(a.ndim)
    else:
      axes = range(a.ndim - len(s), a.ndim)

  if len(axes) != len(set(axes)):
    raise ValueError(
        "%s does not support repeated axes. Got axes %s." % (full_name, axes))

  if len(axes) > 3:
    # XLA does not support FFTs over more than 3 dimensions
    raise ValueError(
        "%s only supports 1D, 2D, and 3D FFTs. "
        "Got axes %s with input rank %s." % (full_name, orig_axes, a.ndim))

  # XLA only supports FFTs over the innermost axes, so rearrange if necessary.
  if orig_axes is not None:
    axes = tuple(range(a.ndim - len(axes), a.ndim))
    a = jnp.moveaxis(a, orig_axes, axes)

  if s is None:
    if fft_type == xla_client.FftType.IRFFT:
      s = [a.shape[axis] for axis in axes[:-1]]
      if axes:
        s += [max(0, 2 * (a.shape[axes[-1]] - 1))]
    else:
      s = [a.shape[axis] for axis in axes]

  transformed = lax.fft(a, fft_type, s)

  if orig_axes is not None:
    transformed = jnp.moveaxis(transformed, axes, orig_axes)
  return transformed
Exemple #4
0
 def testLaxFftAcceptsStringTypes(self):
     rng = jtu.rand_default(self.rng())
     x = rng((10, ), np.complex64)
     self.assertAllClose(
         np.fft.fft(x).astype(np.complex64),
         lax.fft(x, "FFT", fft_lengths=(10, )))