Ejemplo n.º 1
0
    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
        cmp = match['complex']
        complex_abs = match['abs']
        complex_abs_name = complex_abs.soft_get('name', complex_abs.id)

        power_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)

        pow0 = create_op_with_const_inputs(
            graph, Pow, {1: power_type(2.0)},
            {'name': complex_abs_name + '/real_part_squared'})
        pow1 = create_op_with_const_inputs(
            graph, Pow, {1: power_type(2.0)},
            {'name': complex_abs_name + '/imag_part_squared'})

        cmp.in_port(0).get_connection().set_destination(pow0.in_port(0))
        cmp.in_port(1).get_connection().set_destination(pow1.in_port(0))

        add = Add(graph, {
            'name': complex_abs_name + '/squared_abs'
        }).create_node([pow0, pow1])
        sqrt = create_op_with_const_inputs(graph, Pow, {1: power_type(0.5)},
                                           {})
        add.out_port(0).connect(sqrt.in_port(0))

        complex_abs.out_port(0).get_connection().set_source(sqrt.out_port(0))

        rename_nodes([(complex_abs, complex_abs_name + '/to_be_removed'),
                      (sqrt, complex_abs_name)])
Ejemplo n.º 2
0
    def find_and_replace_pattern(self, graph: Graph):
        for dequantize_node in graph.get_op_nodes(op='DequantizeLinear'):
            node_name = dequantize_node.soft_get('name', dequantize_node.id)
            axis = dequantize_node.soft_get('axis', None)
            scale_y_shape = dequantize_node.in_port(1).data.get_shape()
            model_data_type = data_type_str_to_np(
                graph.graph['cmd_params'].data_type)
            cast = Cast(graph, {
                'dst_type': model_data_type,
                'name': node_name + '/Cast'
            }).create_node()
            dequantize_node.in_port(0).get_connection().set_destination(
                cast.in_port(0))
            mul = Mul(graph, {'can_be_fused': False}).create_node()

            is_second_port_connected = dequantize_node.is_in_port_connected(2)
            if is_second_port_connected:
                # its is necessary not to replace subrtract for pattern in offline transformations
                # See ConvertQuantizeDequantize transformation in ngraph
                sub = Sub(graph, {
                    'name': node_name + '/Sub',
                    'zero_point_sub': True
                }).create_node()
                cast.out_port(0).connect(sub.in_port(0))
                dequantize_node.in_port(2).get_connection().set_destination(
                    sub.in_port(1))
                sub.out_port(0).connect(mul.in_port(0))
            else:
                cast.out_port(0).connect(mul.in_port(0))

            dequantize_node.in_port(1).get_connection().set_destination(
                mul.in_port(1))
            dequantize_node.out_port(0).get_connection().set_source(
                mul.out_port(0))
            rename_nodes([(dequantize_node, node_name + '/TBD'),
                          (mul, node_name)])

            assert scale_y_shape is not None
            if axis is not None and len(
                    scale_y_shape) > 0 and scale_y_shape[0] > 1:
                input_shape = cast.in_port(0).data.get_shape()
                target_shape = np.ones(len(input_shape), np.int64)
                target_shape[axis] = input_shape[axis]

                mul_reshape = create_op_with_const_inputs(
                    graph, Reshape, {1: int64_array(target_shape)},
                    {'name': node_name + '/Reshape/Mul'})
                mul.in_port(1).get_connection().set_destination(
                    mul_reshape.in_port(0))
                mul_reshape.out_port(0).connect(mul.in_port(1))

                if is_second_port_connected:
                    sub_reshape = create_op_with_const_inputs(
                        graph, Reshape, {1: int64_array(target_shape)},
                        {'name': node_name + '/Reshape/Sub'})
                    sub.in_port(1).get_connection().set_destination(
                        sub_reshape.in_port(0))
                    sub_reshape.out_port(0).connect(sub.in_port(1))
Ejemplo n.º 3
0
def create_ss_interval_border(graph: Graph, slice_border_port: Port,
                              shape: np.ndarray, axes: np.ndarray,
                              node_name: str):
    """
    This function creates "begin"/"end" parameters for the StridedSlice based on Slice's "starts"/"ends"

    :param graph: graph to operate on.
    :param slice_border_port: node output port that provides "starts"/"ends" values for the Slice.
    :param shape: input shape of the Slice
    :param axes: axes that "starts" and "ends" apply to
    :param node_name: Slice node name
    :return: Concat node that forms "begin"/"end" values for the StridedSlice
    """
    # the value for 'starts' or 'ends' might be maximum/minimum possible value of int64. This
    # value must be converted to maximum/minimum of int32 because such big values do not fit into the int32 which is
    # supported by the StridedSlice layer
    clamp = create_op_with_const_inputs(graph,
                                        Clamp,
                                        port_value_dict={
                                            1: np.iinfo(np.int32).min,
                                            2: np.iinfo(np.int32).max
                                        },
                                        op_attrs=dict(name=node_name +
                                                      '/Clamp'))
    clamp.in_port(0).connect(slice_border_port)
    # we have to convert "starts"/"ends" values from the network to one data type with constant values that are created
    # here to prevent type errors in Concat node
    cast = Cast(graph, dict(name=node_name + '/CastToI64',
                            dst_type=np.int64)).create_node()
    cast.in_port(0).connect(clamp.out_port(0))
    concat = Concat(graph, dict(name=node_name + '/Concat',
                                axis=0)).create_node()
    for value_idx, port_idx in enumerate(axes):
        concat.add_input_port(port_idx)
        # "axes" may not be sorted, so we need to split "starts"/"ends" values and connect each value to the correct
        # Concat input port
        value = create_op_with_const_inputs(
            graph,
            Gather,
            port_value_dict={
                1: int64_array([value_idx]),
                2: int64_array(0)
            },
            op_attrs={'name': node_name + '/Gather'})
        cast.out_port(0).connect(value.in_port(0))
        value.out_port(0).connect(concat.in_port(port_idx))
    for port_idx in range(len(shape)):
        if not concat.is_in_port_connected(port_idx):
            concat.add_input_port(port_idx)
            # This border value would be ignored in StridedSlice because of the begin_mask\end_mask
            const = Const(
                graph, dict(name=node_name + '/Const',
                            value=int64_array([0]))).create_node()
            const.out_port(0).connect(concat.in_port(port_idx))

    return concat
    def replace_pattern(self, graph: Graph, match: dict):
        if not self.is_applicable(match):
            return

        unsqueeze_node = match['unsqueeze']
        unsqueeze_name = unsqueeze_node.soft_get('name', unsqueeze_node.id)
        second_input_of_unsqueeze = unsqueeze_node.in_port(
            1).get_connection().get_source().node
        d_idx = int(second_input_of_unsqueeze.value)
        axis = d_idx - 1

        shape_node = Shape(graph,
                           dict(name=unsqueeze_name + '/Shape')).create_node()
        axis_len_node = node_to_get_shape_value_of_indices(shape_node, [axis])

        second_input_of_tile = match['tile'].in_port(
            1).get_connection().get_source().node
        scale = int64_array([second_input_of_tile.value[d_idx]])
        float_scale = float32_array([second_input_of_tile.value[d_idx]])
        mul_node = create_op_with_const_inputs(
            graph, Mul, {1: scale}, {'name': unsqueeze_name + '/Mul'})

        axis_len_node.out_port(0).connect(mul_node.in_port(0))

        interp_node = create_op_with_const_inputs(
            graph, Interpolate, {
                2: float_scale,
                3: int64_array([axis])
            }, {
                'mode': 'nearest',
                'antialias': 0,
                'pads_begin': int64_array([0]),
                'pads_end': int64_array([0]),
                'coordinate_transformation_mode': 'half_pixel',
                'nearest_mode': 'round_prefer_floor',
                'cube_coeff': -0.75,
                'version': 'opset4',
                'shape_calculation_mode': 'scales',
                'in_ports_count': 4,
                'maybe_part_of_sequence': True
            })
        mul_node.out_port(0).connect(interp_node.in_port(1))

        reshape_node = match['reshape']
        reshape_node.out_port(0).get_connection().set_source(
            interp_node.out_port(0))
        reshape_name = reshape_node.soft_get('name', reshape_node.id)
        rename_nodes([(reshape_node, reshape_name + '/delete'),
                      (interp_node, reshape_name)])

        unsqueeze_connection = unsqueeze_node.in_port(0).get_connection()
        unsqueeze_connection.set_destination(interp_node.in_port(0))
        unsqueeze_connection.get_source().connect(shape_node.in_port(0))
