def test_compute_embedding_map_fn_tflite(self, average_over_time,
                                             sample_rate_key, sample_rate):
        # Establish required key names.
        audio_key = 'audio_key'

        # Construct the tf.train.Example test data.
        ex = tf.train.Example()
        ex.features.feature[audio_key].float_list.value.extend(
            np.zeros(2000, np.float32))
        if sample_rate_key:
            ex.features.feature[sample_rate_key].int64_list.value.append(8000)

        old_k = 'oldkey'

        def _feature_fn(x, s):
            return tf.expand_dims(tf_frontend.compute_frontend_features(
                x, s, overlap_seconds=79),
                                  axis=-1).numpy().astype(np.float32)

        do_fn = audio_to_embeddings_beam_utils.ComputeEmbeddingMapFn(
            name='module_name',
            module='file.tflite',
            output_key=0,
            audio_key=audio_key,
            sample_rate_key=sample_rate_key,
            sample_rate=sample_rate,
            average_over_time=average_over_time,
            feature_fn=_feature_fn)
        do_fn.setup()
        new_k, new_v = next(do_fn.process((old_k, ex)))

        self.assertEqual(new_k, old_k)
        expected_shape = (1,
                          BASE_SHAPE_[1]) if average_over_time else BASE_SHAPE_
        self.assertEqual(new_v.shape, expected_shape)
    def test_compute_embedding_map_fn(self, average_over_time, sample_rate_key,
                                      sample_rate):
        # Establish required key names.
        audio_key = 'audio_key'

        # Construct the tf.train.Example test data.
        ex = tf.train.Example()
        ex.features.feature[audio_key].float_list.value.extend(
            np.zeros(2000, np.float32))
        if sample_rate_key:
            ex.features.feature[sample_rate_key].int64_list.value.append(8000)

        old_k = 'oldkey'

        do_fn = audio_to_embeddings_beam_utils.ComputeEmbeddingMapFn(
            name='module_name',
            module='@loc',
            output_key='unnecessary',
            audio_key=audio_key,
            sample_rate_key=sample_rate_key,
            sample_rate=sample_rate,
            average_over_time=average_over_time)
        do_fn.setup()
        new_k, new_v = next(do_fn.process((old_k, ex)))

        self.assertEqual(new_k, old_k)
        expected_shape = (1,
                          BASE_SHAPE_[1]) if average_over_time else BASE_SHAPE_
        self.assertEqual(new_v.shape, expected_shape)
    def test_compute_embedding_map_fn_custom_call(self, average_over_time,
                                                  sample_rate_key,
                                                  sample_rate):
        # Establish required key names.
        audio_key = 'audio_key'
        custom_call_shape = (5, 25)

        # Custom call function for embedding generation.
        def test_call_fn(audio_samples, sample_rate, module_location,
                         output_key, name):
            """Mock waveform-to-embedding computation."""
            del audio_samples, sample_rate, module_location, output_key, name
            return np.zeros(custom_call_shape, dtype=np.float32)

        # Construct the tf.train.Example test data.
        ex = tf.train.Example()
        ex.features.feature[audio_key].float_list.value.extend(
            np.zeros(2000, np.float32))
        if sample_rate_key:
            ex.features.feature[sample_rate_key].int64_list.value.append(8000)

        old_k = 'oldkey'

        do_fn = audio_to_embeddings_beam_utils.ComputeEmbeddingMapFn(
            name='module_name',
            module='@loc',
            output_key='unnecessary',
            audio_key=audio_key,
            sample_rate_key=sample_rate_key,
            sample_rate=sample_rate,
            average_over_time=average_over_time,
            module_call_fn=test_call_fn,
            setup_fn=lambda _: None)
        do_fn.setup()
        new_k, new_v = next(do_fn.process((old_k, ex)))

        self.assertEqual(new_k, old_k)
        expected_shape = (
            1,
            custom_call_shape[1]) if average_over_time else custom_call_shape
        self.assertEqual(new_v.shape, expected_shape)