示例#1
0
    def test_quantize_transpose(self):
        np.random.seed(1)
        model_fp32_path = 'transpose_fp32.onnx'
        model_uint8_path = 'transpose_uint8.onnx'
        model_uint8_qdq_path = 'transpose_uint8_qdq.onnx'

        self.construct_model_matmul_transpose(model_fp32_path, [3, 7], [7, 5], [5, 3])

        # Verify QOperator model
        data_reader = self.input_feeds(1, {'input': [3, 7]})
        quantize_static(model_fp32_path, model_uint8_path, data_reader)
        # make sure transpose become xint8 operator, its input name could tell that
        check_op_nodes(self, model_uint8_path, lambda node: (node.name != "transpose_node" or node.input[0] != 'matmul_output'))
        qnode_counts = {'QLinearMatMul': 1, 'QuantizeLinear': 1, 'DequantizeLinear': 1, 'Transpose': 1}
        check_op_type_count(self, model_uint8_path, **qnode_counts)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_path, data_reader.get_next())

        # Verify QDQ model
        data_reader.rewind()
        quantize_static(model_fp32_path, model_uint8_qdq_path, data_reader, quant_format=QuantFormat.QDQ)
        qdqnode_counts = {'MatMul': 1, 'QuantizeLinear': 2, 'DequantizeLinear': 3, 'Transpose': 1}
        check_op_type_count(self, model_uint8_qdq_path, **qdqnode_counts)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_qdq_path, data_reader.get_next())
示例#2
0
    def test_quantize_resize(self):
        np.random.seed(1)

        model_fp32_path = 'resize_fp32.onnx'
        model_uint8_path = 'resize_uint8.onnx'
        model_uint8_qdq_path = 'resize_uint8_qdq.onnx'

        kwargs = {'coordinate_transformation_mode': 'asymmetric', 'mode': 'nearest', 'nearest_mode': 'floor'}
        self.construct_model_conv_resize(model_fp32_path,
                                         [1, 2, 26, 42], [3, 2, 3, 3],
                                         [1, 3, 24, 40], [1, 3, 48, 80],
                                         kwargs,
                                         [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 2.0, 2.0], None)

        # Verify QOperator mode
        data_reader = self.input_feeds(1, {'input': [1, 2, 26, 42]})
        quantize_static(model_fp32_path, model_uint8_path, data_reader)

        # make sure resize become xint8 operator, its input name could tell that
        check_op_nodes(self, model_uint8_path, lambda node: (node.name != "resize_node" or node.input[0] != 'conv_output'))
        qnode_counts = {'QLinearConv': 1, 'QuantizeLinear': 1, 'DequantizeLinear': 2, 'Resize': 1}
        check_op_type_count(self, model_uint8_path, **qnode_counts)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_path, data_reader.get_next())

        # Verify QDQ mode
        data_reader.rewind()
        quantize_static(model_fp32_path, model_uint8_qdq_path, data_reader, quant_format=QuantFormat.QDQ)
        qdqnode_counts = {'Conv': 1, 'QuantizeLinear': 2, 'DequantizeLinear': 3, 'Resize': 1}
        check_op_type_count(self, model_uint8_qdq_path, **qdqnode_counts)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_qdq_path, data_reader.get_next())