Ejemplo n.º 5
0
    def replace(node: Node, const: Node):
        graph = node.graph
        shape = const.shape
        const_name = const.soft_get('name', const.id)

        non_one_dims = np.argwhere(shape != 1).flatten()
        one_dims = np.argwhere(shape == 1).flatten()

        if not (non_one_dims.size == 1 and 5 < np.prod(shape) < 500):
            # (5;500) range is deduced to affect less models
            return

        value = const.value
        if not np.array_equal(np.arange(0, np.prod(shape), 1).reshape(shape), value):
            return

        positive_idx = non_one_dims.item(0)
        negative_idx = positive_idx - len(shape)

        node_name = node.soft_get('name', node.id)
        gather = create_op_with_const_inputs(graph, Gather, {1: int64_array(negative_idx), 2: int64_array(0)},
                                             {'name': node_name + '/BroadcastingDim'})
        gather_for_const = create_op_with_const_inputs(graph, Gather, {1: int64_array(negative_idx), 2: int64_array(0)},
                                                       {'name': const_name + '/BroadcastingDim'})
        shapeof_node = Shape(graph, {'name': const_name + '/ShapeOf'}).create_node()
        shapeof_node.out_port(0).connect(gather_for_const.in_port(0))

        equal_node = create_op_with_const_inputs(graph, Equal, {1: int64_array(1)}, {'name': node_name + '/ConstOne'})
        gather.out_port(0).connect(equal_node.in_port(0))

        select_node = Select(graph, {'name': node_name + '/Select',
                                      'auto_broadcast': 'numpy'}).create_node([equal_node, gather_for_const, gather])

        const.out_port(0).connect(shapeof_node.in_port(0))

        range_node = create_op_with_const_inputs(graph, Range,
                                                 {0: mo_array(0, dtype=value.dtype),
                                                  2: mo_array(1, dtype=value.dtype)},
                                                 {'name': const_name + '/Range', 'dtype': value.dtype})
        select_node.out_port(0).connect(range_node.in_port(1))

        node.in_port(1).get_connection().add_destination(gather.in_port(0))

        node.in_port(0).get_connection().set_source(range_node.out_port(0))

        if one_dims.size:
            unsqueeze = create_op_node_with_second_input(graph, Unsqueeze, one_dims,
                                                         {'name': const_name + '/KeepShape'})
            range_node.out_port(0).get_connection().insert_node(unsqueeze)
            rename_nodes([(const, const_name + '/ToBeDeleted'), (unsqueeze, const_name)])
        else:
            rename_nodes([(const, const_name + '/ToBeDeleted'), (range_node, const_name)])
    def find_and_replace_pattern(self, graph: Graph):
        shape_ops = graph.get_op_nodes(op='ShapeOf')

        # 1. Inserting Gather to N*C format on constant shape paths
        for shape in shape_ops:
            source_port = shape.in_port(0).get_source()
            if is_output_data_in_correct_layout(source_port.node, source_port.idx):
                continue  # data is already in N*C format

            name = shape.soft_get('name', shape.id)
            rank = source_port.data.get_shape().size

            if rank in [4, 5]:
                index = int64_array([0, *list(range(2, rank)), 1])
            else:
                continue  # data is layout independent

            gather = create_op_with_const_inputs(graph, op=Gather, port_value_dict={1: index, 2: int64_array(0)},
                                                 op_attrs={'name': name + '/GatherNCHWtoNHWC'})
            shape.out_port(0).get_connection().insert_node(gather)

        # 2. Inserting Gather/Transpose to NC* format
        shape_sub_graph_end_points = self.find_shape_subgraph_endpoints([shape.out_port(0) for shape in shape_ops])
        for in_port in shape_sub_graph_end_points:
            name = in_port.node.soft_get('name', in_port.node.id)
            shape = in_port.data.get_shape()

            should_switch_layout = not any([is_output_data_in_correct_layout(port.node, port.idx)
                                            for port in in_port.node.out_ports().values() if not port.disconnected()])
            should_insert_gather = should_switch_layout and len(shape) == 1 and shape.item(0) in [4, 5]
            should_insert_transpose = should_switch_layout and len(shape) in [4, 5]

            if should_insert_gather:
                # we should turn input permutation off to perform it with the following gather insertion
                in_port.__setattr__('input_permutation', None)
                index = int64_array([0, shape.item(0) - 1, *list(range(1, shape.item(0) - 1))])
                gather = create_op_with_const_inputs(graph, op=Gather,
                                                     port_value_dict={1: index, 2: int64_array(0)},
                                                     op_attrs={'name': name + '/GatherNHWCtoNCHW'})
                in_port.get_connection().insert_node(gather)
            elif should_insert_transpose:
                # we should turn input permutation off to perform it with the following transpose insertion
                in_port.__setattr__('input_permutation', None)
                order = int64_array([0, len(shape) - 1, *list(range(1, len(shape) - 1))])
                transpose = create_op_with_const_inputs(graph, op=Transpose, port_value_dict={1: order},
                                                        op_attrs={'name': name + '/TransposeNHWCtoNCHW',
                                                                  'override_output_shape': True})
                mark_input_as_in_correct_layout(transpose, 0)
                mark_output_as_in_correct_layout(transpose, 0)
                in_port.get_connection().insert_node(transpose)
            else:
                continue  # data is layout independent
