def get_configs_from_multiple_files(model_config_path="",
                                    train_config_path="",
                                    train_input_config_path="",
                                    eval_config_path="",
                                    eval_input_config_path="",
                                    graph_rewriter_config_path=""):
    """Reads training configuration from multiple config files.

  Args:
    model_config_path: Path to model_pb2.DetectionModel.
    train_config_path: Path to train_pb2.TrainConfig.
    train_input_config_path: Path to input_reader_pb2.InputReader.
    eval_config_path: Path to eval_pb2.EvalConfig.
    eval_input_config_path: Path to input_reader_pb2.InputReader.
    graph_rewriter_config_path: Path to graph_rewriter_pb2.GraphRewriter.

  Returns:
    Dictionary of configuration objects. Keys are `model`, `train_config`,
      `train_input_config`, `eval_config`, `eval_input_config`. Key/Values are
        returned only for valid (non-empty) strings.
  """
    configs = {}
    if model_config_path:
        model_config = model_pb2.DetectionModel()
        with tf.gfile.GFile(model_config_path, "r") as f:
            text_format.Merge(f.read(), model_config)
            configs["model"] = model_config

    if train_config_path:
        train_config = train_pb2.TrainConfig()
        with tf.gfile.GFile(train_config_path, "r") as f:
            text_format.Merge(f.read(), train_config)
            configs["train_config"] = train_config

    if train_input_config_path:
        train_input_config = input_reader_pb2.InputReader()
        with tf.gfile.GFile(train_input_config_path, "r") as f:
            text_format.Merge(f.read(), train_input_config)
            configs["train_input_config"] = train_input_config

    if eval_config_path:
        eval_config = eval_pb2.EvalConfig()
        with tf.gfile.GFile(eval_config_path, "r") as f:
            text_format.Merge(f.read(), eval_config)
            configs["eval_config"] = eval_config

    if eval_input_config_path:
        eval_input_config = input_reader_pb2.InputReader()
        with tf.gfile.GFile(eval_input_config_path, "r") as f:
            text_format.Merge(f.read(), eval_input_config)
            configs["eval_input_configs"] = [eval_input_config]

    if graph_rewriter_config_path:
        configs["graph_rewriter_config"] = get_graph_rewriter_config_from_file(
            graph_rewriter_config_path)

    return configs
def get_configs_from_multiple_files():
    """Reads training configuration from multiple config files.

    Reads the training config from the following files:
      model_config: Read from --model_config_path
      train_config: Read from --train_config_path
      input_config: Read from --input_config_path

    Returns:
      model_config: model_pb2.DetectionModel
      train_config: train_pb2.TrainConfig
      input_config: input_reader_pb2.InputReader
    """
    train_config = train_pb2.TrainConfig()
    with tf.gfile.GFile(FLAGS.train_config_path, 'r') as f:
        text_format.Merge(f.read(), train_config)

    model_config = model_pb2.DetectionModel()
    with tf.gfile.GFile(FLAGS.model_config_path, 'r') as f:
        text_format.Merge(f.read(), model_config)

    input_config = input_reader_pb2.InputReader()
    with tf.gfile.GFile(FLAGS.input_config_path, 'r') as f:
        text_format.Merge(f.read(), input_config)

    return model_config, train_config, input_config
    def test_build_tf_record_input_reader(self):
        tf_record_path = self.create_tf_record()

        input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
        input_reader_proto = input_reader_pb2.InputReader()
        text_format.Merge(input_reader_text_proto, input_reader_proto)
        tensor_dict = dataset_util.make_initializable_iterator(
            dataset_builder.build(input_reader_proto)).get_next()

        sv = tf.train.Supervisor(logdir=self.get_temp_dir())
        with sv.prepare_or_wait_for_session() as sess:
            sv.start_queue_runners(sess)
            output_dict = sess.run(tensor_dict)

        self.assertTrue(fields.InputDataFields.groundtruth_instance_masks
                        not in output_dict)
        self.assertEquals((4, 5, 3),
                          output_dict[fields.InputDataFields.image].shape)
        self.assertEquals(
            [2], output_dict[fields.InputDataFields.groundtruth_classes])
        self.assertEquals(
            (1, 4),
            output_dict[fields.InputDataFields.groundtruth_boxes].shape)
        self.assertAllEqual(
            [0.0, 0.0, 1.0, 1.0],
            output_dict[fields.InputDataFields.groundtruth_boxes][0])
