def test_weight_only_activation_only_fakequant(self):
     for qengine in ["fbgemm", "qnnpack"]:
         if qengine not in torch.backends.quantized.supported_engines:
             continue
         if qengine == 'qnnpack':
             if IS_PPC or TEST_WITH_UBSAN:
                 continue
         with override_quantized_engine(qengine):
             torch.manual_seed(67)
             calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
             eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
             qconfigset = set([torch.quantization.default_weight_only_quant_qconfig,
                               torch.quantization.default_activation_only_quant_qconfig])
             SQNRTarget = [35, 45]
             for idx, qconfig in enumerate(qconfigset):
                 my_model = ModelMultipleOpsNoAvgPool().to(torch.float32)
                 my_model.eval()
                 out_ref = my_model(eval_data)
                 fq_model = torch.quantization.QuantWrapper(my_model)
                 fq_model.train()
                 fq_model.qconfig = qconfig
                 torch.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
                 torch.quantization.prepare_qat(fq_model)
                 fq_model.eval()
                 fq_model.apply(torch.quantization.disable_fake_quant)
                 fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
                 fq_model(calib_data)
                 fq_model.apply(torch.quantization.enable_fake_quant)
                 fq_model.apply(torch.quantization.disable_observer)
                 out_fq = fq_model(eval_data)
                 SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq))
                 self.assertGreater(SQNRdB, SQNRTarget[idx], msg='Quantized model numerics diverge from float')
 def test_fake_quant_true_quant_compare(self):
     for qengine in ["fbgemm", "qnnpack"]:
         if qengine not in torch.backends.quantized.supported_engines:
             continue
         if qengine == 'qnnpack':
             if IS_PPC or TEST_WITH_UBSAN:
                 continue
         with override_quantized_engine(qengine):
             torch.manual_seed(67)
             my_model = ModelMultipleOpsNoAvgPool().to(torch.float32)
             calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
             eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
             my_model.eval()
             out_ref = my_model(eval_data)
             fq_model = torch.quantization.QuantWrapper(my_model)
             fq_model.train()
             fq_model.qconfig = torch.quantization.default_qat_qconfig
             torch.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
             torch.quantization.prepare_qat(fq_model)
             fq_model.eval()
             fq_model.apply(torch.quantization.disable_fake_quant)
             fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
             fq_model(calib_data)
             fq_model.apply(torch.quantization.enable_fake_quant)
             fq_model.apply(torch.quantization.disable_observer)
             out_fq = fq_model(eval_data)
             SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq))
             # Quantized model output should be close to floating point model output numerically
             # Setting target SQNR to be 35 dB
             self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB')
             torch.quantization.convert(fq_model)
             out_q = fq_model(eval_data)
             SQNRdB = 20 * torch.log10(torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10))
             self.assertGreater(SQNRdB, 60, msg='Fake quant and true quant numerics diverge, expect SQNR > 60 dB')
 def test_weight_only_activation_only_fakequant(self):
     torch.manual_seed(67)
     calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
     eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
     qconfigset = set([
         torch.quantization.default_weight_only_quant_qconfig,
         torch.quantization.default_activation_only_quant_qconfig
     ])
     SQNRTarget = [35, 45]
     for idx, qconfig in enumerate(qconfigset):
         myModel = ModelMultipleOpsNoAvgPool().to(torch.float32)
         myModel.eval()
         out_ref = myModel(eval_data)
         fqModel = torch.quantization.QuantWrapper(myModel)
         fqModel.train()
         fqModel.qconfig = qconfig
         torch.quantization.fuse_modules(fqModel.module,
                                         [['conv1', 'bn1', 'relu1']])
         torch.quantization.prepare_qat(fqModel)
         fqModel.eval()
         fqModel.apply(torch.quantization.disable_fake_quant)
         fqModel.apply(torch.nn._intrinsic.qat.freeze_bn_stats)
         fqModel(calib_data)
         fqModel.apply(torch.quantization.enable_fake_quant)
         fqModel.apply(torch.quantization.disable_observer)
         out_fq = fqModel(eval_data)
         SQNRdB = 20 * torch.log10(
             torch.norm(out_ref) / torch.norm(out_ref - out_fq))
         self.assertGreater(
             SQNRdB,
             SQNRTarget[idx],
             msg='Quantized model numerics diverge from float')
 def test_fake_quant_true_quant_compare(self):
     torch.manual_seed(67)
     myModel = ModelMultipleOpsNoAvgPool().to(torch.float32)
     calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
     eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
     myModel.eval()
     out_ref = myModel(eval_data)
     fqModel = torch.quantization.QuantWrapper(myModel)
     fqModel.train()
     fqModel.qconfig = torch.quantization.default_qat_qconfig
     torch.quantization.fuse_modules(fqModel.module,
                                     [['conv1', 'bn1', 'relu1']])
     torch.quantization.prepare_qat(fqModel)
     fqModel.eval()
     fqModel.apply(torch.quantization.disable_fake_quant)
     fqModel.apply(torch.nn._intrinsic.qat.freeze_bn_stats)
     fqModel(calib_data)
     fqModel.apply(torch.quantization.enable_fake_quant)
     fqModel.apply(torch.quantization.disable_observer)
     out_fq = fqModel(eval_data)
     SQNRdB = 20 * torch.log10(
         torch.norm(out_ref) / torch.norm(out_ref - out_fq))
     # Quantized model output should be close to floating point model output numerically
     # Setting target SQNR to be 35 dB
     self.assertGreater(
         SQNRdB,
         35,
         msg=
         'Quantized model numerics diverge from float, expect SQNR > 35 dB')
     torch.quantization.convert(fqModel)
     out_q = fqModel(eval_data)
     SQNRdB = 20 * torch.log10(
         torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10))
     self.assertGreater(
         SQNRdB,
         60,
         msg=
         'Fake quant and true quant numerics diverge, expect SQNR > 60 dB')