def replace_pattern(graph: Graph, match: Dict[str, Node]):
        node = match['op']
        name = node.soft_get('name', node.id)
        input_shape = node.in_port(0).data.get_shape()
        second_input_shape = node.in_port(1).data.get_shape()

        begin_mask = np.zeros(len(input_shape), dtype=np.int64)
        end_mask = np.zeros(len(input_shape), dtype=np.int64)

        for i in node.axes:
            end_mask[i] = np.int64(1)

        new_axis_mask = np.zeros(len(input_shape), dtype=np.int64)
        shrink_axis_mask = np.zeros(len(input_shape), dtype=np.int64)
        ellipsis_mask = np.zeros(len(input_shape), dtype=np.int64)

        ss = create_op_with_const_inputs(graph, StridedSlice,
                                         port_value_dict={1: np.zeros(len(input_shape), dtype=np.int64)},
                                         op_attrs={'name': 'StridedSlice', 'begin_mask': begin_mask,
                                                   'end_mask': end_mask, 'new_axis_mask': new_axis_mask,
                                                   'shrink_axis_mask': shrink_axis_mask,
                                                   'ellipsis_mask': ellipsis_mask})
        if input_shape.size == second_input_shape.size:
            end = Shape(graph, dict(name=name + '/End')).create_node()
            end.in_port(0).connect(node.in_port(1).get_source())
            ss.in_port(2).connect(end.out_port(0))
        else:
            shape_like, rank_like = get_shape_and_rank_nodes_by_port(node.in_port(1).get_source())
            end_first_part = get_shape_values_by_range_idxs(shape_like, rank_like, 0, node.axes[-1], include_end=True)
            if input_shape.size - 1 == node.axes[-1]:
                ss.in_port(2).connect(end_first_part.out_port(0))
            else:
                shape, rank = get_shape_and_rank_nodes_by_port(node.in_port(0).get_source())
                end_second_part = get_shape_values_by_range_idxs(shape, rank, node.axes[-1], -1, include_begin=False,
                                                                 include_end=True)
                end = new_shape_node_from_shape_nodes([end_first_part, end_second_part])
                ss.in_port(2).connect(end.out_port(0))

        node.in_port(0).get_connection().set_destination(ss.in_port(0))
        node.in_port(1).disconnect()
        node.out_port(0).get_connection().set_source(ss.out_port(0))

        rename_nodes([(node, name + '/ShouldBeDeleted'), (ss, name)])
Пример #2
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['flatten']
        name = node.soft_get('name', node.id)

        assert node.has_valid('axis'), 'Flatten {} has no mandatory `axis` attribute'.format(name)
        assert node.has_valid('end_axis'), 'Flatten {} has no mandatory `end_axis` attribute'.format(name)

        axis = node.axis
        end_axis = node.end_axis

        if end_axis == -1 and axis >= 0:
            begin_dims = Const(graph, {'value': int64_array([0] * axis)}).create_node()
            middle_dim = Const(graph, {'value': int64_array([-1])}).create_node()
            end_dims = Const(graph, {'value': int64_array([])}).create_node()
        else:
            rank = Rank(graph, {'name': name + '/input_rank'}).create_node()
            node.in_port(0).get_source().connect(rank.in_port(0))

            shape = Shape(graph, {'name': name + '/input_shape'}).create_node()
            node.in_port(0).get_source().connect(shape.in_port(0))

            begin_dims = get_shape_values_by_range_idxs(
                shape=shape, rank=rank, begin=0, end=axis)
            middle_dims = get_shape_values_by_range_idxs(
                shape=shape, rank=rank, begin=axis, end=end_axis, include_end=True)
            end_dims = get_shape_values_by_range_idxs(
                shape=shape, rank=rank, begin=end_axis, end=-1, include_begin=False, include_end=True)

            middle_dim = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]), {'keep_dims': True})
            middle_dims.out_port(0).connect(middle_dim.in_port(0))

        dim = new_shape_node_from_shape_nodes([begin_dims, middle_dim, end_dims])

        original_name = node.soft_get('name')
        abandoned_name = original_name + '/ShouldBeDeleted'
        reshape_node = Reshape(graph, {}).create_node()
        # Keep node with the same name to avoid confuse with renaming
        rename_nodes([(node, abandoned_name), (reshape_node, original_name)])
        reshape_node.in_port(1).connect(dim.out_port(0))

        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
        node.in_port(0).get_connection().set_destination(reshape_node.in_port(0))
Пример #3
0
 def resolve_minus2(self, shape_node, input_index, reshape_index, dims):
     rank_node = Shape(
         shape_node.graph,
         dict(name=shape_node.id +
              '/RankShapeMXReshapeMinus2')).create_node()
     rank_node.in_port(0).connect(shape_node.out_port(0))
     shape_values_node = get_shape_values_by_range_idxs(shape=shape_node,
                                                        rank=rank_node,
                                                        begin=input_index,
                                                        end=-1,
                                                        include_begin=True,
                                                        include_end=True)
     input_index = None
     reshape_index = reshape_index + 1
     return input_index, reshape_index, dims, shape_values_node
