示例#1
0
    def _train(self):
        tf.disable_eager_execution()
        ps_tasks = 0
        worker_replicas = 1
        worker_job_name = 'lonely_worker'
        task = 0
        is_chief = True
        master = ''
        graph_rewriter_fn = None
        # loading and reading  the config file
        configs = create_configs_from_pipeline_proto(self.pipeline)
        model_config = configs['model']
        train_config = configs['train_config']
        input_config = configs['train_input_config']
        # creating the tf object detection api model (from the config parameters)
        model_fn = functools.partial(model_builder.build, model_config=model_config, is_training=True)

        def get_next(config):
            return dataset_builder.make_initializable_iterator(dataset_builder.build(config)).get_next()

        create_input_dict_fn = functools.partial(get_next, input_config)
        if 'graph_rewriter_config' in configs:
            graph_rewriter_fn = graph_rewriter_builder.build(configs['graph_rewriter_config'], is_training=True)
        # training the model with the new parameters
        trainer.train(create_input_dict_fn, model_fn, train_config, master, task, 1, worker_replicas, False, ps_tasks,
                      worker_job_name, is_chief, str(self._out_folder), graph_hook_fn=graph_rewriter_fn)
示例#2
0
    def test_create_configs_from_pipeline_proto(self):
        """Tests creating configs dictionary from pipeline proto."""

        print(
            '\n=================================================================='
        )
        print('test_create_configs_from_pipeline_proto')

        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        pipeline_config.model.faster_rcnn.num_classes = 10
        pipeline_config.train_config.batch_size = 32
        pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
        pipeline_config.eval_config.num_examples = 20
        pipeline_config.eval_input_reader.queue_capacity = 100

        configs = config_util.create_configs_from_pipeline_proto(
            pipeline_config)
        self.assertProtoEquals(pipeline_config.model, configs["model"])
        self.assertProtoEquals(pipeline_config.train_config,
                               configs["train_config"])
        self.assertProtoEquals(pipeline_config.train_input_reader,
                               configs["train_input_config"])
        self.assertProtoEquals(pipeline_config.eval_config,
                               configs["eval_config"])
        self.assertProtoEquals(pipeline_config.eval_input_reader,
                               configs["eval_input_config"])
示例#3
0
  def test_create_configs_from_pipeline_proto(self):
    """Tests creating configs dictionary from pipeline proto."""

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.model.faster_rcnn.num_classes = 10
    pipeline_config.train_config.batch_size = 32
    pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
    pipeline_config.eval_config.num_examples = 20
    pipeline_config.eval_input_reader.queue_capacity = 100

    configs = config_util.create_configs_from_pipeline_proto(pipeline_config)
    self.assertProtoEquals(pipeline_config.model, configs["model"])
    self.assertProtoEquals(pipeline_config.train_config,
                           configs["train_config"])
    self.assertProtoEquals(pipeline_config.train_input_reader,
                           configs["train_input_config"])
    self.assertProtoEquals(pipeline_config.eval_config, configs["eval_config"])
    self.assertProtoEquals(pipeline_config.eval_input_reader,
                           configs["eval_input_config"])
示例#4
0
 def _set_config_paths(self):
     configs = create_configs_from_pipeline_proto(self.pipeline)
     update_input_reader_config(configs,
                                key_name="train_input_config",
                                input_name=None,
                                field_name="input_path",
                                value=str(self._val_record_file),
                                path_updater=_update_tf_record_input_path)
     update_input_reader_config(configs,
                                key_name="eval_input_configs",
                                input_name=None,
                                field_name="input_path",
                                value=str(self._train_record_file),
                                path_updater=_update_tf_record_input_path)
     update_dict = {
         "label_map_path": str(self._labels_map_file),
         "train_config.fine_tune_checkpoint": str(self._checkpoint_model_folder.joinpath("model.ckpt"))
     }
     configs = merge_external_params_with_configs(configs, kwargs_dict=update_dict)
     self._pipeline = create_pipeline_proto_from_configs(configs)
def get_configs_from_pipeline_string(pipeline_config_string):
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    text_format.Merge(pipeline_config_string, pipeline_config)

    return config_util.create_configs_from_pipeline_proto(pipeline_config)