Exemple #1
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='MVNCaffe'):
            node_name = node.soft_get('name', node.id)

            start_axis = 2
            if node['across_channels'] == 1:
                start_axis = 1

            rank = Rank(graph, {'name': node_name + '/Rank'}).create_node()

            # create range of axes based on `start_axis` and rank of input
            rng = create_op_with_const_inputs(graph, Range, {
                0: int64_array(start_axis),
                2: int64_array(1)
            }, {
                'name': node_name + '/Range',
                'output_type': np.int64
            })
            rng.in_port(1).connect(rank.out_port(0))

            new_mvn = MVN(
                graph, {
                    'eps': node.soft_get('eps', 1e-9),
                    'eps_mode': 'inside_sqrt',
                    'normalize_variance': node.soft_get(
                        'normalize_variance', 1)
                }).create_node([node.in_port(0).get_source().node, rng])
            new_mvn.in_port(0).get_connection().add_destination(
                rank.in_port(0))
            node.out_port(0).get_connection().set_source(new_mvn.out_port(0))
            rename_nodes([(node, node_name + '/tbd'), (new_mvn, node_name)])

            graph.remove_node(node.id)
    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
        node = match['reduce']
        connected_in_ports = [
            port for port in node.in_ports().values()
            if not port.disconnected()
        ]
        if len(connected_in_ports) == 1:
            node_name = node.soft_get('name', node.id)

            # if the 'axis' is None then we still add a second input to the layer with a 1D array with 1 element equal
            # to None. The infer function handles this case because the input shape is known at this stage only
            if node.has_valid('axis'):
                const = Const(graph, {
                    'name': node_name + '/axis',
                    'value': node.axis
                }).create_node()
                node.add_input_port(1, skip_if_exist=True)
                const.out_port(0).connect(node.in_port(1))
                del graph.node[node.id]['axis']
            else:
                # The default (if there is no 'axis') is to reduce over all the dimensions of the input tensor.
                axes = create_op_with_const_inputs(
                    graph, Range, {
                        0: int64_array(0),
                        2: int64_array(1)
                    }, dict(name=node_name + '/axes'))
                end_of_range = Rank(graph, dict(name=node_name +
                                                '/range_end')).create_node()
                node.in_port(0).get_connection().get_source().connect(
                    end_of_range.in_port(0))
                end_of_range.out_port(0).connect(axes.in_port(1))

                node.add_input_port(1, skip_if_exist=True)
                axes.out_port(0).connect(node.in_port(1))
Exemple #3
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='ATen', operator='embedding_bag'):
            assert node.soft_get('mode') == 0, 'ATen::embedding_bag has unsupported mode, only "sum" ' \
                                               'mode is supported for node {}.'.format(node.id)
            node_name = node.soft_get('name', node.id)
            rename_node(node, node_name + '/TBR')
            is_packed = False
            if len(node.in_ports()) < 3 or node.in_port(2).disconnected():
                is_packed = True
                embedding_bag = EmbeddingBagPackedSum(graph, {'name': node_name}).create_node()
            else:
                embedding_bag = EmbeddingBagOffsetsSum(graph, {'name': node_name}).create_node()
                node.in_port(2).get_connection().set_destination(embedding_bag.in_port(2))
            rename_node(embedding_bag, node_name)
            node.in_port(0).get_connection().set_destination(embedding_bag.in_port(0))
            node.in_port(1).get_connection().set_destination(embedding_bag.in_port(1))
            node.out_port(0).get_connection().set_source(embedding_bag.out_port(0))
            if len(node.in_ports()) == 4 and not node.in_port(3).disconnected():
                if is_packed:
                    node.in_port(3).get_connection().set_destination(embedding_bag.in_port(2))
                else:
                    # connect per_sample_weights
                    node.in_port(3).get_connection().set_destination(embedding_bag.in_port(4))

                    weights_shape_node = Shape(graph, {'name': node_name + '/WeightsShape'}).create_node()

                    weights_rank_node = Rank(graph, {'name': node_name + '/WeightsRank'}).create_node()
                    last_dim_node = get_canonical_axis_index_node(weights_rank_node, -1)
                    weights_last_dim = get_shape_values_by_indices_node(weights_shape_node, last_dim_node)

                    weights_first_dim = node_to_get_shape_value_of_indices(weights_shape_node, [0])

                    zero_col_node = create_op_with_const_inputs(graph, Broadcast, {0: int64_array([0])},
                                                                {'name': node_name + '/Broadcast'})
                    zero_col_node.in_port(1).connect(weights_last_dim.out_port(0))

                    default_embeddings_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(0)},
                                                                          {'name': node_name + '/Unsqueeze'})
                    default_embeddings_node.in_port(0).connect(zero_col_node.out_port(0))

                    # expand embedding table with zeros
                    weights_concat = Concat(graph, {'axis': 0, 'in_ports_count': 2,
                                                    'name': node_name + '/Concat'}).create_node()
                    embedding_bag.in_port(0).get_connection().set_destination(weights_concat.in_port(0))
                    weights_concat.in_port(0).get_connection().add_destination(weights_shape_node.in_port(0))
                    weights_concat.in_port(0).get_connection().add_destination(weights_rank_node.in_port(0))
                    weights_concat.in_port(1).connect(default_embeddings_node.out_port(0))
                    weights_concat.out_port(0).connect(embedding_bag.in_port(0))

                    # point default index to expanded part of embedding table
                    weights_first_dim.out_port(0).connect(embedding_bag.in_port(3))
