コード例 #1
0
ファイル: kde.py プロジェクト: romanngg/jax
 def logpdf(self, x):
     _check_arraylike("logpdf", x)
     x = self._reshape_points(x)
     result = _gaussian_kernel_eval(True, self.dataset.T,
                                    self.weights[:,
                                                 None], x.T, self.inv_cov)
     return result[:, 0]
コード例 #2
0
ファイル: kde.py プロジェクト: romanngg/jax
 def evaluate(self, points):
     _check_arraylike("evaluate", points)
     points = self._reshape_points(points)
     result = _gaussian_kernel_eval(False, self.dataset.T,
                                    self.weights[:, None], points.T,
                                    self.inv_cov)
     return result[:, 0]
コード例 #3
0
def _segment_update(name: str,
                    data: Array,
                    segment_ids: Array,
                    scatter_op: Callable,
                    num_segments: Optional[int] = None,
                    indices_are_sorted: bool = False,
                    unique_indices: bool = False,
                    bucket_size: Optional[int] = None,
                    reducer: Optional[Callable] = None,
                    mode: Optional[lax.GatherScatterMode] = None) -> Array:
    jnp._check_arraylike(name, data, segment_ids)
    mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode
    data = jnp.asarray(data)
    segment_ids = jnp.asarray(segment_ids)
    dtype = data.dtype
    if num_segments is None:
        num_segments = jnp.max(segment_ids) + 1
    num_segments = core.concrete_or_error(
        int, num_segments, "segment_sum() `num_segments` argument.")
    if num_segments is not None and num_segments < 0:
        raise ValueError("num_segments must be non-negative.")


    num_buckets = 1 if bucket_size is None \
                    else util.ceil_of_ratio(segment_ids.size, bucket_size)
    if num_buckets == 1:
        out = jnp.full((num_segments, ) + data.shape[1:],
                       _get_identity(scatter_op, dtype),
                       dtype=dtype)
        return _scatter_update(out,
                               segment_ids,
                               data,
                               scatter_op,
                               indices_are_sorted,
                               unique_indices,
                               normalize_indices=False,
                               mode=mode)

    # Bucketize indices and perform segment_update on each bucket to improve
    # numerical stability for operations like product and sum.
    assert reducer is not None
    out = jnp.full((num_buckets, num_segments) + data.shape[1:],
                   _get_identity(scatter_op, dtype),
                   dtype=dtype)
    out = _scatter_update(
        out,
        np.index_exp[lax.div(jnp.arange(segment_ids.shape[0]), bucket_size),
                     segment_ids[None, :]],
        data,
        scatter_op,
        indices_are_sorted,
        unique_indices,
        normalize_indices=False,
        mode=mode)
    return reducer(out, axis=0).astype(dtype)
コード例 #4
0
ファイル: kde.py プロジェクト: romanngg/jax
    def __init__(self, dataset, bw_method=None, weights=None):
        _check_arraylike("gaussian_kde", dataset)
        dataset = jnp.atleast_2d(dataset)
        if jnp.issubdtype(lax.dtype(dataset), jnp.complexfloating):
            raise NotImplementedError(
                "gaussian_kde does not support complex data")
        if not dataset.size > 1:
            raise ValueError("`dataset` input should have multiple elements.")

        d, n = dataset.shape
        if weights is not None:
            _check_arraylike("gaussian_kde", weights)
            dataset, weights = _promote_dtypes_inexact(dataset, weights)
            weights = jnp.atleast_1d(weights)
            weights /= jnp.sum(weights)
            if weights.ndim != 1:
                raise ValueError("`weights` input should be one-dimensional.")
            if len(weights) != n:
                raise ValueError("`weights` input should be of length n")
        else:
            dataset, = _promote_dtypes_inexact(dataset)
            weights = jnp.full(n, 1.0 / n, dtype=dataset.dtype)

        self._setattr("dataset", dataset)
        self._setattr("weights", weights)
        neff = self._setattr("neff", 1 / jnp.sum(weights**2))

        bw_method = "scott" if bw_method is None else bw_method
        if bw_method == "scott":
            factor = jnp.power(neff, -1. / (d + 4))
        elif bw_method == "silverman":
            factor = jnp.power(neff * (d + 2) / 4.0, -1. / (d + 4))
        elif jnp.isscalar(bw_method) and not isinstance(bw_method, str):
            factor = bw_method
        elif callable(bw_method):
            factor = bw_method(self)
        else:
            raise ValueError(
                "`bw_method` should be 'scott', 'silverman', a scalar, or a callable."
            )

        data_covariance = jnp.atleast_2d(
            jnp.cov(dataset, rowvar=1, bias=False, aweights=weights))
        data_inv_cov = jnp.linalg.inv(data_covariance)
        covariance = data_covariance * factor**2
        inv_cov = data_inv_cov / factor**2
        self._setattr("covariance", covariance)
        self._setattr("inv_cov", inv_cov)
