def test_generate_embedding_data_with_bottom_k_boxes(self):
   saved_model_path = self._export_saved_model()
   top_k_embedding_count = 0
   bottom_k_embedding_count = 2
   inference_fn = generate_embedding_data.GenerateEmbeddingDataFn(
       saved_model_path, top_k_embedding_count, bottom_k_embedding_count)
   inference_fn.start_bundle()
   generated_example = self._create_tf_example()
   self.assertAllEqual(
       tf.train.Example.FromString(generated_example).features
       .feature['image/object/class/label'].int64_list.value, [5])
   self.assertAllEqual(
       tf.train.Example.FromString(generated_example).features
       .feature['image/object/class/text'].bytes_list.value, ['hyena'])
   output = inference_fn.process(generated_example)
   output_example = output[0]
   self.assert_expected_example(output_example, botk=True)
 def test_generate_embedding_data_fn(self):
   saved_model_path = self._export_saved_model()
   top_k_embedding_count = 1
   bottom_k_embedding_count = 0
   inference_fn = generate_embedding_data.GenerateEmbeddingDataFn(
       saved_model_path, top_k_embedding_count, bottom_k_embedding_count)
   inference_fn.setup()
   generated_example = self._create_tf_example()
   self.assertAllEqual(tf.train.Example.FromString(
       generated_example).features.feature['image/object/class/label']
                       .int64_list.value, [5])
   self.assertAllEqual(tf.train.Example.FromString(
       generated_example).features.feature['image/object/class/text']
                       .bytes_list.value, [b'hyena'])
   output = inference_fn.process(('dummy_key', generated_example))
   output_example = output[0][1]
   self.assert_expected_example(output_example)