def acc_ops_dequantize(network, target, args, kwargs, name): """ Currently just a no-op. """ input_val = kwargs["input"] if not isinstance(input_val, trt.tensorrt.ITensor): raise RuntimeError(f"{name} received input {input_val} that is not part " "of the TensorRT region!") q_scale = acc_utils.get_field_from_acc_out_ty(kwargs["input_tensor_meta"], "q_scale") q_zero_point = acc_utils.get_field_from_acc_out_ty(kwargs["input_tensor_meta"], "q_zero_point") dtype = acc_utils.get_field_from_acc_out_ty(kwargs["input_tensor_meta"], "dtype") if dtype not in (torch.quint8, torch.qint8, torch.qint32): raise RuntimeError("Only support (torch.quint8, torch.qint8, torch.qint32) " f"quantized type in dequantize, get {dtype}.") if q_zero_point != 0: raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}") scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([q_scale], dtype=np.float32))) scale_layer.name = input_val.name + ".dequant.scale" scale = scale_layer.get_output(0) assert trt.__version__ > "8.0", "Explicit dequantize op is only supported in " "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__ layer = network.add_dequantize(input=input_val, scale=scale) layer.name = input_val.name + ".dequant" layer.axis = 0 return layer.get_output(0)
def quantized_conv2d( *, input, weight, bias, stride, padding, dilation, groups, padding_mode, acc_out_ty=None, ): assert acc_out_ty is not None return torch.nn.quantized.functional.conv2d( input, weight, bias, stride, padding, dilation, groups, padding_mode, acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"), acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"), )
def acc_ops_quantize_per_tensor(network, target, args, kwargs, name): input_val = kwargs["input"] if not isinstance(input_val, trt.tensorrt.ITensor): raise RuntimeError(f"{name} received input {input_val} that is not part " "of the TensorRT region!") q_scale = acc_utils.get_field_from_acc_out_ty(kwargs["acc_out_ty"], "q_scale") q_zero_point = acc_utils.get_field_from_acc_out_ty(kwargs["acc_out_ty"], "q_zero_point") dtype = acc_utils.get_field_from_acc_out_ty(kwargs["acc_out_ty"], "dtype") if dtype not in (torch.quint8, torch.qint8, torch.qint32): raise RuntimeError("Only support (torch.quint8, torch.qint8, torch.qint32) " f"quantized type in quantize_per_tensor, get {dtype}.") if q_zero_point != 0: raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}") # temporarily set q_scale to 1 to make sure the q_scale is different # for quantize and dequantize to avoid the error # TODO: follow up with nvidia TensorRT team to repro and fix the problem q_scale = 1 scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([float(q_scale)], dtype=np.float32))) scale_layer.name = input_val.name + ".quant.scale" scale = scale_layer.get_output(0) assert trt.__version__ > "8.0", "Explicit quantize op is only supported in " "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__ layer = network.add_quantize(input=input_val, scale=scale) layer.axis = 0 layer.name = input_val.name + ".quant" return layer.get_output(0)
def quantize_per_tensor(*, input, acc_out_ty=None): assert acc_out_ty is not None return torch.quantize_per_tensor( input, acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"), acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"), acc_utils.get_field_from_acc_out_ty(acc_out_ty, "dtype"), )
def quantized_mul(*, input, other, acc_out_ty=None): assert acc_out_ty is not None return torch.ops.quantized.mul( input, other, acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"), acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"), )
def quantized_linear(*, input, weight, bias, acc_out_ty=None): assert acc_out_ty is not None return nn.quantized.functional.linear( input, weight, bias, acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"), acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"), )
def quantized_batch_norm2d(*, input, running_mean, running_var, weight, bias, eps, acc_out_ty): return torch.ops.quantized.batch_norm2d( input, weight, bias, running_mean, running_var, eps, acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_scale"), acc_utils.get_field_from_acc_out_ty(acc_out_ty, "q_zero_point"), )
def acc_ops_reshape(network, target, args, kwargs, name): input_val = kwargs["input"] if not isinstance(input_val, trt.tensorrt.ITensor): raise RuntimeError( f"Reshape received input {input_val} that is not part " "of the TensorRT region!") shape = acc_utils.get_field_from_acc_out_ty(kwargs["acc_out_ty"], "shape") if network.has_implicit_batch_dimension: shape = shape[1:] layer = network.add_shuffle(input_val) if all(isinstance(s, int) for s in shape): layer.reshape_dims = tuple(shape) else: # Convert all the dimensions to trt Tensors. trt_shape = [] for i, s in enumerate(shape): if isinstance(s, trt.tensorrt.ITensor): if len(s.shape) == 0: s = append_ones(network, s, f"{name}_{i}", 1) trt_shape.append(s) else: trt_shape.append(get_trt_tensor(network, s, f"{name}_{i}")) shape_layer = network.add_concatenation(inputs=trt_shape) shape_layer.axis = 0 shape_layer.name = f"{name}_output_shape" layer.set_input(1, shape_layer.get_output(0)) layer.name = name return layer.get_output(0)
def quantized_mul(*, input, other, acc_out_ty=None): assert acc_out_ty is not None qparams = acc_utils.get_field_from_acc_out_ty(acc_out_ty, "qparams") return torch.ops.quantized.mul( input, other, qparams["scale"], qparams["zero_point"], )
def quantized_linear(*, input, weight, bias, acc_out_ty=None): assert acc_out_ty is not None qparams = acc_utils.get_field_from_acc_out_ty(acc_out_ty, "qparams") return nn.quantized.functional.linear( input, weight, bias, qparams["scale"], qparams["zero_point"], )
def quantized_batch_norm2d(*, input, running_mean, running_var, weight, bias, eps, acc_out_ty): qparams = acc_utils.get_field_from_acc_out_ty(acc_out_ty, "qparams") return torch.ops.quantized.batch_norm2d( input, weight, bias, running_mean, running_var, eps, qparams["scale"], qparams["zero_point"], )
def reshape(*, input, acc_out_ty=None): assert acc_out_ty is not None return torch.reshape( input, tuple(acc_utils.get_field_from_acc_out_ty(acc_out_ty, "shape")) )
def to_dtype(input, acc_out_ty=None): assert acc_out_ty is not None, "valid acc_out_ty needed" return input.to(dtype=acc_utils.get_field_from_acc_out_ty(acc_out_ty, "dtype"))
def quantize_per_tensor(*, input, acc_out_ty=None): assert acc_out_ty is not None qparams = acc_utils.get_field_from_acc_out_ty(acc_out_ty, "qparams") dtype = acc_utils.get_field_from_acc_out_ty(acc_out_ty, "dtype") return torch.quantize_per_tensor(input, qparams["scale"], qparams["zero_point"], dtype)