コード例 #5
0
def _segment_update(name: str,
                    data: Array,
                    segment_ids: Array,
                    scatter_op: Callable,
                    num_segments: Optional[int] = None,
                    indices_are_sorted: bool = False,
                    unique_indices: bool = False,
                    bucket_size: Optional[int] = None,
                    reducer: Optional[Callable] = None) -> Array:
    jnp._check_arraylike(name, data, segment_ids)
    data = jnp.asarray(data)
    segment_ids = jnp.asarray(segment_ids)
    dtype = data.dtype
    if num_segments is None:
        num_segments = jnp.max(segment_ids) + 1
    num_segments = core.concrete_or_error(
        int, num_segments, "segment_sum() `num_segments` argument.")
    if num_segments is not None and num_segments < 0:
        raise ValueError("num_segments must be non-negative.")

    out = jnp.full((num_segments, ) + data.shape[1:],
                   _get_identity(scatter_op, dtype),
                   dtype=dtype)

    num_buckets = 1 if bucket_size is None \
                    else util.ceil_of_ratio(segment_ids.size, bucket_size)
    if num_buckets == 1:
        return _scatter_update(out,
                               segment_ids,
                               data,
                               scatter_op,
                               indices_are_sorted,
                               unique_indices,
                               normalize_indices=False)

    # Bucketize indices and perform segment_update on each bucket to improve
    # numerical stability for operations like product and sum.
    assert reducer is not None
    outs = []
    for sub_data, sub_segment_ids in zip(
            jnp.array_split(data, num_buckets),
            jnp.array_split(segment_ids, num_buckets)):
        outs.append(
            _segment_update(name, sub_data, sub_segment_ids, scatter_op,
                            num_segments, indices_are_sorted, unique_indices))
    return reducer(jnp.stack(outs), axis=0).astype(dtype)
