コード例 #1
0
def get_configs_from_multiple_files(model_config_path="",
                                    train_config_path="",
                                    train_input_config_path="",
                                    eval_config_path="",
                                    eval_input_config_path="",
                                    graph_rewriter_config_path=""):
    """Reads training configuration from multiple config files.

  Args:
    model_config_path: Path to model_pb2.DetectionModel.
    train_config_path: Path to train_pb2.TrainConfig.
    train_input_config_path: Path to input_reader_pb2.InputReader.
    eval_config_path: Path to eval_pb2.EvalConfig.
    eval_input_config_path: Path to input_reader_pb2.InputReader.
    graph_rewriter_config_path: Path to graph_rewriter_pb2.GraphRewriter.

  Returns:
    Dictionary of configuration objects. Keys are `model`, `train_config`,
      `train_input_config`, `eval_config`, `eval_input_config`. Key/Values are
        returned only for valid (non-empty) strings.
  """
    configs = {}
    if model_config_path:
        model_config = model_pb2.DetectionModel()
        with tf.gfile.GFile(model_config_path, "r") as f:
            text_format.Merge(f.read(), model_config)
            configs["model"] = model_config

    if train_config_path:
        train_config = train_pb2.TrainConfig()
        with tf.gfile.GFile(train_config_path, "r") as f:
            text_format.Merge(f.read(), train_config)
            configs["train_config"] = train_config

    if train_input_config_path:
        train_input_config = input_reader_pb2.InputReader()
        with tf.gfile.GFile(train_input_config_path, "r") as f:
            text_format.Merge(f.read(), train_input_config)
            configs["train_input_config"] = train_input_config

    if eval_config_path:
        eval_config = eval_pb2.EvalConfig()
        with tf.gfile.GFile(eval_config_path, "r") as f:
            text_format.Merge(f.read(), eval_config)
            configs["eval_config"] = eval_config

    if eval_input_config_path:
        eval_input_config = input_reader_pb2.InputReader()
        with tf.gfile.GFile(eval_input_config_path, "r") as f:
            text_format.Merge(f.read(), eval_input_config)
            configs["eval_input_config"] = eval_input_config

    if graph_rewriter_config_path:
        configs["graph_rewriter_config"] = get_graph_rewriter_config_from_file(
            graph_rewriter_config_path)

    return configs
