def replace_op(self, graph: Graph, node: Node): pb = node.parameters weights_size = read_binary_integer32_token(pb) weights = read_blob(pb, weights_size, dtype=np.int32) - 1 node_name = node.soft_get('name', node.id) const_attrs = { 'name': node_name + '/indexes', 'value': np.array(weights), 'shape': [weights_size], 'data_type': np.int32 } indexes_node = Const(graph).create_node(attrs=const_attrs) perm_in_1 = Const(graph, {'value': int64_array([1, 0]), 'name': node_name + '/order'}).create_node() perm1_node = Transpose(graph, {'name': node_name + '/input_permute'}).create_node([node.in_node(0)]) perm1_node.in_port(0).connect(node.in_port(0).get_source()) perm1_node.in_port(1).connect(perm_in_1.out_port(0)) gather_node = create_op_with_const_inputs(graph, Gather, {2: int64_array(0)}, {'name': node_name + '/gather'}) gather_node.in_port(0).connect(perm1_node.out_port(0)) gather_node.in_port(1).connect(indexes_node.out_port(0)) perm2_node = Transpose(graph, {'name': node_name + '/output_permute'}).create_node() perm2_node.in_port(0).connect(gather_node.out_port(0)) perm2_node.in_port(1).connect(perm_in_1.out_port(0)) return [perm2_node.id]
def replace_op(self, graph: Graph, node: Node): pb = node.parameters weights_size = read_binary_integer32_token(pb) weights = read_blob(pb, weights_size, dtype=np.int32) - 1 const_attrs = { 'name': 'indexes/{}'.format(node.id), 'value': np.array(weights), 'shape': [weights_size], 'data_type': np.int32 } indexes_node = Const(graph).create_node(attrs=const_attrs) perm_in_1 = Const( graph, { 'value': np.array([1, 0], dtype=np.int64), 'shape': [2], 'data_type': np.int64 }).create_node() axis_const = Const(graph, {'value': int64_array(0)}).create_node() perm1_node = Transpose(graph, { 'name': 'input_permute' }).create_node([node.in_node(0)]) perm1_node.in_port(0).connect(node.in_port(0).get_source()) perm1_node.in_port(1).connect(perm_in_1.out_port(0)) gather_node = Gather(graph, {}).create_node() gather_node.in_port(0).connect(perm1_node.out_port(0)) gather_node.in_port(1).connect(indexes_node.out_port(0)) gather_node.in_port(2).connect(axis_const.out_port(0)) perm2_node = Transpose(graph, {'name': 'output_permute'}).create_node() perm2_node.in_port(0).connect(gather_node.out_port(0)) perm2_node.in_port(1).connect(perm_in_1.out_port(0)) return [perm2_node.id]
def insert_transpose(node, in_port_idx): graph = node.graph name = node.soft_get('name', node.id) assert in_port_idx in node.in_ports() and not node.in_port(in_port_idx).disconnected(), \ 'Input port with index {} should be connected for node {}'.format(in_port_idx, name) in_port = node.in_port(in_port_idx) port_shape = in_port.data.get_shape() assert port_shape is not None, \ 'Shape is unknown for input port with index {} for node {}'.format(in_port_idx, name) transpose_order = list(range(port_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[ -2], transpose_order[-1] order = Const(graph, { 'value': int64_array(transpose_order) }).create_node() transpose = Transpose( graph, { 'name': name + '/{}_port_transpose'.format(in_port_idx) }).create_node() port_source = in_port.get_source() in_port.get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(port_source) transpose.in_port(1).connect(order.out_port(0)) transpose['override_output_shape'] = True
def replace_pattern(graph: Graph, match: dict): node = match['matmul'] name = node.soft_get('name', node.id) A_shape = node.in_port(0).data.get_shape() B_shape = node.in_port(1).data.get_shape() out_shape = node.out_port(0).data.get_shape() assert A_shape is not None and B_shape is not None and out_shape is not None B_value = node.in_port(1).data.get_value() if (B_value is not None or node.in_port(1).get_source().node.has_and_set('stop_value_propagation')) and B_shape[ B_shape != 1].size <= 2: # transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O] # to FullyConnected representation: [I, K] * [O, K] = [I, O] B, I, K, O, aligned_A_shape, aligned_B_shape = MatMulToFullyConnected.get_matmul_BIKO(node) # weights normalization if not node.transpose_b: # FullyConnected weights layout is OI # MatMul second input layout is (B)IO transpose_order = list(range(B_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = Transpose(graph, {'name': name + '/weights_transpose'}).create_node() weights_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(weights_source) transpose.in_port(1).connect(order.out_port(0)) order.infer(order) transpose.infer(transpose) if node.in_port(1).data.get_shape().size != 2: const = Const(graph, {'value': int64_array([-1, K])}).create_node() reshape = Reshape(graph, {'name': name + '/weights_reshape'}).create_node() weights_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(reshape.out_port(0)) reshape.in_port(0).connect(weights_source) reshape.in_port(1).connect(const.out_port(0)) const.infer(const) reshape.infer(reshape) assert np.all(np.array_equal(node.in_port(1).data.get_shape(), int64_array([O, K]))), \ "MatMul `{}` was not converted to FullyConnected: wrong weights shape: {}, " \ "B={}, I={}, K={}, O={}".format(name, node.in_port(1).data.get_shape(), B, I, K, O) node.in_port(1).bin = 'weights' del node['transpose_b'] # input normalization if node.transpose_a: transpose_order = list(range(A_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = Transpose(graph, {'name': name + '/input_transpose'}).create_node() input_source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(input_source) transpose.in_port(1).connect(order.out_port(0)) order.infer(order) transpose.infer(transpose) if A_shape.size != 2: const = Const(graph, {'value': int64_array([-1, K])}).create_node() reshape = Reshape(graph, {'name': name + '/input_reshape'}).create_node() input_source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape.out_port(0)) reshape.in_port(0).connect(input_source) reshape.in_port(1).connect(const.out_port(0)) const.infer(const) reshape.infer(reshape) assert np.all(np.array_equal(node.in_port(0).data.get_shape(), int64_array([np.prod(B) * I, K]))), \ "MatMul `{}` wasn't converted to FullyConnected: wrong input shape: {}, " \ "B={}, I={}, K={}, O={}".format(name, node.in_port(0).data.get_shape(), B, I, K, O) del node['transpose_a'] FullyConnected.update_node_stat(node, {'out-size': O}) # output normalization if out_shape.size != 2: const = Const(graph, {'value': int64_array([*B, I, O])}).create_node() reshape = Reshape(graph, {'name': name + '/output_reshape'}).create_node() dst = node.out_port(0).get_destination() node.out_port(0).get_connection().set_destination(reshape.in_port(0)) const.out_port(0).connect(reshape.in_port(1)) reshape.out_port(0).connect(dst) node.infer(node) const.infer(const) reshape.infer(reshape) else: assert A_shape.size == out_shape.size assert B_shape.size <= out_shape.size if B_shape.size != out_shape.size: unsqueeze_dim = Const(graph, {'value': int64_array(list(range(out_shape.size - B_shape.size))) }).create_node() unsqueeze = Unsqueeze(graph, {}).create_node() B_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(unsqueeze.out_port(0)) unsqueeze.in_port(0).connect(B_source) unsqueeze.in_port(1).connect(unsqueeze_dim.out_port(0)) unsqueeze_dim.infer(unsqueeze_dim) unsqueeze.infer(unsqueeze) Gemm.update_node_stat(node, { 'transpose_a': node.has_and_set('transpose_a'), 'transpose_b': node.has_and_set('transpose_b'), })