Exemplo n.º 1
0
def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
    """Parity test for quant model."""
    seed_everything(42)
    dm = RegressDataModule()
    accelerator = "gpu" if torch.cuda.is_available() else "cpu"
    trainer_args = dict(default_root_dir=tmpdir, max_epochs=7, accelerator=accelerator, devices=1)
    model = RegressionModel()
    qmodel = copy.deepcopy(model)

    trainer = Trainer(**trainer_args)
    trainer.fit(model, datamodule=dm)
    org_size = get_model_size_mb(model)
    org_score = torch.mean(torch.tensor([mape(model(x), y) for x, y in dm.test_dataloader()]))

    fusing_layers = [(f"layer_{i}", f"layer_{i}a") for i in range(3)] if fuse else None
    qcb = QuantizationAwareTraining(
        observer_type=observe,
        modules_to_fuse=fusing_layers,
        quantize_on_fit_end=convert,
        observer_enabled_stages=("train", "validate"),
    )
    trainer = Trainer(callbacks=[qcb], **trainer_args)
    trainer.fit(qmodel, datamodule=dm)

    quant_calls = qcb._forward_calls
    assert quant_calls == qcb._forward_calls
    quant_score = torch.mean(torch.tensor([mape(qmodel(x), y) for x, y in dm.test_dataloader()]))
    # test that the test score is almost the same as with pure training
    assert torch.allclose(org_score, quant_score, atol=0.45)
    model_path = trainer.checkpoint_callback.best_model_path
    curr_epoch = trainer.current_epoch

    trainer_args.update(dict(max_epochs=1, enable_checkpointing=False))
    if not convert:
        trainer = Trainer(callbacks=[QuantizationAwareTraining()], **trainer_args)
        trainer.fit(qmodel, datamodule=dm)
        qmodel.eval()
        torch.quantization.convert(qmodel, inplace=True)

    quant_size = get_model_size_mb(qmodel)
    # test that the trained model is smaller then initial
    size_ratio = quant_size / org_size
    assert size_ratio < 0.65

    # todo: make it work also with strict loading
    qmodel2 = RegressionModel.load_from_checkpoint(model_path, strict=False)
    quant2_score = torch.mean(torch.tensor([mape(qmodel2(x), y) for x, y in dm.test_dataloader()]))
    assert torch.allclose(org_score, quant2_score, atol=0.45)

    # test without and with QAT callback
    trainer_args.update(max_epochs=curr_epoch + 1)
    qmodel2 = RegressionModel()
    trainer = Trainer(callbacks=[QuantizationAwareTraining()], **trainer_args)
    trainer.fit(qmodel2, datamodule=dm, ckpt_path=model_path)
    quant2_score = torch.mean(torch.tensor([mape(qmodel2(x), y) for x, y in dm.test_dataloader()]))
    # test that the test score is almost the same as with pure training
    assert torch.allclose(org_score, quant2_score, atol=0.45)
def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
    """Parity test for quant model"""
    seed_everything(42)
    dm = RegressDataModule()
    trainer_args = dict(default_root_dir=tmpdir,
                        max_epochs=7,
                        gpus=int(torch.cuda.is_available()))
    model = RegressionModel()
    qmodel = copy.deepcopy(model)

    trainer = Trainer(**trainer_args)
    trainer.fit(model, datamodule=dm)
    org_size = model.model_size
    org_score = torch.mean(
        torch.tensor([
            mean_relative_error(model(x), y) for x, y in dm.test_dataloader()
        ]))

    fusing_layers = [(f"layer_{i}", f"layer_{i}a")
                     for i in range(3)] if fuse else None
    qcb = QuantizationAwareTraining(observer_type=observe,
                                    modules_to_fuse=fusing_layers,
                                    quantize_on_fit_end=convert)
    trainer = Trainer(callbacks=[qcb], **trainer_args)
    trainer.fit(qmodel, datamodule=dm)

    quant_calls = qcb._forward_calls
    assert quant_calls == qcb._forward_calls
    quant_score = torch.mean(
        torch.tensor([
            mean_relative_error(qmodel(x), y) for x, y in dm.test_dataloader()
        ]))
    # test that the test score is almost the same as with pure training
    assert torch.allclose(org_score, quant_score, atol=0.45)
    model_path = trainer.checkpoint_callback.best_model_path

    trainer_args.update(dict(max_epochs=1, checkpoint_callback=False))
    if not convert:
        trainer = Trainer(callbacks=[QuantizationAwareTraining()],
                          **trainer_args)
        trainer.fit(qmodel, datamodule=dm)
        qmodel.eval()
        torch.quantization.convert(qmodel, inplace=True)

    quant_size = qmodel.model_size
    # test that the trained model is smaller then initial
    size_ratio = quant_size / org_size
    assert size_ratio < 0.65

    # todo: make it work also with strict loading
    qmodel2 = RegressionModel.load_from_checkpoint(model_path, strict=False)
    quant2_score = torch.mean(
        torch.tensor([
            mean_relative_error(qmodel2(x), y)
            for x, y in dm.test_dataloader()
        ]))
    assert torch.allclose(org_score, quant2_score, atol=0.45)