Пример #1
0
def get_configs_from_multiple_files(model_config_path="",
                                    train_config_path="",
                                    train_input_config_path="",
                                    eval_config_path="",
                                    eval_input_config_path="",
                                    graph_rewriter_config_path=""):
  """Reads training configuration from multiple config files.

  Args:
    model_config_path: Path to model_pb2.DetectionModel.
    train_config_path: Path to train_pb2.TrainConfig.
    train_input_config_path: Path to input_reader_pb2.InputReader.
    eval_config_path: Path to eval_pb2.EvalConfig.
    eval_input_config_path: Path to input_reader_pb2.InputReader.
    graph_rewriter_config_path: Path to graph_rewriter_pb2.GraphRewriter.

  Returns:
    Dictionary of configuration objects. Keys are `model`, `train_config`,
      `train_input_config`, `eval_config`, `eval_input_config`. Key/Values are
        returned only for valid (non-empty) strings.
  """
  configs = {}
  if model_config_path:
    model_config = model_pb2.DetectionModel()
    with tf.gfile.GFile(model_config_path, "r") as f:
      text_format.Merge(f.read(), model_config)
      configs["model"] = model_config

  if train_config_path:
    train_config = train_pb2.TrainConfig()
    with tf.gfile.GFile(train_config_path, "r") as f:
      text_format.Merge(f.read(), train_config)
      configs["train_config"] = train_config

  if train_input_config_path:
    train_input_config = input_reader_pb2.InputReader()
    with tf.gfile.GFile(train_input_config_path, "r") as f:
      text_format.Merge(f.read(), train_input_config)
      configs["train_input_config"] = train_input_config

  if eval_config_path:
    eval_config = eval_pb2.EvalConfig()
    with tf.gfile.GFile(eval_config_path, "r") as f:
      text_format.Merge(f.read(), eval_config)
      configs["eval_config"] = eval_config

  if eval_input_config_path:
    eval_input_config = input_reader_pb2.InputReader()
    with tf.gfile.GFile(eval_input_config_path, "r") as f:
      text_format.Merge(f.read(), eval_input_config)
      configs["eval_input_configs"] = [eval_input_config]

  if graph_rewriter_config_path:
    configs["graph_rewriter_config"] = get_graph_rewriter_config_from_file(
        graph_rewriter_config_path)

  return configs
Пример #2
0
    def test_build_tf_record_input_reader(self):
        tf_record_path = self.create_tf_record()

        input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
        input_reader_proto = input_reader_pb2.InputReader()
        text_format.Merge(input_reader_text_proto, input_reader_proto)
        tensor_dict = dataset_builder.make_initializable_iterator(
            dataset_builder.build(input_reader_proto,
                                  batch_size=1)).get_next()

        with tf.train.MonitoredSession() as sess:
            output_dict = sess.run(tensor_dict)

        self.assertTrue(fields.InputDataFields.groundtruth_instance_masks
                        not in output_dict)
        self.assertEquals((1, 4, 5, 3),
                          output_dict[fields.InputDataFields.image].shape)
        self.assertAllEqual(
            [[2]], output_dict[fields.InputDataFields.groundtruth_classes])
        self.assertEquals(
            (1, 1, 4),
            output_dict[fields.InputDataFields.groundtruth_boxes].shape)
        self.assertAllEqual(
            [0.0, 0.0, 1.0, 1.0],
            output_dict[fields.InputDataFields.groundtruth_boxes][0][0])
