def test_qat(self, tmp_dir): @META_ARCH_REGISTRY.register() class QuantizableDetMetaArchForTest(mah.DetMetaArchForTest): custom_config_dict = {"preserved_attributes": ["preserved_attr"]} def __init__(self, cfg): super().__init__(cfg) self.avgpool.preserved_attr = "foo" self.avgpool.not_preserved_attr = "bar" def prepare_for_quant(self, cfg): example_inputs = (torch.rand(1, 3, 3, 3), ) self.avgpool = prepare_qat_fx( self.avgpool, { "": set_backend_and_create_qconfig(cfg, is_train=self.training) }, example_inputs, self.custom_config_dict, ) return self def prepare_for_quant_convert(self, cfg): self.avgpool = convert_fx( self.avgpool, convert_custom_config_dict=self.custom_config_dict) return self cfg = self._get_cfg(tmp_dir) cfg.MODEL.META_ARCHITECTURE = "QuantizableDetMetaArchForTest" cfg.QUANTIZATION.QAT.ENABLED = True task = GeneralizedRCNNTask(cfg) callbacks = [ QuantizationAwareTraining.from_config(cfg), ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR, save_last=True), ] trainer = pl.Trainer( max_steps=1, limit_train_batches=1, num_sanity_val_steps=0, callbacks=callbacks, logger=False, ) with EventStorage() as storage: task.storage = storage trainer.fit(task) prepared_avgpool = task._prepared.model.avgpool self.assertEqual(prepared_avgpool.preserved_attr, "foo") self.assertFalse(hasattr(prepared_avgpool, "not_preserved_attr")) with temp_defrost(cfg): cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt") model = GeneralizedRCNNTask.build_model(cfg, eval_only=True) self.assertTrue(isinstance(model.avgpool, torch.fx.GraphModule))
def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]: """Gets the trainer callbacks based on the given D2Go Config. Args: cfg: The normalized ConfigNode for this D2Go Task. Returns: A list of configured Callbacks to be used by the Lightning Trainer. """ callbacks: List[Callback] = [ LearningRateMonitor(logging_interval="step"), ModelCheckpoint( dirpath=cfg.OUTPUT_DIR, save_last=True, ), ] if cfg.QUANTIZATION.QAT.ENABLED: callbacks.append(QuantizationAwareTraining.from_config(cfg)) return callbacks