示例#3
0
    def quantize_argmax_test(self, activation_type, weight_type, extra_options = {}):
        np.random.seed(1)
        model_fp32_path = 'argmax_fp32.onnx'

        self.construct_model_argmax(model_fp32_path,
                                            [1, 256, 128, 128],
                                            [1, 32, 128])

        activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8
        activation_type_str = 'u8' if (activation_type == QuantType.QUInt8) else 's8'
        weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8'
        model_uint8_path = 'argmax_{}{}.onnx'.format(activation_type_str, weight_type_str)
        model_uint8_qdq_path = 'argmax_{}{}_qdq.onnx'.format(activation_type_str, weight_type_str)
        model_uint8_qdq_trt_path = 'argmax_{}{}_qdq_trt.onnx'.format(activation_type_str, weight_type_str)

        # Verify QOperator mode
        data_reader = self.input_feeds(1, {'input': [1, 256, 128, 128]})
        quantize_static(model_fp32_path, model_uint8_path, data_reader, quant_format=QuantFormat.QOperator,
                        activation_type = activation_type, weight_type = weight_type, extra_options = extra_options)
        # make sure argmax become xint8 operator, its input name could tell that
        check_op_nodes(self, model_uint8_path, lambda node: not(node.name == "argmax_node" and node.input[0] == 'conv_output'))
        qnode_counts = {'QuantizeLinear': 1, 'QLinearConv': 1, 'ArgMax': 1}
        check_op_type_count(self, model_uint8_path, **qnode_counts)
        qnode_io_qtypes = {'QuantizeLinear' : [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]}
        check_qtype_by_node_type(self, model_uint8_path, qnode_io_qtypes)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_path, data_reader.get_next())

        # Verify QDQ mode
        data_reader.rewind()
        quantize_static(model_fp32_path, model_uint8_qdq_path, data_reader, quant_format=QuantFormat.QDQ,
                        activation_type = activation_type, weight_type = weight_type, extra_options = extra_options)
        qdqnode_counts = {'QuantizeLinear': 2, 'DequantizeLinear': 3, 'ArgMax': 1}
        check_op_type_count(self, model_uint8_qdq_path, **qdqnode_counts)
        qnode_io_qtypes = {'QuantizeLinear' : [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]}
        check_qtype_by_node_type(self, model_uint8_qdq_path, qnode_io_qtypes)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_qdq_path, data_reader.get_next())

        # Verify QDQ mode for TensorRT
        data_reader.rewind()
        quantize_static(model_fp32_path, model_uint8_qdq_trt_path, data_reader, quant_format=QuantFormat.QDQ,
                        activation_type=activation_type, weight_type=weight_type, extra_options=extra_options,
                        op_types_to_quantize=['ArgMax'])
        qdqnode_counts = {'QuantizeLinear': 1, 'DequantizeLinear': 1, 'ArgMax': 1}
        check_op_type_count(self, model_uint8_qdq_trt_path, **qdqnode_counts)
        qnode_io_qtypes = {'QuantizeLinear' : [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]}
        check_qtype_by_node_type(self, model_uint8_qdq_trt_path, qnode_io_qtypes)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_qdq_trt_path, data_reader.get_next())
示例#4
0
    def test_quantize_maxpool(self):
        np.random.seed(1)

        model_fp32_path = 'maxpool_fp32.onnx'
        model_uint8_path = 'maxpool_uint8.onnx'
        model_uint8_qdq_path = 'maxpool_uint8_qdq.onnx'

        self.construct_model_conv_maxpool(model_fp32_path, [1, 2, 26, 42],
                                          [3, 2, 3, 3], [1, 3, 24, 40],
                                          {'kernel_shape': [3, 3]},
                                          [1, 3, 22, 38])

        # Verify QOperator mode
        data_reader = self.input_feeds(1, {'input': [1, 2, 26, 42]})
        quantize_static(model_fp32_path, model_uint8_path, data_reader)

        # make sure maxpool become xint8 operator, its input name could tell that
        check_op_nodes(
            self, model_uint8_path, lambda node:
            (node.name != "maxpool_node" or node.input[0] != 'conv_output'))
        qnode_counts = {
            'QLinearConv': 1,
            'QuantizeLinear': 1,
            'DequantizeLinear': 2,
            'MaxPool': 1
        }
        check_op_type_count(self, model_uint8_path, **qnode_counts)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_path,
                                data_reader.get_next())

        # Verify QDQ mode
        data_reader.rewind()
        quantize_static(model_fp32_path,
                        model_uint8_qdq_path,
                        data_reader,
                        quant_format=QuantFormat.QDQ)
        qdqnode_counts = {
            'Conv': 1,
            'QuantizeLinear': 2,
            'DequantizeLinear': 3,
            'MaxPool': 1
        }
        check_op_type_count(self, model_uint8_qdq_path, **qdqnode_counts)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_qdq_path,
                                data_reader.get_next())
