def test_mini_beam_pipeline(self):
   with beam.Pipeline() as root:
     _ = (
         root
         | beam.Create([('k1', make_tfexample(5)), ('k2', make_tfexample(5))])
         | beam.ParDo(
             beam_dofns.ComputeMultipleEmbeddingsFromSingleModel(
                 name='all',
                 module='dummy_mod_loc',
                 output_key=['k1', 'k2'],
                 audio_key='audio',
                 sample_rate_key='sample_rate',
                 sample_rate=None,
                 average_over_time=True,
                 feature_fn=None,
                 embedding_names=['em1', 'em2'],
                 embedding_length=10,
                 chunk_len=0,
                 setup_fn=lambda _: MockModule(['k1', 'k2'])))
         | beam.Map(
             data_prep_utils.combine_multiple_embeddings_to_tfex,
             delete_audio_from_output=True,
             audio_key='audio',
             label_key='label',
             speaker_id_key='speaker_id'))
  def test_multiple_embeddings(self, chunk_len, average_over_time):
    dofn = beam_dofns.ComputeMultipleEmbeddingsFromSingleModel(
        name='all',
        module='dummy_name',
        output_key=['k1', 'k2'],  # Sneak the list in.
        audio_key='audio',
        sample_rate_key=None,
        sample_rate=16000,
        average_over_time=average_over_time,
        feature_fn=None,
        embedding_names=['em1', 'em2'],
        embedding_length=10,
        chunk_len=chunk_len,
        setup_fn=lambda _: MockModule(['k1', 'k2'])
    )
    dofn.setup()
    for l in [8000, 16000, 32000]:
      k = f'key_{l}'
      ex = make_tfexample(l)
      kn, exn, emb_dict = list(dofn.process((k, ex)))[0]
      self.assertEqual(k, kn)
      self.assertLen(emb_dict, 2)
      self.assertSetEqual(set(emb_dict.keys()), set(['em1', 'em2']))

      # Now run the next stage of the pipeline on it.
      # TODO(joelshor): Add correctness checks on the output.
      data_prep_utils.combine_multiple_embeddings_to_tfex(
          (kn, exn, emb_dict),
          delete_audio_from_output=True,
          audio_key='audio',
          label_key='label',
          speaker_id_key='speaker_id')
示例#3
0
    def test_pipeline_padding(self, process_fn, chunk_len):
        """Check that the model input is of sufficient length."""
        k, ex = 'key', make_tfexample(100)
        common_args = dict(name='name',
                           module=None,
                           output_key=['output_key'],
                           audio_key='audio',
                           sample_rate_key='sample_rate',
                           sample_rate=None,
                           average_over_time=True,
                           model_input_min_length=400,
                           setup_fn=lambda _: FakeMod())
        if process_fn == 'ComputeEmbeddingMapFn':
            beam_dofn = beam_dofns.ComputeEmbeddingMapFn(**common_args)
        elif process_fn == 'ComputeMultipleEmbeddings':
            beam_dofn = beam_dofns.ComputeMultipleEmbeddingsFromSingleModel(
                embedding_names=['em1'], chunk_len=chunk_len, **common_args)
        elif process_fn == 'ChunkAudioAndComputeEmbeddings':
            beam_dofn = beam_dofns.ChunkAudioAndComputeEmbeddings(
                embedding_names=['em1'], chunk_len=chunk_len, **common_args)
        else:
            assert process_fn == 'ComputeBatchedChunkedSingleEmbeddings'
            beam_dofn = beam_dofns.ComputeBatchedChunkedSingleEmbeddings(
                **common_args)

        # Run preprocessing step.
        beam_dofn.setup()
        if process_fn == 'ComputeEmbeddingMapFn':
            model_input, sample_rate = beam_dofn.read_and_preprocess_audio(
                k, ex)
            expected_output_shape = (400, )
        elif process_fn == 'ComputeBatchedChunkedSingleEmbeddings':
            model_input, _, sample_rate = beam_dofn.read_and_preprocess_batched_audio(
                [k, k], [ex, ex])
            expected_output_shape = (2, 400)
        else:
            model_input, sample_rate = beam_dofn.tfex_to_chunked_audio(k, ex)
            expected_output_shape = (2, chunk_len) if chunk_len else (1, 400)

        # Original audio is too short, so it should be padded to
        # `model_input_min_length`.

        self.assertEqual(model_input.shape, expected_output_shape)

        # Having a non-standard sample rate should trigger resampling and cause the
        # output to be 16kHz.
        self.assertEqual(sample_rate, 16000)
