コード例 #1
0
    def test_unsqueeze_infer_negative_indices(self):
        unsq_dims = np.array([-1])
        graph = build_graph(self.nodes_attributes,
                            [('data_1', 'unsq'),
                             ('unsq_dims_const', 'unsq_dims'),
                             ('unsq_dims', 'unsq'),
                             ('unsq', 'data_2')],
                            {'data_1': {'shape': np.array([2, 3, 64, 64])},
                             'unsq_dims': {'value': unsq_dims, 'shape': unsq_dims.shape},
                             'unsq_dims_const': {'value': unsq_dims, 'shape': unsq_dims.shape},
                             })

        graph_ref = build_graph(self.nodes_attributes,
                                [('data_1', 'unsq'),
                                 ('unsq_dims_const', 'unsq_dims'),
                                 ('unsq_dims', 'unsq'),
                                 ('unsq', 'data_2')],
                                {'data_1': {'shape': np.array([2, 3, 64, 64])},
                                 'unsq_dims': {'value': int64_array([4]), 'shape': unsq_dims.shape},
                                 'unsq_dims_const': {'value': int64_array([4]), 'shape': unsq_dims.shape},
                                 'data_2': {'shape': np.array([2, 3, 64, 64, 1])},
                                 })

        unsqueeze_node = Node(graph, 'unsq')
        Unsqueeze.infer(unsqueeze_node)

        (flag, resp) = compare_graphs(graph, graph_ref, 'data_2')
        self.assertTrue(flag, resp)
コード例 #2
0
    def test_unsqueeze_infer(self):
        graph = build_graph(
            self.nodes_attributes, [('data_1', 'unsq'), ('unsq', 'data_2')], {
                'data_1': {
                    'shape': np.array([1, 3, 64, 64])
                },
                'unsq': {
                    'unsqueeze_dims': np.array([0, 4])
                }
            })

        graph_ref = build_graph(
            self.nodes_attributes, [('data_1', 'unsq'), ('unsq', 'data_2')], {
                'data_1': {
                    'shape': np.array([1, 3, 64, 64])
                },
                'unsq': {
                    'unsqueeze_dims': np.array([0, 4])
                },
                'data_2': {
                    'shape': np.array([1, 1, 3, 64, 1, 64])
                }
            })

        unsqueeze_node = Node(graph, 'unsq')
        Unsqueeze.infer(unsqueeze_node)

        (flag, resp) = compare_graphs(graph, graph_ref, 'data_2')
        self.assertTrue(flag, resp)
コード例 #3
0
ファイル: unsqueeze_test.py プロジェクト: yding10/openvino
    def test_unsqueeze_infer(self, input_shape, unsq_dims, output_shape,
                             ref_uns_dims, input_value, output_value):
        graph = build_graph(
            self.nodes_attributes, [('data_1', 'unsq'),
                                    ('unsq_dims_const', 'unsq_dims'),
                                    ('unsq_dims', 'unsq'), ('unsq', 'data_2')],
            {
                'data_1': {
                    'shape': input_shape,
                    'value': input_value
                },
                'unsq_dims': {
                    'value': unsq_dims,
                    'shape': unsq_dims.shape
                },
                'unsq_dims_const': {
                    'value': unsq_dims,
                    'shape': unsq_dims.shape
                },
            })

        graph_ref = build_graph(
            self.nodes_attributes, [('data_1', 'unsq'),
                                    ('unsq_dims_const', 'unsq_dims'),
                                    ('unsq_dims', 'unsq'), ('unsq', 'data_2')],
            {
                'data_1': {
                    'shape': input_shape,
                    'value': input_value
                },
                'unsq_dims': {
                    'value': ref_uns_dims,
                    'shape': ref_uns_dims.shape
                },
                'unsq_dims_const': {
                    'value': ref_uns_dims,
                    'shape': ref_uns_dims.shape
                },
                'data_2': {
                    'shape': output_shape,
                    'value': output_value
                },
            })

        unsqueeze_node = Node(graph, 'unsq')
        Unsqueeze.infer(unsqueeze_node)

        (flag, resp) = compare_graphs(graph, graph_ref, 'data_2')
        self.assertTrue(flag, resp)
        self.assertTrue(
            strict_compare_tensors(
                Node(graph, 'data_2').shape,
                Node(graph_ref, 'data_2').shape))
        if Node(graph_ref, 'data_2').value is not None:
            self.assertTrue(
                strict_compare_tensors(
                    Node(graph, 'data_2').value,
                    Node(graph_ref, 'data_2').value))
