예제 #1
0
파일: signal.py 프로젝트: GJBoth/jax
def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=None):
    if overwrite_data is not None:
        raise NotImplementedError("overwrite_data argument not implemented.")
    if type not in ['constant', 'linear']:
        raise ValueError("Trend type must be 'linear' or 'constant'.")
    data, = _promote_dtypes_inexact(jnp.asarray(data))
    if type == 'constant':
        return data - data.mean(axis, keepdims=True)
    else:
        N = data.shape[axis]
        # bp is static, so we use np operations to avoid pushing to device.
        bp = np.sort(np.unique(np.r_[0, bp, N]))
        if bp[0] < 0 or bp[-1] > N:
            raise ValueError(
                "Breakpoints must be non-negative and less than length of data along given axis."
            )
        data = jnp.moveaxis(data, axis, 0)
        shape = data.shape
        data = data.reshape(N, -1)
        for m in range(len(bp) - 1):
            Npts = bp[m + 1] - bp[m]
            A = jnp.vstack([
                jnp.ones(Npts, dtype=data.dtype),
                jnp.arange(1, Npts + 1, dtype=data.dtype) / Npts
            ]).T
            sl = slice(bp[m], bp[m + 1])
            coef, *_ = linalg.lstsq(A, data[sl])
            data = data.at[sl].add(
                -jnp.matmul(A, coef, precision=lax.Precision.HIGHEST))
        return jnp.moveaxis(data.reshape(shape), 0, axis)