Ejemplo n.º 7
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']
        name = node.soft_get('name', node.id)

        assert node.has_valid(
            'axis'), 'Slice operation `{}` has no `axis` parameter'.format(
                name)
        axis = int64_array(node.axis)
        if axis.size != 1:
            return

        assert node.has_valid(
            'slice_point'
        ), 'Slice operation `{}` has no `slice_point` parameter'.format(name)
        slice_point = node.slice_point

        if slice_point.size == 0:
            num_splits = len(node.out_ports())
            split_node = create_op_with_const_inputs(graph,
                                                     op=Split,
                                                     port_value_dict={1: axis},
                                                     op_attrs={
                                                         'name': name,
                                                         'num_splits':
                                                         num_splits
                                                     })
        else:
            size_splits = []
            curr_pos = 0
            for point in slice_point:
                assert point > curr_pos
                size_splits.append(point - curr_pos)
                curr_pos = point
            size_splits.append(-1)

            split_node = create_op_with_const_inputs(
                graph,
                op=VariadicSplit,
                port_value_dict={
                    1: axis,
                    2: int64_array(size_splits)
                },
                op_attrs={
                    'name': name,
                    'out_ports_count': len(slice_point) + 1
                })

        node.in_port(0).get_connection().set_destination(split_node.in_port(0))
        for i, port in node.out_ports().items():
            node.out_port(i).get_connection().set_source(
                split_node.out_port(i))
Ejemplo n.º 8
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='ATen', operator='embedding_bag'):
            assert node.soft_get('mode') == 0, 'ATen::embedding_bag has unsupported mode, only "sum" ' \
                                               'mode is supported for node {}.'.format(node.id)
            node_name = node.soft_get('name', node.id)
            rename_node(node, node_name + '/TBR')
            is_packed = False
            if len(node.in_ports()) < 3 or node.in_port(2).disconnected():
                is_packed = True
                embedding_bag = EmbeddingBagPackedSum(graph, {'name': node_name}).create_node()
            else:
                embedding_bag = EmbeddingBagOffsetsSum(graph, {'name': node_name}).create_node()
                node.in_port(2).get_connection().set_destination(embedding_bag.in_port(2))
            rename_node(embedding_bag, node_name)
            node.in_port(0).get_connection().set_destination(embedding_bag.in_port(0))
            node.in_port(1).get_connection().set_destination(embedding_bag.in_port(1))
            node.out_port(0).get_connection().set_source(embedding_bag.out_port(0))
            if len(node.in_ports()) == 4 and not node.in_port(3).disconnected():
                if is_packed:
                    node.in_port(3).get_connection().set_destination(embedding_bag.in_port(2))
                else:
                    # connect per_sample_weights
                    node.in_port(3).get_connection().set_destination(embedding_bag.in_port(4))

                    weights_shape_node = Shape(graph, {'name': node_name + '/WeightsShape'}).create_node()

                    weights_rank_node = Rank(graph, {'name': node_name + '/WeightsRank'}).create_node()
                    last_dim_node = get_canonical_axis_index_node(weights_rank_node, -1)
                    weights_last_dim = get_shape_values_by_indices_node(weights_shape_node, last_dim_node)

                    weights_first_dim = node_to_get_shape_value_of_indices(weights_shape_node, [0])

                    zero_col_node = create_op_with_const_inputs(graph, Broadcast, {0: int64_array([0])},
                                                                {'name': node_name + '/Broadcast'})
                    zero_col_node.in_port(1).connect(weights_last_dim.out_port(0))

                    default_embeddings_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(0)},
                                                                          {'name': node_name + '/Unsqueeze'})
                    default_embeddings_node.in_port(0).connect(zero_col_node.out_port(0))

                    # expand embedding table with zeros
                    weights_concat = Concat(graph, {'axis': 0, 'in_ports_count': 2,
                                                    'name': node_name + '/Concat'}).create_node()
                    embedding_bag.in_port(0).get_connection().set_destination(weights_concat.in_port(0))
                    weights_concat.in_port(0).get_connection().add_destination(weights_shape_node.in_port(0))
                    weights_concat.in_port(0).get_connection().add_destination(weights_rank_node.in_port(0))
                    weights_concat.in_port(1).connect(default_embeddings_node.out_port(0))
                    weights_concat.out_port(0).connect(embedding_bag.in_port(0))

                    # point default index to expanded part of embedding table
                    weights_first_dim.out_port(0).connect(embedding_bag.in_port(3))
