Example #1
0
def inverse_stft(stfts,
                 frame_length,
                 frame_step,
                 fft_length,
                 window_fn=functools.partial(window_ops.hann_window,
                                             periodic=True),
                 name=None):
    """Computes the inverse [Short-time Fourier Transform][stft] of `stfts`.

  Implemented with GPU-compatible ops and supports gradients.

  Args:
    stfts: A `complex64` `[..., frames, fft_unique_bins]` `Tensor` of STFT bins
      representing a batch of `fft_length`-point STFTs where `fft_unique_bins`
      is `fft_length // 2 + 1`
    frame_length: An integer scalar `Tensor`. The window length in samples.
    frame_step: An integer scalar `Tensor`. The number of samples to step.
    fft_length: An integer scalar `Tensor`. The size of the FFT that produced
      `stfts`.
    window_fn: A callable that takes a window length and a `dtype` keyword
      argument and returns a `[window_length]` `Tensor` of samples in the
      provided datatype. If set to `None`, no windowing is used.
    name: An optional name for the operation.

  Returns:
    A `[..., samples]` `Tensor` of `float32` signals representing the inverse
    STFT for each input STFT in `stfts`.

  Raises:
    ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar,
      `frame_step` is not scalar, or `fft_length` is not scalar.

  [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
  """
    with ops.name_scope(name, 'inverse_stft', [stfts]):
        stfts = ops.convert_to_tensor(stfts, name='stfts')
        stfts.shape.with_rank_at_least(2)
        frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
        frame_length.shape.assert_has_rank(0)
        frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
        frame_step.shape.assert_has_rank(0)
        fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
        fft_length.shape.assert_has_rank(0)
        real_frames = spectral_ops.irfft(stfts,
                                         [fft_length])[..., :frame_length]

        # Optionally window and overlap-add the inner 2 dimensions of real_frames
        # into a single [samples] dimension.
        if window_fn is not None:
            window = window_fn(frame_length, dtype=stfts.dtype.real_dtype)
            real_frames *= window
        return reconstruction_ops.overlap_and_add(real_frames, frame_step)
Example #2
0
def inverse_stft(stfts,
                 frame_length,
                 frame_step,
                 fft_length,
                 window_fn=functools.partial(window_ops.hann_window,
                                             periodic=True),
                 name=None):
  """Computes the inverse [Short-time Fourier Transform][stft] of `stfts`.

  Implemented with GPU-compatible ops and supports gradients.

  Args:
    stfts: A `complex64` `[..., frames, fft_unique_bins]` `Tensor` of STFT bins
      representing a batch of `fft_length`-point STFTs where `fft_unique_bins`
      is `fft_length // 2 + 1`
    frame_length: An integer scalar `Tensor`. The window length in samples.
    frame_step: An integer scalar `Tensor`. The number of samples to step.
    fft_length: An integer scalar `Tensor`. The size of the FFT that produced
      `stfts`.
    window_fn: A callable that takes a window length and a `dtype` keyword
      argument and returns a `[window_length]` `Tensor` of samples in the
      provided datatype. If set to `None`, no windowing is used.
    name: An optional name for the operation.

  Returns:
    A `[..., samples]` `Tensor` of `float32` signals representing the inverse
    STFT for each input STFT in `stfts`.

  Raises:
    ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar,
      `frame_step` is not scalar, or `fft_length` is not scalar.

  [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
  """
  with ops.name_scope(name, 'inverse_stft', [stfts]):
    stfts = ops.convert_to_tensor(stfts, name='stfts')
    stfts.shape.with_rank_at_least(2)
    frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
    frame_length.shape.assert_has_rank(0)
    frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
    frame_step.shape.assert_has_rank(0)
    fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
    fft_length.shape.assert_has_rank(0)
    real_frames = spectral_ops.irfft(stfts, [fft_length])[..., :frame_length]

    # Optionally window and overlap-add the inner 2 dimensions of real_frames
    # into a single [samples] dimension.
    if window_fn is not None:
      window = window_fn(frame_length, dtype=stfts.dtype.real_dtype)
      real_frames *= window
    return reconstruction_ops.overlap_and_add(real_frames, frame_step)
Example #3
0
 def _tf_fn(x):
   return spectral_ops.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)])
Example #4
0
 def _tf_fn(x):
     return spectral_ops.irfft(x,
                               fft_length=[2 * (x.shape[-1].value - 1)])
