Example #1
0
 def _pad_time_signal(input_placeholder, frame_size):
   if frame_size > self._reference_frame_size:
     return tf.concat(
       [input_signal, tf.ones(
         [get_shape(input_signal)[0], frame_size - self._reference_frame_size,
          get_shape(input_signal)[2]]) * 1e-7], axis=1)
   else:
     return input_placeholder
Example #2
0
 def _crop_stft_output_to_reference_frame_size_length(
         channel_concatenated_stft, crop_size):
     return tf.slice(channel_concatenated_stft, [0, 0, 0], [
         get_shape(channel_concatenated_stft)[0], crop_size,
         get_shape(channel_concatenated_stft)[2]
     ])
Example #3
0
    def _apply_stft_to_input(self):
        from returnn.tf.util.basic import get_shape

        # noinspection PyShadowingNames
        def _crop_stft_output_to_reference_frame_size_length(
                channel_concatenated_stft, crop_size):
            return tf.slice(channel_concatenated_stft, [0, 0, 0], [
                get_shape(channel_concatenated_stft)[0], crop_size,
                get_shape(channel_concatenated_stft)[2]
            ])

        input_placeholder = self.input_data.get_placeholder_as_batch_major()
        channel_wise_stft_res_list = list()
        for fft_size, frame_size in zip(self._fft_sizes, self._frame_sizes):

            def _get_window(window_length, dtype):
                if self._window == "hanning":
                    try:
                        # noinspection PyPackageRequirements
                        from tensorflow.signal import hann_window
                    except ImportError:
                        # noinspection PyPackageRequirements,PyUnresolvedReferences
                        from tensorflow.contrib.signal import hann_window
                    window = hann_window(window_length, dtype=dtype)
                elif self._window == "blackman":
                    # noinspection PyPackageRequirements
                    import scipy.signal
                    window = tf.constant(
                        scipy.signal.windows.blackman(frame_size),
                        dtype=tf.float32)
                elif self._window == "None" or self._window == "ones":
                    window = tf.ones((window_length, ), dtype=dtype)
                else:
                    assert False, "Window was not parsed correctly: {}".format(
                        self._window)
                return window

            # noinspection PyShadowingNames
            def _pad_time_signal(input_placeholder, frame_size):
                if frame_size > self._reference_frame_size:
                    return tf.concat([
                        input_signal,
                        tf.ones([
                            get_shape(input_signal)[0],
                            frame_size - self._reference_frame_size,
                            get_shape(input_signal)[2]
                        ]) * 1e-7
                    ],
                                     axis=1)
                else:
                    return input_placeholder

            input_signal = _pad_time_signal(input_placeholder, frame_size)
            if self._use_rfft:
                try:
                    # noinspection PyPackageRequirements
                    from tensorflow.signal import stft
                except ImportError:
                    # noinspection PyPackageRequirements,PyUnresolvedReferences
                    from tensorflow.contrib.signal import stft
                channel_wise_stft = stft(signals=tf.transpose(
                    input_signal, [0, 2, 1]),
                                         frame_length=frame_size,
                                         frame_step=self._frame_shift,
                                         fft_length=fft_size,
                                         window_fn=_get_window,
                                         pad_end=self._pad_last_frame)
                channel_wise_stft = tf.transpose(channel_wise_stft,
                                                 [0, 2, 1, 3])
                batch_dim = tf.shape(channel_wise_stft)[0]
                time_dim = tf.shape(channel_wise_stft)[1]
                concat_feature_dim = channel_wise_stft.shape[
                    2] * channel_wise_stft.shape[3]
                channel_concatenated_stft = tf.reshape(
                    channel_wise_stft,
                    (batch_dim, time_dim, concat_feature_dim))
                if channel_wise_stft_res_list:
                    channel_concatenated_stft = (
                        _crop_stft_output_to_reference_frame_size_length(
                            channel_concatenated_stft,
                            get_shape(channel_wise_stft_res_list[0])[1]))
                channel_wise_stft_res_list.append(channel_concatenated_stft)
        output_placeholder = tf.concat(channel_wise_stft_res_list, axis=2)
        return output_placeholder
Example #4
0
 def _cropStftOutputToReferenceFrameSizeLength(
         channel_concatenated_stft, crop_size):
     return tf.slice(channel_concatenated_stft, [0, 0, 0], [
         get_shape(channel_concatenated_stft)[0], crop_size,
         get_shape(channel_concatenated_stft)[2]
     ])
Example #5
0
    def _apply_stft_to_input(self):
        from returnn.tf.util.basic import get_shape

        def _cropStftOutputToReferenceFrameSizeLength(
                channel_concatenated_stft, crop_size):
            return tf.slice(channel_concatenated_stft, [0, 0, 0], [
                get_shape(channel_concatenated_stft)[0], crop_size,
                get_shape(channel_concatenated_stft)[2]
            ])

        input_placeholder = self.input_data.get_placeholder_as_batch_major()
        channel_wise_stft_res_list = list()
        for fft_size, frame_size in zip(self._fft_sizes, self._frame_sizes):

            def _get_window(window_length, dtype):
                if self._window == "hanning":
                    window = tf.contrib.signal.hann_window(window_length,
                                                           dtype=dtype)
                if self._window == "blackman":
                    tf_compat.v1.assert_equal(frame_size, window_length)
                    import scipy.signal
                    window = tf.constant(scipy.signal.blackman(frame_size),
                                         dtype=tf.float32)
                if self._window == "None" or self._window == "ones":
                    window = tf.ones((window_length, ), dtype=dtype)
                return window

            def _padTimeSignal(input_placeholder, frame_size):
                if frame_size > self._reference_frame_size:
                    return tf.concat([
                        input_signal,
                        tf.ones([
                            get_shape(input_signal)[0],
                            frame_size - self._reference_frame_size,
                            get_shape(input_signal)[2]
                        ]) * 1e-7
                    ],
                                     axis=1)
                else:
                    return input_placeholder

            input_signal = _padTimeSignal(input_placeholder, frame_size)
            if self._use_rfft:
                channel_wise_stft = tf.contrib.signal.stft(
                    signals=tf.transpose(input_signal, [0, 2, 1]),
                    frame_length=frame_size,
                    frame_step=self._frame_shift,
                    fft_length=fft_size,
                    window_fn=_get_window,
                    pad_end=self._pad_last_frame)
                channel_wise_stft = tf.transpose(channel_wise_stft,
                                                 [0, 2, 1, 3])
                batch_dim = tf.shape(channel_wise_stft)[0]
                time_dim = tf.shape(channel_wise_stft)[1]
                concat_feature_dim = channel_wise_stft.shape[
                    2] * channel_wise_stft.shape[3]
                channel_concatenated_stft = tf.reshape(
                    channel_wise_stft,
                    (batch_dim, time_dim, concat_feature_dim))
                if channel_wise_stft_res_list:
                    channel_concatenated_stft = _cropStftOutputToReferenceFrameSizeLength(
                        channel_concatenated_stft,
                        get_shape(channel_wise_stft_res_list[0])[1])
                channel_wise_stft_res_list.append(channel_concatenated_stft)
        output_placeholder = tf.concat(channel_wise_stft_res_list, axis=2)
        return output_placeholder