def compute_kmeans_anchors_hook(runner, cfg): """ This function will create a before_train hook, it will: 1: create a train loader using provided KMEANS_ANCHORS.DATASETS. 2: collecting statistics of boxes using outputs from train loader, use up to KMEANS_ANCHORS.NUM_TRAINING_IMG images. 3: compute K-means using KMEANS_ANCHORS.NUM_CLUSTERS clusters 4: update the buffers in anchor_generator. """ def before_train_callback(trainer): if not cfg.MODEL.KMEANS_ANCHORS.KMEANS_ANCHORS_ON: return new_cfg = cfg.clone() with temp_defrost(new_cfg): new_cfg.DATASETS.TRAIN = cfg.MODEL.KMEANS_ANCHORS.DATASETS data_loader = runner.build_detection_train_loader(new_cfg) anchors = compute_kmeans_anchors(cfg, data_loader) anchors = anchors.tolist() assert isinstance(trainer.model, GeneralizedRCNN) assert isinstance(trainer.model.proposal_generator, RPN) anchor_generator = trainer.model.proposal_generator.anchor_generator assert isinstance(anchor_generator, KMeansAnchorGenerator) anchor_generator.update_cell_anchors(anchors) return hooks.CallbackHook(before_train=before_train_callback)
def _create_after_step_hook(self, cfg, model, optimizer, scheduler, periodic_checkpointer): """ Create a hook that performs some pre-defined tasks used in this script (evaluation, LR scheduling, checkpointing). """ def after_step_callback(trainer): trainer.storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False) scheduler.step() # Note: when precise BN is enabled, some checkpoints will have more precise # statistics than others, if they are saved immediately after eval. if comm.is_main_process(): periodic_checkpointer.step(trainer.iter) return hooks.CallbackHook(after_step=after_step_callback)
def _create_qat_hook(self, cfg): """ Create a hook to start QAT (during training) and/or change the phase of QAT. """ applied = { "enable_fake_quant": False, "enable_observer": False, "disable_observer": False, "freeze_bn_stats": False, } assert (cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER <= cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER ), "Can't diable observer before enabling it" def qat_before_step_callback(trainer): if (not applied["enable_fake_quant"] and trainer.iter >= cfg.QUANTIZATION.QAT.START_ITER): logger.info( "[QAT] enable fake quant to start QAT, iter = {}".format( trainer.iter)) trainer.model.apply(torch.quantization.enable_fake_quant) applied["enable_fake_quant"] = True if cfg.QUANTIZATION.QAT.BATCH_SIZE_FACTOR != 1.0: loader_cfg = cfg.clone() loader_cfg.defrost() num_gpus = comm.get_world_size() old_bs = cfg.SOLVER.IMS_PER_BATCH // num_gpus new_bs = math.ceil(old_bs * cfg.QUANTIZATION.QAT.BATCH_SIZE_FACTOR) loader_cfg.SOLVER.IMS_PER_BATCH = new_bs * num_gpus loader_cfg.freeze() logger.info( "[QAT] Rebuild data loader with batch size per GPU: {} -> {}" .format(old_bs, new_bs)) # This method assumes the data loader can be replaced from trainer assert trainer.__class__ == SimpleTrainer del trainer._data_loader_iter del trainer.data_loader data_loader = self.build_detection_train_loader(loader_cfg) trainer.data_loader = data_loader trainer._data_loader_iter = iter(data_loader) if (not applied["enable_observer"] and trainer.iter >= cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER and trainer.iter < cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER): logger.info("[QAT] enable observer, iter = {}".format( trainer.iter)) trainer.model.apply(torch.quantization.enable_observer) applied["enable_observer"] = True if (not applied["disable_observer"] and trainer.iter >= cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER): logger.info( "[QAT] disabling observer for sub seq iters, iter = {}". format(trainer.iter)) trainer.model.apply(torch.quantization.disable_observer) applied["disable_observer"] = True if (not applied["freeze_bn_stats"] and trainer.iter >= cfg.QUANTIZATION.QAT.FREEZE_BN_ITER): logger.info( "[QAT] freezing BN for subseq iters, iter = {}".format( trainer.iter)) trainer.model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) applied["freeze_bn_stats"] = True if (applied["enable_fake_quant"] and cfg.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIODICALLY and trainer.iter % cfg.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIOD == 0): logger.info(f"[QAT] updating observers, iter = {trainer.iter}") trainer.model.apply(observer_update_stat) return hooks.CallbackHook(before_step=qat_before_step_callback)