def test_dequant_stub(): normal_net = Float.DequantStub() normal_net.eval() qat_from_float = QAT.DequantStub.from_float_module(normal_net) qat_from_float.eval() disable_fake_quant(qat_from_float) disable_observer(qat_from_float) qat_net = QAT.DequantStub() qat_net.eval() disable_observer(qat_net) propagate_qconfig(qat_net, min_max_fakequant_qconfig) init_qat_net(qat_net) q_net = Q.DequantStub.from_qat_module(qat_net) q_net.eval() x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x = fake_quant(x, inp_scale) x.q_dict["scale"] = inp_scale normal = normal_net(x) qat_without_fakequant = qat_from_float(x) fake_quant_normal = normal_net(x) qat = qat_net(x) q = q_net(quant(x, inp_scale)).numpy() np.testing.assert_allclose(qat_without_fakequant, normal) np.testing.assert_allclose(qat, fake_quant_normal) np.testing.assert_allclose(q, fake_quant_normal.numpy())
def __init__(self): super().__init__() self.quant = QAT.QuantStub() self.linear = Float.Sequential(QAT.Linear(3, 3), QAT.Linear(3, 3)) self.dequant = QAT.DequantStub() self.linear[0].bias[...] = Parameter(np.random.rand(3)) self.linear[1].bias[...] = Parameter(np.random.rand(3))
def __init__(self): super().__init__() self.quant = QAT.QuantStub() self.linear = QAT.Linear(3, 3) self.dequant = QAT.DequantStub() self.linear.bias.set_value(np.random.rand(3))