def get_configs_from_pipeline_file():
    """Reads training configuration from a os_pipeline_pb2.CoLocTrainEvalPipelineConfig.

  Reads training config from file specified by pipeline_config_path flag.

  Returns:
    model_config: attention_model_pb2.CoLocDetectionModel
    train_config: train_pb2.TrainConfig
    input_config: coloc_input_reader_pb2.CoLocInputReader
  """
    pipeline_config = attention_pipeline_pb2.AttentionTrainEvalPipelineConfig()
    with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
        text_format.Merge(f.read(), pipeline_config)

    model_config = pipeline_config.model
    train_config = pipeline_config.train_config
    input_config = pipeline_config.train_input_reader

    return model_config, train_config, input_config
示例#2
0
def get_configs_from_pipeline_file():
    """Reads evaluation configuration from a pipeline_pb2.TrainEvalPipelineConfig.

  Reads evaluation config from file specified by pipeline_config_path flag.

  Returns:
    model_config: a model_pb2.DetectionModel
    eval_config: a eval_pb2.EvalConfig
    input_config: a input_reader_pb2.InputReader
  """
    pipeline_config = attention_pipeline_pb2.AttentionTrainEvalPipelineConfig()
    with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
        text_format.Merge(f.read(), pipeline_config)

    model_config = pipeline_config.model
    if FLAGS.eval_training_data:
        eval_config = pipeline_config.train_config
    else:
        eval_config = pipeline_config.eval_config
    input_config = pipeline_config.eval_input_reader
    return model_config, eval_config, input_config
示例#3
0
def read_pipeline_config(config_path):
  pipeline_config = attention_pipeline_pb2.AttentionTrainEvalPipelineConfig()
  assert os.path.exists(config_path), 'Config file does not exist'
  with open(config_path, 'r') as f:
    text_format.Merge(f.read(), pipeline_config)
  return pipeline_config