Пример #1
0
    def __init__(self, dataset, bw_method=None, weights=None):
        _check_arraylike("gaussian_kde", dataset)
        dataset = jnp.atleast_2d(dataset)
        if jnp.issubdtype(lax.dtype(dataset), jnp.complexfloating):
            raise NotImplementedError(
                "gaussian_kde does not support complex data")
        if not dataset.size > 1:
            raise ValueError("`dataset` input should have multiple elements.")

        d, n = dataset.shape
        if weights is not None:
            _check_arraylike("gaussian_kde", weights)
            dataset, weights = _promote_dtypes_inexact(dataset, weights)
            weights = jnp.atleast_1d(weights)
            weights /= jnp.sum(weights)
            if weights.ndim != 1:
                raise ValueError("`weights` input should be one-dimensional.")
            if len(weights) != n:
                raise ValueError("`weights` input should be of length n")
        else:
            dataset, = _promote_dtypes_inexact(dataset)
            weights = jnp.full(n, 1.0 / n, dtype=dataset.dtype)

        self._setattr("dataset", dataset)
        self._setattr("weights", weights)
        neff = self._setattr("neff", 1 / jnp.sum(weights**2))

        bw_method = "scott" if bw_method is None else bw_method
        if bw_method == "scott":
            factor = jnp.power(neff, -1. / (d + 4))
        elif bw_method == "silverman":
            factor = jnp.power(neff * (d + 2) / 4.0, -1. / (d + 4))
        elif jnp.isscalar(bw_method) and not isinstance(bw_method, str):
            factor = bw_method
        elif callable(bw_method):
            factor = bw_method(self)
        else:
            raise ValueError(
                "`bw_method` should be 'scott', 'silverman', a scalar, or a callable."
            )

        data_covariance = jnp.atleast_2d(
            jnp.cov(dataset, rowvar=1, bias=False, aweights=weights))
        data_inv_cov = jnp.linalg.inv(data_covariance)
        covariance = data_covariance * factor**2
        inv_cov = data_inv_cov / factor**2
        self._setattr("covariance", covariance)
        self._setattr("inv_cov", inv_cov)
Пример #2
0
def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=None):
    if overwrite_data is not None:
        raise NotImplementedError("overwrite_data argument not implemented.")
    if type not in ['constant', 'linear']:
        raise ValueError("Trend type must be 'linear' or 'constant'.")
    data, = _promote_dtypes_inexact(jnp.asarray(data))
    if type == 'constant':
        return data - data.mean(axis, keepdims=True)
    else:
        N = data.shape[axis]
        # bp is static, so we use np operations to avoid pushing to device.
        bp = np.sort(np.unique(np.r_[0, bp, N]))
        if bp[0] < 0 or bp[-1] > N:
            raise ValueError(
                "Breakpoints must be non-negative and less than length of data along given axis."
            )
        data = jnp.moveaxis(data, axis, 0)
        shape = data.shape
        data = data.reshape(N, -1)
        for m in range(len(bp) - 1):
            Npts = bp[m + 1] - bp[m]
            A = jnp.vstack([
                jnp.ones(Npts, dtype=data.dtype),
                jnp.arange(1, Npts + 1, dtype=data.dtype) / Npts
            ]).T
            sl = slice(bp[m], bp[m + 1])
            coef, *_ = linalg.lstsq(A, data[sl])
            data = data.at[sl].add(
                -jnp.matmul(A, coef, precision=lax.Precision.HIGHEST))
        return jnp.moveaxis(data.reshape(shape), 0, axis)
Пример #3
0
def _gaussian_kernel_eval(in_log, points, values, xi, precision):
    points, values, xi, precision = _promote_dtypes_inexact(
        points, values, xi, precision)
    d = points.shape[1]

    if xi.shape[1] != d:
        raise ValueError("points and xi must have same trailing dim")
    if precision.shape != (d, d):
        raise ValueError("precision matrix must match data dims")

    whitening = linalg.cholesky(precision, lower=True)
    points = jnp.dot(points, whitening)
    xi = jnp.dot(xi, whitening)
    log_norm = jnp.sum(jnp.log(
        jnp.diag(whitening))) - 0.5 * d * jnp.log(2 * np.pi)

    def kernel(x_test, x_train, y_train):
        arg = log_norm - 0.5 * jnp.sum(jnp.square(x_train - x_test))
        if in_log:
            return jnp.log(y_train) + arg
        else:
            return y_train * jnp.exp(arg)

    reduce = special.logsumexp if in_log else jnp.sum
    reduced_kernel = lambda x: reduce(vmap(kernel, in_axes=(None, 0, 0))
                                      (x, points, values),
                                      axis=0)
    mapped_kernel = vmap(reduced_kernel)

    return mapped_kernel(xi)
