Ejemplo n.º 1
0
 def testDecodeFlacToWav(self):
     with open(
             test_helper.test_src_dir_path('tools/testdata/gan_or_vae.wav'),
             'rb') as f:
         wav = f.read()
     with open(
             test_helper.test_src_dir_path(
                 'tools/testdata/gan_or_vae.flac'), 'rb') as f:
         flac = f.read()
     tf.logging.info('flac: %d bytes', len(flac))
     try:
         converted = audio_lib.DecodeFlacToWav(flac)
         tf.logging.info('wav: %d bytes, converted: %d bytes', len(wav),
                         len(converted))
         self.assertEqual(wav, converted)
     except OSError:
         # sox is not installed, ignore this test.
         pass
Ejemplo n.º 2
0
def _CreateAsrFeatures():
    # First pass: extract transcription files.
    if os.path.exists(FLAGS.transcripts_filepath):
        trans = _LoadTranscriptionsFromFile()
    else:
        tf.logging.info('Running first pass on the fly')
        trans = _ReadTranscriptions()
    tf.logging.info('Total transcripts: %d', len(trans))
    tf_bytes = tf.placeholder(dtype=tf.string)
    # Great! It uses the frontend directly
    log_mel = audio_lib.ExtractLogMelFeatures(tf_bytes)
    # Second pass: transcode the flac.
    file_obj = tf.io.gfile.GFile(FLAGS.input_tarball, mode='rb')
    tar = tarfile.open(fileobj=file_obj, mode='r:gz')
    n = 0
    recordio_writers = _OpenSubShards()
    tfconf = tf.config_pb2.ConfigProto()
    tfconf.gpu_options.allow_growth = True
    with tf.Session(config=tfconf) as sess:
        for tarinfo in tar:
            if not tarinfo.name.endswith('.flac'):
                continue
            n += 1
            if n % FLAGS.num_shards != FLAGS.shard_id:
                continue
            uttid = re.sub('.*/(.+)\\.flac', '\\1', tarinfo.name)
            f = tar.extractfile(tarinfo)
            wav_bytes = audio_lib.DecodeFlacToWav(f.read())
            f.close()
            frames = sess.run(log_mel, feed_dict={tf_bytes: wav_bytes})
            assert uttid in trans, uttid
            num_words = len(trans[uttid])
            tf.logging.info('utt[%d]: %s [%d frames, %d words]', n, uttid,
                            frames.shape[1], num_words)
            ex = _MakeTfExample(uttid, frames, trans[uttid])
            outf = _SelectRandomShard(recordio_writers)
            outf.write(ex.SerializeToString())
        tar.close()
    file_obj.close()
    _CloseSubShards(recordio_writers)