def _export_graph(self, pipeline_config, num_channels=3): """Exports a tflite graph.""" output_dir = self.get_temp_dir() trained_checkpoint_prefix = os.path.join(output_dir, 'model.ckpt') tflite_graph_file = os.path.join(output_dir, 'tflite_graph.pb') quantize = pipeline_config.HasField('graph_rewriter') self._save_checkpoint_from_mock_model( trained_checkpoint_prefix, use_moving_averages=pipeline_config.eval_config. use_moving_averages, quantize=quantize, num_channels=num_channels) with mock.patch.object(model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel() with tf.Graph().as_default(): export_tflite_ssd_graph_lib.export_tflite_graph( pipeline_config=pipeline_config, trained_checkpoint_prefix=trained_checkpoint_prefix, output_dir=output_dir, add_postprocessing_op=False, max_detections=10, max_classes_per_detection=1) return tflite_graph_file
def _export_graph_with_postprocessing_op(self, pipeline_config, num_channels=3, additional_output_tensors=()): """Exports a tflite graph with custom postprocessing op.""" output_dir = self.get_temp_dir() trained_checkpoint_prefix = os.path.join(output_dir, 'model.ckpt') tflite_graph_file = os.path.join(output_dir, 'tflite_graph.pb') quantize = pipeline_config.HasField('graph_rewriter') self._save_checkpoint_from_mock_model( trained_checkpoint_prefix, use_moving_averages=pipeline_config.eval_config. use_moving_averages, quantize=quantize, num_channels=num_channels) with mock.patch.object(model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel() with tf.Graph().as_default(): tf.identity(tf.constant([[1, 2], [3, 4]], tf.uint8), name='UnattachedTensor') export_tflite_ssd_graph_lib.export_tflite_graph( pipeline_config=pipeline_config, trained_checkpoint_prefix=trained_checkpoint_prefix, output_dir=output_dir, add_postprocessing_op=True, max_detections=10, max_classes_per_detection=1, additional_output_tensors=additional_output_tensors) return tflite_graph_file
def create_tflite_graph_pb(team_uuid, model_uuid): if blob_storage.tflite_graph_pb_exists(team_uuid, model_uuid): return model_entity = model_trainer.retrieve_model_entity(team_uuid, model_uuid) # The following code is inspired by # https://github.com/tensorflow/models/tree/e5c9661aadbcb90cb4fd3ef76066f6d1dab116ff/research/object_detection/export_tflite_ssd_graph.py pipeline_config_path = blob_storage.get_pipeline_config_path( team_uuid, model_uuid) pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() with tf.io.gfile.GFile(pipeline_config_path, 'r') as f: text_format.Merge(f.read(), pipeline_config) trained_checkpoint_path = model_entity['trained_checkpoint_path'] if trained_checkpoint_path == '': message = 'Error: Trained checkpoint not found for model_uuid=%s.' % model_uuid logging.critical(message) raise exceptions.HttpErrorNotFound(message) output_directory = blob_storage.get_tflite_folder_path( team_uuid, model_uuid) add_postprocessing_op = True max_detections = 10 # This matches the default for TFObjectDetector.Parameters.maxNumDetections in the the FTC SDK. max_classes_per_detection = 1 use_regular_nms = False export_tflite_ssd_graph_lib.export_tflite_graph( pipeline_config, trained_checkpoint_path, output_directory, add_postprocessing_op, max_detections, max_classes_per_detection, use_regular_nms=use_regular_nms)
def _export_graph_with_postprocessing_op(self, pipeline_config, num_channels=3): """Exports a tflite graph with custom postprocessing op.""" output_dir = self.get_temp_dir() trained_checkpoint_prefix = os.path.join(output_dir, 'model.ckpt') tflite_graph_file = os.path.join(output_dir, 'tflite_graph.pb') quantize = pipeline_config.HasField('graph_rewriter') self._save_checkpoint_from_mock_model( trained_checkpoint_prefix, use_moving_averages=pipeline_config.eval_config.use_moving_averages, quantize=quantize, num_channels=num_channels) with mock.patch.object( model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel() with tf.Graph().as_default(): export_tflite_ssd_graph_lib.export_tflite_graph( pipeline_config=pipeline_config, trained_checkpoint_prefix=trained_checkpoint_prefix, output_dir=output_dir, add_postprocessing_op=True, max_detections=10, max_classes_per_detection=1) return tflite_graph_file
def main(argv): del argv # Unused. flags.mark_flag_as_required('output_directory') flags.mark_flag_as_required('pipeline_config_path') flags.mark_flag_as_required('trained_checkpoint_prefix') pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f: text_format.Merge(f.read(), pipeline_config) text_format.Merge(FLAGS.config_override, pipeline_config) export_tflite_ssd_graph_lib.export_tflite_graph( pipeline_config, FLAGS.trained_checkpoint_prefix, FLAGS.output_directory, FLAGS.add_postprocessing_op, FLAGS.max_detections, FLAGS.max_classes_per_detection, FLAGS.use_regular_nms)
def main(argv): del argv # Unused. flags.mark_flag_as_required('output_directory') flags.mark_flag_as_required('pipeline_config_path') flags.mark_flag_as_required('trained_checkpoint_prefix') pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f: text_format.Merge(f.read(), pipeline_config) text_format.Merge(FLAGS.config_override, pipeline_config) export_tflite_ssd_graph_lib.export_tflite_graph( pipeline_config, FLAGS.trained_checkpoint_prefix, FLAGS.output_directory, FLAGS.add_postprocessing_op, FLAGS.max_detections, FLAGS.max_classes_per_detection, use_regular_nms=FLAGS.use_regular_nms)
def to_tflite(self, checkpoint, out_folder=None, max_detections=10, add_postprocessing_op=True, use_regular_nms=True, max_classes_per_detection=1): try: assert self.arch == "ssd", "This method is only supported for ssd models" model_checkpoint = str( self._out_folder.joinpath("model.ckpt-{}".format(checkpoint))) tflite_model_folder = Path( out_folder) if out_folder else self._out_folder export_tflite_ssd_graph_lib.export_tflite_graph( self._pipeline, model_checkpoint, str(tflite_model_folder), add_postprocessing_op, max_detections, max_classes_per_detection, use_regular_nms=use_regular_nms) # convert to tflite cmd = '''toco --output_format=TFLITE --graph_def_file="{}" --output_file="{}" --input_shapes="1,{},{},3" --input_arrays=normalized_input_image_tensor --output_arrays=TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3 --inference_type=FLOAT --allow_custom_ops ''' \ .format( tflite_model_folder.joinpath("tflite_graph.pb"), tflite_model_folder.joinpath("model.tflite"), self.input_size[0], self.input_size[1] ) cmd = " ".join([line.strip() for line in cmd.splitlines()]) print(subprocess.check_output(cmd, shell=True).decode()) return self except Exception as ex: raise Exception("Error converting the model {}".format(ex)) from ex