Пример #4
0
def logpdf(x, mean, cov, allow_singular=None):
    if allow_singular is not None:
        raise NotImplementedError(
            "allow_singular argument of multivariate_normal.logpdf")
    x, mean, cov = _promote_dtypes_inexact(x, mean, cov)
    if not mean.shape:
        return (-1 / 2 * jnp.square(x - mean) / cov - 1 / 2 *
                (np.log(2 * np.pi) + jnp.log(cov)))
    else:
        n = mean.shape[-1]
        if not np.shape(cov):
            y = x - mean
            return (-1 / 2 * jnp.einsum('...i,...i->...', y, y) / cov - n / 2 *
                    (np.log(2 * np.pi) + jnp.log(cov)))
        else:
            if cov.ndim < 2 or cov.shape[-2:] != (n, n):
                raise ValueError(
                    "multivariate_normal.logpdf got incompatible shapes")
            L = lax.linalg.cholesky(cov)
            y = lax.linalg.triangular_solve(L,
                                            x - mean,
                                            lower=True,
                                            transpose_a=True)
            return (-1 / 2 * jnp.einsum('...i,...i->...', y, y) -
                    n / 2 * np.log(2 * np.pi) -
                    jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1))
Пример #5
0
    def __init__(self,
                 points,
                 values,
                 method="linear",
                 bounds_error=False,
                 fill_value=nan):
        if method not in ("linear", "nearest"):
            raise ValueError(f"method {method!r} is not defined")
        self.method = method
        self.bounds_error = bounds_error
        if self.bounds_error:
            raise NotImplementedError(
                "`bounds_error` takes no effect under JIT")

        _check_arraylike("RegularGridInterpolator", values)
        if len(points) > values.ndim:
            ve = f"there are {len(points)} point arrays, but values has {values.ndim} dimensions"
            raise ValueError(ve)

        values, = _promote_dtypes_inexact(values)

        if fill_value is not None:
            _check_arraylike("RegularGridInterpolator", fill_value)
            fill_value = asarray(fill_value)
            if not can_cast(
                    fill_value.dtype, values.dtype, casting='same_kind'):
                ve = "fill_value must be either 'None' or of a type compatible with values"
                raise ValueError(ve)
        self.fill_value = fill_value

        # TODO: assert sanity of `points` similar to SciPy but in a JIT-able way
        _check_arraylike("RegularGridInterpolator", *points)
        self.grid = tuple(asarray(p) for p in points)
        self.values = values
