def replace_pattern(self, graph: Graph, match: dict):
        dequantize_node = match['dequantize']
        quantize_node = match['quantize']

        scale_zerop_is_exist = quantize_node.is_in_port_connected(1) and quantize_node.is_in_port_connected(2) and \
                               dequantize_node.is_in_port_connected(1) and dequantize_node.is_in_port_connected(2)
        if not scale_zerop_is_exist:
            return
        q_scale = quantize_node.in_port(1).get_source().node
        q_zerop = quantize_node.in_port(2).get_source().node
        dq_scale = dequantize_node.in_port(1).get_source().node
        dq_zerop = dequantize_node.in_port(2).get_source().node
        scales_and_zerop_is_const = q_scale.soft_get('type') == 'Const' and dq_scale.soft_get('type') == 'Const' and \
                                    q_zerop.soft_get('type') == 'Const' and dq_zerop.soft_get('type') == 'Const'
        scales_and_zerop_equals = np.array_equal(q_scale.value, dq_scale.value) and \
                                  np.array_equal(q_zerop.value, dq_zerop.value)

        # only constant as for zero_point/scale supported
        # only patterns with same scale/zero_point values for Q and DQ are supported
        if not (scales_and_zerop_is_const or scales_and_zerop_equals):
            return

        QuantizeLinearResolver.quantize_to_fakequantize(
            graph, quantize_node, True)
        quantize_node['isolated'] = True
示例#2
0
    def test_quantize_no_zerop(self):
        graph = build_graph(nodes1_attributes, [
            ('input', 'input_data'),
            ('input_data', 'quantize'),
            ('quantize', 'quantize_data'),
            ('scale_param_q', 'scale_param_q_data'),
            ('scale_param_q_data', 'quantize'),
            ('quantize', 'quantize_data'),
            ('quantize_data', 'out'),
            ('out', 'out_data'),
            ('out_data', 'result'),
        ], {
            'scale_param_q': {
                'shape': np.array([]),
                'value': np.float32(1.0 / 255)
            },
            'scale_param_q_data': {
                'shape': np.array([]),
                'value': np.float32(1.0 / 255)
            },
        },
                            nodes_with_edges_only=True)

        graph_ref = build_graph(nodes_ref_attributes, [
            ('input', 'input_data'),
            ('input_data', 'f_quantize'),
            ('scale_param_q', 'scale_param_q_data'),
            ('scale_param_q_data', 'mul1', {
                'out': 0
            }),
            ('in_low', 'in_low_data'),
            ('in_low_data', 'mul1'),
            ('mul1', 'mul1_data'),
            ('mul1_data', 'f_quantize'),
            ('f_quantize', 'f_quantize_data'),
            ('scale_param_q_data', 'mul2', {
                'out': 0
            }),
            ('in_high', 'in_high_data'),
            ('in_high_data', 'mul2'),
            ('mul2', 'mul2_data'),
            ('mul2_data', 'f_quantize'),
            ('out_low', 'out_low_data'),
            ('out_low_data', 'f_quantize'),
            ('out_high', 'out_high_data'),
            ('out_high_data', 'f_quantize'),
            ('f_quantize_data', 'cast'),
            ('cast', 'cast_data'),
            ('cast_data', 'out'),
            ('out', 'out_data'),
            ('out_data', 'result'),
        ], {
            'in_low': {
                'shape': np.array([]),
                'value': 0
            },
            'in_low_data': {
                'shape': np.array([]),
                'value': 0
            },
            'in_high': {
                'shape': np.array([]),
                'value': 255
            },
            'in_high_data': {
                'shape': np.array([]),
                'value': 255
            },
            'out_low': {
                'shape': np.array([]),
                'value': 0
            },
            'out_low_data': {
                'shape': np.array([]),
                'value': 0
            },
            'out_high': {
                'shape': np.array([]),
                'value': 255
            },
            'out_high_data': {
                'shape': np.array([]),
                'value': 255
            },
            'cast': {
                'dst_type': np.uint8
            }
        },
                                nodes_with_edges_only=True)

        graph.stage = 'middle'
        QuantizeLinearResolver().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'result',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