예제 #2
0
파일: fft.py 프로젝트: xueeinstein/jax
def _fft_core(func_name, fft_type, a, s, axes, norm):
    full_name = "jax.numpy.fft." + func_name

    if s is not None:
        s = tuple(map(operator.index, s))
        if np.any(np.less(s, 0)):
            raise ValueError("Shape should be non-negative.")

    if s is not None and axes is not None and len(s) != len(axes):
        # Same error as numpy.
        raise ValueError("Shape and axes have different lengths.")

    orig_axes = axes
    if axes is None:
        if s is None:
            axes = range(a.ndim)
        else:
            axes = range(a.ndim - len(s), a.ndim)

    if len(axes) != len(set(axes)):
        raise ValueError(
            f"{full_name} does not support repeated axes. Got axes {axes}.")

    if len(axes) > 3:
        # XLA does not support FFTs over more than 3 dimensions
        raise ValueError("%s only supports 1D, 2D, and 3D FFTs. "
                         "Got axes %s with input rank %s." %
                         (full_name, orig_axes, a.ndim))

    # XLA only supports FFTs over the innermost axes, so rearrange if necessary.
    if orig_axes is not None:
        axes = tuple(range(a.ndim - len(axes), a.ndim))
        a = jnp.moveaxis(a, orig_axes, axes)

    if s is not None:
        a = jnp.asarray(a)
        in_s = list(a.shape)
        for axis, x in safe_zip(axes, s):
            in_s[axis] = x
        if fft_type == xla_client.FftType.IRFFT:
            in_s[-1] = (in_s[-1] // 2 + 1)
        # Cropping
        a = a[tuple(map(slice, in_s))]
        # Padding
        a = jnp.pad(a, [(0, x - y) for x, y in zip(in_s, a.shape)])
    else:
        if fft_type == xla_client.FftType.IRFFT:
            s = [a.shape[axis] for axis in axes[:-1]]
            if axes:
                s += [max(0, 2 * (a.shape[axes[-1]] - 1))]
        else:
            s = [a.shape[axis] for axis in axes]
    transformed = lax.fft(a, fft_type, tuple(s))
    transformed *= _fft_norm(jnp.array(s, dtype=transformed.dtype), func_name,
                             norm)

    if orig_axes is not None:
        transformed = jnp.moveaxis(transformed, axes, orig_axes)
    return transformed
예제 #3
0
def _unique_sorted_mask(ar, axis):
    aux = moveaxis(ar, axis, 0)
    if np.issubdtype(aux.dtype, np.complexfloating):
        # Work around issue in sorting of complex numbers with Nan only in the
        # imaginary component. This can be removed if sorting in this situation
        # is fixed to match numpy.
        aux = where(isnan(aux), _lax_const(aux, np.nan), aux)
    size, *out_shape = aux.shape
    if _prod(out_shape) == 0:
        size = 1
        perm = zeros(1, dtype=int)
    else:
        perm = lexsort(aux.reshape(size, _prod(out_shape)).T[::-1])
    aux = aux[perm]
    if aux.size:
        if dtypes.issubdtype(aux.dtype, np.inexact):
            # This is appropriate for both float and complex due to the documented behavior of np.unique:
            # See https://github.com/numpy/numpy/blob/v1.22.0/numpy/lib/arraysetops.py#L212-L220
            neq = lambda x, y: lax.ne(x, y) & ~(isnan(x) & isnan(y))
        else:
            neq = lax.ne
        mask = ones(size, dtype=bool).at[1:].set(
            any(neq(aux[1:], aux[:-1]), tuple(range(1, aux.ndim))))
    else:
        mask = zeros(size, dtype=bool)
    return aux, mask, perm
예제 #4
0
파일: linalg.py 프로젝트: ahoenselaar/jax
def norm(x,
         ord=None,
         axis: Union[None, Tuple[int, ...], int] = None,
         keepdims=False):
    x = _promote_arg_dtypes(jnp.asarray(x))
    x_shape = jnp.shape(x)
    ndim = len(x_shape)

    if axis is None:
        # NumPy has an undocumented behavior that admits arbitrary rank inputs if
        # `ord` is None: https://github.com/numpy/numpy/issues/14215
        if ord is None:
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims))
        axis = tuple(range(ndim))
    elif isinstance(axis, tuple):
        axis = tuple(canonicalize_axis(x, ndim) for x in axis)
    else:
        axis = (canonicalize_axis(axis, ndim), )

    num_axes = len(axis)
    if num_axes == 1:
        if ord is None or ord == 2:
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)),
                        axis=axis,
                        keepdims=keepdims))
        elif ord == jnp.inf:
            return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims)
        elif ord == -jnp.inf:
            return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims)
        elif ord == 0:
            return jnp.sum(x != 0,
                           dtype=jnp.finfo(lax.dtype(x)).dtype,
                           axis=axis,
                           keepdims=keepdims)
        elif ord == 1:
            # Numpy has a special case for ord == 1 as an optimization. We don't
            # really need the optimization (XLA could do it for us), but the Numpy
            # code has slightly different type promotion semantics, so we need a
            # special case too.
            return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims)
        else:
            abs_x = jnp.abs(x)
            ord = lax._const(abs_x, ord)
            out = jnp.sum(abs_x**ord, axis=axis, keepdims=keepdims)
            return jnp.power(out, 1. / ord)

    elif num_axes == 2:
        row_axis, col_axis = cast(Tuple[int, ...], axis)
        if ord is None or ord in ('f', 'fro'):
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)),
                        axis=axis,
                        keepdims=keepdims))
        elif ord == 1:
            if not keepdims and col_axis > row_axis:
                col_axis -= 1
            return jnp.amax(jnp.sum(jnp.abs(x),
                                    axis=row_axis,
                                    keepdims=keepdims),
                            axis=col_axis,
                            keepdims=keepdims)
        elif ord == -1:
            if not keepdims and col_axis > row_axis:
                col_axis -= 1
            return jnp.amin(jnp.sum(jnp.abs(x),
                                    axis=row_axis,
                                    keepdims=keepdims),
                            axis=col_axis,
                            keepdims=keepdims)
        elif ord == jnp.inf:
            if not keepdims and row_axis > col_axis:
                row_axis -= 1
            return jnp.amax(jnp.sum(jnp.abs(x),
                                    axis=col_axis,
                                    keepdims=keepdims),
                            axis=row_axis,
                            keepdims=keepdims)
        elif ord == -jnp.inf:
            if not keepdims and row_axis > col_axis:
                row_axis -= 1
            return jnp.amin(jnp.sum(jnp.abs(x),
                                    axis=col_axis,
                                    keepdims=keepdims),
                            axis=row_axis,
                            keepdims=keepdims)
        elif ord in ('nuc', 2, -2):
            x = jnp.moveaxis(x, axis, (-2, -1))
            if ord == 2:
                reducer = jnp.amax
            elif ord == -2:
                reducer = jnp.amin
            else:
                reducer = jnp.sum
            y = reducer(svd(x, compute_uv=False), axis=-1)
            if keepdims:
                result_shape = list(x_shape)
                result_shape[axis[0]] = 1
                result_shape[axis[1]] = 1
                y = jnp.reshape(y, result_shape)
            return y
        else:
            raise ValueError("Invalid order '{}' for matrix norm.".format(ord))
    else:
        raise ValueError(
            "Invalid axis values ({}) for jnp.linalg.norm.".format(axis))
