def test_qat_convbn2d(): in_channels = 32 out_channels = 64 kernel_size = 3 for groups, bias in product([1, 4], [True, False]): module = ConvBn2d(in_channels, out_channels, kernel_size, groups=groups, bias=bias) module.train() qat_module = quantize_qat(module, inplace=False) disable_fake_quant(qat_module) inputs = tensor( np.random.randn(4, in_channels, 32, 32).astype(np.float32)) normal_outputs = module(inputs) qat_outputs = qat_module(inputs) assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) assertTensorClose(module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8) assertTensorClose(module.bn.running_var, qat_module.bn.running_var, max_err=5e-7) module.eval() normal_outputs = module(inputs) qat_module.eval() qat_outputs = qat_module(inputs) assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)
def test_disable_fake_quant(): class Net(Float.Module): def __init__(self): super().__init__() self.quant = Float.QuantStub() self.linear = Float.Linear(3, 3) self.dequant = Float.DequantStub() self.linear.bias.set_value(np.random.rand(3)) def forward(self, x): x = self.quant(x) x = self.linear(x) x = self.dequant(x) return x x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) net = Net() y1 = net(x).numpy() net = quantize_qat(net, min_max_fakequant_qconfig) y2 = net(x).numpy() disable_fake_quant(net) y3 = net(x).numpy() np.testing.assert_allclose(y1, y3) with pytest.raises(AssertionError): np.testing.assert_allclose(y2, y3)
def test_quant_stub(): normal_net = Float.QuantStub() normal_net.eval() qat_from_float = QAT.QuantStub.from_float_module(normal_net) qat_from_float.eval() disable_observer(qat_from_float) disable_fake_quant(qat_from_float) qat_net = QAT.QuantStub() qat_net.eval() disable_observer(qat_net) propagate_qconfig(qat_net, min_max_fakequant_qconfig) init_qat_net(qat_net) q_net = Q.QuantStub.from_qat_module(qat_net) q_net.eval() x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) normal = normal_net(x) qat_without_fakequant = qat_from_float(x) fake_quant_normal = fake_quant_act(normal_net(x), act_scale) qat = qat_net(x) q = q_net(x).numpy() * act_scale np.testing.assert_allclose(qat_without_fakequant, normal) np.testing.assert_allclose(qat, fake_quant_normal) np.testing.assert_allclose(q, fake_quant_normal.numpy())
def test_quantize_batchmatmul_activation(): batch = 4 in_features = 8 out_features = 4 class TestNet(Module): def __init__(self, bias): super().__init__() self.quant = QuantStub() self.dequant = DequantStub() self.batch_mm = BatchMatMulActivation( batch, in_features, out_features, bias=bias ) def forward(self, inp): out = self.quant(inp) out = self.batch_mm(out) out = expand_dims(out, -1) out = self.dequant(out) return out inputs = tensor( np.random.randn(batch, in_features, out_features).astype(np.float32) ) for bias in (True, False): net = TestNet(bias) net.train() qat_net = quantize_qat(net, inplace=False) disable_fake_quant(qat_net) normal_outputs = net(inputs) qat_outputs = qat_net(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) net.eval() normal_outputs = net(inputs) qat_net.eval() qat_outputs = qat_net(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) enable_fake_quant(qat_net) qat_outputs = qat_net(inputs) qnet = quantize(qat_net, inplace=False) qnet.eval() quantize_outputs = qnet(inputs) np.testing.assert_allclose( qat_outputs.numpy(), quantize_outputs.numpy(), atol=1e-6 ) @jit.trace(capture_as_const=True) def f(x): qnet.eval() return qnet(x) f(inputs) file = io.BytesIO() f.dump(file, enable_nchw4=True) file.seek(0) dumped_outputs = cgtools.load_and_inference(file, [inputs])[0] np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6)
def test_enable_and_disable_fake_quant(): net = init_qat_net() disable_fake_quant(net) assert net.quant.act_fake_quant.enabled == False assert net.linear.weight_fake_quant.enabled == False assert net.linear.act_fake_quant.enabled == False enable_fake_quant(net) assert net.quant.act_fake_quant.enabled == True assert net.linear.weight_fake_quant.enabled == True assert net.linear.act_fake_quant.enabled == True
def test_qat_conv(padding, padding_mode): in_channels = 32 out_channels = 64 kernel_size = 3 class TestNet(Module): def __init__(self, groups, bias): super().__init__() self.quant = QuantStub() self.dequant = DequantStub() self.conv = Conv2d( in_channels, out_channels, kernel_size, groups=groups, bias=bias, padding=padding, padding_mode=padding_mode, ) self.conv_relu = ConvRelu2d(out_channels, in_channels, kernel_size, groups=groups, bias=bias) def forward(self, inp): out = self.quant(inp) out = self.conv(out) out = self.conv_relu(out) out = self.dequant(out) return out inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) for groups, bias in product([1, 4], [True, False]): net = TestNet(groups, bias) net.train() qat_net = quantize_qat(net, inplace=False) disable_fake_quant(qat_net) normal_outputs = net(inputs) qat_outputs = qat_net(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) net.eval() normal_outputs = net(inputs) qat_net.eval() qat_outputs = qat_net(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
def test_elemwise(kind): normal_net = Float.Elemwise(kind) normal_net.eval() qat_from_float = QAT.Elemwise.from_float_module(normal_net) qat_from_float.eval() disable_observer(qat_from_float) disable_fake_quant(qat_from_float) qat_net = QAT.Elemwise(kind) qat_net.eval() disable_observer(qat_net) propagate_qconfig(qat_net, min_max_fakequant_qconfig) init_qat_net(qat_net) q_net = Q.Elemwise.from_qat_module(qat_net) q_net.eval() x1_scale = np.float32(np.random.rand() + 1) x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x1 = fake_quant_act(x1, x1_scale) x1.qparams.scale = x1_scale x2_scale = np.float32(np.random.rand() + 1) x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x2 = fake_quant_act(x2, x2_scale) x2.qparams.scale = x2_scale x1_int8 = quant(x1, x1_scale) x2_int8 = quant(x2, x2_scale) # test correctness of `Float`, `QAT` and `Quantized` if kind in ("add", "mul", "fuse_add_relu"): normal = normal_net(x1, x2) qat_without_fakequant = qat_from_float(x1, x2) fake_quant_normal = fake_quant_act(normal_net(x1, x2), act_scale) qat = qat_net(x1, x2) q = q_net(x1_int8, x2_int8).numpy() * act_scale else: normal = normal_net(x1) qat_without_fakequant = qat_from_float(x1) fake_quant_normal = fake_quant_act(normal_net(x1), act_scale) qat = qat_net(x1) q = q_net(x1_int8).numpy() * act_scale np.testing.assert_allclose(qat_without_fakequant, normal) np.testing.assert_allclose(qat, fake_quant_normal) np.testing.assert_allclose(q, fake_quant_normal.numpy())
def test_conv(module): normal_net = getattr(Float, module)(3, 3, 3, 1, 1, 1, bias=True) normal_net.eval() qat_net = getattr(QAT, module)(3, 3, 3, 1, 1, 1, bias=True) qat_net.eval() disable_observer(qat_net) propagate_qconfig(qat_net, min_max_fakequant_qconfig) init_qat_net(qat_net) x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32")) inp_scale = gen_inp_scale() x = fake_quant_act(x, inp_scale) x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) x_int8 = quant(x, inp_scale) weight = np.random.normal(size=(3, 3, 3, 3)).astype("float32") bias = np.random.normal(size=(1, 3, 1, 1)).astype("float32") if module in ("ConvBn2d", "ConvBnRelu2d"): normal_net.conv.weight[...] = fake_quant_weight(weight, weight_scale) normal_net.conv.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale) qat_net.conv.weight[...] = Parameter(weight) qat_net.conv.bias[...] = Parameter(bias) else: normal_net.weight[...] = fake_quant_weight(weight, weight_scale) normal_net.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale) qat_net.weight[...] = Parameter(weight) qat_net.bias[...] = Parameter(bias) qat_from_float = getattr(QAT, module).from_float_module(normal_net) qat_from_float.eval() disable_observer(qat_from_float) disable_fake_quant(qat_from_float) q_net = getattr(Q, module).from_qat_module(qat_net) q_net.eval() normal = normal_net(x) qat_without_fakequant = qat_from_float(x) fake_quant_normal = fake_quant_act(normal_net(x), act_scale) qat = qat_net(x) q = q_net(x_int8).numpy() * act_scale np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-5) np.testing.assert_allclose(qat, fake_quant_normal, atol=act_scale) np.testing.assert_allclose(q, fake_quant_normal.numpy(), atol=act_scale)
def test_conv(module): normal_net = getattr(Float, module)(3, 3, 3, 1, 1, 1, bias=True) normal_net.eval() qat_net = getattr(QAT, module)(3, 3, 3, 1, 1, 1, bias=True) qat_net.eval() disable_observer(qat_net) propagate_qconfig(qat_net, min_max_fakequant_qconfig) init_qat_net(qat_net) x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32")) x = fake_quant(x, inp_scale) x.q_dict["scale"] = inp_scale x_int8 = quant(x, inp_scale) weight = np.random.normal(size=(3, 3, 3, 3)).astype("float32") bias = np.random.normal(size=(1, 3, 1, 1)).astype("float32") if module in ("ConvBn2d", "ConvBnRelu2d"): normal_net.conv.weight.set_value(fake_quant(weight, weight_scale)) normal_net.conv.bias.set_value( fake_quant(bias, inp_scale * weight_scale)) qat_net.conv.weight.set_value(weight) qat_net.conv.bias.set_value(bias) else: normal_net.weight.set_value(fake_quant(weight, weight_scale)) normal_net.bias.set_value(fake_quant(bias, inp_scale * weight_scale)) qat_net.weight.set_value(weight) qat_net.bias.set_value(bias) qat_from_float = getattr(QAT, module).from_float_module(normal_net) qat_from_float.eval() disable_observer(qat_from_float) disable_fake_quant(qat_from_float) q_net = getattr(Q, module).from_qat_module(qat_net) q_net.eval() normal = normal_net(x) qat_without_fakequant = qat_from_float(x) fake_quant_normal = fake_quant(normal_net(x), act_scale) qat = qat_net(x) q = q_net(x_int8).numpy() * act_scale np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-6) np.testing.assert_allclose(qat, fake_quant_normal) np.testing.assert_allclose(q, fake_quant_normal.numpy())
def test_enable_and_disable_all(): x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) net = Net() y1 = net(x).numpy() net = quantize_qat(net, min_max_fakequant_qconfig) init_observer(net, x) y2 = net(x).numpy() disable_fake_quant(net) y3 = net(x).numpy() enable_fake_quant(net) y4 = net(x).numpy() np.testing.assert_allclose(y1, y3) np.testing.assert_allclose(y2, y4) with pytest.raises(AssertionError): np.testing.assert_allclose(y2, y3)
def test_linear(): normal_net = Float.Linear(3, 3, bias=True) normal_net.eval() qat_net = QAT.Linear(3, 3, bias=True) qat_net.eval() disable_observer(qat_net) propagate_qconfig(qat_net, min_max_fakequant_qconfig) init_qat_net(qat_net) x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) inp_scale = gen_inp_scale() x = fake_quant_act(x, inp_scale) x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) x_int8 = quant(x, inp_scale) weight = np.random.normal(size=(3, 3)).astype("float32") bias = np.random.normal(size=(3, )).astype("float32") normal_net.weight[...] = fake_quant_weight(weight, weight_scale) normal_net.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale) qat_net.weight[...] = Parameter(weight) qat_net.bias[...] = Parameter(bias) qat_from_float = QAT.Linear.from_float_module(normal_net) qat_from_float.eval() disable_fake_quant(qat_from_float) disable_observer(qat_from_float) q_net = Q.Linear.from_qat_module(qat_net) q_net.eval() normal = normal_net(x) qat_without_fakequant = qat_from_float(x) fake_quant_normal = fake_quant_act(normal_net(x), act_scale) qat = qat_net(x) q = q_net(x_int8).numpy() * act_scale np.testing.assert_allclose(qat_without_fakequant, normal) np.testing.assert_allclose(qat, fake_quant_normal.numpy()) np.testing.assert_allclose(q, fake_quant_normal.numpy())
def test_linear(): normal_net = Float.Linear(3, 3, bias=True) normal_net.eval() qat_net = QAT.Linear(3, 3, bias=True) qat_net.eval() disable_observer(qat_net) propagate_qconfig(qat_net, min_max_fakequant_qconfig) init_qat_net(qat_net) x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x = fake_quant(x, inp_scale) x.q_dict["scale"] = inp_scale x_int8 = quant(x, inp_scale) weight = np.random.normal(size=(3, 3)).astype("float32") bias = np.random.normal(size=(3, )).astype("float32") normal_net.weight.set_value(fake_quant(weight, weight_scale)) normal_net.bias.set_value(fake_quant(bias, inp_scale * weight_scale)) qat_net.weight.set_value(weight) qat_net.bias.set_value(bias) qat_from_float = QAT.Linear.from_float_module(normal_net) qat_from_float.eval() disable_fake_quant(qat_from_float) disable_observer(qat_from_float) q_net = Q.Linear.from_qat_module(qat_net) q_net.eval() normal = normal_net(x) qat_without_fakequant = qat_from_float(x) fake_quant_normal = fake_quant(normal_net(x), act_scale) qat = qat_net(x) q = q_net(x_int8).numpy() * act_scale np.testing.assert_allclose(qat_without_fakequant, normal) np.testing.assert_allclose(qat, fake_quant_normal) np.testing.assert_allclose(q, fake_quant_normal.numpy())
def test_qat_convbn2d(): in_channels = 32 out_channels = 64 kernel_size = 3 for groups, bias in product([1, 4], [True, False]): module = ConvBn2d(in_channels, out_channels, kernel_size, groups=groups, bias=bias) module.train() qat_module = quantize_qat(module, inplace=False) disable_fake_quant(qat_module) inputs = tensor( np.random.randn(4, in_channels, 32, 32).astype(np.float32)) normal_outputs = module(inputs) # import pdb # pdb.set_trace() qat_outputs = qat_module(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6) np.testing.assert_allclose( module.bn.running_mean.numpy(), qat_module.bn.running_mean.numpy(), atol=5e-8, ) np.testing.assert_allclose( module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7, ) module.eval() normal_outputs = module(inputs) qat_module.eval() qat_outputs = qat_module(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6)
def test_concat(): normal_net = Float.Concat() normal_net.eval() qat_net = QAT.Concat() qat_net.eval() disable_observer(qat_net) propagate_qconfig(qat_net, min_max_fakequant_qconfig) init_qat_net(qat_net) inps = [] inps_int8 = [] for i in range(3): inp_scale = gen_inp_scale() inps.append(mge.tensor( np.random.normal(size=(3, 3)).astype("float32"))) inps[i] = fake_quant_act(inps[i], inp_scale) inps[i].qparams.update( create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) inps_int8.append(quant(inps[i], inp_scale)) qat_from_float = QAT.Concat.from_float_module(normal_net) qat_from_float.eval() disable_fake_quant(qat_from_float) disable_observer(qat_from_float) q_net = Q.Concat.from_qat_module(qat_net) q_net.eval() normal = normal_net(inps) qat_without_fakequant = qat_from_float(inps) fake_quant_normal = fake_quant_act(normal_net(inps), act_scale) qat = qat_net(inps) q = q_net(inps_int8).numpy() * act_scale np.testing.assert_allclose(qat_without_fakequant, normal) np.testing.assert_allclose(qat, fake_quant_normal.numpy()) np.testing.assert_allclose(q, fake_quant_normal.numpy())
def test_qat_batchmatmul_activation(): batch = 4 in_features = 8 out_features = 4 class TestNet(Module): def __init__(self, bias): super().__init__() self.quant = QuantStub() self.dequant = DequantStub() self.batch_mm = BatchMatMulActivation(batch, in_features, out_features, bias=bias) def forward(self, inp): out = self.quant(inp) out = self.batch_mm(out) out = self.dequant(out) return out inputs = tensor( np.random.randn(batch, in_features, out_features).astype(np.float32)) for bias in (True, False): net = TestNet(bias) net.train() qat_net = quantize_qat(net, inplace=False) disable_fake_quant(qat_net) normal_outputs = net(inputs) qat_outputs = qat_net(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) net.eval() normal_outputs = net(inputs) qat_net.eval() qat_outputs = qat_net(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
def test_qat_conv_transpose2d(): in_channels = 32 out_channels = 64 kernel_size = 3 class TestNet(Module): def __init__(self, bias): super().__init__() self.quant = QuantStub() self.dequant = DequantStub() self.conv = ConvTranspose2d(in_channels, out_channels, kernel_size, bias=bias) def forward(self, inp): out = self.quant(inp) out = self.conv(out) out = self.dequant(out) return out inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) for bias in [True, False]: net = TestNet(bias) net.train() qat_net = quantize_qat(net, inplace=False) disable_fake_quant(qat_net) normal_outputs = net(inputs) qat_outputs = qat_net(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) net.eval() normal_outputs = net(inputs) qat_net.eval() qat_outputs = qat_net(inputs) np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
def init_observer(module, data): enable_observer(module) disable_fake_quant(module) module(data) disable_observer(module) enable_fake_quant(module)