Ejemplo n.º 1
0
    def replace_pattern(self, graph: Graph, match: dict):
        bias_add = match['BiasAdd']

        # Replace BiasAdd by Add operation
        new_add = Add(graph, {'name': bias_add.id + '/Add'}).create_node()

        bias_add.in_port(0).get_connection().set_destination(
            new_add.in_port(0))
        bias_add.in_port(1).get_connection().set_destination(
            new_add.in_port(1))
        bias_add.out_port(0).get_connection().set_source(new_add.out_port(0))

        if bias_add.data_format != 'NCHW':
            return

        input_shape = new_add.in_port(0).data.get_shape()
        bias_shape = new_add.in_port(1).data.get_shape()
        assert len(bias_shape) == 1

        unsqueeze_dims = np.arange(len(input_shape))
        channel_dim = get_features_dim('NCHW', len(input_shape))
        unsqueeze_dims = np.delete(unsqueeze_dims, channel_dim, 0)

        unsqueeze_node = Unsqueeze(graph, {
            'name': new_add.id + '/BiasUnsqueeze'
        }).create_node()
        unsqueeze_dims_node = Const(graph, {
            'name': new_add.id + '/Dims',
            'value': unsqueeze_dims
        }).create_node()
        # Reconnecting nodes
        unsqueeze_node.in_port(1).connect(unsqueeze_dims_node.out_port(0))
        unsqueeze_node['override_output_shape'] = True

        new_add.in_port(1).get_connection().insert_node(unsqueeze_node)
Ejemplo n.º 2
0
    def replace_pattern(graph: Graph, match: dict):
        fq = match['fq']

        if len(fq.out_port(0).get_destinations()) > 1:
            # FQ should have only one child -- Transpose for optimization
            return

        transpose = match['transpose']
        name = fq.soft_get('name', fq.id)

        input_shape = transpose.in_port(0).data.get_shape()

        # detaching transpose from the graph
        transpose.out_port(0).get_connection().set_source(transpose.in_port(0).get_connection().get_source())
        transpose.in_port(0).disconnect()

        for idx, port in fq.in_ports().items():
            transpose_copy = transpose.copy_node({'override_output_shape': True})
            transpose.in_port(1).get_source().connect(transpose_copy.in_port(1))

            start_port = transpose_copy.in_port(0)

            idxs = np.arange(len(input_shape) - len(port.data.get_shape()))
            if idxs.size != 0:
                axis = Const(graph, {'name': name + '/in_{}_unsqueeze_axis'.format(idx),
                                     'value': int64_array(idxs)}).create_node()
                unsqueeze = Unsqueeze(graph, {'name': name + '/in_{}_unsqueeze'.format(idx)}).create_node()
                axis.out_port(0).connect(unsqueeze.in_port(1))
                unsqueeze.out_port(0).connect(transpose_copy.in_port(0))
                start_port = unsqueeze.in_port(0)

            src = port.get_source()
            port.get_connection().set_source(transpose_copy.out_port(0))
            src.connect(start_port)
 def find_and_replace_pattern(self, graph: Graph):
     for expand_dims_node in graph.get_op_nodes(op='ExpandDims'):
         if len(expand_dims_node.in_nodes()) == 1:
             expand_axis = expand_dims_node.expand_axis
             if not isinstance(expand_axis, np.ndarray):
                 expand_axis = int64_array([expand_axis]).flatten()
             unsqueeze_node = Unsqueeze(graph, {'name': expand_dims_node.id}).create_node()
             unsqueeze_dims_node = Const(graph, {'name': expand_dims_node.id + '/Dims',
                                                 'value': expand_axis}).create_node()
             expand_dims_node.in_port(0).get_connection().set_destination(unsqueeze_node.in_port(0))
             expand_dims_node.out_port(0).get_connection().set_source(unsqueeze_node.out_port(0))
             unsqueeze_node.in_port(1).connect(unsqueeze_dims_node.out_port(0))
         else:
             log.error('The ExpandDims node {} has more than 1 input'.format(expand_dims_node.soft_get('name')))
