def test_qat_transforms(self): """ Tests the appropropriate ModelTransforms are defined with QAT.""" qat = QuantizationAwareTraining( start_step=300, enable_observer=(350, 500), freeze_bn_step=550 ) trainer = Trainer() module = TestModule() qat.setup(trainer, module, stage="train") self.assertGreater(len(qat.transforms), 0) def assertContainsTransformsAtStep(step): """ Asserts at least one transform exists at the specified step and that it is removed after the step begins. """ self.assertGreater( len( [ transform for transform in qat.transforms if transform.step == step ] ), 0, f"step={step}", ) trainer.global_step = step qat.on_train_batch_start( trainer, module, batch=None, batch_idx=0, dataloader_idx=0 ) self.assertEqual( len( [ transform for transform in qat.transforms if transform.step == step ] ), 0, f"step={step}", ) assertContainsTransformsAtStep(step=300) assertContainsTransformsAtStep(step=350) assertContainsTransformsAtStep(step=500) assertContainsTransformsAtStep(step=550)
def test_attribute_preservation_qat(self, root_dir): """ Validates we can preserve specified properties in module. """ seed_everything(100) model = TestModule() model.layer._added_property = 10 model._not_preserved = 15 model._added_property = 20 num_epochs = 2 qat = QuantizationAwareTraining( preserved_attrs=["_added_property", "layer._added_property"] ) trainer = Trainer( default_root_dir=os.path.join(root_dir, "quantized"), checkpoint_callback=False, callbacks=[qat], max_epochs=num_epochs, logger=False, ) trainer.fit(model) self.assertIsNotNone(qat.prepared) self.assertIsNotNone(qat.quantized) # Assert properties are maintained. self.assertTrue(hasattr(qat.prepared, "_added_property")) self.assertTrue(hasattr(qat.prepared.layer, "_added_property")) with self.assertRaises(AttributeError): qat.prepared._not_preserved
def test_module_quantized_during_train(self, root_dir): """ Validate quantized aware training works as expected. """ seed_everything(100) model = TestModule() test_in = torch.randn(1, 32) before_train = model.eval()(test_in) num_epochs = 2 qat = QuantizationAwareTraining() trainer = Trainer( default_root_dir=os.path.join(root_dir, "quantized"), checkpoint_callback=False, callbacks=[qat], max_epochs=num_epochs, logger=False, ) trainer.fit(model) self.assertIsNotNone(qat.prepared) self.assertIsNotNone(qat.quantized) test_out = model.eval()(test_in) self.assertGreater( (test_out ** 2).sum(), 0.03, "With the given seend, L2^2 should be > 0.03." ) base_out = qat.quantized.eval()(test_in) self.assertTrue(torch.allclose(base_out, test_out)) # Weight changed during training. self.assertFalse(torch.allclose(before_train, test_out)) # Validate .test() call works as expected and does not change model weights. trainer.test(model) self.assertTrue(torch.allclose(test_out, model.eval()(test_in)))
def test_qat_interval_transform(self, root_dir): """ Tests an interval transform is applied multiple times. """ seed_everything(100) def linear_fn_counter(mod): if isinstance(mod, torch.nn.Linear): linear_fn_counter.count += 1 linear_fn_counter.count = 0 model = TestModule() num_epochs = 2 qat = QuantizationAwareTraining() qat.transforms.append( ModelTransform(fn=linear_fn_counter, message="Counter", interval=10) ) trainer = Trainer( default_root_dir=os.path.join(root_dir, "quantized"), checkpoint_callback=False, callbacks=[qat], max_epochs=num_epochs, logger=False, ) trainer.fit(model) # Model has 2 linear layers. self.assertEqual(linear_fn_counter.count, 2 * (trainer.global_step // 10 + 1))
def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]: """Gets the trainer callbacks based on the given D2Go Config. Args: cfg: The normalized ConfigNode for this D2Go Task. Returns: A list of configured Callbacks to be used by the Lightning Trainer. """ callbacks: List[Callback] = [ LearningRateMonitor(logging_interval="step"), ModelCheckpoint( dirpath=cfg.OUTPUT_DIR, save_last=True, ), ] if cfg.QUANTIZATION.QAT.ENABLED: qat = cfg.QUANTIZATION.QAT callbacks.append( QuantizationAwareTraining( qconfig_dicts={ submodule: None for submodule in cfg.QUANTIZATION.MODULES } if cfg.QUANTIZATION.MODULES else None, start_step=qat.START_ITER, enable_observer=(qat.ENABLE_OBSERVER_ITER, qat.DISABLE_OBSERVER_ITER), freeze_bn_step=qat.FREEZE_BN_ITER, )) return callbacks
def test_qat(self, tmp_dir): @META_ARCH_REGISTRY.register() class QuantizableDetMetaArchForTest(mah.DetMetaArchForTest): custom_config_dict = {"preserved_attributes": ["preserved_attr"]} def __init__(self, cfg): super().__init__(cfg) self.avgpool.preserved_attr = "foo" self.avgpool.not_preserved_attr = "bar" def prepare_for_quant(self, cfg): example_inputs = (torch.rand(1, 3, 3, 3), ) self.avgpool = prepare_qat_fx( self.avgpool, { "": set_backend_and_create_qconfig(cfg, is_train=self.training) }, example_inputs, self.custom_config_dict, ) return self def prepare_for_quant_convert(self, cfg): self.avgpool = convert_fx( self.avgpool, convert_custom_config_dict=self.custom_config_dict) return self cfg = self._get_cfg(tmp_dir) cfg.MODEL.META_ARCHITECTURE = "QuantizableDetMetaArchForTest" cfg.QUANTIZATION.QAT.ENABLED = True task = GeneralizedRCNNTask(cfg) callbacks = [ QuantizationAwareTraining.from_config(cfg), ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR, save_last=True), ] trainer = pl.Trainer( max_steps=1, limit_train_batches=1, num_sanity_val_steps=0, callbacks=callbacks, logger=False, ) with EventStorage() as storage: task.storage = storage trainer.fit(task) prepared_avgpool = task._prepared.model.avgpool self.assertEqual(prepared_avgpool.preserved_attr, "foo") self.assertFalse(hasattr(prepared_avgpool, "not_preserved_attr")) with temp_defrost(cfg): cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt") model = GeneralizedRCNNTask.build_model(cfg, eval_only=True) self.assertTrue(isinstance(model.avgpool, torch.fx.GraphModule))
def test_qat_misconfiguration(self): """ Tests failure when misconfiguring the QAT Callback. """ invalid_params = [ {"start_step": -1}, {"enable_observer": (42, 42)}, {"enable_observer": (42, 21)}, {"enable_observer": (-1, None)}, {"freeze_bn_step": -1}, ] for invalid_param in invalid_params: with self.assertRaises(ValueError): _ = QuantizationAwareTraining(**invalid_param)
def test_quantization_without_train(self, root_dir): """ Validate quantization occurs even without a call to .fit() first. """ seed_everything(100) model = TestModule() num_epochs = 2 qat = QuantizationAwareTraining() trainer = Trainer( default_root_dir=os.path.join(root_dir, "quantized"), checkpoint_callback=False, callbacks=[qat], max_epochs=num_epochs, logger=False, ) trainer.test(model) self.assertIsNotNone(qat.prepared) self.assertIsNotNone(qat.quantized)
def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]: """Gets the trainer callbacks based on the given D2Go Config. Args: cfg: The normalized ConfigNode for this D2Go Task. Returns: A list of configured Callbacks to be used by the Lightning Trainer. """ callbacks: List[Callback] = [ LearningRateMonitor(logging_interval="step"), ModelCheckpoint( dirpath=cfg.OUTPUT_DIR, save_last=True, ), ] if cfg.QUANTIZATION.QAT.ENABLED: callbacks.append(QuantizationAwareTraining.from_config(cfg)) return callbacks
def test_quantization_and_checkpointing(self, root_dir): """ Validate written checkpoints can be loaded back as expected. """ seed_everything(100) model = TestModule() num_epochs = 2 qat = QuantizationAwareTraining() checkpoint_dir = os.path.join(root_dir, "checkpoints") checkpoint = ModelCheckpoint(dirpath=checkpoint_dir, save_last=True) trainer = Trainer( default_root_dir=os.path.join(root_dir, "quantized"), callbacks=[qat, checkpoint], max_epochs=num_epochs, logger=False, ) # Mimick failing mid-training by not running on_fit_end. with mock.patch.object(qat, "on_fit_end"): trainer.fit(model) ckpt = torch.load(os.path.join(checkpoint_dir, "last.ckpt")) model.load_state_dict(ckpt["state_dict"])
def test_submodule_qat(self, root_dir): """Tests that we can customize QAT through exposed API.""" seed_everything(100) model = TestModule() test_in = torch.randn(1, 32) before_train = model.eval()(test_in) num_epochs = 2 qat = QuantizationAwareTraining( qconfig_dicts={"another_layer": { "": get_default_qat_qconfig() }}) trainer = Trainer( default_root_dir=os.path.join(root_dir, "quantized"), enable_checkpointing=False, callbacks=[qat], max_epochs=num_epochs, logger=False, ) trainer.fit(model) self.assertIsNotNone(qat.prepared) self.assertIsNotNone(qat.quantized) test_out = model.eval()(test_in) self.assertGreater((test_out**2).sum(), 0.03, "With the given seend, L2^2 should be > 0.03.") base_out = qat.quantized.eval()(test_in) self.assertTrue(torch.allclose(base_out, test_out)) # Weight changed during training. self.assertFalse(torch.allclose(before_train, test_out)) # Validate .test() call works as expected and does not change model weights. trainer.test(model) self.assertTrue(torch.allclose(test_out, model.eval()(test_in)))