예제 #5
0
def conv_general_dilated_patches(
    lhs: lax.Array,
    filter_shape: Sequence[int],
    window_strides: Sequence[int],
    padding: Union[str, Sequence[Tuple[int, int]]],
    lhs_dilation: Optional[Sequence[int]] = None,
    rhs_dilation: Optional[Sequence[int]] = None,
    dimension_numbers: Optional[lax.ConvGeneralDilatedDimensionNumbers] = None,
    precision: Optional[lax.PrecisionType] = None,
    preferred_element_type: Optional[DType] = None,
) -> lax.Array:
    """Extract patches subject to the receptive field of `conv_general_dilated`.

  Runs the input through a convolution with given parameters. The kernel of the
  convolution is constructed such that the output channel dimension `"C"`
  contains flattened image patches, so instead a single `"C"` dimension
  represents, for example, three dimensions `"chw"` collapsed. The order of
  these dimensions is `"c" + ''.join(c for c in rhs_spec if c not in 'OI')`,
  where `rhs_spec == dimension_numbers[1]`, and the size of this `"C"`
  dimension is therefore the size of each patch, i.e.
  `np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`, where
  `lhs_spec == dimension_numbers[0]`.

  Docstring below adapted from `jax.lax.conv_general_dilated`.

  See Also:
    https://www.tensorflow.org/xla/operation_semantics#conv_convolution

  Args:
    lhs: a rank `n+2` dimensional input array.
    filter_shape: a sequence of `n` integers, representing the receptive window
      spatial shape in the order as specified in
      `rhs_spec = dimension_numbers[1]`.
    window_strides: a sequence of `n` integers, representing the inter-window
      strides.
    padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
      `n` `(low, high)` integer pairs that give the padding to apply before and
      after each spatial dimension.
    lhs_dilation: `None`, or a sequence of `n` integers, giving the
      dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
      is also known as transposed convolution.
    rhs_dilation: `None`, or a sequence of `n` integers, giving the
      dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
      is also known as atrous convolution.
    dimension_numbers: either `None`, or a 3-tuple
      `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
      of length `n+2`. `None` defaults to `("NCHWD..., OIHWD..., NCHWD...")`.
    precision: Optional. Either ``None``, which means the default precision for
      the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
      ``Precision.HIGH`` or ``Precision.HIGHEST``).
    preferred_element_type: Optional. Either ``None``, which means the default
      accumulation type for the input types, or a datatype, indicating to
      accumulate results to and return a result with that datatype.

  Returns:
    A rank `n+2` array containing the flattened image patches in the output
    channel (`"C"`) dimension. For example if
    `dimension_numbers = ("NcHW", "OIwh", "CNHW")`, the output has dimension
    numbers `"CNHW" = "{cwh}NHW"`, with the size of dimension `"C"` equal to
    the size of each patch
    (`np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`).

  """
    filter_shape = tuple(filter_shape)
    dimension_numbers = lax.conv_dimension_numbers(lhs.shape,
                                                   (1, 1) + filter_shape,
                                                   dimension_numbers)

    lhs_spec, rhs_spec, out_spec = dimension_numbers

    spatial_size = prod(filter_shape)
    n_channels = lhs.shape[lhs_spec[1]]

    # Move separate `lhs` spatial locations into separate `rhs` channels.
    rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2)

    rhs = rhs.reshape((spatial_size, 1) + filter_shape)
    rhs = jnp.tile(rhs, (n_channels, ) + (1, ) * (rhs.ndim - 1))
    rhs = jnp.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1]))

    out = lax.conv_general_dilated(
        lhs=lhs,
        rhs=rhs,
        window_strides=window_strides,
        padding=padding,
        lhs_dilation=lhs_dilation,
        rhs_dilation=rhs_dilation,
        dimension_numbers=dimension_numbers,
        precision=None if precision is None else
        (precision, lax.Precision.DEFAULT),
        feature_group_count=n_channels,
        preferred_element_type=preferred_element_type)
    return out