Пример #3
0
    def test_build_tf_record_input_reader_with_batch_size_two_and_masks(self):
        tf_record_path = self.create_tf_record()

        input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      load_instance_masks: true
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
        input_reader_proto = input_reader_pb2.InputReader()
        text_format.Merge(input_reader_text_proto, input_reader_proto)

        def one_hot_class_encoding_fn(tensor_dict):
            tensor_dict[
                fields.InputDataFields.groundtruth_classes] = tf.one_hot(
                    tensor_dict[fields.InputDataFields.groundtruth_classes] -
                    1,
                    depth=3)
            return tensor_dict

        tensor_dict = dataset_builder.make_initializable_iterator(
            dataset_builder.build(
                input_reader_proto,
                transform_input_data_fn=one_hot_class_encoding_fn,
                batch_size=2)).get_next()

        with tf.train.MonitoredSession() as sess:
            output_dict = sess.run(tensor_dict)

        self.assertAllEqual([2, 1, 4, 5], output_dict[
            fields.InputDataFields.groundtruth_instance_masks].shape)
    def test_build_tf_record_input_reader_and_load_instance_masks(self):
        tf_record_path = self.create_tf_record()

        input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      load_instance_masks: true
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
        input_reader_proto = input_reader_pb2.InputReader()
        text_format.Merge(input_reader_text_proto, input_reader_proto)
        tensor_dict = input_reader_builder.build(input_reader_proto)

        with tf.train.MonitoredSession() as sess:
            output_dict = sess.run(tensor_dict)

        self.assertEquals((4, 5, 3),
                          output_dict[fields.InputDataFields.image].shape)
        self.assertEquals(
            [2], output_dict[fields.InputDataFields.groundtruth_classes])
        self.assertEquals(
            (1, 4),
            output_dict[fields.InputDataFields.groundtruth_boxes].shape)
        self.assertAllEqual(
            [0.0, 0.0, 1.0, 1.0],
            output_dict[fields.InputDataFields.groundtruth_boxes][0])
        self.assertAllEqual((1, 4, 5), output_dict[
            fields.InputDataFields.groundtruth_instance_masks].shape)
Пример #5
0
    def test_get_configs_from_multiple_files(self):
        """Tests that proto configs can be read from multiple files."""
        temp_dir = self.get_temp_dir()

        # Write model config file.
        model_config_path = os.path.join(temp_dir, "model.config")
        model = model_pb2.DetectionModel()
        model.faster_rcnn.num_classes = 10
        _write_config(model, model_config_path)

        # Write train config file.
        train_config_path = os.path.join(temp_dir, "train.config")
        train_config = train_config = train_pb2.TrainConfig()
        train_config.batch_size = 32
        _write_config(train_config, train_config_path)

        # Write train input config file.
        train_input_config_path = os.path.join(temp_dir, "train_input.config")
        train_input_config = input_reader_pb2.InputReader()
        train_input_config.label_map_path = "path/to/label_map"
        _write_config(train_input_config, train_input_config_path)

        # Write eval config file.
        eval_config_path = os.path.join(temp_dir, "eval.config")
        eval_config = eval_pb2.EvalConfig()
        eval_config.num_examples = 20
        _write_config(eval_config, eval_config_path)

        # Write eval input config file.
        eval_input_config_path = os.path.join(temp_dir, "eval_input.config")
        eval_input_config = input_reader_pb2.InputReader()
        eval_input_config.label_map_path = "path/to/another/label_map"
        _write_config(eval_input_config, eval_input_config_path)

        configs = config_util.get_configs_from_multiple_files(
            model_config_path=model_config_path,
            train_config_path=train_config_path,
            train_input_config_path=train_input_config_path,
            eval_config_path=eval_config_path,
            eval_input_config_path=eval_input_config_path)
        self.assertProtoEquals(model, configs["model"])
        self.assertProtoEquals(train_config, configs["train_config"])
        self.assertProtoEquals(train_input_config,
                               configs["train_input_config"])
        self.assertProtoEquals(eval_config, configs["eval_config"])
        self.assertProtoEquals(eval_input_config,
                               configs["eval_input_configs"][0])