示例#3
0
    def test_quantize_with_axis(self, input_shape, scale_param_value,
                                zero_param_value, target_shape, in_low,
                                in_high, out_low, out_high, axis):
        graph = build_graph(nodes1_attributes, [
            ('input', 'input_data'),
            ('input_data', 'quantize'),
            ('scale_param_q', 'scale_param_q_data'),
            ('scale_param_q_data', 'quantize'),
            ('zerop_param_q', 'zerop_param_q_data'),
            ('zerop_param_q_data', 'quantize'),
            ('quantize', 'quantize_data'),
            ('quantize_data', 'out'),
            ('out', 'out_data'),
            ('out_data', 'result'),
        ], {
            'quantize': {
                'axis': axis
            },
            'input': {
                'shape': input_shape
            },
            'input_data': {
                'shape': input_shape
            },
            'scale_param_q': {
                'shape': scale_param_value.shape,
                'value': scale_param_value
            },
            'scale_param_q_data': {
                'shape': scale_param_value.shape,
                'value': scale_param_value
            },
            'zerop_param_q': {
                'shape': zero_param_value.shape,
                'value': zero_param_value
            },
            'zerop_param_q_data': {
                'shape': zero_param_value.shape,
                'value': zero_param_value
            },
        },
                            nodes_with_edges_only=True)

        graph_ref = build_graph(nodes_ref_attributes, [
            ('input', 'input_data'),
            ('input_data', 'f_quantize'),
            ('scale_param_q', 'scale_param_q_data'),
            ('scale_param_q_data', 'mul1', {
                'out': 0
            }),
            ('in_low', 'in_low_data'),
            ('in_low_data', 'mul1'),
            ('mul1', 'mul1_data'),
            ('mul1_data', 'high_reshape'),
            ('high_reshape_const', 'high_reshape_const_data'),
            ('high_reshape_const_data', 'high_reshape'),
            ('high_reshape', 'high_reshape_data'),
            ('high_reshape_data', 'f_quantize'),
            ('f_quantize', 'f_quantize_data'),
            ('scale_param_q_data', 'mul2', {
                'out': 0
            }),
            ('in_high', 'in_high_data'),
            ('in_high_data', 'mul2'),
            ('mul2', 'mul2_data'),
            ('mul2_data', 'low_reshape'),
            ('low_reshape', 'low_reshape_data'),
            ('low_reshape_data', 'f_quantize'),
            ('low_reshape_const', 'low_reshape_const_data'),
            ('low_reshape_const_data', 'low_reshape'),
            ('out_low', 'out_low_data'),
            ('out_low_data', 'f_quantize'),
            ('out_high', 'out_high_data'),
            ('out_high_data', 'f_quantize'),
            ('f_quantize_data', 'cast'),
            ('cast', 'cast_data'),
            ('cast_data', 'out'),
            ('out', 'out_data'),
            ('out_data', 'result'),
        ], {
            'in_low': {
                'shape': in_low.shape,
                'value': in_low
            },
            'in_low_data': {
                'shape': in_low.shape,
                'value': in_low
            },
            'in_high': {
                'shape': in_high.shape,
                'value': in_high
            },
            'in_high_data': {
                'shape': in_high.shape,
                'value': in_high
            },
            'out_low': {
                'shape': np.array([]),
                'value': out_low
            },
            'out_low_data': {
                'shape': np.array([]),
                'value': out_low
            },
            'out_high': {
                'shape': np.array([]),
                'value': out_high
            },
            'out_high_data': {
                'shape': np.array([]),
                'value': out_high
            },
            'cast': {
                'dst_type': zero_param_value.dtype
            },
            'low_reshape_const_data': {
                'shape': target_shape.shape,
                'value': target_shape
            },
            'high_reshape_const_data': {
                'shape': target_shape.shape,
                'value': target_shape
            },
        },
                                nodes_with_edges_only=True)

        graph.stage = 'middle'
        QuantizeLinearResolver().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'result',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