Ejemplo n.º 4
0
    def replace_pattern(graph: Graph, match: dict):
        """
        Workarounds not supported type of Tile in Inference Engine (Tiles are supported for 2-D or 4-D tensors):
        Searches for Tiles with 3D shapes and covers it with Reshapes.

        Example: Tile (axis=1, tiles=16):
            in_shape: [1,1,101]
            out_shape: [1,16,101]

        Old behaviour:
            Tile -> [1,16,101]
        New behaviour:
            Reshape [1,1,101,1] -> Tile -> [1,16,101,1] -> Reshape [1,16,101]
        """
        node = match['tile']
        name = node.soft_get('name', node.id)

        out_shape = node.out_port(0).data.get_shape()
        assert out_shape is not None, 'Output shape is undefined for {} in back phase'.format(name)
        if out_shape.size != 3:
            return

        inp_shape = node.in_port(0).data.get_shape()
        assert inp_shape is not None, 'Input shape is undefined for {} in back phase'.format(name)

        unsqueeze_dim = Const(graph, {'name': name + '/3D_Tile_Unsqueeze_dim', 'value': int64_array([3])}).create_node()
        unsqueeze = Unsqueeze(graph, {'name': name + '/3D_Tile_Unsqueeze', 'override_output_shape': True}).create_node()
        unsqueeze_dim.out_port(0).connect(unsqueeze.in_port(1))

        const = Const(graph, {'name': name + '/additional_axis', 'value': int64_array([1])}).create_node()
        new_tiles = new_shape_node_from_shape_nodes([node.in_port(1).get_source().node, const])

        node.in_port(1).get_connection().set_source(new_tiles.out_port(0))

        squeeze_dim = Const(graph, {'name': name + '/3D_Tile_Squeeze_dim', 'value': int64_array([3])}).create_node()
        squeeze = Squeeze(graph, {'name': name + '/3D_Tile_Squeeze', 'override_output_shape': True}).create_node()
        squeeze_dim.out_port(0).connect(squeeze.in_port(1))

        source = node.in_port(0).get_source()
        node.in_port(0).get_connection().set_source(unsqueeze.out_port(0))
        unsqueeze.in_port(0).connect(source)

        node.out_port(0).get_connection().set_source(squeeze.out_port(0))
        node.out_port(0).connect(squeeze.in_port(0))

        node['override_output_shape'] = True
        new_tiles['override_output_shape'] = True
        node['need_shape_inference'] = True
Ejemplo n.º 5
0
    def add_unsqueeze_for_new(graph: Graph, ss_node: Node):
        log.info(
            "StridedSlice op with new axis mask '{}' has been detected".format(
                ss_node.id))
        if len(ss_node.in_nodes()) != 4 or len(ss_node.out_nodes()) != 1:
            return

        shape_out = ss_node.out_node().shape
        dim = np.array(range(len(ss_node['new_axis_mask'])))[np.array(
            ss_node['new_axis_mask'], dtype=bool)]
        ss_shape = []
        for i in range(0, len(ss_node['new_axis_mask'])):
            if not ss_node['new_axis_mask'][i]:
                ss_shape.append(shape_out[i])
            else:
                ss_node['new_axis_mask'][i] = 0

        ss_node.out_port(0).data.set_shape(ss_shape)

        # insert Unsqueeze
        unsqueeze_node = Unsqueeze(graph,
                                   dict(name=ss_node.name +
                                        '/Unsqueeze_new')).create_node()
        ss_node.out_port(0).get_connection().insert_node(unsqueeze_node)
        unsqueeze_node.out_port(0).data.set_shape(shape_out)

        dims_node = Const(graph, {
            'name': unsqueeze_node.id + '/Indices',
            'value': int64_array(dim)
        }).create_node()
        dims_node.out_port(0).connect(unsqueeze_node.in_port(1))
