def main(_) -> None: params = exp_factory.get_exp_config(FLAGS.experiment) if FLAGS.config_file is not None: for config_file in FLAGS.config_file: params = hyperparams.override_params_dict(params, config_file, is_strict=True) if FLAGS.params_override: params = hyperparams.override_params_dict(params, FLAGS.params_override, is_strict=True) params.validate() params.lock() logging.info('Converting SavedModel from %s to TFLite model...', FLAGS.saved_model_dir) tflite_model = export_tflite_lib.convert_tflite_model( saved_model_dir=FLAGS.saved_model_dir, quant_type=FLAGS.quant_type, params=params, calibration_steps=FLAGS.calibration_steps) with tf.io.gfile.GFile(FLAGS.tflite_path, 'wb') as fw: fw.write(tflite_model) logging.info('TFLite model converted and saved to %s.', FLAGS.tflite_path)
def test_export_tflite_detection(self, experiment, quant_type, input_image_size): test_tfrecord_file = os.path.join(self.get_temp_dir(), 'det_test.tfrecord') example = tfexample_utils.create_detection_test_example( image_height=input_image_size[0], image_width=input_image_size[1], image_channel=3, num_instances=10) self._create_test_tfrecord( tfrecord_file=test_tfrecord_file, example=example, num_samples=10) params = exp_factory.get_exp_config(experiment) params.task.validation_data.input_path = test_tfrecord_file params.task.train_data.input_path = test_tfrecord_file temp_dir = self.get_temp_dir() module = detection_serving.DetectionModule( params=params, batch_size=1, input_image_size=input_image_size, input_type='tflite') self._export_from_module( module=module, input_type='tflite', saved_model_dir=os.path.join(temp_dir, 'saved_model')) tflite_model = export_tflite_lib.convert_tflite_model( saved_model_dir=os.path.join(temp_dir, 'saved_model'), quant_type=quant_type, params=params, calibration_steps=5) self.assertIsInstance(tflite_model, bytes)
def test_export_tflite_detection(self, experiment, quant_type, input_image_size): params = exp_factory.get_exp_config(experiment) temp_dir = self.get_temp_dir() module = detection_serving.DetectionModule( params=params, batch_size=1, input_image_size=input_image_size) self._export_from_module(module=module, input_type='tflite', saved_model_dir=os.path.join( temp_dir, 'saved_model')) tflite_model = export_tflite_lib.convert_tflite_model( saved_model_dir=os.path.join(temp_dir, 'saved_model'), quant_type=quant_type, params=params, calibration_steps=5) self.assertIsInstance(tflite_model, bytes)
def test_export_tflite(self, experiment, quant_type, input_image_size): params = exp_factory.get_exp_config(experiment) params.task.validation_data.input_path = self._test_tfrecord_file params.task.train_data.input_path = self._test_tfrecord_file temp_dir = self.get_temp_dir() module = image_classification_serving.ClassificationModule( params=params, batch_size=1, input_image_size=input_image_size) self._export_from_module(module=module, input_type='tflite', saved_model_dir=os.path.join( temp_dir, 'saved_model')) tflite_model = export_tflite_lib.convert_tflite_model( saved_model_dir=os.path.join(temp_dir, 'saved_model'), quant_type=quant_type, params=params, calibration_steps=5) self.assertIsInstance(tflite_model, bytes)