def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['mxreshape']

        input_index = 0
        reshape_index = 0
        shape_node = Shape(graph, dict(name=node.id +
                                       '/ShapeMXReshape')).create_node()
        shape_node.in_port(0).connect(node.in_port(0).get_source())
        output_dims_nodes = []
        for d in node.dim:
            if reshape_index < len(node.dim):
                input_index, reshape_index, output_dims_nodes = self.resolve(
                    input_index, reshape_index, node.dim, shape_node,
                    output_dims_nodes)

        concat_node = Concat(
            shape_node.graph,
            dict(name=shape_node.id + '/ConcatMXReshape_',
                 axis=0,
                 in_ports_count=len(output_dims_nodes))).create_node()

        for in_port_index, dim_node in enumerate(output_dims_nodes):
            concat_node.in_port(in_port_index).connect(dim_node.out_port(0))

        reshape_node = Reshape(graph,
                               dict(name=node.id + '/Reshape_')).create_node()
        reshape_node.in_port(1).connect(concat_node.out_port(0))
        node.in_port(0).get_connection().set_destination(
            reshape_node.in_port(0))
        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
Esempio n. 2
0
    def append_variances(priors_scale_node: Node, variance: list):
        graph = priors_scale_node.graph
        name = priors_scale_node.name

        sp_shape = Shape(graph, {'name': name + '/shape'}).create_node()
        priors_scale_node.out_port(0).connect(sp_shape.in_port(0))

        begin = Const(graph, {'value': int64_array([-2])}).create_node()
        end = Const(graph, {'value': int64_array([-1])}).create_node()
        stride = Const(graph, {'value': int64_array([1])}).create_node()
        shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': int64_array([1]),
                                                     'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]),
                                                     'shrink_axis_mask': int64_array([0]),
                                                     'ellipsis_mask': int64_array([0])}).create_node()

        sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0))
        begin.out_port(0).connect(shape_part_for_tiling.in_port(1))
        end.out_port(0).connect(shape_part_for_tiling.in_port(2))
        stride.out_port(0).connect(shape_part_for_tiling.in_port(3))

        shape_concat = create_op_node_with_second_input(graph, Concat, int64_array([4]),
                                                        {'name': name + '/shape_for_tiling', 'in_ports_count': 2,
                                                         'axis': int64_array(0)},
                                                        shape_part_for_tiling)

        variance = Const(graph, {'name': name + '/variance', 'value': float32_array(variance)}).create_node()
        tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node()
        variance.out_port(0).connect(tile.in_port(0))
        shape_concat.out_port(0).connect(tile.in_port(1))

        reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node()
        sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node()
        sp_reshape.in_port(0).connect(priors_scale_node.out_port(0))
        sp_reshape.in_port(1).connect(reshape_dim.out_port(0))

        concat = Concat(graph,
                        {'name': name + '/priors_concat', 'axis': int64_array(0), 'in_ports_count': 2}).create_node()
        sp_reshape.out_port(0).connect(concat.in_port(0))
        tile.out_port(0).connect(concat.in_port(1))

        output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node()
        output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node()
        concat.out_port(0).connect(output_node.in_port(0))
        output_dims.out_port(0).connect(output_node.in_port(1))

        return output_node