Exemple #4
0
    def convert_ifft_to_dft(self, graph: Graph, mx_fft: Node):
        mx_fft_name = mx_fft.soft_get('name', mx_fft.id)

        rank_node = Rank(graph, {'name': mx_fft_name + '/rank'}).create_node()
        sub_node = create_op_with_const_inputs(graph, Sub, {1: int64_array(1)},
                                               {'name': mx_fft_name + '/Sub'})
        rank_node.out_port(0).connect(sub_node.in_port(0))
        broadcast_node0 = create_op_with_const_inputs(
            graph, Broadcast, {0: int64_array(0)},
            {'name': mx_fft_name + '/broadcast'})
        sub_node.out_port(0).connect(broadcast_node0.in_port(1))
        concat_node = create_op_with_const_inputs(
            graph, Concat, {1: int64_array([-1, 2])}, {
                'name': mx_fft_name + '/new_shape',
                'in_ports_count': 2,
                'axis': 0
            }, broadcast_node0)

        reshape_node = Reshape(graph, {
            'name': mx_fft_name + '/reshape'
        }).create_node()
        concat_node.out_port(0).connect(reshape_node.in_port(1))

        mx_fft_connection = mx_fft.in_port(0).get_connection()
        mx_fft_connection.set_destination(reshape_node.in_port(0))
        mx_fft_connection.get_source().connect(rank_node.in_port(0))

        dft_node = create_op_with_const_inputs(
            graph, IDFT, {1: int64_array([-1])}, {
                'name': mx_fft_name + '/idft',
                'in_ports_count': 2
            }, reshape_node)

        split_node = create_op_with_const_inputs(
            graph, Split, {1: int64_array(-1)}, {
                'name': mx_fft_name + '/split',
                'num_splits': 2
            }, dft_node)
        squeeze_node = create_op_with_const_inputs(graph, Squeeze,
                                                   {1: int64_array([-1])}, {},
                                                   split_node)

        mx_fft.out_port(0).get_connection().set_source(
            squeeze_node.out_port(0))
        rename_nodes([(mx_fft, mx_fft_name + '/to_be_removed'),
                      (squeeze_node, mx_fft_name)])