コード例 #4
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['matmul']
        name = node.soft_get('name', node.id)

        A_shape = node.in_port(0).data.get_shape()
        B_shape = node.in_port(1).data.get_shape()
        out_shape = node.out_port(0).data.get_shape()

        assert A_shape is not None and B_shape is not None and out_shape is not None

        B_value = node.in_port(1).data.get_value()
        if (B_value is not None or node.in_port(1).get_source().node.has_and_set('stop_value_propagation')) and B_shape[
            B_shape != 1].size <= 2:
            # transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O]
            # to FullyConnected representation: [I, K] * [O, K] = [I, O]
            B, I, K, O, aligned_A_shape, aligned_B_shape = MatMulToFullyConnected.get_matmul_BIKO(node)

            # weights normalization
            if not node.transpose_b:
                # FullyConnected weights layout is OI
                # MatMul second input layout is (B)IO
                transpose_order = list(range(B_shape.size))
                transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1]

                order = Const(graph, {'value': int64_array(transpose_order)}).create_node()
                transpose = Transpose(graph, {'name': name + '/weights_transpose'}).create_node()

                weights_source = node.in_port(1).get_source()
                node.in_port(1).get_connection().set_source(transpose.out_port(0))
                transpose.in_port(0).connect(weights_source)
                transpose.in_port(1).connect(order.out_port(0))

                order.infer(order)
                transpose.infer(transpose)

            if node.in_port(1).data.get_shape().size != 2:
                const = Const(graph, {'value': int64_array([-1, K])}).create_node()
                reshape = Reshape(graph, {'name': name + '/weights_reshape'}).create_node()

                weights_source = node.in_port(1).get_source()
                node.in_port(1).get_connection().set_source(reshape.out_port(0))

                reshape.in_port(0).connect(weights_source)
                reshape.in_port(1).connect(const.out_port(0))

                const.infer(const)
                reshape.infer(reshape)

            assert np.all(np.array_equal(node.in_port(1).data.get_shape(), int64_array([O, K]))), \
                "MatMul `{}` was not converted to FullyConnected: wrong weights shape: {}, " \
                "B={}, I={}, K={}, O={}".format(name, node.in_port(1).data.get_shape(), B, I, K, O)

            node.in_port(1).bin = 'weights'
            del node['transpose_b']

            # input normalization
            if node.transpose_a:
                transpose_order = list(range(A_shape.size))
                transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1]

                order = Const(graph, {'value': int64_array(transpose_order)}).create_node()
                transpose = Transpose(graph, {'name': name + '/input_transpose'}).create_node()

                input_source = node.in_port(0).get_source()
                node.in_port(0).get_connection().set_source(transpose.out_port(0))
                transpose.in_port(0).connect(input_source)
                transpose.in_port(1).connect(order.out_port(0))

                order.infer(order)
                transpose.infer(transpose)

            if A_shape.size != 2:
                const = Const(graph, {'value': int64_array([-1, K])}).create_node()
                reshape = Reshape(graph, {'name': name + '/input_reshape'}).create_node()

                input_source = node.in_port(0).get_source()
                node.in_port(0).get_connection().set_source(reshape.out_port(0))
                reshape.in_port(0).connect(input_source)
                reshape.in_port(1).connect(const.out_port(0))

                const.infer(const)
                reshape.infer(reshape)

            assert np.all(np.array_equal(node.in_port(0).data.get_shape(), int64_array([np.prod(B) * I, K]))), \
                "MatMul `{}` wasn't converted to FullyConnected: wrong input shape: {}, " \
                "B={}, I={}, K={}, O={}".format(name, node.in_port(0).data.get_shape(), B, I, K, O)

            del node['transpose_a']

            FullyConnected.update_node_stat(node, {'out-size': O})

            # output normalization
            if out_shape.size != 2:
                const = Const(graph, {'value': int64_array([*B, I, O])}).create_node()
                reshape = Reshape(graph, {'name': name + '/output_reshape'}).create_node()

                dst = node.out_port(0).get_destination()
                node.out_port(0).get_connection().set_destination(reshape.in_port(0))
                const.out_port(0).connect(reshape.in_port(1))
                reshape.out_port(0).connect(dst)

                node.infer(node)

                const.infer(const)
                reshape.infer(reshape)

        else:
            assert A_shape.size == out_shape.size
            assert B_shape.size <= out_shape.size
            if B_shape.size != out_shape.size:
                unsqueeze_dim = Const(graph, {'value': int64_array(list(range(out_shape.size - B_shape.size)))
                                              }).create_node()
                unsqueeze = Unsqueeze(graph, {}).create_node()
                B_source = node.in_port(1).get_source()
                node.in_port(1).get_connection().set_source(unsqueeze.out_port(0))
                unsqueeze.in_port(0).connect(B_source)
                unsqueeze.in_port(1).connect(unsqueeze_dim.out_port(0))

                unsqueeze_dim.infer(unsqueeze_dim)
                unsqueeze.infer(unsqueeze)

            Gemm.update_node_stat(node, {
                'transpose_a': node.has_and_set('transpose_a'),
                'transpose_b': node.has_and_set('transpose_b'),
            })