示例#1
0
def get_stft(
    y,
    return_magnitude=True,
    frame_length=4096,
    frame_step=1024,
    T=512,
    F=1024,
):

    waveform = tf.concat([tf.zeros(
        (frame_length, 1)), tf.expand_dims(y, -1)], 0)
    stft_feature = tf.transpose(
        stft(
            tf.transpose(waveform),
            frame_length,
            frame_step,
            window_fn=lambda frame_length, dtype:
            (hann_window(frame_length, periodic=True, dtype=dtype)),
            pad_end=True,
        ),
        perm=[1, 2, 0],
    )
    if return_magnitude:
        D = tf.abs(pad_and_partition(stft_feature, T))[:, :, :F, :]
        return stft_feature, D
    else:
        return stft_feature
示例#2
0
    def _build_stft_feature(self):
        """ Compute STFT of waveform and slice the STFT in segment
         with the right length to feed the network.
        """

        stft_name = self.stft_name
        spec_name = self.spectrogram_name

        if stft_name not in self._features:
            # pad input with a frame of zeros
            waveform = tf.concat([
                tf.zeros((self._frame_length, self._n_channels)),
                self._features['waveform']
            ], 0)
            stft_feature = tf.transpose(
                stft(tf.transpose(waveform),
                     self._frame_length,
                     self._frame_step,
                     window_fn=lambda frame_length, dtype:
                     (hann_window(frame_length, periodic=True, dtype=dtype)),
                     pad_end=True),
                perm=[1, 2, 0])
            self._features[f'{self._mix_name}_stft'] = stft_feature
        if spec_name not in self._features:
            self._features[spec_name] = tf.abs(
                pad_and_partition(self._features[stft_name],
                                  self._T))[:, :, :self._F, :]
示例#3
0
def istft(
    stft_t,
    y,
    time_crop = None,
    factor = 2 / 3,
    frame_length = 4096,
    frame_step = 1024,
    T = 512,
    F = 1024,
):

    inversed = (
        inverse_stft(
            tf.transpose(stft_t, perm = [2, 0, 1]),
            frame_length,
            frame_step,
            window_fn = lambda frame_length, dtype: (
                hann_window(frame_length, periodic = True, dtype = dtype)
            ),
        )
        * factor
    )
    reshaped = tf.transpose(inversed)
    if time_crop is None:
        time_crop = tf.shape(y)[0]
    return reshaped[frame_length : frame_length + time_crop, :]
 def get_stft(X):
     return tf.signal.stft(
         X,
         frame_length,
         frame_step,
         window_fn=lambda frame_length, dtype:
         (hann_window(frame_length, periodic=True, dtype=dtype)),
         pad_end=True,
     )
示例#5
0
    def _inverse_stft(self, stft_t, time_crop=None):
        """ Inverse and reshape the given STFT

        :param stft_t: input STFT
        :returns: inverse STFT (waveform)
        """
        inversed = inverse_stft(
            tf.transpose(stft_t, perm=[2, 0, 1]),
            self._frame_length,
            self._frame_step,
            window_fn=lambda frame_length, dtype:
            (hann_window(frame_length, periodic=True, dtype=dtype)
             )) * self.WINDOW_COMPENSATION_FACTOR
        reshaped = tf.transpose(inversed)
        if time_crop is None:
            time_crop = tf.shape(self._features['waveform'])[0]
        return reshaped[self._frame_length:self._frame_length + time_crop, :]
示例#6
0
 def _get_window(window_length, dtype):
   if self._window == "hanning":
     try:
       # noinspection PyPackageRequirements
       from tensorflow.signal import hann_window
     except ImportError:
       # noinspection PyPackageRequirements,PyUnresolvedReferences
       from tensorflow.contrib.signal import hann_window
     window = hann_window(window_length, dtype=dtype)
   elif self._window == "blackman":
     # noinspection PyPackageRequirements
     import scipy.signal
     window = tf.constant(scipy.signal.windows.blackman(frame_size), dtype=tf.float32)
   elif self._window == "None" or self._window == "ones":
     window = tf.ones((window_length,), dtype=dtype)
   else:
     assert False, "Window was not parsed correctly: {}".format(self._window)
   return window
