def test_transpose_infer_1(self, order): graph = self._create_graph_with_transpose(order) transpose_node = Node(graph, 'transpose') Transpose.infer(transpose_node) ref = [transpose_node.in_node().shape[i] for i in order] self.assertTrue(np.array_equal(transpose_node.out_node().shape, np.array(ref)))
def test_transpose_infer_2(self): order = None graph = self._create_graph_with_transpose(order) transpose_node = Node(graph, 'transpose') transpose_node['reverse_order'] = True Transpose.infer(transpose_node) ref = np.array([x for x in reversed(transpose_node.in_node().shape)]) self.assertTrue(np.array_equal(transpose_node.out_node().shape, ref), "Shapes are not the same: {} and {}".format(transpose_node.out_node().shape, ref))
def replace_pattern(graph: Graph, match: dict): node = match['matmul'] name = node.soft_get('name', node.id) A_shape = node.in_port(0).data.get_shape() B_shape = node.in_port(1).data.get_shape() out_shape = node.out_port(0).data.get_shape() assert A_shape is not None and B_shape is not None and out_shape is not None B_value = node.in_port(1).data.get_value() if (B_value is not None or node.in_port(1).get_source().node.has_and_set('stop_value_propagation')) and B_shape[ B_shape != 1].size <= 2: # transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O] # to FullyConnected representation: [I, K] * [O, K] = [I, O] B, I, K, O, aligned_A_shape, aligned_B_shape = MatMulToFullyConnected.get_matmul_BIKO(node) # weights normalization if not node.transpose_b: # FullyConnected weights layout is OI # MatMul second input layout is (B)IO transpose_order = list(range(B_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = Transpose(graph, {'name': name + '/weights_transpose'}).create_node() weights_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(weights_source) transpose.in_port(1).connect(order.out_port(0)) order.infer(order) transpose.infer(transpose) if node.in_port(1).data.get_shape().size != 2: const = Const(graph, {'value': int64_array([-1, K])}).create_node() reshape = Reshape(graph, {'name': name + '/weights_reshape'}).create_node() weights_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(reshape.out_port(0)) reshape.in_port(0).connect(weights_source) reshape.in_port(1).connect(const.out_port(0)) const.infer(const) reshape.infer(reshape) assert np.all(np.array_equal(node.in_port(1).data.get_shape(), int64_array([O, K]))), \ "MatMul `{}` was not converted to FullyConnected: wrong weights shape: {}, " \ "B={}, I={}, K={}, O={}".format(name, node.in_port(1).data.get_shape(), B, I, K, O) node.in_port(1).bin = 'weights' del node['transpose_b'] # input normalization if node.transpose_a: transpose_order = list(range(A_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = Transpose(graph, {'name': name + '/input_transpose'}).create_node() input_source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(input_source) transpose.in_port(1).connect(order.out_port(0)) order.infer(order) transpose.infer(transpose) if A_shape.size != 2: const = Const(graph, {'value': int64_array([-1, K])}).create_node() reshape = Reshape(graph, {'name': name + '/input_reshape'}).create_node() input_source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape.out_port(0)) reshape.in_port(0).connect(input_source) reshape.in_port(1).connect(const.out_port(0)) const.infer(const) reshape.infer(reshape) assert np.all(np.array_equal(node.in_port(0).data.get_shape(), int64_array([np.prod(B) * I, K]))), \ "MatMul `{}` wasn't converted to FullyConnected: wrong input shape: {}, " \ "B={}, I={}, K={}, O={}".format(name, node.in_port(0).data.get_shape(), B, I, K, O) del node['transpose_a'] FullyConnected.update_node_stat(node, {'out-size': O}) # output normalization if out_shape.size != 2: const = Const(graph, {'value': int64_array([*B, I, O])}).create_node() reshape = Reshape(graph, {'name': name + '/output_reshape'}).create_node() dst = node.out_port(0).get_destination() node.out_port(0).get_connection().set_destination(reshape.in_port(0)) const.out_port(0).connect(reshape.in_port(1)) reshape.out_port(0).connect(dst) node.infer(node) const.infer(const) reshape.infer(reshape) else: assert A_shape.size == out_shape.size assert B_shape.size <= out_shape.size if B_shape.size != out_shape.size: unsqueeze_dim = Const(graph, {'value': int64_array(list(range(out_shape.size - B_shape.size))) }).create_node() unsqueeze = Unsqueeze(graph, {}).create_node() B_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(unsqueeze.out_port(0)) unsqueeze.in_port(0).connect(B_source) unsqueeze.in_port(1).connect(unsqueeze_dim.out_port(0)) unsqueeze_dim.infer(unsqueeze_dim) unsqueeze.infer(unsqueeze) Gemm.update_node_stat(node, { 'transpose_a': node.has_and_set('transpose_a'), 'transpose_b': node.has_and_set('transpose_b'), })