def replace_pattern(self, graph: Graph, match: dict): dequantize_node = match['dequantize'] quantize_node = match['quantize'] scale_zerop_is_exist = quantize_node.is_in_port_connected(1) and quantize_node.is_in_port_connected(2) and \ dequantize_node.is_in_port_connected(1) and dequantize_node.is_in_port_connected(2) if not scale_zerop_is_exist: return q_scale = quantize_node.in_port(1).get_source().node q_zerop = quantize_node.in_port(2).get_source().node dq_scale = dequantize_node.in_port(1).get_source().node dq_zerop = dequantize_node.in_port(2).get_source().node scales_and_zerop_is_const = q_scale.soft_get('type') == 'Const' and dq_scale.soft_get('type') == 'Const' and \ q_zerop.soft_get('type') == 'Const' and dq_zerop.soft_get('type') == 'Const' scales_and_zerop_equals = np.array_equal(q_scale.value, dq_scale.value) and \ np.array_equal(q_zerop.value, dq_zerop.value) # only constant as for zero_point/scale supported # only patterns with same scale/zero_point values for Q and DQ are supported if not (scales_and_zerop_is_const or scales_and_zerop_equals): return QuantizeLinearResolver.quantize_to_fakequantize( graph, quantize_node, True) quantize_node['isolated'] = True
def test_quantize_no_zerop(self): graph = build_graph(nodes1_attributes, [ ('input', 'input_data'), ('input_data', 'quantize'), ('quantize', 'quantize_data'), ('scale_param_q', 'scale_param_q_data'), ('scale_param_q_data', 'quantize'), ('quantize', 'quantize_data'), ('quantize_data', 'out'), ('out', 'out_data'), ('out_data', 'result'), ], { 'scale_param_q': { 'shape': np.array([]), 'value': np.float32(1.0 / 255) }, 'scale_param_q_data': { 'shape': np.array([]), 'value': np.float32(1.0 / 255) }, }, nodes_with_edges_only=True) graph_ref = build_graph(nodes_ref_attributes, [ ('input', 'input_data'), ('input_data', 'f_quantize'), ('scale_param_q', 'scale_param_q_data'), ('scale_param_q_data', 'mul1', { 'out': 0 }), ('in_low', 'in_low_data'), ('in_low_data', 'mul1'), ('mul1', 'mul1_data'), ('mul1_data', 'f_quantize'), ('f_quantize', 'f_quantize_data'), ('scale_param_q_data', 'mul2', { 'out': 0 }), ('in_high', 'in_high_data'), ('in_high_data', 'mul2'), ('mul2', 'mul2_data'), ('mul2_data', 'f_quantize'), ('out_low', 'out_low_data'), ('out_low_data', 'f_quantize'), ('out_high', 'out_high_data'), ('out_high_data', 'f_quantize'), ('f_quantize_data', 'cast'), ('cast', 'cast_data'), ('cast_data', 'out'), ('out', 'out_data'), ('out_data', 'result'), ], { 'in_low': { 'shape': np.array([]), 'value': 0 }, 'in_low_data': { 'shape': np.array([]), 'value': 0 }, 'in_high': { 'shape': np.array([]), 'value': 255 }, 'in_high_data': { 'shape': np.array([]), 'value': 255 }, 'out_low': { 'shape': np.array([]), 'value': 0 }, 'out_low_data': { 'shape': np.array([]), 'value': 0 }, 'out_high': { 'shape': np.array([]), 'value': 255 }, 'out_high_data': { 'shape': np.array([]), 'value': 255 }, 'cast': { 'dst_type': np.uint8 } }, nodes_with_edges_only=True) graph.stage = 'middle' QuantizeLinearResolver().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp)
def test_quantize_with_axis(self, input_shape, scale_param_value, zero_param_value, target_shape, in_low, in_high, out_low, out_high, axis): graph = build_graph(nodes1_attributes, [ ('input', 'input_data'), ('input_data', 'quantize'), ('scale_param_q', 'scale_param_q_data'), ('scale_param_q_data', 'quantize'), ('zerop_param_q', 'zerop_param_q_data'), ('zerop_param_q_data', 'quantize'), ('quantize', 'quantize_data'), ('quantize_data', 'out'), ('out', 'out_data'), ('out_data', 'result'), ], { 'quantize': { 'axis': axis }, 'input': { 'shape': input_shape }, 'input_data': { 'shape': input_shape }, 'scale_param_q': { 'shape': scale_param_value.shape, 'value': scale_param_value }, 'scale_param_q_data': { 'shape': scale_param_value.shape, 'value': scale_param_value }, 'zerop_param_q': { 'shape': zero_param_value.shape, 'value': zero_param_value }, 'zerop_param_q_data': { 'shape': zero_param_value.shape, 'value': zero_param_value }, }, nodes_with_edges_only=True) graph_ref = build_graph(nodes_ref_attributes, [ ('input', 'input_data'), ('input_data', 'f_quantize'), ('scale_param_q', 'scale_param_q_data'), ('scale_param_q_data', 'mul1', { 'out': 0 }), ('in_low', 'in_low_data'), ('in_low_data', 'mul1'), ('mul1', 'mul1_data'), ('mul1_data', 'high_reshape'), ('high_reshape_const', 'high_reshape_const_data'), ('high_reshape_const_data', 'high_reshape'), ('high_reshape', 'high_reshape_data'), ('high_reshape_data', 'f_quantize'), ('f_quantize', 'f_quantize_data'), ('scale_param_q_data', 'mul2', { 'out': 0 }), ('in_high', 'in_high_data'), ('in_high_data', 'mul2'), ('mul2', 'mul2_data'), ('mul2_data', 'low_reshape'), ('low_reshape', 'low_reshape_data'), ('low_reshape_data', 'f_quantize'), ('low_reshape_const', 'low_reshape_const_data'), ('low_reshape_const_data', 'low_reshape'), ('out_low', 'out_low_data'), ('out_low_data', 'f_quantize'), ('out_high', 'out_high_data'), ('out_high_data', 'f_quantize'), ('f_quantize_data', 'cast'), ('cast', 'cast_data'), ('cast_data', 'out'), ('out', 'out_data'), ('out_data', 'result'), ], { 'in_low': { 'shape': in_low.shape, 'value': in_low }, 'in_low_data': { 'shape': in_low.shape, 'value': in_low }, 'in_high': { 'shape': in_high.shape, 'value': in_high }, 'in_high_data': { 'shape': in_high.shape, 'value': in_high }, 'out_low': { 'shape': np.array([]), 'value': out_low }, 'out_low_data': { 'shape': np.array([]), 'value': out_low }, 'out_high': { 'shape': np.array([]), 'value': out_high }, 'out_high_data': { 'shape': np.array([]), 'value': out_high }, 'cast': { 'dst_type': zero_param_value.dtype }, 'low_reshape_const_data': { 'shape': target_shape.shape, 'value': target_shape }, 'high_reshape_const_data': { 'shape': target_shape.shape, 'value': target_shape }, }, nodes_with_edges_only=True) graph.stage = 'middle' QuantizeLinearResolver().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp)
def test_quantize_dequantize_linear_resolver(self): graph = build_graph(nodes_attrs=nodes_attributes, edges=[ *connect('input', '0:non_const_quantize'), *connect('y_scale_2', '1:non_const_quantize'), *connect('y_zeropoint_2', '2:non_const_quantize'), *connect('non_const_quantize', '0:non_const_dequantize'), *connect('x_scale_2', '1:non_const_dequantize'), *connect('x_zeropoint_2', '2:non_const_dequantize'), *connect('const_input', '0:const_quantize'), *connect('y_scale_1', '1:const_quantize'), *connect('y_zeropoint_1', '2:const_quantize'), *connect('const_quantize', '0:const_dequantize'), *connect('x_scale_1', '1:const_dequantize'), *connect('x_zeropoint_1', '2:const_dequantize'), *connect('const_dequantize', '0:add'), *connect('non_const_dequantize', '1:add'), *connect('add', 'result') ], nodes_with_edges_only=True) const_ref_graph = build_graph(nodes_attrs=nodes_attributes, edges=[ *connect('input', '0:non_const_quantize'), *connect('y_scale_2', '1:non_const_quantize'), *connect('y_zeropoint_2', '2:non_const_quantize'), *connect('non_const_quantize', '0:non_const_dequantize'), *connect('x_scale_2', '1:non_const_dequantize'), *connect('x_zeropoint_2', '2:non_const_dequantize'), *connect('const_input', '0:const_fq'), *connect('y_scale_1:0', '0:mul_low'), *connect('in_low', '1:mul_low'), ('y_scale_1_d', 'mul_high', {'out': 1, 'in': 0}), *connect('in_high', '1:mul_high'), *connect('mul_low', '1:const_fq'), *connect('mul_high', '2:const_fq'), *connect('out_low', '3:const_fq'), *connect('out_high', '4:const_fq'), *connect('const_fq', 'const_cast'), *connect('const_cast', '0:const_dequantize'), *connect('x_scale_1', '1:const_dequantize'), *connect('x_zeropoint_1', '2:const_dequantize'), *connect('const_dequantize', '0:add'), *connect('non_const_dequantize', '1:add'), *connect('add', 'result') ],nodes_with_edges_only=True) QuantizeDequantizeLinearResolver().find_and_replace_pattern(graph) graph.graph['layout'] = 'NCHW' (flag, resp) = compare_graphs(graph, const_ref_graph, 'result') self.assertTrue(flag, resp) ref_graph = build_graph(nodes_attrs=nodes_attributes, edges=[ *connect('input', '0:non_const_fq'), *connect('y_scale_2:0', '0:non_const_mul_low'), *connect('non_const_in_low', '1:non_const_mul_low'), ('y_scale_2_d', 'non_const_mul_high', {'out': 1, 'in': 0}), *connect('non_const_in_high', '1:non_const_mul_high'), *connect('non_const_mul_low', '1:non_const_fq'), *connect('non_const_mul_high', '2:non_const_fq'), *connect('non_const_out_low', '3:non_const_fq'), *connect('non_const_out_high', '4:non_const_fq'), *connect('non_const_fq', 'non_const_cast'), *connect('non_const_cast', '0:non_const_dequantize'), *connect('x_scale_2', '1:non_const_dequantize'), *connect('x_zeropoint_2', '2:non_const_dequantize'), *connect('const_input', '0:const_fq'), *connect('y_scale_1:0', '0:mul_low'), *connect('in_low', '1:mul_low'), ('y_scale_1_d', 'mul_high', {'out': 1, 'in': 0}), *connect('in_high', '1:mul_high'), *connect('mul_low', '1:const_fq'), *connect('mul_high', '2:const_fq'), *connect('out_low', '3:const_fq'), *connect('out_high', '4:const_fq'), *connect('const_fq', 'const_cast'), *connect('const_cast', '0:const_dequantize'), *connect('x_scale_1', '1:const_dequantize'), *connect('x_zeropoint_1', '2:const_dequantize'), *connect('const_dequantize', '0:add'), *connect('non_const_dequantize', '1:add'), *connect('add', 'result') ], nodes_with_edges_only=True) QuantizeLinearResolver().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, ref_graph, 'result') self.assertTrue(flag, resp)