def test_generate_embedding_data_with_bottom_k_boxes(self): saved_model_path = self._export_saved_model() top_k_embedding_count = 0 bottom_k_embedding_count = 2 inference_fn = generate_embedding_data.GenerateEmbeddingDataFn( saved_model_path, top_k_embedding_count, bottom_k_embedding_count) inference_fn.start_bundle() generated_example = self._create_tf_example() self.assertAllEqual( tf.train.Example.FromString(generated_example).features .feature['image/object/class/label'].int64_list.value, [5]) self.assertAllEqual( tf.train.Example.FromString(generated_example).features .feature['image/object/class/text'].bytes_list.value, ['hyena']) output = inference_fn.process(generated_example) output_example = output[0] self.assert_expected_example(output_example, botk=True)
def test_generate_embedding_data_fn(self): saved_model_path = self._export_saved_model() top_k_embedding_count = 1 bottom_k_embedding_count = 0 inference_fn = generate_embedding_data.GenerateEmbeddingDataFn( saved_model_path, top_k_embedding_count, bottom_k_embedding_count) inference_fn.setup() generated_example = self._create_tf_example() self.assertAllEqual(tf.train.Example.FromString( generated_example).features.feature['image/object/class/label'] .int64_list.value, [5]) self.assertAllEqual(tf.train.Example.FromString( generated_example).features.feature['image/object/class/text'] .bytes_list.value, [b'hyena']) output = inference_fn.process(('dummy_key', generated_example)) output_example = output[0][1] self.assert_expected_example(output_example)