Example #4
0
def get_configs_from_multiple_files():
    """Reads evaluation configuration from multiple config files.

    Reads the evaluation config from the following files:
      model_config: Read from --model_config_path
      eval_config: Read from --eval_config_path
      input_config: Read from --input_config_path

    Returns:
      model_config: a model_pb2.DetectionModel
      eval_config: a eval_pb2.EvalConfig
      input_config: a input_reader_pb2.InputReader
    """
    eval_config = eval_pb2.EvalConfig()
    with tf.gfile.GFile(FLAGS.eval_config_path, 'r') as f:
        text_format.Merge(f.read(), eval_config)

    model_config = model_pb2.DetectionModel()
    with tf.gfile.GFile(FLAGS.model_config_path, 'r') as f:
        text_format.Merge(f.read(), model_config)

    input_config = input_reader_pb2.InputReader()
    with tf.gfile.GFile(FLAGS.input_config_path, 'r') as f:
        text_format.Merge(f.read(), input_config)

    return model_config, eval_config, input_config
    def test_get_configs_from_multiple_files(self):
        """Tests that proto configs can be read from multiple files."""
        temp_dir = self.get_temp_dir()

        # Write model config file.
        model_config_path = os.path.join(temp_dir, "model.config")
        model = model_pb2.DetectionModel()
        model.faster_rcnn.num_classes = 10
        _write_config(model, model_config_path)

        # Write train config file.
        train_config_path = os.path.join(temp_dir, "train.config")
        train_config = train_config = train_pb2.TrainConfig()
        train_config.batch_size = 32
        _write_config(train_config, train_config_path)

        # Write train input config file.
        train_input_config_path = os.path.join(temp_dir, "train_input.config")
        train_input_config = input_reader_pb2.InputReader()
        train_input_config.label_map_path = "path/to/label_map"
        _write_config(train_input_config, train_input_config_path)

        # Write eval config file.
        eval_config_path = os.path.join(temp_dir, "eval.config")
        eval_config = eval_pb2.EvalConfig()
        eval_config.num_examples = 20
        _write_config(eval_config, eval_config_path)

        # Write eval input config file.
        eval_input_config_path = os.path.join(temp_dir, "eval_input.config")
        eval_input_config = input_reader_pb2.InputReader()
        eval_input_config.label_map_path = "path/to/another/label_map"
        _write_config(eval_input_config, eval_input_config_path)

        configs = config_util.get_configs_from_multiple_files(
            model_config_path=model_config_path,
            train_config_path=train_config_path,
            train_input_config_path=train_input_config_path,
            eval_config_path=eval_config_path,
            eval_input_config_path=eval_input_config_path)
        self.assertProtoEquals(model, configs["model"])
        self.assertProtoEquals(train_config, configs["train_config"])
        self.assertProtoEquals(train_input_config,
                               configs["train_input_config"])
        self.assertProtoEquals(eval_config, configs["eval_config"])
        self.assertProtoEquals(eval_input_config, configs["eval_input_config"])
 def test_raises_error_with_no_input_paths(self):
     input_reader_text_proto = """
   shuffle: false
   num_readers: 1
   load_instance_masks: true
 """
     input_reader_proto = input_reader_pb2.InputReader()
     text_format.Merge(input_reader_text_proto, input_reader_proto)
     with self.assertRaises(ValueError):
         dataset_builder.build(input_reader_proto)
  def test_read_dataset(self):
    config = input_reader_pb2.InputReader()
    config.num_readers = 1
    config.shuffle = False

    data = self._get_dataset_next([self._path_template % '*'], config,
                                  batch_size=20)
    with self.test_session() as sess:
      self.assertAllEqual(sess.run(data),
                          [[1, 10, 2, 20, 3, 30, 4, 40, 5, 50, 1, 10, 2, 20, 3,
                            30, 4, 40, 5, 50]])
  def test_read_dataset_single_epoch(self):
    config = input_reader_pb2.InputReader()
    config.num_epochs = 1
    config.num_readers = 1
    config.shuffle = False

    data = self._get_dataset_next([self._path_template % '0'], config,
                                  batch_size=30)
    with self.test_session() as sess:
      # First batch will retrieve as much as it can, second batch will fail.
      self.assertAllEqual(sess.run(data), [[1, 10]])
      self.assertRaises(tf.errors.OutOfRangeError, sess.run, data)
    def test_build_tf_record_input_reader_with_batch_size_two(self):
        tf_record_path = self.create_tf_record()

        input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
        input_reader_proto = input_reader_pb2.InputReader()
        text_format.Merge(input_reader_text_proto, input_reader_proto)

        def one_hot_class_encoding_fn(tensor_dict):
            tensor_dict[
                fields.InputDataFields.groundtruth_classes] = tf.one_hot(
                    tensor_dict[fields.InputDataFields.groundtruth_classes] -
                    1,
                    depth=3)
            return tensor_dict

        tensor_dict = dataset_util.make_initializable_iterator(
            dataset_builder.build(
                input_reader_proto,
                transform_input_data_fn=one_hot_class_encoding_fn,
                batch_size=2,
                max_num_boxes=2,
                num_classes=3,
                spatial_image_shape=[4, 5])).get_next()

        sv = tf.train.Supervisor(logdir=self.get_temp_dir())
        with sv.prepare_or_wait_for_session() as sess:
            sv.start_queue_runners(sess)
            output_dict = sess.run(tensor_dict)

        self.assertAllEqual([2, 4, 5, 3],
                            output_dict[fields.InputDataFields.image].shape)
        self.assertAllEqual(
            [2, 2, 3],
            output_dict[fields.InputDataFields.groundtruth_classes].shape)
        self.assertAllEqual(
            [2, 2, 4],
            output_dict[fields.InputDataFields.groundtruth_boxes].shape)
        self.assertAllEqual(
            [[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]],
             [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]]],
            output_dict[fields.InputDataFields.groundtruth_boxes])
  def test_build_tf_record_input_reader_and_load_instance_masks(self):
    tf_record_path = self.create_tf_record()

    input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      load_instance_masks: true
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
    input_reader_proto = input_reader_pb2.InputReader()
    text_format.Merge(input_reader_text_proto, input_reader_proto)
    tensor_dict = dataset_util.make_initializable_iterator(
        dataset_builder.build(input_reader_proto, batch_size=1)).get_next()

    sv = tf.train.Supervisor(logdir=self.get_temp_dir())
    with sv.prepare_or_wait_for_session() as sess:
      sv.start_queue_runners(sess)
      output_dict = sess.run(tensor_dict)
    self.assertAllEqual(
        (1, 1, 4, 5),
        output_dict[fields.InputDataFields.groundtruth_instance_masks].shape)