def get_configs_from_multiple_files(): """Reads evaluation configuration from multiple config files. Reads the evaluation config from the following files: model_config: Read from --model_config_path eval_config: Read from --eval_config_path input_config: Read from --input_config_path Returns: model_config: a model_pb2.DetectionModel eval_config: a eval_pb2.EvalConfig input_config: a input_reader_pb2.InputReader """ eval_config = eval_pb2.EvalConfig() with tf.gfile.GFile(FLAGS.eval_config_path, 'r') as f: text_format.Merge(f.read(), eval_config) model_config = model_pb2.DetectionModel() with tf.gfile.GFile(FLAGS.model_config_path, 'r') as f: text_format.Merge(f.read(), model_config) input_config = input_reader_pb2.InputReader() with tf.gfile.GFile(FLAGS.input_config_path, 'r') as f: text_format.Merge(f.read(), input_config) return model_config, eval_config, input_config
def get_configs_from_multiple_files(model_config_path="", train_config_path="", train_input_config_path="", eval_config_path="", eval_input_config_path="", graph_rewriter_config_path=""): """Reads training configuration from multiple config files. Args: model_config_path: Path to model_pb2.DetectionModel. train_config_path: Path to train_pb2.TrainConfig. train_input_config_path: Path to input_reader_pb2.InputReader. eval_config_path: Path to eval_pb2.EvalConfig. eval_input_config_path: Path to input_reader_pb2.InputReader. graph_rewriter_config_path: Path to graph_rewriter_pb2.GraphRewriter. Returns: Dictionary of configuration objects. Keys are `model`, `train_config`, `train_input_config`, `eval_config`, `eval_input_config`. Key/Values are returned only for valid (non-empty) strings. """ configs = {} if model_config_path: model_config = model_pb2.DetectionModel() with tf.gfile.GFile(model_config_path, "r") as f: text_format.Merge(f.read(), model_config) configs["model"] = model_config if train_config_path: train_config = train_pb2.TrainConfig() with tf.gfile.GFile(train_config_path, "r") as f: text_format.Merge(f.read(), train_config) configs["train_config"] = train_config if train_input_config_path: train_input_config = input_reader_pb2.InputReader() with tf.gfile.GFile(train_input_config_path, "r") as f: text_format.Merge(f.read(), train_input_config) configs["train_input_config"] = train_input_config if eval_config_path: eval_config = eval_pb2.EvalConfig() with tf.gfile.GFile(eval_config_path, "r") as f: text_format.Merge(f.read(), eval_config) configs["eval_config"] = eval_config if eval_input_config_path: eval_input_config = input_reader_pb2.InputReader() with tf.gfile.GFile(eval_input_config_path, "r") as f: text_format.Merge(f.read(), eval_input_config) configs["eval_input_configs"] = [eval_input_config] if graph_rewriter_config_path: configs["graph_rewriter_config"] = get_graph_rewriter_config_from_file( graph_rewriter_config_path) return configs
def test_get_configs_from_multiple_files(self): """Tests that proto configs can be read from multiple files.""" temp_dir = self.get_temp_dir() # Write model config file. model_config_path = os.path.join(temp_dir, "model.config") model = model_pb2.DetectionModel() model.faster_rcnn.num_classes = 10 _write_config(model, model_config_path) # Write train config file. train_config_path = os.path.join(temp_dir, "train.config") train_config = train_config = train_pb2.TrainConfig() train_config.batch_size = 32 _write_config(train_config, train_config_path) # Write train input config file. train_input_config_path = os.path.join(temp_dir, "train_input.config") train_input_config = input_reader_pb2.InputReader() train_input_config.label_map_path = "path/to/label_map" _write_config(train_input_config, train_input_config_path) # Write eval config file. eval_config_path = os.path.join(temp_dir, "eval.config") eval_config = eval_pb2.EvalConfig() eval_config.num_examples = 20 _write_config(eval_config, eval_config_path) # Write eval input config file. eval_input_config_path = os.path.join(temp_dir, "eval_input.config") eval_input_config = input_reader_pb2.InputReader() eval_input_config.label_map_path = "path/to/another/label_map" _write_config(eval_input_config, eval_input_config_path) configs = config_util.get_configs_from_multiple_files( model_config_path=model_config_path, train_config_path=train_config_path, train_input_config_path=train_input_config_path, eval_config_path=eval_config_path, eval_input_config_path=eval_input_config_path) self.assertProtoEquals(model, configs["model"]) self.assertProtoEquals(train_config, configs["train_config"]) self.assertProtoEquals(train_input_config, configs["train_input_config"]) self.assertProtoEquals(eval_config, configs["eval_config"]) self.assertProtoEquals(eval_input_config, configs["eval_input_config"])