コード例 #6
0
ファイル: signal.py プロジェクト: frederikwilde/jax
def _overlap_and_add(x, step_size):
    """Utility function compatible with tf.signal.overlap_and_add.

  Args:
    x: An array with `(..., frames, frame_length)`-shape.
    step_size: An integer denoting overlap offsets. Must be less than
      `frame_length`.

  Returns:
    An array with `(..., output_size)`-shape containing overlapped signal.
  """
    _check_arraylike("_overlap_and_add", x)
    step_size = jax.core.concrete_or_error(int, step_size,
                                           "step_size for overlap_and_add")
    if x.ndim < 2:
        raise ValueError('Input must have (..., frames, frame_length) shape.')

    *batch_shape, nframes, segment_len = x.shape
    flat_batchsize = np.prod(batch_shape, dtype=np.int64)
    x = x.reshape((flat_batchsize, nframes, segment_len))
    output_size = step_size * (nframes - 1) + segment_len
    nstep_per_segment = 1 + (segment_len - 1) // step_size

    # Here, we use shorter notation for axes.
    # B: batch_size, N: nframes, S: nstep_per_segment,
    # T: segment_len divided by S

    padded_segment_len = nstep_per_segment * step_size
    x = jnp.pad(x, ((0, 0), (0, 0), (0, padded_segment_len - segment_len)))
    x = x.reshape((flat_batchsize, nframes, nstep_per_segment, step_size))

    # For obtaining shifted signals, this routine reinterprets flattened array
    # with a shrinked axis.  With appropriate truncation/ padding, this operation
    # pushes the last padded elements of the previous row to the head of the
    # current row.
    # See implementation of `overlap_and_add` in Tensorflow for details.
    x = x.transpose((0, 2, 1, 3))  # x: (B, S, N, T)
    x = jnp.pad(x, ((0, 0), (0, 0), (0, nframes), (0, 0)))  # x: (B, S, N*2, T)
    shrinked = x.shape[2] - 1
    x = x.reshape((flat_batchsize, -1))
    x = x[:, :(nstep_per_segment * shrinked * step_size)]
    x = x.reshape((flat_batchsize, nstep_per_segment, shrinked * step_size))

    # Finally, sum shifted segments, and truncate results to the output_size.
    x = x.sum(axis=1)[:, :output_size]
    return x.reshape(tuple(batch_shape) + (-1, ))
コード例 #7
0
ファイル: interpolate.py プロジェクト: wayfeng/jax
    def __init__(self,
                 points,
                 values,
                 method="linear",
                 bounds_error=False,
                 fill_value=nan):
        if method not in ("linear", "nearest"):
            raise ValueError(f"method {method!r} is not defined")
        self.method = method
        self.bounds_error = bounds_error
        if self.bounds_error:
            raise NotImplementedError(
                "`bounds_error` takes no effect under JIT")

        _check_arraylike("RegularGridInterpolator", values)
        if len(points) > values.ndim:
            ve = f"there are {len(points)} point arrays, but values has {values.ndim} dimensions"
            raise ValueError(ve)

        values, = _promote_dtypes_inexact(values)

        if fill_value is not None:
            _check_arraylike("RegularGridInterpolator", fill_value)
            fill_value = asarray(fill_value)
            if not can_cast(
                    fill_value.dtype, values.dtype, casting='same_kind'):
                ve = "fill_value must be either 'None' or of a type compatible with values"
                raise ValueError(ve)
        self.fill_value = fill_value

        # TODO: assert sanity of `points` similar to SciPy but in a JIT-able way
        _check_arraylike("RegularGridInterpolator", *points)
        self.grid = tuple(asarray(p) for p in points)
        self.values = values
コード例 #8
0
ファイル: interpolate.py プロジェクト: wayfeng/jax
def _ndim_coords_from_arrays(points, ndim=None):
    """Convert a tuple of coordinate arrays to a (..., ndim)-shaped array."""
    if isinstance(points, tuple) and len(points) == 1:
        # handle argument tuple
        points = points[0]
    if isinstance(points, tuple):
        p = broadcast_arrays(*points)
        for p_other in p[1:]:
            if p_other.shape != p[0].shape:
                raise ValueError(
                    "coordinate arrays do not have the same shape")
        points = empty(p[0].shape + (len(points), ), dtype=float)
        for j, item in enumerate(p):
            points = points.at[..., j].set(item)
    else:
        _check_arraylike("_ndim_coords_from_arrays", points)
        points = asarray(points)  # SciPy: asanyarray(points)
        if points.ndim == 1:
            if ndim is None:
                points = points.reshape(-1, 1)
            else:
                points = points.reshape(-1, ndim)
    return points