Esempio n. 3
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        mxreshape = match['op']
        if not mxreshape.reverse:
            return

        shape_node = Shape(graph, dict(name=mxreshape.id + '/Shape')).create_node()
        forward_reverse_unsqueeze_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]),
                                                                          dict(name=str(mxreshape.id) + '/ForwardUnsqueeze'))
        forward_reverse_node = Reverse(graph, dict(name=mxreshape.id + '/ForwardReverse', axis=1)).create_node()

        forward_reverse_squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]),
                                                                        dict(name=str(mxreshape.id) + '/ForwardSqueeze'))
        reshape_node = Reshape(graph, dict(name=mxreshape.id + '/Reshape')).create_node()
        shape_node.in_port(0).connect(mxreshape.in_port(0).get_source())
        mxreshape.in_port(0).get_connection().set_destination(reshape_node.in_port(0))

        forward_reverse_unsqueeze_node.in_port(0).connect(shape_node.out_port(0))
        forward_reverse_node.in_port(0).connect(forward_reverse_unsqueeze_node.out_port(0))
        forward_reverse_squeeze_node.in_port(0).connect(forward_reverse_node.out_port(0))
        reshape_node.in_port(1).connect(forward_reverse_squeeze_node.out_port(0))

        reshape_shape_node = create_op_node_with_second_input(graph, Reshape, int64_array(np.flip(mxreshape.dim, 0)),
                                                              dict(name=str(mxreshape.id) + '/ReshapeShape'))
        if np.sum(np.in1d([-2, -3, -4], mxreshape.dim), axis=0):
            reshape_shape_node = MXReshape(graph, dict(name=mxreshape.id + '/Reshape',
                                     dim=int64_array(np.flip(mxreshape.dim, 0)))).create_node()

        reshape_shape_node.in_port(0).connect(reshape_node.out_port(0))

        backward_shape_node = Shape(graph, dict(name=mxreshape.id + '/BackwardShape')).create_node()
        backward_reverse_unsqueeze_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]),
                                                                           dict(name=str(mxreshape.id) + '/BackwardUnsqueeze'))
        backward_reverse_node = Reverse(graph, dict(name=mxreshape.id + '/BackwardReverse', axis=1)).create_node()
        backward_reverse_squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]),
                                                                         dict(name=str(mxreshape.id) + '/BackwardSqueeze'))
        backward_reshape_node = Reshape(graph, dict(name=mxreshape.id + '/BackwardReshape')).create_node()

        backward_shape_node.in_port(0).connect(reshape_shape_node.out_port(0))
        backward_reverse_unsqueeze_node.in_port(0).connect(backward_shape_node.out_port(0))
        backward_reverse_node.in_port(0).connect(backward_reverse_unsqueeze_node.out_port(0))
        backward_reverse_squeeze_node.in_port(0).connect(backward_reverse_node.out_port(0))

        backward_reshape_node.in_port(0).connect(reshape_shape_node.out_port(0))
        backward_reshape_node.in_port(1).connect(backward_reverse_squeeze_node.out_port(0))

        mxreshape.out_port(0).get_connection().set_source(backward_reshape_node.out_port(0))
Esempio n. 4
0
    def find_and_replace_pattern(self, graph: Graph):
        layout = graph.graph['layout']
        for eltwise_op_node in graph.get_op_nodes(is_eltwise=True):
                out_shape = eltwise_op_node.out_port().data.get_shape()
                if 4 <= len(out_shape) <= 5:
                    out_features = out_shape[get_features_dim(layout, len(out_shape))]
                    for port, node in eltwise_op_node.in_nodes().items():
                        if len(node.shape) != len(out_shape) and len(node.shape) == 1 and out_features == node.shape[0]:
                            new_shape = shape_for_layout(layout, batch=1, features=out_features, height=1, width=1,
                                                         depth=1 if len(out_shape) == 5 else None)
                            dim_const = Const(graph, {'value': new_shape, 'name': node.id + '/Dim'}).create_node()
                            reshape_op = Reshape(graph, attrs={'dim': new_shape, 'name': node.id + '/Broadcast'}).create_node()

                            eltwise_op_node.in_port(port).get_source().node.out_port(0).get_connection().set_destination(reshape_op.in_port(0))
                            reshape_op.in_port(1).connect(dim_const.out_port(0))

                            reshape_op.out_port(0).connect(eltwise_op_node.in_port(port))
