def test_quantize_batchmatmul_activation(): batch = 4 in_features = 8 out_features = 4 class TestNet(Module): def __init__(self, bias): super().__init__() self.quant = QuantStub() self.dequant = DequantStub() self.batch_mm = BatchMatMulActivation( batch, in_features, out_features, bias=bias ) def forward(self, inp): out = self.quant(inp) out = self.batch_mm(out) out = expand_dims(out, -1) out = self.dequant(out) return out inputs = tensor( np.random.randn(batch, in_features, out_features).astype(np.float32) ) for bias in (True, False): net = TestNet(bias) net.train() qat_net = quantize_qat(net, inplace=False) disable_fake_quant(qat_net) normal_outputs = net(inputs) qat_outputs = qat_net(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) net.eval() normal_outputs = net(inputs) qat_net.eval() qat_outputs = qat_net(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) enable_fake_quant(qat_net) qat_outputs = qat_net(inputs) qnet = quantize(qat_net, inplace=False) qnet.eval() quantize_outputs = qnet(inputs) np.testing.assert_allclose( qat_outputs.numpy(), quantize_outputs.numpy(), atol=1e-6 ) @jit.trace(capture_as_const=True) def f(x): qnet.eval() return qnet(x) f(inputs) file = io.BytesIO() f.dump(file, enable_nchw4=True) file.seek(0) dumped_outputs = cgtools.load_and_inference(file, [inputs])[0] np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6)
def test_enable_and_disable_fake_quant(): net = init_qat_net() disable_fake_quant(net) assert net.quant.act_fake_quant.enabled == False assert net.linear.weight_fake_quant.enabled == False assert net.linear.act_fake_quant.enabled == False enable_fake_quant(net) assert net.quant.act_fake_quant.enabled == True assert net.linear.weight_fake_quant.enabled == True assert net.linear.act_fake_quant.enabled == True
def test_enable_and_disable_all(): x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) net = Net() y1 = net(x).numpy() net = quantize_qat(net, min_max_fakequant_qconfig) init_observer(net, x) y2 = net(x).numpy() disable_fake_quant(net) y3 = net(x).numpy() enable_fake_quant(net) y4 = net(x).numpy() np.testing.assert_allclose(y1, y3) np.testing.assert_allclose(y2, y4) with pytest.raises(AssertionError): np.testing.assert_allclose(y2, y3)
def init_observer(module, data): enable_observer(module) disable_fake_quant(module) module(data) disable_observer(module) enable_fake_quant(module)