def test_qconfig_dict(self): data = [(torch.randn(10, 5, dtype=torch.float) * 20, 1)] # Eager mode qconfig = QConfig(activation=Observer, weight=WeightObserver) eager_module = AnnotatedNestedModel() eager_module.fc3.qconfig = qconfig eager_module.sub2.fc1.qconfig = qconfig # Assign weights eager_module.sub1.fc.weight.data.fill_(1.0) eager_module.sub2.fc1.module.weight.data.fill_(1.0) eager_module.sub2.fc2.weight.data.fill_(1.0) eager_module.fc3.module.weight.data.fill_(1.0) script_module = torch.jit.script(NestedModel()) # Copy weights for eager_module script_module.sub1.fc.weight = eager_module.sub1.fc.weight script_module.sub2.fc1.weight = eager_module.sub2.fc1.module.weight script_module.sub2.fc2.weight = eager_module.sub2.fc2.weight script_module.fc3.weight = eager_module.fc3.module.weight # Quantize eager module quantized_eager_module = quantize(eager_module, default_eval_fn, data) def get_forward(m): return m._c._get_method('forward') # Quantize script_module torch._C._jit_pass_constant_propagation( get_forward(script_module).graph) ScriptedObserver = torch.jit.script(Observer()) ScriptedWeightObserver = torch.jit.script(WeightObserver()) scripted_qconfig = QConfig(activation=ScriptedObserver._c, weight=ScriptedWeightObserver._c) qconfig_dict = {'sub2.fc1': scripted_qconfig, 'fc3': scripted_qconfig} torch._C._jit_pass_insert_observers(script_module._c, "forward", qconfig_dict) # Run script_module and Collect statistics get_forward(script_module)(data[0][0]) # Insert quantize and dequantize calls script_module._c = torch._C._jit_pass_insert_quant_dequant( script_module._c, "forward") # Note that observer modules are not removed right now torch._C._jit_pass_quant_fusion( script_module._c._get_method('forward').graph) get_forward(script_module)(data[0][0]) eager_result = quantized_eager_module(data[0][0]) script_result = get_forward(script_module)(data[0][0]) self.assertEqual(eager_result, script_result)
def test_nested1(self): r"""Test quantization for nested model, top level 'fc3' and 'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized """ model = AnnotatedNestedModel() def checkPrepModules(model, before_calib=False): if before_calib: self.checkObservers(model) self.checkNoPrepModules(model) self.checkNoPrepModules(model.sub1) self.checkNoPrepModules(model.sub1.fc) self.checkNoPrepModules(model.sub1.relu) self.checkNoPrepModules(model.sub2) self.checkHasPrepModules(model.sub2.fc1) self.checkNoPrepModules(model.sub2.fc2) self.checkHasPrepModules(model.fc3) model = prepare(model) checkPrepModules(model, True) test_only_eval_fn(model, self.calib_data) model = convert(model) def checkQuantized(model): checkPrepModules(model) self.checkLinear(model.sub1.fc) self.checkWrappedQuantizedLinear(model.fc3) self.checkWrappedQuantizedLinear(model.sub2.fc1) self.checkLinear(model.sub2.fc2) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) checkQuantized(model) # test one line API model = quantize(AnnotatedNestedModel(), test_only_eval_fn, self.calib_data) checkQuantized(model)