def test_1(self): # # NHWC NCHW NHWC # Input->DATA->Transpose->DATA->Transpose->DATA => Input->DATA # graph = build_graph(nodes_attributes, [ ('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'permute_1'), ('permute_1', 'permute_1_data'), ('permute_1_data', 'permute_2'), ('permute_2', 'permute_2_data'), ('permute_2_data', 'op_output'), ('const_1', 'const_1_data'), ('const_1_data', 'permute_1', { 'in': 1 }), ('const_2', 'const_2_data'), ('const_2_data', 'permute_2', { 'in': 1 }), ], { 'placeholder_1_data': { 'shape': np.array([1, 227, 227, 3]) }, 'const_1_data': { 'value': np.array([0, 3, 1, 2]) }, 'permute_1_data': { 'shape': np.array([1, 3, 227, 227]) }, 'const_2_data': { 'value': np.array([0, 2, 3, 1]) }, 'permute_2_data': { 'shape': np.array([1, 227, 227, 3]) }, }, nodes_with_edges_only=True) graph.graph['layout'] = 'NHWC' graph.graph['cmd_params'] = Namespace(keep_shape_ops=False) graph_ref = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'op_output')], {'placeholder_1_data': { 'shape': np.array([1, 227, 227, 3]) }}, nodes_with_edges_only=True) pattern = FuseTransposesSequence() pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1_data', check_op_attrs=True) self.assertTrue(flag, resp)
def find_and_replace_pattern(self, graph: Graph): for start_node in graph.pseudo_topological_sort(): matched_nodes = [] if self.is_node_match_for_optimization(start_node): next_node = start_node while self.is_node_match_for_optimization(next_node): matched_nodes.append(next_node) next_node[self.OPTIMIZED_NODE_FLAG] = True next_nodes = get_next_operation(next_node) if len(next_nodes) > 1: log.debug('There are two consumers of the node {}. Stop matching sequence.'.format( next_node.soft_get('name'))) break next_node = next_nodes[0] # optimize sequence of three or more Transpose-Reshape nodes if len(matched_nodes) >= 3: self.optimize_permute_reshape_sequence(graph, matched_nodes) # run the RemoveRedundantReshapes to remove dummy (NOP) reshapes. After that we can run Transposes fusing FuseReshapesSequence().find_and_replace_pattern(graph) RemoveRedundantReshapes().find_and_replace_pattern(graph) FuseTransposesSequence().find_and_replace_pattern(graph)