def do_train(cfg): model = instantiate(cfg.model) logger = logging.getLogger("detectron2") logger.info("Model:\n{}".format(model)) model.to(cfg.train.device) cfg.optimizer.params.model = model optim = instantiate(cfg.optimizer) train_loader = instantiate(cfg.dataloader.train) model = create_ddp_model(model, **cfg.train.ddp) trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)( model, train_loader, optim) checkpointer = DetectionCheckpointer( model, cfg.train.output_dir, optimizer=optim, trainer=trainer, ) trainer.register_hooks([ hooks.IterationTimer(), hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) if comm.is_main_process() else None, hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), hooks.PeriodicWriter( default_writers(cfg.train.output_dir, cfg.train.max_iter), period=cfg.train.log_period, ) if comm.is_main_process() else None, ]) checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=True) start_iter = 0 trainer.train(start_iter, cfg.train.max_iter)
def training_step(self, batch, batch_idx): data_time = time.perf_counter() - self.data_start # Need to manually enter/exit since trainer may launch processes # This ideally belongs in setup, but setup seems to run before processes are spawned if self.storage is None: self.storage = EventStorage(0) self.storage.__enter__() self.iteration_timer.trainer = weakref.proxy(self) self.iteration_timer.before_step() self.writers = ( default_writers(self.cfg.OUTPUT_DIR, self.max_iter) if comm.is_main_process() else {} ) loss_dict = self.model(batch) SimpleTrainer.write_metrics(loss_dict, data_time) opt = self.optimizers() self.storage.put_scalar( "lr", opt.param_groups[self._best_param_group_id]["lr"], smoothing_hint=False ) self.iteration_timer.after_step() self.storage.step() # A little odd to put before step here, but it's the best way to get a proper timing self.iteration_timer.before_step() if self.storage.iter % 20 == 0: for writer in self.writers: writer.write() return sum(loss_dict.values())
def do_train(args, cfg): """ Args: cfg: an object with the following attributes: model: instantiate to a module dataloader.{train,test}: instantiate to dataloaders dataloader.evaluator: instantiate to evaluator for test set optimizer: instantaite to an optimizer lr_multiplier: instantiate to a fvcore scheduler train: other misc config defined in `configs/common/train.py`, including: output_dir (str) init_checkpoint (str) amp.enabled (bool) max_iter (int) eval_period, log_period (int) device (str) checkpointer (dict) ddp (dict) """ model = instantiate(cfg.model) logger = logging.getLogger("detectron2") logger.info("Model:\n{}".format(model)) model.to(cfg.train.device) cfg.optimizer.params.model = model optim = instantiate(cfg.optimizer) train_loader = instantiate(cfg.dataloader.train) model = create_ddp_model(model, **cfg.train.ddp) trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)( model, train_loader, optim) checkpointer = DetectionCheckpointer( model, cfg.train.output_dir, optimizer=optim, trainer=trainer, ) trainer.register_hooks([ hooks.IterationTimer(), hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) if comm.is_main_process() else None, hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), hooks.PeriodicWriter( default_writers(cfg.train.output_dir, cfg.train.max_iter), period=cfg.train.log_period, ) if comm.is_main_process() else None, ]) checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) if args.resume and checkpointer.has_checkpoint(): # The checkpoint stores the training iteration that just finished, thus we start # at the next iteration start_iter = trainer.iter + 1 else: start_iter = 0 trainer.train(start_iter, cfg.train.max_iter)
def do_train(cfg, model, resume=False): model.train() optimizer = build_optimizer(cfg, model) scheduler = build_lr_scheduler(cfg, optimizer) checkpointer = DetectionCheckpointer( model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler ) start_iter = ( checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1 ) max_iter = cfg.SOLVER.MAX_ITER periodic_checkpointer = PeriodicCheckpointer( checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter ) writers = default_writers(cfg.OUTPUT_DIR, max_iter) if comm.is_main_process() else [] # compared to "train_net.py", we do not support accurate timing and # precise BN here, because they are not trivial to implement in a small training loop data_loader = build_detection_train_loader(cfg) logger.info("Starting training from iteration {}".format(start_iter)) with EventStorage(start_iter) as storage: for data, iteration in zip(data_loader, range(start_iter, max_iter)): storage.iter = iteration loss_dict = model(data) losses = sum(loss_dict.values()) assert torch.isfinite(losses).all(), loss_dict loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(loss_dict).items()} losses_reduced = sum(loss for loss in loss_dict_reduced.values()) if comm.is_main_process(): storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced) optimizer.zero_grad() losses.backward() optimizer.step() storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False) scheduler.step() if ( cfg.TEST.EVAL_PERIOD > 0 and (iteration + 1) % cfg.TEST.EVAL_PERIOD == 0 and iteration != max_iter - 1 ): do_test(cfg, model) # Compared to "train_net.py", the test results are not dumped to EventStorage comm.synchronize() if iteration - start_iter > 5 and ( (iteration + 1) % 20 == 0 or iteration == max_iter - 1 ): for writer in writers: writer.write() periodic_checkpointer.step(iteration)
def do_train(cfg, model, resume=False, output_dir="./output", backbone_weights=None): model.train() optimizer = build_optimizer(cfg, model) scheduler = build_lr_scheduler(cfg, optimizer) checkpointer = DetectionCheckpointer(model, output_dir, optimizer=optimizer, scheduler=scheduler) start_iter = (checkpointer.resume_or_load( cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1) max_iter = cfg.SOLVER.MAX_ITER ckpt_file = backbone_weights if ckpt_file is not None: print(f"Loading backbone weights from {ckpt_file}!") ckpt = torch.load(ckpt_file) state_dict = ckpt["state_dict"] backbone_state_dict = OrderedDict() mapped = [] for key in state_dict: if "online_network.encoder" in key and "num_batches_tracked" not in key: new_key = key.replace("online_network.encoder.", "") backbone_state_dict[new_key] = state_dict[key] mapped.append(key) model.backbone.load_state_dict(backbone_state_dict) print("Succesfully mapped the following weights:") print(mapped) periodic_checkpointer = PeriodicCheckpointer(checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter) writers = default_writers(output_dir, max_iter) if comm.is_main_process() else [] # compared to "train_net.py", we do not support accurate timing and # precise BN here, because they are not trivial to implement in a small training loop data_loader = build_detection_train_loader(cfg) logger.info("Starting training from iteration {}".format(start_iter)) with EventStorage(start_iter) as storage: for data, iteration in zip(data_loader, range(start_iter, max_iter)): storage.iter = iteration with autocast(): loss_dict = model(data) losses = sum(loss_dict.values()) assert torch.isfinite(losses).all(), loss_dict loss_dict_reduced = { k: v.item() for k, v in comm.reduce_dict(loss_dict).items() } losses_reduced = sum(loss for loss in loss_dict_reduced.values()) if comm.is_main_process(): storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced) optimizer.zero_grad() losses.backward() optimizer.step() storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False) scheduler.step() if (cfg.TEST.EVAL_PERIOD > 0 and (iteration + 1) % cfg.TEST.EVAL_PERIOD == 0 and iteration != max_iter - 1): do_test(cfg, model, output_dir) # Compared to "train_net.py", the test results are not dumped to EventStorage comm.synchronize() if iteration - start_iter > 5 and ((iteration + 1) % 20 == 0 or iteration == max_iter - 1): for writer in writers: writer.write() periodic_checkpointer.step(iteration)