Example #5
0
def inverse_stft(stfts,
                 frame_length,
                 frame_step,
                 fft_length=None,
                 window_fn=functools.partial(window_ops.hann_window,
                                             periodic=True),
                 name=None):
  """Computes the inverse [Short-time Fourier Transform][stft] of `stfts`.

  To reconstruct an original waveform, a complimentary window function should
  be used in inverse_stft. Such a window function can be constructed with
  tf.contrib.signal.inverse_stft_window_fn.

  Example:

  ```python
  frame_length = 400
  frame_step = 160
  waveform = tf.placeholder(dtype=tf.float32, shape=[1000])
  stft = tf.contrib.signal.stft(waveform, frame_length, frame_step)
  inverse_stft = tf.contrib.signal.inverse_stft(
      stft, frame_length, frame_step,
      window_fn=tf.contrib.signal.inverse_stft_window_fn(frame_step))
  ```

  if a custom window_fn is used in stft, it must be passed to
  inverse_stft_window_fn:

  ```python
  frame_length = 400
  frame_step = 160
  window_fn = functools.partial(window_ops.hamming_window, periodic=True),
  waveform = tf.placeholder(dtype=tf.float32, shape=[1000])
  stft = tf.contrib.signal.stft(
      waveform, frame_length, frame_step, window_fn=window_fn)
  inverse_stft = tf.contrib.signal.inverse_stft(
      stft, frame_length, frame_step,
      window_fn=tf.contrib.signal.inverse_stft_window_fn(
         frame_step, forward_window_fn=window_fn))
  ```

  Implemented with GPU-compatible ops and supports gradients.

  Args:
    stfts: A `complex64` `[..., frames, fft_unique_bins]` `Tensor` of STFT bins
      representing a batch of `fft_length`-point STFTs where `fft_unique_bins`
      is `fft_length // 2 + 1`
    frame_length: An integer scalar `Tensor`. The window length in samples.
    frame_step: An integer scalar `Tensor`. The number of samples to step.
    fft_length: An integer scalar `Tensor`. The size of the FFT that produced
      `stfts`. If not provided, uses the smallest power of 2 enclosing
      `frame_length`.
    window_fn: A callable that takes a window length and a `dtype` keyword
      argument and returns a `[window_length]` `Tensor` of samples in the
      provided datatype. If set to `None`, no windowing is used.
    name: An optional name for the operation.

  Returns:
    A `[..., samples]` `Tensor` of `float32` signals representing the inverse
    STFT for each input STFT in `stfts`.

  Raises:
    ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar,
      `frame_step` is not scalar, or `fft_length` is not scalar.

  [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
  """
  with ops.name_scope(name, 'inverse_stft', [stfts]):
    stfts = ops.convert_to_tensor(stfts, name='stfts')
    stfts.shape.with_rank_at_least(2)
    frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
    frame_length.shape.assert_has_rank(0)
    frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
    frame_step.shape.assert_has_rank(0)
    if fft_length is None:
      fft_length = _enclosing_power_of_two(frame_length)
    else:
      fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
      fft_length.shape.assert_has_rank(0)

    real_frames = spectral_ops.irfft(stfts, [fft_length])

    # frame_length may be larger or smaller than fft_length, so we pad or
    # truncate real_frames to frame_length.
    frame_length_static = tensor_util.constant_value(frame_length)
    # If we don't know the shape of real_frames's inner dimension, pad and
    # truncate to frame_length.
    if (frame_length_static is None or
        real_frames.shape.ndims is None or
        real_frames.shape[-1].value is None):
      real_frames = real_frames[..., :frame_length]
      real_frames_rank = array_ops.rank(real_frames)
      real_frames_shape = array_ops.shape(real_frames)
      paddings = array_ops.concat(
          [array_ops.zeros([real_frames_rank - 1, 2],
                           dtype=frame_length.dtype),
           [[0, math_ops.maximum(0, frame_length - real_frames_shape[-1])]]], 0)
      real_frames = array_ops.pad(real_frames, paddings)
    # We know real_frames's last dimension and frame_length statically. If they
    # are different, then pad or truncate real_frames to frame_length.
    elif real_frames.shape[-1].value > frame_length_static:
      real_frames = real_frames[..., :frame_length_static]
    elif real_frames.shape[-1].value < frame_length_static:
      pad_amount = frame_length_static - real_frames.shape[-1].value
      real_frames = array_ops.pad(real_frames,
                                  [[0, 0]] * (real_frames.shape.ndims - 1) +
                                  [[0, pad_amount]])

    # The above code pads the inner dimension of real_frames to frame_length,
    # but it does so in a way that may not be shape-inference friendly.
    # Restore shape information if we are able to.
    if frame_length_static is not None and real_frames.shape.ndims is not None:
      real_frames.set_shape([None] * (real_frames.shape.ndims - 1) +
                            [frame_length_static])

    # Optionally window and overlap-add the inner 2 dimensions of real_frames
    # into a single [samples] dimension.
    if window_fn is not None:
      window = window_fn(frame_length, dtype=stfts.dtype.real_dtype)
      real_frames *= window
    return reconstruction_ops.overlap_and_add(real_frames, frame_step)
