コード例 #1
0
    def test_get_eval_metric_ops_for_coco_detections_and_resized_masks(
            self, batch_size=1, max_gt_boxes=None, scale_to_absolute=False):
        eval_config = eval_pb2.EvalConfig()
        eval_config.metrics_set.extend(
            ['coco_detection_metrics', 'coco_mask_metrics'])
        categories = self._get_categories_list()
        eval_dict = self._make_evaluation_dict(
            batch_size=batch_size,
            max_gt_boxes=max_gt_boxes,
            scale_to_absolute=scale_to_absolute,
            resized_groundtruth_masks=True)
        metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
            eval_config, categories, eval_dict)
        _, update_op_boxes = metric_ops['DetectionBoxes_Precision/mAP']
        _, update_op_masks = metric_ops['DetectionMasks_Precision/mAP']

        with self.test_session() as sess:
            metrics = {}
            for key, (value_op, _) in metric_ops.iteritems():
                metrics[key] = value_op
            sess.run(update_op_boxes)
            sess.run(update_op_masks)
            metrics = sess.run(metrics)
            self.assertAlmostEqual(1.0,
                                   metrics['DetectionBoxes_Precision/mAP'])
            self.assertAlmostEqual(1.0,
                                   metrics['DetectionMasks_Precision/mAP'])
コード例 #2
0
 def test_get_eval_metric_ops_raises_error_with_unsupported_metric(self):
     eval_config = eval_pb2.EvalConfig()
     eval_config.metrics_set.extend(['unsupported_metric'])
     categories = self._get_categories_list()
     eval_dict = self._make_evaluation_dict()
     with self.assertRaises(ValueError):
         eval_util.get_eval_metric_ops_for_evaluators(
             eval_config, categories, eval_dict)
コード例 #3
0
def get_configs_from_multiple_files(model_config_path="",
                                    train_config_path="",
                                    train_input_config_path="",
                                    eval_config_path="",
                                    eval_input_config_path="",
                                    graph_rewriter_config_path=""):
  """Reads training configuration from multiple config files.

  Args:
    model_config_path: Path to model_pb2.DetectionModel.
    train_config_path: Path to train_pb2.TrainConfig.
    train_input_config_path: Path to input_reader_pb2.InputReader.
    eval_config_path: Path to eval_pb2.EvalConfig.
    eval_input_config_path: Path to input_reader_pb2.InputReader.
    graph_rewriter_config_path: Path to graph_rewriter_pb2.GraphRewriter.

  Returns:
    Dictionary of configuration objects. Keys are `model`, `train_config`,
      `train_input_config`, `eval_config`, `eval_input_config`. Key/Values are
        returned only for valid (non-empty) strings.
  """
  configs = {}
  if model_config_path:
    model_config = model_pb2.DetectionModel()
    with tf.gfile.GFile(model_config_path, "r") as f:
      text_format.Merge(f.read(), model_config)
      configs["model"] = model_config

  if train_config_path:
    train_config = train_pb2.TrainConfig()
    with tf.gfile.GFile(train_config_path, "r") as f:
      text_format.Merge(f.read(), train_config)
      configs["train_config"] = train_config

  if train_input_config_path:
    train_input_config = input_reader_pb2.InputReader()
    with tf.gfile.GFile(train_input_config_path, "r") as f:
      text_format.Merge(f.read(), train_input_config)
      configs["train_input_config"] = train_input_config

  if eval_config_path:
    eval_config = eval_pb2.EvalConfig()
    with tf.gfile.GFile(eval_config_path, "r") as f:
      text_format.Merge(f.read(), eval_config)
      configs["eval_config"] = eval_config

  if eval_input_config_path:
    eval_input_config = input_reader_pb2.InputReader()
    with tf.gfile.GFile(eval_input_config_path, "r") as f:
      text_format.Merge(f.read(), eval_input_config)
      configs["eval_input_configs"] = [eval_input_config]

  if graph_rewriter_config_path:
    configs["graph_rewriter_config"] = get_graph_rewriter_config_from_file(
        graph_rewriter_config_path)

  return configs