Esempio n. 5
0
    def convert_fft_to_dft(self, graph: Graph, mx_fft: Node):
        mx_fft_name = mx_fft.soft_get('name', mx_fft.id)
        unsqueeze_node = create_op_with_const_inputs(
            graph, Unsqueeze, {1: int64_array([-1])},
            {'name': mx_fft_name + '/Unsqueeze'})
        rank_node = Rank(graph, {'name': mx_fft_name + '/Rank'}).create_node()

        mx_fft_connection = mx_fft.in_port(0).get_connection()
        mx_fft_connection.set_destination(unsqueeze_node.in_port(0))
        mx_fft_connection.get_source().connect(rank_node.in_port(0))

        add_node = create_op_with_const_inputs(graph, Add, {1: int64_array(1)},
                                               {'name': mx_fft_name + '/Add'},
                                               rank_node)
        broadcast_node1 = create_op_with_const_inputs(
            graph, Broadcast, {0: int64_array(0)},
            {'name': mx_fft_name + '/Pad_broadcast'})
        add_node.out_port(0).connect(broadcast_node1.in_port(1))

        scatter_node = create_op_with_const_inputs(
            graph, ScatterUpdate, {
                2: int64_array(1),
                3: int64_array(0)
            }, {'name': mx_fft_name + '/ScatterUpdate'})
        broadcast_node1.out_port(0).connect(scatter_node.in_port(0))
        rank_node.out_port(0).connect(scatter_node.in_port(1))

        pad_node = Pad(graph, {
            'name': mx_fft_name + '/Pad',
            'mode': 'constant'
        }).create_node([unsqueeze_node, broadcast_node1, scatter_node])

        dft_node = create_op_with_const_inputs(
            graph, DFT, {1: int64_array([-1])}, {
                'name': mx_fft_name + '/DFT',
                'in_ports_count': 2
            }, pad_node)

        sub_node = create_op_with_const_inputs(graph, Sub, {1: int64_array(1)},
                                               {'name': mx_fft_name + '/Sub'})
        rank_node.out_port(0).connect(sub_node.in_port(0))
        broadcast_node2 = create_op_with_const_inputs(
            graph, Broadcast, {0: int64_array(0)},
            {'name': mx_fft_name + '/Reshape_broadcast'})
        sub_node.out_port(0).connect(broadcast_node2.in_port(1))
        concat_node = create_op_with_const_inputs(
            graph, Concat, {1: int64_array([-1, 2])}, {
                'name': mx_fft_name + '/New_shape',
                'in_ports_count': 2,
                'axis': 0
            }, broadcast_node2)

        reshape_node = Reshape(graph, {}).create_node([dft_node, concat_node])

        mx_fft.out_port(0).get_connection().set_source(
            reshape_node.out_port(0))
        rename_nodes([(mx_fft, mx_fft_name + '/to_be_removed'),
                      (reshape_node, mx_fft_name)])
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['flatten']
        name = node.soft_get('name', node.id)

        assert node.has_valid(
            'axis'
        ), 'Flatten {} should have `axis` attribute extracted, but it\'s not'.format(
            name)
        axis = node.axis

        reshape_node = Reshape(graph, {
            'name': node.id + '/Reshape'
        }).create_node()

        if axis == 0:
            dim = Const(
                graph, {
                    'value': int64_array([1, -1]),
                    'name': reshape_node.name + '/shape'
                }).create_node()
        elif axis == 1:
            dim = Const(
                graph, {
                    'value': int64_array([0, -1]),
                    'name': reshape_node.name + '/shape'
                }).create_node()
        else:
            shape = Shape(graph, {'name': name + '/input_shape'}).create_node()

            idxs = list(range(axis)) if axis > 0 else list(range(axis, 0))

            axis_shape_portion = node_to_get_shape_value_of_indices(
                shape, idxs)
            first_dims = create_op_node_with_second_input(
                graph, ReduceProd, int64_array([0]), {
                    'name': name + '/first_dims',
                    'keep_dims': True
                })
            second_dims = Const(graph, {
                'value': int64_array([-1]),
                'name': name + '/second_dims'
            }).create_node()

            node.in_port(0).get_source().connect(shape.in_port(0))
            axis_shape_portion.out_port(0).connect(first_dims.in_port(0))

            order_of_dims = [first_dims, second_dims
                             ] if axis > 0 else [second_dims, first_dims]

            dim = new_shape_node_from_shape_nodes(order_of_dims)

        reshape_node.in_port(1).connect(dim.out_port(0))

        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
        node.in_port(0).get_connection().set_destination(
            reshape_node.in_port(0))