Exemple #5
0
    def find_and_replace_pattern(self, graph: Graph):
        global_poolings = graph.get_op_nodes(type='Pooling', global_pool=True)
        if len(global_poolings) == 0:
            return

        layout = graph.graph['layout']
        assert layout != 'NHWC', 'Global pooling transformation depends on layout (NHWC not enabled)'

        for pooling in global_poolings:
            name = pooling.soft_get('name', pooling.id)
            assert pooling.has_valid(
                'pool_method'
            ), 'Global Pooling {} has no `pool_method` attribute'.format(name)
            method = pooling['pool_method']
            assert method in self.pool_method_to_reduce_type, \
                'Unexpected Global Pooling method `{}` for node `{}`'.format(method, name)
            reduce_op_class = self.pool_method_to_reduce_type[method]

            reduce = reduce_op_class(graph, {
                'name': name + '/reduce',
                'keep_dims': True
            }).create_node()

            pooling.out_port(0).get_connection().set_source(reduce.out_port(0))
            src = pooling.in_port(0).get_connection().get_source()

            reduce.in_port(0).get_connection().set_source(src)

            start = Const(graph, {'value': int64_array(2)}).create_node()
            end = Rank(graph, {'name': name + '/input_rank'}).create_node()
            delta = Const(graph, {'value': int64_array(1)}).create_node()

            axis = Range(graph, {
                'name': name + '/global_pooling_reduce_axis'
            }).create_node()

            axis.in_port(0).connect(start.out_port(0))
            src.connect(end.in_port(0))
            axis.in_port(1).connect(end.out_port(0))
            axis.in_port(2).connect(delta.out_port(0))

            axis.out_port(0).connect(reduce.in_port(1))

            log.debug('Global {} pooling was converted to reduce: `{}`'.format(
                method, name))
