def extract(cls, node): attrs = { 'op': __class__.op, 'order': node.module.order, } Transpose.update_node_stat(node, attrs) return cls.enabled
def replace_pattern(self, graph: Graph, match: dict): node = match['op'] N, H, W, C = match['in_data'].shape block_size = node['block_size'] graph.remove_edge(match['in_data'].id, node.id) graph.remove_edge(node.id, match['out_data'].id) dim_6D = int64_array([0, block_size, block_size, int(C / (block_size ** 2)), H, W]) order_6D = int64_array([0, 3, 4, 1, 5, 2]) dim_4D = int64_array([0, int(H * block_size), int(W * block_size), int(C / (block_size ** 2))]) reshape_6_op = Reshape(graph, dict(name=node.id + '/Reshape_to_6D')) reshape_6_const_data = Const(graph, dict(value=dim_6D)).create_node_with_data() reshape_6_data_node = reshape_6_op.create_node_with_data([match['in_data'], reshape_6_const_data]) mark_as_correct_data_layout(reshape_6_data_node.in_node(0)) order_const_data = Const(graph, dict(value=order_6D)).create_node_with_data() transpose_op = Transpose(graph, dict(name=node.id + '/Transpose')) transpose_data_node = transpose_op.create_node_with_data([reshape_6_data_node, order_const_data]) mark_as_correct_data_layout(transpose_data_node.in_node(0)) reshape_4_op = Reshape(graph, dict(name=node.id + '/Reshape_to_4D')) reshape_4_const_data = Const(graph, dict(value=dim_4D)).create_node_with_data() reshape_4_data_node = reshape_4_op.create_node_with_data([transpose_data_node, reshape_4_const_data], data_nodes=[match['out_data']]) mark_input_as_in_correct_layout(reshape_4_data_node.in_node(0), 0) mark_output_as_in_correct_layout(reshape_4_data_node.in_node(0), 0)
def test_transpose_infer_1(self, order): graph = self._create_graph_with_transpose(order) transpose_node = Node(graph, 'transpose') Transpose.infer(transpose_node) ref = [transpose_node.in_node().shape[i] for i in order] self.assertTrue(np.array_equal(transpose_node.out_node().shape, np.array(ref)))
def extract(node): # In case of undefined 'perm' attribute, Transpose operation in ONNX reverse the dimensions order = onnx_attr(node, 'perm', 'ints', default=None) attrs = { 'order': int64_array(order) if order is not None else None, 'reverse_order': order is None } Transpose.update_node_stat(node, attrs) return __class__.enabled
def test_transpose_infer_2(self): order = None graph = self._create_graph_with_transpose(order) transpose_node = Node(graph, 'transpose') transpose_node['reverse_order'] = True Transpose.infer(transpose_node) ref = np.array([x for x in reversed(transpose_node.in_node().shape)]) self.assertTrue(np.array_equal(transpose_node.out_node().shape, ref), "Shapes are not the same: {} and {}".format(transpose_node.out_node().shape, ref))
def replace_pattern(self, graph: Graph, match: [str, Node]): swapaxis = match['op'] assert len(swapaxis.in_ports()) == 1 assert swapaxis.has_and_set('order') order = swapaxis.order swapaxis.add_input_port(1) const = Const(graph, {'value': order}).create_node() const.out_port(0).connect(swapaxis.in_port(1)) Transpose.update_node_stat(swapaxis, {'need_shape_inference': True}) del swapaxis['order']
def add_output_reshape(graph: Graph, match: dict): """ Since MXNet Y output shape is [batch_size, seq_len, hidden_size * num_directions] we need to add reshape from above common format [batch_size, num_directions, seq_len, hidden_size] to MXNet format. """ lstm = match['rnn_layer'] input = match['input'] if not lstm.has_num_directions: return old_data_node = lstm.out_node(0) num_directions = 2 if lstm.direction in ['bidirectional'] else 1 mxnet_shape = lstm.out_node(0).shape.copy() if lstm.batch_dim == 0: mo_shape = np.array([ input.shape[lstm.batch_dim], input.shape[lstm.sequence_dim], lstm.hidden_size ], dtype=np.int64) else: mo_shape = np.array([ input.shape[lstm.sequence_dim], input.shape[lstm.batch_dim], lstm.hidden_size ], dtype=np.int64) if lstm.has_num_directions: mo_shape = np.insert(mo_shape, 1, np.int64(num_directions)) lstm_name = lstm.soft_get('name', lstm.id) new_data = Op._create_data_node(graph, name=lstm_name + '/Data/Reshape_mxnet/', attrs={'shape': mo_shape}) graph.remove_edge(lstm.id, old_data_node.id) graph.add_edge(lstm.id, new_data.id, key=0, out=0) # Add Transpose permute_order = Const( graph, { 'name': lstm_name + '/Transpose_mxnet_order', 'value': int64_array([0, 2, 1, 3]) }).create_node_with_data() permute_data = Transpose(graph, { 'name': lstm_name + '/Transpose_mxnet/' }).create_node_with_data([new_data, permute_order]) # Add Reshape reshape = Reshape(graph, {'name': lstm_name + '/Reshape_mxnet/'}) reshape_dim_data = Const(graph, { 'name': lstm_name + '/Reshape_mxnet_dim', 'value': mxnet_shape }).create_node_with_data() reshape.create_node_with_data([permute_data, reshape_dim_data], dict(), data_nodes=[old_data_node])
def insert_transpose(node, in_port_idx): graph = node.graph name = node.soft_get('name', node.id) assert in_port_idx in node.in_ports() and not node.in_port(in_port_idx).disconnected(), \ 'Input port with index {} should be connected for node {}'.format(in_port_idx, name) in_port = node.in_port(in_port_idx) port_shape = in_port.data.get_shape() assert port_shape is not None, \ 'Shape is unknown for input port with index {} for node {}'.format(in_port_idx, name) transpose_order = list(range(port_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[ -2], transpose_order[-1] order = Const(graph, { 'value': int64_array(transpose_order) }).create_node() transpose = Transpose( graph, { 'name': name + '/{}_port_transpose'.format(in_port_idx) }).create_node() port_source = in_port.get_source() in_port.get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(port_source) transpose.in_port(1).connect(order.out_port(0)) transpose['override_output_shape'] = True
def replace_sub_graph(self, graph: Graph, match: dict): target_node = match['target_node'] nodes_with_weights = self.dfs( graph, target_node.name, ('Convolution', 'FullyConnected', 'ScaleShift'), True) convolution_nodes = [ node for node in nodes_with_weights if Node(graph, node).op == 'Convolution' ] for convolution_node in convolution_nodes: target_node = self.search_target_node(Node(graph, convolution_node)) order_const = Const( graph, dict(value=np.array([0, 3, 2, 1]))).create_node() permute_node = Transpose( graph, dict(name=target_node.name + '/Transpose')).create_node() target_node.insert_node_after(permute_node, 0) order_const.out_port(0).connect(permute_node.in_port(1))
def find_and_replace_pattern(self, graph: Graph): if graph.graph['layout'] != 'NHWC': # we check it here because this transformation is called explicitly from the pipeline return # reshape from 4D-5D -> ND. Insert Transpose(NC(D)HW->N(D)HWC) before Reshape for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True): reinterp_shape_node = Node(graph, reinterp_shape_node_id) assert 0 in reinterp_shape_node.in_nodes(), 'Node {} does not have 0 input. \n{}'.format( reinterp_shape_node_id, graph.dump_graph_for_graphviz()) input_shape = reinterp_shape_node.in_node(0).shape if not is_input_data_in_correct_layout(reinterp_shape_node, 0) and len(input_shape) >= 4: order_const = Const(graph, {'value': PermuteAttrs().get_nchw_to_nhwc_permutation(len(input_shape)).perm }).create_node() permute_node = Transpose(graph, {'name': reinterp_shape_node.in_port(0).get_source().node.name + '/Transpose' }).create_node() reinterp_shape_node.in_port(0).get_connection().insert_node(permute_node) order_const.out_port(0).connect(permute_node.in_port(1)) order_const.infer(order_const) # do not infer the Transpose node because it should have input data node in NCHW layout (but currently # it is NHWC because data node attributes has not been permuted yet) and produce output in NHWC layout # (which is true at this moment) permute_node['need_shape_inference'] = False # mark the Transpose output data node having correct layout so it's shape will not be permuted mark_output_as_in_correct_layout(permute_node, 0) # keep the reinterp_shape_node in NHWC layout mark_input_as_in_correct_layout(reinterp_shape_node, 0) mark_input_as_in_correct_layout(reinterp_shape_node, 1) # reshape from ND -> 4D-5D. Insert Transpose(N(D)HWC->NC(D)HW) after Reshape for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True): reinterp_shape_node = Node(graph, reinterp_shape_node_id) assert 0 in reinterp_shape_node.out_nodes(), 'Node {} does not have 0 output. \n{}'.format( reinterp_shape_node_id, graph.dump_graph_for_graphviz()) output_shape = reinterp_shape_node.out_node(0).shape if not is_output_data_in_correct_layout(reinterp_shape_node, 0) and len(output_shape) >= 4: order_const = Const(graph, { 'value': PermuteAttrs().get_nhwc_to_nchw_permutation(len(output_shape)).perm}).create_node() permute_node = Transpose(graph, {'name': reinterp_shape_node.id + '/Transpose'}).create_node() reinterp_shape_node.out_port(0).get_connection().insert_node(permute_node) order_const.out_port(0).connect(permute_node.in_port(1)) # the Reshape and Transpose operations should work in original (NHWC layout) so the Transpose # will convert it to the NCHW mark_input_as_in_correct_layout(permute_node, 0) mark_input_as_in_correct_layout(permute_node, 1) # do not set Transpose output data node 'correct_data_layout' attribute so the data node shape will be # permuted # keep the reinterp_shape_node in NHWC layout mark_output_as_in_correct_layout(reinterp_shape_node, 0) mark_input_as_in_correct_layout(reinterp_shape_node, 1) # do not re-infer the Transpose node because it output data node should be in NHWC layout to make the # rest of the graph consistent permute_node['need_shape_inference'] = False
def extract(cls, node): Transpose.update_node_stat(node, {'order': None}) return cls.enabled
def extract(node): Transpose.update_node_stat(node, {'order': None}) return __class__.enabled
def find_and_replace_pattern(self, graph: Graph): for node in list(graph.nodes()): node = Node(graph, node) # Check that node layout mismatch with graph layout # For example: NHWC and NCHW or NCDHW and NDHWC if node.kind == 'op' and node.has_valid( 'layout') and node.layout != indices_mapping[len( node.layout)][graph.graph['layout']]: input = node.in_node() output = node.out_node() # Calculate permutation for further Transpose operations if graph.graph['layout'] == 'NCHW': # if Node has NCHW and graph has NHWC layout permutation = PermuteAttrs.get_nhwc_to_nchw_permutation( len(node.layout)) else: # if Node has NHWC and graph has NCHW layout permutation = PermuteAttrs.get_nchw_to_nhwc_permutation( len(node.layout)) # Schematic representation of transformation below # # \ NCHW NCHW # NHWC -- \ | permutation permutation | # data-->Convolution(example)-->data -- / | | NCHW | | # / data->Transpose->data->Convolution->data->Transpose->data # 1. Insert input Transpose # This Transpose will permute input from original input layout to operation layout edge_attrs = graph.get_edge_data(input.id, node.id)[0] graph.remove_edge(input.id, node.id) input_order_const = Const(graph, { 'value': permutation.perm }).create_node_with_data() input_permute_op = Transpose( graph, dict(name=node.name + '/Transpose_')) input_permute_data_node = input_permute_op.create_node_with_data( [input, input_order_const]) graph.add_edge(input_permute_data_node.id, node.id, **edge_attrs) # 2. Insert output Transpose # This Transpose will permute output from operation layout to original input layout edge_attrs = graph.get_edge_data(node.id, output.id)[0] graph.remove_edge(node.id, output.id) input_data_node = Op.create_data_node( graph, node, {'shape': output.shape[permutation.perm]}, edge_attrs) output_order_const = Const(graph, { 'value': permutation.inv }).create_node_with_data() output_permute_op = Transpose( graph, dict(name=node.name + '/Transpose_')).create_node_with_data( [input_data_node, output_order_const], data_nodes=output) # 3. Add permutations for Node # Here we use permutation mechanism where data nodes takes permutation attribute. # And then we call permute_attrs method that permutes node attributes according to permutations on # data nodes. node.in_node()['permutation'] = permutation node.out_node()['permutation'] = permutation node.permute_attrs.permute_attrs(node) node.in_node()['permutation'] = None node.out_node()['permutation'] = None
def replace_pattern(graph: Graph, match: dict): node = match['matmul'] name = node.soft_get('name', node.id) A_shape = node.in_port(0).data.get_shape() B_shape = node.in_port(1).data.get_shape() out_shape = node.out_port(0).data.get_shape() assert A_shape is not None and B_shape is not None and out_shape is not None B_value = node.in_port(1).data.get_value() if (B_value is not None or node.in_port(1).get_source().node.has_and_set('stop_value_propagation')) and B_shape[ B_shape != 1].size <= 2: # transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O] # to FullyConnected representation: [I, K] * [O, K] = [I, O] B, I, K, O, aligned_A_shape, aligned_B_shape = MatMulToFullyConnected.get_matmul_BIKO(node) # weights normalization if not node.transpose_b: # FullyConnected weights layout is OI # MatMul second input layout is (B)IO transpose_order = list(range(B_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = Transpose(graph, {'name': name + '/weights_transpose'}).create_node() weights_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(weights_source) transpose.in_port(1).connect(order.out_port(0)) order.infer(order) transpose.infer(transpose) if node.in_port(1).data.get_shape().size != 2: const = Const(graph, {'value': int64_array([-1, K])}).create_node() reshape = Reshape(graph, {'name': name + '/weights_reshape'}).create_node() weights_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(reshape.out_port(0)) reshape.in_port(0).connect(weights_source) reshape.in_port(1).connect(const.out_port(0)) const.infer(const) reshape.infer(reshape) assert np.all(np.array_equal(node.in_port(1).data.get_shape(), int64_array([O, K]))), \ "MatMul `{}` was not converted to FullyConnected: wrong weights shape: {}, " \ "B={}, I={}, K={}, O={}".format(name, node.in_port(1).data.get_shape(), B, I, K, O) node.in_port(1).bin = 'weights' del node['transpose_b'] # input normalization if node.transpose_a: transpose_order = list(range(A_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = Transpose(graph, {'name': name + '/input_transpose'}).create_node() input_source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(input_source) transpose.in_port(1).connect(order.out_port(0)) order.infer(order) transpose.infer(transpose) if A_shape.size != 2: const = Const(graph, {'value': int64_array([-1, K])}).create_node() reshape = Reshape(graph, {'name': name + '/input_reshape'}).create_node() input_source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape.out_port(0)) reshape.in_port(0).connect(input_source) reshape.in_port(1).connect(const.out_port(0)) const.infer(const) reshape.infer(reshape) assert np.all(np.array_equal(node.in_port(0).data.get_shape(), int64_array([np.prod(B) * I, K]))), \ "MatMul `{}` wasn't converted to FullyConnected: wrong input shape: {}, " \ "B={}, I={}, K={}, O={}".format(name, node.in_port(0).data.get_shape(), B, I, K, O) del node['transpose_a'] FullyConnected.update_node_stat(node, {'out-size': O}) # output normalization if out_shape.size != 2: const = Const(graph, {'value': int64_array([*B, I, O])}).create_node() reshape = Reshape(graph, {'name': name + '/output_reshape'}).create_node() dst = node.out_port(0).get_destination() node.out_port(0).get_connection().set_destination(reshape.in_port(0)) const.out_port(0).connect(reshape.in_port(1)) reshape.out_port(0).connect(dst) node.infer(node) const.infer(const) reshape.infer(reshape) else: assert A_shape.size == out_shape.size assert B_shape.size <= out_shape.size if B_shape.size != out_shape.size: unsqueeze_dim = Const(graph, {'value': int64_array(list(range(out_shape.size - B_shape.size))) }).create_node() unsqueeze = Unsqueeze(graph, {}).create_node() B_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(unsqueeze.out_port(0)) unsqueeze.in_port(0).connect(B_source) unsqueeze.in_port(1).connect(unsqueeze_dim.out_port(0)) unsqueeze_dim.infer(unsqueeze_dim) unsqueeze.infer(unsqueeze) Gemm.update_node_stat(node, { 'transpose_a': node.has_and_set('transpose_a'), 'transpose_b': node.has_and_set('transpose_b'), })
def replace_op(self, graph: Graph, node: Node): pb = node.parameters weights_size = read_binary_integer32_token(pb) weights = read_blob(pb, weights_size, dtype=np.int32) - 1 const_attrs = { 'name': 'indexes/{}'.format(node.id), 'value': np.array(weights), 'shape': [weights_size], 'data_type': np.int32 } indexes_node = Const(graph).create_node(attrs=const_attrs) perm_in_1 = Const( graph, { 'value': np.array([1, 0], dtype=np.int64), 'shape': [2], 'data_type': np.int64 }).create_node() axis_const = Const(graph, {'value': int64_array(0)}).create_node() perm1_node = Transpose(graph, { 'name': 'input_permute' }).create_node([node.in_node(0)]) perm1_node.in_port(0).connect(node.in_port(0).get_source()) perm1_node.in_port(1).connect(perm_in_1.out_port(0)) gather_node = Gather(graph, {}).create_node() gather_node.in_port(0).connect(perm1_node.out_port(0)) gather_node.in_port(1).connect(indexes_node.out_port(0)) gather_node.in_port(2).connect(axis_const.out_port(0)) perm2_node = Transpose(graph, {'name': 'output_permute'}).create_node() perm2_node.in_port(0).connect(gather_node.out_port(0)) perm2_node.in_port(1).connect(perm_in_1.out_port(0)) return [perm2_node.id]
def extract(cls, node): order = node.pb.permute_param.order Transpose.update_node_stat(node, {'order': np.array(order, dtype=np.int32)}) return cls.enabled
def replace_op(self, graph: Graph, node: Node): pb = node.parameters weights_size = read_binary_integer32_token(pb) weights = read_blob(pb, weights_size, dtype=np.int32) - 1 node_name = node.soft_get('name', node.id) const_attrs = { 'name': node_name + '/indexes', 'value': np.array(weights), 'shape': [weights_size], 'data_type': np.int32 } indexes_node = Const(graph).create_node(attrs=const_attrs) perm_in_1 = Const(graph, {'value': int64_array([1, 0]), 'name': node_name + '/order'}).create_node() perm1_node = Transpose(graph, {'name': node_name + '/input_permute'}).create_node([node.in_node(0)]) perm1_node.in_port(0).connect(node.in_port(0).get_source()) perm1_node.in_port(1).connect(perm_in_1.out_port(0)) gather_node = create_op_with_const_inputs(graph, Gather, {2: int64_array(0)}, {'name': node_name + '/gather'}) gather_node.in_port(0).connect(perm1_node.out_port(0)) gather_node.in_port(1).connect(indexes_node.out_port(0)) perm2_node = Transpose(graph, {'name': node_name + '/output_permute'}).create_node() perm2_node.in_port(0).connect(gather_node.out_port(0)) perm2_node.in_port(1).connect(perm_in_1.out_port(0)) return [perm2_node.id]
def extract(node): attrs = get_mxnet_layer_attrs(node.symbol_dict) order = list(attrs.tuple("axes", int, None)) Transpose.update_node_stat(node, {'order': np.array(order, dtype=np.int32)}) return __class__.enabled