Example #1
0
def compute_kmeans_anchors_hook(runner, cfg):
    """
    This function will create a before_train hook, it will:
        1: create a train loader using provided KMEANS_ANCHORS.DATASETS.
        2: collecting statistics of boxes using outputs from train loader, use up
            to KMEANS_ANCHORS.NUM_TRAINING_IMG images.
        3: compute K-means using KMEANS_ANCHORS.NUM_CLUSTERS clusters
        4: update the buffers in anchor_generator.
    """
    def before_train_callback(trainer):
        if not cfg.MODEL.KMEANS_ANCHORS.KMEANS_ANCHORS_ON:
            return

        new_cfg = cfg.clone()
        with temp_defrost(new_cfg):
            new_cfg.DATASETS.TRAIN = cfg.MODEL.KMEANS_ANCHORS.DATASETS
            data_loader = runner.build_detection_train_loader(new_cfg)

        anchors = compute_kmeans_anchors(cfg, data_loader)
        anchors = anchors.tolist()

        assert isinstance(trainer.model, GeneralizedRCNN)
        assert isinstance(trainer.model.proposal_generator, RPN)
        anchor_generator = trainer.model.proposal_generator.anchor_generator
        assert isinstance(anchor_generator, KMeansAnchorGenerator)
        anchor_generator.update_cell_anchors(anchors)

    return hooks.CallbackHook(before_train=before_train_callback)
Example #2
0
    def _create_after_step_hook(self, cfg, model, optimizer, scheduler,
                                periodic_checkpointer):
        """
        Create a hook that performs some pre-defined tasks used in this script
        (evaluation, LR scheduling, checkpointing).
        """
        def after_step_callback(trainer):
            trainer.storage.put_scalar("lr",
                                       optimizer.param_groups[0]["lr"],
                                       smoothing_hint=False)
            scheduler.step()
            # Note: when precise BN is enabled, some checkpoints will have more precise
            # statistics than others, if they are saved immediately after eval.
            if comm.is_main_process():
                periodic_checkpointer.step(trainer.iter)

        return hooks.CallbackHook(after_step=after_step_callback)
Example #3
0
    def _create_qat_hook(self, cfg):
        """
        Create a hook to start QAT (during training) and/or change the phase of QAT.
        """
        applied = {
            "enable_fake_quant": False,
            "enable_observer": False,
            "disable_observer": False,
            "freeze_bn_stats": False,
        }

        assert (cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER <=
                cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
                ), "Can't diable observer before enabling it"

        def qat_before_step_callback(trainer):
            if (not applied["enable_fake_quant"]
                    and trainer.iter >= cfg.QUANTIZATION.QAT.START_ITER):
                logger.info(
                    "[QAT] enable fake quant to start QAT, iter = {}".format(
                        trainer.iter))
                trainer.model.apply(torch.quantization.enable_fake_quant)
                applied["enable_fake_quant"] = True

                if cfg.QUANTIZATION.QAT.BATCH_SIZE_FACTOR != 1.0:
                    loader_cfg = cfg.clone()
                    loader_cfg.defrost()
                    num_gpus = comm.get_world_size()
                    old_bs = cfg.SOLVER.IMS_PER_BATCH // num_gpus
                    new_bs = math.ceil(old_bs *
                                       cfg.QUANTIZATION.QAT.BATCH_SIZE_FACTOR)
                    loader_cfg.SOLVER.IMS_PER_BATCH = new_bs * num_gpus
                    loader_cfg.freeze()

                    logger.info(
                        "[QAT] Rebuild data loader with batch size per GPU: {} -> {}"
                        .format(old_bs, new_bs))
                    # This method assumes the data loader can be replaced from trainer
                    assert trainer.__class__ == SimpleTrainer
                    del trainer._data_loader_iter
                    del trainer.data_loader
                    data_loader = self.build_detection_train_loader(loader_cfg)
                    trainer.data_loader = data_loader
                    trainer._data_loader_iter = iter(data_loader)

            if (not applied["enable_observer"] and
                    trainer.iter >= cfg.QUANTIZATION.QAT.ENABLE_OBSERVER_ITER
                    and
                    trainer.iter < cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER):
                logger.info("[QAT] enable observer, iter = {}".format(
                    trainer.iter))
                trainer.model.apply(torch.quantization.enable_observer)
                applied["enable_observer"] = True

            if (not applied["disable_observer"] and trainer.iter >=
                    cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER):
                logger.info(
                    "[QAT] disabling observer for sub seq iters, iter = {}".
                    format(trainer.iter))
                trainer.model.apply(torch.quantization.disable_observer)
                applied["disable_observer"] = True

            if (not applied["freeze_bn_stats"]
                    and trainer.iter >= cfg.QUANTIZATION.QAT.FREEZE_BN_ITER):
                logger.info(
                    "[QAT] freezing BN for subseq iters, iter = {}".format(
                        trainer.iter))
                trainer.model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
                applied["freeze_bn_stats"] = True

            if (applied["enable_fake_quant"]
                    and cfg.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIODICALLY
                    and trainer.iter %
                    cfg.QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIOD == 0):
                logger.info(f"[QAT] updating observers, iter = {trainer.iter}")
                trainer.model.apply(observer_update_stat)

        return hooks.CallbackHook(before_step=qat_before_step_callback)