def conv_flatten_concat_action(graph: Graph, match: dict): assert graph.graph['layout'] == 'NHWC' reshape_node = match['reshape'] reshape_data_node = match['reshape_data'] conv_name = match['conv'].name conv_data_node = match['conv_data'] # the pattern should be applied only in case when the reshape operation changes number of dimensions if len(reshape_data_node.shape) == len( conv_data_node.shape) or reshape_node.has_and_set('nchw_layout'): return if len(reshape_data_node.out_nodes()) == 1 and reshape_data_node.out_node().has_valid('type') and \ reshape_data_node.out_node().type == 'FullyConnected' and \ can_repack_fully_connected_weights_nhwc_to_nchw(reshape_data_node.out_node()): log.info( 'There is a FullyConnected layer after the node "{}" which weights will be repacked. So there is no ' 'need to insert Permute'.format(reshape_node.soft_get('name'))) return graph.remove_edge(conv_data_node.id, reshape_node.id) permutation_order = PermuteAttrs.get_nchw_to_nhwc_permutation( len(conv_data_node.shape)).perm new_permute_op = Permute(graph, {'order': permutation_order}) permute_data_node = new_permute_op.create_node_with_data( [conv_data_node], dict(name=conv_name + '/Permute_')) graph.create_edge(permute_data_node, reshape_node) # Disable permutation for Reshape and Concat layers attributes PermuteAttrs.set_permutation(reshape_node, reshape_data_node, None) reshape_node['nchw_layout'] = True
def replace_pattern(graph: Graph, match: dict): reshape = match['reshape'] assert len(reshape.in_nodes()) > 0 if graph.graph['layout'] == 'NCHW' or reshape.has_and_set('nchw_layout') or\ reshape.soft_get('correct_data_layout') is True: return input_node = reshape.in_node() output_node = reshape.out_node() input_shape = input_node.shape output_shape = output_node.shape if len(input_shape) >= 4 and len(output_shape) == 3: # Check that we will permute some shapes in this Reshape by our permutation pass layout = 'NCHW' c_idx = get_features_dim(layout, len(input_shape)) hw_idx = [ get_width_dim(layout, len(input_shape)), get_height_dim(layout, len(input_shape)) ] if input_shape[c_idx] != 1 and np.any( input_shape[hw_idx] != [1, 1]): # then nhwc -> nchw permutation can change shapes significantly # We need to wrap up node with NCHW -> NHWC permutes and don't touch it later permutation = PermuteAttrs.get_nchw_to_nhwc_permutation( len(input_shape)) permutation_back = PermuteAttrs.get_nchw_to_nhwc_permutation( len(input_shape)) # 1. Insert input Permute # This Permute will permute input from original input layout to operation layout edge_attrs = graph.get_edge_data(input_node.id, reshape.id)[0] graph.remove_edge(input_node.id, reshape.id) permute_op = Permute(graph, { 'order': permutation.perm, 'name': reshape.name + '/Permute_' }) permute_data_node = permute_op.create_node_with_data( [input_node]) graph.add_edge(permute_data_node.id, reshape.id, **edge_attrs)
def find_and_replace_pattern(self, graph: Graph): for node in list(graph.nodes()): node = Node(graph, node) # Check that node layout mismatch with graph layout # For example: NHWC and NCHW or NCDHW and NDHWC if node.kind == 'op' and node.has_valid( 'layout') and node.layout != indices_mapping[len( node.layout)][graph.graph['layout']]: input = node.in_node() output = node.out_node() # Calculate permutation for further Transpose operations if graph.graph['layout'] == 'NCHW': # if Node has NCHW and graph has NHWC layout permutation = PermuteAttrs.get_nhwc_to_nchw_permutation( len(node.layout)) else: # if Node has NHWC and graph has NCHW layout permutation = PermuteAttrs.get_nchw_to_nhwc_permutation( len(node.layout)) # Schematic representation of transformation below # # \ NCHW NCHW # NHWC -- \ | permutation permutation | # data-->Convolution(example)-->data -- / | | NCHW | | # / data->Transpose->data->Convolution->data->Transpose->data # 1. Insert input Transpose # This Transpose will permute input from original input layout to operation layout edge_attrs = graph.get_edge_data(input.id, node.id)[0] graph.remove_edge(input.id, node.id) input_order_const = Const(graph, { 'value': permutation.perm }).create_node_with_data() input_permute_op = Transpose( graph, dict(name=node.name + '/Transpose_')) input_permute_data_node = input_permute_op.create_node_with_data( [input, input_order_const]) graph.add_edge(input_permute_data_node.id, node.id, **edge_attrs) # 2. Insert output Transpose # This Transpose will permute output from operation layout to original input layout edge_attrs = graph.get_edge_data(node.id, output.id)[0] graph.remove_edge(node.id, output.id) input_data_node = Op.create_data_node( graph, node, {'shape': output.shape[permutation.perm]}, edge_attrs) output_order_const = Const(graph, { 'value': permutation.inv }).create_node_with_data() output_permute_op = Transpose( graph, dict(name=node.name + '/Transpose_')).create_node_with_data( [input_data_node, output_order_const], data_nodes=output) # 3. Add permutations for Node # Here we use permutation mechanism where data nodes takes permutation attribute. # And then we call permute_attrs method that permutes node attributes according to permutations on # data nodes. node.in_node()['permutation'] = permutation node.out_node()['permutation'] = permutation node.permute_attrs.permute_attrs(node) node.in_node()['permutation'] = None node.out_node()['permutation'] = None
def permute_nchw_to_nhwc(shape): perm = PermuteAttrs.get_nchw_to_nhwc_permutation(len(shape)).perm new_shape = np.array(shape)[perm] return new_shape
def permute_nchw_to_nhwc(shape, use_new_frontend=False): if use_new_frontend: return shape perm = PermuteAttrs.get_nchw_to_nhwc_permutation(len(shape)).perm new_shape = np.array(shape)[perm] return new_shape