示例#5
0
    def quantize_resize_test(self, activation_type, weight_type, extra_options = {}):
        np.random.seed(1)
        model_fp32_path = 'resize_fp32.onnx'

        kwargs = {'coordinate_transformation_mode': 'asymmetric', 'mode': 'nearest', 'nearest_mode': 'floor'}
        self.construct_model_conv_resize(model_fp32_path,
                                         [1, 2, 26, 42], [3, 2, 3, 3],
                                         [1, 3, 24, 40], [1, 3, 48, 80],
                                         kwargs,
                                         [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 2.0, 2.0], None)

        activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8
        activation_type_str = 'u8' if (activation_type == QuantType.QUInt8) else 's8'
        weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8'
        model_uint8_path = 'resize_{}{}.onnx'.format(activation_type_str, weight_type_str)
        model_uint8_qdq_path = 'resize_{}{}_qdq.onnx'.format(activation_type_str, weight_type_str)

        # Verify QOperator mode
        data_reader = self.input_feeds(1, {'input': [1, 2, 26, 42]})
        quantize_static(model_fp32_path, model_uint8_path, data_reader,
                        activation_type = activation_type, weight_type = weight_type, extra_options = extra_options)
        # make sure resize become xint8 operator, its input name could tell that
        check_op_nodes(self, model_uint8_path, lambda node: (node.name != "resize_node" or node.input[0] != 'conv_output'))
        qnode_counts = {'QLinearConv': 1, 'QuantizeLinear': 1, 'DequantizeLinear': 2, 'Resize': 1}
        check_op_type_count(self, model_uint8_path, **qnode_counts)
        qnode_io_qtypes = {'QuantizeLinear' : [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]}
        qnode_io_qtypes.update({'DequantizeLinear' : [['i', 2, activation_proto_qtype]]})
        check_qtype_by_node_type(self, model_uint8_path, qnode_io_qtypes)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_path, data_reader.get_next())

        # Verify QDQ mode
        data_reader.rewind()
        quantize_static(model_fp32_path, model_uint8_qdq_path, data_reader, quant_format=QuantFormat.QDQ,
                        activation_type = activation_type, weight_type = weight_type, extra_options = extra_options)
        qdqnode_counts = {'Conv': 1, 'QuantizeLinear': 3, 'DequantizeLinear': 4, 'Resize': 1}
        check_op_type_count(self, model_uint8_qdq_path, **qdqnode_counts)
        qnode_io_qtypes = {'QuantizeLinear' : [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]}
        check_qtype_by_node_type(self, model_uint8_qdq_path, qnode_io_qtypes)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_qdq_path, data_reader.get_next())
示例#6
0
    def quantize_maxpool_test(self, activation_type, weight_type, extra_options={}):
        np.random.seed(1)
        model_fp32_path = 'maxpool_fp32.onnx'
        self.construct_model_conv_maxpool(model_fp32_path,
                                          [1, 2, 26, 42], [3, 2, 3, 3],
                                          [1, 3, 24, 40], {'kernel_shape': [3, 3]},
                                          [1, 3, 22, 38])
        data_reader = self.input_feeds(1, {'input': [1, 2, 26, 42]})

        activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8
        activation_type_str = 'u8' if (activation_type == QuantType.QUInt8) else 's8'
        weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8'
        model_q8_path = 'maxpool_{}{}.onnx'.format(activation_type_str, weight_type_str)
        model_q8_qdq_path = 'maxpool_dqd_{}{}.onnx'.format(activation_type_str, weight_type_str)

        # Verify QOperator mode
        data_reader.rewind()
        quantize_static(model_fp32_path, model_q8_path, data_reader, quant_format=QuantFormat.QOperator,
                        activation_type=activation_type, weight_type=weight_type, extra_options=extra_options)
        # make sure maxpool become xint8 operator, its input name could tell that
        check_op_nodes(self, model_q8_path, lambda node: (node.name != "maxpool_node" or node.input[0] != 'conv_output'))
        qnode_counts = {'QLinearConv': 1, 'QuantizeLinear': 1, 'DequantizeLinear': 2, 'MaxPool': 1}
        check_op_type_count(self, model_q8_path, **qnode_counts)
        qnode_io_qtypes = {'QuantizeLinear': [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]}
        qnode_io_qtypes.update({'DequantizeLinear': [['i', 2, activation_proto_qtype]]})
        check_qtype_by_node_type(self, model_q8_path, qnode_io_qtypes)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_q8_path, data_reader.get_next())

        # Verify QDQ mode
        data_reader.rewind()
        quantize_static(model_fp32_path, model_q8_qdq_path, data_reader, quant_format=QuantFormat.QDQ,
                        activation_type=activation_type, weight_type=weight_type, extra_options=extra_options)
        qdqnode_counts = {'Conv': 1, 'QuantizeLinear': 3, 'DequantizeLinear': 4, 'MaxPool': 1}
        check_op_type_count(self, model_q8_qdq_path, **qdqnode_counts)
        qnode_io_qtypes = {'QuantizeLinear': [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]}
        qnode_io_qtypes.update({'DequantizeLinear': [['i', 2, activation_proto_qtype]]})
        check_qtype_by_node_type(self, model_q8_qdq_path, qnode_io_qtypes)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_q8_qdq_path, data_reader.get_next())
