Ejemplo n.º 1
0
    def graph_mode_quantize(
        self,
        inputs,
        data_loader,
        calibration_num_batches=64,
        qconfig_dict=None,
        force_quantize=False,
    ):
        """Quantize the model during export with graph mode quantization."""
        if force_quantize:
            trace = self.trace(inputs)
            if not qconfig_dict:
                qconfig_dict = {"": get_default_qconfig("fbgemm")}
            prepare_m = prepare_jit(trace, qconfig_dict, inplace=False)
            prepare_m.eval()
            with torch.no_grad():
                for i, (_, batch) in enumerate(data_loader):
                    print("Running calibration with batch {}".format(i))
                    input_data = self.onnx_trace_input(batch)
                    prepare_m(*input_data)
                    if i == calibration_num_batches - 1:
                        break
            trace = convert_jit(prepare_m, inplace=True)
        else:
            super().quantize()
            trace = self.trace(inputs)

        return trace
    def graph_mode_quantize(self,
                            inputs,
                            data_loader,
                            calibration_num_batches=64):
        """Quantize the model during export with graph mode quantization for linformer encoder."""
        if (isinstance(self.right_encoder, RoBERTaEncoder)
                and self.right_encoder.use_linformer_encoder
                and isinstance(self.left_encoder, RoBERTaEncoder)
                and self.left_encoder.use_linformer_encoder):
            trace = self.trace(inputs)
            qconfig = get_default_qconfig("fbgemm")
            qconfig_dict = {"": qconfig}
            prepare_m = prepare_jit(trace, qconfig_dict, inplace=False)
            prepare_m.eval()
            with torch.no_grad():
                for i, (_, batch) in enumerate(data_loader):
                    print("Running calibration with batch {}".format(i))
                    input_data = self.onnx_trace_input(batch)
                    prepare_m(*input_data)
                    if i == calibration_num_batches - 1:
                        break
            trace = convert_jit(prepare_m, inplace=True)
        else:
            super().quantize()
            trace = self.trace(inputs)

        return trace