Exemple #6
0
    def convert_fft_to_dft(self, graph: Graph, mx_fft: Node):
        mx_fft_name = mx_fft.soft_get('name', mx_fft.id)
        unsqueeze_node = create_op_with_const_inputs(
            graph, Unsqueeze, {1: int64_array([-1])},
            {'name': mx_fft_name + '/Unsqueeze'})
        rank_node = Rank(graph, {'name': mx_fft_name + '/Rank'}).create_node()

        mx_fft_connection = mx_fft.in_port(0).get_connection()
        mx_fft_connection.set_destination(unsqueeze_node.in_port(0))
        mx_fft_connection.get_source().connect(rank_node.in_port(0))

        add_node = create_op_with_const_inputs(graph, Add, {1: int64_array(1)},
                                               {'name': mx_fft_name + '/Add'},
                                               rank_node)
        broadcast_node1 = create_op_with_const_inputs(
            graph, Broadcast, {0: int64_array(0)},
            {'name': mx_fft_name + '/Pad_broadcast'})
        add_node.out_port(0).connect(broadcast_node1.in_port(1))

        scatter_node = create_op_with_const_inputs(
            graph, ScatterUpdate, {
                2: int64_array(1),
                3: int64_array(0)
            }, {'name': mx_fft_name + '/ScatterUpdate'})
        broadcast_node1.out_port(0).connect(scatter_node.in_port(0))
        rank_node.out_port(0).connect(scatter_node.in_port(1))

        pad_node = Pad(graph, {
            'name': mx_fft_name + '/Pad',
            'mode': 'constant'
        }).create_node([unsqueeze_node, broadcast_node1, scatter_node])

        dft_node = create_op_with_const_inputs(
            graph, DFT, {1: int64_array([-1])}, {
                'name': mx_fft_name + '/DFT',
                'in_ports_count': 2
            }, pad_node)

        sub_node = create_op_with_const_inputs(graph, Sub, {1: int64_array(1)},
                                               {'name': mx_fft_name + '/Sub'})
        rank_node.out_port(0).connect(sub_node.in_port(0))
        broadcast_node2 = create_op_with_const_inputs(
            graph, Broadcast, {0: int64_array(0)},
            {'name': mx_fft_name + '/Reshape_broadcast'})
        sub_node.out_port(0).connect(broadcast_node2.in_port(1))
        concat_node = create_op_with_const_inputs(
            graph, Concat, {1: int64_array([-1, 2])}, {
                'name': mx_fft_name + '/New_shape',
                'in_ports_count': 2,
                'axis': 0
            }, broadcast_node2)

        reshape_node = Reshape(graph, {}).create_node([dft_node, concat_node])

        mx_fft.out_port(0).get_connection().set_source(
            reshape_node.out_port(0))
        rename_nodes([(mx_fft, mx_fft_name + '/to_be_removed'),
                      (reshape_node, mx_fft_name)])
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['flatten']
        name = node.soft_get('name', node.id)

        assert node.has_valid('axis'), 'Flatten {} has no mandatory `axis` attribute'.format(name)
        assert node.has_valid('end_axis'), 'Flatten {} has no mandatory `end_axis` attribute'.format(name)

        axis = node.axis
        end_axis = node.end_axis

        if end_axis == -1 and axis >= 0:
            begin_dims = Const(graph, {'value': int64_array([0] * axis)}).create_node()
            middle_dim = Const(graph, {'value': int64_array([-1])}).create_node()
            end_dims = Const(graph, {'value': int64_array([])}).create_node()
        else:
            rank = Rank(graph, {'name': name + '/input_rank'}).create_node()
            node.in_port(0).get_source().connect(rank.in_port(0))

            shape = Shape(graph, {'name': name + '/input_shape'}).create_node()
            node.in_port(0).get_source().connect(shape.in_port(0))

            begin_dims = get_shape_values_by_range_idxs(
                shape=shape, rank=rank, begin=0, end=axis)
            middle_dims = get_shape_values_by_range_idxs(
                shape=shape, rank=rank, begin=axis, end=end_axis, include_end=True)
            end_dims = get_shape_values_by_range_idxs(
                shape=shape, rank=rank, begin=end_axis, end=-1, include_begin=False, include_end=True)

            middle_dim = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]), {'keep_dims': True})
            middle_dims.out_port(0).connect(middle_dim.in_port(0))

        dim = new_shape_node_from_shape_nodes([begin_dims, middle_dim, end_dims])

        original_name = node.soft_get('name')
        abandoned_name = original_name + '/ShouldBeDeleted'
        reshape_node = Reshape(graph, {}).create_node()
        # Keep node with the same name to avoid confuse with renaming
        rename_nodes([(node, abandoned_name), (reshape_node, original_name)])
        reshape_node.in_port(1).connect(dim.out_port(0))

        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
        node.in_port(0).get_connection().set_destination(reshape_node.in_port(0))
    def replace_op(self, graph: Graph, node: Node):
        name = node.soft_get('name', node.id)

        # create range of axes for MVN based on `start_axis` and rank of input
        rank = Rank(graph, {'name': name + '/Rank'}).create_node()
        rng = create_op_with_const_inputs(graph, Range, {
            0: int64_array(2),
            2: int64_array(1)
        }, {
            'name': name + '/Range',
            'output_type': np.int64
        })
        mvn = MVN(
            graph, {
                'eps': node.epsilon,
                'eps_mode': 'inside_sqrt',
                'normalize_variance': 1,
                'name': name + '/Ins_Norm/MVN_',
            }).create_node()
        node.in_port(0).get_connection().set_destination(mvn.in_port(0))
        rng.out_port(0).connect(mvn.in_port(1))
        mul = Mul(graph, {
            'axis': 1,
            'name': name + '/Ins_Norm/mul_'
        }).create_node()
        mvn.out_port(0).connect(mul.in_port(0))
        node.in_port(1).get_connection().set_destination(mul.in_port(1))
        add = Add(graph, {
            'axis': 1,
            'name': name + '/Ins_Norm/add_'
        }).create_node()
        mul.out_port(0).connect(add.in_port(0))
        node.in_port(2).get_connection().set_destination(add.in_port(1))

        mvn.in_port(0).get_connection().add_destination(rank.in_port(0))
        rng.in_port(1).connect(rank.out_port(0))

        rename_nodes([(node, name + '/TBD'), (add, name)])

        return [add.id]
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='SpaceToBatch') + graph.get_op_nodes(
                op='BatchToSpace'):
            node.add_input_port(3, skip_if_exist=True)

            # convert TF representation of the pads/crops as [N, 2] to IE representation: [N] and [N]
            transposed_pads = create_op_with_const_inputs(
                graph, Transpose, {1: int64_array([1, 0])})
            node.in_port(2).get_connection().set_destination(
                transposed_pads.in_port(0))
            split_pads = create_op_with_const_inputs(graph, Split,
                                                     {1: int64_array(0)},
                                                     {'num_splits': 2})
            transposed_pads.out_port(0).connect(split_pads.in_port(0))
            for port_ind in range(2):
                node.in_port(port_ind + 2).connect(
                    split_pads.out_port(port_ind))
                node.in_port(port_ind + 2).get_connection().insert_node(
                    create_op_with_const_inputs(graph, Squeeze,
                                                {1: int64_array([0])}))

            # add zeros/ones to related inputs to align it with data input
            in0_rank = Rank(graph, {
                'name': node.name + '/rank_0'
            }).create_node()
            in1_shape = Shape(graph, {
                'name': node.name + '/rank_1'
            }).create_node()

            diff_size = Sub(graph, {
                'name': node.name + '/sub_0'
            }).create_node()
            diff = Sub(graph, {'name': node.name + '/sub_1'}).create_node()
            const_begin = Const(graph, {
                'value': int64_array([1])
            }).create_node()
            const_pad_val = Const(graph, {
                'value': int64_array(1)
            }).create_node()

            block_shape = Pad(graph, {
                'name': node.name + '/aligned_block_shape',
                'mode': 'constant'
            }).create_node()

            # in case of SpaceToBatch begin = pads_begin, end = pads_end
            # in case of BatchToSpace begin = crops_begin, end = crops_end
            new_begin_name = '/aligned_pads_begin'
            new_end_name = '/aligned_pads_end'
            if node.type == 'BatchToSpace':
                new_begin_name = '/aligned_crops_begin'
                new_end_name = '/aligned_crops_end'

            begin = Pad(graph, {
                'name': node.name + new_begin_name,
                'mode': 'constant'
            }).create_node()
            end = Pad(graph, {
                'name': node.name + new_end_name,
                'mode': 'constant'
            }).create_node()

            in0_rank_1d = create_op_node_with_second_input(
                graph, Unsqueeze, int64_array([0]),
                {'name': node.name + '/1d_rank_of_0'}, in0_rank)

            node.in_port(0).get_source().connect(in0_rank.in_port(0))
            node.in_port(1).get_source().connect(in1_shape.in_port(0))
            in0_rank_1d.out_port(0).connect(diff_size.in_port(0))
            in1_shape.out_port(0).connect(diff_size.in_port(1))
            diff_size.out_port(0).connect(diff.in_port(0))
            const_begin.out_port(0).connect(diff.in_port(1))
            const_pad_val.out_port(0).connect(block_shape.in_port(3))

            inputs_array = [block_shape, begin, end]
            for idx, input_to_node in enumerate(inputs_array):
                name_of_input_to_node = input_to_node.name
                node.in_port(idx + 1).get_connection().set_destination(
                    input_to_node.in_port(0))
                const_begin.out_port(0).connect(input_to_node.in_port(1))
                diff.out_port(0).connect(input_to_node.in_port(2))
                input_to_node.out_port(0).connect(node.in_port(idx + 1))
                convert = Cast(graph, {
                    'name': name_of_input_to_node + '/i64',
                    'dst_type': np.int64
                }).create_node()
                input_to_node.in_port(0).get_connection().insert_node(convert)