示例#7
0
    def quantize_reshape_test(self, activation_type, weight_type, extra_options={}):
        np.random.seed(1)
        model_fp32_path = "reshape_fp32.onnx"

        self.construct_model_matmul_reshape(model_fp32_path, [3, 7], [7, 3], [1, 9])

        activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8
        activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8"
        weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8"
        model_uint8_path = "reshape_{}{}.onnx".format(activation_type_str, weight_type_str)
        model_uint8_qdq_path = "reshape_{}{}_qdq.onnx".format(activation_type_str, weight_type_str)

        # Verify QOperator mode
        data_reader = self.input_feeds(1, {"input": [3, 7]})
        quantize_static(
            model_fp32_path,
            model_uint8_path,
            data_reader,
            quant_format=QuantFormat.QOperator,
            activation_type=activation_type,
            weight_type=weight_type,
            extra_options=extra_options,
        )
        # make sure reshape become xint8 operator, its input name could tell that
        check_op_nodes(
            self,
            model_uint8_path,
            lambda node: (node.name != "reshape_node" or node.input[0] != "matmul_output"),
        )
        qnode_counts = {
            "QLinearMatMul": 1,
            "QuantizeLinear": 1,
            "DequantizeLinear": 1,
            "Reshape": 1,
        }
        check_op_type_count(self, model_uint8_path, **qnode_counts)
        qnode_io_qtypes = {
            "QuantizeLinear": [
                ["i", 2, activation_proto_qtype],
                ["o", 0, activation_proto_qtype],
            ]
        }
        qnode_io_qtypes.update({"DequantizeLinear": [["i", 2, activation_proto_qtype]]})
        check_qtype_by_node_type(self, model_uint8_path, qnode_io_qtypes)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_path, data_reader.get_next())

        # Verify QDQ mode
        data_reader.rewind()
        quantize_static(
            model_fp32_path,
            model_uint8_qdq_path,
            data_reader,
            quant_format=QuantFormat.QDQ,
            activation_type=activation_type,
            weight_type=weight_type,
            extra_options=extra_options,
        )
        qdqnode_counts = {
            "MatMul": 1,
            "QuantizeLinear": 3,
            "DequantizeLinear": 4,
            "Reshape": 1,
        }
        check_op_type_count(self, model_uint8_qdq_path, **qdqnode_counts)
        qnode_io_qtypes = {
            "QuantizeLinear": [
                ["i", 2, activation_proto_qtype],
                ["o", 0, activation_proto_qtype],
            ]
        }
        check_qtype_by_node_type(self, model_uint8_qdq_path, qnode_io_qtypes)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_qdq_path, data_reader.get_next())