示例#7
0
def compute_spectrogram_tf(
    waveform: tf.Tensor,
    frame_length: int = 2048,
    frame_step: int = 512,
    spec_exponent: float = 1.0,
    window_exponent: float = 1.0,
) -> tf.Tensor:
    """
    Compute magnitude / power spectrogram from waveform as a
    `n_samples x n_channels` tensor.

    Parameters:
        waveform (tensorflow.Tensor):
            Input waveform as `(times x number of channels)` tensor.
        frame_length (int):
            Length of a STFT frame to use.
        frame_step (int):
            HOP between successive frames.
        spec_exponent (float):
            Exponent of the spectrogram (usually 1 for magnitude
            spectrogram, or 2 for power spectrogram).
        window_exponent (float):
            Exponent applied to the Hann windowing function (may be
            useful for making perfect STFT/iSTFT reconstruction).

    Returns:
        tensorflow.Tensor:
            Computed magnitude / power spectrogram as a
            `(T x F x n_channels)` tensor.
    """
    stft_tensor: tf.Tensor = tf.transpose(
        stft(
            tf.transpose(waveform),
            frame_length,
            frame_step,
            window_fn=lambda f, dtype: hann_window(
                f, periodic=True, dtype=waveform.dtype
            )
            ** window_exponent,
        ),
        perm=[1, 2, 0],
    )
    return tf.abs(stft_tensor) ** spec_exponent
    def __init__(self, X, Y, frame_length=4096, frame_step=1024):
        def get_stft(X):
            return tf.signal.stft(
                X,
                frame_length,
                frame_step,
                window_fn=lambda frame_length, dtype:
                (hann_window(frame_length, periodic=True, dtype=dtype)),
                pad_end=True,
            )

        stft_X = get_stft(X)
        stft_Y = get_stft(Y)
        mag_X = tf.abs(stft_X)
        mag_Y = tf.abs(stft_Y)

        angle_X = tf.math.imag(stft_X)
        angle_Y = tf.math.imag(stft_Y)

        partitioned_mag_X = tf_featurization.pad_and_partition(mag_X, 512)
        partitioned_angle_X = tf_featurization.pad_and_partition(angle_X, 512)
        params = {'conv_n_filters': [32 * (2**i) for i in range(6)]}

        with tf.variable_scope('model_mag'):
            mix_mag = tf.expand_dims(partitioned_mag_X, 3)[:, :, :-1, :]
            mix_mag_logits = unet.Model(
                mix_mag,
                output_mask_logit=True,
                dropout=0.0,
                training=True,
                params=params,
            ).logits
            mix_mag_logits = tf.squeeze(mix_mag_logits, 3)
            mix_mag_logits = tf.pad(mix_mag_logits, [(0, 0), (0, 0), (0, 1)],
                                    mode='CONSTANT')
            mix_mag_logits = tf.nn.relu(mix_mag_logits)

        with tf.variable_scope('model_angle'):
            mix_angle = tf.expand_dims(partitioned_angle_X, 3)[:, :, :-1, :]
            mix_angle_logits = unet.Model(
                mix_angle,
                output_mask_logit=True,
                dropout=0.0,
                training=True,
                params=params,
            ).logits
            mix_angle_logits = tf.squeeze(mix_angle_logits, 3)
            mix_angle_logits = tf.pad(mix_angle_logits, [(0, 0), (0, 0),
                                                         (0, 1)],
                                      mode='CONSTANT')

        partitioned_mag_Y = tf_featurization.pad_and_partition(mag_Y, 512)
        partitioned_angle_Y = tf_featurization.pad_and_partition(angle_Y, 512)

        self.mag_l1 = tf.reduce_mean(tf.abs(partitioned_mag_Y -
                                            mix_mag_logits))
        self.angle_l1 = tf.reduce_mean(
            tf.abs(partitioned_angle_Y - mix_angle_logits))
        self.cost = self.mag_l1 + self.angle_l1

        def get_original_shape(D, stft):
            instrument_mask = D

            old_shape = tf.shape(instrument_mask)
            new_shape = tf.concat(
                [[old_shape[0] * old_shape[1]], old_shape[2:]], axis=0)
            instrument_mask = tf.reshape(instrument_mask, new_shape)
            instrument_mask = instrument_mask[:tf.shape(stft)[0]]
            return instrument_mask

        _mag = get_original_shape(tf.expand_dims(mix_mag_logits, -1), stft_X)
        _angle = get_original_shape(tf.expand_dims(mix_angle_logits, -1),
                                    stft_X)

        stft = tf.multiply(tf.complex(_mag, 0.0),
                           tf.exp(tf.complex(0.0, _angle)))

        inverse_stft_X = inverse_stft(
            stft[:, :, 0],
            frame_length,
            frame_step,
            window_fn=lambda frame_length, dtype:
            (hann_window(frame_length, periodic=True, dtype=dtype)),
        )