def test_record_observer(self):
        model = AnnotatedSingleLayerLinearModel()
        model.qconfig = default_debug_qconfig
        model = prepare(model)
        # run the evaluation and dump all tensors
        test_only_eval_fn(model, self.calib_data)
        test_only_eval_fn(model, self.calib_data)
        observer_dict = {}
        get_observer_dict(model, observer_dict)

        self.assertTrue('fc1.module.observer' in observer_dict.keys(),
                        'observer is not recorded in the dict')
        self.assertEqual(len(observer_dict['fc1.module.observer'].get_tensor_value()), 2 * len(self.calib_data))
        self.assertEqual(observer_dict['fc1.module.observer'].get_tensor_value()[0], model(self.calib_data[0][0]))
Example #2
0
    def test_record_observer(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = AnnotatedSingleLayerLinearModel()
                model.qconfig = default_debug_qconfig
                model = prepare(model)
                # run the evaluation and dump all tensors
                test_only_eval_fn(model, self.calib_data)
                test_only_eval_fn(model, self.calib_data)
                observer_dict = {}
                get_observer_dict(model, observer_dict)

                self.assertTrue('fc1.module.activation_post_process' in observer_dict.keys(),
                                'observer is not recorded in the dict')
                self.assertEqual(len(observer_dict['fc1.module.activation_post_process'].get_tensor_value()),
                                 2 * len(self.calib_data))
                self.assertEqual(observer_dict['fc1.module.activation_post_process'].get_tensor_value()[0],
                                 model(self.calib_data[0][0]))
Example #3
0
    def _post_eval_hook(self, model, **args):
        """The function is used to do some post process after complete evaluation.
           Here, it used to dump quantizable op's output tensor.

        Args:
            model (object): input model

        Returns:
            None
        """
        from torch.utils.tensorboard import SummaryWriter
        from torch.quantization import get_observer_dict

        if args is not None and 'accuracy' in args:
            accuracy = args['accuracy']
        else:
            accuracy = ''

        if self.dump_times == 0:
            writer = SummaryWriter(
                'runs/eval/baseline' + '_acc' + str(accuracy), model)
        else:
            writer = SummaryWriter(
                'runs/eval/tune_' + str(self.dump_times) + '_acc' +
                str(accuracy), model)

        if args is not None and 'input' in args and self.dump_times == 0:
            writer.add_graph(model, args['input'])

        summary = OrderedDict()
        observer_dict = {}
        get_observer_dict(model, observer_dict)
        for key in observer_dict:
            if isinstance(observer_dict[key],
                          torch.nn.modules.linear.Identity):
                continue
            op_name = key.strip(".activation_post_process")
            summary[op_name +
                    ".output"] = observer_dict[key].get_tensor_value()
            for iter in summary[op_name + ".output"]:
                # Only collect last fused child output
                op = op_name
                if self.is_fused_child(op_name) == True and \
                   self.is_last_fused_child(op_name) == True:
                    op = op_name[:op_name.rfind('.')]
                else:
                    if self.is_fused_child(op_name) == True and \
                       self.is_last_fused_child(op_name) == False:
                        continue
                    else:
                        op = op_name

                if summary[op_name + ".output"][iter].is_quantized:
                    writer.add_histogram(
                        op + "/Output/int8",
                        torch.dequantize(summary[op_name + ".output"][iter]))
                else:
                    writer.add_histogram(op + "/Output/fp32",
                                         summary[op_name + ".output"][iter])

        state_dict = model.state_dict()
        for key in state_dict:
            if not isinstance(state_dict[key], torch.Tensor):
                continue

            op = key[:key.rfind('.')]
            if self.is_fused_child(op) == True:
                # fused child tensorboard tag will be merge
                weight = key[key.rfind('.') + 1:]
                op = op[:op.rfind('.')] + '/' + weight
            else:
                weight = key[key.rfind('.') + 1:]
                op = key[:key.rfind('.')] + '/' + weight

            # To merge ._packed_params
            op = op.replace('._packed_params', '')

            if state_dict[key].is_quantized:
                writer.add_histogram(op + "/int8",
                                     torch.dequantize(state_dict[key]))
            else:
                writer.add_histogram(op + "/fp32", state_dict[key])

        writer.close()
        self.dump_times = self.dump_times + 1

        return summary