def test_qat_transforms(self):
        """ Tests the appropropriate ModelTransforms are defined with QAT."""
        qat = QuantizationAwareTraining(
            start_step=300, enable_observer=(350, 500), freeze_bn_step=550
        )

        trainer = Trainer()
        module = TestModule()

        qat.setup(trainer, module, stage="train")

        self.assertGreater(len(qat.transforms), 0)

        def assertContainsTransformsAtStep(step):
            """
            Asserts at least one transform exists at the specified step and
            that it is removed after the step begins.
            """
            self.assertGreater(
                len(
                    [
                        transform
                        for transform in qat.transforms
                        if transform.step == step
                    ]
                ),
                0,
                f"step={step}",
            )
            trainer.global_step = step
            qat.on_train_batch_start(
                trainer, module, batch=None, batch_idx=0, dataloader_idx=0
            )

            self.assertEqual(
                len(
                    [
                        transform
                        for transform in qat.transforms
                        if transform.step == step
                    ]
                ),
                0,
                f"step={step}",
            )

        assertContainsTransformsAtStep(step=300)
        assertContainsTransformsAtStep(step=350)
        assertContainsTransformsAtStep(step=500)
        assertContainsTransformsAtStep(step=550)
    def test_attribute_preservation_qat(self, root_dir):
        """ Validates we can preserve specified properties in module. """
        seed_everything(100)

        model = TestModule()
        model.layer._added_property = 10
        model._not_preserved = 15
        model._added_property = 20

        num_epochs = 2
        qat = QuantizationAwareTraining(
            preserved_attrs=["_added_property", "layer._added_property"]
        )
        trainer = Trainer(
            default_root_dir=os.path.join(root_dir, "quantized"),
            checkpoint_callback=False,
            callbacks=[qat],
            max_epochs=num_epochs,
            logger=False,
        )

        trainer.fit(model)

        self.assertIsNotNone(qat.prepared)
        self.assertIsNotNone(qat.quantized)

        # Assert properties are maintained.
        self.assertTrue(hasattr(qat.prepared, "_added_property"))
        self.assertTrue(hasattr(qat.prepared.layer, "_added_property"))

        with self.assertRaises(AttributeError):
            qat.prepared._not_preserved
    def test_module_quantized_during_train(self, root_dir):
        """ Validate quantized aware training works as expected. """
        seed_everything(100)

        model = TestModule()
        test_in = torch.randn(1, 32)
        before_train = model.eval()(test_in)
        num_epochs = 2
        qat = QuantizationAwareTraining()
        trainer = Trainer(
            default_root_dir=os.path.join(root_dir, "quantized"),
            checkpoint_callback=False,
            callbacks=[qat],
            max_epochs=num_epochs,
            logger=False,
        )
        trainer.fit(model)

        self.assertIsNotNone(qat.prepared)
        self.assertIsNotNone(qat.quantized)

        test_out = model.eval()(test_in)
        self.assertGreater(
            (test_out ** 2).sum(), 0.03, "With the given seend, L2^2 should be > 0.03."
        )

        base_out = qat.quantized.eval()(test_in)
        self.assertTrue(torch.allclose(base_out, test_out))
        # Weight changed during training.
        self.assertFalse(torch.allclose(before_train, test_out))

        # Validate .test() call works as expected and does not change model weights.
        trainer.test(model)

        self.assertTrue(torch.allclose(test_out, model.eval()(test_in)))
    def test_qat_interval_transform(self, root_dir):
        """ Tests an interval transform is applied multiple times. """
        seed_everything(100)

        def linear_fn_counter(mod):
            if isinstance(mod, torch.nn.Linear):
                linear_fn_counter.count += 1

        linear_fn_counter.count = 0

        model = TestModule()
        num_epochs = 2
        qat = QuantizationAwareTraining()
        qat.transforms.append(
            ModelTransform(fn=linear_fn_counter, message="Counter", interval=10)
        )
        trainer = Trainer(
            default_root_dir=os.path.join(root_dir, "quantized"),
            checkpoint_callback=False,
            callbacks=[qat],
            max_epochs=num_epochs,
            logger=False,
        )
        trainer.fit(model)

        # Model has 2 linear layers.
        self.assertEqual(linear_fn_counter.count, 2 * (trainer.global_step // 10 + 1))
Exemple #5
0
def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
    """Gets the trainer callbacks based on the given D2Go Config.

    Args:
        cfg: The normalized ConfigNode for this D2Go Task.

    Returns:
        A list of configured Callbacks to be used by the Lightning Trainer.
    """
    callbacks: List[Callback] = [
        LearningRateMonitor(logging_interval="step"),
        ModelCheckpoint(
            dirpath=cfg.OUTPUT_DIR,
            save_last=True,
        ),
    ]
    if cfg.QUANTIZATION.QAT.ENABLED:
        qat = cfg.QUANTIZATION.QAT
        callbacks.append(
            QuantizationAwareTraining(
                qconfig_dicts={
                    submodule: None
                    for submodule in cfg.QUANTIZATION.MODULES
                } if cfg.QUANTIZATION.MODULES else None,
                start_step=qat.START_ITER,
                enable_observer=(qat.ENABLE_OBSERVER_ITER,
                                 qat.DISABLE_OBSERVER_ITER),
                freeze_bn_step=qat.FREEZE_BN_ITER,
            ))
    return callbacks
Exemple #6
0
    def test_qat(self, tmp_dir):
        @META_ARCH_REGISTRY.register()
        class QuantizableDetMetaArchForTest(mah.DetMetaArchForTest):
            custom_config_dict = {"preserved_attributes": ["preserved_attr"]}

            def __init__(self, cfg):
                super().__init__(cfg)
                self.avgpool.preserved_attr = "foo"
                self.avgpool.not_preserved_attr = "bar"

            def prepare_for_quant(self, cfg):
                example_inputs = (torch.rand(1, 3, 3, 3), )
                self.avgpool = prepare_qat_fx(
                    self.avgpool,
                    {
                        "":
                        set_backend_and_create_qconfig(cfg,
                                                       is_train=self.training)
                    },
                    example_inputs,
                    self.custom_config_dict,
                )
                return self

            def prepare_for_quant_convert(self, cfg):
                self.avgpool = convert_fx(
                    self.avgpool,
                    convert_custom_config_dict=self.custom_config_dict)
                return self

        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL.META_ARCHITECTURE = "QuantizableDetMetaArchForTest"
        cfg.QUANTIZATION.QAT.ENABLED = True
        task = GeneralizedRCNNTask(cfg)

        callbacks = [
            QuantizationAwareTraining.from_config(cfg),
            ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR, save_last=True),
        ]
        trainer = pl.Trainer(
            max_steps=1,
            limit_train_batches=1,
            num_sanity_val_steps=0,
            callbacks=callbacks,
            logger=False,
        )
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)
        prepared_avgpool = task._prepared.model.avgpool
        self.assertEqual(prepared_avgpool.preserved_attr, "foo")
        self.assertFalse(hasattr(prepared_avgpool, "not_preserved_attr"))

        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertTrue(isinstance(model.avgpool, torch.fx.GraphModule))
 def test_qat_misconfiguration(self):
     """ Tests failure when misconfiguring the QAT Callback. """
     invalid_params = [
         {"start_step": -1},
         {"enable_observer": (42, 42)},
         {"enable_observer": (42, 21)},
         {"enable_observer": (-1, None)},
         {"freeze_bn_step": -1},
     ]
     for invalid_param in invalid_params:
         with self.assertRaises(ValueError):
             _ = QuantizationAwareTraining(**invalid_param)
    def test_quantization_without_train(self, root_dir):
        """ Validate quantization occurs even without a call to .fit() first. """
        seed_everything(100)

        model = TestModule()
        num_epochs = 2
        qat = QuantizationAwareTraining()
        trainer = Trainer(
            default_root_dir=os.path.join(root_dir, "quantized"),
            checkpoint_callback=False,
            callbacks=[qat],
            max_epochs=num_epochs,
            logger=False,
        )

        trainer.test(model)

        self.assertIsNotNone(qat.prepared)
        self.assertIsNotNone(qat.quantized)
