def test_nested3(self): r"""More complicated nested test case with child qconfig overrides parent qconfig """ model = NestedModel().eval() custum_options = { 'dtype': torch.quint8, 'qscheme': torch.per_tensor_affine } custom_dynamic_qconfig = QConfigDynamic(weight=default_weight_observer) qconfig_dynamic_dict = { 'fc3': default_dynamic_qconfig, 'sub2': default_dynamic_qconfig, 'sub2.fc1': custom_dynamic_qconfig } prepare_dynamic(model, qconfig_dynamic_dict) convert_dynamic(model) def checkQuantized(model): self.checkDynamicQuantizedLinear(model.sub2.fc1) self.checkDynamicQuantizedLinear(model.sub2.fc2) self.checkDynamicQuantizedLinear(model.fc3) self.checkScriptable(model, self.calib_data, check_save_load=True) checkQuantized(model) # test one line API model = quantize_dynamic(NestedModel().eval(), qconfig_dynamic_dict) checkQuantized(model)
def test_prepare_dynamic_lstm(self): class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float) def forward(self, x): return self.lstm(x) from torch.quantization.observer import default_dynamic_quant_observer, _MinMaxTensorListObserver qconfig = QConfigDynamic(activation=default_dynamic_quant_observer, weight=_MinMaxTensorListObserver) m = torch.jit.script(M()) m = prepare_dynamic_script(m, {'': qconfig}) assert len(attrs_with_prefix(m.lstm, '_observer_')) == 1 FileCheck().check('_MinMaxTensorListObserver = prim::GetAttr[name="_observer_0') \ .check("aten::lstm") \ .check("return") \ .run(str(get_module_method(m, 'lstm', 'forward__0').graph))