示例#8
0
    def quantize_resize_test(self,
                             activation_type,
                             weight_type,
                             extra_options={}):
        np.random.seed(1)
        model_fp32_path = "resize_fp32.onnx"

        kwargs = {
            "coordinate_transformation_mode": "asymmetric",
            "mode": "nearest",
            "nearest_mode": "floor",
        }
        self.construct_model_conv_resize(
            model_fp32_path,
            [1, 2, 26, 42],
            [3, 2, 3, 3],
            [1, 3, 24, 40],
            [1, 3, 48, 80],
            kwargs,
            [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
            [1.0, 1.0, 2.0, 2.0],
            None,
        )

        activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8
        activation_type_str = "u8" if (activation_type
                                       == QuantType.QUInt8) else "s8"
        weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8"
        model_uint8_path = "resize_{}{}.onnx".format(activation_type_str,
                                                     weight_type_str)
        model_uint8_qdq_path = "resize_{}{}_qdq.onnx".format(
            activation_type_str, weight_type_str)

        # Verify QOperator mode
        data_reader = self.input_feeds(1, {"input": [1, 2, 26, 42]})
        quantize_static(
            model_fp32_path,
            model_uint8_path,
            data_reader,
            quant_format=QuantFormat.QOperator,
            activation_type=activation_type,
            weight_type=weight_type,
            extra_options=extra_options,
        )
        # make sure resize become xint8 operator, its input name could tell that
        check_op_nodes(
            self,
            model_uint8_path,
            lambda node:
            (node.name != "resize_node" or node.input[0] != "conv_output"),
        )
        qnode_counts = {
            "QLinearConv": 1,
            "QuantizeLinear": 1,
            "DequantizeLinear": 2,
            "Resize": 1,
        }
        check_op_type_count(self, model_uint8_path, **qnode_counts)
        qnode_io_qtypes = {
            "QuantizeLinear": [
                ["i", 2, activation_proto_qtype],
                ["o", 0, activation_proto_qtype],
            ]
        }
        qnode_io_qtypes.update(
            {"DequantizeLinear": [["i", 2, activation_proto_qtype]]})
        check_qtype_by_node_type(self, model_uint8_path, qnode_io_qtypes)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_path,
                                data_reader.get_next())

        # Verify QDQ mode
        data_reader.rewind()
        quantize_static(
            model_fp32_path,
            model_uint8_qdq_path,
            data_reader,
            quant_format=QuantFormat.QDQ,
            activation_type=activation_type,
            weight_type=weight_type,
            extra_options=extra_options,
        )
        qdqnode_counts = {
            "Conv": 1,
            "QuantizeLinear": 3,
            "DequantizeLinear": 4,
            "Resize": 1,
        }
        check_op_type_count(self, model_uint8_qdq_path, **qdqnode_counts)
        qnode_io_qtypes = {
            "QuantizeLinear": [
                ["i", 2, activation_proto_qtype],
                ["o", 0, activation_proto_qtype],
            ]
        }
        check_qtype_by_node_type(self, model_uint8_qdq_path, qnode_io_qtypes)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_qdq_path,
                                data_reader.get_next())
示例#9
0
    def quantize_reshape_test(self,
                              activation_type,
                              weight_type,
                              extra_options={}):
        np.random.seed(1)
        model_fp32_path = 'reshape_fp32.onnx'

        self.construct_model_matmul_reshape(model_fp32_path, [3, 7], [7, 3],
                                            [1, 9])

        activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8
        activation_type_str = 'u8' if (activation_type
                                       == QuantType.QUInt8) else 's8'
        weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8'
        model_uint8_path = 'reshape_{}{}.onnx'.format(activation_type_str,
                                                      weight_type_str)
        model_uint8_qdq_path = 'reshape_{}{}_qdq.onnx'.format(
            activation_type_str, weight_type_str)

        # Verify QOperator mode
        data_reader = self.input_feeds(1, {'input': [3, 7]})
        quantize_static(model_fp32_path,
                        model_uint8_path,
                        data_reader,
                        activation_type=activation_type,
                        weight_type=weight_type,
                        extra_options=extra_options)
        # make sure transpose become xint8 operator, its input name could tell that
        check_op_nodes(
            self, model_uint8_path, lambda node:
            (node.name != "reshape_node" or node.input[0] != 'matmul_output'))
        qnode_counts = {
            'QLinearMatMul': 1,
            'QuantizeLinear': 1,
            'DequantizeLinear': 1,
            'Reshape': 1
        }
        check_op_type_count(self, model_uint8_path, **qnode_counts)
        qnode_io_qtypes = {
            'QuantizeLinear': [['i', 2, activation_proto_qtype],
                               ['o', 0, activation_proto_qtype]]
        }
        qnode_io_qtypes.update(
            {'DequantizeLinear': [['i', 2, activation_proto_qtype]]})
        check_qtype_by_node_type(self, model_uint8_path, qnode_io_qtypes)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_path,
                                data_reader.get_next())

        # Verify QDQ mode
        data_reader.rewind()
        quantize_static(model_fp32_path,
                        model_uint8_qdq_path,
                        data_reader,
                        quant_format=QuantFormat.QDQ,
                        activation_type=activation_type,
                        weight_type=weight_type,
                        extra_options=extra_options)
        qdqnode_counts = {
            'MatMul': 1,
            'QuantizeLinear': 3,
            'DequantizeLinear': 4,
            'Reshape': 1
        }
        check_op_type_count(self, model_uint8_qdq_path, **qdqnode_counts)
        qnode_io_qtypes = {
            'QuantizeLinear': [['i', 2, activation_proto_qtype],
                               ['o', 0, activation_proto_qtype]]
        }
        check_qtype_by_node_type(self, model_uint8_qdq_path, qnode_io_qtypes)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_qdq_path,
                                data_reader.get_next())