示例#4
0
def multiple_embeddings_from_single_model_pipeline(
        root,
        input_filenames,
        output_filename,
        sample_rate,
        debug,
        embedding_names,
        embedding_modules,
        module_output_keys,
        audio_key,
        sample_rate_key,
        label_key,
        speaker_id_key,
        average_over_time,
        delete_audio_from_output,
        split_embeddings_into_separate_tables=False,
        use_frontend_fn=False,
        normalize_to_pm_one=True,
        model_input_min_length=None,
        embedding_length=None,
        chunk_len=None,
        input_format='tfrecord',
        output_format='tfrecord',
        suffix='Main',
        module_call_fn=utils.samples_to_embedding_tfhub_w2v2,
        setup_fn=hub.load):
    """Construct beam pipeline for mapping from audio to embeddings.

  Args:
    root: The beam root node.
    input_filenames: Python list. List of input files.
    output_filename: Python string. Output filename.
    sample_rate: Python int, or `None`. The sample rate for all embeddings, or
      `None` if this is a TFDS dataset, or if each example has its own sample
      rate.
    debug: Python bool. Whether to operate in debug mode.
    embedding_names: Python list of embeddings.
    embedding_modules: Python list of TF-Hub modules.
    module_output_keys: Python list of strings, names of output modules.
    audio_key: Python string, the key of the audio.
    sample_rate_key: Python string or `None`, the key for.
    label_key: Python string. Field for label.
    speaker_id_key: Python string or `None`. Key for speaker ID, or `None`.
    average_over_time: Python bool. If `True`, average over the time axis.
    delete_audio_from_output: Python bool. Whether to remove audio fromm
      outputs.
    split_embeddings_into_separate_tables: stuff
    use_frontend_fn: stuff
    normalize_to_pm_one: stuff
    model_input_min_length: stuff
    embedding_length: None.
    chunk_len: Stuff
    input_format: Python string. Must correspond to a function in
      `reader_functions`.
    output_format: Python string. Must correspond to a function in
      `writer_functions`.
    suffix: Python string. Suffix to stage names to make them unique.
    module_call_fn: Function for inference on audio.
    setup_fn: Stuff.
  """
    del split_embeddings_into_separate_tables, use_frontend_fn

    # Common sanity checks and preprocessing.
    _common_pipeline_sanity_checks(embedding_modules, embedding_names,
                                   module_output_keys)
    input_examples = _common_pipeline_beginning(root, input_format,
                                                input_filenames, suffix, debug)
    s = suffix
    embedding_module = embedding_modules[0]

    # Compute all the embeddings simultaneously.
    logging.info('Adding all signals: %s', module_output_keys)
    tbl = (input_examples
           | f'ComputeEmbedding-{s}' >> beam.ParDo(
               beam_dofns.ComputeMultipleEmbeddingsFromSingleModel(
                   name='all',
                   module=embedding_module,
                   output_key=module_output_keys,
                   audio_key=audio_key,
                   sample_rate_key=sample_rate_key,
                   sample_rate=sample_rate,
                   average_over_time=average_over_time,
                   feature_fn=None,
                   normalize_to_pm_one=normalize_to_pm_one,
                   model_input_min_length=model_input_min_length,
                   embedding_names=embedding_names,
                   embedding_length=embedding_length,
                   chunk_len=chunk_len,
                   module_call_fn=module_call_fn,
                   setup_fn=setup_fn))
           | f'Reshuffle2-{s}' >> beam.Reshuffle()
           | f'ToTFExample-{s}' >> beam.Map(
               utils.combine_multiple_embeddings_to_tfex,
               delete_audio_from_output=delete_audio_from_output,
               audio_key=audio_key,
               label_key=label_key,
               speaker_id_key=speaker_id_key)
           | f'Reshuffle3-{s}' >> beam.Reshuffle())

    # Output sanity checks and write embeddings to disk.
    _common_pipeline_ending(tbl, output_filename, output_format, s)