def test_export_yields_saved_model(self): tmp_dir = self.get_temp_dir() self._save_checkpoint_from_mock_model(tmp_dir) with mock.patch.object(model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel() output_directory = os.path.join(tmp_dir, 'output') export_tflite_graph_lib_tf2.export_tflite_model( pipeline_config=self._get_ssd_config(), trained_checkpoint_dir=tmp_dir, output_directory=output_directory, max_detections=10, use_regular_nms=False) self.assertTrue( os.path.exists( os.path.join(output_directory, 'saved_model', 'saved_model.pb'))) self.assertTrue( os.path.exists( os.path.join(output_directory, 'saved_model', 'variables', 'variables.index'))) self.assertTrue( os.path.exists( os.path.join(output_directory, 'saved_model', 'variables', 'variables.data-00000-of-00001')))
def test_unsupported_architecture(self): tmp_dir = self.get_temp_dir() self._save_checkpoint_from_mock_model(tmp_dir) pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config.model.faster_rcnn.num_classes = 10 with mock.patch.object(model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel() output_directory = os.path.join(tmp_dir, 'output') expected_message = 'Only ssd or center_net models are supported in tflite' try: export_tflite_graph_lib_tf2.export_tflite_model( pipeline_config=pipeline_config, trained_checkpoint_dir=tmp_dir, output_directory=output_directory, max_detections=10, use_regular_nms=False) except ValueError as e: if expected_message not in str(e): raise else: raise AssertionError('Exception not raised: %s' % expected_message)
def main(argv): del argv # Unused. flags.mark_flag_as_required('pipeline_config_path') flags.mark_flag_as_required('trained_checkpoint_dir') flags.mark_flag_as_required('output_directory') pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() with tf.io.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f: text_format.Parse(f.read(), pipeline_config) text_format.Parse(FLAGS.config_override, pipeline_config) export_tflite_graph_lib_tf2.export_tflite_model( pipeline_config, FLAGS.trained_checkpoint_dir, FLAGS.output_directory, FLAGS.ssd_max_detections, FLAGS.ssd_use_regular_nms)
def test_center_net_inference_object_detection(self): tmp_dir = self.get_temp_dir() output_directory = os.path.join(tmp_dir, 'output') self._save_checkpoint_from_mock_model(tmp_dir) with mock.patch.object(model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel() export_tflite_graph_lib_tf2.export_tflite_model( pipeline_config=self._get_center_net_config(), trained_checkpoint_dir=tmp_dir, output_directory=output_directory, max_detections=10, use_regular_nms=False) saved_model_path = os.path.join(output_directory, 'saved_model') detect_fn = tf.saved_model.load(saved_model_path) detect_fn_sig = detect_fn.signatures['serving_default'] image = tf.zeros(shape=[1, 10, 10, 3], dtype=tf.float32) detections = detect_fn_sig(image) # The exported graph doesn't have numerically correct outputs, but there # should be 4. self.assertEqual(4, len(detections))
def main(argv): del argv # Unused. flags.mark_flag_as_required("pipeline_config_path") flags.mark_flag_as_required("trained_checkpoint_dir") flags.mark_flag_as_required("output_directory") pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() with tf.io.gfile.GFile(FLAGS.pipeline_config_path, "r") as f: text_format.Parse(f.read(), pipeline_config) override_config = pipeline_pb2.TrainEvalPipelineConfig() text_format.Parse(FLAGS.config_override, override_config) pipeline_config.MergeFrom(override_config) export_tflite_graph_lib_tf2.export_tflite_model( pipeline_config, FLAGS.trained_checkpoint_dir, FLAGS.output_directory, FLAGS.max_detections, FLAGS.ssd_use_regular_nms, FLAGS.centernet_include_keypoints, FLAGS.keypoint_label_map_path, )