def test_mark_unfused_nodes_2(self): # Placeholder->ScaleShift->Mul->Add graph = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'mul_1'), ('mul_1_w', 'mul_1'), ('mul_1', 'mul_1_data'), ('mul_1_data', 'add_1'), ('add_1_w', 'add_1'), ('add_1', 'add_1_data'), ('add_1_data', 'mul_2'), ('mul_2_w', 'mul_2'), ('mul_2', 'mul_2_data'), ('mul_2_data', 'concat_1'), ('concat_1', 'concat_1_data'), ('placeholder_1_data', 'concat_1'), ('concat_1_data', 'op_output')], { 'placeholder_1_data': { 'shape': np.array([1, 227, 227, 3]) }, 'mul_1_data': { 'shape': np.array([1, 227, 227, 3]) }, 'add_1_data': { 'shape': np.array([1, 227, 227, 3]) }, 'mul_2_data': { 'shape': np.array([1, 227, 227, 3]) }, 'mul_1_w': { 'shape': np.array([1]), 'value': 6 }, 'add_1_w': { 'shape': np.array([1]), 'value': 6 }, 'mul_2_w': { 'shape': np.array([1]), 'value': 6 }, }) graph.graph['layout'] = 'NHWC' mark_unfused_nodes(graph, '.*') self.assertFalse(graph.node['mul_1']['can_be_fused'], "can_be_fused should be False") self.assertFalse(graph.node['mul_2']['can_be_fused'], "can_be_fused should be False") self.assertFalse(graph.node['add_1']['can_be_fused'], "can_be_fused should be False") self.assertFalse(graph.node['placeholder_1']['can_be_fused'], "can_be_fused should be False") self.assertFalse(graph.node['concat_1']['can_be_fused'], "can_be_fused should be False")
def find_and_replace_pattern(self, graph: Graph): fw = graph.graph['fw'] argv = graph.graph['cmd_params'] layout = graph.graph['layout'] for_graph_and_each_sub_graph_recursively(graph, fuse_pad) for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up()) # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes for_graph_and_each_sub_graph_recursively( graph, lambda graph: mark_unfused_nodes(graph, argv.finegrain_fusing)) # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence # IE doesn't support batchNormInference with 4 inputs, so we have to split it to two ScaleShift for_graph_and_each_sub_graph_recursively(graph, convert_batch_norm) if fw == 'caffe': # Converting ScaleShift layer to Mul->Add for_graph_and_each_sub_graph_recursively( graph, convert_scale_shift_to_mul_add) for_graph_and_each_sub_graph_recursively( graph, Div().find_and_replace_pattern) for_graph_and_each_sub_graph_recursively( graph, Sub().find_and_replace_pattern) for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up()) if not argv.disable_fusing: if fw != 'caffe': # Converting ScaleShift layer to Mul->Add for_graph_and_each_sub_graph_recursively( graph, convert_scale_shift_to_mul_add) for_graph_and_each_sub_graph_recursively( graph, lambda G: G.clean_up()) # Fusing the sequences of Mul/Add operations for_graph_and_each_sub_graph_recursively(graph, fuse_mul_add_sequence) for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up()) normalize_eltwise_inputs(graph) for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up()) # Fusing linear operation to Convolution for_graph_and_each_sub_graph_recursively(graph, fuse_linear_ops) for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up()) if not argv.disable_gfusing: for_graph_and_each_sub_graph_recursively( graph, grouped_convolutions_fusing) for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up()) if not argv.disable_fusing: for_graph_and_each_sub_graph_recursively( graph, fuse_linear_ops) for_graph_and_each_sub_graph_recursively( graph, lambda G: G.clean_up()) for_graph_and_each_sub_graph_recursively(graph, normalize_eltwise_inputs) for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up()) if not argv.disable_fusing: MarkNodesToFuseUpToFakeQuantize().find_and_replace_pattern(graph) FakeQuantizeFuse().find_and_replace_pattern(graph) AddFakeQuantizeFuse().find_and_replace_pattern(graph) MulFakeQuantizeFuse().find_and_replace_pattern(graph) for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up()) for_graph_and_each_sub_graph_recursively(graph, fuse_pad) for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up()) if layout != 'NHWC' and not argv.disable_resnet_optimization: stride_optimization(graph)