def test_model_build_and_export_tflite(self, model_name, image_size):
     tmp_dir = self.create_tempdir().full_path
     config = export_util.ExportConfig(model_name=model_name,
                                       image_size=image_size,
                                       output_dir=tmp_dir)
     config.quantization_config.quantize = False
     model = _build_model(config)
     tflite_path = _dump_tflite(model, config)
     self.assertTrue(tf.io.gfile.exists(tflite_path))
 def test_model_build_and_export_saved_model(self, model_name, image_size):
     tmp_dir = self.create_tempdir().full_path
     config = export_util.ExportConfig(model_name=model_name,
                                       image_size=image_size,
                                       output_dir=tmp_dir)
     model = _build_model(config)
     saved_model_path = os.path.join(config.output_dir, config.model_name)
     model.save(saved_model_path)
     self.assertTrue(tf.saved_model.contains_saved_model(saved_model_path))
 def test_segmentation_finalize_methods(self, model_name, finalize_method):
     tmp_dir = self.create_tempdir().full_path
     config = export_util.ExportConfig(
         model_name=model_name,
         image_size=512,
         output_dir=tmp_dir,
         finalize_method=finalize_method.split(','))
     config.quantization_config.quantize = False
     model = _build_model(config)
     model_input = tf.random.normal(
         [1, config.image_size, config.image_size, 3])
     self.assertEqual(
         model(model_input).get_shape().as_list(),
         [1, config.image_size, config.image_size])
示例#4
0
def get_export_config_from_flags():
  """Creates ExportConfig from cmd line flags."""
  quantization_config = export_util.QuantizationConfig(
      quantize=FLAGS.quantize,
      quantize_less_restrictive=FLAGS.quantize_less_restrictive,
      use_experimental_quantizer=FLAGS.use_experimental_quantizer,
      num_calibration_steps=FLAGS.num_calibration_steps,
      dataset_name=FLAGS.dataset_name,
      dataset_dir=FLAGS.dataset_dir,
      dataset_split=FLAGS.dataset_split)
  export_config = export_util.ExportConfig(
      model_name=FLAGS.model_name,
      ckpt_path=FLAGS.ckpt_path,
      ckpt_format=FLAGS.ckpt_format,
      output_dir=FLAGS.output_dir,
      image_size=FLAGS.image_size,
      finalize_method=FLAGS.finalize_method.lower().split(','),
      quantization_config=quantization_config)
  return export_config