Esempio n. 1
0
    def test_configure_trainer_and_train_two_steps(self):
        print('\n=================================================')
        print('test_configure_trainer_and_train_two_steps')

        train_config_text_proto = """
									optimizer {
										adam_optimizer {
											learning_rate {
												constant_learning_rate {
													learning_rate: 0.01
												}
											}
										}
									}
									
									num_steps: 5
									"""
        train_config = train_pb2.TrainConfig()
        text_format.Merge(train_config_text_proto, train_config)

        train_dir = self.get_temp_dir()
        print('\ntrain_dir: {}\n'.format(train_dir))

        trainer.train(create_tensor_dict_fn=get_input_function,
                      create_model_fn=FakeDetectionModel,
                      train_config=train_config,
                      master='',
                      task=0,
                      num_clones=1,
                      worker_replicas=1,
                      clone_on_cpu=True,
                      ps_tasks=0,
                      worker_job_name='worker',
                      is_chief=True,
                      train_dir=train_dir)
Esempio n. 2
0
def get_configs_from_multiple_files():
  """Reads training configuration from multiple config files.

  Reads the training config from the following files:
    model_config: Read from --model_config_path
    train_config: Read from --train_config_path
    input_config: Read from --input_config_path

  Returns:
    model_config: model_pb2.DetectionModel
    train_config: train_pb2.TrainConfig
    input_config: input_reader_pb2.InputReader
  """
  train_config = train_pb2.TrainConfig()
  with tf.gfile.GFile(FLAGS.train_config_path, 'r') as f:
    text_format.Merge(f.read(), train_config)

  model_config = model_pb2.DetectionModel()
  with tf.gfile.GFile(FLAGS.model_config_path, 'r') as f:
    text_format.Merge(f.read(), model_config)

  input_config = input_reader_pb2.InputReader()
  with tf.gfile.GFile(FLAGS.input_config_path, 'r') as f:
    text_format.Merge(f.read(), input_config)

  return model_config, train_config, input_config
Esempio n. 3
0
def update_augmentation_options(config):
    # print(type(config))
    opt = train_pb2.TrainConfig()
    num_of_opt = random.randint(1, len(AUGMENTATION_OPTIONS))
    options = random.sample(AUGMENTATION_OPTIONS, num_of_opt)
    # print(options)
    text = '\n'.join([
        "data_augmentation_options { %s { %s }}" % (i.name, '\n'.join([
            "{}: {}".format(
                o, i.format[j](round(
                    random.uniform(i.values[j][0], i.values[j][1]), 2)))
            for j, o in enumerate(i.params)
        ])) for i in options
    ])
    text_format.Merge(text, opt)

    if isinstance(config, pipeline_pb2.TrainEvalPipelineConfig):
        # train_config = config.train_config
        config.train_config.MergeFrom(opt)
    elif isinstance(config, train_pb2.TrainConfig):
        # train_config = config
        config.train_config.MergeFrom(opt)
    elif isinstance(config, dict):
        train_config = config['train_config']
        train_config.MergeFrom(opt)
        config['train_config'] = train_config
    else:
        raise AttributeError(
            "Wrong config format. `TrainEvalPipelineConfig` or `TrainConfig` required"
        )

    # print(text)
    return config
Esempio n. 4
0
def get_configs_from_multiple_files(model_config_path="",
                                    train_config_path="",
                                    train_input_config_path="",
                                    eval_config_path="",
                                    eval_input_config_path="",
                                    offline_eval_input_config_path=""):
    """Reads training configuration from multiple config files.

  Args:
    model_config_path: Path to model_pb2.DetectionModel.
    train_config_path: Path to train_pb2.TrainConfig.
    train_input_config_path: Path to input_reader_pb2.InputReader.
    eval_config_path: Path to eval_pb2.EvalConfig.
    eval_input_config_path: Path to input_reader_pb2.InputReader.

  Returns:
    Dictionary of configuration objects. Keys are `model`, `train_config`,
      `train_input_config`, `eval_config`, `eval_input_config`. Key/Values are
        returned only for valid (non-empty) strings.
  """
    configs = {}
    if model_config_path:
        model_config = model_pb2.DetectionModel()
        with tf.gfile.GFile(model_config_path, "r") as f:
            text_format.Merge(f.read(), model_config)
            configs["model"] = model_config

    if train_config_path:
        train_config = train_pb2.TrainConfig()
        with tf.gfile.GFile(train_config_path, "r") as f:
            text_format.Merge(f.read(), train_config)
            configs["train_config"] = train_config

    if train_input_config_path:
        train_input_config = input_reader_pb2.InputReader()
        with tf.gfile.GFile(train_input_config_path, "r") as f:
            text_format.Merge(f.read(), train_input_config)
            configs["train_input_config"] = train_input_config

    if eval_config_path:
        eval_config = eval_pb2.EvalConfig()
        with tf.gfile.GFile(eval_config_path, "r") as f:
            text_format.Merge(f.read(), eval_config)
            configs["eval_config"] = eval_config

    if eval_input_config_path:
        eval_input_config = input_reader_pb2.InputReader()
        with tf.gfile.GFile(eval_input_config_path, "r") as f:
            text_format.Merge(f.read(), eval_input_config)
            configs["eval_input_config"] = eval_input_config

    if offline_eval_input_config_path:
        offline_eval_input_config = input_reader_pb2.InputReader()
        with tf.gfile.GFile(offline_eval_input_config_path, "r") as f:
            text_format.Merge(f.read(), offline_eval_input_config)
            configs["offline_eval_input_config"] = offline_eval_input_config

    return configs