Exemple #9
0
def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
    """Gets the trainer callbacks based on the given D2Go Config.

    Args:
        cfg: The normalized ConfigNode for this D2Go Task.

    Returns:
        A list of configured Callbacks to be used by the Lightning Trainer.
    """
    callbacks: List[Callback] = [
        LearningRateMonitor(logging_interval="step"),
        ModelCheckpoint(
            dirpath=cfg.OUTPUT_DIR,
            save_last=True,
        ),
    ]
    if cfg.QUANTIZATION.QAT.ENABLED:
        callbacks.append(QuantizationAwareTraining.from_config(cfg))
    return callbacks
    def test_quantization_and_checkpointing(self, root_dir):
        """ Validate written checkpoints can be loaded back as expected. """
        seed_everything(100)

        model = TestModule()
        num_epochs = 2
        qat = QuantizationAwareTraining()
        checkpoint_dir = os.path.join(root_dir, "checkpoints")
        checkpoint = ModelCheckpoint(dirpath=checkpoint_dir, save_last=True)
        trainer = Trainer(
            default_root_dir=os.path.join(root_dir, "quantized"),
            callbacks=[qat, checkpoint],
            max_epochs=num_epochs,
            logger=False,
        )
        # Mimick failing mid-training by not running on_fit_end.
        with mock.patch.object(qat, "on_fit_end"):
            trainer.fit(model)

        ckpt = torch.load(os.path.join(checkpoint_dir, "last.ckpt"))
        model.load_state_dict(ckpt["state_dict"])
    def test_submodule_qat(self, root_dir):
        """Tests that we can customize QAT through exposed API."""
        seed_everything(100)

        model = TestModule()
        test_in = torch.randn(1, 32)
        before_train = model.eval()(test_in)
        num_epochs = 2
        qat = QuantizationAwareTraining(
            qconfig_dicts={"another_layer": {
                "": get_default_qat_qconfig()
            }})
        trainer = Trainer(
            default_root_dir=os.path.join(root_dir, "quantized"),
            enable_checkpointing=False,
            callbacks=[qat],
            max_epochs=num_epochs,
            logger=False,
        )
        trainer.fit(model)

        self.assertIsNotNone(qat.prepared)
        self.assertIsNotNone(qat.quantized)

        test_out = model.eval()(test_in)
        self.assertGreater((test_out**2).sum(), 0.03,
                           "With the given seend, L2^2 should be > 0.03.")

        base_out = qat.quantized.eval()(test_in)
        self.assertTrue(torch.allclose(base_out, test_out))
        # Weight changed during training.
        self.assertFalse(torch.allclose(before_train, test_out))

        # Validate .test() call works as expected and does not change model weights.
        trainer.test(model)

        self.assertTrue(torch.allclose(test_out, model.eval()(test_in)))