Пример #6
0
 def test_raises_error_with_no_input_paths(self):
     input_reader_text_proto = """
   shuffle: false
   num_readers: 1
   load_instance_masks: true
 """
     input_reader_proto = input_reader_pb2.InputReader()
     text_format.Merge(input_reader_text_proto, input_reader_proto)
     with self.assertRaises(ValueError):
         dataset_builder.build(input_reader_proto, batch_size=1)
Пример #7
0
    def test_disable_shuffle_(self):
        config = input_reader_pb2.InputReader()
        config.num_readers = 1
        config.shuffle = False

        data = self._get_dataset_next([self._shuffle_path_template % '*'],
                                      config,
                                      batch_size=10)
        expected_non_shuffle_output = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]

        with self.test_session() as sess:
            self.assertAllEqual(sess.run(data), [expected_non_shuffle_output])
Пример #8
0
    def test_read_dataset_single_epoch(self):
        config = input_reader_pb2.InputReader()
        config.num_epochs = 1
        config.num_readers = 1
        config.shuffle = False

        data = self._get_dataset_next([self._path_template % '0'],
                                      config,
                                      batch_size=30)
        with self.test_session() as sess:
            # First batch will retrieve as much as it can, second batch will fail.
            self.assertAllEqual(sess.run(data), [[1, 10]])
            self.assertRaises(tf.errors.OutOfRangeError, sess.run, data)
Пример #9
0
    def test_reduce_num_reader(self):
        config = input_reader_pb2.InputReader()
        config.num_readers = 10
        config.shuffle = False

        data = self._get_dataset_next([self._path_template % '*'],
                                      config,
                                      batch_size=20)
        with self.test_session() as sess:
            self.assertAllEqual(sess.run(data), [[
                1, 10, 2, 20, 3, 30, 4, 40, 5, 50, 1, 10, 2, 20, 3, 30, 4, 40,
                5, 50
            ]])
Пример #10
0
    def test_enable_shuffle(self):
        config = input_reader_pb2.InputReader()
        config.num_readers = 1
        config.shuffle = True

        tf.set_random_seed(1)  # Set graph level seed.
        data = self._get_dataset_next([self._shuffle_path_template % '*'],
                                      config,
                                      batch_size=10)
        expected_non_shuffle_output = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]

        with self.test_session() as sess:
            self.assertTrue(
                np.any(
                    np.not_equal(sess.run(data), expected_non_shuffle_output)))
Пример #11
0
    def test_build_tf_record_input_reader_and_load_instance_masks(self):
        tf_record_path = self.create_tf_record()

        input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      load_instance_masks: true
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
        input_reader_proto = input_reader_pb2.InputReader()
        text_format.Merge(input_reader_text_proto, input_reader_proto)
        tensor_dict = dataset_builder.make_initializable_iterator(
            dataset_builder.build(input_reader_proto,
                                  batch_size=1)).get_next()

        with tf.train.MonitoredSession() as sess:
            output_dict = sess.run(tensor_dict)
        self.assertAllEqual((1, 1, 4, 5), output_dict[
            fields.InputDataFields.groundtruth_instance_masks].shape)
Пример #12
0
    def test_sample_one_of_n_shards(self):
        tf_record_path = self.create_tf_record(num_examples=4)

        input_reader_text_proto = """
      shuffle: false
      num_readers: 1
      sample_1_of_n_examples: 2
      tf_record_input_reader {{
        input_path: '{0}'
      }}
    """.format(tf_record_path)
        input_reader_proto = input_reader_pb2.InputReader()
        text_format.Merge(input_reader_text_proto, input_reader_proto)
        tensor_dict = dataset_builder.make_initializable_iterator(
            dataset_builder.build(input_reader_proto,
                                  batch_size=1)).get_next()

        with tf.train.MonitoredSession() as sess:
            output_dict = sess.run(tensor_dict)
            self.assertAllEqual(['0'],
                                output_dict[fields.InputDataFields.source_id])
            output_dict = sess.run(tensor_dict)
            self.assertEquals(['2'],
                              output_dict[fields.InputDataFields.source_id])