Ejemplo n.º 9
0
    def convert_fft_to_dft(self, graph: Graph, mx_fft: Node):
        mx_fft_name = mx_fft.soft_get('name', mx_fft.id)
        unsqueeze_node = create_op_with_const_inputs(
            graph, Unsqueeze, {1: int64_array([-1])},
            {'name': mx_fft_name + '/Unsqueeze'})
        rank_node = Rank(graph, {'name': mx_fft_name + '/Rank'}).create_node()

        mx_fft_connection = mx_fft.in_port(0).get_connection()
        mx_fft_connection.set_destination(unsqueeze_node.in_port(0))
        mx_fft_connection.get_source().connect(rank_node.in_port(0))

        add_node = create_op_with_const_inputs(graph, Add, {1: int64_array(1)},
                                               {'name': mx_fft_name + '/Add'},
                                               rank_node)
        broadcast_node1 = create_op_with_const_inputs(
            graph, Broadcast, {0: int64_array(0)},
            {'name': mx_fft_name + '/Pad_broadcast'})
        add_node.out_port(0).connect(broadcast_node1.in_port(1))

        scatter_node = create_op_with_const_inputs(
            graph, ScatterUpdate, {
                2: int64_array(1),
                3: int64_array(0)
            }, {'name': mx_fft_name + '/ScatterUpdate'})
        broadcast_node1.out_port(0).connect(scatter_node.in_port(0))
        rank_node.out_port(0).connect(scatter_node.in_port(1))

        pad_node = Pad(graph, {
            'name': mx_fft_name + '/Pad',
            'mode': 'constant'
        }).create_node([unsqueeze_node, broadcast_node1, scatter_node])

        dft_node = create_op_with_const_inputs(
            graph, DFT, {1: int64_array([-1])}, {
                'name': mx_fft_name + '/DFT',
                'in_ports_count': 2
            }, pad_node)

        sub_node = create_op_with_const_inputs(graph, Sub, {1: int64_array(1)},
                                               {'name': mx_fft_name + '/Sub'})
        rank_node.out_port(0).connect(sub_node.in_port(0))
        broadcast_node2 = create_op_with_const_inputs(
            graph, Broadcast, {0: int64_array(0)},
            {'name': mx_fft_name + '/Reshape_broadcast'})
        sub_node.out_port(0).connect(broadcast_node2.in_port(1))
        concat_node = create_op_with_const_inputs(
            graph, Concat, {1: int64_array([-1, 2])}, {
                'name': mx_fft_name + '/New_shape',
                'in_ports_count': 2,
                'axis': 0
            }, broadcast_node2)

        reshape_node = Reshape(graph, {}).create_node([dft_node, concat_node])

        mx_fft.out_port(0).get_connection().set_source(
            reshape_node.out_port(0))
        rename_nodes([(mx_fft, mx_fft_name + '/to_be_removed'),
                      (reshape_node, mx_fft_name)])
def insert_transpose(graph: Graph, input_port: Port, before_input=True):
    from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs

    input_rank = len(input_port.data.get_shape())
    if input_rank > 3:
        if before_input:
            axis_order = np.concatenate(
                (int64_array([0]), int64_array(list(range(2, input_rank))),
                 int64_array([1])))
            source_node = input_port.get_source().node
            transpose_name = source_node.soft_get(
                'name', source_node.id) + '/TransposeToNHWC'
        else:
            axis_order = np.concatenate(
                (int64_array([0]), int64_array([input_rank - 1]),
                 int64_array(list(range(1, input_rank - 1)))))
            transpose_name = input_port.node.soft_get(
                'name', input_port.node.id) + '/TransposeToNCHW'
            input_port.node['need_shape_inference'] = True
            input_port.node['override_output_shape'] = True
        transpose = create_op_with_const_inputs(graph, Transpose,
                                                {1: axis_order},
                                                {'name': transpose_name})
        input_port.get_connection().insert_node(transpose)
        transpose['need_shape_inference'] = True
        transpose['override_output_shape'] = True
Ejemplo n.º 11
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='AttributedVariadicSplit'):
            name = node.soft_get('name', node.id)

            axis = node.soft_get('axis', None)
            assert axis is not None, \
                'AttributedVariadicSplit should have `axis` parameter set, but it`s not for node {}'.format(name)

            size_splits = node.soft_get('size_splits', None)
            assert size_splits is not None, \
                'AttributedVariadicSplit should have `size_splits` parameter set, but it`s not for node {}'.format(name)

            split = create_op_with_const_inputs(
                graph, VariadicSplit, {
                    1: np.int64(axis),
                    2: size_splits
                }, {
                    'name': name + '/VariadicSplit',
                    'out_ports_count': len(size_splits)
                })

            for idx, port in node.out_ports().items():
                port.get_connection().set_source(split.out_port(idx))

            node.in_port(0).get_connection().set_destination(split.in_port(0))
            graph.remove_node(node.id)
Ejemplo n.º 12
0
    def normalize_body_graph(loop_node: Node):
        loop_name = loop_node.soft_get('name', loop_node.id)
        # connect "trip count" input if it is not connected with default value "Infinity" (-1)
        if not loop_node.is_in_port_connected(0):
            loop_node.add_input_port(0, skip_if_exist=True)
            Const(loop_node.graph, {'name': loop_name + '/trip_count', 'value': int64_array(-1)}).\
                create_node().out_port(0).connect(loop_node.in_port(0))

        # connect "execution condition" input if it is not connected with default value True
        if not loop_node.is_in_port_connected(1):
            loop_node.add_input_port(1, skip_if_exist=True)
            Const(loop_node.graph, {'name': loop_name + '/execution_cond', 'value': np.array(True, dtype=np.bool)}).\
                create_node().out_port(0).connect(loop_node.in_port(1))

        # scan output need Unsqueeze over axis 0
        for record in loop_node.output_port_map:
            body_node = Loop.get_body_node_by_internal_id(loop_node, record['internal_layer_id'])
            assert body_node is not None
            assert body_node.soft_get('type') == 'Result'

            if record['axis'] is not None:
                unsqueeze = create_op_with_const_inputs(loop_node.body, Unsqueeze, {1: int64_array([0])})
                body_node.in_port(0).get_connection().insert_node(unsqueeze)

        Loop.normalize_input_output_ports(loop_node)
