def main(): experiment = comet_ml.Experiment() cfg = setup() for d in ["train", "val"]: DatasetCatalog.register("balloon_" + d, lambda d=d: get_balloon_dicts("balloon/" + d)) MetadataCatalog.get("balloon_" + d).set(thing_classes=["balloon"]) balloon_metadata = MetadataCatalog.get("balloon_train") # Wrap the Detectron Default Trainer trainer = CometDefaultTrainer(cfg, experiment) trainer.resume_or_load(resume=False) # Register Hook to compute metrics using an Evaluator Object trainer.register_hooks([ hooks.EvalHook(10, lambda: trainer.evaluate_metrics(cfg, trainer.model)) ]) # Register Hook to compute eval loss trainer.register_hooks([ hooks.EvalHook(10, lambda: trainer.evaluate_loss(cfg, trainer.model)) ]) trainer.train() # Evaluate Model Predictions cfg.MODEL.WEIGHTS = os.path.join( cfg.OUTPUT_DIR, "model_final.pth") # path to the model we just trained cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # set a custom testing threshold predictor = DefaultPredictor(cfg) log_predictions(predictor, get_balloon_dicts("balloon/val"), experiment)
def main(args): cfg = setup(args) if args.eval_only: model = Trainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume ) res = Trainer.test(cfg, model) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) if comm.is_main_process(): verify_results(cfg, res) # Save the evaluation results pd.DataFrame(res).to_csv(f'{cfg.OUTPUT_DIR}/eval.csv') return res """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop (see plain_train_net.py) or subclassing the trainer. """ trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) trainer.register_hooks( [hooks.EvalHook(0, lambda: trainer.eval_and_save(cfg, trainer.model))] ) if cfg.TEST.AUG.ENABLED: trainer.register_hooks( [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] ) return trainer.train()
def build_hooks(self): cfg = self.cfg.clone() cfg.defrost() cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN ret = [ hooks.IterationTimer(), hooks.LRScheduler(self.optimizer, self.scheduler), hooks.PreciseBN( # Run at the same freq as (but before) evaluation. cfg.TEST.EVAL_PERIOD, self.model, # Build a new data loader to not affect training self.build_train_loader(cfg), cfg.TEST.PRECISE_BN.NUM_ITER, ) if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) else None, ] # Do PreciseBN before checkpointer, because it updates the model and need to # be saved by checkpointer. # This is not always the best: if checkpointing has a different frequency, # some checkpoints may have more precise statistics than others. if comm.is_main_process(): ret.append( hooks.PeriodicCheckpointer( self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD ) ) def test_and_save_results_student(): self._last_eval_results_student = self.test(self.cfg, self.model) _last_eval_results_student = { k + "_student": self._last_eval_results_student[k] for k in self._last_eval_results_student.keys() } return _last_eval_results_student def test_and_save_results_teacher(): self._last_eval_results_teacher = self.test( self.cfg, self.model_teacher) return self._last_eval_results_teacher ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results_student)) ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results_teacher)) if comm.is_main_process(): # run writers in the end, so that evaluation metrics are written ret.append(hooks.PeriodicWriter(self.build_writers(), period=20)) return ret
def main(args): setup_logger_global_cfg_global_textlogger(args, tl_textdir=args.tl_textdir) cfg = setup(args) cfg = D2Utils.cfg_merge_from_easydict(cfg, global_cfg) if comm.is_main_process(): path = os.path.join(cfg.OUTPUT_DIR, "config.yaml") with open(path, "w") as f: f.write(cfg.dump()) if args.eval_only: model = Trainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume) res = Trainer.test(cfg, model) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) if comm.is_main_process(): verify_results(cfg, res) return res """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop (see plain_train_net.py) or subclassing the trainer. """ trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks([ hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model)) ]) return trainer.train()
def main(args): #cfg = setup(args) if args.eval_only: # TODO(jen) #model = Trainer.build_model(cfg) from models.rpn import ProposalNetwork rpn = ProposalNetwork('cuda') cfg = rpn.cfg model = rpn.predictor.model cfg.DATASETS.TRAIN = ('lvis_v0.5_train', ) cfg.DATASETS.TEST = ('lvis_v0.5_val', ) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume) res = Trainer.test(cfg, model) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) if comm.is_main_process(): verify_results(cfg, res) return res """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop or subclassing the trainer. """ trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks([ hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model)) ]) return trainer.train()
def main(args): print("enter main") cfg = setup(args) print("setup cfg ") register_lofar_datasets(cfg) print("register lofar datasets") print(f"output dir is {cfg.OUTPUT_DIR} datasets") if args.eval_only: model = LOFARTrainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume) res = LOFARTrainer.test(cfg, model) if comm.is_main_process(): verify_results(cfg, res) if cfg.TEST.AUG.ENABLED: res.update(LOFARTrainer.test_with_TTA(cfg, model)) return res """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop or subclassing the trainer. """ trainer = LOFARTrainer(cfg) trainer.resume_or_load(resume=args.resume) print("set up trainer") if cfg.TEST.AUG.ENABLED: trainer.register_hooks([ hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model)) ]) return trainer.train()
def main(args): cfg = setup(args) if args.eval_only: model = Trainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume) res = Trainer.test(cfg, model) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) if comm.is_main_process(): verify_results(cfg, res) return res """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop (see plain_train_net.py) or subclassing the trainer. """ trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) from detectron2.lib.utils.net import convert_bn2affine_model trainer.model = convert_bn2affine_model(trainer.model) trainer.cuda(cfg.MODEL.DEVICE) if cfg.TEST.AUG.ENABLED: trainer.register_hooks([ hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model)) ]) return trainer.train()
def main(args): cfg = setup(args) if args.test_img: results = test_on_img(cfg, args.test_img) print(results) return results if args.eval_only: model = Trainer.build_model(cfg) AdetCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume ) evaluators = [ Trainer.build_evaluator(cfg, name) for name in cfg.DATASETS.TEST ] res = Trainer.test(cfg, model, evaluators) if comm.is_main_process(): verify_results(cfg, res) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) return res """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop or subclassing the trainer. """ trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks( [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] ) return trainer.train()
def main(args): cfg = setup(args) # TODO REGISTRATION OF DATASET register_isprs_train_instance(cfg) register_isprs_val_instance(cfg) register_isprs_test_instance(cfg) register_isprs_train_panoptic(cfg) register_isprs_val_panoptic(cfg) register_isprs_test_panoptic(cfg) if args.eval_only: model = ISPRSTrainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume) res = ISPRSTrainer.test(cfg, model) if cfg.TEST.AUG.ENABLED: res.update(ISPRSTrainer.test_with_TTA(cfg, model)) if comm.is_main_process(): verify_results(cfg, res) return res trainer = ISPRSTrainer(cfg) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks([ hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model)) ]) return trainer.train()
def main(args): cfg = setup(args) if comm.is_main_process( ) and cfg.USE_WANDB: # CSD: set up wandb (for tracking visualizations) wandb.login() wandb.init(project=cfg.WANDB_PROJECT_NAME, config=cfg, sync_tensorboard=True) if args.eval_only: model = Trainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume) res = Trainer.test(cfg, model) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) if comm.is_main_process(): verify_results(cfg, res) return res """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop (see plain_train_net.py) or subclassing the trainer. """ trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks([ hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model)) ]) return trainer.train()
def main(args): cfg = setup(args) # Load cfg as python dict config = load_yaml(args.config_file) # If evaluation if args.eval_only: model = Trainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume ) res = Trainer.test(cfg, model) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) if comm.is_main_process(): verify_results(cfg, res) # If training else: trainer = Trainer(cfg) # Load model weights (if specified) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks( [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] ) # Will evaluation be done at end of training? res = trainer.train() return res
def main(args): cfg = setup(args) if args.eval_only: model = Trainer.build_model(cfg) AdetCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume) res = Trainer.test(cfg, model) # d2 defaults.py if comm.is_main_process(): verify_results(cfg, res) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) return res if args.tpu: import torch_xla.core.xla_model as xm _ = xm.xla_device() """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop or subclassing the trainer. """ trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks([ hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model)) ]) return trainer.train()
def main(args): cfg = setup(args) if args.eval_only: cfg.DATASETS.TEST = ("usopen_nadal_test",) cfg.freeze() default_setup(cfg, args) model = Trainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=True ) evaluator = Trainer.build_evaluator( cfg, dataset_name=cfg.DATASETS.TEST[0], output_folder=os.path.join(cfg.OUTPUT_DIR, "inference", "test") ) res = Trainer.test(cfg, model, evaluator) return """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop (see plain_train_net.py) or subclassing the trainer. """ cfg.freeze() default_setup(cfg, args) trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks( [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] ) return trainer.train()
def main(args): print('setting up configs') cfg = setup(args) print('configs loaded') if args.eval_only: model = Trainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume ) res = Trainer.test(cfg, model) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) if comm.is_main_process(): verify_results(cfg, res) return res """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop (see plain_train_net.py) or subclassing the trainer. """ print('building model from configs') trainer = Trainer(cfg) print('model built') trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks( [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] ) return trainer.train()
def main(args): # TODO: remove this hardcoded stuff root = '/home/arash/Software/datasets/VOCdevkit/' register_all_pascal_voc_org(root) cfg = setup(args) if args.eval_only: model = Trainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume) res = Trainer.test(cfg, model) if comm.is_main_process(): verify_results(cfg, res) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) return res """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop or subclassing the trainer. """ trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks([ hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model)) ]) return trainer.train()
def main(args): # Gets and sets up config cfg = setup(args) # If eval only if args.eval_only: model = Trainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume) res = Trainer.test(cfg, model) if comm.is_main_process(): verify_results(cfg, res) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) return res """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop or subclassing the trainer. """ trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks([ hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model)) ]) return trainer.train()
def do_train(cfg): model = instantiate(cfg.model) logger = logging.getLogger("detectron2") logger.info("Model:\n{}".format(model)) model.to(cfg.train.device) cfg.optimizer.params.model = model optim = instantiate(cfg.optimizer) train_loader = instantiate(cfg.dataloader.train) model = create_ddp_model(model, **cfg.train.ddp) trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)( model, train_loader, optim) checkpointer = DetectionCheckpointer( model, cfg.train.output_dir, optimizer=optim, trainer=trainer, ) trainer.register_hooks([ hooks.IterationTimer(), hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) if comm.is_main_process() else None, hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), hooks.PeriodicWriter( default_writers(cfg.train.output_dir, cfg.train.max_iter), period=cfg.train.log_period, ) if comm.is_main_process() else None, ]) checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=True) start_iter = 0 trainer.train(start_iter, cfg.train.max_iter)
def main(args): register_coco_instances("object365_train", {}, "data/object_365/objects365_train.json", "data/object_365/train") register_coco_instances("object365_val", {}, "data/object_365/objects365_val.json", "data/object_365/val") cfg = setup(args) if args.eval_only: model = Trainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume) res = Trainer.test(cfg, model) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) if comm.is_main_process(): verify_results(cfg, res) return res """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop (see plain_train_net.py) or subclassing the trainer. """ trainer = Trainer(cfg, linear_eval=args.linear_eval, mini=args.mini, mini_ratio=args.mini_ratio) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks([ hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model)) ]) return trainer.train()
def main(args): cfg = setup(args) from thop import profile if args.eval_only: model = Trainer.build_model(cfg) #input_size = (1, 3, 288, 800) #input_size = torch.zeros(*input_size) #flops, params = profile(model, inputs=(input_size, )) AdetCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume) res = Trainer.test(cfg, model) # d2 defaults.py if comm.is_main_process(): verify_results(cfg, res) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) return res """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop or subclassing the trainer. """ trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks([ hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model)) ]) return trainer.train()
def test_writer_hooks(self): model = _SimpleModel(sleep_sec=0.1) trainer = SimpleTrainer(model, self._data_loader("cpu"), torch.optim.SGD(model.parameters(), 0.1)) max_iter = 50 with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: json_file = os.path.join(d, "metrics.json") writers = [CommonMetricPrinter(max_iter), JSONWriter(json_file)] trainer.register_hooks([ hooks.EvalHook(0, lambda: {"metric": 100}), hooks.PeriodicWriter(writers) ]) with self.assertLogs(writers[0].logger) as logs: trainer.train(0, max_iter) with open(json_file, "r") as f: data = [json.loads(line.strip()) for line in f] self.assertEqual([x["iteration"] for x in data], [19, 39, 49, 50]) # the eval metric is in the last line with iter 50 self.assertIn("metric", data[-1], "Eval metric must be in last line of JSON!") # test logged messages from CommonMetricPrinter self.assertEqual(len(logs.output), 3) for log, iter in zip(logs.output, [19, 39, 49]): self.assertIn(f"iter: {iter}", log) self.assertIn("eta: 0:00:00", logs.output[-1], "Last ETA must be 0!")
def main(args): cfg = setup(args) # disable strict kwargs checking: allow one to specify path handle # hints through kwargs, like timeout in DP evaluation PathManager.set_strict_kwargs_checking(False) if args.eval_only: model = Trainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume ) res = Trainer.test(cfg, model) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) if comm.is_main_process(): verify_results(cfg, res) return res trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks( [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] ) return trainer.train()
def main(args): cfg = setup(args) # Select which trainer to use Trainer = MetaReweightTrainer if cfg.get("META_REWEIGHT", False) else PlainTrainer if args.eval_only: model = Trainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume) res = Trainer.test(cfg, model) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) if comm.is_main_process(): verify_results(cfg, res) return res """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop (see plain_train_net.py) or subclassing the trainer. """ trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks([ hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model)) ]) return trainer.train()
def test_best_checkpointer(self): model = _SimpleModel() dataloader = self._data_loader("cpu") opt = torch.optim.SGD(model.parameters(), 0.1) metric_name = "metric" total_iter = 40 test_period = 10 test_cases = [ ("max", iter([0.3, 0.4, 0.35, 0.5]), 3), ("min", iter([1.0, 0.8, 0.9, 0.9]), 2), ("min", iter([math.nan, 0.8, 0.9, 0.9]), 1), ] for mode, metrics, call_count in test_cases: trainer = SimpleTrainer(model, dataloader, opt) with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: checkpointer = Checkpointer(model, d, opt=opt, trainer=trainer) trainer.register_hooks([ hooks.EvalHook(test_period, lambda: {metric_name: next(metrics)}), hooks.BestCheckpointer(test_period, checkpointer, metric_name, mode=mode), ]) with mock.patch.object(checkpointer, "save") as mock_save_method: trainer.train(0, total_iter) self.assertEqual(mock_save_method.call_count, call_count)
def main(args): cfg = setup(args) # Load cfg as python dict config = load_yaml(args.config_file) # Setup wandb wandb.init( # Use exp name to resume run later on # id="cascade_df_resume", id=args.exp_name, project="piplup-od", # name="cascade_df_resume", name=args.exp_name, sync_tensorboard=True, config=config, # Resume making use of the same exp name resume=args.exp_name if args.resume else False, # dir=cfg.OUTPUT_DIR, ) # Auto upload any checkpoints to wandb as they are written # wandb.save(os.path.join(cfg.OUTPUT_DIR, "*.pth")) # TODO: Visualize and log training examples and annotations # training_imgs = viz_data(cfg) # wandb.log({"training_examples": training_imgs}) # If evaluation if args.eval_only: model = Trainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume ) res = Trainer.test(cfg, model) # FIXME: TTA if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) if comm.is_main_process(): verify_results(cfg, res) # If training else: trainer = Trainer(cfg) # Load model weights (if specified) trainer.resume_or_load(resume=args.resume) # FIXME: TTA if cfg.TEST.AUG.ENABLED: trainer.register_hooks( [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] ) # Will evaluation be done at end of training? res = trainer.train() # TODO: Visualize and log predictions and groundtruth annotations pred_imgs = viz_preds(cfg) wandb.log({"prediction_examples": pred_imgs}) return res
def build_hooks(self): """ Build a list of default hooks, including timing, evaluation, checkpointing, lr scheduling, precise BN, writing events. Returns: list[HookBase]: """ logger = logging.getLogger(__name__) cfg = self.cfg.clone() cfg.defrost() cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN ret = [ hooks.IterationTimer(), hooks.LRScheduler(self.optimizer, self.scheduler), ] if cfg.SOLVER.SWA.ENABLED: ret.append( additional_hooks.SWA( cfg.SOLVER.MAX_ITER, cfg.SOLVER.SWA.PERIOD, cfg.SOLVER.SWA.LR_START, cfg.SOLVER.SWA.ETA_MIN_LR, cfg.SOLVER.SWA.LR_SCHED, ) ) if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model): logger.info("Prepare precise BN dataset") ret.append(hooks.PreciseBN( # Run at the same freq as (but before) evaluation. cfg.TEST.EVAL_PERIOD, self.model, # Build a new data loader to not affect training self.build_train_loader(cfg), cfg.TEST.PRECISE_BN.NUM_ITER, )) # Do PreciseBN before checkpointer, because it updates the model and need to # be saved by checkpointer. # This is not always the best: if checkpointing has a different frequency, # some checkpoints may have more precise statistics than others. if comm.is_main_process(): ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD)) def test_and_save_results(): self._last_eval_results = self.test(self.cfg, self.model) return self._last_eval_results # Do evaluation after checkpointer, because then if it fails, # we can use the saved checkpoint to debug. ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) if comm.is_main_process(): # run writers in the end, so that evaluation metrics are written ret.append(hooks.PeriodicWriter(self.build_writers(), 20)) return ret
def do_train(args, cfg): """ Args: cfg: an object with the following attributes: model: instantiate to a module dataloader.{train,test}: instantiate to dataloaders dataloader.evaluator: instantiate to evaluator for test set optimizer: instantaite to an optimizer lr_multiplier: instantiate to a fvcore scheduler train: other misc config defined in `configs/common/train.py`, including: output_dir (str) init_checkpoint (str) amp.enabled (bool) max_iter (int) eval_period, log_period (int) device (str) checkpointer (dict) ddp (dict) """ model = instantiate(cfg.model) logger = logging.getLogger("detectron2") logger.info("Model:\n{}".format(model)) model.to(cfg.train.device) cfg.optimizer.params.model = model optim = instantiate(cfg.optimizer) train_loader = instantiate(cfg.dataloader.train) model = create_ddp_model(model, **cfg.train.ddp) trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)( model, train_loader, optim) checkpointer = DetectionCheckpointer( model, cfg.train.output_dir, optimizer=optim, trainer=trainer, ) trainer.register_hooks([ hooks.IterationTimer(), hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) if comm.is_main_process() else None, hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), hooks.PeriodicWriter( default_writers(cfg.train.output_dir, cfg.train.max_iter), period=cfg.train.log_period, ) if comm.is_main_process() else None, ]) checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) if args.resume and checkpointer.has_checkpoint(): # The checkpoint stores the training iteration that just finished, thus we start # at the next iteration start_iter = trainer.iter + 1 else: start_iter = 0 trainer.train(start_iter, cfg.train.max_iter)
def train(cfg): trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks([ hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model)) ]) return trainer.train()
def build_hooks(self): """ Build a list of default hooks, including timing, evaluation, checkpointing, lr scheduling, precise BN, writing events. Returns: list[HookBase]: """ cfg = self.cfg.clone() cfg.defrost() cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN ret = [ hooks.IterationTimer(), hooks.LRScheduler(self.optimizer, self.scheduler), hooks.PreciseBN( # Run at the same freq as (but before) evaluation. cfg.TEST.EVAL_PERIOD, self.model, # Build a new data loader to not affect training self.build_train_loader(cfg), cfg.TEST.PRECISE_BN.NUM_ITER, ) if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) else None, ] # Do PreciseBN before checkpointer, because it updates the model and need to # be saved by checkpointer. # This is not always the best: if checkpointing has a different frequency, # some checkpoints may have more precise statistics than others. if comm.is_main_process(): if cfg.SOLVER.CHECKPOINT_BY_EPOCH: ckpt_period = cfg.SOLVER.CHECKPOINT_PERIOD * self.iters_per_epoch else: ckpt_period = cfg.SOLVER.CHECKPOINT_PERIOD ret.append( MyPeriodicCheckpointer(self.checkpointer, ckpt_period, max_to_keep=cfg.SOLVER.get( "NUM_CKPT_KEEP", 5), iters_per_epoch=self.iters_per_epoch)) def test_and_save_results(): self._last_eval_results = self.test(self.cfg, self.model) return self._last_eval_results # Do evaluation after checkpointer, because then if it fails, # we can use the saved checkpoint to debug. ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) if comm.is_main_process(): # run writers in the end, so that evaluation metrics are written ret.append( hooks.PeriodicWriter(self.build_writers(), period=cfg.TRAIN.get("PRINT_FREQ", 100))) return ret
def test_eval_hook(self): model = _SimpleModel() dataloader = self._data_loader("cpu") opt = torch.optim.SGD(model.parameters(), 0.1) for total_iter, period, eval_count in [(30, 15, 2), (31, 15, 3), (20, 0, 1)]: test_func = mock.Mock(return_value={"metric": 3.0}) trainer = SimpleTrainer(model, dataloader, opt) trainer.register_hooks([hooks.EvalHook(period, test_func)]) trainer.train(0, total_iter) self.assertEqual(test_func.call_count, eval_count)
def main(args): cfg = setup(args) from detectron2.data.datasets import register_coco_instances register_coco_instances("surgery_train2", {}, "data/coco/annotations/instances_train2017.json", "data/coco/train2017") MetadataCatalog.get("surgery_train2").thing_classes = [ 'Cerebellum', 'CN8', 'CN5', 'CN7', 'SCA', 'AICA', 'SuperiorPetrosalVein', 'Vein', 'Brainstem', 'Suction', 'Bipolar', 'Forcep', 'BluntProbe', 'Drill', 'Kerrison', 'Cottonoid', 'Scissors', 'Unknown' ] DatasetCatalog.get("surgery_train2") register_coco_instances("surgery_val2", {}, "data/coco/annotations/instances_train2017.json", "data/coco/train2017") MetadataCatalog.get("surgery_val2").thing_classes = [ 'Cerebellum', 'CN8', 'CN5', 'CN7', 'SCA', 'AICA', 'SuperiorPetrosalVein', 'Vein', 'Brainstem', 'Suction', 'Bipolar', 'Forcep', 'BluntProbe', 'Drill', 'Kerrison', 'Cottonoid', 'Scissors', 'Unknown' ] DatasetCatalog.get("surgery_val2") if args.eval_only: model = Trainer.build_model(cfg) AdetCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume) res = Trainer.test(cfg, model) # d2 defaults.py if comm.is_main_process(): verify_results(cfg, res) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) return res """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop or subclassing the trainer. """ trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks([ hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model)) ]) return trainer.train()