Esempio n. 7
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['flatten']
        name = node.soft_get('name', node.id)

        assert node.has_valid('axis'), 'Flatten {} has no mandatory `axis` attribute'.format(name)
        assert node.has_valid('end_axis'), 'Flatten {} has no mandatory `end_axis` attribute'.format(name)

        axis = node.axis
        end_axis = node.end_axis

        if end_axis == -1 and axis >= 0:
            begin_dims = Const(graph, {'value': int64_array([0] * axis)}).create_node()
            middle_dim = Const(graph, {'value': int64_array([-1])}).create_node()
            end_dims = Const(graph, {'value': int64_array([])}).create_node()
        else:
            rank = Rank(graph, {'name': name + '/input_rank'}).create_node()
            node.in_port(0).get_source().connect(rank.in_port(0))

            shape = Shape(graph, {'name': name + '/input_shape'}).create_node()
            node.in_port(0).get_source().connect(shape.in_port(0))

            begin_dims = get_shape_values_by_range_idxs(
                shape=shape, rank=rank, begin=0, end=axis)
            middle_dims = get_shape_values_by_range_idxs(
                shape=shape, rank=rank, begin=axis, end=end_axis, include_end=True)
            end_dims = get_shape_values_by_range_idxs(
                shape=shape, rank=rank, begin=end_axis, end=-1, include_begin=False, include_end=True)

            middle_dim = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]), {'keep_dims': True})
            middle_dims.out_port(0).connect(middle_dim.in_port(0))

        dim = new_shape_node_from_shape_nodes([begin_dims, middle_dim, end_dims])

        original_name = node.soft_get('name')
        abandoned_name = original_name + '/ShouldBeDeleted'
        reshape_node = Reshape(graph, {}).create_node()
        # Keep node with the same name to avoid confuse with renaming
        rename_nodes([(node, abandoned_name), (reshape_node, original_name)])
        reshape_node.in_port(1).connect(dim.out_port(0))

        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
        node.in_port(0).get_connection().set_destination(reshape_node.in_port(0))