Exemple #10
0
def replace_resize(graph: Graph, resize: Node):
    log.debug("Converting of ONNX Resize-10 to Interpolate-4 "
              "is triggered for node {}.".format(
                  resize.soft_get('name', resize.id)))

    resize_name = resize.soft_get('name', resize.id)

    rank_node = Rank(graph, {'name': resize_name + '/max_axes'}).create_node()
    range_node = create_op_with_const_inputs(graph, Range, {
        0: int64_array(2),
        2: int64_array(1)
    }, {'name': resize_name + '/axes'})

    sizes_ss = create_op_with_const_inputs(graph, StridedSlice, {
        1: int64_array([2]),
        2: int64_array([0]),
        3: int64_array([1])
    }, {
        'name': resize_name + '/sizes_ss',
        'begin_mask': int64_array([1]),
        'end_mask': int64_array([0]),
        'new_axis_mask': int64_array([0]),
        'shrink_axis_mask': int64_array([0]),
        'ellipsis_mask': int64_array([0])
    })
    scales_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([2]),
            2: int64_array([0]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/scales_ss',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([0]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })

    rank_node.out_port(0).connect(range_node.in_port(1))

    interpolate_node = Interpolate(
        graph, {
            'version': 'opset4',
            'mode': 'linear_onnx' if resize.mode == 'linear' else 'nearest',
            'coordinate_transformation_mode': 'asymmetric',
            'cube_coeff': -0.75,
            'nearest_mode': 'simple',
            'pads_begin': int64_array([0]),
            'pads_end': int64_array([0]),
            'antialias': 0,
            'shape_calculation_mode': 'scales',
            'in_ports_count': 4
        }).create_node()

    range_node.out_port(0).connect(interpolate_node.in_port(3))
    shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()

    # When we calculate 'sizes' input as floor(input_shape * scales), we can get incorrect 'sizes' if, e.g.,
    # scales = [1.0, 1.0, 1.33333, 2.0], input_shape = [1, 3, 30, 200], because
    # input_shape * scales = [1, 3, 39.9999, 400], and floor(input_shape * scales)[2] == 39, not 40.
    # Maybe we need to calculate 'sizes' input as floor(input_shape * scales + eps), where eps is some small
    # floating point number, e.g. 1.0e-5. But, in this case, if scales = [1.0, 1.0, 1.333333, 2.0],
    # input_shape = [1, 3, 30, 200], floor(input_shape * scales + eps) = 39, not 40, because
    # input_shape[2] * scales[2] + 1.0e-5 =  39.99991.
    # Hence, we need to calculate 'sizes' as floor(input_shape * (scales + eps)).
    add_node = create_op_with_const_inputs(graph, Add,
                                           {1: float_array([1.0e-5])},
                                           {'name': resize_name + '/Add'})

    dst_dtype = np.float32  # even if data_type=FP16 use float32 for shape values

    cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()

    shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
    mul_node = Mul(graph, {
        'name': resize_name + '/Mul'
    }).create_node([cast_shape_to_float, add_node])
    floor_node = Floor(graph, {
        'name': resize_name + '/Floor'
    }).create_node([mul_node])
    cast_mul_result_to_int = Cast(graph, {
        'dst_type': np.int64
    }).create_node([floor_node])
    cast_mul_result_to_int.out_port(0).connect(sizes_ss.in_port(0))
    sizes_ss.out_port(0).connect(interpolate_node.in_port(1))

    scales_ss.out_port(0).connect(interpolate_node.in_port(2))

    connection_of_resize_input = resize.in_port(0).get_connection()
    connection_of_resize_input.set_destination(interpolate_node.in_port(0))

    connection_of_scales = resize.in_port(1).get_connection()
    connection_of_scales.set_destination(scales_ss.in_port(0))

    connection_of_resize_input.get_source().connect(shape_of.in_port(0))
    connection_of_resize_input.get_source().connect(rank_node.in_port(0))
    connection_of_scales.get_source().connect(add_node.in_port(0))

    rename_nodes([(resize, resize_name + '/delete'),
                  (interpolate_node, resize_name)])
    resize.out_port(0).get_connection().set_source(
        interpolate_node.out_port(0))
    def mxrepeat_decomposition(node: Node):
        graph = node.graph
        name = node.soft_get('name', node.id)

        rename_node(node, name + '/to_be_removed')

        # Unqueeze
        input_rank = Rank(graph, {'name': name + '/Rank'}).create_node()
        node.in_port(0).get_source().connect(input_rank.in_port(0))

        axis = get_canonical_axis_index_node(input_rank, node.axis)
        unsqueeze_axis = create_op_node_with_second_input(
            graph,
            Add,
            int64_array([1]), {'name': name + '/Unsqueeze/Axis'},
            input_node=axis)

        unsqueeze = Unsqueeze(graph, {
            'name': name + '/Unsqueeze'
        }).create_node()
        unsqueeze.in_port(1).connect(unsqueeze_axis.out_port(0))

        # Tile (1, 1, ..., repeats, ..., 1)
        # we generate tile array according to the following table:

        # parts:       |      first      |  repeats |  second     |
        # i:           | 0, 1, ..., axis,| axis + 1,| ..., rank+1 |
        # tile_array:  | 1, 1, ...,  1  ,| repeats ,| ...,   1    |

        one = Const(graph, {
            'name': name + '/Broadcast/One',
            'value': int64_array([1])
        }).create_node()
        first_ones = Broadcast(graph, {
            'name': name + '/Broadcast/Ones_first_part'
        }).create_node()
        first_ones.in_port(0).connect(one.out_port(0))
        first_ones.in_port(1).connect(unsqueeze_axis.out_port(0))

        repeats = Const(graph, {
            'name': name + '/repeats',
            'value': int64_array([node.repeats])
        }).create_node()

        second_ones = Broadcast(graph, {
            'name': name + '/Broadcast/Ones_second_part'
        }).create_node()
        second_part_broadcast_shape = Sub(
            graph, {
                'name': name + '/Broadcast/Shape/second_part'
            }).create_node()
        second_part_broadcast_shape.in_port(0).connect(input_rank.out_port(0))
        second_part_broadcast_shape.in_port(1).connect(
            unsqueeze_axis.out_port(0))
        second_ones.in_port(0).connect(one.out_port(0))
        second_ones.in_port(1).connect(second_part_broadcast_shape.out_port(0))

        tile_repeats = new_shape_node_from_shape_nodes(
            [first_ones, repeats, second_ones])
        tile = Tile(graph, {'name': name + '/Tile'}).create_node()
        tile.in_port(1).connect(tile_repeats.out_port(0))

        # Reshape (input_shape[:axis], input_shape[axis] * repeats, input_shape[axis+1:])
        # we generate reshape dim array according to the following table:

        # parts:       |    first   |                rep           |  second   |
        # i:           | 0, 1, ... ,|               axis,          | ..., rank |
        # dim_array:   | inp_sh[i] ,| input_shape[axis] * repeats ,| inp_sh[i] |

        input_shape = Shape(graph, {'name': name + '/Shape'}).create_node()
        node.in_port(0).get_source().connect(input_shape.in_port(0))

        first_input_shape_part = get_shape_values_by_range_idxs(
            input_shape,
            input_rank,
            begin=0,
            end=node.axis,
            include_begin=True,
            include_end=False)

        original_axis_dim = create_op_with_const_inputs(
            graph,
            Gather, {2: int64_array(0)}, {'name': name + '/OriginalDim'},
            input_node=input_shape)
        original_axis_dim.in_port(1).connect(axis.out_port(0))

        repeated_dimention = Mul(graph, {
            'name': name + '/RepeatedDim'
        }).create_node()
        repeated_dimention.in_port(0).connect(original_axis_dim.out_port(0))
        repeated_dimention.in_port(1).connect(repeats.out_port(0))

        second_input_shape_part = get_shape_values_by_range_idxs(
            input_shape,
            input_rank,
            begin=node.axis,
            end=-1,
            include_begin=False,
            include_end=True)

        output_shape = new_shape_node_from_shape_nodes([
            first_input_shape_part, repeated_dimention, second_input_shape_part
        ])

        reshape = Reshape(graph, {'name': name}).create_node()
        rename_node(reshape, name)
        reshape.in_port(1).connect(output_shape.out_port(0))

        # Final connections
        node.in_port(0).get_connection().set_destination(unsqueeze.in_port(0))
        tile.in_port(0).connect(unsqueeze.out_port(0))
        reshape.in_port(0).connect(tile.out_port(0))
        node.out_port(0).get_connection().set_source(reshape.out_port(0))
 def extract(cls, node: Node):
     Rank.update_node_stat(node, {'output_type': np.int32})
     return cls.enabled