示例#4
0
    def test_quantize_dequantize_linear_resolver(self):
        graph = build_graph(nodes_attrs=nodes_attributes,
                            edges=[
                                *connect('input', '0:non_const_quantize'),
                                *connect('y_scale_2', '1:non_const_quantize'),
                                *connect('y_zeropoint_2', '2:non_const_quantize'),
                                *connect('non_const_quantize', '0:non_const_dequantize'),
                                *connect('x_scale_2', '1:non_const_dequantize'),
                                *connect('x_zeropoint_2', '2:non_const_dequantize'),

                                *connect('const_input', '0:const_quantize'),
                                *connect('y_scale_1', '1:const_quantize'),
                                *connect('y_zeropoint_1', '2:const_quantize'),
                                *connect('const_quantize', '0:const_dequantize'),
                                *connect('x_scale_1', '1:const_dequantize'),
                                *connect('x_zeropoint_1', '2:const_dequantize'),
                                *connect('const_dequantize', '0:add'),
                                *connect('non_const_dequantize', '1:add'),
                                *connect('add', 'result')
                            ], nodes_with_edges_only=True)

        const_ref_graph = build_graph(nodes_attrs=nodes_attributes,
                                      edges=[
                                          *connect('input', '0:non_const_quantize'),
                                          *connect('y_scale_2', '1:non_const_quantize'),
                                          *connect('y_zeropoint_2', '2:non_const_quantize'),
                                          *connect('non_const_quantize', '0:non_const_dequantize'),
                                          *connect('x_scale_2', '1:non_const_dequantize'),
                                          *connect('x_zeropoint_2', '2:non_const_dequantize'),

                                          *connect('const_input', '0:const_fq'),
                                          *connect('y_scale_1:0', '0:mul_low'),
                                          *connect('in_low', '1:mul_low'),
                                          ('y_scale_1_d', 'mul_high', {'out': 1, 'in': 0}),
                                          *connect('in_high', '1:mul_high'),
                                          *connect('mul_low', '1:const_fq'),
                                          *connect('mul_high', '2:const_fq'),
                                          *connect('out_low', '3:const_fq'),
                                          *connect('out_high', '4:const_fq'),
                                          *connect('const_fq', 'const_cast'),
                                          *connect('const_cast', '0:const_dequantize'),
                                          *connect('x_scale_1', '1:const_dequantize'),
                                          *connect('x_zeropoint_1', '2:const_dequantize'),
                                          *connect('const_dequantize', '0:add'),
                                          *connect('non_const_dequantize', '1:add'),
                                          *connect('add', 'result')
                                      ],nodes_with_edges_only=True)
        QuantizeDequantizeLinearResolver().find_and_replace_pattern(graph)
        graph.graph['layout'] = 'NCHW'
        (flag, resp) = compare_graphs(graph, const_ref_graph, 'result')
        self.assertTrue(flag, resp)

        ref_graph = build_graph(nodes_attrs=nodes_attributes,
                                edges=[
                                    *connect('input', '0:non_const_fq'),
                                    *connect('y_scale_2:0', '0:non_const_mul_low'),
                                    *connect('non_const_in_low', '1:non_const_mul_low'),
                                    ('y_scale_2_d', 'non_const_mul_high', {'out': 1, 'in': 0}),
                                    *connect('non_const_in_high', '1:non_const_mul_high'),
                                    *connect('non_const_mul_low', '1:non_const_fq'),
                                    *connect('non_const_mul_high', '2:non_const_fq'),
                                    *connect('non_const_out_low', '3:non_const_fq'),
                                    *connect('non_const_out_high', '4:non_const_fq'),
                                    *connect('non_const_fq', 'non_const_cast'),
                                    *connect('non_const_cast', '0:non_const_dequantize'),
                                    *connect('x_scale_2', '1:non_const_dequantize'),
                                    *connect('x_zeropoint_2', '2:non_const_dequantize'),

                                    *connect('const_input', '0:const_fq'),
                                    *connect('y_scale_1:0', '0:mul_low'),
                                    *connect('in_low', '1:mul_low'),
                                    ('y_scale_1_d', 'mul_high', {'out': 1, 'in': 0}),
                                    *connect('in_high', '1:mul_high'),
                                    *connect('mul_low', '1:const_fq'),
                                    *connect('mul_high', '2:const_fq'),
                                    *connect('out_low', '3:const_fq'),
                                    *connect('out_high', '4:const_fq'),
                                    *connect('const_fq', 'const_cast'),
                                    *connect('const_cast', '0:const_dequantize'),
                                    *connect('x_scale_1', '1:const_dequantize'),
                                    *connect('x_zeropoint_1', '2:const_dequantize'),
                                    *connect('const_dequantize', '0:add'),
                                    *connect('non_const_dequantize', '1:add'),
                                    *connect('add', 'result')
                                ], nodes_with_edges_only=True)
        QuantizeLinearResolver().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph, ref_graph, 'result')
        self.assertTrue(flag, resp)