Ejemplo n.º 1
0
def _MakeLogMelFromTensorflowBuiltin(tf_wav_bytes):
    sample_rate, audio = audio_lib.DecodeWav(tf_wav_bytes)
    static_sample_rate = 16000
    with tf.control_dependencies(
        [tf.assert_equal(sample_rate, static_sample_rate)]):
        log_mel = audio_lib.AudioToMfcc(static_sample_rate, audio, 25, 25, 40)
    return log_mel
Ejemplo n.º 2
0
 def testDecodeWav(self):
     with open(
             test_helper.test_src_dir_path('tools/testdata/gan_or_vae.wav'),
             'rb') as f:
         wav = f.read()
     with self.session() as sess:
         sample_rate, audio = sess.run(audio_lib.DecodeWav(wav))
         self.assertEqual(24000, sample_rate)
         self.assertEqual(75900, len(audio))
Ejemplo n.º 3
0
 def testAudioToMfcc(self):
     with open(
             test_helper.test_src_dir_path('tools/testdata/gan_or_vae.wav'),
             'rb') as f:
         wav = f.read()
     sample_rate, audio = audio_lib.DecodeWav(wav)
     static_sample_rate = 24000
     mfcc = audio_lib.AudioToMfcc(static_sample_rate, audio, 32, 25, 40)
     with self.session() as sess:
         audio_sample_rate, mfcc = sess.run([sample_rate, mfcc])
         assert audio_sample_rate == static_sample_rate
         self.assertAllEqual(mfcc.shape, [1, 126, 40])
Ejemplo n.º 4
0
def _MakeLogMel(tf_wav_bytes):
    sample_rate, audio = audio_lib.DecodeWav(tf_wav_bytes)
    audio *= 32768
    # Remove channel dimension, since we have a single channel.
    audio = tf.squeeze(audio, axis=1)
    # TODO(drpng): make batches.
    audio = tf.expand_dims(audio, axis=0)
    static_sample_rate = 16000
    mel_frontend = _CreateAsrFrontend()
    with tf.control_dependencies(
        [tf.assert_equal(sample_rate, static_sample_rate)]):
        log_mel, _ = mel_frontend.FPropDefaultTheta(audio)
    return log_mel
Ejemplo n.º 5
0
    def _InferenceSubgraph_Default(self):
        """Constructs graph for offline inference.

    Returns:
      (fetches, feeds) where both fetches and feeds are dictionaries. Each
      dictionary consists of keys corresponding to tensor names, and values
      corresponding to a tensor in the graph which should be input/read from.
    """
        p = self.params
        with tf.name_scope('default'):
            # TODO(laurenzo): Once the migration to integrated frontends is complete,
            # this model should be upgraded to use the MelAsrFrontend in its
            # params vs relying on pre-computed feature generation and the inference
            # special casing.
            wav_bytes = tf.placeholder(dtype=tf.string, name='wav')
            frontend = self.frontend if p.frontend else None
            if not frontend:
                # No custom frontend. Instantiate the default.
                frontend_p = asr_frontend.MelAsrFrontend.Params()
                frontend = frontend_p.cls(frontend_p)

            # Decode the wave bytes and use the explicit frontend.
            unused_sample_rate, audio = audio_lib.DecodeWav(wav_bytes)
            audio *= 32768
            # Remove channel dimension, since we have a single channel.
            audio = tf.squeeze(audio, axis=1)
            # Add batch.
            audio = tf.expand_dims(audio, axis=0)
            input_batch_src = py_utils.NestedMap(src_inputs=audio,
                                                 paddings=tf.zeros_like(audio))
            input_batch_src = frontend.FPropDefaultTheta(input_batch_src)

            # Undo default stacking, if specified.
            input_batch_src.src_inputs = tf.reshape(
                input_batch_src.src_inputs, [1, -1, p.input.frame_size, 1])
            input_batch_src.paddings = tf.zeros(
                shape=[1, tf.shape(input_batch_src.src_inputs)[1]])

            encoder_outputs = self.encoder.FPropDefaultTheta(input_batch_src)
            decoder_outputs = self.decoder.BeamSearchDecode(encoder_outputs)
            topk = self._GetTopK(decoder_outputs)

            feeds = {'wav': wav_bytes}
            fetches = {
                'hypotheses': topk.decoded,
                'scores': topk.scores,
                'src_frames': input_batch_src.src_inputs,
                'encoder_frames': encoder_outputs.encoded
            }

            return fetches, feeds