def get_configs_from_multiple_files(model_config_path="", train_config_path="", train_input_config_path="", eval_config_path="", eval_input_config_path="", graph_rewriter_config_path=""): """Reads training configuration from multiple config files. Args: model_config_path: Path to model_pb2.DetectionModel. train_config_path: Path to train_pb2.TrainConfig. train_input_config_path: Path to input_reader_pb2.InputReader. eval_config_path: Path to eval_pb2.EvalConfig. eval_input_config_path: Path to input_reader_pb2.InputReader. graph_rewriter_config_path: Path to graph_rewriter_pb2.GraphRewriter. Returns: Dictionary of configuration objects. Keys are `model`, `train_config`, `train_input_config`, `eval_config`, `eval_input_config`. Key/Values are returned only for valid (non-empty) strings. """ configs = {} if model_config_path: model_config = model_pb2.DetectionModel() with tf.gfile.GFile(model_config_path, "r") as f: text_format.Merge(f.read(), model_config) configs["model"] = model_config if train_config_path: train_config = train_pb2.TrainConfig() with tf.gfile.GFile(train_config_path, "r") as f: text_format.Merge(f.read(), train_config) configs["train_config"] = train_config if train_input_config_path: train_input_config = input_reader_pb2.InputReader() with tf.gfile.GFile(train_input_config_path, "r") as f: text_format.Merge(f.read(), train_input_config) configs["train_input_config"] = train_input_config if eval_config_path: eval_config = eval_pb2.EvalConfig() with tf.gfile.GFile(eval_config_path, "r") as f: text_format.Merge(f.read(), eval_config) configs["eval_config"] = eval_config if eval_input_config_path: eval_input_config = input_reader_pb2.InputReader() with tf.gfile.GFile(eval_input_config_path, "r") as f: text_format.Merge(f.read(), eval_input_config) configs["eval_input_config"] = eval_input_config if graph_rewriter_config_path: configs["graph_rewriter_config"] = get_graph_rewriter_config_from_file( graph_rewriter_config_path) return configs
def test_build_tf_record_input_reader(self): tf_record_path = self.create_tf_record() input_reader_text_proto = """ shuffle: false num_readers: 1 tf_record_input_reader {{ input_path: '{0}' }} """.format(tf_record_path) input_reader_proto = input_reader_pb2.InputReader() text_format.Merge(input_reader_text_proto, input_reader_proto) tensor_dict = dataset_builder.make_initializable_iterator( dataset_builder.build(input_reader_proto, batch_size=1)).get_next() sv = tf.train.Supervisor(logdir=self.get_temp_dir()) with sv.prepare_or_wait_for_session() as sess: sv.start_queue_runners(sess) output_dict = sess.run(tensor_dict) self.assertTrue( fields.InputDataFields.groundtruth_instance_masks not in output_dict) self.assertEquals((1, 4, 5, 3), output_dict[fields.InputDataFields.image].shape) self.assertAllEqual([[2]], output_dict[fields.InputDataFields.groundtruth_classes]) self.assertEquals( (1, 1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape) self.assertAllEqual( [0.0, 0.0, 1.0, 1.0], output_dict[fields.InputDataFields.groundtruth_boxes][0][0])
def test_build_tf_record_input_reader_with_batch_size_two_and_masks(self): tf_record_path = self.create_tf_record() input_reader_text_proto = """ shuffle: false num_readers: 1 load_instance_masks: true tf_record_input_reader {{ input_path: '{0}' }} """.format(tf_record_path) input_reader_proto = input_reader_pb2.InputReader() text_format.Merge(input_reader_text_proto, input_reader_proto) def one_hot_class_encoding_fn(tensor_dict): tensor_dict[fields.InputDataFields.groundtruth_classes] = tf.one_hot( tensor_dict[fields.InputDataFields.groundtruth_classes] - 1, depth=3) return tensor_dict tensor_dict = dataset_builder.make_initializable_iterator( dataset_builder.build( input_reader_proto, transform_input_data_fn=one_hot_class_encoding_fn, batch_size=2)).get_next() sv = tf.train.Supervisor(logdir=self.get_temp_dir()) with sv.prepare_or_wait_for_session() as sess: sv.start_queue_runners(sess) output_dict = sess.run(tensor_dict) self.assertAllEqual( [2, 1, 4, 5], output_dict[fields.InputDataFields.groundtruth_instance_masks].shape)
def test_build_tf_record_input_reader_and_load_instance_masks(self): tf_record_path = self.create_tf_record() input_reader_text_proto = """ shuffle: false num_readers: 1 load_instance_masks: true tf_record_input_reader {{ input_path: '{0}' }} """.format(tf_record_path) input_reader_proto = input_reader_pb2.InputReader() text_format.Merge(input_reader_text_proto, input_reader_proto) tensor_dict = input_reader_builder.build(input_reader_proto) sv = tf.train.Supervisor(logdir=self.get_temp_dir()) with sv.prepare_or_wait_for_session() as sess: sv.start_queue_runners(sess) output_dict = sess.run(tensor_dict) self.assertEquals((4, 5, 3), output_dict[fields.InputDataFields.image].shape) self.assertEquals( [2], output_dict[fields.InputDataFields.groundtruth_classes]) self.assertEquals( (1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape) self.assertAllEqual( [0.0, 0.0, 1.0, 1.0], output_dict[fields.InputDataFields.groundtruth_boxes][0]) self.assertAllEqual((1, 4, 5), output_dict[ fields.InputDataFields.groundtruth_instance_masks].shape)
def test_get_configs_from_multiple_files(self): """Tests that proto configs can be read from multiple files.""" temp_dir = self.get_temp_dir() # Write model config file. model_config_path = os.path.join(temp_dir, "model.config") model = model_pb2.DetectionModel() model.faster_rcnn.num_classes = 10 _write_config(model, model_config_path) # Write train config file. train_config_path = os.path.join(temp_dir, "train.config") train_config = train_config = train_pb2.TrainConfig() train_config.batch_size = 32 _write_config(train_config, train_config_path) # Write train input config file. train_input_config_path = os.path.join(temp_dir, "train_input.config") train_input_config = input_reader_pb2.InputReader() train_input_config.label_map_path = "path/to/label_map" _write_config(train_input_config, train_input_config_path) # Write eval config file. eval_config_path = os.path.join(temp_dir, "eval.config") eval_config = eval_pb2.EvalConfig() eval_config.num_examples = 20 _write_config(eval_config, eval_config_path) # Write eval input config file. eval_input_config_path = os.path.join(temp_dir, "eval_input.config") eval_input_config = input_reader_pb2.InputReader() eval_input_config.label_map_path = "path/to/another/label_map" _write_config(eval_input_config, eval_input_config_path) configs = config_util.get_configs_from_multiple_files( model_config_path=model_config_path, train_config_path=train_config_path, train_input_config_path=train_input_config_path, eval_config_path=eval_config_path, eval_input_config_path=eval_input_config_path) self.assertProtoEquals(model, configs["model"]) self.assertProtoEquals(train_config, configs["train_config"]) self.assertProtoEquals(train_input_config, configs["train_input_config"]) self.assertProtoEquals(eval_config, configs["eval_config"]) self.assertProtoEquals(eval_input_config, configs["eval_input_config"])
def test_raises_error_with_no_input_paths(self): input_reader_text_proto = """ shuffle: false num_readers: 1 load_instance_masks: true """ input_reader_proto = input_reader_pb2.InputReader() text_format.Merge(input_reader_text_proto, input_reader_proto) with self.assertRaises(ValueError): dataset_builder.build(input_reader_proto, batch_size=1)
def test_disable_shuffle_(self): config = input_reader_pb2.InputReader() config.num_readers = 1 config.shuffle = False data = self._get_dataset_next( [self._shuffle_path_template % '*'], config, batch_size=10) expected_non_shuffle_output = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] with self.test_session() as sess: self.assertAllEqual(sess.run(data), [expected_non_shuffle_output])
def test_reduce_num_reader(self): config = input_reader_pb2.InputReader() config.num_readers = 10 config.shuffle = False data = self._get_dataset_next([self._path_template % '*'], config, batch_size=20) with self.test_session() as sess: self.assertAllEqual(sess.run(data), [[1, 10, 2, 20, 3, 30, 4, 40, 5, 50, 1, 10, 2, 20, 3, 30, 4, 40, 5, 50]])
def test_read_dataset_single_epoch(self): config = input_reader_pb2.InputReader() config.num_epochs = 1 config.num_readers = 1 config.shuffle = False data = self._get_dataset_next([self._path_template % '0'], config, batch_size=30) with self.test_session() as sess: # First batch will retrieve as much as it can, second batch will fail. self.assertAllEqual(sess.run(data), [[1, 10]]) self.assertRaises(tf.errors.OutOfRangeError, sess.run, data)
def test_enable_shuffle(self): config = input_reader_pb2.InputReader() config.num_readers = 1 config.shuffle = True tf.set_random_seed(1) # Set graph level seed. data = self._get_dataset_next( [self._shuffle_path_template % '*'], config, batch_size=10) expected_non_shuffle_output = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] with self.test_session() as sess: self.assertTrue( np.any(np.not_equal(sess.run(data), expected_non_shuffle_output)))