def graph_mode_quantize( self, inputs, data_loader, calibration_num_batches=64, qconfig_dict=None, force_quantize=False, ): """Quantize the model during export with graph mode quantization.""" if force_quantize: trace = self.trace(inputs) if not qconfig_dict: qconfig_dict = {"": get_default_qconfig("fbgemm")} prepare_m = prepare_jit(trace, qconfig_dict, inplace=False) prepare_m.eval() with torch.no_grad(): for i, (_, batch) in enumerate(data_loader): print("Running calibration with batch {}".format(i)) input_data = self.onnx_trace_input(batch) prepare_m(*input_data) if i == calibration_num_batches - 1: break trace = convert_jit(prepare_m, inplace=True) else: super().quantize() trace = self.trace(inputs) return trace
def graph_mode_quantize(self, inputs, data_loader, calibration_num_batches=64): """Quantize the model during export with graph mode quantization for linformer encoder.""" if (isinstance(self.right_encoder, RoBERTaEncoder) and self.right_encoder.use_linformer_encoder and isinstance(self.left_encoder, RoBERTaEncoder) and self.left_encoder.use_linformer_encoder): trace = self.trace(inputs) qconfig = get_default_qconfig("fbgemm") qconfig_dict = {"": qconfig} prepare_m = prepare_jit(trace, qconfig_dict, inplace=False) prepare_m.eval() with torch.no_grad(): for i, (_, batch) in enumerate(data_loader): print("Running calibration with batch {}".format(i)) input_data = self.onnx_trace_input(batch) prepare_m(*input_data) if i == calibration_num_batches - 1: break trace = convert_jit(prepare_m, inplace=True) else: super().quantize() trace = self.trace(inputs) return trace