Пример #6
0
def _convolve_nd(in1, in2, mode, *, precision):
  if mode not in ["full", "same", "valid"]:
    raise ValueError("mode must be one of ['full', 'same', 'valid']")
  if in1.ndim != in2.ndim:
    raise ValueError("in1 and in2 must have the same number of dimensions")
  if in1.size == 0 or in2.size == 0:
    raise ValueError(f"zero-size arrays not supported in convolutions, got shapes {in1.shape} and {in2.shape}.")
  in1, in2 = _promote_dtypes_inexact(in1, in2)

  no_swap = all(s1 >= s2 for s1, s2 in zip(in1.shape, in2.shape))
  swap = all(s1 <= s2 for s1, s2 in zip(in1.shape, in2.shape))
  if not (no_swap or swap):
    raise ValueError("One input must be smaller than the other in every dimension.")

  shape_o = in2.shape
  if swap:
    in1, in2 = in2, in1
  shape = in2.shape
  in2 = jnp.flip(in2)

  if mode == 'valid':
    padding = [(0, 0) for s in shape]
  elif mode == 'same':
    padding = [(s - 1 - (s_o - 1) // 2, s - s_o + (s_o - 1) // 2)
               for (s, s_o) in zip(shape, shape_o)]
  elif mode == 'full':
    padding = [(s - 1, s - 1) for s in shape]

  strides = tuple(1 for s in shape)
  result = lax.conv_general_dilated(in1[None, None], in2[None, None], strides,
                                    padding, precision=precision)
  return result[0, 0]
Пример #7
0
def logpdf(x, mean, cov):
    x, mean, cov = _promote_dtypes_inexact(x, mean, cov)
    if not mean.shape:
        return (-1 / 2 * jnp.square(x - mean) / cov - 1 / 2 *
                (np.log(2 * np.pi) + jnp.log(cov)))
    else:
        n = mean.shape[-1]
        if not np.shape(cov):
            y = x - mean
            return (-1 / 2 * jnp.einsum('...i,...i->...', y, y) / cov - n / 2 *
                    (np.log(2 * np.pi) + jnp.log(cov)))
        else:
            if cov.ndim < 2 or cov.shape[-2:] != (n, n):
                raise ValueError(
                    "multivariate_normal.logpdf got incompatible shapes")
            L = cholesky(cov)
            y = triangular_solve(L, x - mean, lower=True, transpose_a=True)
            return (-1 / 2 * jnp.einsum('...i,...i->...', y, y) -
                    n / 2 * np.log(2 * np.pi) - jnp.log(L.diagonal()).sum())
Пример #8
0
def logpdf(x, alpha):
  x, alpha = _promote_dtypes_inexact(x, alpha)
  if alpha.ndim != 1:
    raise ValueError(
      f"`alpha` must be one-dimensional; got alpha.shape={alpha.shape}"
    )
  if x.shape[0] not in (alpha.shape[0], alpha.shape[0] - 1):
    raise ValueError(
      "`x` must have either the same number of entries as `alpha` "
      f"or one entry fewer; got x.shape={x.shape}, alpha.shape={alpha.shape}"
    )
  one = lax._const(x, 1)
  if x.shape[0] != alpha.shape[0]:
    x = jnp.concatenate([x, lax.sub(one, x.sum(0, keepdims=True))], axis=0)
  normalize_term = jnp.sum(gammaln(alpha)) - gammaln(jnp.sum(alpha))
  if x.ndim > 1:
    alpha = lax.broadcast_in_dim(alpha, alpha.shape + (1,) * (x.ndim - 1), (0,))
  log_probs = lax.sub(jnp.sum(xlogy(lax.sub(alpha, one), x), axis=0), normalize_term)
  return jnp.where(_is_simplex(x), log_probs, -jnp.inf)
Пример #9
0
def _spectral_helper(x,
                     y,
                     fs=1.0,
                     window='hann',
                     nperseg=None,
                     noverlap=None,
                     nfft=None,
                     detrend_type='constant',
                     return_onesided=True,
                     scaling='density',
                     axis=-1,
                     mode='psd',
                     boundary=None,
                     padded=False):
    """LAX-backend implementation of `scipy.signal._spectral_helper`.

  Unlike the original helper function, `y` can be None for explicitly
  indicating auto-spectral (non cross-spectral) computation.  In addition to
  this, `detrend` argument is renamed to `detrend_type` for avoiding internal
  name overlap.
  """
    if mode not in ('psd', 'stft'):
        raise ValueError(f"Unknown value for mode {mode}, "
                         "must be one of: ('psd', 'stft')")

    def make_pad(mode, **kwargs):
        def pad(x, n, axis=-1):
            pad_width = [(0, 0) for unused_n in range(x.ndim)]
            pad_width[axis] = (n, n)
            return jnp.pad(x, pad_width, mode, **kwargs)

        return pad

    boundary_funcs = {
        'even': make_pad('reflect'),
        'odd': odd_ext,
        'constant': make_pad('edge'),
        'zeros': make_pad('constant', constant_values=0.0),
        None: lambda x, *args, **kwargs: x
    }

    # Check/ normalize inputs
    if boundary not in boundary_funcs:
        raise ValueError(f"Unknown boundary option '{boundary}', "
                         f"must be one of: {list(boundary_funcs.keys())}")

    axis = jax.core.concrete_or_error(operator.index, axis,
                                      "axis of windowed-FFT")
    axis = canonicalize_axis(axis, x.ndim)

    if y is None:
        _check_arraylike('spectral_helper', x)
        x, = _promote_dtypes_inexact(x)
        outershape = tuple_delete(x.shape, axis)
    else:
        if mode != 'psd':
            raise ValueError(
                "two-argument mode is available only when mode=='psd'")
        _check_arraylike('spectral_helper', x, y)
        x, y = _promote_dtypes_inexact(x, y)
        if x.ndim != y.ndim:
            raise ValueError(
                "two-arguments must have the same rank ({x.ndim} vs {y.ndim})."
            )
        # Check if we can broadcast the outer axes together
        try:
            outershape = jnp.broadcast_shapes(tuple_delete(x.shape, axis),
                                              tuple_delete(y.shape, axis))
        except ValueError as err:
            raise ValueError('x and y cannot be broadcast together.') from err

    result_dtype = dtypes._to_complex_dtype(x.dtype)
    freq_dtype = np.finfo(result_dtype).dtype

    if nperseg is not None:  # if specified by user
        nperseg = jax.core.concrete_or_error(int, nperseg,
                                             "nperseg of windowed-FFT")
        if nperseg < 1:
            raise ValueError('nperseg must be a positive integer')
    # parse window; if array like, then set nperseg = win.shape
    win, nperseg = signal_helper._triage_segments(window,
                                                  nperseg,
                                                  input_length=x.shape[axis],
                                                  dtype=result_dtype)

    if noverlap is None:
        noverlap = nperseg // 2
    else:
        noverlap = jax.core.concrete_or_error(int, noverlap,
                                              "noverlap of windowed-FFT")
    if nfft is None:
        nfft = nperseg
    else:
        nfft = jax.core.concrete_or_error(int, nfft, "nfft of windowed-FFT")

    # Special cases for size == 0
    if y is None:
        if x.size == 0:
            return jnp.zeros(x.shape, freq_dtype), jnp.zeros(
                x.shape, freq_dtype), jnp.zeros(x.shape, result_dtype)
    else:
        if x.size == 0 or y.size == 0:
            shape = tuple_insert(outershape,
                                 min([x.shape[axis], y.shape[axis]]), axis)
            return jnp.zeros(shape, freq_dtype), jnp.zeros(
                shape, freq_dtype), jnp.zeros(shape, result_dtype)

    # Move time-axis to the end
    x = jnp.moveaxis(x, axis, -1)
    if y is not None and y.ndim > 1:
        y = jnp.moveaxis(y, axis, -1)

    # Check if x and y are the same length, zero-pad if necessary
    if y is not None and x.shape[-1] != y.shape[-1]:
        if x.shape[-1] < y.shape[-1]:
            pad_shape = list(x.shape)
            pad_shape[-1] = y.shape[-1] - x.shape[-1]
            x = jnp.concatenate((x, jnp.zeros_like(x, shape=pad_shape)), -1)
        else:
            pad_shape = list(y.shape)
            pad_shape[-1] = x.shape[-1] - y.shape[-1]
            y = jnp.concatenate((y, jnp.zeros_like(x, shape=pad_shape)), -1)

    if nfft < nperseg:
        raise ValueError('nfft must be greater than or equal to nperseg.')
    if noverlap >= nperseg:
        raise ValueError('noverlap must be less than nperseg.')
    nstep = nperseg - noverlap

    # Apply paddings
    if boundary is not None:
        ext_func = boundary_funcs[boundary]
        x = ext_func(x, nperseg // 2, axis=-1)
        if y is not None:
            y = ext_func(y, nperseg // 2, axis=-1)

    if padded:
        # Pad to integer number of windowed segments
        # I.e make x.shape[-1] = nperseg + (nseg-1)*nstep, with integer nseg
        nadd = (-(x.shape[-1] - nperseg) % nstep) % nperseg
        x = jnp.concatenate(
            (x, jnp.zeros_like(x, shape=(*x.shape[:-1], nadd))), axis=-1)
        if y is not None:
            y = jnp.concatenate(
                (y, jnp.zeros_like(x, shape=(*y.shape[:-1], nadd))), axis=-1)

    # Handle detrending and window functions
    if not detrend_type:
        detrend_func = lambda d: d
    elif not callable(detrend_type):
        detrend_func = partial(detrend, type=detrend_type, axis=-1)
    elif axis != -1:
        # Wrap this function so that it receives a shape that it could
        # reasonably expect to receive.
        def detrend_func(d):
            d = jnp.moveaxis(d, axis, -1)
            d = detrend_type(d)
            return jnp.moveaxis(d, -1, axis)
    else:
        detrend_func = detrend_type

    # Determine scale
    if scaling == 'density':
        scale = 1.0 / (fs * (win * win).sum())
    elif scaling == 'spectrum':
        scale = 1.0 / win.sum()**2
    else:
        raise ValueError(f'Unknown scaling: {scaling}')
    if mode == 'stft':
        scale = jnp.sqrt(scale)

    # Determine onesided/ two-sided
    if return_onesided:
        sides = 'onesided'
        if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
            sides = 'twosided'
            warnings.warn('Input data is complex, switching to '
                          'return_onesided=False')
    else:
        sides = 'twosided'

    if sides == 'twosided':
        freqs = jax.numpy.fft.fftfreq(nfft, 1 / fs).astype(freq_dtype)
    elif sides == 'onesided':
        freqs = jax.numpy.fft.rfftfreq(nfft, 1 / fs).astype(freq_dtype)

    # Perform the windowed FFTs
    result = _fft_helper(x.astype(result_dtype), win, detrend_func, nperseg,
                         noverlap, nfft, sides)

    if y is not None:
        # All the same operations on the y data
        result_y = _fft_helper(y.astype(result_dtype), win, detrend_func,
                               nperseg, noverlap, nfft, sides)
        result = jnp.conjugate(result) * result_y
    elif mode == 'psd':
        result = jnp.conjugate(result) * result

    result *= scale

    if sides == 'onesided' and mode == 'psd':
        end = None if nfft % 2 else -1
        result = result.at[..., 1:end].mul(2)

    time = jnp.arange(nperseg / 2,
                      x.shape[-1] - nperseg / 2 + 1,
                      nperseg - noverlap,
                      dtype=freq_dtype) / fs
    if boundary is not None:
        time -= (nperseg / 2) / fs

    result = result.astype(result_dtype)

    # All imaginary parts are zero anyways
    if y is None and mode != 'stft':
        result = result.real

    # Move frequency axis back to axis where the data came from
    result = jnp.moveaxis(result, -1, axis)

    return freqs, time, result