def test_export_yields_saved_model(self):
     tmp_dir = self.get_temp_dir()
     self._save_checkpoint_from_mock_model(tmp_dir)
     with mock.patch.object(model_builder, 'build',
                            autospec=True) as mock_builder:
         mock_builder.return_value = FakeModel()
         output_directory = os.path.join(tmp_dir, 'output')
         export_tflite_graph_lib_tf2.export_tflite_model(
             pipeline_config=self._get_ssd_config(),
             trained_checkpoint_dir=tmp_dir,
             output_directory=output_directory,
             max_detections=10,
             use_regular_nms=False)
         self.assertTrue(
             os.path.exists(
                 os.path.join(output_directory, 'saved_model',
                              'saved_model.pb')))
         self.assertTrue(
             os.path.exists(
                 os.path.join(output_directory, 'saved_model', 'variables',
                              'variables.index')))
         self.assertTrue(
             os.path.exists(
                 os.path.join(output_directory, 'saved_model', 'variables',
                              'variables.data-00000-of-00001')))
    def test_unsupported_architecture(self):
        tmp_dir = self.get_temp_dir()
        self._save_checkpoint_from_mock_model(tmp_dir)

        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        pipeline_config.model.faster_rcnn.num_classes = 10

        with mock.patch.object(model_builder, 'build',
                               autospec=True) as mock_builder:
            mock_builder.return_value = FakeModel()
            output_directory = os.path.join(tmp_dir, 'output')
            expected_message = 'Only ssd or center_net models are supported in tflite'
            try:
                export_tflite_graph_lib_tf2.export_tflite_model(
                    pipeline_config=pipeline_config,
                    trained_checkpoint_dir=tmp_dir,
                    output_directory=output_directory,
                    max_detections=10,
                    use_regular_nms=False)
            except ValueError as e:
                if expected_message not in str(e):
                    raise
            else:
                raise AssertionError('Exception not raised: %s' %
                                     expected_message)
Пример #3
0
def main(argv):
    del argv  # Unused.
    flags.mark_flag_as_required('pipeline_config_path')
    flags.mark_flag_as_required('trained_checkpoint_dir')
    flags.mark_flag_as_required('output_directory')

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()

    with tf.io.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
        text_format.Parse(f.read(), pipeline_config)
    text_format.Parse(FLAGS.config_override, pipeline_config)

    export_tflite_graph_lib_tf2.export_tflite_model(
        pipeline_config, FLAGS.trained_checkpoint_dir, FLAGS.output_directory,
        FLAGS.ssd_max_detections, FLAGS.ssd_use_regular_nms)
    def test_center_net_inference_object_detection(self):
        tmp_dir = self.get_temp_dir()
        output_directory = os.path.join(tmp_dir, 'output')
        self._save_checkpoint_from_mock_model(tmp_dir)
        with mock.patch.object(model_builder, 'build',
                               autospec=True) as mock_builder:
            mock_builder.return_value = FakeModel()
            export_tflite_graph_lib_tf2.export_tflite_model(
                pipeline_config=self._get_center_net_config(),
                trained_checkpoint_dir=tmp_dir,
                output_directory=output_directory,
                max_detections=10,
                use_regular_nms=False)

        saved_model_path = os.path.join(output_directory, 'saved_model')
        detect_fn = tf.saved_model.load(saved_model_path)
        detect_fn_sig = detect_fn.signatures['serving_default']
        image = tf.zeros(shape=[1, 10, 10, 3], dtype=tf.float32)
        detections = detect_fn_sig(image)

        # The exported graph doesn't have numerically correct outputs, but there
        # should be 4.
        self.assertEqual(4, len(detections))
def main(argv):
    del argv  # Unused.
    flags.mark_flag_as_required("pipeline_config_path")
    flags.mark_flag_as_required("trained_checkpoint_dir")
    flags.mark_flag_as_required("output_directory")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()

    with tf.io.gfile.GFile(FLAGS.pipeline_config_path, "r") as f:
        text_format.Parse(f.read(), pipeline_config)
    override_config = pipeline_pb2.TrainEvalPipelineConfig()
    text_format.Parse(FLAGS.config_override, override_config)
    pipeline_config.MergeFrom(override_config)

    export_tflite_graph_lib_tf2.export_tflite_model(
        pipeline_config,
        FLAGS.trained_checkpoint_dir,
        FLAGS.output_directory,
        FLAGS.max_detections,
        FLAGS.ssd_use_regular_nms,
        FLAGS.centernet_include_keypoints,
        FLAGS.keypoint_label_map_path,
    )