def test_chunked_correctness(self, emb_on_chnks): class MockModuleConstant(object): def __init__(self, output_keys): self.signatures = {'waveform': self._fn} self.output_keys = output_keys def _fn(self, waveform, paddings): del paddings print(f'waveform.shape: {waveform.shape}') bs, l = waveform.shape tdim = l / 1000 assert tdim == int(tdim) ones = tf.ones([1, int(tdim), 10], tf.float32) assert waveform[0, 0].numpy().size == 1, waveform[0, 0] e = tf.concat( [ones * float(waveform.numpy()[i, 0]) for i in range(bs)], axis=0) return {k: e for k in self.output_keys} dofn = beam_dofns.ChunkAudioAndComputeEmbeddings( name='all', module='dummy_name', output_key=['okey'], embedding_names=['em'], audio_key='audio', label_key='label', speaker_id_key='speaker_id', sample_rate_key=None, sample_rate=16000, average_over_time=True, chunk_len=8000, compute_embeddings_on_chunked_audio=emb_on_chnks, setup_fn=lambda _: MockModuleConstant(['okey'])) dofn.setup() k = 'key_8000' ex = make_tfexample(16000) os = list(dofn.process((k, ex))) self.assertLen(os, 2) # First chunk. (kn, aud, _, _, embs_d) = os[0] self.assertEqual(f'{k}_0', kn) self.assertLen(aud, 8000) self.assertLen(embs_d, 1) emb = embs_d['em'] self.assertEqual(emb.shape, (1, 10)) np.testing.assert_equal(emb, 0.0) # Second chunk. (kn, aud, _, _, embs_d) = os[1] self.assertEqual(f'{k}_1', kn) self.assertLen(aud, 8000) self.assertLen(embs_d, 1) emb = embs_d['em'] self.assertEqual(emb.shape, (1, 10)) np.testing.assert_equal(emb, 8000 if emb_on_chnks else 0)
def test_pipeline_padding(self, process_fn, chunk_len): """Check that the model input is of sufficient length.""" k, ex = 'key', make_tfexample(100) common_args = dict(name='name', module=None, output_key=['output_key'], audio_key='audio', sample_rate_key='sample_rate', sample_rate=None, average_over_time=True, model_input_min_length=400, setup_fn=lambda _: FakeMod()) if process_fn == 'ComputeEmbeddingMapFn': beam_dofn = beam_dofns.ComputeEmbeddingMapFn(**common_args) elif process_fn == 'ComputeMultipleEmbeddings': beam_dofn = beam_dofns.ComputeMultipleEmbeddingsFromSingleModel( embedding_names=['em1'], chunk_len=chunk_len, **common_args) elif process_fn == 'ChunkAudioAndComputeEmbeddings': beam_dofn = beam_dofns.ChunkAudioAndComputeEmbeddings( embedding_names=['em1'], chunk_len=chunk_len, **common_args) else: assert process_fn == 'ComputeBatchedChunkedSingleEmbeddings' beam_dofn = beam_dofns.ComputeBatchedChunkedSingleEmbeddings( **common_args) # Run preprocessing step. beam_dofn.setup() if process_fn == 'ComputeEmbeddingMapFn': model_input, sample_rate = beam_dofn.read_and_preprocess_audio( k, ex) expected_output_shape = (400, ) elif process_fn == 'ComputeBatchedChunkedSingleEmbeddings': model_input, _, sample_rate = beam_dofn.read_and_preprocess_batched_audio( [k, k], [ex, ex]) expected_output_shape = (2, 400) else: model_input, sample_rate = beam_dofn.tfex_to_chunked_audio(k, ex) expected_output_shape = (2, chunk_len) if chunk_len else (1, 400) # Original audio is too short, so it should be padded to # `model_input_min_length`. self.assertEqual(model_input.shape, expected_output_shape) # Having a non-standard sample rate should trigger resampling and cause the # output to be 16kHz. self.assertEqual(sample_rate, 16000)
def test_chunk_audio(self, chunk_len, average_over_time, emb_on_chnks): dofn = beam_dofns.ChunkAudioAndComputeEmbeddings( name='all', module='dummy_name', output_key=['okey1', 'okey2'], embedding_names=['em1', 'em2'], audio_key='audio', label_key='label', speaker_id_key='speaker_id', sample_rate_key=None, sample_rate=16000, average_over_time=average_over_time, chunk_len=chunk_len, compute_embeddings_on_chunked_audio=emb_on_chnks, setup_fn=lambda _: MockModule(['okey1', 'okey2'])) dofn.setup() for l in [8000, 16000, 32000]: k = f'key_{l}' ex = make_tfexample(l) for i, (kn, aud, lbl, spkr, embs_d) in enumerate(dofn.process((k, ex))): self.assertEqual(f'{k}_{i}', kn) if chunk_len: expected_chunk_len = chunk_len if l > chunk_len else l else: expected_chunk_len = l self.assertLen(aud, expected_chunk_len) self.assertEqual(lbl, b'dummy_lbl') self.assertEqual(spkr, b'dummy_spkr') for _, emb in embs_d.items(): self.assertEqual(emb.shape, (1 if average_over_time else 5, 10)) # Now run the next stage of the pipeline on it. # TODO(joelshor): Add correctness checks on the output. data_prep_utils.chunked_audio_to_tfex( (kn, aud, lbl, spkr, embs_d), delete_audio_from_output=True, pass_through_normalized_audio=False, chunk_len=chunk_len, embedding_length=10)
def precompute_chunked_audio_pipeline( root, input_filenames, output_filename, sample_rate, debug, embedding_names, embedding_modules, module_output_keys, audio_key, sample_rate_key, label_key=None, speaker_id_key=None, average_over_time=True, delete_audio_from_output=True, split_embeddings_into_separate_tables=False, use_frontend_fn=False, normalize_to_pm_one=True, compute_embeddings_on_chunked_audio=True, model_input_min_length=None, embedding_length=1024, chunk_len=None, input_format='tfrecord', output_format='tfrecord', suffix='Main', module_call_fn=utils.samples_to_embedding_tfhub_w2v2, setup_fn=hub.load): """Construct beam pipeline for mapping from audio to embeddings. Args: root: The beam root node. input_filenames: Python list. List of input files. output_filename: Python string. Output filename. sample_rate: Python int, or `None`. The sample rate for all embeddings, or `None` if this is a TFDS dataset, or if each example has its own sample rate. debug: Python bool. Whether to operate in debug mode. embedding_names: Python list of embeddings. embedding_modules: Python list of TF-Hub modules. module_output_keys: Python list of strings, names of output modules. audio_key: Python string, the key of the audio. sample_rate_key: Python string or `None`, the key for. label_key: Python string. Field for label. speaker_id_key: Python string. Field for speaker id. average_over_time: Whether to average over time. delete_audio_from_output: Whether to remove audio. split_embeddings_into_separate_tables: stuff use_frontend_fn: stuff normalize_to_pm_one: stuff compute_embeddings_on_chunked_audio: stuff model_input_min_length: stuff embedding_length: Length of embedding. chunk_len: stuff input_format: Python string. Must correspond to a function in `reader_functions`. output_format: Python string. Must correspond to a function `writer_functions`. suffix: Python string. Suffix to stage names to make them unique. module_call_fn: Function for inference on audio. setup_fn: Function for creating audio inference model. """ del split_embeddings_into_separate_tables, use_frontend_fn # Common sanity checks and preprocessing. _common_pipeline_sanity_checks(embedding_modules, embedding_names, module_output_keys) input_examples = _common_pipeline_beginning(root, input_format, input_filenames, suffix, debug) s = suffix embedding_module = embedding_modules[0] # Chunk-specific logic: we need to pad inputs to at least the chunk length. if chunk_len: model_input_min_length = max(model_input_min_length or 0, chunk_len) # Compute all the embeddings simultaneously. logging.info('Adding all signals: %s', module_output_keys) tbl = (input_examples | f'ComputeEmbedding-{s}' >> beam.ParDo( beam_dofns.ChunkAudioAndComputeEmbeddings( name='all', module=embedding_module, output_key=module_output_keys, embedding_names=embedding_names, audio_key=audio_key, label_key=label_key, speaker_id_key=speaker_id_key, sample_rate_key=sample_rate_key, sample_rate=sample_rate, average_over_time=average_over_time, normalize_to_pm_one=normalize_to_pm_one, model_input_min_length=model_input_min_length, chunk_len=chunk_len, module_call_fn=module_call_fn, compute_embeddings_on_chunked_audio= compute_embeddings_on_chunked_audio, setup_fn=setup_fn)) | f'Reshuffle2-{s}' >> beam.Reshuffle() | f'ToTFExample-{s}' >> beam.Map( utils.chunked_audio_to_tfex, delete_audio_from_output=delete_audio_from_output, chunk_len=chunk_len, label_key=label_key, speaker_id_key=speaker_id_key, embedding_length=embedding_length) | f'Reshuffle3-{s}' >> beam.Reshuffle()) # Output sanity checks and write embeddings to disk. _common_pipeline_ending(tbl, output_filename, output_format, s)