Exemple #1
0
 def test_reshape_infer(self, input_value, input_shape, output_shape,
                        ref_value, ref_shape):
     graph = build_graph(
         nodes_attributes, [('input', 'data'), ('data', 'reshape'),
                            ('output_shape', 'output_shape_data'),
                            ('output_shape_data', 'reshape'),
                            ('reshape', 'reshape_out')], {
                                'data': {
                                    'shape': input_shape,
                                    'value': input_value
                                },
                                'output_shape': {
                                    'value': output_shape,
                                    'shape': output_shape.shape
                                },
                                'output_shape_data': {
                                    'value': output_shape,
                                    'shape': output_shape.shape
                                },
                            })
     node = Node(graph, 'reshape')
     Reshape.infer(node)
     if ref_value is not None:
         self.assertTrue(
             strict_compare_tensors(
                 node.out_port(0).data.get_value(), shape_array(ref_value)))
     self.assertTrue(
         strict_compare_tensors(
             node.out_port(0).data.get_shape(), shape_array(ref_shape)))
 def replace_pattern(graph: Graph, match: dict):
     node = match['op']
     input_shape = node.in_port(0).data.get_shape()
     if len(input_shape) > 2:
         new_shape = Const(graph, {
             'value': np.array([0, -1], dtype=np.int64)
         }).create_node()
         reshape = Reshape(graph, {}).create_node()
         source = node.in_port(0).get_source()
         node.in_port(0).get_connection().set_source(reshape.out_port(0))
         source.connect(reshape.in_port(0))
         new_shape.out_port(0).connect(reshape.in_port(1))
         new_shape.infer(new_shape)
         reshape.infer(reshape)