Ejemplo n.º 13
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='MVNCaffe'):
            node_name = node.soft_get('name', node.id)

            start_axis = 2
            if node['across_channels'] == 1:
                start_axis = 1

            rank = Rank(graph, {'name': node_name + '/Rank'}).create_node()

            # create range of axes based on `start_axis` and rank of input
            rng = create_op_with_const_inputs(graph, Range, {
                0: int64_array(start_axis),
                2: int64_array(1)
            }, {
                'name': node_name + '/Range',
                'output_type': np.int64
            })
            rng.in_port(1).connect(rank.out_port(0))

            new_mvn = MVN(
                graph, {
                    'eps': node.soft_get('eps', 1e-9),
                    'eps_mode': 'inside_sqrt',
                    'normalize_variance': node.soft_get(
                        'normalize_variance', 1)
                }).create_node([node.in_port(0).get_source().node, rng])
            new_mvn.in_port(0).get_connection().add_destination(
                rank.in_port(0))
            node.out_port(0).get_connection().set_source(new_mvn.out_port(0))
            rename_nodes([(node, node_name + '/tbd'), (new_mvn, node_name)])

            graph.remove_node(node.id)
Ejemplo n.º 14
0
    def replace_op(self, graph: Graph, node: Node):
        # save the original node name to use it in the new Pad op instance
        original_name = node.soft_get('name', node.id)
        rename_node(node, original_name + '/TBR')

        new_pad = Pad(graph, {
            'mode': node.soft_get('mode', None)
        }).create_node()
        rename_node(new_pad, original_name)

        node.in_port(0).get_connection().set_destination(new_pad.in_port(0))

        if node.soft_get('mode') == 'constant':
            # the input with fill value is an optional third input in ONNX
            if not node.in_port(2).disconnected():
                node.in_port(2).get_connection().set_destination(
                    new_pad.in_port(3))
            else:
                new_pad.in_port(3).connect(
                    Const(graph, {
                        'value': 0.0
                    }).create_node().out_port(0))

        # convert ONNX representation of the pads as [2 * N] to MO representation: [N] and [N]
        split_pads = create_op_with_const_inputs(graph, Split,
                                                 {1: int64_array(0)},
                                                 {'num_splits': 2})
        node.in_port(1).get_connection().set_destination(split_pads.in_port(0))
        split_pads.out_port(0).connect(new_pad.in_port(1))
        split_pads.out_port(1).connect(new_pad.in_port(2))

        return [new_pad.id]
Ejemplo n.º 15
0
    def find_and_replace_pattern(self, graph: Graph):
        for fake_output in graph.get_op_nodes(op='FakeOutput'):
            name = fake_output.soft_get('name', fake_output.id)

            producer = fake_output.in_port(0).get_source().node
            producer_outputs = 0
            for port in producer.out_ports().values():
                if not port.disconnected():
                    producer_outputs += 1
            if producer_outputs != 1:
                # At this stage we don't know the type of output, so we rely on MO transformation which updates the
                # Const type for elementwise operations in case of input data types mismatch
                add = create_op_with_const_inputs(graph, Add, {1: int64_array(0)}, {'can_be_fused': False})
                rename_nodes([(fake_output, name + '/TBD'), (add, name)])

                prev_op_in_port = fake_output.in_port(0).get_connection().get_source()
                # Get tensor names incoming to FakeOutput
                tensor_names = prev_op_in_port.get_tensor_names()

                # Remove tensor info from data node
                prev_op_in_port.remove_tensor_names()

                fake_output.in_port(0).get_connection().set_destination(add.in_port(0))
                fake_output.out_port(0).get_connection().set_source(add.out_port(0))

                # Move tensor names to Add op, which replaces FakeOutput
                if len(tensor_names) > 0:
                    add.out_port(0).add_tensor_names(tensor_names)

            else:
                result_in_port = fake_output.out_port(0).get_destination()
                result_in_port.disconnect()
                fake_output.in_port(0).get_connection().set_destination(result_in_port)
                rename_nodes([(fake_output, name + '/TBD'), (producer, name)])
Ejemplo n.º 16
0
    def make_interpolate_reshapeable(interpolate, concat):
        assert interpolate.soft_get('type') == 'Interpolate'
        assert concat.soft_get('type') == 'Concat'

        output_shape = interpolate.out_port(0).data.get_shape()

        interp_axes = [get_canonical_axis_index(output_shape, axis) for axis in Interpolate.get_axes(interpolate)]
        concat_axis = get_canonical_axis_index(output_shape, concat.axis)
        if concat_axis in interp_axes:
            return

        concat_srcs = [port.get_source() for port in concat.in_ports().values() if not port.disconnected()]
        non_interp_concat_srcs = [src for src in concat_srcs if src.node.soft_get('type') != 'Interpolate']
        if len(non_interp_concat_srcs) == 0:
            return

        graph = interpolate.graph
        src = non_interp_concat_srcs[0]

        shape = Shape(graph, {'name': src.node.soft_get('name', src.node.id) + '/Shape'}).create_node()
        shape.in_port(0).connect(src)
        gather = create_op_with_const_inputs(graph, Gather,
                                             {1: np.array(interp_axes, dtype=np.int32), 2: int64_array(0)},
                                             {'name': shape.name + '/Gathered'}, shape)
        interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
Ejemplo n.º 17
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='Interpolate', version='opset1'):
            transformation_mode = 'align_corners' if int(
                node.soft_get('align_corners', 0)) else 'half_pixel'
            interpolate1_name = node.soft_get('name', node.id)
            interpolate4 = create_op_with_const_inputs(
                graph, Interpolate, {
                    2: mo_array([1.0, 1.0]),
                    3: int64_array(node.axes)
                }, {
                    'mode': node.mode,
                    'antialias': node.antialias,
                    'coordinate_transformation_mode': transformation_mode,
                    'pads_begin': correct_pad(node.soft_get('pads_begin', 0)),
                    'pads_end': correct_pad(node.soft_get('pads_end', 0)),
                    'nearest_mode': 'round_prefer_floor',
                    'cube_coeff': -0.75,
                    'shape_calculation_mode': 'sizes',
                    'version': 'opset4',
                    'in_ports_count': 4,
                })

            interpolate1_input_connection = node.in_port(0).get_connection()
            interpolate1_input_connection.set_destination(
                interpolate4.in_port(0))

            sizes_connection = node.in_port(1).get_connection()
            sizes_connection.set_destination(interpolate4.in_port(1))

            node.out_port(0).get_connection().set_source(
                interpolate4.out_port(0))
            rename_nodes([(node, interpolate1_name + '/delete'),
                          (interpolate4, interpolate1_name)])
    def make_interpolate_reshape_able(self, interpolate: Node, concat: Node):
        assert interpolate.soft_get('type') == 'Interpolate'
        assert concat.soft_get('type') == 'Concat'
        interp_axes = Interpolate.get_axes(interpolate)
        concat_axis = self.get_concat_axis(concat)

        if concat_axis is None or interp_axes is None \
                or np.any(interp_axes < 0) or concat_axis < 0 \
                or concat_axis in interp_axes:
            # checks that interpolate axes and concat axis are valid and do not intersect
            return

        non_interp_concat_srcs = self.get_non_interpolate_concat_sources(
            concat)
        if not len(non_interp_concat_srcs):
            # there is no Concat input to take input from
            return

        graph = interpolate.graph
        src = non_interp_concat_srcs[0]

        shape = Shape(graph, {
            'name': src.node.soft_get('name', src.node.id) + '/Shape'
        }).create_node()
        shape.in_port(0).connect(src)
        gather = create_op_with_const_inputs(
            graph,
            Gather, {
                1: np.array(interp_axes, dtype=np.int32),
                2: int64_array(0)
            }, {'name': shape.name + '/Gathered'},
            input_node=shape)
        interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