コード例 #9
0
ファイル: signal.py プロジェクト: GJBoth/jax
def _spectral_helper(x,
                     y,
                     fs=1.0,
                     window='hann',
                     nperseg=None,
                     noverlap=None,
                     nfft=None,
                     detrend_type='constant',
                     return_onesided=True,
                     scaling='density',
                     axis=-1,
                     mode='psd',
                     boundary=None,
                     padded=False):
    """LAX-backend implementation of `scipy.signal._spectral_helper`.

  Unlike the original helper function, `y` can be None for explicitly
  indicating auto-spectral (non cross-spectral) computation.  In addition to
  this, `detrend` argument is renamed to `detrend_type` for avoiding internal
  name overlap.
  """
    if mode not in ('psd', 'stft'):
        raise ValueError(f"Unknown value for mode {mode}, "
                         "must be one of: ('psd', 'stft')")

    def make_pad(mode, **kwargs):
        def pad(x, n, axis=-1):
            pad_width = [(0, 0) for unused_n in range(x.ndim)]
            pad_width[axis] = (n, n)
            return jnp.pad(x, pad_width, mode, **kwargs)

        return pad

    boundary_funcs = {
        'even': make_pad('reflect'),
        'odd': odd_ext,
        'constant': make_pad('edge'),
        'zeros': make_pad('constant', constant_values=0.0),
        None: lambda x, *args, **kwargs: x
    }

    # Check/ normalize inputs
    if boundary not in boundary_funcs:
        raise ValueError(f"Unknown boundary option '{boundary}', "
                         f"must be one of: {list(boundary_funcs.keys())}")

    axis = jax.core.concrete_or_error(operator.index, axis,
                                      "axis of windowed-FFT")
    axis = canonicalize_axis(axis, x.ndim)

    if nperseg is not None:  # if specified by user
        nperseg = jax.core.concrete_or_error(int, nperseg,
                                             "nperseg of windowed-FFT")
        if nperseg < 1:
            raise ValueError('nperseg must be a positive integer')
    # parse window; if array like, then set nperseg = win.shape
    win, nperseg = signal_helper._triage_segments(window,
                                                  nperseg,
                                                  input_length=x.shape[axis])

    if noverlap is None:
        noverlap = nperseg // 2
    else:
        noverlap = jax.core.concrete_or_error(int, noverlap,
                                              "noverlap of windowed-FFT")
    if nfft is None:
        nfft = nperseg
    else:
        nfft = jax.core.concrete_or_error(int, nfft, "nfft of windowed-FFT")

    _check_arraylike("_spectral_helper", x)
    x = jnp.asarray(x)

    if y is None:
        outdtype = jax.dtypes.canonicalize_dtype(
            np.result_type(x, np.complex64))
    else:
        _check_arraylike("_spectral_helper", y)
        y = jnp.asarray(y)
        outdtype = jax.dtypes.canonicalize_dtype(
            np.result_type(x, y, np.complex64))
        if mode != 'psd':
            raise ValueError(
                "two-argument mode is available only when mode=='psd'")
        if x.ndim != y.ndim:
            raise ValueError(
                "two-arguments must have the same rank ({x.ndim} vs {y.ndim})."
            )

        # Check if we can broadcast the outer axes together
        try:
            outershape = jnp.broadcast_shapes(tuple_delete(x.shape, axis),
                                              tuple_delete(y.shape, axis))
        except ValueError as e:
            raise ValueError('x and y cannot be broadcast together.') from e

    # Special cases for size == 0
    if y is None:
        if x.size == 0:
            return jnp.zeros(x.shape), jnp.zeros(x.shape), jnp.zeros(x.shape)
    else:
        if x.size == 0 or y.size == 0:
            outshape = tuple_insert(outershape,
                                    min([x.shape[axis], y.shape[axis]]), axis)
            emptyout = jnp.zeros(outshape)
            return emptyout, emptyout, emptyout

    # Move time-axis to the end
    if x.ndim > 1:
        if axis != -1:
            x = jnp.moveaxis(x, axis, -1)
            if y is not None and y.ndim > 1:
                y = jnp.moveaxis(y, axis, -1)

    # Check if x and y are the same length, zero-pad if necessary
    if y is not None:
        if x.shape[-1] != y.shape[-1]:
            if x.shape[-1] < y.shape[-1]:
                pad_shape = list(x.shape)
                pad_shape[-1] = y.shape[-1] - x.shape[-1]
                x = jnp.concatenate((x, jnp.zeros(pad_shape)), -1)
            else:
                pad_shape = list(y.shape)
                pad_shape[-1] = x.shape[-1] - y.shape[-1]
                y = jnp.concatenate((y, jnp.zeros(pad_shape)), -1)

    if nfft < nperseg:
        raise ValueError('nfft must be greater than or equal to nperseg.')
    if noverlap >= nperseg:
        raise ValueError('noverlap must be less than nperseg.')
    nstep = nperseg - noverlap

    # Apply paddings
    if boundary is not None:
        ext_func = boundary_funcs[boundary]
        x = ext_func(x, nperseg // 2, axis=-1)
        if y is not None:
            y = ext_func(y, nperseg // 2, axis=-1)

    if padded:
        # Pad to integer number of windowed segments
        # I.e make x.shape[-1] = nperseg + (nseg-1)*nstep, with integer nseg
        nadd = (-(x.shape[-1] - nperseg) % nstep) % nperseg
        zeros_shape = list(x.shape[:-1]) + [nadd]
        x = jnp.concatenate((x, jnp.zeros(zeros_shape)), axis=-1)
        if y is not None:
            zeros_shape = list(y.shape[:-1]) + [nadd]
            y = jnp.concatenate((y, jnp.zeros(zeros_shape)), axis=-1)

    # Handle detrending and window functions
    if not detrend_type:

        def detrend_func(d):
            return d
    elif not hasattr(detrend_type, '__call__'):

        def detrend_func(d):
            return detrend(d, type=detrend_type, axis=-1)
    elif axis != -1:
        # Wrap this function so that it receives a shape that it could
        # reasonably expect to receive.
        def detrend_func(d):
            d = jnp.moveaxis(d, axis, -1)
            d = detrend_type(d)
            return jnp.moveaxis(d, -1, axis)
    else:
        detrend_func = detrend_type

    if np.result_type(win, np.complex64) != outdtype:
        win = win.astype(outdtype)

    # Determine scale
    if scaling == 'density':
        scale = 1.0 / (fs * (win * win).sum())
    elif scaling == 'spectrum':
        scale = 1.0 / win.sum()**2
    else:
        raise ValueError(f'Unknown scaling: {scaling}')
    if mode == 'stft':
        scale = jnp.sqrt(scale)

    # Determine onesided/ two-sided
    if return_onesided:
        sides = 'onesided'
        if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
            sides = 'twosided'
            warnings.warn('Input data is complex, switching to '
                          'return_onesided=False')
    else:
        sides = 'twosided'

    if sides == 'twosided':
        freqs = jax.numpy.fft.fftfreq(nfft, 1 / fs)
    elif sides == 'onesided':
        freqs = jax.numpy.fft.rfftfreq(nfft, 1 / fs)

    # Perform the windowed FFTs
    result = _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides)

    if y is not None:
        # All the same operations on the y data
        result_y = _fft_helper(y, win, detrend_func, nperseg, noverlap, nfft,
                               sides)
        result = jnp.conjugate(result) * result_y
    elif mode == 'psd':
        result = jnp.conjugate(result) * result

    result *= scale

    if sides == 'onesided' and mode == 'psd':
        end = None if nfft % 2 else -1
        result = result.at[..., 1:end].mul(2)

    time = jnp.arange(nperseg / 2, x.shape[-1] - nperseg / 2 + 1,
                      nperseg - noverlap) / fs
    if boundary is not None:
        time -= (nperseg / 2) / fs

    result = result.astype(outdtype)

    # All imaginary parts are zero anyways
    if y is None and mode != 'stft':
        result = result.real

    # Move frequency axis back to axis where the data came from
    result = jnp.moveaxis(result, -1, axis)

    return freqs, time, result
コード例 #10
0
ファイル: signal.py プロジェクト: frederikwilde/jax
def istft(Zxx,
          fs=1.0,
          window='hann',
          nperseg=None,
          noverlap=None,
          nfft=None,
          input_onesided=True,
          boundary=True,
          time_axis=-1,
          freq_axis=-2):
    # Input validation
    _check_arraylike("istft", Zxx)
    if Zxx.ndim < 2:
        raise ValueError('Input stft must be at least 2d!')
    freq_axis = canonicalize_axis(freq_axis, Zxx.ndim)
    time_axis = canonicalize_axis(time_axis, Zxx.ndim)
    if freq_axis == time_axis:
        raise ValueError('Must specify differing time and frequency axes!')

    Zxx = jnp.asarray(Zxx,
                      dtype=jax.dtypes.canonicalize_dtype(
                          np.result_type(Zxx, np.complex64)))

    n_default = (2 * (Zxx.shape[freq_axis] - 1)
                 if input_onesided else Zxx.shape[freq_axis])

    nperseg = jax.core.concrete_or_error(int, nperseg or n_default,
                                         "nperseg: segment length of STFT")
    if nperseg < 1:
        raise ValueError('nperseg must be a positive integer')

    if nfft is None:
        nfft = n_default
        if input_onesided and nperseg == n_default + 1:
            nfft += 1  # Odd nperseg, no FFT padding
    else:
        nfft = jax.core.concrete_or_error(int, nfft, "nfft of STFT")
    if nfft < nperseg:
        raise ValueError(
            f'FFT length ({nfft}) must be longer than nperseg ({nperseg}).')

    noverlap = jax.core.concrete_or_error(int, noverlap or nperseg // 2,
                                          "noverlap of STFT")
    if noverlap >= nperseg:
        raise ValueError('noverlap must be less than nperseg.')
    nstep = nperseg - noverlap

    # Rearrange axes if necessary
    if time_axis != Zxx.ndim - 1 or freq_axis != Zxx.ndim - 2:
        outer_idxs = tuple(idx for idx in range(Zxx.ndim)
                           if idx not in {time_axis, freq_axis})
        Zxx = jnp.transpose(Zxx, outer_idxs + (freq_axis, time_axis))

    # Perform IFFT
    ifunc = jax.numpy.fft.irfft if input_onesided else jax.numpy.fft.ifft
    # xsubs: [..., T, N], N is the number of frames, T is the frame length.
    xsubs = ifunc(Zxx, axis=-2, n=nfft)[..., :nperseg, :]

    # Get window as array
    if isinstance(window, (str, tuple)):
        win = osp_signal.get_window(window, nperseg)
        win = jnp.asarray(win)
    else:
        win = jnp.asarray(window)
        if len(win.shape) != 1:
            raise ValueError('window must be 1-D')
        if win.shape[0] != nperseg:
            raise ValueError('window must have length of {0}'.format(nperseg))
    win = win.astype(xsubs.dtype)

    xsubs *= win.sum()  # This takes care of the 'spectrum' scaling

    # make win broadcastable over xsubs
    win = win.reshape((1, ) * (xsubs.ndim - 2) + win.shape + (1, ))
    x = _overlap_and_add((xsubs * win).swapaxes(-2, -1), nstep)
    win_squared = jnp.repeat((win * win), xsubs.shape[-1], axis=-1)
    norm = _overlap_and_add(win_squared.swapaxes(-2, -1), nstep)

    # Remove extension points
    if boundary:
        x = x[..., nperseg // 2:-(nperseg // 2)]
        norm = norm[..., nperseg // 2:-(nperseg // 2)]
    x /= jnp.where(norm > 1e-10, norm, 1.0)

    # Put axes back
    if x.ndim > 1:
        if time_axis != Zxx.ndim - 1:
            if freq_axis < time_axis:
                time_axis -= 1
            x = jnp.moveaxis(x, -1, time_axis)

    time = jnp.arange(x.shape[0]) / fs
    return time, x