示例#1
0
    def testStreaming(self, input_samples):
        # prepare non streaming model
        stft_layer = stft.STFT(self.frame_size,
                               self.frame_step,
                               mode=modes.Modes.TRAINING,
                               inference_batch_size=1,
                               padding='causal')
        input_tf = tf.keras.layers.Input(shape=(self.input_signal.shape[1], ),
                                         batch_size=1)
        net = stft_layer(input_tf)
        model_non_stream = tf.keras.models.Model(input_tf, net)

        params = test_utils.Params([1])
        # shape of input data in the inference streaming mode (excluding batch size)
        params.data_shape = (input_samples * stft_layer.frame_step, )
        params.step = input_samples

        # convert it to streaming model
        model_stream = utils.to_streaming_inference(
            model_non_stream, params,
            modes.Modes.STREAM_INTERNAL_STATE_INFERENCE)
        model_stream.summary()

        # run streaming inference and compare it with default stft
        stream_out = inference.run_stream_inference(params, model_stream,
                                                    self.input_signal)
        stream_output_length = stream_out.shape[1]
        self.assertAllClose(stream_out, self.stft_out[:,
                                                      0:stream_output_length])
示例#2
0
 def testNonStreaming(self):
     # prepare non streaming model and compare it with default stft
     stft_layer = stft.STFT(self.frame_size,
                            self.frame_step,
                            mode=modes.Modes.TRAINING,
                            inference_batch_size=1,
                            padding='causal')
     input_tf = tf.keras.layers.Input(shape=(self.input_signal.shape[1], ),
                                      batch_size=1)
     net = stft_layer(input_tf)
     model_non_stream = tf.keras.models.Model(input_tf, net)
     self.non_stream_out = model_non_stream.predict(self.input_signal)
     self.assertAllClose(self.non_stream_out, self.stft_out)
示例#3
0
    def setUp(self):
        super(STFTTest, self).setUp()
        test_utils.set_seed(123)

        self.frame_size = 40
        self.frame_step = 10
        # layer definition
        stft_layer = stft.STFT(self.frame_size,
                               self.frame_step,
                               mode=modes.Modes.TRAINING,
                               inference_batch_size=1,
                               padding='causal')

        if stft_layer.window_type == 'hann_tf':
            synthesis_window_fn = tf.signal.hann_window
        else:
            synthesis_window_fn = None

        # prepare input data
        self.input_signal = np.random.rand(1, 120)

        # prepare default tf stft
        padding_layer = temporal_padding.TemporalPadding(
            padding_size=stft_layer.frame_size - 1, padding=stft_layer.padding)
        # pylint: disable=g-long-lambda
        stft_default_layer = tf.keras.layers.Lambda(
            lambda x: tf.signal.stft(x,
                                     stft_layer.frame_size,
                                     stft_layer.frame_step,
                                     fft_length=stft_layer.fft_size,
                                     window_fn=synthesis_window_fn,
                                     pad_end=False))
        # pylint: enable=g-long-lambda
        input_tf = tf.keras.layers.Input(shape=(self.input_signal.shape[1], ),
                                         batch_size=1)
        net = padding_layer(input_tf)
        net = stft_default_layer(net)

        model_stft = tf.keras.models.Model(input_tf, net)

        self.stft_out = model_stft.predict(self.input_signal)