Esempio n. 5
0
    def test_get_configs_from_multiple_files(self):
        """Tests that proto configs can be read from multiple files."""
        print(
            '\n=========================================================================='
        )
        print('test_get_configs_from_multiple_files')

        #temp_dir = self.get_temp_dir()
        temp_dir = '/home/zq/tmp/'

        # Write model config file.
        model_config_path = os.path.join(temp_dir, "model.config")
        model = model_pb2.DetectionModel()
        model.faster_rcnn.num_classes = 10
        _write_config(model, model_config_path)

        # Write train config file.
        train_config_path = os.path.join(temp_dir, "train.config")
        train_config = train_pb2.TrainConfig()
        train_config.batch_size = 32
        _write_config(train_config, train_config_path)

        # Write train input config file.
        train_input_config_path = os.path.join(temp_dir, "train_input.config")
        train_input_config = input_reader_pb2.InputReader()
        train_input_config.label_map_path = "path/to/label_map"
        _write_config(train_input_config, train_input_config_path)

        # Write eval config file.
        eval_config_path = os.path.join(temp_dir, "eval.config")
        eval_config = eval_pb2.EvalConfig()
        eval_config.num_examples = 20
        _write_config(eval_config, eval_config_path)

        # Write eval input config file.
        eval_input_config_path = os.path.join(temp_dir, "eval_input.config")
        eval_input_config = input_reader_pb2.InputReader()
        eval_input_config.label_map_path = "path/to/another/label_map"
        _write_config(eval_input_config, eval_input_config_path)

        configs = config_util.get_configs_from_multiple_files(
            model_config_path=model_config_path,
            train_config_path=train_config_path,
            train_input_config_path=train_input_config_path,
            eval_config_path=eval_config_path,
            eval_input_config_path=eval_input_config_path)

        self.assertProtoEquals(model, configs["model"])
        self.assertProtoEquals(train_config, configs["train_config"])
        self.assertProtoEquals(train_input_config,
                               configs["train_input_config"])
        self.assertProtoEquals(eval_config, configs["eval_config"])
        self.assertProtoEquals(eval_input_config, configs["eval_input_config"])
Esempio n. 6
0
    def test_configure_trainer_with_multiclass_scores_and_train_two_steps(
            self):
        print(
            '\n========================================================================'
        )
        print(
            'test_configure_trainer_with_multiclass_scores_and_train_two_steps'
        )

        train_config_text_proto = """
									optimizer {
										adam_optimizer {
											learning_rate {
												constant_learning_rate {
													learning_rate: 0.01
												}
											}
										}
									}
									data_augmentation_options {
										random_adjust_brightness {
											max_delta: 0.2
										}
									}
									data_augmentation_options {
										random_adjust_contrast {
											min_delta: 0.7
											max_delta: 1.1
										}
									}
									num_steps: 2
									use_multiclass_scores: true
									"""
        train_config = train_pb2.TrainConfig()
        text_format.Merge(train_config_text_proto, train_config)

        train_dir = self.get_temp_dir()

        trainer.train(create_tensor_dict_fn=get_input_function,
                      create_model_fn=FakeDetectionModel,
                      train_config=train_config,
                      master='',
                      task=0,
                      num_clones=1,
                      worker_replicas=1,
                      clone_on_cpu=True,
                      ps_tasks=0,
                      worker_job_name='worker',
                      is_chief=True,
                      train_dir=train_dir)
def get_configs_from_multiple_files():

  train_config = train_pb2.TrainConfig()
  with tf.gfile.GFile(FLAGS.train_config_path, 'r') as f:
    text_format.Merge(f.read(), train_config)

  model_config = model_pb2.DetectionModel()
  with tf.gfile.GFile(FLAGS.model_config_path, 'r') as f:
    text_format.Merge(f.read(), model_config)

  input_config = input_reader_pb2.InputReader()
  with tf.gfile.GFile(FLAGS.input_config_path, 'r') as f:
    text_format.Merge(f.read(), input_config)

  return model_config, train_config, input_config