Example #1
0
  def __getitem__(self, key):
    if not isinstance(key, tuple):
      key = (key,)

    params = [self.axis, self.ndmin, self.trans1d, -1]

    if isinstance(key[0], str):
      # split off the directive
      directive, *key = key  # pytype: disable=bad-unpacking
      # check two special cases: matrix directives
      if directive == "r":
        params[-1] = 0
      elif directive == "c":
        params[-1] = 1
      else:
        vec = directive.split(",")
        k = len(vec)
        if k < 4:
          vec += params[k:]
        else:
          # ignore everything after the first three comma-separated ints
          vec = vec[:3] + params[-1]
        try:
          params = list(map(int, vec))
        except ValueError as err:
          raise ValueError(
            f"could not understand directive {directive!r}"
          ) from err

    axis, ndmin, trans1d, matrix = params

    output = []
    for item in key:
      if isinstance(item, slice):
        newobj = _make_1d_grid_from_slice(item, op_name=self.op_name)
      elif isinstance(item, str):
        raise ValueError("string directive must be placed at the beginning")
      else:
        newobj = item

      newobj = array(newobj, copy=False, ndmin=ndmin)

      if trans1d != -1 and ndmin - np.ndim(item) > 0:
        shape_obj = list(range(ndmin))
        # Calculate number of left shifts, with overflow protection by mod
        num_lshifts = ndmin - abs(ndmin + trans1d + 1) % ndmin
        shape_obj = tuple(shape_obj[num_lshifts:] + shape_obj[:num_lshifts])

        newobj = transpose(newobj, shape_obj)

      output.append(newobj)

    res = concatenate(tuple(output), axis=axis)

    if matrix != -1 and res.ndim == 1:
      # insert 2nd dim at axis 0 or 1
      res = expand_dims(res, matrix)

    return res
Example #2
0
def istft(Zxx,
          fs=1.0,
          window='hann',
          nperseg=None,
          noverlap=None,
          nfft=None,
          input_onesided=True,
          boundary=True,
          time_axis=-1,
          freq_axis=-2):
    # Input validation
    _check_arraylike("istft", Zxx)
    if Zxx.ndim < 2:
        raise ValueError('Input stft must be at least 2d!')
    freq_axis = canonicalize_axis(freq_axis, Zxx.ndim)
    time_axis = canonicalize_axis(time_axis, Zxx.ndim)
    if freq_axis == time_axis:
        raise ValueError('Must specify differing time and frequency axes!')

    Zxx = jnp.asarray(Zxx,
                      dtype=jax.dtypes.canonicalize_dtype(
                          np.result_type(Zxx, np.complex64)))

    n_default = (2 * (Zxx.shape[freq_axis] - 1)
                 if input_onesided else Zxx.shape[freq_axis])

    nperseg = jax.core.concrete_or_error(int, nperseg or n_default,
                                         "nperseg: segment length of STFT")
    if nperseg < 1:
        raise ValueError('nperseg must be a positive integer')

    if nfft is None:
        nfft = n_default
        if input_onesided and nperseg == n_default + 1:
            nfft += 1  # Odd nperseg, no FFT padding
    else:
        nfft = jax.core.concrete_or_error(int, nfft, "nfft of STFT")
    if nfft < nperseg:
        raise ValueError(
            f'FFT length ({nfft}) must be longer than nperseg ({nperseg}).')

    noverlap = jax.core.concrete_or_error(int, noverlap or nperseg // 2,
                                          "noverlap of STFT")
    if noverlap >= nperseg:
        raise ValueError('noverlap must be less than nperseg.')
    nstep = nperseg - noverlap

    # Rearrange axes if necessary
    if time_axis != Zxx.ndim - 1 or freq_axis != Zxx.ndim - 2:
        outer_idxs = tuple(idx for idx in range(Zxx.ndim)
                           if idx not in {time_axis, freq_axis})
        Zxx = jnp.transpose(Zxx, outer_idxs + (freq_axis, time_axis))

    # Perform IFFT
    ifunc = jax.numpy.fft.irfft if input_onesided else jax.numpy.fft.ifft
    # xsubs: [..., T, N], N is the number of frames, T is the frame length.
    xsubs = ifunc(Zxx, axis=-2, n=nfft)[..., :nperseg, :]

    # Get window as array
    if isinstance(window, (str, tuple)):
        win = osp_signal.get_window(window, nperseg)
        win = jnp.asarray(win)
    else:
        win = jnp.asarray(window)
        if len(win.shape) != 1:
            raise ValueError('window must be 1-D')
        if win.shape[0] != nperseg:
            raise ValueError('window must have length of {0}'.format(nperseg))
    win = win.astype(xsubs.dtype)

    xsubs *= win.sum()  # This takes care of the 'spectrum' scaling

    # make win broadcastable over xsubs
    win = win.reshape((1, ) * (xsubs.ndim - 2) + win.shape + (1, ))
    x = _overlap_and_add((xsubs * win).swapaxes(-2, -1), nstep)
    win_squared = jnp.repeat((win * win), xsubs.shape[-1], axis=-1)
    norm = _overlap_and_add(win_squared.swapaxes(-2, -1), nstep)

    # Remove extension points
    if boundary:
        x = x[..., nperseg // 2:-(nperseg // 2)]
        norm = norm[..., nperseg // 2:-(nperseg // 2)]
    x /= jnp.where(norm > 1e-10, norm, 1.0)

    # Put axes back
    if x.ndim > 1:
        if time_axis != Zxx.ndim - 1:
            if freq_axis < time_axis:
                time_axis -= 1
            x = jnp.moveaxis(x, -1, time_axis)

    time = jnp.arange(x.shape[0]) / fs
    return time, x