Пример #4
0
    def mxrepeat_decomposition(node: Node):
        graph = node.graph
        name = node.soft_get('name', node.id)

        rename_node(node, name + '/to_be_removed')

        # Unqueeze
        input_rank = Rank(graph, {'name': name + '/Rank'}).create_node()
        node.in_port(0).get_source().connect(input_rank.in_port(0))

        axis = get_canonical_axis_index_node(input_rank, node.axis)
        unsqueeze_axis = create_op_node_with_second_input(
            graph,
            Add,
            int64_array([1]), {'name': name + '/Unsqueeze/Axis'},
            input_node=axis)

        unsqueeze = Unsqueeze(graph, {
            'name': name + '/Unsqueeze'
        }).create_node()
        unsqueeze.in_port(1).connect(unsqueeze_axis.out_port(0))

        # Tile (1, 1, ..., repeats, ..., 1)
        # we generate tile array according to the following table:

        # parts:       |      first      |  repeats |  second     |
        # i:           | 0, 1, ..., axis,| axis + 1,| ..., rank+1 |
        # tile_array:  | 1, 1, ...,  1  ,| repeats ,| ...,   1    |

        one = Const(graph, {
            'name': name + '/Broadcast/One',
            'value': int64_array([1])
        }).create_node()
        first_ones = Broadcast(graph, {
            'name': name + '/Broadcast/Ones_first_part'
        }).create_node()
        first_ones.in_port(0).connect(one.out_port(0))
        first_ones.in_port(1).connect(unsqueeze_axis.out_port(0))

        repeats = Const(graph, {
            'name': name + '/repeats',
            'value': int64_array([node.repeats])
        }).create_node()

        second_ones = Broadcast(graph, {
            'name': name + '/Broadcast/Ones_second_part'
        }).create_node()
        second_part_broadcast_shape = Sub(
            graph, {
                'name': name + '/Broadcast/Shape/second_part'
            }).create_node()
        second_part_broadcast_shape.in_port(0).connect(input_rank.out_port(0))
        second_part_broadcast_shape.in_port(1).connect(
            unsqueeze_axis.out_port(0))
        second_ones.in_port(0).connect(one.out_port(0))
        second_ones.in_port(1).connect(second_part_broadcast_shape.out_port(0))

        tile_repeats = new_shape_node_from_shape_nodes(
            [first_ones, repeats, second_ones])
        tile = Tile(graph, {'name': name + '/Tile'}).create_node()
        tile.in_port(1).connect(tile_repeats.out_port(0))

        # Reshape (input_shape[:axis], input_shape[axis] * repeats, input_shape[axis+1:])
        # we generate reshape dim array according to the following table:

        # parts:       |    first   |                rep           |  second   |
        # i:           | 0, 1, ... ,|               axis,          | ..., rank |
        # dim_array:   | inp_sh[i] ,| input_shape[axis] * repeats ,| inp_sh[i] |

        input_shape = Shape(graph, {'name': name + '/Shape'}).create_node()
        node.in_port(0).get_source().connect(input_shape.in_port(0))

        first_input_shape_part = get_shape_values_by_range_idxs(
            input_shape,
            input_rank,
            begin=0,
            end=node.axis,
            include_begin=True,
            include_end=False)

        original_axis_dim = create_op_with_const_inputs(
            graph,
            Gather, {2: int64_array(0)}, {'name': name + '/OriginalDim'},
            input_node=input_shape)
        original_axis_dim.in_port(1).connect(axis.out_port(0))

        repeated_dimention = Mul(graph, {
            'name': name + '/RepeatedDim'
        }).create_node()
        repeated_dimention.in_port(0).connect(original_axis_dim.out_port(0))
        repeated_dimention.in_port(1).connect(repeats.out_port(0))

        second_input_shape_part = get_shape_values_by_range_idxs(
            input_shape,
            input_rank,
            begin=node.axis,
            end=-1,
            include_begin=False,
            include_end=True)

        output_shape = new_shape_node_from_shape_nodes([
            first_input_shape_part, repeated_dimention, second_input_shape_part
        ])

        reshape = Reshape(graph, {'name': name}).create_node()
        rename_node(reshape, name)
        reshape.in_port(1).connect(output_shape.out_port(0))

        # Final connections
        node.in_port(0).get_connection().set_destination(unsqueeze.in_port(0))
        tile.in_port(0).connect(unsqueeze.out_port(0))
        reshape.in_port(0).connect(tile.out_port(0))
        node.out_port(0).get_connection().set_source(reshape.out_port(0))