コード例 #1
0
    def _export_graph(self, pipeline_config, num_channels=3):
        """Exports a tflite graph."""
        output_dir = self.get_temp_dir()
        trained_checkpoint_prefix = os.path.join(output_dir, 'model.ckpt')
        tflite_graph_file = os.path.join(output_dir, 'tflite_graph.pb')

        quantize = pipeline_config.HasField('graph_rewriter')
        self._save_checkpoint_from_mock_model(
            trained_checkpoint_prefix,
            use_moving_averages=pipeline_config.eval_config.
            use_moving_averages,
            quantize=quantize,
            num_channels=num_channels)
        with mock.patch.object(model_builder, 'build',
                               autospec=True) as mock_builder:
            mock_builder.return_value = FakeModel()

            with tf.Graph().as_default():
                export_tflite_ssd_graph_lib.export_tflite_graph(
                    pipeline_config=pipeline_config,
                    trained_checkpoint_prefix=trained_checkpoint_prefix,
                    output_dir=output_dir,
                    add_postprocessing_op=False,
                    max_detections=10,
                    max_classes_per_detection=1)
        return tflite_graph_file
コード例 #2
0
    def _export_graph_with_postprocessing_op(self,
                                             pipeline_config,
                                             num_channels=3,
                                             additional_output_tensors=()):
        """Exports a tflite graph with custom postprocessing op."""
        output_dir = self.get_temp_dir()
        trained_checkpoint_prefix = os.path.join(output_dir, 'model.ckpt')
        tflite_graph_file = os.path.join(output_dir, 'tflite_graph.pb')

        quantize = pipeline_config.HasField('graph_rewriter')
        self._save_checkpoint_from_mock_model(
            trained_checkpoint_prefix,
            use_moving_averages=pipeline_config.eval_config.
            use_moving_averages,
            quantize=quantize,
            num_channels=num_channels)
        with mock.patch.object(model_builder, 'build',
                               autospec=True) as mock_builder:
            mock_builder.return_value = FakeModel()

            with tf.Graph().as_default():
                tf.identity(tf.constant([[1, 2], [3, 4]], tf.uint8),
                            name='UnattachedTensor')
                export_tflite_ssd_graph_lib.export_tflite_graph(
                    pipeline_config=pipeline_config,
                    trained_checkpoint_prefix=trained_checkpoint_prefix,
                    output_dir=output_dir,
                    add_postprocessing_op=True,
                    max_detections=10,
                    max_classes_per_detection=1,
                    additional_output_tensors=additional_output_tensors)
        return tflite_graph_file
コード例 #3
0
ファイル: tflite_creator.py プロジェクト: sbooher2023/fmltc
def create_tflite_graph_pb(team_uuid, model_uuid):
    if blob_storage.tflite_graph_pb_exists(team_uuid, model_uuid):
        return

    model_entity = model_trainer.retrieve_model_entity(team_uuid, model_uuid)

    # The following code is inspired by
    # https://github.com/tensorflow/models/tree/e5c9661aadbcb90cb4fd3ef76066f6d1dab116ff/research/object_detection/export_tflite_ssd_graph.py
    pipeline_config_path = blob_storage.get_pipeline_config_path(
        team_uuid, model_uuid)
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    with tf.io.gfile.GFile(pipeline_config_path, 'r') as f:
        text_format.Merge(f.read(), pipeline_config)
    trained_checkpoint_path = model_entity['trained_checkpoint_path']
    if trained_checkpoint_path == '':
        message = 'Error: Trained checkpoint not found for model_uuid=%s.' % model_uuid
        logging.critical(message)
        raise exceptions.HttpErrorNotFound(message)
    output_directory = blob_storage.get_tflite_folder_path(
        team_uuid, model_uuid)
    add_postprocessing_op = True
    max_detections = 10  # This matches the default for TFObjectDetector.Parameters.maxNumDetections in the the FTC SDK.
    max_classes_per_detection = 1
    use_regular_nms = False
    export_tflite_ssd_graph_lib.export_tflite_graph(
        pipeline_config,
        trained_checkpoint_path,
        output_directory,
        add_postprocessing_op,
        max_detections,
        max_classes_per_detection,
        use_regular_nms=use_regular_nms)