Exemple #3
0
    def replace_pattern(self, graph: Graph, match: dict):
        conv = match['conv']

        assert len(conv.out_nodes()) == 1, "Convolution operation {} should have 1 output data node".format(conv.id)
        out_data = conv.out_node()

        assert out_data.has_valid('shape'), 'Output shape is undefined for {} in back phase'.format(conv.id)
        out_shape = out_data.shape

        if out_shape.size != 3:
            return

        assert len(conv.in_nodes()) >= 1, "Convolution operation {} should have more than 1 input data node".format(
            conv.id)
        inp_data = conv.in_node()

        assert inp_data.has_valid('shape'), 'Input shape is undefined for {} in back phase'.format(conv.id)
        inp_shape = inp_data.shape
        new_inp_shape = np.insert(inp_shape, 2, 1)

        # setting to None to be overwritten by infer function
        conv.kernel_spatial_idx = None
        conv.spatial_dims = None

        # inserting fake H dimension
        conv.dilation = np.insert(conv.dilation, 2, 1)
        conv.kernel_spatial = np.append([1], conv.kernel_spatial)
        conv.pad = np.insert(conv.pad, 2, [0, 0], axis=0)
        conv.stride = np.insert(conv.stride, 2, 1)

        weights_node = conv.in_node(1)
        weights_node.value = np.reshape(weights_node.value, np.insert(weights_node.value.shape, 2, 1))
        weights_node.shape = np.array(weights_node.value.shape, dtype=np.int64)

        reshape = Reshape(graph, {'name': conv.name + '/reshape'}).create_node()
        reshape_dim = Const(graph, {'value': new_inp_shape, 'name': reshape.id + '/Dim'}).create_node()
        conv.in_port(0).get_connection().insert_node(reshape)
        reshape.in_port(1).connect(reshape_dim.out_port(0))

        reshape_back = Reshape(graph, {'name': conv.name + '/reshape_back'}).create_node()
        reshape_back_dim = Const(graph, {'value': out_shape, 'name': reshape.id + '/Dim'}).create_node()
        conv.out_port(0).get_connection().insert_node(reshape_back)
        reshape_back.in_port(1).connect(reshape_back_dim.out_port(0))

        # run shape inference manually for several nodes to override shapes of the model nodes which changed behaviour
        reshape_dim.infer(reshape_dim)
        reshape.infer(reshape)
        conv.infer(conv)
    def replace_pattern(graph: Graph, match: dict):
        node = match['matmul']
        name = node.soft_get('name', node.id)

        A_shape = node.in_port(0).data.get_shape()
        B_shape = node.in_port(1).data.get_shape()
        out_shape = node.out_port(0).data.get_shape()

        assert A_shape is not None and B_shape is not None and out_shape is not None

        B_value = node.in_port(1).data.get_value()
        if (B_value is not None or node.in_port(1).get_source().node.has_and_set('stop_value_propagation')) and B_shape[
            B_shape != 1].size <= 2:
            # transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O]
            # to FullyConnected representation: [I, K] * [O, K] = [I, O]
            B, I, K, O, aligned_A_shape, aligned_B_shape = MatMulToFullyConnected.get_matmul_BIKO(node)

            # weights normalization
            if not node.transpose_b:
                # FullyConnected weights layout is OI
                # MatMul second input layout is (B)IO
                transpose_order = list(range(B_shape.size))
                transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1]

                order = Const(graph, {'value': int64_array(transpose_order)}).create_node()
                transpose = Transpose(graph, {'name': name + '/weights_transpose'}).create_node()

                weights_source = node.in_port(1).get_source()
                node.in_port(1).get_connection().set_source(transpose.out_port(0))
                transpose.in_port(0).connect(weights_source)
                transpose.in_port(1).connect(order.out_port(0))

                order.infer(order)
                transpose.infer(transpose)

            if node.in_port(1).data.get_shape().size != 2:
                const = Const(graph, {'value': int64_array([-1, K])}).create_node()
                reshape = Reshape(graph, {'name': name + '/weights_reshape'}).create_node()

                weights_source = node.in_port(1).get_source()
                node.in_port(1).get_connection().set_source(reshape.out_port(0))

                reshape.in_port(0).connect(weights_source)
                reshape.in_port(1).connect(const.out_port(0))

                const.infer(const)
                reshape.infer(reshape)

            assert np.all(np.array_equal(node.in_port(1).data.get_shape(), int64_array([O, K]))), \
                "MatMul `{}` was not converted to FullyConnected: wrong weights shape: {}, " \
                "B={}, I={}, K={}, O={}".format(name, node.in_port(1).data.get_shape(), B, I, K, O)

            node.in_port(1).bin = 'weights'
            del node['transpose_b']

            # input normalization
            if node.transpose_a:
                transpose_order = list(range(A_shape.size))
                transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1]

                order = Const(graph, {'value': int64_array(transpose_order)}).create_node()
                transpose = Transpose(graph, {'name': name + '/input_transpose'}).create_node()

                input_source = node.in_port(0).get_source()
                node.in_port(0).get_connection().set_source(transpose.out_port(0))
                transpose.in_port(0).connect(input_source)
                transpose.in_port(1).connect(order.out_port(0))

                order.infer(order)
                transpose.infer(transpose)

            if A_shape.size != 2:
                const = Const(graph, {'value': int64_array([-1, K])}).create_node()
                reshape = Reshape(graph, {'name': name + '/input_reshape'}).create_node()

                input_source = node.in_port(0).get_source()
                node.in_port(0).get_connection().set_source(reshape.out_port(0))
                reshape.in_port(0).connect(input_source)
                reshape.in_port(1).connect(const.out_port(0))

                const.infer(const)
                reshape.infer(reshape)

            assert np.all(np.array_equal(node.in_port(0).data.get_shape(), int64_array([np.prod(B) * I, K]))), \
                "MatMul `{}` wasn't converted to FullyConnected: wrong input shape: {}, " \
                "B={}, I={}, K={}, O={}".format(name, node.in_port(0).data.get_shape(), B, I, K, O)

            del node['transpose_a']

            FullyConnected.update_node_stat(node, {'out-size': O})

            # output normalization
            if out_shape.size != 2:
                const = Const(graph, {'value': int64_array([*B, I, O])}).create_node()
                reshape = Reshape(graph, {'name': name + '/output_reshape'}).create_node()

                dst = node.out_port(0).get_destination()
                node.out_port(0).get_connection().set_destination(reshape.in_port(0))
                const.out_port(0).connect(reshape.in_port(1))
                reshape.out_port(0).connect(dst)

                node.infer(node)

                const.infer(const)
                reshape.infer(reshape)

        else:
            assert A_shape.size == out_shape.size
            assert B_shape.size <= out_shape.size
            if B_shape.size != out_shape.size:
                unsqueeze_dim = Const(graph, {'value': int64_array(list(range(out_shape.size - B_shape.size)))
                                              }).create_node()
                unsqueeze = Unsqueeze(graph, {}).create_node()
                B_source = node.in_port(1).get_source()
                node.in_port(1).get_connection().set_source(unsqueeze.out_port(0))
                unsqueeze.in_port(0).connect(B_source)
                unsqueeze.in_port(1).connect(unsqueeze_dim.out_port(0))

                unsqueeze_dim.infer(unsqueeze_dim)
                unsqueeze.infer(unsqueeze)

            Gemm.update_node_stat(node, {
                'transpose_a': node.has_and_set('transpose_a'),
                'transpose_b': node.has_and_set('transpose_b'),
            })
    def replace_pattern(graph: Graph, match: dict):
        """
        Workarounds not supported type of Tile in Inference Engine (Tiles are supported for 2-D or 4-D tensors):
        Searches for Tiles with 3D shapes and covers it with Reshapes.

        Example: Tile (axis=1, tiles=16):
            in_shape: [1,1,101]
            out_shape: [1,16,101]

        Old behaviour:
            Tile -> [1,16,101]
        New behaviour:
            Reshape [1,1,101,1] -> Tile -> [1,16,101,1] -> Reshape [1,16,101]
        """
        tile = match['tile']

        assert len(tile.out_nodes(
        )) == 1, "Tile operation {} should have 1 output data node".format(
            tile.id)
        out_data = tile.out_node()

        assert out_data.has_valid(
            'shape'), 'Output shape is undefined for {} in back phase'.format(
                tile.id)
        out_shape = out_data.shape

        if out_shape.size != 3:
            return

        assert len(tile.in_nodes(
        )) == 1, "Tile operation {} should have 1 input data node".format(
            tile.id)
        inp_data = tile.in_node()

        assert inp_data.has_valid(
            'shape'), 'Input shape is undefined for {} in back phase'.format(
                tile.id)
        inp_shape = inp_data.shape
        new_inp_shape = np.append(inp_shape, [1])

        reshape = Reshape(graph, {
            'name': tile.name + '/reshape'
        }).create_node()
        reshape_dim = Const(graph, {
            'value': new_inp_shape,
            'name': reshape.id + '/Dim'
        }).create_node()
        tile.in_port(0).get_connection().insert_node(reshape)
        reshape.in_port(1).connect(reshape_dim.out_port(0))

        reshape_back = Reshape(graph, {
            'name': tile.name + '/reshape_back'
        }).create_node()
        reshape_back_dim = Const(graph, {
            'value': out_shape,
            'name': reshape.id + '/Dim'
        }).create_node()
        tile.out_port(0).get_connection().insert_node(reshape_back)
        reshape_back.in_port(1).connect(reshape_back_dim.out_port(0))

        # run shape inference manually for several nodes to override shapes of the model nodes which changed behaviour
        reshape_dim.infer(reshape_dim)
        reshape.infer(reshape)
        tile.infer(tile)