Ejemplo n.º 19
0
    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
        node = match['reduce']
        connected_in_ports = [
            port for port in node.in_ports().values()
            if not port.disconnected()
        ]
        if len(connected_in_ports) == 1:
            node_name = node.soft_get('name', node.id)

            # if the 'axis' is None then we still add a second input to the layer with a 1D array with 1 element equal
            # to None. The infer function handles this case because the input shape is known at this stage only
            if node.has_valid('axis'):
                const = Const(graph, {
                    'name': node_name + '/axis',
                    'value': node.axis
                }).create_node()
                node.add_input_port(1, skip_if_exist=True)
                const.out_port(0).connect(node.in_port(1))
                del graph.node[node.id]['axis']
            else:
                # The default (if there is no 'axis') is to reduce over all the dimensions of the input tensor.
                axes = create_op_with_const_inputs(
                    graph, Range, {
                        0: int64_array(0),
                        2: int64_array(1)
                    }, dict(name=node_name + '/axes'))
                end_of_range = Rank(graph, dict(name=node_name +
                                                '/range_end')).create_node()
                node.in_port(0).get_connection().get_source().connect(
                    end_of_range.in_port(0))
                end_of_range.out_port(0).connect(axes.in_port(1))

                node.add_input_port(1, skip_if_exist=True)
                axes.out_port(0).connect(node.in_port(1))
Ejemplo n.º 20
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['proposal']
        assert len(node.in_ports()) == 3, "Proposal op must have exactly 3 input ports"
        im_info_shape = node.in_port(2).data.get_shape()
        assert im_info_shape is not None

        if np.array_equal(im_info_shape, [1, 6]):
            log.error('The model contains Proposal layer "{}" with input of shape [1, 6]. Inference Engine '
                      'implementation of the Proposal layer uses only 4 first values (indices 0, 1, 2 and 3). '
                      'Elements with indices 4 and 5 will be ignored.'.format(node.soft_get('name', node.id)),
                      extra={'is_warning': True})

            cropped_im_info = create_op_with_const_inputs(graph, StridedSlice, {1: np.array([0, 0], dtype=np.int32),
                                                                                2: np.array([1, 3], dtype=np.int32),
                                                                                3: np.array([1, 1], dtype=np.int32)},
                                                          {'name': 'cropped_im_info',
                                                           'begin_mask': int64_array([1, 1]),
                                                           'end_mask': int64_array([1, 1]),
                                                           'new_axis_mask': int64_array([0, 0]),
                                                           'shrink_axis_mask': int64_array([0, 0]),
                                                           'ellipsis_mask': int64_array([0, 0]),
                                                           'override_output_shape': True,
                                                           })

            node.in_port(2).get_connection().insert_node(cropped_im_info)

            # update the im_info_shape so the next 'if' statement become true
            im_info_shape = int64_array([1, 3])

        if np.array_equal(im_info_shape, [1, 3]) or np.array_equal(im_info_shape, [1, 4]):
            reshape = create_op_node_with_second_input(graph, Reshape, [im_info_shape[1]], {'name': 'im_info/Reshape'})
            node.in_port(2).get_connection().set_destination(reshape.in_port(0))
            reshape.out_port(0).connect(node.in_port(2))
Ejemplo n.º 21
0
    def replace_op(self, graph: Graph, node: Node):
        pb = node.parameters
        weights_size = read_binary_integer32_token(pb)
        weights = read_blob(pb, weights_size, dtype=np.int32) - 1

        node_name = node.soft_get('name', node.id)
        const_attrs = {
                       'name': node_name + '/indexes',
                       'value': np.array(weights),
                       'shape': [weights_size],
                       'data_type': np.int32
                      }
        indexes_node = Const(graph).create_node(attrs=const_attrs)

        perm_in_1 = Const(graph, {'value': int64_array([1, 0]), 'name': node_name + '/order'}).create_node()
        perm1_node = Transpose(graph, {'name': node_name + '/input_permute'}).create_node([node.in_node(0)])
        perm1_node.in_port(0).connect(node.in_port(0).get_source())
        perm1_node.in_port(1).connect(perm_in_1.out_port(0))

        gather_node = create_op_with_const_inputs(graph, Gather, {2: int64_array(0)}, {'name': node_name + '/gather'})
        gather_node.in_port(0).connect(perm1_node.out_port(0))
        gather_node.in_port(1).connect(indexes_node.out_port(0))

        perm2_node = Transpose(graph, {'name': node_name + '/output_permute'}).create_node()
        perm2_node.in_port(0).connect(gather_node.out_port(0))
        perm2_node.in_port(1).connect(perm_in_1.out_port(0))

        return [perm2_node.id]
