def insert_transpose(node, in_port_idx): graph = node.graph name = node.soft_get('name', node.id) assert in_port_idx in node.in_ports() and not node.in_port(in_port_idx).disconnected(), \ 'Input port with index {} should be connected for node {}'.format(in_port_idx, name) in_port = node.in_port(in_port_idx) port_shape = in_port.data.get_shape() assert port_shape is not None, \ 'Shape is unknown for input port with index {} for node {}'.format(in_port_idx, name) transpose_order = list(range(port_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[ -2], transpose_order[-1] order = Const(graph, { 'value': int64_array(transpose_order) }).create_node() transpose = Transpose( graph, { 'name': name + '/{}_port_transpose'.format(in_port_idx) }).create_node() port_source = in_port.get_source() in_port.get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(port_source) transpose.in_port(1).connect(order.out_port(0)) transpose['override_output_shape'] = True
def find_and_replace_pattern(self, graph: Graph): if graph.graph['layout'] != 'NHWC': # we check it here because this transformation is called explicitly from the pipeline return # reshape from 4D-5D -> ND. Insert Transpose(NC(D)HW->N(D)HWC) before Reshape for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True): reinterp_shape_node = Node(graph, reinterp_shape_node_id) assert 0 in reinterp_shape_node.in_nodes(), 'Node {} does not have 0 input. \n{}'.format( reinterp_shape_node_id, graph.dump_graph_for_graphviz()) input_shape = reinterp_shape_node.in_node(0).shape if not is_input_data_in_correct_layout(reinterp_shape_node, 0) and len(input_shape) >= 4: order_const = Const(graph, {'value': PermuteAttrs().get_nchw_to_nhwc_permutation(len(input_shape)).perm }).create_node() permute_node = Transpose(graph, {'name': reinterp_shape_node.in_port(0).get_source().node.name + '/Transpose' }).create_node() reinterp_shape_node.in_port(0).get_connection().insert_node(permute_node) order_const.out_port(0).connect(permute_node.in_port(1)) order_const.infer(order_const) # do not infer the Transpose node because it should have input data node in NCHW layout (but currently # it is NHWC because data node attributes has not been permuted yet) and produce output in NHWC layout # (which is true at this moment) permute_node['need_shape_inference'] = False # mark the Transpose output data node having correct layout so it's shape will not be permuted mark_output_as_in_correct_layout(permute_node, 0) # keep the reinterp_shape_node in NHWC layout mark_input_as_in_correct_layout(reinterp_shape_node, 0) mark_input_as_in_correct_layout(reinterp_shape_node, 1) # reshape from ND -> 4D-5D. Insert Transpose(N(D)HWC->NC(D)HW) after Reshape for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True): reinterp_shape_node = Node(graph, reinterp_shape_node_id) assert 0 in reinterp_shape_node.out_nodes(), 'Node {} does not have 0 output. \n{}'.format( reinterp_shape_node_id, graph.dump_graph_for_graphviz()) output_shape = reinterp_shape_node.out_node(0).shape if not is_output_data_in_correct_layout(reinterp_shape_node, 0) and len(output_shape) >= 4: order_const = Const(graph, { 'value': PermuteAttrs().get_nhwc_to_nchw_permutation(len(output_shape)).perm}).create_node() permute_node = Transpose(graph, {'name': reinterp_shape_node.id + '/Transpose'}).create_node() reinterp_shape_node.out_port(0).get_connection().insert_node(permute_node) order_const.out_port(0).connect(permute_node.in_port(1)) # the Reshape and Transpose operations should work in original (NHWC layout) so the Transpose # will convert it to the NCHW mark_input_as_in_correct_layout(permute_node, 0) mark_input_as_in_correct_layout(permute_node, 1) # do not set Transpose output data node 'correct_data_layout' attribute so the data node shape will be # permuted # keep the reinterp_shape_node in NHWC layout mark_output_as_in_correct_layout(reinterp_shape_node, 0) mark_input_as_in_correct_layout(reinterp_shape_node, 1) # do not re-infer the Transpose node because it output data node should be in NHWC layout to make the # rest of the graph consistent permute_node['need_shape_inference'] = False
def replace_op(self, graph: Graph, node: Node): pb = node.parameters weights_size = read_binary_integer32_token(pb) weights = read_blob(pb, weights_size, dtype=np.int32) - 1 node_name = node.soft_get('name', node.id) const_attrs = { 'name': node_name + '/indexes', 'value': np.array(weights), 'shape': [weights_size], 'data_type': np.int32 } indexes_node = Const(graph).create_node(attrs=const_attrs) perm_in_1 = Const(graph, {'value': int64_array([1, 0]), 'name': node_name + '/order'}).create_node() perm1_node = Transpose(graph, {'name': node_name + '/input_permute'}).create_node([node.in_node(0)]) perm1_node.in_port(0).connect(node.in_port(0).get_source()) perm1_node.in_port(1).connect(perm_in_1.out_port(0)) gather_node = create_op_with_const_inputs(graph, Gather, {2: int64_array(0)}, {'name': node_name + '/gather'}) gather_node.in_port(0).connect(perm1_node.out_port(0)) gather_node.in_port(1).connect(indexes_node.out_port(0)) perm2_node = Transpose(graph, {'name': node_name + '/output_permute'}).create_node() perm2_node.in_port(0).connect(gather_node.out_port(0)) perm2_node.in_port(1).connect(perm_in_1.out_port(0)) return [perm2_node.id]
def replace_op(self, graph: Graph, node: Node): pb = node.parameters weights_size = read_binary_integer32_token(pb) weights = read_blob(pb, weights_size, dtype=np.int32) - 1 const_attrs = { 'name': 'indexes/{}'.format(node.id), 'value': np.array(weights), 'shape': [weights_size], 'data_type': np.int32 } indexes_node = Const(graph).create_node(attrs=const_attrs) perm_in_1 = Const( graph, { 'value': np.array([1, 0], dtype=np.int64), 'shape': [2], 'data_type': np.int64 }).create_node() axis_const = Const(graph, {'value': int64_array(0)}).create_node() perm1_node = Transpose(graph, { 'name': 'input_permute' }).create_node([node.in_node(0)]) perm1_node.in_port(0).connect(node.in_port(0).get_source()) perm1_node.in_port(1).connect(perm_in_1.out_port(0)) gather_node = Gather(graph, {}).create_node() gather_node.in_port(0).connect(perm1_node.out_port(0)) gather_node.in_port(1).connect(indexes_node.out_port(0)) gather_node.in_port(2).connect(axis_const.out_port(0)) perm2_node = Transpose(graph, {'name': 'output_permute'}).create_node() perm2_node.in_port(0).connect(gather_node.out_port(0)) perm2_node.in_port(1).connect(perm_in_1.out_port(0)) return [perm2_node.id]
def replace_sub_graph(self, graph: Graph, match: dict): target_node = match['target_node'] nodes_with_weights = self.dfs( graph, target_node.name, ('Convolution', 'FullyConnected', 'ScaleShift'), True) convolution_nodes = [ node for node in nodes_with_weights if Node(graph, node).op == 'Convolution' ] for convolution_node in convolution_nodes: target_node = self.search_target_node(Node(graph, convolution_node)) order_const = Const( graph, dict(value=np.array([0, 3, 2, 1]))).create_node() permute_node = Transpose( graph, dict(name=target_node.name + '/Transpose')).create_node() target_node.insert_node_after(permute_node, 0) order_const.out_port(0).connect(permute_node.in_port(1))
def replace_pattern(graph: Graph, match: dict): node = match['matmul'] name = node.soft_get('name', node.id) A_shape = node.in_port(0).data.get_shape() B_shape = node.in_port(1).data.get_shape() out_shape = node.out_port(0).data.get_shape() assert A_shape is not None and B_shape is not None and out_shape is not None B_value = node.in_port(1).data.get_value() if (B_value is not None or node.in_port(1).get_source().node.has_and_set('stop_value_propagation')) and B_shape[ B_shape != 1].size <= 2: # transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O] # to FullyConnected representation: [I, K] * [O, K] = [I, O] B, I, K, O, aligned_A_shape, aligned_B_shape = MatMulToFullyConnected.get_matmul_BIKO(node) # weights normalization if not node.transpose_b: # FullyConnected weights layout is OI # MatMul second input layout is (B)IO transpose_order = list(range(B_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = Transpose(graph, {'name': name + '/weights_transpose'}).create_node() weights_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(weights_source) transpose.in_port(1).connect(order.out_port(0)) order.infer(order) transpose.infer(transpose) if node.in_port(1).data.get_shape().size != 2: const = Const(graph, {'value': int64_array([-1, K])}).create_node() reshape = Reshape(graph, {'name': name + '/weights_reshape'}).create_node() weights_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(reshape.out_port(0)) reshape.in_port(0).connect(weights_source) reshape.in_port(1).connect(const.out_port(0)) const.infer(const) reshape.infer(reshape) assert np.all(np.array_equal(node.in_port(1).data.get_shape(), int64_array([O, K]))), \ "MatMul `{}` was not converted to FullyConnected: wrong weights shape: {}, " \ "B={}, I={}, K={}, O={}".format(name, node.in_port(1).data.get_shape(), B, I, K, O) node.in_port(1).bin = 'weights' del node['transpose_b'] # input normalization if node.transpose_a: transpose_order = list(range(A_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = Transpose(graph, {'name': name + '/input_transpose'}).create_node() input_source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(input_source) transpose.in_port(1).connect(order.out_port(0)) order.infer(order) transpose.infer(transpose) if A_shape.size != 2: const = Const(graph, {'value': int64_array([-1, K])}).create_node() reshape = Reshape(graph, {'name': name + '/input_reshape'}).create_node() input_source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape.out_port(0)) reshape.in_port(0).connect(input_source) reshape.in_port(1).connect(const.out_port(0)) const.infer(const) reshape.infer(reshape) assert np.all(np.array_equal(node.in_port(0).data.get_shape(), int64_array([np.prod(B) * I, K]))), \ "MatMul `{}` wasn't converted to FullyConnected: wrong input shape: {}, " \ "B={}, I={}, K={}, O={}".format(name, node.in_port(0).data.get_shape(), B, I, K, O) del node['transpose_a'] FullyConnected.update_node_stat(node, {'out-size': O}) # output normalization if out_shape.size != 2: const = Const(graph, {'value': int64_array([*B, I, O])}).create_node() reshape = Reshape(graph, {'name': name + '/output_reshape'}).create_node() dst = node.out_port(0).get_destination() node.out_port(0).get_connection().set_destination(reshape.in_port(0)) const.out_port(0).connect(reshape.in_port(1)) reshape.out_port(0).connect(dst) node.infer(node) const.infer(const) reshape.infer(reshape) else: assert A_shape.size == out_shape.size assert B_shape.size <= out_shape.size if B_shape.size != out_shape.size: unsqueeze_dim = Const(graph, {'value': int64_array(list(range(out_shape.size - B_shape.size))) }).create_node() unsqueeze = Unsqueeze(graph, {}).create_node() B_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(unsqueeze.out_port(0)) unsqueeze.in_port(0).connect(B_source) unsqueeze.in_port(1).connect(unsqueeze_dim.out_port(0)) unsqueeze_dim.infer(unsqueeze_dim) unsqueeze.infer(unsqueeze) Gemm.update_node_stat(node, { 'transpose_a': node.has_and_set('transpose_a'), 'transpose_b': node.has_and_set('transpose_b'), })