def replace_pattern(self, graph: Graph, match: dict): node = match['op'] N, H, W, C = match['in_data'].shape block_size = node['block_size'] graph.remove_edge(match['in_data'].id, node.id) graph.remove_edge(node.id, match['out_data'].id) dim_6D = int64_array([0, block_size, block_size, int(C / (block_size ** 2)), H, W]) order_6D = int64_array([0, 3, 4, 1, 5, 2]) dim_4D = int64_array([0, int(H * block_size), int(W * block_size), int(C / (block_size ** 2))]) reshape_6_op = Reshape(graph, dict(name=node.id + '/Reshape_to_6D')) reshape_6_const_data = Const(graph, dict(value=dim_6D)).create_node_with_data() reshape_6_data_node = reshape_6_op.create_node_with_data([match['in_data'], reshape_6_const_data]) mark_as_correct_data_layout(reshape_6_data_node.in_node(0)) order_const_data = Const(graph, dict(value=order_6D)).create_node_with_data() transpose_op = Transpose(graph, dict(name=node.id + '/Transpose')) transpose_data_node = transpose_op.create_node_with_data([reshape_6_data_node, order_const_data]) mark_as_correct_data_layout(transpose_data_node.in_node(0)) reshape_4_op = Reshape(graph, dict(name=node.id + '/Reshape_to_4D')) reshape_4_const_data = Const(graph, dict(value=dim_4D)).create_node_with_data() reshape_4_data_node = reshape_4_op.create_node_with_data([transpose_data_node, reshape_4_const_data], data_nodes=[match['out_data']]) mark_input_as_in_correct_layout(reshape_4_data_node.in_node(0), 0) mark_output_as_in_correct_layout(reshape_4_data_node.in_node(0), 0)
def find_and_replace_pattern(self, graph: Graph): for node in list(graph.nodes()): node = Node(graph, node) # Check that node layout mismatch with graph layout # For example: NHWC and NCHW or NCDHW and NDHWC if node.kind == 'op' and node.has_valid( 'layout') and node.layout != indices_mapping[len( node.layout)][graph.graph['layout']]: input = node.in_node() output = node.out_node() # Calculate permutation for further Transpose operations if graph.graph['layout'] == 'NCHW': # if Node has NCHW and graph has NHWC layout permutation = PermuteAttrs.get_nhwc_to_nchw_permutation( len(node.layout)) else: # if Node has NHWC and graph has NCHW layout permutation = PermuteAttrs.get_nchw_to_nhwc_permutation( len(node.layout)) # Schematic representation of transformation below # # \ NCHW NCHW # NHWC -- \ | permutation permutation | # data-->Convolution(example)-->data -- / | | NCHW | | # / data->Transpose->data->Convolution->data->Transpose->data # 1. Insert input Transpose # This Transpose will permute input from original input layout to operation layout edge_attrs = graph.get_edge_data(input.id, node.id)[0] graph.remove_edge(input.id, node.id) input_order_const = Const(graph, { 'value': permutation.perm }).create_node_with_data() input_permute_op = Transpose( graph, dict(name=node.name + '/Transpose_')) input_permute_data_node = input_permute_op.create_node_with_data( [input, input_order_const]) graph.add_edge(input_permute_data_node.id, node.id, **edge_attrs) # 2. Insert output Transpose # This Transpose will permute output from operation layout to original input layout edge_attrs = graph.get_edge_data(node.id, output.id)[0] graph.remove_edge(node.id, output.id) input_data_node = Op.create_data_node( graph, node, {'shape': output.shape[permutation.perm]}, edge_attrs) output_order_const = Const(graph, { 'value': permutation.inv }).create_node_with_data() output_permute_op = Transpose( graph, dict(name=node.name + '/Transpose_')).create_node_with_data( [input_data_node, output_order_const], data_nodes=output) # 3. Add permutations for Node # Here we use permutation mechanism where data nodes takes permutation attribute. # And then we call permute_attrs method that permutes node attributes according to permutations on # data nodes. node.in_node()['permutation'] = permutation node.out_node()['permutation'] = permutation node.permute_attrs.permute_attrs(node) node.in_node()['permutation'] = None node.out_node()['permutation'] = None