예제 #6
0
def conv_general_dilated_local(
        lhs: jnp.ndarray,
        rhs: jnp.ndarray,
        window_strides: Sequence[int],
        padding: Union[str, Sequence[Tuple[int, int]]],
        filter_shape: Sequence[int],
        lhs_dilation: Optional[Sequence[int]] = None,
        rhs_dilation: Optional[Sequence[int]] = None,
        dimension_numbers: Optional[
            convolution.ConvGeneralDilatedDimensionNumbers] = None,
        precision: lax.PrecisionLike = None) -> jnp.ndarray:
    """General n-dimensional unshared convolution operator with optional dilation.

  Also known as locally connected layer, the operation is equivalent to
  convolution with a separate (unshared) `rhs` kernel used at each output
  spatial location. Docstring below adapted from `jax.lax.conv_general_dilated`.

  See Also:
    https://www.tensorflow.org/xla/operation_semantics#conv_convolution

  Args:
    lhs: a rank `n+2` dimensional input array.
    rhs: a rank `n+2` dimensional array of kernel weights. Unlike in regular
      CNNs, its spatial coordinates (`H`, `W`, ...) correspond to output spatial
      locations, while input spatial locations are fused with the input channel
      locations in the single `I` dimension, in the order of
      `"C" + ''.join(c for c in rhs_spec if c not in 'OI')`, where
      `rhs_spec = dimension_numbers[1]`. For example, if `rhs_spec == "WHIO",
      the unfolded kernel shape is
      `"[output W][output H]{I[receptive window W][receptive window H]}O"`.
    window_strides: a sequence of `n` integers, representing the inter-window
      strides.
    padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
      `n` `(low, high)` integer pairs that give the padding to apply before and
      after each spatial dimension.
    filter_shape: a sequence of `n` integers, representing the receptive window
      spatial shape in the order as specified in
      `rhs_spec = dimension_numbers[1]`.
    lhs_dilation: `None`, or a sequence of `n` integers, giving the
      dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
      is also known as transposed convolution.
    rhs_dilation: `None`, or a sequence of `n` integers, giving the
      dilation factor to apply in each input spatial dimension of `rhs`.
      RHS dilation is also known as atrous convolution.
    dimension_numbers: either `None`, a `ConvDimensionNumbers` object, or
      a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
      of length `n+2`.
    precision: Optional. Either ``None``, which means the default precision for
      the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
      ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
      ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.

  Returns:
    An array containing the unshared convolution result.

  In the string case of `dimension_numbers`, each character identifies by
  position:

  - the batch dimensions in `lhs`, `rhs`, and the output with the character
    'N',
  - the feature dimensions in `lhs` and the output with the character 'C',
  - the input and output feature dimensions in rhs with the characters 'I'
    and 'O' respectively, and
  - spatial dimension correspondences between `lhs`, `rhs`, and the output using
    any distinct characters.

  For example, to indicate dimension numbers consistent with the `conv` function
  with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As
  another example, to indicate dimension numbers consistent with the TensorFlow
  Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the
  latter form of convolution dimension specification, window strides are
  associated with spatial dimension character labels according to the order in
  which the labels appear in the `rhs_spec` string, so that `window_strides[0]`
  is matched with the dimension corresponding to the first character
  appearing in rhs_spec that is not `'I'` or `'O'`.

  If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')`
  (for a 2D convolution).
  """
    c_precision = lax.canonicalize_precision(precision)
    lhs_precision = type_cast(Optional[lax.PrecisionType],
                              (c_precision[0] if
                               (isinstance(c_precision, tuple)
                                and len(c_precision) == 2) else c_precision))

    patches = conv_general_dilated_patches(lhs=lhs,
                                           filter_shape=filter_shape,
                                           window_strides=window_strides,
                                           padding=padding,
                                           lhs_dilation=lhs_dilation,
                                           rhs_dilation=rhs_dilation,
                                           dimension_numbers=dimension_numbers,
                                           precision=lhs_precision)

    lhs_spec, rhs_spec, out_spec = convolution.conv_dimension_numbers(
        lhs.shape, (1, 1) + tuple(filter_shape), dimension_numbers)

    lhs_c_dims, rhs_c_dims = [out_spec[1]], [rhs_spec[1]]

    lhs_b_dims = out_spec[2:]
    rhs_b_dims = rhs_spec[2:]

    rhs_b_dims = [
        rhs_b_dims[i]
        for i in sorted(range(len(rhs_b_dims)), key=lambda k: lhs_b_dims[k])
    ]
    lhs_b_dims = sorted(lhs_b_dims)

    dn = ((lhs_c_dims, rhs_c_dims), (lhs_b_dims, rhs_b_dims))
    out = lax.dot_general(patches,
                          rhs,
                          dimension_numbers=dn,
                          precision=precision)
    out = jnp.moveaxis(out, (-2, -1), (out_spec[0], out_spec[1]))
    return out