Ejemplo n.º 6
0
    def mxrepeat_decomposition(node: Node):
        graph = node.graph
        name = node.soft_get('name', node.id)

        rename_node(node, name + '/to_be_removed')

        # Unqueeze
        input_rank = Rank(graph, {'name': name + '/Rank'}).create_node()
        node.in_port(0).get_source().connect(input_rank.in_port(0))

        axis = get_canonical_axis_index_node(input_rank, node.axis)
        unsqueeze_axis = create_op_node_with_second_input(
            graph,
            Add,
            int64_array([1]), {'name': name + '/Unsqueeze/Axis'},
            input_node=axis)

        unsqueeze = Unsqueeze(graph, {
            'name': name + '/Unsqueeze'
        }).create_node()
        unsqueeze.in_port(1).connect(unsqueeze_axis.out_port(0))

        # Tile (1, 1, ..., repeats, ..., 1)
        # we generate tile array according to the following table:

        # parts:       |      first      |  repeats |  second     |
        # i:           | 0, 1, ..., axis,| axis + 1,| ..., rank+1 |
        # tile_array:  | 1, 1, ...,  1  ,| repeats ,| ...,   1    |

        one = Const(graph, {
            'name': name + '/Broadcast/One',
            'value': int64_array([1])
        }).create_node()
        first_ones = Broadcast(graph, {
            'name': name + '/Broadcast/Ones_first_part'
        }).create_node()
        first_ones.in_port(0).connect(one.out_port(0))
        first_ones.in_port(1).connect(unsqueeze_axis.out_port(0))

        repeats = Const(graph, {
            'name': name + '/repeats',
            'value': int64_array([node.repeats])
        }).create_node()

        second_ones = Broadcast(graph, {
            'name': name + '/Broadcast/Ones_second_part'
        }).create_node()
        second_part_broadcast_shape = Sub(
            graph, {
                'name': name + '/Broadcast/Shape/second_part'
            }).create_node()
        second_part_broadcast_shape.in_port(0).connect(input_rank.out_port(0))
        second_part_broadcast_shape.in_port(1).connect(
            unsqueeze_axis.out_port(0))
        second_ones.in_port(0).connect(one.out_port(0))
        second_ones.in_port(1).connect(second_part_broadcast_shape.out_port(0))

        tile_repeats = new_shape_node_from_shape_nodes(
            [first_ones, repeats, second_ones])
        tile = Tile(graph, {'name': name + '/Tile'}).create_node()
        tile.in_port(1).connect(tile_repeats.out_port(0))

        # Reshape (input_shape[:axis], input_shape[axis] * repeats, input_shape[axis+1:])
        # we generate reshape dim array according to the following table:

        # parts:       |    first   |                rep           |  second   |
        # i:           | 0, 1, ... ,|               axis,          | ..., rank |
        # dim_array:   | inp_sh[i] ,| input_shape[axis] * repeats ,| inp_sh[i] |

        input_shape = Shape(graph, {'name': name + '/Shape'}).create_node()
        node.in_port(0).get_source().connect(input_shape.in_port(0))

        first_input_shape_part = get_shape_values_by_range_idxs(
            input_shape,
            input_rank,
            begin=0,
            end=node.axis,
            include_begin=True,
            include_end=False)

        original_axis_dim = create_op_with_const_inputs(
            graph,
            Gather, {2: int64_array(0)}, {'name': name + '/OriginalDim'},
            input_node=input_shape)
        original_axis_dim.in_port(1).connect(axis.out_port(0))

        repeated_dimention = Mul(graph, {
            'name': name + '/RepeatedDim'
        }).create_node()
        repeated_dimention.in_port(0).connect(original_axis_dim.out_port(0))
        repeated_dimention.in_port(1).connect(repeats.out_port(0))

        second_input_shape_part = get_shape_values_by_range_idxs(
            input_shape,
            input_rank,
            begin=node.axis,
            end=-1,
            include_begin=False,
            include_end=True)

        output_shape = new_shape_node_from_shape_nodes([
            first_input_shape_part, repeated_dimention, second_input_shape_part
        ])

        reshape = Reshape(graph, {'name': name}).create_node()
        rename_node(reshape, name)
        reshape.in_port(1).connect(output_shape.out_port(0))

        # Final connections
        node.in_port(0).get_connection().set_destination(unsqueeze.in_port(0))
        tile.in_port(0).connect(unsqueeze.out_port(0))
        reshape.in_port(0).connect(tile.out_port(0))
        node.out_port(0).get_connection().set_source(reshape.out_port(0))
