Example #1
0
def test_quantization_disable_observers(tmpdir, observer_enabled_stages):
    """Test disabling observers."""
    qmodel = RegressionModel()
    qcb = QuantizationAwareTraining(
        observer_enabled_stages=observer_enabled_stages)
    trainer = Trainer(callbacks=[qcb], default_root_dir=tmpdir)

    # Quantize qmodel.
    qcb.on_fit_start(trainer, qmodel)
    fake_quants = list(module for module in qmodel.modules()
                       if isinstance(module, FakeQuantizeBase))
    # Disable some of observers before fitting.
    for fake_quant in fake_quants[::2]:
        fake_quant.disable_observer()

    for stage, on_stage_start, on_stage_end in [
        ("train", qcb.on_train_start, qcb.on_train_end),
        ("validate", qcb.on_validation_start, qcb.on_validation_end),
        ("test", qcb.on_test_start, qcb.on_test_end),
        ("predict", qcb.on_predict_start, qcb.on_predict_end),
    ]:
        before_stage_observer_enabled = torch.as_tensor(
            list(map(_get_observer_enabled, fake_quants)))

        on_stage_start(trainer, qmodel)
        expected_stage_observer_enabled = torch.as_tensor(
            before_stage_observer_enabled if stage in
            observer_enabled_stages else [False] * len(fake_quants))
        assert torch.equal(
            torch.as_tensor(list(map(_get_observer_enabled, fake_quants))),
            expected_stage_observer_enabled)

        on_stage_end(trainer, qmodel)
        assert torch.equal(
            torch.as_tensor(list(map(_get_observer_enabled, fake_quants))),
            before_stage_observer_enabled)
def test_quantization_triggers(tmpdir, trigger_fn, expected_count):
    """Test  how many times the quant is called"""
    dm = RegressDataModule()
    qmodel = RegressionModel()
    qcb = QuantizationAwareTraining(collect_quantization=trigger_fn)
    trainer = Trainer(
        callbacks=[qcb],
        default_root_dir=tmpdir,
        limit_train_batches=1,
        limit_val_batches=1,
        max_epochs=4,
    )
    trainer.fit(qmodel, datamodule=dm)

    assert qcb._forward_calls == expected_count
def test_quantization_val_test_predict(tmpdir):
    """Test the default quantization aware training not affected by validating, testing and predicting."""
    seed_everything(42)
    num_features = 16
    dm = RegressDataModule(num_features=num_features)
    qmodel = RegressionModel()

    val_test_predict_qmodel = copy.deepcopy(qmodel)
    trainer = Trainer(
        callbacks=[QuantizationAwareTraining(quantize_on_fit_end=False)],
        default_root_dir=tmpdir,
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        limit_predict_batches=1,
        val_check_interval=1,
        num_sanity_val_steps=1,
        max_epochs=4,
    )
    trainer.fit(val_test_predict_qmodel, datamodule=dm)
    trainer.validate(model=val_test_predict_qmodel,
                     datamodule=dm,
                     verbose=False)
    trainer.test(model=val_test_predict_qmodel, datamodule=dm, verbose=False)
    trainer.predict(model=val_test_predict_qmodel,
                    dataloaders=[
                        torch.utils.data.DataLoader(
                            RandomDataset(num_features, 16))
                    ])

    expected_qmodel = copy.deepcopy(qmodel)
    # No validation in ``expected_qmodel`` fitting.
    Trainer(
        callbacks=[QuantizationAwareTraining(quantize_on_fit_end=False)],
        default_root_dir=tmpdir,
        limit_train_batches=1,
        limit_val_batches=0,
        max_epochs=4,
    ).fit(expected_qmodel, datamodule=dm)

    expected_state_dict = expected_qmodel.state_dict()
    for key, value in val_test_predict_qmodel.state_dict().items():
        expected_value = expected_state_dict[key]
        assert torch.allclose(value, expected_value)
def test_quantization_exceptions(tmpdir):
    """Test wrong fuse layers."""
    with pytest.raises(MisconfigurationException, match="Unsupported qconfig"):
        QuantizationAwareTraining(qconfig=["abc"])

    with pytest.raises(MisconfigurationException, match="Unsupported observer type"):
        QuantizationAwareTraining(observer_type="abc")

    with pytest.raises(MisconfigurationException, match="Unsupported `collect_quantization`"):
        QuantizationAwareTraining(collect_quantization="abc")

    with pytest.raises(MisconfigurationException, match="Unsupported `collect_quantization`"):
        QuantizationAwareTraining(collect_quantization=1.2)

    fusing_layers = [(f"layers.mlp_{i}", f"layers.NONE-mlp_{i}a") for i in range(3)]
    qcb = QuantizationAwareTraining(modules_to_fuse=fusing_layers)
    trainer = Trainer(callbacks=[qcb], default_root_dir=tmpdir, max_epochs=1)
    with pytest.raises(MisconfigurationException, match="one or more of them is not your model attributes"):
        trainer.fit(RegressionModel(), datamodule=RegressDataModule())
