def replace_pattern(self, graph: Graph, match: dict): node = match['op'] N, H, W, C = match['in_data'].shape block_size = node['block_size'] graph.remove_edge(match['in_data'].id, node.id) graph.remove_edge(node.id, match['out_data'].id) dim_6D = int64_array([0, block_size, block_size, int(C / (block_size ** 2)), H, W]) order_6D = int64_array([0, 3, 4, 1, 5, 2]) dim_4D = int64_array([0, int(H * block_size), int(W * block_size), int(C / (block_size ** 2))]) reshape_6_op = Reshape(graph, dict(name=node.id + '/Reshape_to_6D')) reshape_6_const_data = Const(graph, dict(value=dim_6D)).create_node_with_data() reshape_6_data_node = reshape_6_op.create_node_with_data([match['in_data'], reshape_6_const_data]) mark_as_correct_data_layout(reshape_6_data_node.in_node(0)) order_const_data = Const(graph, dict(value=order_6D)).create_node_with_data() transpose_op = Transpose(graph, dict(name=node.id + '/Transpose')) transpose_data_node = transpose_op.create_node_with_data([reshape_6_data_node, order_const_data]) mark_as_correct_data_layout(transpose_data_node.in_node(0)) reshape_4_op = Reshape(graph, dict(name=node.id + '/Reshape_to_4D')) reshape_4_const_data = Const(graph, dict(value=dim_4D)).create_node_with_data() reshape_4_data_node = reshape_4_op.create_node_with_data([transpose_data_node, reshape_4_const_data], data_nodes=[match['out_data']]) mark_input_as_in_correct_layout(reshape_4_data_node.in_node(0), 0) mark_output_as_in_correct_layout(reshape_4_data_node.in_node(0), 0)
def find_and_replace_pattern(self, graph: Graph): visited = set() marked_nodes = set() condition_forward = lambda n: not InsertLayoutPropagationTranspose.is_nhwc_to_nchw_transpose_needed( n) condition_backward = lambda n: not InsertLayoutPropagationTranspose.is_nchw_to_nhwc_transpose_needed( n) for node_condition in self.op_conditions: for node in graph.get_op_nodes(): if node_condition(node): log.debug( 'Detected node "{}" as a node which should be executed in the original layout' ''.format(node.soft_get('name', node.id))) forward_visited_nodes = self.bfs([node], visited, condition_forward, True) backward_visited_nodes = self.bfs([node], visited, condition_backward, False) # find "reinterp_shape" like ops which change rank of input to 4D or 5D from smaller dimensions for back_node in backward_visited_nodes: for input_node in self.get_input_nodes(back_node): if input_node not in backward_visited_nodes and not condition_forward( input_node): marked_nodes.add(input_node) # find "reinterp_shape" like ops which change rank of input from 4D or 5D to smaller dimensions for forward_node in forward_visited_nodes: for output_node in self.get_output_nodes(forward_node): if output_node not in forward_visited_nodes and not condition_backward( output_node): marked_nodes.add(output_node) marked_nodes.update(forward_visited_nodes + backward_visited_nodes) if len(marked_nodes): log.debug( 'The following nodes will be executed in the original layout: {}' ''.format([n.soft_get('name', n.id) for n in marked_nodes])) # mark all matched nodes as in correct layout and disable attributes permutation for them for visited_node in marked_nodes: mark_as_correct_data_layout(visited_node) visited_node['nchw_layout'] = True