def _train(self): tf.disable_eager_execution() ps_tasks = 0 worker_replicas = 1 worker_job_name = 'lonely_worker' task = 0 is_chief = True master = '' graph_rewriter_fn = None # loading and reading the config file configs = create_configs_from_pipeline_proto(self.pipeline) model_config = configs['model'] train_config = configs['train_config'] input_config = configs['train_input_config'] # creating the tf object detection api model (from the config parameters) model_fn = functools.partial(model_builder.build, model_config=model_config, is_training=True) def get_next(config): return dataset_builder.make_initializable_iterator(dataset_builder.build(config)).get_next() create_input_dict_fn = functools.partial(get_next, input_config) if 'graph_rewriter_config' in configs: graph_rewriter_fn = graph_rewriter_builder.build(configs['graph_rewriter_config'], is_training=True) # training the model with the new parameters trainer.train(create_input_dict_fn, model_fn, train_config, master, task, 1, worker_replicas, False, ps_tasks, worker_job_name, is_chief, str(self._out_folder), graph_hook_fn=graph_rewriter_fn)
def test_create_configs_from_pipeline_proto(self): """Tests creating configs dictionary from pipeline proto.""" print( '\n==================================================================' ) print('test_create_configs_from_pipeline_proto') pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config.model.faster_rcnn.num_classes = 10 pipeline_config.train_config.batch_size = 32 pipeline_config.train_input_reader.label_map_path = "path/to/label_map" pipeline_config.eval_config.num_examples = 20 pipeline_config.eval_input_reader.queue_capacity = 100 configs = config_util.create_configs_from_pipeline_proto( pipeline_config) self.assertProtoEquals(pipeline_config.model, configs["model"]) self.assertProtoEquals(pipeline_config.train_config, configs["train_config"]) self.assertProtoEquals(pipeline_config.train_input_reader, configs["train_input_config"]) self.assertProtoEquals(pipeline_config.eval_config, configs["eval_config"]) self.assertProtoEquals(pipeline_config.eval_input_reader, configs["eval_input_config"])
def test_create_configs_from_pipeline_proto(self): """Tests creating configs dictionary from pipeline proto.""" pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config.model.faster_rcnn.num_classes = 10 pipeline_config.train_config.batch_size = 32 pipeline_config.train_input_reader.label_map_path = "path/to/label_map" pipeline_config.eval_config.num_examples = 20 pipeline_config.eval_input_reader.queue_capacity = 100 configs = config_util.create_configs_from_pipeline_proto(pipeline_config) self.assertProtoEquals(pipeline_config.model, configs["model"]) self.assertProtoEquals(pipeline_config.train_config, configs["train_config"]) self.assertProtoEquals(pipeline_config.train_input_reader, configs["train_input_config"]) self.assertProtoEquals(pipeline_config.eval_config, configs["eval_config"]) self.assertProtoEquals(pipeline_config.eval_input_reader, configs["eval_input_config"])
def _set_config_paths(self): configs = create_configs_from_pipeline_proto(self.pipeline) update_input_reader_config(configs, key_name="train_input_config", input_name=None, field_name="input_path", value=str(self._val_record_file), path_updater=_update_tf_record_input_path) update_input_reader_config(configs, key_name="eval_input_configs", input_name=None, field_name="input_path", value=str(self._train_record_file), path_updater=_update_tf_record_input_path) update_dict = { "label_map_path": str(self._labels_map_file), "train_config.fine_tune_checkpoint": str(self._checkpoint_model_folder.joinpath("model.ckpt")) } configs = merge_external_params_with_configs(configs, kwargs_dict=update_dict) self._pipeline = create_pipeline_proto_from_configs(configs)
def get_configs_from_pipeline_string(pipeline_config_string): pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() text_format.Merge(pipeline_config_string, pipeline_config) return config_util.create_configs_from_pipeline_proto(pipeline_config)