def test_configure_trainer_and_train_two_steps(self): print('\n=================================================') print('test_configure_trainer_and_train_two_steps') train_config_text_proto = """ optimizer { adam_optimizer { learning_rate { constant_learning_rate { learning_rate: 0.01 } } } } num_steps: 5 """ train_config = train_pb2.TrainConfig() text_format.Merge(train_config_text_proto, train_config) train_dir = self.get_temp_dir() print('\ntrain_dir: {}\n'.format(train_dir)) trainer.train(create_tensor_dict_fn=get_input_function, create_model_fn=FakeDetectionModel, train_config=train_config, master='', task=0, num_clones=1, worker_replicas=1, clone_on_cpu=True, ps_tasks=0, worker_job_name='worker', is_chief=True, train_dir=train_dir)
def get_configs_from_multiple_files(): """Reads training configuration from multiple config files. Reads the training config from the following files: model_config: Read from --model_config_path train_config: Read from --train_config_path input_config: Read from --input_config_path Returns: model_config: model_pb2.DetectionModel train_config: train_pb2.TrainConfig input_config: input_reader_pb2.InputReader """ train_config = train_pb2.TrainConfig() with tf.gfile.GFile(FLAGS.train_config_path, 'r') as f: text_format.Merge(f.read(), train_config) model_config = model_pb2.DetectionModel() with tf.gfile.GFile(FLAGS.model_config_path, 'r') as f: text_format.Merge(f.read(), model_config) input_config = input_reader_pb2.InputReader() with tf.gfile.GFile(FLAGS.input_config_path, 'r') as f: text_format.Merge(f.read(), input_config) return model_config, train_config, input_config
def update_augmentation_options(config): # print(type(config)) opt = train_pb2.TrainConfig() num_of_opt = random.randint(1, len(AUGMENTATION_OPTIONS)) options = random.sample(AUGMENTATION_OPTIONS, num_of_opt) # print(options) text = '\n'.join([ "data_augmentation_options { %s { %s }}" % (i.name, '\n'.join([ "{}: {}".format( o, i.format[j](round( random.uniform(i.values[j][0], i.values[j][1]), 2))) for j, o in enumerate(i.params) ])) for i in options ]) text_format.Merge(text, opt) if isinstance(config, pipeline_pb2.TrainEvalPipelineConfig): # train_config = config.train_config config.train_config.MergeFrom(opt) elif isinstance(config, train_pb2.TrainConfig): # train_config = config config.train_config.MergeFrom(opt) elif isinstance(config, dict): train_config = config['train_config'] train_config.MergeFrom(opt) config['train_config'] = train_config else: raise AttributeError( "Wrong config format. `TrainEvalPipelineConfig` or `TrainConfig` required" ) # print(text) return config
def get_configs_from_multiple_files(model_config_path="", train_config_path="", train_input_config_path="", eval_config_path="", eval_input_config_path="", offline_eval_input_config_path=""): """Reads training configuration from multiple config files. Args: model_config_path: Path to model_pb2.DetectionModel. train_config_path: Path to train_pb2.TrainConfig. train_input_config_path: Path to input_reader_pb2.InputReader. eval_config_path: Path to eval_pb2.EvalConfig. eval_input_config_path: Path to input_reader_pb2.InputReader. Returns: Dictionary of configuration objects. Keys are `model`, `train_config`, `train_input_config`, `eval_config`, `eval_input_config`. Key/Values are returned only for valid (non-empty) strings. """ configs = {} if model_config_path: model_config = model_pb2.DetectionModel() with tf.gfile.GFile(model_config_path, "r") as f: text_format.Merge(f.read(), model_config) configs["model"] = model_config if train_config_path: train_config = train_pb2.TrainConfig() with tf.gfile.GFile(train_config_path, "r") as f: text_format.Merge(f.read(), train_config) configs["train_config"] = train_config if train_input_config_path: train_input_config = input_reader_pb2.InputReader() with tf.gfile.GFile(train_input_config_path, "r") as f: text_format.Merge(f.read(), train_input_config) configs["train_input_config"] = train_input_config if eval_config_path: eval_config = eval_pb2.EvalConfig() with tf.gfile.GFile(eval_config_path, "r") as f: text_format.Merge(f.read(), eval_config) configs["eval_config"] = eval_config if eval_input_config_path: eval_input_config = input_reader_pb2.InputReader() with tf.gfile.GFile(eval_input_config_path, "r") as f: text_format.Merge(f.read(), eval_input_config) configs["eval_input_config"] = eval_input_config if offline_eval_input_config_path: offline_eval_input_config = input_reader_pb2.InputReader() with tf.gfile.GFile(offline_eval_input_config_path, "r") as f: text_format.Merge(f.read(), offline_eval_input_config) configs["offline_eval_input_config"] = offline_eval_input_config return configs
def test_get_configs_from_multiple_files(self): """Tests that proto configs can be read from multiple files.""" print( '\n==========================================================================' ) print('test_get_configs_from_multiple_files') #temp_dir = self.get_temp_dir() temp_dir = '/home/zq/tmp/' # Write model config file. model_config_path = os.path.join(temp_dir, "model.config") model = model_pb2.DetectionModel() model.faster_rcnn.num_classes = 10 _write_config(model, model_config_path) # Write train config file. train_config_path = os.path.join(temp_dir, "train.config") train_config = train_pb2.TrainConfig() train_config.batch_size = 32 _write_config(train_config, train_config_path) # Write train input config file. train_input_config_path = os.path.join(temp_dir, "train_input.config") train_input_config = input_reader_pb2.InputReader() train_input_config.label_map_path = "path/to/label_map" _write_config(train_input_config, train_input_config_path) # Write eval config file. eval_config_path = os.path.join(temp_dir, "eval.config") eval_config = eval_pb2.EvalConfig() eval_config.num_examples = 20 _write_config(eval_config, eval_config_path) # Write eval input config file. eval_input_config_path = os.path.join(temp_dir, "eval_input.config") eval_input_config = input_reader_pb2.InputReader() eval_input_config.label_map_path = "path/to/another/label_map" _write_config(eval_input_config, eval_input_config_path) configs = config_util.get_configs_from_multiple_files( model_config_path=model_config_path, train_config_path=train_config_path, train_input_config_path=train_input_config_path, eval_config_path=eval_config_path, eval_input_config_path=eval_input_config_path) self.assertProtoEquals(model, configs["model"]) self.assertProtoEquals(train_config, configs["train_config"]) self.assertProtoEquals(train_input_config, configs["train_input_config"]) self.assertProtoEquals(eval_config, configs["eval_config"]) self.assertProtoEquals(eval_input_config, configs["eval_input_config"])
def test_configure_trainer_with_multiclass_scores_and_train_two_steps( self): print( '\n========================================================================' ) print( 'test_configure_trainer_with_multiclass_scores_and_train_two_steps' ) train_config_text_proto = """ optimizer { adam_optimizer { learning_rate { constant_learning_rate { learning_rate: 0.01 } } } } data_augmentation_options { random_adjust_brightness { max_delta: 0.2 } } data_augmentation_options { random_adjust_contrast { min_delta: 0.7 max_delta: 1.1 } } num_steps: 2 use_multiclass_scores: true """ train_config = train_pb2.TrainConfig() text_format.Merge(train_config_text_proto, train_config) train_dir = self.get_temp_dir() trainer.train(create_tensor_dict_fn=get_input_function, create_model_fn=FakeDetectionModel, train_config=train_config, master='', task=0, num_clones=1, worker_replicas=1, clone_on_cpu=True, ps_tasks=0, worker_job_name='worker', is_chief=True, train_dir=train_dir)
def get_configs_from_multiple_files(): train_config = train_pb2.TrainConfig() with tf.gfile.GFile(FLAGS.train_config_path, 'r') as f: text_format.Merge(f.read(), train_config) model_config = model_pb2.DetectionModel() with tf.gfile.GFile(FLAGS.model_config_path, 'r') as f: text_format.Merge(f.read(), model_config) input_config = input_reader_pb2.InputReader() with tf.gfile.GFile(FLAGS.input_config_path, 'r') as f: text_format.Merge(f.read(), input_config) return model_config, train_config, input_config