コード例 #4
0
    def test_get_evaluator_with_evaluator_options(self):
        eval_config = eval_pb2.EvalConfig()
        eval_config.metrics_set.extend(['coco_detection_metrics'])
        eval_config.include_metrics_per_category = True
        categories = self._get_categories_list()

        evaluator_options = eval_util.evaluator_options_from_eval_config(
            eval_config)
        evaluator = eval_util.get_evaluators(eval_config, categories,
                                             evaluator_options)

        self.assertTrue(evaluator[0]._include_metrics_per_category)
コード例 #5
0
    def test_get_eval_metric_ops_for_evaluators(self):
        eval_config = eval_pb2.EvalConfig()
        eval_config.metrics_set.extend(
            ['coco_detection_metrics', 'coco_mask_metrics'])
        eval_config.include_metrics_per_category = True

        evaluator_options = eval_util.evaluator_options_from_eval_config(
            eval_config)
        self.assertTrue(evaluator_options['coco_detection_metrics']
                        ['include_metrics_per_category'])
        self.assertTrue(evaluator_options['coco_mask_metrics']
                        ['include_metrics_per_category'])
コード例 #6
0
    def test_get_evaluator_with_no_evaluator_options(self):
        eval_config = eval_pb2.EvalConfig()
        eval_config.metrics_set.extend(['coco_detection_metrics'])
        eval_config.include_metrics_per_category = True
        categories = self._get_categories_list()

        evaluator = eval_util.get_evaluators(eval_config,
                                             categories,
                                             evaluator_options=None)

        # Even though we are setting eval_config.include_metrics_per_category = True
        # this option is never passed into the DetectionEvaluator constructor (via
        # `evaluator_options`).
        self.assertFalse(evaluator[0]._include_metrics_per_category)
コード例 #7
0
    def test_get_configs_from_multiple_files(self):
        """Tests that proto configs can be read from multiple files."""
        temp_dir = self.get_temp_dir()

        # Write model config file.
        model_config_path = os.path.join(temp_dir, "model.config")
        model = model_pb2.DetectionModel()
        model.faster_rcnn.num_classes = 10
        _write_config(model, model_config_path)

        # Write train config file.
        train_config_path = os.path.join(temp_dir, "train.config")
        train_config = train_config = train_pb2.TrainConfig()
        train_config.batch_size = 32
        _write_config(train_config, train_config_path)

        # Write train input config file.
        train_input_config_path = os.path.join(temp_dir, "train_input.config")
        train_input_config = input_reader_pb2.InputReader()
        train_input_config.label_map_path = "path/to/label_map"
        _write_config(train_input_config, train_input_config_path)

        # Write eval config file.
        eval_config_path = os.path.join(temp_dir, "eval.config")
        eval_config = eval_pb2.EvalConfig()
        eval_config.num_examples = 20
        _write_config(eval_config, eval_config_path)

        # Write eval input config file.
        eval_input_config_path = os.path.join(temp_dir, "eval_input.config")
        eval_input_config = input_reader_pb2.InputReader()
        eval_input_config.label_map_path = "path/to/another/label_map"
        _write_config(eval_input_config, eval_input_config_path)

        configs = config_util.get_configs_from_multiple_files(
            model_config_path=model_config_path,
            train_config_path=train_config_path,
            train_input_config_path=train_input_config_path,
            eval_config_path=eval_config_path,
            eval_input_config_path=eval_input_config_path)
        self.assertProtoEquals(model, configs["model"])
        self.assertProtoEquals(train_config, configs["train_config"])
        self.assertProtoEquals(train_input_config,
                               configs["train_input_config"])
        self.assertProtoEquals(eval_config, configs["eval_config"])
        self.assertProtoEquals(eval_input_config,
                               configs["eval_input_configs"][0])