def permute_before_and_after(inp: Node, middle: Node, out: Node, order): ''' Insert two permutes: before middle node and after middle node. The first permute has a given order, the second permute has an inversed order. ''' permute = Permute(middle.graph, dict(order=np.array(order))) edge_attrs = deepcopy(middle.graph.get_edge_data(inp.id, middle.id)[0]) middle.graph.remove_edge(inp.id, middle.id) new_inp = permute.create_node_with_data([inp], dict(name=middle.name + '/InputPermute')) middle.graph.add_edge(new_inp.id, middle.id, **edge_attrs) permute = Permute(middle.graph, dict(order=inverse_perm(np.array(order)))) middle.graph.remove_edge(middle.id, out.id) new_out = Op._create_data_node(middle.graph, name=middle.name + '/WithoutPermute', attrs={'shape': out.shape[order]}) middle.graph.add_edge(middle.id, new_out.id, key=0, out=0) permute.create_node_with_data([new_out], dict(name=middle.name + '/OutputPermute'), data_nodes=out)
def conv_flatten_concat_action(graph: Graph, match: dict): assert graph.graph['layout'] == 'NHWC' reshape_node = match['reshape'] reshape_data_node = match['reshape_data'] conv_name = match['conv'].name conv_data_node = match['conv_data'] # the pattern should be applied only in case when the reshape operation changes number of dimensions if len(reshape_data_node.shape) == len( conv_data_node.shape) or reshape_node.has_and_set('nchw_layout'): return if len(reshape_data_node.out_nodes()) == 1 and reshape_data_node.out_node().has_valid('type') and \ reshape_data_node.out_node().type == 'FullyConnected' and \ can_repack_fully_connected_weights_nhwc_to_nchw(reshape_data_node.out_node()): log.info( 'There is a FullyConnected layer after the node "{}" which weights will be repacked. So there is no ' 'need to insert Permute'.format(reshape_node.soft_get('name'))) return graph.remove_edge(conv_data_node.id, reshape_node.id) permutation_order = PermuteAttrs.get_nchw_to_nhwc_permutation( len(conv_data_node.shape)).perm new_permute_op = Permute(graph, {'order': permutation_order}) permute_data_node = new_permute_op.create_node_with_data( [conv_data_node], dict(name=conv_name + '/Permute_')) graph.create_edge(permute_data_node, reshape_node) # Disable permutation for Reshape and Concat layers attributes PermuteAttrs.set_permutation(reshape_node, reshape_data_node, None) reshape_node['nchw_layout'] = True
def replace_pattern(self, graph: Graph, match: dict): if graph.graph['layout'] != "NCHW": return node = match['op'] in_node = node.in_node(0) out_node = node.out_node(0) group = int(node['group']) graph.remove_edge(in_node.id, node.id) graph.remove_edge(node.id, out_node.id) rows = group cols = in_node.shape[1] // group if rows * cols != in_node.shape[1]: raise Error("Group {} should divide input channels number {} without reminder for node {}" "".format(group, in_node.shape[1], node.id)) reshape_split = Reshape(graph, attrs={'name': node.id + '/Reshape_split_', 'dim': np.array([in_node.shape[0], rows, cols, -1])}) reshape_split_node = reshape_split.create_node_with_data([in_node]) transpose = Permute(graph, attrs={'name': node.id + '/Transpose_', 'order': np.array([0, 2, 1, 3])}) transpose_node = transpose.create_node_with_data([reshape_split_node]) reshape_concat = Reshape(graph, attrs={'name': node.id + '/Reshape_concat_', 'dim': out_node.shape}) reshape_concat.create_node_with_data([transpose_node], data_nodes=[out_node])
def permute_before_and_after(inp: Node, middle: Node, out: Node, input_order, output_order): """ Insert two permutes: before middle node and after middle node. Both permutes has a given order (input/output). """ # Permute before input permute = Permute(middle.graph, dict(order=np.array(input_order))) edge_attrs = deepcopy(middle.graph.get_edge_data(inp.id, middle.id)[0]) middle.graph.remove_edge(inp.id, middle.id) new_inp = permute.create_node_with_data([inp], dict(name=middle.name + '/InputPermute')) middle.graph.add_edge(new_inp.id, middle.id, **edge_attrs) # Permute after output permute = Permute(middle.graph, dict(order=output_order)) middle.graph.remove_edge(middle.id, out.id) new_out = Op._create_data_node(middle.graph, name=middle.name + '/WithoutPermute', attrs={'shape': out.shape[output_order]}) middle.graph.add_edge(middle.id, new_out.id, key=0, out=0) permute.create_node_with_data([new_out], dict(name=middle.name + '/OutputPermute'), data_nodes=out)
def add_output_reshape(graph: Graph, match: dict): """ Since MXNet Y output shape is [batch_size, seq_len, hidden_size * num_directions] we need to add reshape from above common format [batch_size, num_directions, seq_len, hidden_size] to MXNet format. """ lstm = match['rnn_layer'] input = match['input'] if not lstm.has_num_directions: return old_data_node = lstm.out_node(0) num_directions = 2 if lstm.direction in ['bidirectional'] else 1 mxnet_shape = lstm.out_node(0).shape.copy() if lstm.batch_dim == 0: mo_shape = np.array([ input.shape[lstm.batch_dim], input.shape[lstm.sequence_dim], lstm.hidden_size ], dtype=np.int64) else: mo_shape = np.array([ input.shape[lstm.sequence_dim], input.shape[lstm.batch_dim], lstm.hidden_size ], dtype=np.int64) if lstm.has_num_directions: mo_shape = np.insert(mo_shape, 1, np.int64(num_directions)) new_data = Op._create_data_node(graph, name=lstm.name + '/Data/Reshape_mxnet/', attrs={'shape': mo_shape}) graph.remove_edge(lstm.id, old_data_node.id) graph.add_edge(lstm.id, new_data.id, key=0, out=0) # Add Permute permute_order = np.array([0, 2, 1, 3], dtype=np.int64) permute = Permute(graph, dict(order=permute_order)) permute_data = permute.create_node_with_data([new_data], dict(name=lstm.name + '/Permute_mxnet/')) # Add Reshape reshape = Reshape(graph, dict(dim=mxnet_shape)) reshape.create_node_with_data([permute_data], dict(name=lstm.name + '/Reshape_mxnet/'), data_nodes=[old_data_node])
def replace_pattern(graph: Graph, match: dict): reshape = match['reshape'] assert len(reshape.in_nodes()) > 0 if graph.graph['layout'] == 'NCHW' or reshape.has_and_set('nchw_layout') or\ reshape.soft_get('correct_data_layout') is True: return input_node = reshape.in_node() output_node = reshape.out_node() input_shape = input_node.shape output_shape = output_node.shape if len(input_shape) >= 4 and len(output_shape) == 3: # Check that we will permute some shapes in this Reshape by our permutation pass layout = 'NCHW' c_idx = get_features_dim(layout, len(input_shape)) hw_idx = [ get_width_dim(layout, len(input_shape)), get_height_dim(layout, len(input_shape)) ] if input_shape[c_idx] != 1 and np.any( input_shape[hw_idx] != [1, 1]): # then nhwc -> nchw permutation can change shapes significantly # We need to wrap up node with NCHW -> NHWC permutes and don't touch it later permutation = PermuteAttrs.get_nchw_to_nhwc_permutation( len(input_shape)) permutation_back = PermuteAttrs.get_nchw_to_nhwc_permutation( len(input_shape)) # 1. Insert input Permute # This Permute will permute input from original input layout to operation layout edge_attrs = graph.get_edge_data(input_node.id, reshape.id)[0] graph.remove_edge(input_node.id, reshape.id) permute_op = Permute(graph, { 'order': permutation.perm, 'name': reshape.name + '/Permute_' }) permute_data_node = permute_op.create_node_with_data( [input_node]) graph.add_edge(permute_data_node.id, reshape.id, **edge_attrs)
def replace_pattern(self, graph: nx.MultiDiGraph, match: dict): if graph.graph['layout'] != 'NHWC': return if self.is_reshape_bad(match['reshape_pack'], match['reshape_unpack'], match['strided_slice']): log.info("Reshape that pack/unpack several dimensions detected {}". format(match['reshape_pack'].id)) node_split = match['reshape_split'] # insert Permute before reshape data_node = Op._create_data_node( graph, node_split.name + "/Permute_before_data") permute_before = Permute( graph, dict(name=node_split.name + "/Permute_before", order=np.array([0, 2, 3, 1]))) in_node = node_split.in_node(0) attrs = deepcopy(graph.get_edge_data(in_node.id, node_split.id)[0]) graph.remove_edge(in_node.id, node_split.id) permute_before_node = permute_before.create_node_with_data( [in_node], permute_before.attrs, data_nodes=[data_node]) graph.add_edge(permute_before_node.id, node_split.id, **attrs) node = match['reshape_pack'] node['nchw_layout'] = True new_reshape_shape = np.concatenate( (np.array([node.in_node(0).shape[0]]), np.array([np.prod(node.in_node(0).shape[[1, 2, 3]])]), np.array([node.in_node(0).shape[-1]]))) node.dim = new_reshape_shape # insert Permute after reshape data_node = Op._create_data_node(graph, node.name + "/Permute_after_data", {'shape': node.dim}) permute_after = Permute( graph, dict(name=node.name + "/Permute_after", order=np.array([0, 2, 1]))) out_node = node.out_node(0) out_node.shape = new_reshape_shape[np.array([0, 2, 1])] attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0]) graph.remove_edge(node.id, out_node.id) permute_after_node = permute_after.create_node_with_data( [data_node], permute_after.attrs, data_nodes=[out_node]) graph.add_edge(node.id, data_node.id, **attrs) # update softmax shape node_softmax = match['softmax'] node_softmax.out_node(0).shape = out_node.shape # revert strided slice and reshape node_ss = match['strided_slice'] node_unpack = match['reshape_unpack'] unpack_out = node_unpack.out_node(0).id ss_out = node_ss.out_node(0).id #gather edge attributes soft_reshape_attrs = deepcopy( graph.get_edge_data( node_softmax.out_node(0).id, node_unpack.id)[0]) reshape_data_attrs = deepcopy( graph.get_edge_data(node_unpack.id, unpack_out)[0]) reshape_ss_attrs = deepcopy( graph.get_edge_data(unpack_out, node_ss.id)[0]) ss_out_attrs = deepcopy(graph.get_edge_data(node_ss.id, ss_out)[0]) #remove all edges in Softmax->Reshape->StridedSlice chain graph.remove_edge(node_softmax.out_node(0).id, node_unpack.id) graph.remove_edge(node_unpack.id, unpack_out) graph.remove_edge(unpack_out, node_ss.id) graph.remove_edge(node_ss.id, ss_out) #add new edges to get chain Softmax->StridedSlice->Reshape graph.add_edge( node_softmax.out_node(0).id, node_ss.id, **soft_reshape_attrs) graph.add_edge(node_ss.id, unpack_out, **reshape_data_attrs) graph.add_edge(unpack_out, node_unpack.id, **reshape_ss_attrs) graph.add_edge(node_unpack.id, ss_out, **ss_out_attrs) #update output shape and parameters for StridedSlice node_ss.out_node(0).shape = np.zeros(3) node_ss.out_node(0).shape[0] = out_node.shape[0] node_ss.out_node(0).shape[1] = 1 node_ss.out_node(0).shape[2] = out_node.shape[2] old_slices = node_ss.slices.copy() node_ss.slices = [] node_ss.slices.append(old_slices[0]) node_ss.slices.append(old_slices[-1]) node_ss.slices.append(slice(0, out_node.shape[2], 1)) node_ss.shrink_axis_mask = [False, False, False] node_ss.new_axis_mask = [False, False, False] #update Reshape attribute node_unpack.dim = np.delete(node_unpack.dim, 4) #prevent permute for reshape because it gives wrong result node_unpack['nchw_layout'] = True node_unpack.out_node(0)['nchw_layout'] = True
def find_and_replace_pattern(self, graph: Graph): for node in list(graph.nodes()): node = Node(graph, node) # Check that node layout mismatch with graph layout # For example: NHWC and NCHW or NCDHW and NDHWC if node.kind == 'op' and node.has_valid( 'layout') and node.layout != indices_mapping[len( node.layout)][graph.graph['layout']]: input = node.in_node() output = node.out_node() # Calculate permutation for further Permute operations if graph.graph['layout'] == 'NCHW': # if Node has NCHW and graph has NHWC layout permutation = PermuteAttrs.get_nhwc_to_nchw_permutation( len(node.layout)) else: # if Node has NHWC and graph has NCHW layout permutation = PermuteAttrs.get_nchw_to_nhwc_permutation( len(node.layout)) # Schematic representation of transformation below # # \ NCHW NCHW # NHWC -- \ | permutation permutation | # data-->Convolution(example)-->data -- / | | NCHW | | # / data->Permute->data->Convolution->data->Permute->data # 1. Insert input Permute # This Permute will permute input from original input layout to operation layout edge_attrs = graph.get_edge_data(input.id, node.id)[0] graph.remove_edge(input.id, node.id) input_permute_op = Permute(graph, {'order': permutation.perm}) input_permute_data_node = input_permute_op.create_node_with_data( [input], dict(name=node.name + '/Permute_')) graph.add_edge(input_permute_data_node.id, node.id, **edge_attrs) # 2. Insert output Permute # This Permute will permute output from operation layout to original input layout edge_attrs = graph.get_edge_data(node.id, output.id)[0] graph.remove_edge(node.id, output.id) input_data_node = Op.create_data_node( graph, node, {'shape': output.shape[permutation.perm]}, edge_attrs) output_permute_op = Permute(graph, {'order': permutation.inv}) output_permute_op.create_node_with_data([input_data_node], dict(name=node.name + '/Permute_'), data_nodes=output) # 3. Add permutations for Node # Here we use permutation mechanism where data nodes takes permutation attribute. # And then we call permute_attrs method that permutes node attributes according to permutations on # data nodes. node.in_node()['permutation'] = permutation node.out_node()['permutation'] = permutation node.permute_attrs.permute_attrs(node) node.in_node()['permutation'] = None node.out_node()['permutation'] = None