Ejemplo n.º 1
0
 def _get_detection_module(self, experiment_name, image_size=(640, 640)):
   params = exp_factory.get_exp_config(experiment_name)
   params.task.model.backbone.resnet.model_id = 18
   params.task.model.detection_generator.use_batched_nms = True
   detection_module = detection.DetectionModule(
       params, batch_size=1, input_image_size=list(image_size))
   return detection_module
Ejemplo n.º 2
0
def main(_):

    params = exp_factory.get_exp_config(FLAGS.experiment)
    for config_file in FLAGS.config_file or []:
        params = hyperparams.override_params_dict(params,
                                                  config_file,
                                                  is_strict=True)
    if FLAGS.params_override:
        params = hyperparams.override_params_dict(params,
                                                  FLAGS.params_override,
                                                  is_strict=True)

    params.validate()
    params.lock()

    export_module = detection.DetectionModule(
        params=params,
        batch_size=FLAGS.batch_size,
        input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
        num_channels=3)

    export_saved_model_lib.export_inference_graph(
        input_type=FLAGS.input_type,
        batch_size=FLAGS.batch_size,
        input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
        params=params,
        checkpoint_path=FLAGS.checkpoint_path,
        export_dir=FLAGS.export_dir,
        export_module=export_module,
        export_checkpoint_subdir='checkpoint',
        export_saved_model_subdir='saved_model')