def test_model_build_and_export_tflite(self, model_name, image_size): tmp_dir = self.create_tempdir().full_path config = export_util.ExportConfig(model_name=model_name, image_size=image_size, output_dir=tmp_dir) config.quantization_config.quantize = False model = _build_model(config) tflite_path = _dump_tflite(model, config) self.assertTrue(tf.io.gfile.exists(tflite_path))
def test_model_build_and_export_saved_model(self, model_name, image_size): tmp_dir = self.create_tempdir().full_path config = export_util.ExportConfig(model_name=model_name, image_size=image_size, output_dir=tmp_dir) model = _build_model(config) saved_model_path = os.path.join(config.output_dir, config.model_name) model.save(saved_model_path) self.assertTrue(tf.saved_model.contains_saved_model(saved_model_path))
def test_segmentation_finalize_methods(self, model_name, finalize_method): tmp_dir = self.create_tempdir().full_path config = export_util.ExportConfig( model_name=model_name, image_size=512, output_dir=tmp_dir, finalize_method=finalize_method.split(',')) config.quantization_config.quantize = False model = _build_model(config) model_input = tf.random.normal( [1, config.image_size, config.image_size, 3]) self.assertEqual( model(model_input).get_shape().as_list(), [1, config.image_size, config.image_size])
def get_export_config_from_flags(): """Creates ExportConfig from cmd line flags.""" quantization_config = export_util.QuantizationConfig( quantize=FLAGS.quantize, quantize_less_restrictive=FLAGS.quantize_less_restrictive, use_experimental_quantizer=FLAGS.use_experimental_quantizer, num_calibration_steps=FLAGS.num_calibration_steps, dataset_name=FLAGS.dataset_name, dataset_dir=FLAGS.dataset_dir, dataset_split=FLAGS.dataset_split) export_config = export_util.ExportConfig( model_name=FLAGS.model_name, ckpt_path=FLAGS.ckpt_path, ckpt_format=FLAGS.ckpt_format, output_dir=FLAGS.output_dir, image_size=FLAGS.image_size, finalize_method=FLAGS.finalize_method.lower().split(','), quantization_config=quantization_config) return export_config