def _pad_time_signal(input_placeholder, frame_size): if frame_size > self._reference_frame_size: return tf.concat( [input_signal, tf.ones( [get_shape(input_signal)[0], frame_size - self._reference_frame_size, get_shape(input_signal)[2]]) * 1e-7], axis=1) else: return input_placeholder
def _crop_stft_output_to_reference_frame_size_length( channel_concatenated_stft, crop_size): return tf.slice(channel_concatenated_stft, [0, 0, 0], [ get_shape(channel_concatenated_stft)[0], crop_size, get_shape(channel_concatenated_stft)[2] ])
def _apply_stft_to_input(self): from returnn.tf.util.basic import get_shape # noinspection PyShadowingNames def _crop_stft_output_to_reference_frame_size_length( channel_concatenated_stft, crop_size): return tf.slice(channel_concatenated_stft, [0, 0, 0], [ get_shape(channel_concatenated_stft)[0], crop_size, get_shape(channel_concatenated_stft)[2] ]) input_placeholder = self.input_data.get_placeholder_as_batch_major() channel_wise_stft_res_list = list() for fft_size, frame_size in zip(self._fft_sizes, self._frame_sizes): def _get_window(window_length, dtype): if self._window == "hanning": try: # noinspection PyPackageRequirements from tensorflow.signal import hann_window except ImportError: # noinspection PyPackageRequirements,PyUnresolvedReferences from tensorflow.contrib.signal import hann_window window = hann_window(window_length, dtype=dtype) elif self._window == "blackman": # noinspection PyPackageRequirements import scipy.signal window = tf.constant( scipy.signal.windows.blackman(frame_size), dtype=tf.float32) elif self._window == "None" or self._window == "ones": window = tf.ones((window_length, ), dtype=dtype) else: assert False, "Window was not parsed correctly: {}".format( self._window) return window # noinspection PyShadowingNames def _pad_time_signal(input_placeholder, frame_size): if frame_size > self._reference_frame_size: return tf.concat([ input_signal, tf.ones([ get_shape(input_signal)[0], frame_size - self._reference_frame_size, get_shape(input_signal)[2] ]) * 1e-7 ], axis=1) else: return input_placeholder input_signal = _pad_time_signal(input_placeholder, frame_size) if self._use_rfft: try: # noinspection PyPackageRequirements from tensorflow.signal import stft except ImportError: # noinspection PyPackageRequirements,PyUnresolvedReferences from tensorflow.contrib.signal import stft channel_wise_stft = stft(signals=tf.transpose( input_signal, [0, 2, 1]), frame_length=frame_size, frame_step=self._frame_shift, fft_length=fft_size, window_fn=_get_window, pad_end=self._pad_last_frame) channel_wise_stft = tf.transpose(channel_wise_stft, [0, 2, 1, 3]) batch_dim = tf.shape(channel_wise_stft)[0] time_dim = tf.shape(channel_wise_stft)[1] concat_feature_dim = channel_wise_stft.shape[ 2] * channel_wise_stft.shape[3] channel_concatenated_stft = tf.reshape( channel_wise_stft, (batch_dim, time_dim, concat_feature_dim)) if channel_wise_stft_res_list: channel_concatenated_stft = ( _crop_stft_output_to_reference_frame_size_length( channel_concatenated_stft, get_shape(channel_wise_stft_res_list[0])[1])) channel_wise_stft_res_list.append(channel_concatenated_stft) output_placeholder = tf.concat(channel_wise_stft_res_list, axis=2) return output_placeholder
def _cropStftOutputToReferenceFrameSizeLength( channel_concatenated_stft, crop_size): return tf.slice(channel_concatenated_stft, [0, 0, 0], [ get_shape(channel_concatenated_stft)[0], crop_size, get_shape(channel_concatenated_stft)[2] ])
def _apply_stft_to_input(self): from returnn.tf.util.basic import get_shape def _cropStftOutputToReferenceFrameSizeLength( channel_concatenated_stft, crop_size): return tf.slice(channel_concatenated_stft, [0, 0, 0], [ get_shape(channel_concatenated_stft)[0], crop_size, get_shape(channel_concatenated_stft)[2] ]) input_placeholder = self.input_data.get_placeholder_as_batch_major() channel_wise_stft_res_list = list() for fft_size, frame_size in zip(self._fft_sizes, self._frame_sizes): def _get_window(window_length, dtype): if self._window == "hanning": window = tf.contrib.signal.hann_window(window_length, dtype=dtype) if self._window == "blackman": tf_compat.v1.assert_equal(frame_size, window_length) import scipy.signal window = tf.constant(scipy.signal.blackman(frame_size), dtype=tf.float32) if self._window == "None" or self._window == "ones": window = tf.ones((window_length, ), dtype=dtype) return window def _padTimeSignal(input_placeholder, frame_size): if frame_size > self._reference_frame_size: return tf.concat([ input_signal, tf.ones([ get_shape(input_signal)[0], frame_size - self._reference_frame_size, get_shape(input_signal)[2] ]) * 1e-7 ], axis=1) else: return input_placeholder input_signal = _padTimeSignal(input_placeholder, frame_size) if self._use_rfft: channel_wise_stft = tf.contrib.signal.stft( signals=tf.transpose(input_signal, [0, 2, 1]), frame_length=frame_size, frame_step=self._frame_shift, fft_length=fft_size, window_fn=_get_window, pad_end=self._pad_last_frame) channel_wise_stft = tf.transpose(channel_wise_stft, [0, 2, 1, 3]) batch_dim = tf.shape(channel_wise_stft)[0] time_dim = tf.shape(channel_wise_stft)[1] concat_feature_dim = channel_wise_stft.shape[ 2] * channel_wise_stft.shape[3] channel_concatenated_stft = tf.reshape( channel_wise_stft, (batch_dim, time_dim, concat_feature_dim)) if channel_wise_stft_res_list: channel_concatenated_stft = _cropStftOutputToReferenceFrameSizeLength( channel_concatenated_stft, get_shape(channel_wise_stft_res_list[0])[1]) channel_wise_stft_res_list.append(channel_concatenated_stft) output_placeholder = tf.concat(channel_wise_stft_res_list, axis=2) return output_placeholder