def build_optimizer(cls, cfg, model): """ Returns: torch.optim.Optimizer: It now calls :func:`detectron2.solver.build_optimizer`. Overwrite it if you'd like a different optimizer. """ return build_optimizer(cfg, model)
def do_train(cfg, model, resume=False): data_loader = build_reid_train_loader(cfg) model.train() optimizer = build_optimizer(cfg, model) iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH scheduler = build_lr_scheduler(cfg, optimizer, iters_per_epoch) checkpointer = Checkpointer(model, cfg.OUTPUT_DIR, save_to_disk=comm.is_main_process(), optimizer=optimizer**scheduler) start_epoch = (checkpointer.resume_or_load( cfg.MODEL.WEIGHTS, resume=resume).get("epoch", -1) + 1) iteration = start_iter = start_epoch * iters_per_epoch max_epoch = cfg.SOLVER.MAX_EPOCH max_iter = max_epoch * iters_per_epoch warmup_iters = cfg.SOLVER.WARMUP_ITERS delay_epochs = cfg.SOLVER.DELAY_EPOCHS periodic_checkpointer = PeriodicCheckpointer(checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_epoch) writers = ([ CommonMetricPrinter(max_iter), JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")), TensorboardXWriter(cfg.OUTPUT_DIR) ] if comm.is_main_process() else []) # compared to "train_net.py", we do not support some hooks, such as # accurate timing, FP16 training and precise BN here, # because they are not trivial to implement in a small training loop logger.info("Start training from epoch {}".format(start_epoch)) with EventStorage(start_iter) as storage: for epoch in range(start_epoch, max_epoch): storage.epoch = epoch for data, _ in zip(data_loader, range(iters_per_epoch)): storage.iter = iteration loss_dict = model(data) losses = sum(loss_dict.values()) assert torch.isfinite(losses).all(), loss_dict loss_dict_reduced = { k: v.item() for k, v in comm.reduce_dict(loss_dict).items() } losses_reduced = sum(loss for loss in loss_dict_reduced.values()) if comm.is_main_process(): storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced) optimizer.zero_grad() losses.backward() optimizer.step() storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False) if iteration - start_iter > 5 and ( (iteration + 1) % 200 == 0 or iteration == max_iter - 1): for writer in writers: writer.write() iteration += 1 if iteration <= warmup_iters: scheduler["warmup_sched"].step() # Write metrics after each epoch for writer in writers: writer.write() if iteration > warmup_iters and (epoch + 1) >= delay_epochs: scheduler["lr_sched"].step() if (cfg.TEST.EVAL_PERIOD > 0 and (epoch + 1) % cfg.TEST.EVAL_PERIOD == 0 and epoch != max_iter - 1): do_test(cfg, model) # Compared to "train_net.py", the test results are not dumped to EventStorage periodic_checkpointer.step(epoch)
def __init__(self, cfg): TrainerBase.__init__(self) logger = logging.getLogger('fastreid.partial-fc.trainer') if not logger.isEnabledFor( logging.INFO): # setup_logger is not called for fastreid setup_logger() # Assume these objects must be constructed in this order. data_loader = self.build_train_loader(cfg) cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes) model = self.build_model(cfg) optimizer, param_wrapper = self.build_optimizer(cfg, model) if cfg.MODEL.HEADS.PFC.ENABLED: # fmt: off feat_dim = cfg.MODEL.BACKBONE.FEAT_DIM embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM num_classes = cfg.MODEL.HEADS.NUM_CLASSES sample_rate = cfg.MODEL.HEADS.PFC.SAMPLE_RATE cls_type = cfg.MODEL.HEADS.CLS_LAYER scale = cfg.MODEL.HEADS.SCALE margin = cfg.MODEL.HEADS.MARGIN # fmt: on # Partial-FC module embedding_size = embedding_dim if embedding_dim > 0 else feat_dim self.pfc_module = PartialFC(embedding_size, num_classes, sample_rate, cls_type, scale, margin) self.pfc_optimizer, _ = build_optimizer(cfg, self.pfc_module, False) # For training, wrap with DDP. But don't need this for inference. if comm.get_world_size() > 1: # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True` # for part of the parameters is not updated. model = DistributedDataParallel( model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, ) if cfg.MODEL.HEADS.PFC.ENABLED: mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size() grad_scaler = MaxClipGradScaler(mini_batch_size, 128 * mini_batch_size, growth_interval=100) self._trainer = PFCTrainer(model, data_loader, optimizer, param_wrapper, self.pfc_module, self.pfc_optimizer, cfg.SOLVER.AMP.ENABLED, grad_scaler) else: self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(model, data_loader, optimizer, param_wrapper) self.iters_per_epoch = len( data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH self.scheduler = self.build_lr_scheduler(cfg, optimizer, self.iters_per_epoch) if cfg.MODEL.HEADS.PFC.ENABLED: self.pfc_scheduler = self.build_lr_scheduler( cfg, self.pfc_optimizer, self.iters_per_epoch) self.checkpointer = Checkpointer( # Assume you want to save checkpoints together with logs/statistics model, cfg.OUTPUT_DIR, save_to_disk=comm.is_main_process(), optimizer=optimizer, **self.scheduler, ) if cfg.MODEL.HEADS.PFC.ENABLED: self.pfc_checkpointer = PfcCheckpointer( self.pfc_module, cfg.OUTPUT_DIR, optimizer=self.pfc_optimizer, **self.pfc_scheduler, ) self.start_epoch = 0 self.max_epoch = cfg.SOLVER.MAX_EPOCH self.max_iter = self.max_epoch * self.iters_per_epoch self.warmup_iters = cfg.SOLVER.WARMUP_ITERS self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS self.cfg = cfg self.register_hooks(self.build_hooks())