Ejemplo n.º 1
0
    def validate_backward_pass(self, tc: TestCase):
        original_model = copy.deepcopy(tc.model)
        quant_op = QcQuantizeRecurrent(module_to_quantize=tc.model, weight_bw=8, activation_bw=8, is_symmetric=False,
                                       quant_scheme=QuantScheme.post_training_tf_enhanced, round_mode='nearest')
        encodings = libpymo.TfEncoding()
        encodings.bw = 8
        encodings.max = 3
        encodings.min = -2
        encodings.delta = 1
        encodings.offset = 0.2
        for input_quantizer in quant_op.input_quantizers.values():
            input_quantizer.enabled = True
            input_quantizer.encoding = encodings
        for name, param in quant_op.named_parameters(recurse=False):
            quant_op.param_quantizers[name].enabled = True
            quant_op.param_quantizers[name].encoding = encodings
        for output_quantizer in quant_op.output_quantizers.values():
            output_quantizer.encoding = encodings
        # Checking if param are matched
        for name, param in original_model.named_parameters():
            self.assertTrue(torch.allclose(param.data, getattr(quant_op, name).data, atol=1e-05))
        inp = torch.rand(tc.input_shape, requires_grad=True).to(tc.device)
        o_qc_rnn, _ = quant_op(inp, hx=None)
        # Checking if param are matched
        for name, param in original_model.named_parameters():
            self.assertTrue(torch.allclose(param.data, getattr(quant_op, name).data, atol=1e-05))
        optimizer = torch.optim.SGD(quant_op.parameters(), lr=0.05, momentum=0.5)
        # creating a fake loss function with sum of output
        loss = o_qc_rnn.flatten().sum()
        loss.backward()
        for name, param in quant_op.module_to_quantize.named_parameters():
            self.assertTrue(param.grad is None)
            self.assertTrue(getattr(quant_op, name).grad is not None)
        optimizer.step()
        for name, param in original_model.named_parameters():
            # check if custom param have been updated
            quant_param = getattr(quant_op, name)
            self.assertFalse(torch.allclose(param.data, quant_param.data, atol=1e-05))

            # check if 'replaced' module param are still the same
            module_to_quantize_param = getattr(quant_op.module_to_quantize, name)
            self.assertTrue(torch.allclose(param.data, module_to_quantize_param.data, atol=1e-05))
        # updated the 'replaced' module and check for the reverse
        quant_op.update_params()
        for name, param in quant_op.module_to_quantize.named_parameters():
            # check if 'replaced' module param have been updated
            orig_param = getattr(original_model, name)
            self.assertFalse(torch.allclose(param.data, orig_param.data, atol=1e-05))

            # check if 'replaced' module param are matching the Custom Op
            quant_param = getattr(quant_op, name)
            self.assertTrue(torch.allclose(param.data, quant_param.data, atol=1e-05))
Ejemplo n.º 2
0
def create_encoding_from_dict(encoding_dict: dict) -> (libpymo.TfEncoding, bool):
    """
    Create encoding object from encoding dictionary
    :param encoding_dict: Dictionary containing encodings
    :return: Encoding object, is_symmetric
    """
    encoding = libpymo.TfEncoding()
    encoding.bw = encoding_dict.get('bitwidth')
    encoding.max = encoding_dict.get('max')
    encoding.min = encoding_dict.get('min')
    encoding.delta = encoding_dict.get('scale')
    encoding.offset = encoding_dict.get('offset')
    is_symmetric = eval(encoding_dict.get('is_symmetric'))  # pylint: disable=eval-used
    return encoding, is_symmetric
Ejemplo n.º 3
0
    def __setstate__(self, state):
        # Restore instance attributes
        self.__dict__.update(state.dict)

        # Create the c++ op
        self._cppOp = AimetTensorQuantizer.AimetTensorQuantizer(
            self.quant_scheme)

        # Create the encoding object
        if hasattr(state, 'min'):
            self.encoding = libpymo.TfEncoding()
            self.encoding.bw = state.bw
            self.encoding.max = state.max
            self.encoding.min = state.min
            self.encoding.delta = state.delta
            self.encoding.offset = state.offset
        else:
            self.encoding = None