Ejemplo n.º 7
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['matmul']
        name = node.soft_get('name', node.id)

        A_shape = node.in_port(0).data.get_shape()
        B_shape = node.in_port(1).data.get_shape()
        out_shape = node.out_port(0).data.get_shape()

        assert A_shape is not None and B_shape is not None and out_shape is not None

        B_value = node.in_port(1).data.get_value()
        if (B_value is not None or node.in_port(1).get_source().node.has_and_set('stop_value_propagation')) and B_shape[
            B_shape != 1].size <= 2:
            # transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O]
            # to FullyConnected representation: [I, K] * [O, K] = [I, O]
            B, I, K, O, aligned_A_shape, aligned_B_shape = MatMulToFullyConnected.get_matmul_BIKO(node)

            # weights normalization
            if not node.transpose_b:
                # FullyConnected weights layout is OI
                # MatMul second input layout is (B)IO
                transpose_order = list(range(B_shape.size))
                transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1]

                order = Const(graph, {'value': int64_array(transpose_order)}).create_node()
                transpose = Transpose(graph, {'name': name + '/weights_transpose'}).create_node()

                weights_source = node.in_port(1).get_source()
                node.in_port(1).get_connection().set_source(transpose.out_port(0))
                transpose.in_port(0).connect(weights_source)
                transpose.in_port(1).connect(order.out_port(0))

                order.infer(order)
                transpose.infer(transpose)

            if node.in_port(1).data.get_shape().size != 2:
                const = Const(graph, {'value': int64_array([-1, K])}).create_node()
                reshape = Reshape(graph, {'name': name + '/weights_reshape'}).create_node()

                weights_source = node.in_port(1).get_source()
                node.in_port(1).get_connection().set_source(reshape.out_port(0))

                reshape.in_port(0).connect(weights_source)
                reshape.in_port(1).connect(const.out_port(0))

                const.infer(const)
                reshape.infer(reshape)

            assert np.all(np.array_equal(node.in_port(1).data.get_shape(), int64_array([O, K]))), \
                "MatMul `{}` was not converted to FullyConnected: wrong weights shape: {}, " \
                "B={}, I={}, K={}, O={}".format(name, node.in_port(1).data.get_shape(), B, I, K, O)

            node.in_port(1).bin = 'weights'
            del node['transpose_b']

            # input normalization
            if node.transpose_a:
                transpose_order = list(range(A_shape.size))
                transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1]

                order = Const(graph, {'value': int64_array(transpose_order)}).create_node()
                transpose = Transpose(graph, {'name': name + '/input_transpose'}).create_node()

                input_source = node.in_port(0).get_source()
                node.in_port(0).get_connection().set_source(transpose.out_port(0))
                transpose.in_port(0).connect(input_source)
                transpose.in_port(1).connect(order.out_port(0))

                order.infer(order)
                transpose.infer(transpose)

            if A_shape.size != 2:
                const = Const(graph, {'value': int64_array([-1, K])}).create_node()
                reshape = Reshape(graph, {'name': name + '/input_reshape'}).create_node()

                input_source = node.in_port(0).get_source()
                node.in_port(0).get_connection().set_source(reshape.out_port(0))
                reshape.in_port(0).connect(input_source)
                reshape.in_port(1).connect(const.out_port(0))

                const.infer(const)
                reshape.infer(reshape)

            assert np.all(np.array_equal(node.in_port(0).data.get_shape(), int64_array([np.prod(B) * I, K]))), \
                "MatMul `{}` wasn't converted to FullyConnected: wrong input shape: {}, " \
                "B={}, I={}, K={}, O={}".format(name, node.in_port(0).data.get_shape(), B, I, K, O)

            del node['transpose_a']

            FullyConnected.update_node_stat(node, {'out-size': O})

            # output normalization
            if out_shape.size != 2:
                const = Const(graph, {'value': int64_array([*B, I, O])}).create_node()
                reshape = Reshape(graph, {'name': name + '/output_reshape'}).create_node()

                dst = node.out_port(0).get_destination()
                node.out_port(0).get_connection().set_destination(reshape.in_port(0))
                const.out_port(0).connect(reshape.in_port(1))
                reshape.out_port(0).connect(dst)

                node.infer(node)

                const.infer(const)
                reshape.infer(reshape)

        else:
            assert A_shape.size == out_shape.size
            assert B_shape.size <= out_shape.size
            if B_shape.size != out_shape.size:
                unsqueeze_dim = Const(graph, {'value': int64_array(list(range(out_shape.size - B_shape.size)))
                                              }).create_node()
                unsqueeze = Unsqueeze(graph, {}).create_node()
                B_source = node.in_port(1).get_source()
                node.in_port(1).get_connection().set_source(unsqueeze.out_port(0))
                unsqueeze.in_port(0).connect(B_source)
                unsqueeze.in_port(1).connect(unsqueeze_dim.out_port(0))

                unsqueeze_dim.infer(unsqueeze_dim)
                unsqueeze.infer(unsqueeze)

            Gemm.update_node_stat(node, {
                'transpose_a': node.has_and_set('transpose_a'),
                'transpose_b': node.has_and_set('transpose_b'),
            })