Ejemplo n.º 22
0
    def replace_layer_norm(self, graph: Graph, match: dict):
        inp = match['pool0']
        node_before = inp.in_port(0).get_source().node
        node_before_name = node_before.soft_get('name', node_before.id)

        # take/check the values of the add, pow and axes for ReduceMean
        pow_param = match['pow_param']
        add_param = match['add_param']
        if add_param.value.size == 1 and pow_param.value.size == 1 and add_param.value.item() <= 1e-05 \
                and pow_param.value.item() == 0.5 and match['pool0_param'].value == match['pool1_param'].value:
            log.debug('Found LayerNorm pattern after {} with name {}'.format(
                node_before.op, node_before_name))
            mvn = create_op_with_const_inputs(
                graph, MVN, {1: match['pool1_param'].value}, {
                    'eps': add_param.value.item(),
                    'normalize_variance': 1,
                    'eps_mode': 'inside_sqrt'
                })
            div_name = match['div'].soft_get('name', match['div'].id)
            rename_nodes([(match['div'], div_name + '/to_be_removed'),
                          (mvn, div_name)])

            inp.in_port(0).get_connection().set_destination(mvn.in_port(0))
            match['div'].out_port(0).get_connection().set_source(
                mvn.out_port(0))
    def find_and_replace_pattern(self, graph: Graph):
        for ctc_greedy_decoder_tf in graph.get_op_nodes(
                op='CTCGreedyDecoderSeqLen', output_sparse_format=True):
            ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get(
                'name', ctc_greedy_decoder_tf.id)

            # TF CTCGreedyDecoder have 4 output tensors. If any of them connected to not Result operation then
            # transformation in not applicable
            for port_num in ctc_greedy_decoder_tf.out_ports():
                if not ctc_greedy_decoder_tf.out_port(port_num).disconnected()\
                        and ctc_greedy_decoder_tf.out_port(port_num).get_destination().node.soft_get('op') != 'Result':
                    return

            # If the first and second output are not connected to Result operations -
            # create Result operation and connect it to appropriate output
            if ctc_greedy_decoder_tf.out_port(0).disconnected():
                first_result = Result(
                    graph, {
                        'name': ctc_greedy_decoder_tf_name + '/decoded_classes'
                    }).create_node()
                ctc_greedy_decoder_tf.out_port(0).connect(
                    first_result.in_port(0))

            if ctc_greedy_decoder_tf.out_port(1).disconnected():
                second_result = Result(graph, {
                    'name':
                    ctc_greedy_decoder_tf_name + '/seq_lengths_output'
                }).create_node()
                ctc_greedy_decoder_tf.out_port(1).connect(
                    second_result.in_port(0))

            # For normalizing input channel needs to transpose input data from [T, N, C] to [N, T, C]
            # which supported CTCGreedyDecoderSeqLen op.
            log.warning(
                'Found TF CTCGreedyDecoder operation at the end of network. '
                'PLEASE NOTE, appropriate network output operation CTCGreedyDecoderSeqLen {} '
                'will have dense format, not sparse format!'.format(
                    ctc_greedy_decoder_tf_name))
            ctc_data_permute = create_op_with_const_inputs(
                graph, Transpose, {1: int64_array([1, 0, 2])},
                {'name': ctc_greedy_decoder_tf_name + '/ctc_data_permute'})

            assert ctc_greedy_decoder_tf.has_valid('merge_repeated'), \
                'The CTCGreedyDecoderSeqLen node "{}" misses "merge_repeated" attribute'.format(
                    ctc_greedy_decoder_tf_name)

            ctc_greedy_decoder_tf.in_port(0).get_source().connect(
                ctc_data_permute.in_port(0))
            ctc_greedy_decoder_tf.in_port(0).disconnect()
            ctc_data_permute.out_port(0).connect(
                ctc_greedy_decoder_tf.in_port(0))

            del ctc_greedy_decoder_tf['output_sparse_format']

            for port_num in [2, 3
                             ]:  # MO CTCGreedyDecoderSeqLen may have 2 outputs
                if port_num in ctc_greedy_decoder_tf.out_ports():
                    if not ctc_greedy_decoder_tf.out_port(
                            port_num).disconnected():
                        ctc_greedy_decoder_tf.out_port(port_num).disconnect()
Ejemplo n.º 24
0
    def replace_with_split_concat(node):
        graph = node.graph

        name = node.soft_get('name', node.id)
        axis = node.axis
        order = node.order

        split = create_op_with_const_inputs(graph, Split,
                                            {1: int64_array(axis)}, {
                                                'name': name + '/Split',
                                                'num_splits': order.size
                                            })
        concat = Concat(graph, {
            'name': name + '/Concat',
            'axis': axis,
            'in_ports_count': order.size
        }).create_node()

        for out_port_idx, in_port_idx in enumerate(order):
            split.out_port(out_port_idx).connect(concat.in_port(in_port_idx))

        node.out_port(0).get_connection().set_source(concat.out_port(0))
        node.in_port(0).get_connection().set_destination(split.in_port(0))

        graph.remove_node(node.id)
Ejemplo n.º 25
0
    def transform_map_fn_input_slicing(external_match: dict, internal_match: dict):
        """
        Transforms TensorFlow 2 input slicing into use of axis attribute for input port of Loop node
        :param external_match: a match used for handling a part of the main graph responsible for input slicing
        :param internal_match: a match used for handling a part of the body graph responsible for input slicing
        """
        loop_node = external_match['while']
        unstack_node = external_match['unstack']
        body_graph = loop_node['body']

        tensor_list_get_item_node = internal_match['slicing']
        unstack_placeholder = internal_match['tensor_list']
        tensor_list_get_item_node_name = tensor_list_get_item_node.soft_get('name', tensor_list_get_item_node.id)

        # 1. process the body graph to avoid unsupported operations: TensorListGetItem and TensorListSetItem
        # replace TensorListGetItem with Squeeze node and iterate through slices using axis for input port
        squeeze_list_element = create_op_with_const_inputs(body_graph, Squeeze, {1: int64_array(0)},
                                                           {'name': 'TensorListGetItemSqueeze'})
        tensor_list_get_item_node.in_port(0).get_connection().set_destination(squeeze_list_element.in_port(0))
        tensor_list_get_item_node.out_port(0).get_connection().set_source(squeeze_list_element.out_port(0))
        rename_nodes([(tensor_list_get_item_node, tensor_list_get_item_node_name + '/AbandonedName'),
                      (squeeze_list_element, tensor_list_get_item_node_name)])
        unstack_placeholder_layer_id = unstack_placeholder.internal_layer_id
        Loop.update_port_map_value_ext(loop_node.input_port_map, 'internal_layer_id', unstack_placeholder_layer_id,
                                       'axis', 0)

        # 2. process locality of Loop node in the main graph to avoid unsupported operations:
        # TensorListFromTensor, TensorListReserve, and TensorListStack
        # remove TensorListFromTensor and pass a tensor to Loop as is
        unstack_node.out_port(0).get_connection().set_source(unstack_node.in_port(0).get_connection().get_source())