示例#10
0
    def quantize_argmax_test(self,
                             activation_type,
                             weight_type,
                             extra_options={}):
        np.random.seed(1)
        model_fp32_path = "argmax_fp32.onnx"

        self.construct_model_argmax(model_fp32_path, [1, 256, 128, 128],
                                    [1, 32, 128])

        activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8
        activation_type_str = "u8" if (activation_type
                                       == QuantType.QUInt8) else "s8"
        weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8"
        model_uint8_path = "argmax_{}{}.onnx".format(activation_type_str,
                                                     weight_type_str)
        model_uint8_qdq_path = "argmax_{}{}_qdq.onnx".format(
            activation_type_str, weight_type_str)
        model_uint8_qdq_trt_path = "argmax_{}{}_qdq_trt.onnx".format(
            activation_type_str, weight_type_str)

        # Verify QOperator mode
        data_reader = self.input_feeds(1, {"input": [1, 256, 128, 128]})
        quantize_static(
            model_fp32_path,
            model_uint8_path,
            data_reader,
            quant_format=QuantFormat.QOperator,
            activation_type=activation_type,
            weight_type=weight_type,
            extra_options=extra_options,
        )
        # make sure argmax become xint8 operator, its input name could tell that
        check_op_nodes(
            self,
            model_uint8_path,
            lambda node: not (node.name == "argmax_node" and node.input[0] ==
                              "conv_output"),
        )
        qnode_counts = {"QuantizeLinear": 1, "QLinearConv": 1, "ArgMax": 1}
        check_op_type_count(self, model_uint8_path, **qnode_counts)
        qnode_io_qtypes = {
            "QuantizeLinear": [
                ["i", 2, activation_proto_qtype],
                ["o", 0, activation_proto_qtype],
            ]
        }
        check_qtype_by_node_type(self, model_uint8_path, qnode_io_qtypes)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_path,
                                data_reader.get_next())

        # Verify QDQ mode
        data_reader.rewind()
        quantize_static(
            model_fp32_path,
            model_uint8_qdq_path,
            data_reader,
            quant_format=QuantFormat.QDQ,
            activation_type=activation_type,
            weight_type=weight_type,
            extra_options=extra_options,
        )
        qdqnode_counts = {
            "QuantizeLinear": 2,
            "DequantizeLinear": 3,
            "ArgMax": 1
        }
        check_op_type_count(self, model_uint8_qdq_path, **qdqnode_counts)
        qnode_io_qtypes = {
            "QuantizeLinear": [
                ["i", 2, activation_proto_qtype],
                ["o", 0, activation_proto_qtype],
            ]
        }
        check_qtype_by_node_type(self, model_uint8_qdq_path, qnode_io_qtypes)
        data_reader.rewind()
        check_model_correctness(self, model_fp32_path, model_uint8_qdq_path,
                                data_reader.get_next())

        # Verify QDQ mode for TensorRT
        data_reader.rewind()
        quantize_static(
            model_fp32_path,
            model_uint8_qdq_trt_path,
            data_reader,
            quant_format=QuantFormat.QDQ,
            activation_type=activation_type,
            weight_type=weight_type,
            extra_options=extra_options,
            op_types_to_quantize=["ArgMax"],
        )
        qdqnode_counts = {
            "QuantizeLinear": 1,
            "DequantizeLinear": 1,
            "ArgMax": 1
        }
        check_op_type_count(self, model_uint8_qdq_trt_path, **qdqnode_counts)
        qnode_io_qtypes = {
            "QuantizeLinear": [
                ["i", 2, activation_proto_qtype],
                ["o", 0, activation_proto_qtype],
            ]
        }
        check_qtype_by_node_type(self, model_uint8_qdq_trt_path,
                                 qnode_io_qtypes)
        data_reader.rewind()
        check_model_correctness(self,
                                model_fp32_path, model_uint8_qdq_trt_path,
                                data_reader.get_next())