Esempio n. 8
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']
        name = node.soft_get('name', node.id)
        axis = node.axis
        input_shape_node = Shape(graph, {
            'name': name + '/ShapeOf'
        }).create_node()
        range_node = create_op_with_const_inputs(graph, Range, {
            0: mo_array(node.start),
            2: mo_array(node.step)
        }, {'name': name + '/Range'})
        node.in_port(0).get_connection().set_destination(
            input_shape_node.in_port(0))

        if axis is not None:
            '''
            Replace arange_like op to subgraph:
            Shape - Gather - Range
            '''
            gather_node = create_op_with_const_inputs(graph, Gather, {
                1: int64_array([axis]),
                2: int64_array(0)
            }, {'name': name + '/Gather'})
            input_shape_node.out_port(0).connect(gather_node.in_port(0))
            gather_node.out_port(0).connect(range_node.in_port(1))
            node.out_port(0).get_connection().set_source(
                range_node.out_port(0))
            rename_nodes([(node, name + '/ShouldBeDeleted'),
                          (range_node, name)])
        else:
            r'''
            Replace arange_like op to subgraph:
                    |
                 ShapeOf ----------- | 
                    |                |
                 ReduceProd          |
                    |                |
                  Range              |
                    |                |
                 Reshape ----------- | 
                    |
            '''

            flattened_shape_node = create_op_with_const_inputs(
                graph, ReduceProd, {1: int64_array([0])}, {
                    'name': input_shape_node.name + '/ReduceProd',
                    'keep_dims': True
                })
            reshape_backward_node = Reshape(graph, {
                'name': name + '/Reshape_backward'
            }).create_node()

            input_shape_node.out_port(0).connect(
                flattened_shape_node.in_port(0))
            flattened_shape_node.out_port(0).connect(range_node.in_port(1))
            range_node.out_port(0).connect(reshape_backward_node.in_port(0))
            input_shape_node.out_port(0).connect(
                reshape_backward_node.in_port(1))
            node.out_port(0).get_connection().set_source(
                reshape_backward_node.out_port(0))
            rename_nodes([(node, name + '/ShouldBeDeleted'),
                          (reshape_backward_node, name)])

        if node.repeat != 1:
            r"""
            First, we generate the correct stop value for Range like new_stop_value = stop_value // repeat + 1.
            Then repeats each value of the interval using Tile. After that we can get a longer interval
            so we reduce it with Slice.
            
            Sub-graph after Range node will be look like
            
            Range - Reshape([-1, 1]) - Tile([1, repeat]) - Reshape(-1) - Slice
            
            """

            if node.repeat < 1:
                raise Error(
                    "Unexpected value {} of the attribute 'repeat' for the node {}"
                    .format(node.repeat, name))

            div_node = create_op_with_const_inputs(
                graph, Div, {1: int64_array([node.repeat])},
                {'name': name + '/Divide'})
            add_node = create_op_with_const_inputs(
                graph, Add, {1: int64_array([1])},
                {'name': div_node.name + '/Add'})
            cast_node = Cast(graph, {
                'name': name + '/ConvertToI64',
                'dst_type': np.int64
            }).create_node()

            cast_node.out_port(0).connect(div_node.in_port(0))
            div_node.out_port(0).connect(add_node.in_port(0))
            range_node.in_port(1).get_connection().set_destination(
                cast_node.in_port(0))
            add_node.out_port(0).connect(range_node.in_port(1))

            tile_forward_reshape = create_op_with_const_inputs(
                graph, Reshape, {1: int64_array([-1, 1])},
                {'name': range_node.name + '/ForwardReshape'})
            tile = create_op_with_const_inputs(
                graph, Tile, {1: int64_array([1, node.repeat])},
                {'name': tile_forward_reshape.name + '/Tile'})
            tile_backward_reshape = create_op_with_const_inputs(
                graph, Reshape, {1: int64_array([-1])},
                {'name': tile.name + '/BackwardReshape'})
            slice_node = create_op_with_const_inputs(
                graph, Slice, {
                    1: int64_array([0]),
                    3: int64_array([0]),
                    4: int64_array([1])
                }, {'name': tile_backward_reshape.name + '/Slice'})

            tile_forward_reshape.out_port(0).connect(tile.in_port(0))
            tile.out_port(0).connect(tile_backward_reshape.in_port(0))
            tile_backward_reshape.out_port(0).connect(slice_node.in_port(0))
            slice_node.in_port(2).connect(div_node.in_port(0).get_source())

            range_node.out_port(0).get_connection().set_source(
                slice_node.out_port(0))
            range_node.out_port(0).connect(tile_forward_reshape.in_port(0))

            if axis is not None:
                rename_nodes([(range_node, name + '/Range'),
                              (slice_node, name)])

        # MXNet arange_like op has no stop attribute and the result tensor always matches the input shape, so
        # we have to correct the stop value for the Range node if step != 1 or start != 0
        if node.step != 1:
            # If step attribute is not integer, we will generate an interval with a larger size and then reduce it
            # using Slice
            true_elements_count_port = range_node.in_port(1).get_source()
            mul_value = np.ceil(node.step) if node.step > 0 else np.floor(
                node.step)
            stop_value = create_op_with_const_inputs(
                graph,
                Mul,
                port_value_dict={1: mo_array(np.ceil(mul_value))},
                op_attrs={'name': range_node.name + '/Stop'})
            range_node.in_port(1).get_connection().insert_node(stop_value)

            slice_range_values = create_op_with_const_inputs(
                graph, Slice, {
                    1: int64_array([0]),
                    3: int64_array([0]),
                    4: int64_array([1])
                }, {'name': range_node.name + '/Slice'})
            slice_range_values.in_port(2).connect(true_elements_count_port)
            range_node.out_port(0).get_connection().insert_node(
                slice_range_values)

            if axis is not None and node.repeat == 1:
                rename_nodes([(range_node, name + '/Range'),
                              (slice_range_values, name)])

        if node.start != 0:
            correct_stop_value = create_op_with_const_inputs(
                graph,
                Add,
                port_value_dict={1: mo_array(node.start)},
                op_attrs={'name': range_node.name + '/Correct_Stop'})
            range_node.in_port(1).get_connection().insert_node(
                correct_stop_value)

        # Range node supports only scalar inputs
        squeeze_node = create_op_with_const_inputs(
            graph,
            Squeeze,
            port_value_dict={1: int64_array(0)},
            op_attrs={"name": range_node.name + '/Stop/Squeeze'})
        range_node.in_port(1).get_connection().insert_node(squeeze_node)