Ejemplo n.º 26
0
    def find_and_replace_pattern(self, graph: Graph):
        for attr_pad in graph.get_op_nodes(op='AttributedPad'):
            # save the original node name to use it in the new Pad op instance
            original_name = attr_pad.soft_get('name', attr_pad.id)

            new_pad = Pad(graph, {
                'mode': attr_pad.soft_get('mode', None),
            }).create_node()
            rename_nodes([(attr_pad, original_name + '/to_be_removed'),
                          (new_pad, original_name)])

            attr_pad.in_port(0).get_connection().set_destination(
                new_pad.in_port(0))
            new_pad.in_port(1).connect(
                Const(graph, {
                    'value': attr_pad.pads[:, 0]
                }).create_node().out_port(0))
            new_pad.in_port(2).connect(
                Const(graph, {
                    'value': attr_pad.pads[:, 1]
                }).create_node().out_port(0))
            if attr_pad.soft_get('mode') == 'constant':
                # create Constant node of proper data type (equal to the data type of the Pad first input)
                convert_pad_value = create_op_with_const_inputs(
                    graph, ConvertLike, {0: attr_pad.fill_value},
                    {'name': original_name + '/pad_value_convert'})
                convert_pad_value.in_port(1).connect(
                    new_pad.in_port(0).get_source())
                new_pad.in_port(3).connect(convert_pad_value.out_port(0))

            attr_pad.out_port(0).get_connection().set_source(
                new_pad.out_port(0))
            graph.remove_node(attr_pad.id)
Ejemplo n.º 27
0
    def div_to_mul_replacement(div: Node):
        # we execute this transformation for V10 IR later on middle phase despite graph_condition
        # so we prevent Div replacement on shape-calculating sub-graphs
        if div.in_port(0).data.get_value() is not None and div.in_port(1).data.get_value() is not None:
            return

        # cannot replace Div with Mul when the divisor is integer because the reciprocal number will be 0
        value = div.in_port(1).data.get_value()
        if value is not None and type(value.item(0)) == int:
            return

        graph = div.graph
        name = div.soft_get('name', div.id)

        # keep Mul name the same as Div -- because of mathematical equality of output tensors
        rename_node(node=div, name=name + '/to_be_removed')

        # reconnect Div in(out)puts to Mul
        mul = Mul(graph, {'name': name}).create_node()
        rename_node(mul, name)

        div.in_port(0).get_connection().set_destination(mul.in_port(0))
        div.in_port(1).get_connection().set_destination(mul.in_port(1))
        div.out_port(0).get_connection().set_source(mul.out_port(0))

        # restore mathematical equivalence to Div operation: Div(A, B) = Mul(A, Pow(B, -1))
        reciprocal = create_op_with_const_inputs(graph, Pow, {1: np.float64(-1)}, {'name': name + '/reciprocal_'})
        mul.in_port(1).get_connection().insert_node(reciprocal)
Ejemplo n.º 28
0
 def replace_pattern(graph: Graph, match: dict):
     node = match['onehot']
     node_name = node.soft_get('name', node.id)
     reshape = create_op_with_const_inputs(graph, Reshape,
                                           {1: int64_array([])},
                                           {'name': node_name + '/Reshape'})
     node.in_port(1).get_connection().insert_node(reshape)
Ejemplo n.º 29
0
    def insert_pre_processing(graph: Graph, input_node: Node, node_mean_scale_values: np.array,
                              preprocessing_name: str):
        assert preprocessing_name in ['scale', 'mean']
        if node_mean_scale_values.get(preprocessing_name) is None:
            return
        user_value = node_mean_scale_values[preprocessing_name]
        value = 1 / user_value if preprocessing_name == 'scale' else user_value * (-1)
        optimize_value = int(preprocessing_name == 'scale')
        op = Mul if preprocessing_name == 'scale' else Add

        if all([x == optimize_value for x in value]):
            return
        assert input_node.has_valid('shape')
        features_dim_idx = get_features_dim(graph.graph['layout'], len(input_node.shape))
        assert compatible_dims(value.size, input_node.shape[features_dim_idx]) or value.size == 1

        shape = np.ones(len(input_node.shape), dtype=np.int64)
        shape[features_dim_idx] = value.size
        value = value.reshape(shape)

        name = input_node.soft_get('name', input_node.id) + '/' + preprocessing_name
        preprocessing = create_op_with_const_inputs(graph, op=op, port_value_dict={1: value}, op_attrs={'name': name})

        for dst in input_node.out_port(0).get_destinations():
            if dst.node.soft_get('type') != 'ShapeOf':
                # After the insertion of additional operations model optimizer
                # should keep the link to the input layer. Parameter node in framework
                # should map to parameter node in IR.
                # For this reason 'fw_tensor_debug_info' should be kept in data node.
                dst.get_connection().set_source(preprocessing.out_port(0), "source")

        input_node.out_port(0).connect(preprocessing.in_port(0))
Ejemplo n.º 30
0
    def sub_to_add_replacement(sub: Node):
        # we execute this transformation for V10 IR later on middle phase despite graph_condition
        # so we prevent Sub replacement on shape-calculating sub-graphs
        if sub.in_port(0).data.get_value() is not None and sub.in_port(
                1).data.get_value() is not None:
            return

        graph = sub.graph
        name = sub.soft_get('name', sub.id)

        # keep Add name the same as Sub -- because of mathematical equality of output tensors
        rename_node(node=sub, name=name + '/to_be_removed')

        # reconnect Sub in(out)puts to Add
        add = Add(graph, {'name': name}).create_node()
        rename_node(add, name)

        sub.in_port(0).get_connection().set_destination(add.in_port(0))
        sub.in_port(1).get_connection().set_destination(add.in_port(1))
        sub.out_port(0).get_connection().set_source(add.out_port(0))

        # restore mathematical equivalence to Sub operation: Sub(A, B) = Add(A, Mul(B, -1))
        const_dtype = sub.soft_get('data_type', np.float32)
        negate = create_op_with_const_inputs(
            graph, Mul, {1: np.array(-1, dtype=const_dtype)},
            {'name': name + '/neg_'})
        add.in_port(1).get_connection().insert_node(negate)