def validate_backward_pass(self, tc: TestCase): original_model = copy.deepcopy(tc.model) quant_op = QcQuantizeRecurrent(module_to_quantize=tc.model, weight_bw=8, activation_bw=8, is_symmetric=False, quant_scheme=QuantScheme.post_training_tf_enhanced, round_mode='nearest') encodings = libpymo.TfEncoding() encodings.bw = 8 encodings.max = 3 encodings.min = -2 encodings.delta = 1 encodings.offset = 0.2 for input_quantizer in quant_op.input_quantizers.values(): input_quantizer.enabled = True input_quantizer.encoding = encodings for name, param in quant_op.named_parameters(recurse=False): quant_op.param_quantizers[name].enabled = True quant_op.param_quantizers[name].encoding = encodings for output_quantizer in quant_op.output_quantizers.values(): output_quantizer.encoding = encodings # Checking if param are matched for name, param in original_model.named_parameters(): self.assertTrue(torch.allclose(param.data, getattr(quant_op, name).data, atol=1e-05)) inp = torch.rand(tc.input_shape, requires_grad=True).to(tc.device) o_qc_rnn, _ = quant_op(inp, hx=None) # Checking if param are matched for name, param in original_model.named_parameters(): self.assertTrue(torch.allclose(param.data, getattr(quant_op, name).data, atol=1e-05)) optimizer = torch.optim.SGD(quant_op.parameters(), lr=0.05, momentum=0.5) # creating a fake loss function with sum of output loss = o_qc_rnn.flatten().sum() loss.backward() for name, param in quant_op.module_to_quantize.named_parameters(): self.assertTrue(param.grad is None) self.assertTrue(getattr(quant_op, name).grad is not None) optimizer.step() for name, param in original_model.named_parameters(): # check if custom param have been updated quant_param = getattr(quant_op, name) self.assertFalse(torch.allclose(param.data, quant_param.data, atol=1e-05)) # check if 'replaced' module param are still the same module_to_quantize_param = getattr(quant_op.module_to_quantize, name) self.assertTrue(torch.allclose(param.data, module_to_quantize_param.data, atol=1e-05)) # updated the 'replaced' module and check for the reverse quant_op.update_params() for name, param in quant_op.module_to_quantize.named_parameters(): # check if 'replaced' module param have been updated orig_param = getattr(original_model, name) self.assertFalse(torch.allclose(param.data, orig_param.data, atol=1e-05)) # check if 'replaced' module param are matching the Custom Op quant_param = getattr(quant_op, name) self.assertTrue(torch.allclose(param.data, quant_param.data, atol=1e-05))
def create_encoding_from_dict(encoding_dict: dict) -> (libpymo.TfEncoding, bool): """ Create encoding object from encoding dictionary :param encoding_dict: Dictionary containing encodings :return: Encoding object, is_symmetric """ encoding = libpymo.TfEncoding() encoding.bw = encoding_dict.get('bitwidth') encoding.max = encoding_dict.get('max') encoding.min = encoding_dict.get('min') encoding.delta = encoding_dict.get('scale') encoding.offset = encoding_dict.get('offset') is_symmetric = eval(encoding_dict.get('is_symmetric')) # pylint: disable=eval-used return encoding, is_symmetric
def __setstate__(self, state): # Restore instance attributes self.__dict__.update(state.dict) # Create the c++ op self._cppOp = AimetTensorQuantizer.AimetTensorQuantizer( self.quant_scheme) # Create the encoding object if hasattr(state, 'min'): self.encoding = libpymo.TfEncoding() self.encoding.bw = state.bw self.encoding.max = state.max self.encoding.min = state.min self.encoding.delta = state.delta self.encoding.offset = state.offset else: self.encoding = None
def test_quantize_only_cpu(self): """ Test tensor quantizer quantize only functionality """ post_training_tensor_quantizer = \ PostTrainingTensorQuantizer(bitwidth=8, round_mode='nearest', quant_scheme=MAP_QUANT_SCHEME_TO_PYMO[QuantScheme.post_training_tf], use_symmetric_encodings=False, enabled_by_default=True) encodings = libpymo.TfEncoding() encodings.bw = 8 encodings.max = 2.23 encodings.min = -5.19 post_training_tensor_quantizer.encoding = encodings inp_tensor = torch.tensor([-7, -5, -3, 0, .1, 2.5]) quant_out = post_training_tensor_quantizer.quantize( inp_tensor, MAP_ROUND_MODE_TO_PYMO['nearest']) expected_out = torch.tensor([0, 6, 75, 178, 181, 255], dtype=torch.float32) self.assertTrue(torch.equal(quant_out, expected_out))
def test_quantsim_export(self): torch.manual_seed(10) model = Model2(Add()) dummy_input = torch.randn(5, 10, 10, 20) sim = QuantizationSimModel(model, dummy_input) encodings = libpymo.TfEncoding() encodings.bw = 8 encodings.max = 5 encodings.min = -5 encodings.delta = 1 encodings.offset = 0.2 sim.model.op1.output_quantizer.encoding = encodings sim.model.conv1.output_quantizer.encoding = encodings sim.model.conv1.param_quantizers['weight'].encoding = encodings sim.export(path='./data', filename_prefix='quant_model', dummy_input=dummy_input) with open('./data/quant_model.encodings') as f: data = json.load(f) self.assertTrue(isinstance(data['activation_encodings']['3'], list)) self.assertTrue(isinstance(data['activation_encodings']['4'], list))
def get_encoding(self) -> libpymo.TfEncoding: """ Get encoding if valid else raise error :return: encoding """ if self.is_encoding_valid(): encoding_min = self.get_variable_from_op( QuantizeOpIndices.encoding_min) encoding_max = self.get_variable_from_op( QuantizeOpIndices.encoding_max) bitwidth = self.bitwidth # Create Encoding object encoding = libpymo.TfEncoding() encoding.min = encoding_min encoding.max = encoding_max encoding.bw = bitwidth encoding.delta, encoding.offset = calculate_delta_offset( encoding_min, encoding_max, bitwidth) else: raise AssertionError( 'Compute encoding or Set encoding must be invoked before') return encoding
def validate_serialize_deserialize(self, tc: TestCase): """ helper method to run quant RNN test """ original_model = copy.deepcopy(tc.model) quant_op = QcQuantizeRecurrent(module_to_quantize=tc.model, weight_bw=8, activation_bw=8, is_symmetric=False, quant_scheme=QuantScheme.post_training_tf_enhanced, round_mode='nearest') quant_op.eval() inp = torch.rand(tc.input_shape, requires_grad=True).to(tc.device) encodings = libpymo.TfEncoding() encodings.bw = 8 encodings.max = 3 encodings.min = -2 encodings.delta = 1 encodings.offset = 0.2 for input_quantizer in quant_op.input_quantizers.values(): input_quantizer.enabled = True input_quantizer.encoding = encodings for name, param in quant_op.named_parameters(recurse=False): quant_op.param_quantizers[name].enabled = True quant_op.param_quantizers[name].encoding = encodings for output_quantizer in quant_op.output_quantizers.values(): output_quantizer.encoding = encodings o_qc_rnn, _ = quant_op(inp, hx=None) optimizer = torch.optim.SGD(quant_op.parameters(), lr=0.05, momentum=0.5) # creating a fake loss function with sum of output loss = o_qc_rnn.flatten().sum() loss.backward() optimizer.step() # Generate Quantize encodings quant_op.compute_encoding() quant_op.compute_weight_encodings() o_pre, h_pre = quant_op(inp, hx=None) # Save and loaded a quantized model with tempfile.NamedTemporaryFile() as f: torch.save(quant_op, f) f.seek(0) loaded_model = torch.load(f) loaded_model.eval() # compare the parameters for name, param in quant_op.named_parameters(recurse=False): loaded_param = getattr(loaded_model, name) self.assertTrue(torch.equal(param, loaded_param), msg="param mismatched recurrent op param mis-matched, TestCase:{}".format(tc.test_name)) for name, param in quant_op.module_to_quantize.named_parameters(): loaded_param = getattr(loaded_model.module_to_quantize, name) self.assertTrue(torch.equal(param, loaded_param), msg="original module mismatched, TestCase:{}".format(tc.test_name)) # compare the quantizers for name, output_quantizer in quant_op.output_quantizers.items(): if output_quantizer.enabled: self.compare_quantizer(output_quantizer, loaded_model.output_quantizers[name]) for name, quantizer in quant_op.param_quantizers.items(): if quantizer.enabled: self.compare_quantizer(quantizer, loaded_model.param_quantizers[name]) # check if the loaded module generates the same output o_post, h_post = loaded_model(inp, hx=None) self.assertTrue(torch.equal(o_pre, o_post), msg="output mismatched, Failed TestCase:{}".format(tc.test_name)) if isinstance(h_pre, tuple): for pre, post in zip(h_pre, h_post): self.assertTrue(torch.equal(pre, post), msg="h or c mismatched, Failed TestCase:{}".format(tc.test_name)) else: self.assertTrue(torch.equal(h_pre, h_post), msg="h mis-matched, Failed TestCase:{}".format(tc.test_name))
def test_qc_post_training_wrapper(self): torch.manual_seed(0) encodings = libpymo.TfEncoding() encodings.bw, encodings.max, encodings.min, encodings.delta, encodings.offset = 8, 0.5, -1, 1, 0.2 encodings_new = libpymo.TfEncoding() encodings_new.bw, encodings_new.max, encodings_new.min, encodings_new.delta, encodings_new.offset = 8, 0.4, -0.98, 1, 0.2 output_grad = [] def hook_fn(m, _, i): for grad in i: try: output_grad.append(grad) except AttributeError: print("None found for Gradient") conv1 = torch.nn.Conv2d(1, 2, 1) quantize = QcPostTrainingWrapper( conv1, weight_bw=8, activation_bw=8, round_mode='nearest', quant_scheme=QuantScheme.post_training_tf_enhanced) quantize.train() quantize._module_to_wrap.register_backward_hook(hook_fn) quantize.input_quantizer.enabled = True quantize.output_quantizers[0].enabled = True quantize.input_quantizer.encoding = encodings quantize.output_quantizers[0].encoding = encodings new_input = torch.autograd.Variable(torch.tensor([[[[0.6469]]], [[[-0.9]]]]), requires_grad=True) quantize.set_mode(QcQuantizeOpMode.ACTIVE) out = quantize(new_input) quantize.input_quantizer.encoding = encodings_new quantize.output_quantizers[0].encoding = encodings_new quantize.param_quantizers['weight'].encoding = encodings_new loss = out.flatten().sum() loss.backward() # Check if input gradient got clipped for i, val in enumerate(new_input): if encodings_new.min > val or val > encodings_new.max: self.assertTrue(new_input.grad[0][i] == 0.0) # Check if output gradient got clipped output_grad = output_grad[0].flatten() self.assertTrue(output_grad[0] == 1.0) self.assertTrue(output_grad[1] == 1.0) self.assertTrue(output_grad[2] == 1.0) self.assertTrue(output_grad[3] == 0.0) # Check if weight gradient got clipped weight_tensor = quantize._module_to_wrap.weight.flatten() weight_tensor_grad = quantize._module_to_wrap.weight.grad.flatten() for i, val in enumerate(weight_tensor): if encodings_new.min > val or val > encodings_new.max: self.assertTrue(weight_tensor_grad[i] == 0.0)