def main(_): params = exp_factory.get_exp_config(FLAGS.experiment) for config_file in FLAGS.config_file or []: params = hyperparams.override_params_dict(params, config_file, is_strict=True) if FLAGS.params_override: params = hyperparams.override_params_dict(params, FLAGS.params_override, is_strict=True) params.validate() params.lock() export_saved_model_lib.export_inference_graph( input_type=FLAGS.input_type, batch_size=FLAGS.batch_size, input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')], params=params, checkpoint_path=FLAGS.checkpoint_path, export_dir=FLAGS.export_dir, export_module=basnet.BASNetModule( params=params, batch_size=FLAGS.batch_size, input_image_size=[ int(x) for x in FLAGS.input_image_size.split(',') ]), export_checkpoint_subdir='checkpoint', export_saved_model_subdir='saved_model')
def _export_model_with_log_model_flops_and_params(self, params): export_saved_model_lib.export_inference_graph( input_type='image_tensor', batch_size=1, input_image_size=[64, 64], params=params, checkpoint_path=os.path.join(self.tempdir, 'unused-ckpt'), export_dir=self.tempdir, log_model_flops_and_params=True)
def test_retinanet_task(self, unused_export): tempdir = self.create_tempdir() params = configs.retinanet.retinanet_resnetfpn_coco() print(params.task.model.backbone) params.task.model.backbone.resnet.model_id = 18 params.task.model.num_classes = 2 params.task.model.max_level = 6 export_saved_model_lib.export_inference_graph( input_type='image_tensor', batch_size=1, input_image_size=[64, 64], params=params, checkpoint_path=os.path.join(tempdir, 'unused-ckpt'), export_dir=tempdir, log_model_flops_and_params=True) self.assertTrue( tf.io.gfile.exists(os.path.join(tempdir, 'model_params.txt'))) self.assertTrue( tf.io.gfile.exists(os.path.join(tempdir, 'model_flops.txt')))
def main(_): params = exp_factory.get_exp_config(FLAGS.experiment) for config_file in FLAGS.config_file or []: params = hyperparams.override_params_dict(params, config_file, is_strict=True) if FLAGS.params_override: params = hyperparams.override_params_dict(params, FLAGS.params_override, is_strict=True) params.validate() params.lock() input_image_size = [int(x) for x in FLAGS.input_image_size.split(',')] input_specs = tf.keras.layers.InputSpec( shape=[FLAGS.batch_size, *input_image_size, 3]) model = factory.build_panoptic_maskrcnn(input_specs=input_specs, model_config=params.task.model) export_module = panoptic_segmentation.PanopticSegmentationModule( params=params, model=model, batch_size=FLAGS.batch_size, input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')], num_channels=3) export_saved_model_lib.export_inference_graph( input_type=FLAGS.input_type, batch_size=FLAGS.batch_size, input_image_size=input_image_size, params=params, checkpoint_path=FLAGS.checkpoint_path, export_dir=FLAGS.export_dir, export_module=export_module, export_checkpoint_subdir='checkpoint', export_saved_model_subdir='saved_model')
def main(_): flags.mark_flag_as_required('export_dir') flags.mark_flag_as_required('checkpoint_path') params = exp_factory.get_exp_config(FLAGS.experiment) for config_file in FLAGS.config_file or []: params = hyperparams.override_params_dict(params, config_file, is_strict=True) if FLAGS.params_override: params = hyperparams.override_params_dict(params, FLAGS.params_override, is_strict=True) params.validate() params.lock() input_image_size = FLAGS.input_image_size export_module = semantic_segmentation_3d.SegmentationModule( params=params, batch_size=1, input_image_size=input_image_size, num_channels=FLAGS.num_channels) export_saved_model_lib.export_inference_graph( input_type=FLAGS.input_type, batch_size=FLAGS.batch_size, input_image_size=input_image_size, params=params, checkpoint_path=FLAGS.checkpoint_path, export_dir=FLAGS.export_dir, num_channels=FLAGS.num_channels, export_module=export_module, export_checkpoint_subdir='checkpoint', export_saved_model_subdir='saved_model')