def test_unsqueeze_infer_negative_indices(self): unsq_dims = np.array([-1]) graph = build_graph(self.nodes_attributes, [('data_1', 'unsq'), ('unsq_dims_const', 'unsq_dims'), ('unsq_dims', 'unsq'), ('unsq', 'data_2')], {'data_1': {'shape': np.array([2, 3, 64, 64])}, 'unsq_dims': {'value': unsq_dims, 'shape': unsq_dims.shape}, 'unsq_dims_const': {'value': unsq_dims, 'shape': unsq_dims.shape}, }) graph_ref = build_graph(self.nodes_attributes, [('data_1', 'unsq'), ('unsq_dims_const', 'unsq_dims'), ('unsq_dims', 'unsq'), ('unsq', 'data_2')], {'data_1': {'shape': np.array([2, 3, 64, 64])}, 'unsq_dims': {'value': int64_array([4]), 'shape': unsq_dims.shape}, 'unsq_dims_const': {'value': int64_array([4]), 'shape': unsq_dims.shape}, 'data_2': {'shape': np.array([2, 3, 64, 64, 1])}, }) unsqueeze_node = Node(graph, 'unsq') Unsqueeze.infer(unsqueeze_node) (flag, resp) = compare_graphs(graph, graph_ref, 'data_2') self.assertTrue(flag, resp)
def test_unsqueeze_infer(self): graph = build_graph( self.nodes_attributes, [('data_1', 'unsq'), ('unsq', 'data_2')], { 'data_1': { 'shape': np.array([1, 3, 64, 64]) }, 'unsq': { 'unsqueeze_dims': np.array([0, 4]) } }) graph_ref = build_graph( self.nodes_attributes, [('data_1', 'unsq'), ('unsq', 'data_2')], { 'data_1': { 'shape': np.array([1, 3, 64, 64]) }, 'unsq': { 'unsqueeze_dims': np.array([0, 4]) }, 'data_2': { 'shape': np.array([1, 1, 3, 64, 1, 64]) } }) unsqueeze_node = Node(graph, 'unsq') Unsqueeze.infer(unsqueeze_node) (flag, resp) = compare_graphs(graph, graph_ref, 'data_2') self.assertTrue(flag, resp)
def test_unsqueeze_infer(self, input_shape, unsq_dims, output_shape, ref_uns_dims, input_value, output_value): graph = build_graph( self.nodes_attributes, [('data_1', 'unsq'), ('unsq_dims_const', 'unsq_dims'), ('unsq_dims', 'unsq'), ('unsq', 'data_2')], { 'data_1': { 'shape': input_shape, 'value': input_value }, 'unsq_dims': { 'value': unsq_dims, 'shape': unsq_dims.shape }, 'unsq_dims_const': { 'value': unsq_dims, 'shape': unsq_dims.shape }, }) graph_ref = build_graph( self.nodes_attributes, [('data_1', 'unsq'), ('unsq_dims_const', 'unsq_dims'), ('unsq_dims', 'unsq'), ('unsq', 'data_2')], { 'data_1': { 'shape': input_shape, 'value': input_value }, 'unsq_dims': { 'value': ref_uns_dims, 'shape': ref_uns_dims.shape }, 'unsq_dims_const': { 'value': ref_uns_dims, 'shape': ref_uns_dims.shape }, 'data_2': { 'shape': output_shape, 'value': output_value }, }) unsqueeze_node = Node(graph, 'unsq') Unsqueeze.infer(unsqueeze_node) (flag, resp) = compare_graphs(graph, graph_ref, 'data_2') self.assertTrue(flag, resp) self.assertTrue( strict_compare_tensors( Node(graph, 'data_2').shape, Node(graph_ref, 'data_2').shape)) if Node(graph_ref, 'data_2').value is not None: self.assertTrue( strict_compare_tensors( Node(graph, 'data_2').value, Node(graph_ref, 'data_2').value))
def replace_pattern(graph: Graph, match: dict): node = match['matmul'] name = node.soft_get('name', node.id) A_shape = node.in_port(0).data.get_shape() B_shape = node.in_port(1).data.get_shape() out_shape = node.out_port(0).data.get_shape() assert A_shape is not None and B_shape is not None and out_shape is not None B_value = node.in_port(1).data.get_value() if (B_value is not None or node.in_port(1).get_source().node.has_and_set('stop_value_propagation')) and B_shape[ B_shape != 1].size <= 2: # transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O] # to FullyConnected representation: [I, K] * [O, K] = [I, O] B, I, K, O, aligned_A_shape, aligned_B_shape = MatMulToFullyConnected.get_matmul_BIKO(node) # weights normalization if not node.transpose_b: # FullyConnected weights layout is OI # MatMul second input layout is (B)IO transpose_order = list(range(B_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = Transpose(graph, {'name': name + '/weights_transpose'}).create_node() weights_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(weights_source) transpose.in_port(1).connect(order.out_port(0)) order.infer(order) transpose.infer(transpose) if node.in_port(1).data.get_shape().size != 2: const = Const(graph, {'value': int64_array([-1, K])}).create_node() reshape = Reshape(graph, {'name': name + '/weights_reshape'}).create_node() weights_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(reshape.out_port(0)) reshape.in_port(0).connect(weights_source) reshape.in_port(1).connect(const.out_port(0)) const.infer(const) reshape.infer(reshape) assert np.all(np.array_equal(node.in_port(1).data.get_shape(), int64_array([O, K]))), \ "MatMul `{}` was not converted to FullyConnected: wrong weights shape: {}, " \ "B={}, I={}, K={}, O={}".format(name, node.in_port(1).data.get_shape(), B, I, K, O) node.in_port(1).bin = 'weights' del node['transpose_b'] # input normalization if node.transpose_a: transpose_order = list(range(A_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = Transpose(graph, {'name': name + '/input_transpose'}).create_node() input_source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(input_source) transpose.in_port(1).connect(order.out_port(0)) order.infer(order) transpose.infer(transpose) if A_shape.size != 2: const = Const(graph, {'value': int64_array([-1, K])}).create_node() reshape = Reshape(graph, {'name': name + '/input_reshape'}).create_node() input_source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape.out_port(0)) reshape.in_port(0).connect(input_source) reshape.in_port(1).connect(const.out_port(0)) const.infer(const) reshape.infer(reshape) assert np.all(np.array_equal(node.in_port(0).data.get_shape(), int64_array([np.prod(B) * I, K]))), \ "MatMul `{}` wasn't converted to FullyConnected: wrong input shape: {}, " \ "B={}, I={}, K={}, O={}".format(name, node.in_port(0).data.get_shape(), B, I, K, O) del node['transpose_a'] FullyConnected.update_node_stat(node, {'out-size': O}) # output normalization if out_shape.size != 2: const = Const(graph, {'value': int64_array([*B, I, O])}).create_node() reshape = Reshape(graph, {'name': name + '/output_reshape'}).create_node() dst = node.out_port(0).get_destination() node.out_port(0).get_connection().set_destination(reshape.in_port(0)) const.out_port(0).connect(reshape.in_port(1)) reshape.out_port(0).connect(dst) node.infer(node) const.infer(const) reshape.infer(reshape) else: assert A_shape.size == out_shape.size assert B_shape.size <= out_shape.size if B_shape.size != out_shape.size: unsqueeze_dim = Const(graph, {'value': int64_array(list(range(out_shape.size - B_shape.size))) }).create_node() unsqueeze = Unsqueeze(graph, {}).create_node() B_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(unsqueeze.out_port(0)) unsqueeze.in_port(0).connect(B_source) unsqueeze.in_port(1).connect(unsqueeze_dim.out_port(0)) unsqueeze_dim.infer(unsqueeze_dim) unsqueeze.infer(unsqueeze) Gemm.update_node_stat(node, { 'transpose_a': node.has_and_set('transpose_a'), 'transpose_b': node.has_and_set('transpose_b'), })