def __call__(self, tensor, mode=0): r""" Converts float weights to quantized weights. Args: - tensor: input data - mode: GFPQ mode for param GFPQ_MODE_INIT(0): There is no valid parameter in param[]. Generate the parameter and filled in param[]. GFPQ_MODE_UPDATE(1): There is parameter in param[]. Generate new parameter, update param[] when the new parameter is better. GFPQ_MODE_APPLY_ONLY(2): There is parameter in param[]. Don't generate parameter. Just use the param[]. """ global _USE_GFPQ_QUANT_LIB if _USE_GFPQ_QUANT_LIB: try: if isinstance(tensor, tuple): for tensor_item in tensor: data_cuda_array = cuda.as_cuda_array( tensor_item.data.detach()) data_p = data_cuda_array.device_ctypes_pointer self._param.mode = mode ret = self._libquant.HI_GFPQ_QuantAndDeQuant_GPU_PY( data_p, data_cuda_array.size, self._bit_width, ctypes.byref(self._param), self._stream.handle, self._cublas_handle) else: data_cuda_array = cuda.as_cuda_array(tensor.data.detach()) data_p = data_cuda_array.device_ctypes_pointer self._param.mode = mode ret = self._libquant.HI_GFPQ_QuantAndDeQuant_GPU_PY( data_p, data_cuda_array.size, self._bit_width, ctypes.byref(self._param), self._stream.handle, self._cublas_handle) except: pass finally: if ret != 0: _USE_GFPQ_QUANT_LIB = False logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) logger.warning( """Failed to quantize data with default HiSVP GFPQ library, Use implemented quantization algorithm instead.""") if isinstance(tensor, tuple): for tensor_item in tensor: tensor_item.data = fake_quantize( tensor_item.data.detach(), self._bit_width) else: tensor.data = fake_quantize(tensor.data.detach(), self._bit_width) else: if isinstance(tensor, tuple): for tensor_item in tensor: tensor_item.data = fake_quantize(tensor_item.data.detach(), self._bit_width) else: tensor.data = fake_quantize(tensor.data.detach(), self._bit_width) return tensor
def test(self, data): os.environ['CUDA_VISIBLE_DEVICES'] = '0' # load library dl = ctypes.cdll.LoadLibrary quant_lib = dl("nnieqat/gpu/lib/libgfpq_gpu.so") _libcublas = ctypes.cdll.LoadLibrary("libcublas.so") # struct GFPQ_PARAM_ST in gfpq.hpp class GFPQ_PARAM_ST(ctypes.Structure): _fields_ = [("mode", ctypes.c_int), ("buf", ctypes.c_byte * 16)] class _types: """Some alias types.""" handle = ctypes.c_void_p stream = ctypes.c_void_p data_origin = data.copy() print( "----------------------------------------------------------------------" ) print("\n\nOriginal data:") print(data) data = data.astype(np.float32) stream = cuda.stream() _libcublas.cublasCreate_v2.restype = int _libcublas.cublasCreate_v2.argtypes = [ctypes.c_void_p] cublas_handle = _types.handle() _libcublas.cublasCreate_v2(ctypes.byref(cublas_handle)) data_gpu = cuda.to_device(data, stream=stream) data_p = data_gpu.device_ctypes_pointer bit_width = 8 param = GFPQ_PARAM_ST() # init or update param first param.mode = 0 ret = quant_lib.HI_GFPQ_QuantAndDeQuant_GPU_PY(data_p, data.size, bit_width, ctypes.byref(param), stream.handle, cublas_handle) if ret != 0: print("HI_GFPQ_QuantAndDeQuant failed(%d)\n" % (ret)), # use apply param param.mode = 2 ret = quant_lib.HI_GFPQ_QuantAndDeQuant_GPU_PY(data_p, data.size, bit_width, ctypes.byref(param), stream.handle, cublas_handle) if ret != 0: print("HI_GFPQ_QuantAndDeQuant failed(%d)" % (ret)), data_gpu.copy_to_host(data, stream=stream) # data may not be available stream.synchronize() _libcublas.cublasDestroy_v2(cublas_handle) import nnieqat from quant_impl import fake_quantize import torch tensor = torch.Tensor(data_origin).cuda() tensor.data = fake_quantize(tensor.data.detach(), 8) diff = abs(tensor.cpu().numpy() - data) # diff_thres = np.max(abs(data)) * 0.001 # print("\nDIFF > 0.1%: ") # print("idx: ", np.where(diff > diff_thres)) # print("Original data:", data_origin[np.where(diff > diff_thres)]) # print("GFPQ result:", data[np.where(diff > diff_thres)]) # print("Impl result:", tensor.cpu().numpy()[np.where(diff > diff_thres)]) diff_max = np.max(diff) print("\nDIFF MAX: " + str(diff_max)) print("\nDIFF RATIO: " + str(diff_max / max(np.max(abs(data)), pow(10, -18))))