Esempio n. 9
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        group_norm_node = match['op']
        group_norm_num_input_dims = len(group_norm_node.in_port(0).data.get_shape())

        # node computing initial GroupNorm input shape
        initial_shape_op_node = Shape(graph, {'name': group_norm_node.name + '/Shape'}).create_node()
        initial_shape_op_node.in_port(0).connect(group_norm_node.in_port(0).get_source())

        initial_shape_op_node_float = Cast(
            graph, {'name': initial_shape_op_node.name + '/to_float',
                    'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
        initial_shape_op_node.out_port(0).connect(initial_shape_op_node_float.in_port(0))

        initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node_float)
        initial_features_dim_node = node_to_get_features_dimension_value(initial_shape_op_node_float)
        initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value(initial_shape_op_node)
        initial_spatial_dims_node = Cast(
            graph, {'name': initial_spatial_dims_node_int.name + '/to_float',
                    'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
        initial_spatial_dims_node_int.out_port(0).connect(initial_spatial_dims_node.in_port(0))

        group_size_node = Const(graph, {'value': int64_array([group_norm_node.num_groups]),
                                        'name': group_norm_node.name + '/GroupSize'}).create_node()

        # calculate "features // group_size" value
        reciprocal_group_size_node = Const(graph, {'value': np.array([1.0 / group_norm_node.num_groups]),
                                                   'name': group_norm_node.name + '/ReciprocalGroupSize'}).create_node()

        c_div_g_node = Mul(graph, {}).create_node()
        c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0))
        c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0))

        batch_mul_group_size_node = Mul(graph, {}).create_node()
        batch_mul_group_size_node.in_port(0).connect(initial_batch_dim_node.out_port(0))
        batch_mul_group_size_node.in_port(1).connect(group_size_node.out_port(0))

        # create new node which concatenates several dims to one
        new_shape_node_float = new_shape_node_from_shape_nodes([batch_mul_group_size_node, c_div_g_node,
                                                                initial_spatial_dims_node])
        new_shape_node = Cast(graph,
                              {'name': new_shape_node_float.name + '/to_int64', 'dst_type': np.int64}).create_node()
        new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0))

        reshape_for_mvn_node = Reshape(graph, {}).create_node()

        group_norm_node.in_port(0).get_connection().set_destination(reshape_for_mvn_node.in_port(0))
        reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0))

        # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc
        gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64)
        gamma_beta_shape[1] = -1

        gamma_value = group_norm_node.in_port(1).get_source().data.get_value()
        beta_value = group_norm_node.in_port(2).get_source().data.get_value()
        assert gamma_value is not None, 'The gamma should be constant'
        assert beta_value is not None, 'The beta should be constant'
        gamma_value = np.reshape(gamma_value, gamma_beta_shape)
        group_norm_node.in_port(1).get_source().data.set_value(gamma_value)
        beta_value = np.reshape(beta_value, gamma_beta_shape)
        group_norm_node.in_port(2).get_source().data.set_value(beta_value)

        # MVN
        mvn_node = MVN(graph, {'name': group_norm_node.name + '/MVN',
                               'normalize_variance': 1,
                               'eps': group_norm_node.eps,
                               'eps_mode': 'inside_sqrt'}).create_node()
        mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0))

        # MVN axes
        _, rank = get_shape_and_rank_nodes_by_port(mvn_node.in_port(0).get_connection().get_source(),
                                                   return_as_a_scalar=True)
        rng = create_op_with_const_inputs(graph, Range, {0: int64_array(1), 2: int64_array(1)},
                                          {'name': group_norm_node.name + '/Range', 'output_type': np.int64})
        mvn_node.in_port(1).connect(rng.out_port(0))
        rng.in_port(1).connect(rank.out_port(0))

        # reshape to the initial shape before multiplying with gamma and adding beta
        reshape_to_initial_shape_node = Reshape(graph, {}).create_node()
        reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0))
        reshape_to_initial_shape_node.in_port(1).connect(initial_shape_op_node.out_port(0))

        mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node()
        mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0))
        group_norm_node.in_port(1).get_connection().set_destination(mul_node.in_port(1))

        add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node()
        add_node.in_port(0).connect(mul_node.out_port(0))
        group_norm_node.in_port(2).get_connection().set_destination(add_node.in_port(1))

        group_norm_node.out_port(0).get_connection().set_source(add_node.out_port(0))