예제 #7
0
파일: signal.py 프로젝트: GJBoth/jax
 def detrend_func(d):
     d = jnp.moveaxis(d, axis, -1)
     d = detrend_type(d)
     return jnp.moveaxis(d, -1, axis)
예제 #8
0
파일: signal.py 프로젝트: GJBoth/jax
def _spectral_helper(x,
                     y,
                     fs=1.0,
                     window='hann',
                     nperseg=None,
                     noverlap=None,
                     nfft=None,
                     detrend_type='constant',
                     return_onesided=True,
                     scaling='density',
                     axis=-1,
                     mode='psd',
                     boundary=None,
                     padded=False):
    """LAX-backend implementation of `scipy.signal._spectral_helper`.

  Unlike the original helper function, `y` can be None for explicitly
  indicating auto-spectral (non cross-spectral) computation.  In addition to
  this, `detrend` argument is renamed to `detrend_type` for avoiding internal
  name overlap.
  """
    if mode not in ('psd', 'stft'):
        raise ValueError(f"Unknown value for mode {mode}, "
                         "must be one of: ('psd', 'stft')")

    def make_pad(mode, **kwargs):
        def pad(x, n, axis=-1):
            pad_width = [(0, 0) for unused_n in range(x.ndim)]
            pad_width[axis] = (n, n)
            return jnp.pad(x, pad_width, mode, **kwargs)

        return pad

    boundary_funcs = {
        'even': make_pad('reflect'),
        'odd': odd_ext,
        'constant': make_pad('edge'),
        'zeros': make_pad('constant', constant_values=0.0),
        None: lambda x, *args, **kwargs: x
    }

    # Check/ normalize inputs
    if boundary not in boundary_funcs:
        raise ValueError(f"Unknown boundary option '{boundary}', "
                         f"must be one of: {list(boundary_funcs.keys())}")

    axis = jax.core.concrete_or_error(operator.index, axis,
                                      "axis of windowed-FFT")
    axis = canonicalize_axis(axis, x.ndim)

    if nperseg is not None:  # if specified by user
        nperseg = jax.core.concrete_or_error(int, nperseg,
                                             "nperseg of windowed-FFT")
        if nperseg < 1:
            raise ValueError('nperseg must be a positive integer')
    # parse window; if array like, then set nperseg = win.shape
    win, nperseg = signal_helper._triage_segments(window,
                                                  nperseg,
                                                  input_length=x.shape[axis])

    if noverlap is None:
        noverlap = nperseg // 2
    else:
        noverlap = jax.core.concrete_or_error(int, noverlap,
                                              "noverlap of windowed-FFT")
    if nfft is None:
        nfft = nperseg
    else:
        nfft = jax.core.concrete_or_error(int, nfft, "nfft of windowed-FFT")

    _check_arraylike("_spectral_helper", x)
    x = jnp.asarray(x)

    if y is None:
        outdtype = jax.dtypes.canonicalize_dtype(
            np.result_type(x, np.complex64))
    else:
        _check_arraylike("_spectral_helper", y)
        y = jnp.asarray(y)
        outdtype = jax.dtypes.canonicalize_dtype(
            np.result_type(x, y, np.complex64))
        if mode != 'psd':
            raise ValueError(
                "two-argument mode is available only when mode=='psd'")
        if x.ndim != y.ndim:
            raise ValueError(
                "two-arguments must have the same rank ({x.ndim} vs {y.ndim})."
            )

        # Check if we can broadcast the outer axes together
        try:
            outershape = jnp.broadcast_shapes(tuple_delete(x.shape, axis),
                                              tuple_delete(y.shape, axis))
        except ValueError as e:
            raise ValueError('x and y cannot be broadcast together.') from e

    # Special cases for size == 0
    if y is None:
        if x.size == 0:
            return jnp.zeros(x.shape), jnp.zeros(x.shape), jnp.zeros(x.shape)
    else:
        if x.size == 0 or y.size == 0:
            outshape = tuple_insert(outershape,
                                    min([x.shape[axis], y.shape[axis]]), axis)
            emptyout = jnp.zeros(outshape)
            return emptyout, emptyout, emptyout

    # Move time-axis to the end
    if x.ndim > 1:
        if axis != -1:
            x = jnp.moveaxis(x, axis, -1)
            if y is not None and y.ndim > 1:
                y = jnp.moveaxis(y, axis, -1)

    # Check if x and y are the same length, zero-pad if necessary
    if y is not None:
        if x.shape[-1] != y.shape[-1]:
            if x.shape[-1] < y.shape[-1]:
                pad_shape = list(x.shape)
                pad_shape[-1] = y.shape[-1] - x.shape[-1]
                x = jnp.concatenate((x, jnp.zeros(pad_shape)), -1)
            else:
                pad_shape = list(y.shape)
                pad_shape[-1] = x.shape[-1] - y.shape[-1]
                y = jnp.concatenate((y, jnp.zeros(pad_shape)), -1)

    if nfft < nperseg:
        raise ValueError('nfft must be greater than or equal to nperseg.')
    if noverlap >= nperseg:
        raise ValueError('noverlap must be less than nperseg.')
    nstep = nperseg - noverlap

    # Apply paddings
    if boundary is not None:
        ext_func = boundary_funcs[boundary]
        x = ext_func(x, nperseg // 2, axis=-1)
        if y is not None:
            y = ext_func(y, nperseg // 2, axis=-1)

    if padded:
        # Pad to integer number of windowed segments
        # I.e make x.shape[-1] = nperseg + (nseg-1)*nstep, with integer nseg
        nadd = (-(x.shape[-1] - nperseg) % nstep) % nperseg
        zeros_shape = list(x.shape[:-1]) + [nadd]
        x = jnp.concatenate((x, jnp.zeros(zeros_shape)), axis=-1)
        if y is not None:
            zeros_shape = list(y.shape[:-1]) + [nadd]
            y = jnp.concatenate((y, jnp.zeros(zeros_shape)), axis=-1)

    # Handle detrending and window functions
    if not detrend_type:

        def detrend_func(d):
            return d
    elif not hasattr(detrend_type, '__call__'):

        def detrend_func(d):
            return detrend(d, type=detrend_type, axis=-1)
    elif axis != -1:
        # Wrap this function so that it receives a shape that it could
        # reasonably expect to receive.
        def detrend_func(d):
            d = jnp.moveaxis(d, axis, -1)
            d = detrend_type(d)
            return jnp.moveaxis(d, -1, axis)
    else:
        detrend_func = detrend_type

    if np.result_type(win, np.complex64) != outdtype:
        win = win.astype(outdtype)

    # Determine scale
    if scaling == 'density':
        scale = 1.0 / (fs * (win * win).sum())
    elif scaling == 'spectrum':
        scale = 1.0 / win.sum()**2
    else:
        raise ValueError(f'Unknown scaling: {scaling}')
    if mode == 'stft':
        scale = jnp.sqrt(scale)

    # Determine onesided/ two-sided
    if return_onesided:
        sides = 'onesided'
        if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
            sides = 'twosided'
            warnings.warn('Input data is complex, switching to '
                          'return_onesided=False')
    else:
        sides = 'twosided'

    if sides == 'twosided':
        freqs = jax.numpy.fft.fftfreq(nfft, 1 / fs)
    elif sides == 'onesided':
        freqs = jax.numpy.fft.rfftfreq(nfft, 1 / fs)

    # Perform the windowed FFTs
    result = _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides)

    if y is not None:
        # All the same operations on the y data
        result_y = _fft_helper(y, win, detrend_func, nperseg, noverlap, nfft,
                               sides)
        result = jnp.conjugate(result) * result_y
    elif mode == 'psd':
        result = jnp.conjugate(result) * result

    result *= scale

    if sides == 'onesided' and mode == 'psd':
        end = None if nfft % 2 else -1
        result = result.at[..., 1:end].mul(2)

    time = jnp.arange(nperseg / 2, x.shape[-1] - nperseg / 2 + 1,
                      nperseg - noverlap) / fs
    if boundary is not None:
        time -= (nperseg / 2) / fs

    result = result.astype(outdtype)

    # All imaginary parts are zero anyways
    if y is None and mode != 'stft':
        result = result.real

    # Move frequency axis back to axis where the data came from
    result = jnp.moveaxis(result, -1, axis)

    return freqs, time, result
예제 #9
0
파일: signal.py 프로젝트: frederikwilde/jax
def istft(Zxx,
          fs=1.0,
          window='hann',
          nperseg=None,
          noverlap=None,
          nfft=None,
          input_onesided=True,
          boundary=True,
          time_axis=-1,
          freq_axis=-2):
    # Input validation
    _check_arraylike("istft", Zxx)
    if Zxx.ndim < 2:
        raise ValueError('Input stft must be at least 2d!')
    freq_axis = canonicalize_axis(freq_axis, Zxx.ndim)
    time_axis = canonicalize_axis(time_axis, Zxx.ndim)
    if freq_axis == time_axis:
        raise ValueError('Must specify differing time and frequency axes!')

    Zxx = jnp.asarray(Zxx,
                      dtype=jax.dtypes.canonicalize_dtype(
                          np.result_type(Zxx, np.complex64)))

    n_default = (2 * (Zxx.shape[freq_axis] - 1)
                 if input_onesided else Zxx.shape[freq_axis])

    nperseg = jax.core.concrete_or_error(int, nperseg or n_default,
                                         "nperseg: segment length of STFT")
    if nperseg < 1:
        raise ValueError('nperseg must be a positive integer')

    if nfft is None:
        nfft = n_default
        if input_onesided and nperseg == n_default + 1:
            nfft += 1  # Odd nperseg, no FFT padding
    else:
        nfft = jax.core.concrete_or_error(int, nfft, "nfft of STFT")
    if nfft < nperseg:
        raise ValueError(
            f'FFT length ({nfft}) must be longer than nperseg ({nperseg}).')

    noverlap = jax.core.concrete_or_error(int, noverlap or nperseg // 2,
                                          "noverlap of STFT")
    if noverlap >= nperseg:
        raise ValueError('noverlap must be less than nperseg.')
    nstep = nperseg - noverlap

    # Rearrange axes if necessary
    if time_axis != Zxx.ndim - 1 or freq_axis != Zxx.ndim - 2:
        outer_idxs = tuple(idx for idx in range(Zxx.ndim)
                           if idx not in {time_axis, freq_axis})
        Zxx = jnp.transpose(Zxx, outer_idxs + (freq_axis, time_axis))

    # Perform IFFT
    ifunc = jax.numpy.fft.irfft if input_onesided else jax.numpy.fft.ifft
    # xsubs: [..., T, N], N is the number of frames, T is the frame length.
    xsubs = ifunc(Zxx, axis=-2, n=nfft)[..., :nperseg, :]

    # Get window as array
    if isinstance(window, (str, tuple)):
        win = osp_signal.get_window(window, nperseg)
        win = jnp.asarray(win)
    else:
        win = jnp.asarray(window)
        if len(win.shape) != 1:
            raise ValueError('window must be 1-D')
        if win.shape[0] != nperseg:
            raise ValueError('window must have length of {0}'.format(nperseg))
    win = win.astype(xsubs.dtype)

    xsubs *= win.sum()  # This takes care of the 'spectrum' scaling

    # make win broadcastable over xsubs
    win = win.reshape((1, ) * (xsubs.ndim - 2) + win.shape + (1, ))
    x = _overlap_and_add((xsubs * win).swapaxes(-2, -1), nstep)
    win_squared = jnp.repeat((win * win), xsubs.shape[-1], axis=-1)
    norm = _overlap_and_add(win_squared.swapaxes(-2, -1), nstep)

    # Remove extension points
    if boundary:
        x = x[..., nperseg // 2:-(nperseg // 2)]
        norm = norm[..., nperseg // 2:-(nperseg // 2)]
    x /= jnp.where(norm > 1e-10, norm, 1.0)

    # Put axes back
    if x.ndim > 1:
        if time_axis != Zxx.ndim - 1:
            if freq_axis < time_axis:
                time_axis -= 1
            x = jnp.moveaxis(x, -1, time_axis)

    time = jnp.arange(x.shape[0]) / fs
    return time, x
예제 #10
0
def _unique(ar,
            axis,
            return_index=False,
            return_inverse=False,
            return_counts=False,
            size=None,
            fill_value=None,
            return_true_size=False):
    """
  Find the unique elements of an array along a particular axis.
  """
    if ar.shape[axis] == 0 and size and fill_value is None:
        raise ValueError(
            "jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified"
        )

    aux, mask, perm = _unique_sorted_mask(ar, axis)
    if size is None:
        ind = core.concrete_or_error(
            None, mask, "The error arose in jnp.unique(). " + UNIQUE_SIZE_HINT)
    else:
        ind = nonzero(mask, size=size)[0]
    result = aux[ind] if aux.size else aux
    if fill_value is not None:
        fill_value = asarray(fill_value, dtype=result.dtype)
    if size is not None and fill_value is not None:
        if result.shape[0]:
            valid = lax.expand_dims(
                arange(size) < mask.sum(), tuple(range(1, result.ndim)))
            result = where(valid, result, fill_value)
        else:
            result = full_like(result,
                               fill_value,
                               shape=(size, *result.shape[1:]))
    result = moveaxis(result, 0, axis)

    ret = (result, )
    if return_index:
        if aux.size:
            ret += (perm[ind], )
        else:
            ret += (perm, )
    if return_inverse:
        if aux.size:
            imask = cumsum(mask) - 1
            inv_idx = zeros(mask.shape,
                            dtype=dtypes.canonicalize_dtype(dtypes.int_))
            inv_idx = inv_idx.at[perm].set(imask)
        else:
            inv_idx = zeros(ar.shape[axis], dtype=int)
        ret += (inv_idx, )
    if return_counts:
        if aux.size:
            if size is None:
                idx = append(nonzero(mask)[0], mask.size)
            else:
                idx = nonzero(mask, size=size + 1)[0]
                idx = idx.at[1:].set(where(idx[1:], idx[1:], mask.size))
            ret += (diff(idx), )
        elif ar.shape[axis]:
            ret += (array([ar.shape[axis]],
                          dtype=dtypes.canonicalize_dtype(dtypes.int_)), )
        else:
            ret += (empty(0, dtype=int), )
    if return_true_size:
        # Useful for internal uses of unique().
        ret += (mask.sum(), )
    return ret[0] if len(ret) == 1 else ret