コード例 #2
0
  def test_build_tf_record_input_reader(self):
    tf_record_path = self.create_tf_record()

    input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
    input_reader_proto = input_reader_pb2.InputReader()
    text_format.Merge(input_reader_text_proto, input_reader_proto)
    tensor_dict = dataset_builder.make_initializable_iterator(
        dataset_builder.build(input_reader_proto, batch_size=1)).get_next()

    sv = tf.train.Supervisor(logdir=self.get_temp_dir())
    with sv.prepare_or_wait_for_session() as sess:
      sv.start_queue_runners(sess)
      output_dict = sess.run(tensor_dict)

    self.assertTrue(
        fields.InputDataFields.groundtruth_instance_masks not in output_dict)
    self.assertEquals((1, 4, 5, 3),
                      output_dict[fields.InputDataFields.image].shape)
    self.assertAllEqual([[2]],
                        output_dict[fields.InputDataFields.groundtruth_classes])
    self.assertEquals(
        (1, 1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
    self.assertAllEqual(
        [0.0, 0.0, 1.0, 1.0],
        output_dict[fields.InputDataFields.groundtruth_boxes][0][0])
コード例 #3
0
  def test_build_tf_record_input_reader_with_batch_size_two_and_masks(self):
    tf_record_path = self.create_tf_record()

    input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      load_instance_masks: true
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
    input_reader_proto = input_reader_pb2.InputReader()
    text_format.Merge(input_reader_text_proto, input_reader_proto)

    def one_hot_class_encoding_fn(tensor_dict):
      tensor_dict[fields.InputDataFields.groundtruth_classes] = tf.one_hot(
          tensor_dict[fields.InputDataFields.groundtruth_classes] - 1, depth=3)
      return tensor_dict

    tensor_dict = dataset_builder.make_initializable_iterator(
        dataset_builder.build(
            input_reader_proto,
            transform_input_data_fn=one_hot_class_encoding_fn,
            batch_size=2)).get_next()

    sv = tf.train.Supervisor(logdir=self.get_temp_dir())
    with sv.prepare_or_wait_for_session() as sess:
      sv.start_queue_runners(sess)
      output_dict = sess.run(tensor_dict)

    self.assertAllEqual(
        [2, 1, 4, 5],
        output_dict[fields.InputDataFields.groundtruth_instance_masks].shape)
コード例 #4
0
    def test_build_tf_record_input_reader_and_load_instance_masks(self):
        tf_record_path = self.create_tf_record()

        input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      load_instance_masks: true
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
        input_reader_proto = input_reader_pb2.InputReader()
        text_format.Merge(input_reader_text_proto, input_reader_proto)
        tensor_dict = input_reader_builder.build(input_reader_proto)

        sv = tf.train.Supervisor(logdir=self.get_temp_dir())
        with sv.prepare_or_wait_for_session() as sess:
            sv.start_queue_runners(sess)
            output_dict = sess.run(tensor_dict)

        self.assertEquals((4, 5, 3),
                          output_dict[fields.InputDataFields.image].shape)
        self.assertEquals(
            [2], output_dict[fields.InputDataFields.groundtruth_classes])
        self.assertEquals(
            (1, 4),
            output_dict[fields.InputDataFields.groundtruth_boxes].shape)
        self.assertAllEqual(
            [0.0, 0.0, 1.0, 1.0],
            output_dict[fields.InputDataFields.groundtruth_boxes][0])
        self.assertAllEqual((1, 4, 5), output_dict[
            fields.InputDataFields.groundtruth_instance_masks].shape)
コード例 #5
0
    def test_get_configs_from_multiple_files(self):
        """Tests that proto configs can be read from multiple files."""
        temp_dir = self.get_temp_dir()

        # Write model config file.
        model_config_path = os.path.join(temp_dir, "model.config")
        model = model_pb2.DetectionModel()
        model.faster_rcnn.num_classes = 10
        _write_config(model, model_config_path)

        # Write train config file.
        train_config_path = os.path.join(temp_dir, "train.config")
        train_config = train_config = train_pb2.TrainConfig()
        train_config.batch_size = 32
        _write_config(train_config, train_config_path)

        # Write train input config file.
        train_input_config_path = os.path.join(temp_dir, "train_input.config")
        train_input_config = input_reader_pb2.InputReader()
        train_input_config.label_map_path = "path/to/label_map"
        _write_config(train_input_config, train_input_config_path)

        # Write eval config file.
        eval_config_path = os.path.join(temp_dir, "eval.config")
        eval_config = eval_pb2.EvalConfig()
        eval_config.num_examples = 20
        _write_config(eval_config, eval_config_path)

        # Write eval input config file.
        eval_input_config_path = os.path.join(temp_dir, "eval_input.config")
        eval_input_config = input_reader_pb2.InputReader()
        eval_input_config.label_map_path = "path/to/another/label_map"
        _write_config(eval_input_config, eval_input_config_path)

        configs = config_util.get_configs_from_multiple_files(
            model_config_path=model_config_path,
            train_config_path=train_config_path,
            train_input_config_path=train_input_config_path,
            eval_config_path=eval_config_path,
            eval_input_config_path=eval_input_config_path)
        self.assertProtoEquals(model, configs["model"])
        self.assertProtoEquals(train_config, configs["train_config"])
        self.assertProtoEquals(train_input_config,
                               configs["train_input_config"])
        self.assertProtoEquals(eval_config, configs["eval_config"])
        self.assertProtoEquals(eval_input_config, configs["eval_input_config"])
コード例 #6
0
 def test_raises_error_with_no_input_paths(self):
   input_reader_text_proto = """
     shuffle: false
     num_readers: 1
     load_instance_masks: true
   """
   input_reader_proto = input_reader_pb2.InputReader()
   text_format.Merge(input_reader_text_proto, input_reader_proto)
   with self.assertRaises(ValueError):
     dataset_builder.build(input_reader_proto, batch_size=1)
コード例 #7
0
  def test_disable_shuffle_(self):
    config = input_reader_pb2.InputReader()
    config.num_readers = 1
    config.shuffle = False

    data = self._get_dataset_next(
        [self._shuffle_path_template % '*'], config, batch_size=10)
    expected_non_shuffle_output = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]

    with self.test_session() as sess:
      self.assertAllEqual(sess.run(data), [expected_non_shuffle_output])
コード例 #8
0
  def test_reduce_num_reader(self):
    config = input_reader_pb2.InputReader()
    config.num_readers = 10
    config.shuffle = False

    data = self._get_dataset_next([self._path_template % '*'], config,
                                  batch_size=20)
    with self.test_session() as sess:
      self.assertAllEqual(sess.run(data),
                          [[1, 10, 2, 20, 3, 30, 4, 40, 5, 50, 1, 10, 2, 20, 3,
                            30, 4, 40, 5, 50]])
コード例 #9
0
  def test_read_dataset_single_epoch(self):
    config = input_reader_pb2.InputReader()
    config.num_epochs = 1
    config.num_readers = 1
    config.shuffle = False

    data = self._get_dataset_next([self._path_template % '0'], config,
                                  batch_size=30)
    with self.test_session() as sess:
      # First batch will retrieve as much as it can, second batch will fail.
      self.assertAllEqual(sess.run(data), [[1, 10]])
      self.assertRaises(tf.errors.OutOfRangeError, sess.run, data)
コード例 #10
0
  def test_enable_shuffle(self):
    config = input_reader_pb2.InputReader()
    config.num_readers = 1
    config.shuffle = True

    tf.set_random_seed(1)  # Set graph level seed.
    data = self._get_dataset_next(
        [self._shuffle_path_template % '*'], config, batch_size=10)
    expected_non_shuffle_output = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]

    with self.test_session() as sess:
      self.assertTrue(
          np.any(np.not_equal(sess.run(data), expected_non_shuffle_output)))