Exemplo n.º 1
0
 def _get_panoptic_segmentation_module(self, experiment_name):
     params = exp_factory.get_exp_config(experiment_name)
     params.task.model.backbone.resnet.model_id = 18
     params.task.model.detection_generator.nms_version = 'batched'
     input_specs = tf.keras.layers.InputSpec(shape=[1, 128, 128, 3])
     model = factory.build_panoptic_maskrcnn(input_specs=input_specs,
                                             model_config=params.task.model)
     panoptic_segmentation_module = panoptic_segmentation.PanopticSegmentationModule(
         params, model=model, batch_size=1, input_image_size=[128, 128])
     return panoptic_segmentation_module
Exemplo n.º 2
0
 def test_build_model_fail_with_none_batch_size(self):
     params = exp_factory.get_exp_config('panoptic_fpn_coco')
     input_specs = tf.keras.layers.InputSpec(shape=[1, 128, 128, 3])
     model = factory.build_panoptic_maskrcnn(input_specs=input_specs,
                                             model_config=params.task.model)
     with self.assertRaisesRegex(
             ValueError,
             'batch_size cannot be None for panoptic segmentation model.'):
         _ = panoptic_segmentation.PanopticSegmentationModule(
             params,
             model=model,
             batch_size=None,
             input_image_size=[128, 128])
Exemplo n.º 3
0
def main(_):

    params = exp_factory.get_exp_config(FLAGS.experiment)
    for config_file in FLAGS.config_file or []:
        params = hyperparams.override_params_dict(params,
                                                  config_file,
                                                  is_strict=True)
    if FLAGS.params_override:
        params = hyperparams.override_params_dict(params,
                                                  FLAGS.params_override,
                                                  is_strict=True)

    params.validate()
    params.lock()

    input_image_size = [int(x) for x in FLAGS.input_image_size.split(',')]
    input_specs = tf.keras.layers.InputSpec(
        shape=[FLAGS.batch_size, *input_image_size, 3])
    model = factory.build_panoptic_maskrcnn(input_specs=input_specs,
                                            model_config=params.task.model)

    export_module = panoptic_segmentation.PanopticSegmentationModule(
        params=params,
        model=model,
        batch_size=FLAGS.batch_size,
        input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
        num_channels=3)

    export_saved_model_lib.export_inference_graph(
        input_type=FLAGS.input_type,
        batch_size=FLAGS.batch_size,
        input_image_size=input_image_size,
        params=params,
        checkpoint_path=FLAGS.checkpoint_path,
        export_dir=FLAGS.export_dir,
        export_module=export_module,
        export_checkpoint_subdir='checkpoint',
        export_saved_model_subdir='saved_model')