Ejemplo n.º 4
0
    def test_quantize_only_cpu(self):
        """ Test tensor quantizer quantize only functionality """

        post_training_tensor_quantizer = \
            PostTrainingTensorQuantizer(bitwidth=8, round_mode='nearest',
                                        quant_scheme=MAP_QUANT_SCHEME_TO_PYMO[QuantScheme.post_training_tf],
                                        use_symmetric_encodings=False, enabled_by_default=True)
        encodings = libpymo.TfEncoding()
        encodings.bw = 8
        encodings.max = 2.23
        encodings.min = -5.19
        post_training_tensor_quantizer.encoding = encodings

        inp_tensor = torch.tensor([-7, -5, -3, 0, .1, 2.5])
        quant_out = post_training_tensor_quantizer.quantize(
            inp_tensor, MAP_ROUND_MODE_TO_PYMO['nearest'])
        expected_out = torch.tensor([0, 6, 75, 178, 181, 255],
                                    dtype=torch.float32)
        self.assertTrue(torch.equal(quant_out, expected_out))
Ejemplo n.º 5
0
    def test_quantsim_export(self):
        torch.manual_seed(10)
        model = Model2(Add())
        dummy_input = torch.randn(5, 10, 10, 20)
        sim = QuantizationSimModel(model, dummy_input)
        encodings = libpymo.TfEncoding()
        encodings.bw = 8
        encodings.max = 5
        encodings.min = -5
        encodings.delta = 1
        encodings.offset = 0.2
        sim.model.op1.output_quantizer.encoding = encodings
        sim.model.conv1.output_quantizer.encoding = encodings
        sim.model.conv1.param_quantizers['weight'].encoding = encodings
        sim.export(path='./data', filename_prefix='quant_model', dummy_input=dummy_input)

        with open('./data/quant_model.encodings') as f:
            data = json.load(f)

        self.assertTrue(isinstance(data['activation_encodings']['3'], list))
        self.assertTrue(isinstance(data['activation_encodings']['4'], list))
Ejemplo n.º 6
0
    def get_encoding(self) -> libpymo.TfEncoding:
        """
        Get encoding if valid else raise error
        :return: encoding
        """
        if self.is_encoding_valid():
            encoding_min = self.get_variable_from_op(
                QuantizeOpIndices.encoding_min)
            encoding_max = self.get_variable_from_op(
                QuantizeOpIndices.encoding_max)
            bitwidth = self.bitwidth

            # Create Encoding object
            encoding = libpymo.TfEncoding()
            encoding.min = encoding_min
            encoding.max = encoding_max
            encoding.bw = bitwidth
            encoding.delta, encoding.offset = calculate_delta_offset(
                encoding_min, encoding_max, bitwidth)
        else:
            raise AssertionError(
                'Compute encoding or Set encoding must be invoked before')

        return encoding
