def patch_model_output_dequant(model):
    def patched_forward(self, input):
        out = self._original_forward(input)
        out = self.output_dequant(out)
        return out

    model.add_module('output_dequant', nnq.DeQuantize())
    model._original_forward = model.forward
    # https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance#comment66379065_2982
    model.forward = patched_forward.__get__(model)
예제 #2
0
    def init(self, C, M, N, dtype, mode):
        assert(mode in ('Q', 'D'))
        self.input = torch.rand(C, M, N)
        self.dtype = dtype
        self.op = nnq.Quantize(scale=1.0, zero_point=0, dtype=dtype)
        self.set_module_name('QuantizePerTensor')

        if mode == 'D':
            self.input = self.op(self.input)
            self.op = nnq.DeQuantize()
            self.set_module_name('DequantizePerTensor')
예제 #3
0
 def test_quant_dequant_api(self):
     r = torch.tensor([[1., -1.], [1., -1.]], dtype=torch.float)
     scale, zero_point, dtype = 1.0, 2, torch.qint8
     # testing Quantize API
     qr = torch.quantize_linear(r, scale, zero_point, dtype)
     quant_m = nnq.Quantize(scale, zero_point, dtype)
     qr2 = quant_m(r)
     self.assertEqual(qr, qr2)
     # testing Dequantize API
     rqr = qr.dequantize()
     dequant_m = nnq.DeQuantize()
     rqr2 = dequant_m(qr2)
     self.assertEqual(rqr, rqr2)
예제 #4
0
 def __init__(self, q_module, float_module, Logger):
     super(Shadow, self).__init__()
     self.orig_module = q_module
     self.shadow_module = float_module
     self.dequant = nnq.DeQuantize()
     self.logger = Logger()