コード例 #4
0
  def _export_graph_with_postprocessing_op(self,
                                           pipeline_config,
                                           num_channels=3):
    """Exports a tflite graph with custom postprocessing op."""
    output_dir = self.get_temp_dir()
    trained_checkpoint_prefix = os.path.join(output_dir, 'model.ckpt')
    tflite_graph_file = os.path.join(output_dir, 'tflite_graph.pb')

    quantize = pipeline_config.HasField('graph_rewriter')
    self._save_checkpoint_from_mock_model(
        trained_checkpoint_prefix,
        use_moving_averages=pipeline_config.eval_config.use_moving_averages,
        quantize=quantize,
        num_channels=num_channels)
    with mock.patch.object(
        model_builder, 'build', autospec=True) as mock_builder:
      mock_builder.return_value = FakeModel()

      with tf.Graph().as_default():
        export_tflite_ssd_graph_lib.export_tflite_graph(
            pipeline_config=pipeline_config,
            trained_checkpoint_prefix=trained_checkpoint_prefix,
            output_dir=output_dir,
            add_postprocessing_op=True,
            max_detections=10,
            max_classes_per_detection=1)
    return tflite_graph_file
コード例 #5
0
def main(argv):
  del argv  # Unused.
  flags.mark_flag_as_required('output_directory')
  flags.mark_flag_as_required('pipeline_config_path')
  flags.mark_flag_as_required('trained_checkpoint_prefix')

  pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()

  with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
    text_format.Merge(f.read(), pipeline_config)
  text_format.Merge(FLAGS.config_override, pipeline_config)
  export_tflite_ssd_graph_lib.export_tflite_graph(
      pipeline_config, FLAGS.trained_checkpoint_prefix, FLAGS.output_directory,
      FLAGS.add_postprocessing_op, FLAGS.max_detections,
      FLAGS.max_classes_per_detection, FLAGS.use_regular_nms)
コード例 #6
0
def main(argv):
  del argv  # Unused.
  flags.mark_flag_as_required('output_directory')
  flags.mark_flag_as_required('pipeline_config_path')
  flags.mark_flag_as_required('trained_checkpoint_prefix')

  pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()

  with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
    text_format.Merge(f.read(), pipeline_config)
  text_format.Merge(FLAGS.config_override, pipeline_config)
  export_tflite_ssd_graph_lib.export_tflite_graph(
      pipeline_config, FLAGS.trained_checkpoint_prefix, FLAGS.output_directory,
      FLAGS.add_postprocessing_op, FLAGS.max_detections,
      FLAGS.max_classes_per_detection, use_regular_nms=FLAGS.use_regular_nms)
コード例 #7
0
ファイル: tf_trainable.py プロジェクト: 5l1v3r1/FalconCV
    def to_tflite(self,
                  checkpoint,
                  out_folder=None,
                  max_detections=10,
                  add_postprocessing_op=True,
                  use_regular_nms=True,
                  max_classes_per_detection=1):
        try:
            assert self.arch == "ssd", "This method is only supported for ssd models"
            model_checkpoint = str(
                self._out_folder.joinpath("model.ckpt-{}".format(checkpoint)))
            tflite_model_folder = Path(
                out_folder) if out_folder else self._out_folder

            export_tflite_ssd_graph_lib.export_tflite_graph(
                self._pipeline,
                model_checkpoint,
                str(tflite_model_folder),
                add_postprocessing_op,
                max_detections,
                max_classes_per_detection,
                use_regular_nms=use_regular_nms)
            # convert to tflite
            cmd = '''toco
                       --output_format=TFLITE
                       --graph_def_file="{}"
                       --output_file="{}"
                       --input_shapes="1,{},{},3"
                       --input_arrays=normalized_input_image_tensor
                       --output_arrays=TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3
                       --inference_type=FLOAT --allow_custom_ops
                       ''' \
                .format(
                tflite_model_folder.joinpath("tflite_graph.pb"),
                tflite_model_folder.joinpath("model.tflite"),
                self.input_size[0], self.input_size[1]
            )
            cmd = " ".join([line.strip() for line in cmd.splitlines()])
            print(subprocess.check_output(cmd, shell=True).decode())
            return self
        except Exception as ex:
            raise Exception("Error converting the model {}".format(ex)) from ex