Ejemplo n.º 7
0
 def validate_serialize_deserialize(self, tc: TestCase):
     """
    helper method to run quant RNN test
     """
     original_model = copy.deepcopy(tc.model)
     quant_op = QcQuantizeRecurrent(module_to_quantize=tc.model, weight_bw=8, activation_bw=8, is_symmetric=False,
                                    quant_scheme=QuantScheme.post_training_tf_enhanced, round_mode='nearest')
     quant_op.eval()
     inp = torch.rand(tc.input_shape, requires_grad=True).to(tc.device)
     encodings = libpymo.TfEncoding()
     encodings.bw = 8
     encodings.max = 3
     encodings.min = -2
     encodings.delta = 1
     encodings.offset = 0.2
     for input_quantizer in quant_op.input_quantizers.values():
         input_quantizer.enabled = True
         input_quantizer.encoding = encodings
     for name, param in quant_op.named_parameters(recurse=False):
         quant_op.param_quantizers[name].enabled = True
         quant_op.param_quantizers[name].encoding = encodings
     for output_quantizer in quant_op.output_quantizers.values():
         output_quantizer.encoding = encodings
     o_qc_rnn, _ = quant_op(inp, hx=None)
     optimizer = torch.optim.SGD(quant_op.parameters(), lr=0.05, momentum=0.5)
     # creating a fake loss function with sum of output
     loss = o_qc_rnn.flatten().sum()
     loss.backward()
     optimizer.step()
     # Generate Quantize encodings
     quant_op.compute_encoding()
     quant_op.compute_weight_encodings()
     o_pre, h_pre = quant_op(inp, hx=None)
     # Save and loaded a quantized model
     with tempfile.NamedTemporaryFile() as f:
         torch.save(quant_op, f)
         f.seek(0)
         loaded_model = torch.load(f)
         loaded_model.eval()
     # compare the parameters
     for name, param in quant_op.named_parameters(recurse=False):
         loaded_param = getattr(loaded_model, name)
         self.assertTrue(torch.equal(param, loaded_param),
                         msg="param mismatched recurrent op param mis-matched, TestCase:{}".format(tc.test_name))
     for name, param in quant_op.module_to_quantize.named_parameters():
         loaded_param = getattr(loaded_model.module_to_quantize, name)
         self.assertTrue(torch.equal(param, loaded_param),
                         msg="original module mismatched, TestCase:{}".format(tc.test_name))
     # compare the quantizers
     for name, output_quantizer in quant_op.output_quantizers.items():
         if output_quantizer.enabled:
             self.compare_quantizer(output_quantizer, loaded_model.output_quantizers[name])
     for name, quantizer in quant_op.param_quantizers.items():
         if quantizer.enabled:
             self.compare_quantizer(quantizer, loaded_model.param_quantizers[name])
     # check if the loaded module generates the same output
     o_post, h_post = loaded_model(inp, hx=None)
     self.assertTrue(torch.equal(o_pre, o_post),
                     msg="output mismatched, Failed TestCase:{}".format(tc.test_name))
     if isinstance(h_pre, tuple):
         for pre, post in zip(h_pre, h_post):
             self.assertTrue(torch.equal(pre, post),
                             msg="h or c mismatched, Failed TestCase:{}".format(tc.test_name))
     else:
         self.assertTrue(torch.equal(h_pre, h_post),
                         msg="h mis-matched, Failed TestCase:{}".format(tc.test_name))
Ejemplo n.º 8
0
    def test_qc_post_training_wrapper(self):
        torch.manual_seed(0)

        encodings = libpymo.TfEncoding()
        encodings.bw, encodings.max, encodings.min, encodings.delta, encodings.offset = 8, 0.5, -1, 1, 0.2

        encodings_new = libpymo.TfEncoding()
        encodings_new.bw, encodings_new.max, encodings_new.min, encodings_new.delta, encodings_new.offset = 8, 0.4, -0.98, 1, 0.2

        output_grad = []

        def hook_fn(m, _, i):

            for grad in i:
                try:
                    output_grad.append(grad)
                except AttributeError:
                    print("None found for Gradient")

        conv1 = torch.nn.Conv2d(1, 2, 1)
        quantize = QcPostTrainingWrapper(
            conv1,
            weight_bw=8,
            activation_bw=8,
            round_mode='nearest',
            quant_scheme=QuantScheme.post_training_tf_enhanced)
        quantize.train()
        quantize._module_to_wrap.register_backward_hook(hook_fn)

        quantize.input_quantizer.enabled = True
        quantize.output_quantizers[0].enabled = True
        quantize.input_quantizer.encoding = encodings
        quantize.output_quantizers[0].encoding = encodings

        new_input = torch.autograd.Variable(torch.tensor([[[[0.6469]]],
                                                          [[[-0.9]]]]),
                                            requires_grad=True)
        quantize.set_mode(QcQuantizeOpMode.ACTIVE)
        out = quantize(new_input)

        quantize.input_quantizer.encoding = encodings_new
        quantize.output_quantizers[0].encoding = encodings_new
        quantize.param_quantizers['weight'].encoding = encodings_new

        loss = out.flatten().sum()
        loss.backward()

        # Check if input gradient got clipped
        for i, val in enumerate(new_input):
            if encodings_new.min > val or val > encodings_new.max:
                self.assertTrue(new_input.grad[0][i] == 0.0)

        # Check if output gradient got clipped
        output_grad = output_grad[0].flatten()
        self.assertTrue(output_grad[0] == 1.0)
        self.assertTrue(output_grad[1] == 1.0)
        self.assertTrue(output_grad[2] == 1.0)
        self.assertTrue(output_grad[3] == 0.0)

        # Check if weight gradient got clipped
        weight_tensor = quantize._module_to_wrap.weight.flatten()
        weight_tensor_grad = quantize._module_to_wrap.weight.grad.flatten()
        for i, val in enumerate(weight_tensor):
            if encodings_new.min > val or val > encodings_new.max:
                self.assertTrue(weight_tensor_grad[i] == 0.0)