def conv2d_relu( g, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper( g, q_input) weight, weight_scale, _, _ = symbolic_helper.dequantize_helper( g, q_weight) q_bias = symbolic_helper.requantize_bias_helper( g, bias, input_scale, weight_scale) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) output = opset9.relu(g, output) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
def add_relu(g, x, y, op_scale, op_zero_point): x, _, _, _ = symbolic_helper.dequantize_helper(g, x) y, _, _, _ = symbolic_helper.dequantize_helper(g, y) output = opset9.add(g, x, y) output = opset9.relu(g, output) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
def relu(g, input): if input not in symbolic_helper._quantized_ops: return opset9.relu(g, input) kwargs = { "Y_scale_f": input.node()["Y_scale"], "Y_zero_point_i": input.node()["Y_zero_point"], } output = g.op("_caffe2::Int8Relu", input, **kwargs) symbolic_helper._quantized_ops.add(output) return output
def relu(g, input): if input not in sym_help._quantized_ops: from torch.onnx.symbolic_opset9 import relu return relu(g, input) kwargs = { "Y_scale_f": input.node()["Y_scale"], "Y_zero_point_i": input.node()["Y_zero_point"], } output = g.op("_caffe2::Int8Relu", input, **kwargs) sym_help._quantized_ops.add(output) return output
def conv2d_relu(g, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point): input, input_scale, _, _ = sym_help.dequantize_helper(g, q_input) weight, weight_scale, _, axis = sym_help.dequantize_helper(g, q_weight) q_bias = sym_help.requantize_bias_helper(g, bias, input_scale, weight_scale, axis) bias, _, _, _ = sym_help.dequantize_helper(g, q_bias) output = conv2d(g, input, weight, bias, stride, padding, dilation, groups) output = relu(g, output) return sym_help.quantize_helper(g, output, op_scale, op_zero_point)