Ejemplo n.º 1
0
def main(_):

    params = exp_factory.get_exp_config(FLAGS.experiment)
    for config_file in FLAGS.config_file or []:
        params = hyperparams.override_params_dict(params,
                                                  config_file,
                                                  is_strict=True)
    if FLAGS.params_override:
        params = hyperparams.override_params_dict(params,
                                                  FLAGS.params_override,
                                                  is_strict=True)

    params.validate()
    params.lock()

    export_saved_model_lib.export_inference_graph(
        input_type=FLAGS.input_type,
        batch_size=FLAGS.batch_size,
        input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
        params=params,
        checkpoint_path=FLAGS.checkpoint_path,
        export_dir=FLAGS.export_dir,
        export_module=basnet.BASNetModule(
            params=params,
            batch_size=FLAGS.batch_size,
            input_image_size=[
                int(x) for x in FLAGS.input_image_size.split(',')
            ]),
        export_checkpoint_subdir='checkpoint',
        export_saved_model_subdir='saved_model')
Ejemplo n.º 2
0
 def _export_model_with_log_model_flops_and_params(self, params):
     export_saved_model_lib.export_inference_graph(
         input_type='image_tensor',
         batch_size=1,
         input_image_size=[64, 64],
         params=params,
         checkpoint_path=os.path.join(self.tempdir, 'unused-ckpt'),
         export_dir=self.tempdir,
         log_model_flops_and_params=True)
Ejemplo n.º 3
0
 def test_retinanet_task(self, unused_export):
     tempdir = self.create_tempdir()
     params = configs.retinanet.retinanet_resnetfpn_coco()
     print(params.task.model.backbone)
     params.task.model.backbone.resnet.model_id = 18
     params.task.model.num_classes = 2
     params.task.model.max_level = 6
     export_saved_model_lib.export_inference_graph(
         input_type='image_tensor',
         batch_size=1,
         input_image_size=[64, 64],
         params=params,
         checkpoint_path=os.path.join(tempdir, 'unused-ckpt'),
         export_dir=tempdir,
         log_model_flops_and_params=True)
     self.assertTrue(
         tf.io.gfile.exists(os.path.join(tempdir, 'model_params.txt')))
     self.assertTrue(
         tf.io.gfile.exists(os.path.join(tempdir, 'model_flops.txt')))
Ejemplo n.º 4
0
def main(_):

    params = exp_factory.get_exp_config(FLAGS.experiment)
    for config_file in FLAGS.config_file or []:
        params = hyperparams.override_params_dict(params,
                                                  config_file,
                                                  is_strict=True)
    if FLAGS.params_override:
        params = hyperparams.override_params_dict(params,
                                                  FLAGS.params_override,
                                                  is_strict=True)

    params.validate()
    params.lock()

    input_image_size = [int(x) for x in FLAGS.input_image_size.split(',')]
    input_specs = tf.keras.layers.InputSpec(
        shape=[FLAGS.batch_size, *input_image_size, 3])
    model = factory.build_panoptic_maskrcnn(input_specs=input_specs,
                                            model_config=params.task.model)

    export_module = panoptic_segmentation.PanopticSegmentationModule(
        params=params,
        model=model,
        batch_size=FLAGS.batch_size,
        input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
        num_channels=3)

    export_saved_model_lib.export_inference_graph(
        input_type=FLAGS.input_type,
        batch_size=FLAGS.batch_size,
        input_image_size=input_image_size,
        params=params,
        checkpoint_path=FLAGS.checkpoint_path,
        export_dir=FLAGS.export_dir,
        export_module=export_module,
        export_checkpoint_subdir='checkpoint',
        export_saved_model_subdir='saved_model')
Ejemplo n.º 5
0
def main(_):
    flags.mark_flag_as_required('export_dir')
    flags.mark_flag_as_required('checkpoint_path')

    params = exp_factory.get_exp_config(FLAGS.experiment)
    for config_file in FLAGS.config_file or []:
        params = hyperparams.override_params_dict(params,
                                                  config_file,
                                                  is_strict=True)
    if FLAGS.params_override:
        params = hyperparams.override_params_dict(params,
                                                  FLAGS.params_override,
                                                  is_strict=True)

    params.validate()
    params.lock()

    input_image_size = FLAGS.input_image_size

    export_module = semantic_segmentation_3d.SegmentationModule(
        params=params,
        batch_size=1,
        input_image_size=input_image_size,
        num_channels=FLAGS.num_channels)

    export_saved_model_lib.export_inference_graph(
        input_type=FLAGS.input_type,
        batch_size=FLAGS.batch_size,
        input_image_size=input_image_size,
        params=params,
        checkpoint_path=FLAGS.checkpoint_path,
        export_dir=FLAGS.export_dir,
        num_channels=FLAGS.num_channels,
        export_module=export_module,
        export_checkpoint_subdir='checkpoint',
        export_saved_model_subdir='saved_model')