def test_nested3(self): r"""More complicated nested test case with child qconfig overrides parent qconfig """ model = NestedModel() custum_options = { 'dtype': torch.quint8, 'qscheme': torch.per_tensor_affine } custom_qconfig = QConfig(weight=default_observer(), activation=default_observer(**custum_options)) qconfig_dict = { 'fc3': default_qconfig, 'sub2': default_qconfig, 'sub2.fc1': custom_qconfig } model = prepare(model, qconfig_dict) def checkPrepModules(model, before_calib=False): if before_calib: self.checkObservers(model) self.checkNoPrepModules(model) self.checkNoPrepModules(model.sub1) self.checkNoPrepModules(model.sub1.fc) self.checkNoPrepModules(model.sub1.relu) self.checkNoPrepModules(model.sub2) self.checkHasPrepModules(model.sub2.fc1) self.checkHasPrepModules(model.sub2.fc2) self.checkHasPrepModules(model.fc3) checkPrepModules(model, True) default_eval_fn(model, calib_data) convert(model) def checkQuantized(model): checkPrepModules(model) self.checkQuantizedLinear(model.sub2.fc1) self.checkQuantizedLinear(model.sub2.fc2) self.checkQuantizedLinear(model.fc3) default_eval_fn(model, calib_data) checkQuantized(model) # test one line API model = quantize(NestedModel(), default_eval_fn, calib_data, qconfig_dict) checkQuantized(model)
def test_nested3(self): r"""More complicated nested test case with child qconfig overrides parent qconfig """ model = NestedModel().eval() custum_options = { 'dtype': torch.quint8, 'qscheme': torch.per_tensor_affine } custom_qconfig = QConfig(weight=default_weight_observer(), activation=default_observer(**custum_options)) qconfig_dict = { 'fc3': default_qconfig, 'sub2': default_qconfig, 'sub2.fc1': custom_qconfig } model = prepare_dynamic(model, qconfig_dict) convert_dynamic(model) def checkQuantized(model): self.checkDynamicQuantizedLinear(model.sub2.fc1) self.checkDynamicQuantizedLinear(model.sub2.fc2) self.checkDynamicQuantizedLinear(model.fc3) checkQuantized(model) # test one line API model = quantize_dynamic(NestedModel().eval(), qconfig_dict) checkQuantized(model)
def __init__(self): super(AnnotatedCustomConfigNestedModel, self).__init__() self.sub1 = LinearReluModel() self.sub2 = TwoLayerLinearModel() self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) self.fc3.qconfig = default_qconfig self.sub2.qconfig = default_qconfig custom_options = { 'dtype': torch.quint8, 'qscheme': torch.per_tensor_affine } custom_qconfig = QConfig(weight=default_weight_observer(), activation=default_observer(**custom_options)) self.sub2.fc1.qconfig = custom_qconfig self.sub2.fc1 = QuantWrapper(self.sub2.fc1) self.sub2.fc2 = QuantWrapper(self.sub2.fc2)