Example #6
0
def inverse_stft(stfts,
                 frame_length,
                 frame_step,
                 fft_length=None,
                 window_fn=functools.partial(window_ops.hann_window,
                                             periodic=True),
                 name=None):
    """Computes the inverse [Short-time Fourier Transform][stft] of `stfts`.

  To reconstruct an original waveform, a complimentary window function should
  be used in inverse_stft. Such a window function can be constructed with
  tf.contrib.signal.inverse_stft_window_fn.

  Example:

  ```python
  frame_length = 400
  frame_step = 160
  waveform = tf.placeholder(dtype=tf.float32, shape=[1000])
  stft = tf.contrib.signal.stft(waveform, frame_length, frame_step)
  inverse_stft = tf.contrib.signal.inverse_stft(
      stft, frame_length, frame_step,
      window_fn=tf.contrib.signal.inverse_stft_window_fn(frame_step))
  ```

  if a custom window_fn is used in stft, it must be passed to
  inverse_stft_window_fn:

  ```python
  frame_length = 400
  frame_step = 160
  window_fn = functools.partial(window_ops.hamming_window, periodic=True),
  waveform = tf.placeholder(dtype=tf.float32, shape=[1000])
  stft = tf.contrib.signal.stft(
      waveform, frame_length, frame_step, window_fn=window_fn)
  inverse_stft = tf.contrib.signal.inverse_stft(
      stft, frame_length, frame_step,
      window_fn=tf.contrib.signal.inverse_stft_window_fn(
         frame_step, forward_window_fn=window_fn))
  ```

  Implemented with GPU-compatible ops and supports gradients.

  Args:
    stfts: A `complex64` `[..., frames, fft_unique_bins]` `Tensor` of STFT bins
      representing a batch of `fft_length`-point STFTs where `fft_unique_bins`
      is `fft_length // 2 + 1`
    frame_length: An integer scalar `Tensor`. The window length in samples.
    frame_step: An integer scalar `Tensor`. The number of samples to step.
    fft_length: An integer scalar `Tensor`. The size of the FFT that produced
      `stfts`. If not provided, uses the smallest power of 2 enclosing
      `frame_length`.
    window_fn: A callable that takes a window length and a `dtype` keyword
      argument and returns a `[window_length]` `Tensor` of samples in the
      provided datatype. If set to `None`, no windowing is used.
    name: An optional name for the operation.

  Returns:
    A `[..., samples]` `Tensor` of `float32` signals representing the inverse
    STFT for each input STFT in `stfts`.

  Raises:
    ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar,
      `frame_step` is not scalar, or `fft_length` is not scalar.

  [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
  """
    with ops.name_scope(name, 'inverse_stft', [stfts]):
        stfts = ops.convert_to_tensor(stfts, name='stfts')
        stfts.shape.with_rank_at_least(2)
        frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
        frame_length.shape.assert_has_rank(0)
        frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
        frame_step.shape.assert_has_rank(0)
        if fft_length is None:
            fft_length = _enclosing_power_of_two(frame_length)
        else:
            fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
            fft_length.shape.assert_has_rank(0)

        real_frames = spectral_ops.irfft(stfts, [fft_length])

        # frame_length may be larger or smaller than fft_length, so we pad or
        # truncate real_frames to frame_length.
        frame_length_static = tensor_util.constant_value(frame_length)
        # If we don't know the shape of real_frames's inner dimension, pad and
        # truncate to frame_length.
        if (frame_length_static is None or real_frames.shape.ndims is None
                or real_frames.shape[-1].value is None):
            real_frames = real_frames[..., :frame_length]
            real_frames_rank = array_ops.rank(real_frames)
            real_frames_shape = array_ops.shape(real_frames)
            paddings = array_ops.concat([
                array_ops.zeros([real_frames_rank - 1, 2],
                                dtype=frame_length.dtype),
                [[
                    0,
                    math_ops.maximum(0, frame_length - real_frames_shape[-1])
                ]]
            ], 0)
            real_frames = array_ops.pad(real_frames, paddings)
        # We know real_frames's last dimension and frame_length statically. If they
        # are different, then pad or truncate real_frames to frame_length.
        elif real_frames.shape[-1].value > frame_length_static:
            real_frames = real_frames[..., :frame_length_static]
        elif real_frames.shape[-1].value < frame_length_static:
            pad_amount = frame_length_static - real_frames.shape[-1].value
            real_frames = array_ops.pad(
                real_frames,
                [[0, 0]] * (real_frames.shape.ndims - 1) + [[0, pad_amount]])

        # The above code pads the inner dimension of real_frames to frame_length,
        # but it does so in a way that may not be shape-inference friendly.
        # Restore shape information if we are able to.
        if frame_length_static is not None and real_frames.shape.ndims is not None:
            real_frames.set_shape([None] * (real_frames.shape.ndims - 1) +
                                  [frame_length_static])

        # Optionally window and overlap-add the inner 2 dimensions of real_frames
        # into a single [samples] dimension.
        if window_fn is not None:
            window = window_fn(frame_length, dtype=stfts.dtype.real_dtype)
            real_frames *= window
        return reconstruction_ops.overlap_and_add(real_frames, frame_step)