def test_record_observer(self): model = AnnotatedSingleLayerLinearModel() model.qconfig = default_debug_qconfig model = prepare(model) # run the evaluation and dump all tensors test_only_eval_fn(model, self.calib_data) test_only_eval_fn(model, self.calib_data) observer_dict = {} get_observer_dict(model, observer_dict) self.assertTrue('fc1.module.observer' in observer_dict.keys(), 'observer is not recorded in the dict') self.assertEqual(len(observer_dict['fc1.module.observer'].get_tensor_value()), 2 * len(self.calib_data)) self.assertEqual(observer_dict['fc1.module.observer'].get_tensor_value()[0], model(self.calib_data[0][0]))
def test_record_observer(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = AnnotatedSingleLayerLinearModel() model.qconfig = default_debug_qconfig model = prepare(model) # run the evaluation and dump all tensors test_only_eval_fn(model, self.calib_data) test_only_eval_fn(model, self.calib_data) observer_dict = {} get_observer_dict(model, observer_dict) self.assertTrue('fc1.module.activation_post_process' in observer_dict.keys(), 'observer is not recorded in the dict') self.assertEqual(len(observer_dict['fc1.module.activation_post_process'].get_tensor_value()), 2 * len(self.calib_data)) self.assertEqual(observer_dict['fc1.module.activation_post_process'].get_tensor_value()[0], model(self.calib_data[0][0]))
def _post_eval_hook(self, model, **args): """The function is used to do some post process after complete evaluation. Here, it used to dump quantizable op's output tensor. Args: model (object): input model Returns: None """ from torch.utils.tensorboard import SummaryWriter from torch.quantization import get_observer_dict if args is not None and 'accuracy' in args: accuracy = args['accuracy'] else: accuracy = '' if self.dump_times == 0: writer = SummaryWriter( 'runs/eval/baseline' + '_acc' + str(accuracy), model) else: writer = SummaryWriter( 'runs/eval/tune_' + str(self.dump_times) + '_acc' + str(accuracy), model) if args is not None and 'input' in args and self.dump_times == 0: writer.add_graph(model, args['input']) summary = OrderedDict() observer_dict = {} get_observer_dict(model, observer_dict) for key in observer_dict: if isinstance(observer_dict[key], torch.nn.modules.linear.Identity): continue op_name = key.strip(".activation_post_process") summary[op_name + ".output"] = observer_dict[key].get_tensor_value() for iter in summary[op_name + ".output"]: # Only collect last fused child output op = op_name if self.is_fused_child(op_name) == True and \ self.is_last_fused_child(op_name) == True: op = op_name[:op_name.rfind('.')] else: if self.is_fused_child(op_name) == True and \ self.is_last_fused_child(op_name) == False: continue else: op = op_name if summary[op_name + ".output"][iter].is_quantized: writer.add_histogram( op + "/Output/int8", torch.dequantize(summary[op_name + ".output"][iter])) else: writer.add_histogram(op + "/Output/fp32", summary[op_name + ".output"][iter]) state_dict = model.state_dict() for key in state_dict: if not isinstance(state_dict[key], torch.Tensor): continue op = key[:key.rfind('.')] if self.is_fused_child(op) == True: # fused child tensorboard tag will be merge weight = key[key.rfind('.') + 1:] op = op[:op.rfind('.')] + '/' + weight else: weight = key[key.rfind('.') + 1:] op = key[:key.rfind('.')] + '/' + weight # To merge ._packed_params op = op.replace('._packed_params', '') if state_dict[key].is_quantized: writer.add_histogram(op + "/int8", torch.dequantize(state_dict[key])) else: writer.add_histogram(op + "/fp32", state_dict[key]) writer.close() self.dump_times = self.dump_times + 1 return summary