def test_quantization(tmpdir, observe: str, fuse: bool):
    """Parity test for quant model"""
    seed_everything(42)
    dm = RegressDataModule()
    trainer_args = dict(
        default_root_dir=tmpdir,
        max_epochs=10,
        gpus=1 if torch.cuda.is_available() else None,
    )
    model = RegressionModel()
    qmodel = copy.deepcopy(model)

    trainer = Trainer(**trainer_args)
    trainer.fit(model, datamodule=dm)
    org_size = model.model_size
    org_score = torch.mean(
        torch.tensor([
            mean_relative_error(model(x), y) for x, y in dm.test_dataloader()
        ]))

    fusing_layers = [(f'layer_{i}', f'layer_{i}a')
                     for i in range(3)] if fuse else None
    qcb = QuantizationAwareTraining(observer_type=observe,
                                    modules_to_fuse=fusing_layers)
    trainer = Trainer(callbacks=[qcb], **trainer_args)
    trainer.fit(qmodel, datamodule=dm)

    quant_calls = qcb._forward_calls
    assert quant_calls == qcb._forward_calls

    quant_size = qmodel.model_size
    quant_score = torch.mean(
        torch.tensor([
            mean_relative_error(qmodel(x), y) for x, y in dm.test_dataloader()
        ]))
    # test that the trained model is smaller then initial
    size_ratio = quant_size / org_size
    assert size_ratio < 0.65
    # test that the test score is almost the same as with pure training
    assert torch.allclose(org_score, quant_score, atol=0.45)
Example #6
0
def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
    """Parity test for quant model."""
    seed_everything(42)
    dm = RegressDataModule()
    accelerator = "gpu" if torch.cuda.is_available() else "cpu"
    trainer_args = dict(default_root_dir=tmpdir,
                        max_epochs=7,
                        accelerator=accelerator,
                        devices=1)
    model = RegressionModel()
    qmodel = copy.deepcopy(model)

    trainer = Trainer(**trainer_args)
    trainer.fit(model, datamodule=dm)
    org_size = get_model_size_mb(model)
    org_score = torch.mean(
        torch.tensor([mape(model(x), y) for x, y in dm.test_dataloader()]))

    fusing_layers = [(f"layer_{i}", f"layer_{i}a")
                     for i in range(3)] if fuse else None
    qcb = QuantizationAwareTraining(
        observer_type=observe,
        modules_to_fuse=fusing_layers,
        quantize_on_fit_end=convert,
        observer_enabled_stages=("train", "validate"),
    )
    trainer = Trainer(callbacks=[qcb], **trainer_args)
    trainer.fit(qmodel, datamodule=dm)

    quant_calls = qcb._forward_calls
    assert quant_calls == qcb._forward_calls
    quant_score = torch.mean(
        torch.tensor([mape(qmodel(x), y) for x, y in dm.test_dataloader()]))
    # test that the test score is almost the same as with pure training
    assert torch.allclose(org_score, quant_score, atol=0.45)
    model_path = trainer.checkpoint_callback.best_model_path
    curr_epoch = trainer.current_epoch

    trainer_args.update(dict(max_epochs=1, enable_checkpointing=False))
    if not convert:
        trainer = Trainer(callbacks=[QuantizationAwareTraining()],
                          **trainer_args)
        trainer.fit(qmodel, datamodule=dm)
        qmodel.eval()
        torch.quantization.convert(qmodel, inplace=True)

    quant_size = get_model_size_mb(qmodel)
    # test that the trained model is smaller then initial
    size_ratio = quant_size / org_size
    assert size_ratio < 0.65

    # todo: make it work also with strict loading
    qmodel2 = RegressionModel.load_from_checkpoint(model_path, strict=False)
    quant2_score = torch.mean(
        torch.tensor([mape(qmodel2(x), y) for x, y in dm.test_dataloader()]))
    assert torch.allclose(org_score, quant2_score, atol=0.47)

    # test without and with QAT callback
    trainer_args.update(max_epochs=curr_epoch + 1)
    qmodel2 = RegressionModel()
    trainer = Trainer(callbacks=[QuantizationAwareTraining()], **trainer_args)
    trainer.fit(qmodel2, datamodule=dm, ckpt_path=model_path)
    quant2_score = torch.mean(
        torch.tensor([mape(qmodel2(x), y) for x, y in dm.test_dataloader()]))
    # test that the test score is almost the same as with pure training
    assert torch.allclose(org_score, quant2_score, atol=0.45)