예제 #1
0
def main(_) -> None:
    params = exp_factory.get_exp_config(FLAGS.experiment)
    if FLAGS.config_file is not None:
        for config_file in FLAGS.config_file:
            params = hyperparams.override_params_dict(params,
                                                      config_file,
                                                      is_strict=True)
    if FLAGS.params_override:
        params = hyperparams.override_params_dict(params,
                                                  FLAGS.params_override,
                                                  is_strict=True)

    params.validate()
    params.lock()

    logging.info('Converting SavedModel from %s to TFLite model...',
                 FLAGS.saved_model_dir)
    tflite_model = export_tflite_lib.convert_tflite_model(
        saved_model_dir=FLAGS.saved_model_dir,
        quant_type=FLAGS.quant_type,
        params=params,
        calibration_steps=FLAGS.calibration_steps)

    with tf.io.gfile.GFile(FLAGS.tflite_path, 'wb') as fw:
        fw.write(tflite_model)

    logging.info('TFLite model converted and saved to %s.', FLAGS.tflite_path)
예제 #2
0
  def test_export_tflite_detection(self, experiment, quant_type,
                                   input_image_size):
    test_tfrecord_file = os.path.join(self.get_temp_dir(), 'det_test.tfrecord')
    example = tfexample_utils.create_detection_test_example(
        image_height=input_image_size[0],
        image_width=input_image_size[1],
        image_channel=3,
        num_instances=10)
    self._create_test_tfrecord(
        tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
    params = exp_factory.get_exp_config(experiment)
    params.task.validation_data.input_path = test_tfrecord_file
    params.task.train_data.input_path = test_tfrecord_file
    temp_dir = self.get_temp_dir()
    module = detection_serving.DetectionModule(
        params=params,
        batch_size=1,
        input_image_size=input_image_size,
        input_type='tflite')
    self._export_from_module(
        module=module,
        input_type='tflite',
        saved_model_dir=os.path.join(temp_dir, 'saved_model'))

    tflite_model = export_tflite_lib.convert_tflite_model(
        saved_model_dir=os.path.join(temp_dir, 'saved_model'),
        quant_type=quant_type,
        params=params,
        calibration_steps=5)

    self.assertIsInstance(tflite_model, bytes)
예제 #3
0
    def test_export_tflite_detection(self, experiment, quant_type,
                                     input_image_size):
        params = exp_factory.get_exp_config(experiment)
        temp_dir = self.get_temp_dir()
        module = detection_serving.DetectionModule(
            params=params, batch_size=1, input_image_size=input_image_size)
        self._export_from_module(module=module,
                                 input_type='tflite',
                                 saved_model_dir=os.path.join(
                                     temp_dir, 'saved_model'))

        tflite_model = export_tflite_lib.convert_tflite_model(
            saved_model_dir=os.path.join(temp_dir, 'saved_model'),
            quant_type=quant_type,
            params=params,
            calibration_steps=5)

        self.assertIsInstance(tflite_model, bytes)
예제 #4
0
    def test_export_tflite(self, experiment, quant_type, input_image_size):
        params = exp_factory.get_exp_config(experiment)
        params.task.validation_data.input_path = self._test_tfrecord_file
        params.task.train_data.input_path = self._test_tfrecord_file
        temp_dir = self.get_temp_dir()
        module = image_classification_serving.ClassificationModule(
            params=params, batch_size=1, input_image_size=input_image_size)
        self._export_from_module(module=module,
                                 input_type='tflite',
                                 saved_model_dir=os.path.join(
                                     temp_dir, 'saved_model'))

        tflite_model = export_tflite_lib.convert_tflite_model(
            saved_model_dir=os.path.join(temp_dir, 'saved_model'),
            quant_type=quant_type,
            params=params,
            calibration_steps=5)

        self.assertIsInstance(tflite_model, bytes)