Esempio n. 10
0
    def mxrepeat_decomposition(node: Node):
        graph = node.graph
        name = node.soft_get('name', node.id)

        rename_node(node, name + '/to_be_removed')

        # Unqueeze
        input_rank = Rank(graph, {'name': name + '/Rank'}).create_node()
        node.in_port(0).get_source().connect(input_rank.in_port(0))

        axis = get_canonical_axis_index_node(input_rank, node.axis)
        unsqueeze_axis = create_op_node_with_second_input(
            graph,
            Add,
            int64_array([1]), {'name': name + '/Unsqueeze/Axis'},
            input_node=axis)

        unsqueeze = Unsqueeze(graph, {
            'name': name + '/Unsqueeze'
        }).create_node()
        unsqueeze.in_port(1).connect(unsqueeze_axis.out_port(0))

        # Tile (1, 1, ..., repeats, ..., 1)
        # we generate tile array according to the following table:

        # parts:       |      first      |  repeats |  second     |
        # i:           | 0, 1, ..., axis,| axis + 1,| ..., rank+1 |
        # tile_array:  | 1, 1, ...,  1  ,| repeats ,| ...,   1    |

        one = Const(graph, {
            'name': name + '/Broadcast/One',
            'value': int64_array([1])
        }).create_node()
        first_ones = Broadcast(graph, {
            'name': name + '/Broadcast/Ones_first_part'
        }).create_node()
        first_ones.in_port(0).connect(one.out_port(0))
        first_ones.in_port(1).connect(unsqueeze_axis.out_port(0))

        repeats = Const(graph, {
            'name': name + '/repeats',
            'value': int64_array([node.repeats])
        }).create_node()

        second_ones = Broadcast(graph, {
            'name': name + '/Broadcast/Ones_second_part'
        }).create_node()
        second_part_broadcast_shape = Sub(
            graph, {
                'name': name + '/Broadcast/Shape/second_part'
            }).create_node()
        second_part_broadcast_shape.in_port(0).connect(input_rank.out_port(0))
        second_part_broadcast_shape.in_port(1).connect(
            unsqueeze_axis.out_port(0))
        second_ones.in_port(0).connect(one.out_port(0))
        second_ones.in_port(1).connect(second_part_broadcast_shape.out_port(0))

        tile_repeats = new_shape_node_from_shape_nodes(
            [first_ones, repeats, second_ones])
        tile = Tile(graph, {'name': name + '/Tile'}).create_node()
        tile.in_port(1).connect(tile_repeats.out_port(0))

        # Reshape (input_shape[:axis], input_shape[axis] * repeats, input_shape[axis+1:])
        # we generate reshape dim array according to the following table:

        # parts:       |    first   |                rep           |  second   |
        # i:           | 0, 1, ... ,|               axis,          | ..., rank |
        # dim_array:   | inp_sh[i] ,| input_shape[axis] * repeats ,| inp_sh[i] |

        input_shape = Shape(graph, {'name': name + '/Shape'}).create_node()
        node.in_port(0).get_source().connect(input_shape.in_port(0))

        first_input_shape_part = get_shape_values_by_range_idxs(
            input_shape,
            input_rank,
            begin=0,
            end=node.axis,
            include_begin=True,
            include_end=False)

        original_axis_dim = create_op_with_const_inputs(
            graph,
            Gather, {2: int64_array(0)}, {'name': name + '/OriginalDim'},
            input_node=input_shape)
        original_axis_dim.in_port(1).connect(axis.out_port(0))

        repeated_dimention = Mul(graph, {
            'name': name + '/RepeatedDim'
        }).create_node()
        repeated_dimention.in_port(0).connect(original_axis_dim.out_port(0))
        repeated_dimention.in_port(1).connect(repeats.out_port(0))

        second_input_shape_part = get_shape_values_by_range_idxs(
            input_shape,
            input_rank,
            begin=node.axis,
            end=-1,
            include_begin=False,
            include_end=True)

        output_shape = new_shape_node_from_shape_nodes([
            first_input_shape_part, repeated_dimention, second_input_shape_part
        ])

        reshape = Reshape(graph, {'name': name}).create_node()
        rename_node(reshape, name)
        reshape.in_port(1).connect(output_shape.out_port(0))

        # Final connections
        node.in_port(0).get_connection().set_destination(unsqueeze.in_port(0))
        tile.in_port(0).connect(unsqueeze.out_port(0))
        reshape.in_port(0).connect(tile.out_port(0))
        node.out_port(0).get_connection().set_source(reshape.out_port(0))
