def do_train(cfg): model = instantiate(cfg.model) logger = logging.getLogger("detectron2") logger.info("Model:\n{}".format(model)) model.to(cfg.train.device) cfg.optimizer.params.model = model optim = instantiate(cfg.optimizer) train_loader = instantiate(cfg.dataloader.train) model = create_ddp_model(model, **cfg.train.ddp) trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)( model, train_loader, optim) checkpointer = DetectionCheckpointer( model, cfg.train.output_dir, optimizer=optim, trainer=trainer, ) trainer.register_hooks([ hooks.IterationTimer(), hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) if comm.is_main_process() else None, hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), hooks.PeriodicWriter( default_writers(cfg.train.output_dir, cfg.train.max_iter), period=cfg.train.log_period, ) if comm.is_main_process() else None, ]) checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=True) start_iter = 0 trainer.train(start_iter, cfg.train.max_iter)
def do_train(args, cfg): """ Args: cfg: an object with the following attributes: model: instantiate to a module dataloader.{train,test}: instantiate to dataloaders dataloader.evaluator: instantiate to evaluator for test set optimizer: instantaite to an optimizer lr_multiplier: instantiate to a fvcore scheduler train: other misc config defined in `configs/common/train.py`, including: output_dir (str) init_checkpoint (str) amp.enabled (bool) max_iter (int) eval_period, log_period (int) device (str) checkpointer (dict) ddp (dict) """ model = instantiate(cfg.model) logger = logging.getLogger("detectron2") logger.info("Model:\n{}".format(model)) model.to(cfg.train.device) cfg.optimizer.params.model = model optim = instantiate(cfg.optimizer) train_loader = instantiate(cfg.dataloader.train) model = create_ddp_model(model, **cfg.train.ddp) trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)( model, train_loader, optim) checkpointer = DetectionCheckpointer( model, cfg.train.output_dir, optimizer=optim, trainer=trainer, ) trainer.register_hooks([ hooks.IterationTimer(), hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) if comm.is_main_process() else None, hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), hooks.PeriodicWriter( default_writers(cfg.train.output_dir, cfg.train.max_iter), period=cfg.train.log_period, ) if comm.is_main_process() else None, ]) checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) if args.resume and checkpointer.has_checkpoint(): # The checkpoint stores the training iteration that just finished, thus we start # at the next iteration start_iter = trainer.iter + 1 else: start_iter = 0 trainer.train(start_iter, cfg.train.max_iter)
def main( cfg, output_dir, runner=None, eval_only=False, # NOTE: always enable resume when running on cluster resume=True, ): setup_after_launch(cfg, output_dir, runner) model = runner.build_model(cfg) logger.info("Model:\n{}".format(model)) if eval_only: checkpointer = runner.build_checkpointer(cfg, model, save_dir=output_dir) # checkpointer.resume_or_load() will skip all additional checkpointable # which may not be desired like ema states if resume and checkpointer.has_checkpoint(): checkpoint = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume) else: checkpoint = checkpointer.load(cfg.MODEL.WEIGHTS) train_iter = checkpoint.get("iteration", None) model.eval() metrics = runner.do_test(cfg, model, train_iter=train_iter) print_metrics_table(metrics) return { "accuracy": metrics, "model_configs": {}, "metrics": metrics, } model = create_ddp_model( model, fp16_compression=cfg.MODEL.DDP_FP16_GRAD_COMPRESS, device_ids=None if cfg.MODEL.DEVICE == "cpu" else [comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS, ) trained_cfgs = runner.do_train(cfg, model, resume=resume) metrics = runner.do_test(cfg, model) print_metrics_table(metrics) # dump config files for trained models trained_model_configs = dump_trained_model_configs(cfg.OUTPUT_DIR, trained_cfgs) return { # for e2e_workflow "accuracy": metrics, # for unit_workflow "model_configs": trained_model_configs, "metrics": metrics, }
def main(args): cfg = LazyConfig.load(args.config_file) cfg = LazyConfig.apply_overrides(cfg, args.opts) default_setup(cfg, args) if args.eval_only: model = instantiate(cfg.model) model = create_ddp_model(model) DetectionCheckpointer(model).load(cfg.train.init_checkpoint) print(do_test(cfg, model)) else: do_train(args, cfg)