def get_configs_from_pipeline_file(): """Reads training configuration from a os_pipeline_pb2.CoLocTrainEvalPipelineConfig. Reads training config from file specified by pipeline_config_path flag. Returns: model_config: attention_model_pb2.CoLocDetectionModel train_config: train_pb2.TrainConfig input_config: coloc_input_reader_pb2.CoLocInputReader """ pipeline_config = attention_pipeline_pb2.AttentionTrainEvalPipelineConfig() with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f: text_format.Merge(f.read(), pipeline_config) model_config = pipeline_config.model train_config = pipeline_config.train_config input_config = pipeline_config.train_input_reader return model_config, train_config, input_config
def get_configs_from_pipeline_file(): """Reads evaluation configuration from a pipeline_pb2.TrainEvalPipelineConfig. Reads evaluation config from file specified by pipeline_config_path flag. Returns: model_config: a model_pb2.DetectionModel eval_config: a eval_pb2.EvalConfig input_config: a input_reader_pb2.InputReader """ pipeline_config = attention_pipeline_pb2.AttentionTrainEvalPipelineConfig() with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f: text_format.Merge(f.read(), pipeline_config) model_config = pipeline_config.model if FLAGS.eval_training_data: eval_config = pipeline_config.train_config else: eval_config = pipeline_config.eval_config input_config = pipeline_config.eval_input_reader return model_config, eval_config, input_config
def read_pipeline_config(config_path): pipeline_config = attention_pipeline_pb2.AttentionTrainEvalPipelineConfig() assert os.path.exists(config_path), 'Config file does not exist' with open(config_path, 'r') as f: text_format.Merge(f.read(), pipeline_config) return pipeline_config