def stft(signals, frame_length, frame_step, fft_length=None,
         pad_end=False, name=None):
  """Computes the [Short-time Fourier Transform][stft] of `signals`.

  Implemented with TPU/GPU-compatible ops and supports gradients.

    signals: A `[..., samples]` `float32`/`float64` `Tensor` of real-valued
    frame_length: An integer scalar `Tensor`. The window length in samples.
    frame_step: An integer scalar `Tensor`. The number of samples to step.
    fft_length: An integer scalar `Tensor`. The size of the FFT to apply.
      If not provided, uses the smallest power of 2 enclosing `frame_length`.
    window_fn: A callable that takes a window length and a `dtype` keyword
      argument and returns a `[window_length]` `Tensor` of samples in the
      provided datatype. If set to `None`, no windowing is used.
    pad_end: Whether to pad the end of `signals` with zeros when the provided
      frame length and step produces a frame that lies partially past its end.
    name: An optional name for the operation.

    A `[..., frames, fft_unique_bins]` `Tensor` of `complex64`/`complex128`
    STFT values where `fft_unique_bins` is `fft_length // 2 + 1` (the unique
    components of the FFT).

    ValueError: If `signals` is not at least rank 1, `frame_length` is
      not scalar, or `frame_step` is not scalar.

  with ops.name_scope(name, 'stft', [signals, frame_length,
    signals = ops.convert_to_tensor(signals, name='signals')
    frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
    frame_step = ops.convert_to_tensor(frame_step, name='frame_step')

    if fft_length is None:
      fft_length = _enclosing_power_of_two(frame_length)
      fft_length = ops.convert_to_tensor(fft_length, name='fft_length')

    framed_signals = shape_ops.frame(
        signals, frame_length, frame_step, pad_end=pad_end)

    # Optionally window the framed signals.
    if window_fn is not None:
      window = window_fn(frame_length, dtype=framed_signals.dtype)
      framed_signals *= window

    # fft_ops.rfft produces the (fft_length/2 + 1) unique components of the
    # FFT of the real windowed signals in framed_signals.
    return fft_ops.rfft(framed_signals, [fft_length])
def dct(input, type=2, n=None, axis=-1, norm=None, name=None):  # pylint: disable=redefined-builtin
    """Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`.

  Types I, II, III and IV are supported.
  Type I is implemented using a length `2N` padded `tf.signal.rfft`.
  Type II is implemented using a length `2N` padded `tf.signal.rfft`, as
   described here: [Type 2 DCT using 2N FFT padded (Makhoul)]
  Type III is a fairly straightforward inverse of Type II
   (i.e. using a length `2N` padded `tf.signal.irfft`).
   Type IV is calculated through 2N length DCT2 of padded signal and
  picking the odd indices.

  Equivalent to [scipy.fftpack.dct]
   for Type-I, Type-II, Type-III and Type-IV DCT.

    input: A `[..., samples]` `float32`/`float64` `Tensor` containing the
      signals to take the DCT of.
    type: The DCT type to perform. Must be 1, 2, 3 or 4.
    n: The length of the transform. If length is less than sequence length,
      only the first n elements of the sequence are considered for the DCT.
      If n is greater than the sequence length, zeros are padded and then
      the DCT is computed as usual.
    axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
    norm: The normalization to apply. `None` for no normalization or `'ortho'`
      for orthonormal normalization.
    name: An optional name for the operation.

    A `[..., samples]` `float32`/`float64` `Tensor` containing the DCT of

    ValueError: If `type` is not `1`, `2`, `3` or `4`, `axis` is
      not `-1`, `n` is not `None` or greater than 0,
      or `norm` is not `None` or `'ortho'`.
    ValueError: If `type` is `1` and `norm` is `ortho`.

    _validate_dct_arguments(input, type, n, axis, norm)
    with _ops.name_scope(name, "dct", [input]):
        input = _ops.convert_to_tensor(input)
        zero = _ops.convert_to_tensor(0.0, dtype=input.dtype)

        seq_len = (tensor_shape.dimension_value(input.shape[-1])
                   or _array_ops.shape(input)[-1])
        if n is not None:
            if n <= seq_len:
                input = input[..., 0:n]
                rank = len(input.shape)
                padding = [[0, 0] for _ in range(rank)]
                padding[rank - 1][1] = n - seq_len
                padding = _ops.convert_to_tensor(padding, dtype=_dtypes.int32)
                input = _array_ops.pad(input, paddings=padding)

        axis_dim = (tensor_shape.dimension_value(input.shape[-1])
                    or _array_ops.shape(input)[-1])
        axis_dim_float = _math_ops.cast(axis_dim, input.dtype)

        if type == 1:
            dct1_input = _array_ops.concat([input, input[..., -2:0:-1]],
            dct1 = _math_ops.real(fft_ops.rfft(dct1_input))
            return dct1

        if type == 2:
            scale = 2.0 * _math_ops.exp(
                    zero, -_math_ops.range(axis_dim_float) * _math.pi * 0.5 /

            # TODO(rjryan): Benchmark performance and memory usage of the various
            # approaches to computing a DCT via the RFFT.
            dct2 = _math_ops.real(
                fft_ops.rfft(input, fft_length=[2 * axis_dim])[..., :axis_dim]
                * scale)

            if norm == "ortho":
                n1 = 0.5 * _math_ops.rsqrt(axis_dim_float)
                n2 = n1 * _math.sqrt(2.0)
                # Use tf.pad to make a vector of [n1, n2, n2, n2, ...].
                weights = _array_ops.pad(_array_ops.expand_dims(n1, 0),
                                         [[0, axis_dim - 1]],
                dct2 *= weights

            return dct2

        elif type == 3:
            if norm == "ortho":
                n1 = _math_ops.sqrt(axis_dim_float)
                n2 = n1 * _math.sqrt(0.5)
                # Use tf.pad to make a vector of [n1, n2, n2, n2, ...].
                weights = _array_ops.pad(_array_ops.expand_dims(n1, 0),
                                         [[0, axis_dim - 1]],
                input *= weights
                input *= axis_dim_float
            scale = 2.0 * _math_ops.exp(
                    _math_ops.range(axis_dim_float) * _math.pi * 0.5 /
            dct3 = _math_ops.real(
                fft_ops.irfft(scale * _math_ops.complex(input, zero),
                              fft_length=[2 * axis_dim]))[..., :axis_dim]

            return dct3

        elif type == 4:
            # DCT-2 of 2N length zero-padded signal, unnormalized.
            dct2 = dct(input, type=2, n=2 * axis_dim, axis=axis, norm=None)
            # Get odd indices of DCT-2 of zero padded 2N signal to obtain
            # DCT-4 of the original N length signal.
            dct4 = dct2[..., 1::2]
            if norm == "ortho":
                dct4 *= _math.sqrt(0.5) * _math_ops.rsqrt(axis_dim_float)

            return dct4
def rfft(input_tensor, fft_length=None, name=None):
    return fft_ops.rfft(input_tensor, fft_length, name)