Esempio n. 11
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['pool']
        node_name = node.soft_get('name', node.id)

        if node.pool_step is None:
            node.stride = int64_array([1, 1, node.window[-1], node.window[-1]])

        # create Reshape before convolution
        # shape = [in_shape[0], pool_stride, 1, in_shape[1]/pool_stride]
        i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()

        dst_dtype = np.float32  # even if data_type=FP16 use float32 for shape values
        shape = Cast(graph, {
            'name': node_name + '/to_float',
            'dst_type': dst_dtype
        }).create_node()
        i_shape.in_port(0).connect(node.in_port(0).get_source())
        shape.in_port(0).connect(i_shape.out_port(0))

        N, H = node_to_get_shape_value_of_indices(
            shape, [0]), node_to_get_shape_value_of_indices(shape, [1])

        div = create_op_with_const_inputs(
            graph, Div, {1: float32_array([node.pool_stride])},
            {'name': node_name + '/div_stride_h'})
        div.in_port(0).connect(H.out_port(0))

        concat = create_op_with_const_inputs(
            graph, Concat, {
                1: float32_array([node.pool_stride]),
                2: float32_array([1])
            }, {
                'name': node_name + '/concat_all_dims',
                'in_ports_count': 4,
                'axis': 0
            })
        concat.in_port(0).connect(N.out_port(0))
        concat.in_port(3).connect(div.out_port(0))

        reshape_pattern = Cast(graph, {
            'name': node_name + '/to_int',
            'dst_type': np.int64
        }).create_node()
        concat.out_port(0).connect(reshape_pattern.in_port(0))

        reshape_in = Reshape(graph, {
            'name': node_name + '/reshape_in'
        }).create_node()
        reshape_in.in_port(1).connect(reshape_pattern.out_port(0))

        # create Reshape after Convolution
        reshape_out = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            {'name': node_name + '/reshape_out'})

        # connect input_reshape_node
        source = node.in_port(0).get_source()
        node.in_port(0).get_connection().set_source(reshape_in.out_port(0))
        reshape_in.in_port(0).connect(source)
        # connect output_reshape_node
        node.out_port(0).get_connection().set_source(reshape_out.out_port(0))
        node.out_port(0).connect(reshape_out.in_port(0))
    def replace_pattern(graph: Graph, match: dict):
        node = match['conv']
        node_name = node.soft_get('name', node.id)

        # create Reshape before convolution
        # if transpose will be applied (new models)
        #   shape = [in_shape[0], t= in_shape[1]/(patch_stride*t), patch_stride, C=1]
        # else (for old models to avoid fails on GNA - should be removed as soon as GNA will be changed)
        #   shape = [in_shape[0], t= in_shape[1]/(patch_stride*t), C=1, patch_stride]
        sp_dim_1 = 1
        if node.has_valid('patch_stride'):
            channel_dim = 2
            sp_dim_2 = 3
            frame_height = node.patch_stride
        else:
            channel_dim = 3
            sp_dim_2 = 2
            frame_height = node.height_in

        i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
        i_shape.in_port(0).connect(node.in_port(0).get_source())

        N, H = node_to_get_shape_value_of_indices(
            i_shape, [0]), node_to_get_shape_value_of_indices(i_shape, [1])

        div = create_op_with_const_inputs(
            graph, Div, {1: int64_array([frame_height * node.kernel[1]])},
            {'name': node_name + '/div_stride_h'})
        div.in_port(0).connect(H.out_port(0))

        concat = create_op_with_const_inputs(
            graph, Concat, {
                sp_dim_2: int64_array([frame_height]),
                channel_dim: int64_array([node.kernel[1]])
            }, {
                'name': node_name + '/concat_all_dims',
                'in_ports_count': 4,
                'axis': 0
            })
        concat.in_port(0).connect(N.out_port(0))
        concat.in_port(sp_dim_1).connect(div.out_port(0))

        reshape_in = Reshape(graph, {
            'name': node_name + '/reshape_in'
        }).create_node()
        reshape_in.in_port(1).connect(concat.out_port(0))

        # change layout from NHWC to NCHW
        # should be replaced by common Permute logic in future
        transpose = None
        if channel_dim == 3 and node.channel_dims == 1:
            transpose = create_op_node_with_second_input(
                graph, Transpose, int64_array([0, 3, 1, 2]),
                {'name': node.name + '/Transpose'}, reshape_in)

        # create Reshape after Convolution
        reshape_out = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            {'name': node_name + '/reshape_out'})

        # connect input_reshape_node
        source = node.in_port(0).get_source()
        node.in_port(0).get_connection().set_source(
            transpose.out_port(0) if transpose else reshape_in.out_port(0))
        reshape_in.in_port(0).connect(source)
        # connect output_reshape_node
        node.out_port(0).get_connection().set_source(reshape_out.out_port(0))
        node.out_port(0).connect(reshape_out.in_port(0))