Exemplo n.º 1
0
    def replace_op(self, graph: Graph, node: Node):
        inp0 = node.in_port(0).get_source().node
        inp1 = node.in_port(1).get_source().node

        begin_id = Const(graph, {"value": int64_array([1])}).create_node()
        end_id = Const(graph, {"value": int64_array([2])}).create_node()
        dim1 = StridedSlice(
            graph,
            dict(
                name=inp0.name + "/dim1",
                begin_mask=[1],
                end_mask=[1],
                shrink_axis_mask=[0],
                new_axis_mask=[0],
                ellipsis_mask=[0],
            ),
        ).create_node([inp1, begin_id, end_id])

        rows = Div(graph, dict(name=node.name + "/rows")).create_node([inp0, dim1])

        inp0 = Cast(
            graph, dict(name=inp0.name + "/fp32", dst_type=np.float32)
        ).create_node([inp0])
        dim1 = Cast(
            graph, dict(name=dim1.name + "/fp32", dst_type=np.float32)
        ).create_node([dim1])
        cols = FloorMod(graph, dict(name=node.name + "/cols")).create_node([inp0, dim1])
        cols = Cast(
            graph, dict(name=cols.name + "/i64", dst_type=np.int64)
        ).create_node([cols])

        concat = PackOp(graph, dict(name=node.name + "/merged", axis=0)).create_node(
            [rows, cols]
        )
        return [concat.id]
def replace_interpolate_pattern(graph: Graph, match: dict):
    split = match['split']
    scale = float32_array([get_split_scale(split)])
    axis = int(split.in_port(1).get_connection().get_source().node.value)
    split_node_name = split.name
    axis_node = Const(graph, {'name': split_node_name + '/axis', 'value': int64_array([axis])}).create_node()

    shape_node = Shape(graph, dict(name=split_node_name + '/Shape')).create_node()
    scales_node = Const(graph, dict(name=split_node_name + '/scales', value=scale)).create_node()
    mul_node = Mul(graph, dict(name=split_node_name + '/Mul')).create_node()
    scales_node.out_port(0).connect(mul_node.in_port(1))

    strided_slice_node = create_op_with_const_inputs(graph,
                                                     StridedSlice,
                                                     {1: int64_array([axis]), 2: int64_array([axis + 1])},
                                                     {
                                                        'name': split_node_name + '/StridedSlice',
                                                        'begin_mask': int64_array([1]),
                                                        'end_mask': int64_array([1]),
                                                        'new_axis_mask': int64_array([0]),
                                                        'shrink_axis_mask': int64_array([0]),
                                                        'ellipsis_mask': int64_array([0])
                                                     })
    shape_node.out_port(0).connect(strided_slice_node.in_port(0))

    cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node()

    strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0))
    cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))

    interp_node = Interpolate(graph,
                              dict(name=split_node_name + '/Interpolate',
                                   mode='nearest',
                                   antialias=0, pads_begin=int64_array([0]), pads_end=int64_array([0]),
                                   coordinate_transformation_mode='half_pixel', nearest_mode='round_prefer_floor',
                                   cube_coeff=-0.75, version='opset4', shape_calculation_mode='scales',
                                   in_ports_count=4, maybe_part_of_sequence=True)).create_node()

    floor_node = Floor(graph, {'name': split_node_name + '/Floor'}).create_node()
    cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node()

    mul_node.out_port(0).connect(floor_node.in_port(0))
    floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0))

    cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1))
    scales_node.out_port(0).connect(interp_node.in_port(2))
    axis_node.out_port(0).connect(interp_node.in_port(3))

    match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0))

    split_connection = split.in_port(0).get_connection()
    split_connection.set_destination(interp_node.in_port(0))
    split_connection.get_source().connect(shape_node.in_port(0))
Exemplo n.º 3
0
def add_removed_converts(graph: Graph):
    for data_node_name in graph.get_nodes_with_attributes(
            Insert_Convert_operation_after=True):
        data_node = Node(graph, data_node_name)
        # Get access to Const node connected to data node
        const_op = data_node.in_node(0)

        if const_op.type != 'Const':
            logger.debug('Error when try to insert Convert operation after {} with {} type'.\
                format(const_op.soft_get('name'), const_op.soft_get('type')))
            continue

        if const_op.data_type != np.float32:
            logger.debug('Error when try to insert Convert operation after Const: {}'.\
                format(const_op.soft_get('name')))
            continue

        convert_op = Cast(
            graph, {
                'dst_type': np.float32,
                'name': const_op.name + '/restored_convert',
                'stop_value_propagation': True
            }).create_node()

        # Insert Convert operation after Const operation
        const_op.out_port(0).get_connection().insert_node(convert_op)
        convert_op.out_node().value = None

        # Convert Const value to FP16 to make types in graph consistent
        const_op.value, _, _ = convert_blob(const_op.value, np.float16)
        const_op.infer(const_op)
Exemplo n.º 4
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        div_sqrt = match['op']
        div_sqrt_name = div_sqrt.soft_get('name', div_sqrt.id)
        shape_node = Shape(graph,
                           dict(name=div_sqrt_name + '/Shape')).create_node()
        data_out_port = div_sqrt.in_port(0).get_source()
        shape_node.in_port(0).connect(data_out_port)

        shape_values_node = node_to_get_shape_value_of_indices(
            shape_node=shape_node, indices=[-1])

        pow_node = AttributedPower(
            graph, dict(name=div_sqrt_name + '/Sqrt',
                        power=mo_array(0.5))).create_node()

        # Due to specification, Power must have inputs with the same data type.
        convert_pow_input = Cast(
            graph,
            dict(dst_type=np.float32,
                 name=shape_values_node.name +
                 '/ConvertToFP32')).create_node()
        div_node = Div(graph, dict(name="Div")).create_node()

        shape_values_node.out_port(0).connect(convert_pow_input.in_port(0))
        convert_pow_input.out_port(0).connect(pow_node.in_port(0))
        div_sqrt.in_port(0).get_connection().set_destination(
            div_node.in_port(0))
        div_node.in_port(1).connect(pow_node.out_port(0))
        div_sqrt.out_port(0).get_connection().set_source(div_node.out_port(0))

        rename_nodes([(div_sqrt, div_sqrt_name + '/ShouldBeDeleted'),
                      (div_node, div_sqrt_name)])
Exemplo n.º 5
0
 def replace_pattern(graph: Graph, match: dict):
     node = match['node']
     for in_port, precision in node.force_precision_in_ports.items():
         if in_port in node.in_ports().keys() and not node.in_port(in_port).disconnected():
             cast = Cast(graph, {'name': node.name + '/Cast_' + str(in_port),
                                 'dst_type': data_type_str_to_np(precision)}).create_node()
             node.in_port(in_port).get_connection().insert_node(cast)
Exemplo n.º 6
0
    def replace_op(self, graph: Graph, node: Node):
        name = node.soft_get('name', node.id)
        axis = node.soft_get('axis', 0)

        rename_node(node=node, name=name + '/to_be_removed')
        cumsum_node = create_op_node_with_second_input(graph, CumSum,
                                                       int64_array(axis), {
                                                           'name': name,
                                                           'reverse': False,
                                                           'exclusive': False
                                                       })
        rename_node(cumsum_node, name)

        node.in_port(0).get_connection().set_destination(
            cumsum_node.in_port(0))
        if node.has_valid('mx_out_type') and node['mx_out_type'] is not None:
            rename_node(node=cumsum_node, name=name + '/CumSum')
            convert = Cast(graph, {
                'name': name,
                'dst_type': node['mx_out_type']
            }).create_node()
            rename_node(convert, name)
            cumsum_node.out_port(0).connect(convert.in_port(0))
            return [convert.id]
        else:
            return [cumsum_node.id]
Exemplo n.º 7
0
    def find_and_replace_pattern(self, graph: Graph):
        for dequantize_node in graph.get_op_nodes(op='DequantizeLinear'):
            node_name = dequantize_node.soft_get('name', dequantize_node.id)
            axis = dequantize_node.soft_get('axis', None)
            scale_y_shape = dequantize_node.in_port(1).data.get_shape()
            model_data_type = data_type_str_to_np(
                graph.graph['cmd_params'].data_type)
            cast = Cast(graph, {
                'dst_type': model_data_type,
                'name': node_name + '/Cast'
            }).create_node()
            dequantize_node.in_port(0).get_connection().set_destination(
                cast.in_port(0))
            mul = Mul(graph, {'can_be_fused': False}).create_node()

            is_second_port_connected = dequantize_node.is_in_port_connected(2)
            if is_second_port_connected:
                # its is necessary not to replace subrtract for pattern in offline transformations
                # See ConvertQuantizeDequantize transformation in ngraph
                sub = Sub(graph, {
                    'name': node_name + '/Sub',
                    'zero_point_sub': True
                }).create_node()
                cast.out_port(0).connect(sub.in_port(0))
                dequantize_node.in_port(2).get_connection().set_destination(
                    sub.in_port(1))
                sub.out_port(0).connect(mul.in_port(0))
            else:
                cast.out_port(0).connect(mul.in_port(0))

            dequantize_node.in_port(1).get_connection().set_destination(
                mul.in_port(1))
            dequantize_node.out_port(0).get_connection().set_source(
                mul.out_port(0))
            rename_nodes([(dequantize_node, node_name + '/TBD'),
                          (mul, node_name)])

            assert scale_y_shape is not None
            if axis is not None and len(
                    scale_y_shape) > 0 and scale_y_shape[0] > 1:
                input_shape = cast.in_port(0).data.get_shape()
                target_shape = np.ones(len(input_shape), np.int64)
                target_shape[axis] = input_shape[axis]

                mul_reshape = create_op_with_const_inputs(
                    graph, Reshape, {1: int64_array(target_shape)},
                    {'name': node_name + '/Reshape/Mul'})
                mul.in_port(1).get_connection().set_destination(
                    mul_reshape.in_port(0))
                mul_reshape.out_port(0).connect(mul.in_port(1))

                if is_second_port_connected:
                    sub_reshape = create_op_with_const_inputs(
                        graph, Reshape, {1: int64_array(target_shape)},
                        {'name': node_name + '/Reshape/Sub'})
                    sub.in_port(1).get_connection().set_destination(
                        sub_reshape.in_port(0))
                    sub_reshape.out_port(0).connect(sub.in_port(1))
def insert_do(graph: Graph, replacement_descriptions: dict):
    do_outputs = replacement_descriptions['do_outputs']
    prior_boxes_node = Node(graph, 'ROIFeatureExtractor_2')
    num_classes = 81
    box_regressions_input_node = Node(
        graph, replacement_descriptions['box_regressions_input_node'])
    box_regressions_node = create_op_node_with_second_input(
        graph, Reshape, int64_array([-1, 4 * num_classes]),
        dict(name='box_regressions'), box_regressions_input_node)

    class_predicitons_node = Node(
        graph, replacement_descriptions['class_predicitons_node'])
    im_info_node = Parameter(graph, {
        "name": 'im_info',
        'shape': int64_array([1, 3])
    }).create_node()

    do_node = ExperimentalDetectronDetectionOutput(
        graph, {
            'name':
            'DetectionOutput',
            'class_agnostic_box_regression':
            0,
            'deltas_weights':
            np.array([10.0, 10.0, 5.0, 5.0]),
            'max_delta_log_wh':
            replacement_descriptions['max_delta_log_wh'],
            'nms_threshold':
            replacement_descriptions['nms_threshold'],
            'score_threshold':
            replacement_descriptions['score_threshold'],
            'num_classes':
            num_classes,
            'max_detections_per_image':
            replacement_descriptions['max_detections_per_image'],
            'post_nms_count':
            replacement_descriptions['post_nms_count']
        }).create_node()
    prior_boxes_node.out_port(1).connect(do_node.in_port(0))
    box_regressions_node.out_port(0).connect(do_node.in_port(1))
    class_predicitons_node.out_port(0).connect(do_node.in_port(2))
    im_info_node.out_port(0).connect(do_node.in_port(3))

    do_output_ports = [
        do_node.out_port(0),
        do_node.out_port(1),
        do_node.out_port(2)
    ]
    old_do_output_nodes = [Node(graph, node_id) for node_id in do_outputs]
    for old_node, new_port in zip(old_do_output_nodes, do_output_ports):
        old_node.out_port(0).get_connection().set_source(new_port)
    # the consumer of the second output port of the ExperimentalDetectronDetectionOutput is the Mul node which second
    # input is of type int64 so it is necessary to insert Cast to have data types match
    do_node.out_port(1).get_connection().insert_node(
        Cast(graph, {
            'dst_type': np.int64
        }).create_node())
Exemplo n.º 9
0
def create_ss_interval_border(graph: Graph, slice_border_port: Port,
                              shape: np.ndarray, axes: np.ndarray,
                              node_name: str):
    """
    This function creates "begin"/"end" parameters for the StridedSlice based on Slice's "starts"/"ends"

    :param graph: graph to operate on.
    :param slice_border_port: node output port that provides "starts"/"ends" values for the Slice.
    :param shape: input shape of the Slice
    :param axes: axes that "starts" and "ends" apply to
    :param node_name: Slice node name
    :return: Concat node that forms "begin"/"end" values for the StridedSlice
    """
    # the value for 'starts' or 'ends' might be maximum/minimum possible value of int64. This
    # value must be converted to maximum/minimum of int32 because such big values do not fit into the int32 which is
    # supported by the StridedSlice layer
    clamp = create_op_with_const_inputs(graph,
                                        Clamp,
                                        port_value_dict={
                                            1: np.iinfo(np.int32).min,
                                            2: np.iinfo(np.int32).max
                                        },
                                        op_attrs=dict(name=node_name +
                                                      '/Clamp'))
    clamp.in_port(0).connect(slice_border_port)
    # we have to convert "starts"/"ends" values from the network to one data type with constant values that are created
    # here to prevent type errors in Concat node
    cast = Cast(graph, dict(name=node_name + '/CastToI64',
                            dst_type=np.int64)).create_node()
    cast.in_port(0).connect(clamp.out_port(0))
    concat = Concat(graph, dict(name=node_name + '/Concat',
                                axis=0)).create_node()
    for value_idx, port_idx in enumerate(axes):
        concat.add_input_port(port_idx)
        # "axes" may not be sorted, so we need to split "starts"/"ends" values and connect each value to the correct
        # Concat input port
        value = create_op_with_const_inputs(
            graph,
            Gather,
            port_value_dict={
                1: int64_array([value_idx]),
                2: int64_array(0)
            },
            op_attrs={'name': node_name + '/Gather'})
        cast.out_port(0).connect(value.in_port(0))
        value.out_port(0).connect(concat.in_port(port_idx))
    for port_idx in range(len(shape)):
        if not concat.is_in_port_connected(port_idx):
            concat.add_input_port(port_idx)
            # This border value would be ignored in StridedSlice because of the begin_mask\end_mask
            const = Const(
                graph, dict(name=node_name + '/Const',
                            value=int64_array([0]))).create_node()
            const.out_port(0).connect(concat.in_port(port_idx))

    return concat
Exemplo n.º 10
0
    def placeholder_scales(self, placeholder: Node):
        """
        Helper function to get scales for prior boxes out of input image size:
                [1 / im_width, 1 / im_height, 1 / im_width, 1 / im_height]
        """
        graph = placeholder.graph
        name = placeholder.soft_get('name', placeholder.id)

        shape_value = placeholder.soft_get('shape', None)
        assert shape_value is not None, \
            "[ {} replacer ] Placeholder `{}` should have shape attribute".format(self.replacement_id, name)
        assert isinstance(shape_value, np.ndarray), \
            "[ {} replacer ] Placeholder `{}` shape attribute should be np.ndarray".format(self.replacement_id, name)
        assert shape_value.size == 4, \
            "[ {} replacer ] Placeholder `{}` should be 4D. Shape: {}".format(self.replacement_id, name, shape_value)

        shape = Shape(graph, {'name': 'input_image_shape'}).create_node()
        shape.in_port(0).connect(placeholder.out_port(0))

        begin = Const(graph, {'value': int64_array([1])}).create_node()
        end = Const(graph, {'value': int64_array([3])}).create_node()
        stride = Const(graph, {'value': int64_array([1])}).create_node()
        spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': int64_array([1]),
                                       'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]),
                                       'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0])}).create_node()

        spatial.in_port(0).connect(shape.out_port(0))
        spatial.in_port(1).connect(begin.out_port(0))
        spatial.in_port(2).connect(end.out_port(0))
        spatial.in_port(3).connect(stride.out_port(0))

        power = Const(graph, {'value': float32_array([-1.])}).create_node()
        spatial_scale = Pow(graph, {}).create_node()

        spatial_scale.in_port(0).connect(spatial.out_port(0))
        spatial_scale.in_port(1).connect(power.out_port(0))

        # Power `type_infer` requires inputs to have equal data type
        convert_to_fp32 = Cast(graph, {'dst_type': np.float32}).create_node()
        spatial_scale.in_port(0).get_connection().insert_node(convert_to_fp32)

        order = Const(graph, {'value': int64_array([1, 0])}).create_node()
        axis_const = Const(graph, {'value': int64_array(0)}).create_node()
        reverse = Gather(graph, {}).create_node()

        reverse.in_port(0).connect(spatial_scale.out_port(0))
        reverse.in_port(1).connect(order.out_port(0))
        axis_const.out_port(0).connect(reverse.in_port(2))

        priors_scale_node = Concat(graph, {'axis': 0, 'in_ports_count': 2}).create_node()
        priors_scale_node.add_input_port(0, skip_if_exist=True)
        priors_scale_node.add_input_port(1, skip_if_exist=True)

        priors_scale_node.in_port(0).connect(reverse.out_port(0))
        priors_scale_node.in_port(1).connect(reverse.out_port(0))
        return priors_scale_node
Exemplo n.º 11
0
    def quantize_data(fake_quantize: Node, dst_type: type,
                      quantized_type: type, mode: str):
        graph = fake_quantize.graph
        name = fake_quantize.soft_get('name', fake_quantize.id)
        levels = fake_quantize.levels

        quantize = fake_quantize.copy_node(
            dict(name=name + '/Copy', stop_value_propagation=False), graph)
        fake_quantize.in_port(0).get_connection().set_destination(
            quantize.in_port(0))

        # inherit input limits
        fake_quantize.in_port(1).get_connection().set_destination(
            quantize.in_port(1))
        fake_quantize.in_port(2).get_connection().set_destination(
            quantize.in_port(2))

        # calculate output limits for quantized weights
        assert mode in ["signed", "unsigned"]
        i_min_value = -(levels // 2) if mode == "signed" else 0

        i_min = mo_array(i_min_value, dtype=dst_type) if not quantize.in_node(
            0).shape.size else mo_array([i_min_value], dtype=dst_type)
        i_max = mo_array(levels + i_min - 1, dtype=dst_type)

        assert i_max - i_min == levels - 1
        out_low = Const(graph, dict(name=name + '/Copy/out_low',
                                    value=i_min)).create_node()
        out_high = Const(graph, dict(name=name + '/Copy/out_high',
                                     value=i_max)).create_node()

        out_low.out_port(0).connect(quantize.in_port(3))
        out_high.out_port(0).connect(quantize.in_port(4))
        out_low.out_port(0).connect(fake_quantize.in_port(1))
        out_high.out_port(0).connect(fake_quantize.in_port(2))

        original_const = quantize.in_port(0).get_source().node
        quantized_data_name = original_const.soft_get(
            'name', original_const.id) + '/quantized'
        cast = Cast(
            graph,
            dict(name=quantized_data_name,
                 dst_type=quantized_type,
                 stop_value_propagation=False)).create_node()

        quantize.out_port(0).connect(cast.in_port(0))

        cast.out_port(0).connect(fake_quantize.in_port(0))
    def find_and_replace_pattern(self, graph: Graph):
        ir_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)

        for node in graph.get_op_nodes(op='RandomUniform'):
            assert node.has_valid('output_type')

            if node.has_and_set('returns_shape_value'):
                continue

            if node.output_type != ir_data_type and np.issubdtype(
                    node.output_type, np.floating):
                node_name = node.soft_get('name', node.id)
                convert_node = Cast(graph, {
                    'name': node_name + "/cast",
                    'dst_type': ir_data_type
                }).create_node()
                node.out_port(0).get_connection().insert_node(convert_node)
Exemplo n.º 13
0
    def replace_op(self, graph: Graph, node: Node):
        if node.has_and_set('inputs_preprocessed'):
            log.debug('Node "{}" has already been preprocessed'.format(
                node.soft_get('name')))
            return []
        # reshape tensor with batch indices to 2d
        unsqueeze_node = create_op_node_with_second_input(
            graph, Unsqueeze, int64_array([1]),
            {'name': node.name + '/Unsqueeze'}, node.in_node(2))

        convert_node = Cast(
            graph, {
                'name':
                unsqueeze_node.name + '/ToFloat',
                'dst_type':
                data_type_str_to_np(graph.graph['cmd_params'].data_type)
            }).create_node()

        convert_node.in_port(0).connect(unsqueeze_node.out_port(0))

        concat_op = Concat(
            graph, {
                'axis': 1,
                'name': node.name + '/concat_batch_indices_and_boxes',
                'in_ports_count': 2
            })
        concat_node = concat_op.create_node([convert_node, node.in_node(1)])

        # do not remove edge with crop_size because it is needed in the partial infer
        graph.remove_edge(node.in_node(1).id, node.id)

        # input to the CropAndResize contains boxes coordinates in YXYX layout. But IE layer ROIPooling expects
        # coordinates in the XYXY layout, so convolution is added here to swap coordinates
        swapped_box_coordinates_node = add_convolution_to_swap_xy_coordinates(
            graph, concat_node, 5)

        # reshape locations tensor to 2D so it could be passed to Eltwise which will be converted to ScaleShift
        reshape_2d_node = create_op_node_with_second_input(
            graph, Reshape, int64_array([-1, 5]),
            dict(name=swapped_box_coordinates_node.id + '/reshape_2d_'),
            swapped_box_coordinates_node)
        graph.create_edge(reshape_2d_node, node, 0, 1)

        # do not replace any output edge
        return []
Exemplo n.º 14
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        tf_slice_node = match['op']
        slice_name = tf_slice_node.soft_get('name', tf_slice_node.id)
        slice_node = Slice(graph).create_node()
        rename_nodes([(tf_slice_node, slice_name + '/to_be_removed'),
                      (slice_node, slice_name)])
        ends_node = Add(graph, {'name': slice_name + '/ends'}).create_node()

        # reconnect input, begin, and size from TFSlice to the subgraph with Slice
        tf_slice_node.in_port(0).get_connection().set_destination(
            slice_node.in_port(0))
        tf_slice_node.in_port(1).get_connection().set_destination(
            slice_node.in_port(1))
        tf_slice_node.in_port(2).get_connection().set_destination(
            ends_node.in_port(0))
        slice_node.in_port(1).get_connection().add_destination(
            ends_node.in_port(1))

        max_ends = Shape(graph, {
            'name': slice_name + '/ShapeOf'
        }).create_node()
        slice_node.in_port(0).get_connection().add_destination(
            max_ends.in_port(0))

        # check if size[i] == -1, will be applied elementwisely: len(size) = len(begin) = input_rank
        where_max_ends_is_needed = create_op_with_const_inputs(
            graph, Equal, {0: int64_array(-1)},
            {'name': slice_name + '/where_max_ends_is_needed'})
        ends_node.in_port(0).get_connection().add_destination(
            where_max_ends_is_needed.in_port(1))
        # select requires equal dtypes, need to convert ends to I64
        ends_casted_to_i64 = Cast(graph, {
            'name': slice_name + '/CastToI64',
            'dst_type': np.int64
        }).create_node([ends_node])
        # if size[i] == 1 then take max_ends values
        correct_ends = Select(graph, {
            'name': slice_name + '/chosen_ends'
        }).create_node(
            [where_max_ends_is_needed, max_ends, ends_casted_to_i64])
        correct_ends.out_port(0).connect(slice_node.in_port(2))

        tf_slice_node.out_port(0).get_connection().set_source(
            slice_node.out_port(0))
Exemplo n.º 15
0
def convert_inputs_of_specific_ops(graph: Graph):
    type_port = {'Broadcast': {1: 'int64', 2: 'int64'},
                 'ConvolutionBackpropData': {2: 'int64'},
                 'Deconvolution': {2: 'int64'},
                 'Gather': {2: 'int64'},
                 'GroupConvolutionBackpropData': {2: 'int64'},
                 'Interpolate': {1: 'int64'},
                 'LRN': {1: 'int64'},
                 'NonMaxSuppression': {2: 'int64'},
                 'NormalizeL2': {1: 'int64'},
                 'OneHot': {1: 'int64'},
                 'Pad': {1: 'int64', 2: 'int64'},
                 'PriorBox': {0: 'int64', 1: 'int64'},
                 'PriorBoxClustered': {0: 'int64', 1: 'int64'},
                 'ReduceLogicalAnd': {1: 'int64'},
                 'ReduceLogicalOr': {1: 'int64'},
                 'ReduceMax': {1: 'int64'},
                 'ReduceMean': {1: 'int64'},
                 'ReduceMin': {1: 'int64'},
                 'ReduceProd': {1: 'int64'},
                 'ReduceSum': {1: 'int64'},
                 'Reshape': {1: 'int64'},
                 'Squeeze': {1: 'int64'},
                 'StridedSlice': {1: 'int64', 2: 'int64', 3: 'int64'},
                 'Split': {1: 'int64'},
                 'Tile': {1: 'int64'},
                 'Transpose': {1: 'int64'},
                 'Unsqueeze': {1: 'int64'},
                 'VariadicSplit': {1: 'int64', 2: 'int64'},
                 }

    for node in graph.get_op_nodes():
        if node.soft_get('type') in type_port:
            ports_to_update = type_port[node.soft_get('type')]
            for port_id, precision in ports_to_update.items():
                if port_id in node.in_ports() and not node.in_port(port_id).disconnected():
                    log.debug('Converting value for the input port "{}" of op "{}" to "{}".'
                              ''.format(port_id, node.soft_get('name', node.id), precision))
                    in_port = node.in_port(port_id)
                    np_type = data_type_str_to_np(precision)
                    if in_port.get_source().node.type == 'Const':
                        convert_const_node_value_type(node.in_port(port_id).get_source().node, np_type)
                    else:
                        in_port.get_connection().insert_node(Cast(graph, {'dst_type': np_type}).create_node())
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='ThresholdedRelu'):
            name = node.soft_get('name', node.id)

            greater = create_op_with_const_inputs(
                graph, Greater, {1: float_array([node.alpha])})
            greater.in_port(0).connect(node.in_port(0).get_source())
            float_greater = Cast(
                graph, {
                    'dst_type':
                    data_type_str_to_np(graph.graph['cmd_params'].data_type)
                }).create_node()
            greater.out_port(0).connect(float_greater.in_port(0))

            mul = Mul(graph, {}).create_node()
            node.out_port(0).get_connection().set_source(mul.out_port(0))
            mul.in_port(0).connect(node.in_port(0).get_source())
            mul.in_port(1).connect(float_greater.out_port(0))

            rename_nodes([(node, name + '/TBR'), (mul, name)])
            graph.remove_node(node.id)
Exemplo n.º 17
0
    def find_and_replace_pattern(self, graph: Graph):
        for quantize_node in graph.get_op_nodes(op='QuantizeLinear'):
            node_name = quantize_node.soft_get('name', quantize_node.id)
            axis = quantize_node.soft_get('axis', None)
            scale_y_shape = quantize_node.in_port(1).data.get_shape()

            if quantize_node.is_in_port_connected(2):
                zerop = quantize_node.in_port(2).get_source().node
            else:
                zerop = Const(
                    graph, {
                        'value': mo_array(0, dtype=np.uint8),
                        'name': node_name + '/ZeroPoint'
                    }).create_node()

            assert zerop.soft_get(
                'type'
            ) == 'Const', 'only constant for zero_point is supported for QuantizeLinear'
            zero_point_type = zerop.value.dtype
            # data type affects range of output values: [-128..127] or [0..255]
            if zero_point_type == np.int8:
                output_low_value = -128.0
                output_high_value = 127.0
            elif zero_point_type == np.uint8:
                output_low_value = 0.0
                output_high_value = 255.0
            else:
                raise Error(
                    'Not expected type {} for zero point value in node {}'.
                    format(zero_point_type, zerop.soft_get('name')))

            fake_quantize = create_op_with_const_inputs(
                graph, FakeQuantize, {
                    3: float_array(output_low_value),
                    4: float_array(output_high_value)
                }, {
                    'levels': 256,
                    'name': node_name + '/FakeQuantize'
                })
            quantize_node.in_port(0).get_connection().set_destination(
                fake_quantize.in_port(0))

            # Calculate input_low value
            mul_low = create_op_with_const_inputs(
                graph, Mul, {1: float_array(output_low_value - zerop.value)},
                {'name': node_name + '/Mul/Low'})
            quantize_node.in_port(1).get_connection().set_destination(
                mul_low.in_port(0))
            mul_low.out_port(0).connect(fake_quantize.in_port(1))

            # Calculate input_high value
            mul_high = create_op_with_const_inputs(
                graph, Mul, {1: float_array(output_high_value - zerop.value)},
                {'name': node_name + '/Mul/High'})
            mul_low.in_port(0).get_connection().add_destination(
                mul_high.in_port(0))
            mul_high.out_port(0).connect(fake_quantize.in_port(2))

            cast = Cast(graph, {
                'dst_type': zero_point_type,
                'name': node_name + '/Cast'
            }).create_node()
            fake_quantize.out_port(0).connect(cast.in_port(0))
            quantize_node.out_port(0).get_connection().set_source(
                cast.out_port(0))
            rename_nodes([(quantize_node, node_name + '/TBD'),
                          (cast, node_name)])

            assert scale_y_shape is not None
            if axis is not None and len(
                    scale_y_shape) > 0 and scale_y_shape[0] > 1:
                input_shape = fake_quantize.in_port(0).data.get_shape()
                target_shape = np.ones(len(input_shape), np.int)
                target_shape[axis] = input_shape[axis]
                mul_low_reshape = create_op_with_const_inputs(
                    graph, Reshape, {1: int64_array(target_shape)},
                    {'name': node_name + '/Reshape/Mul/Low'})
                mul_high_reshape = create_op_with_const_inputs(
                    graph, Reshape, {1: int64_array(target_shape)},
                    {'name': node_name + '/Reshape/Mul/high'})

                fake_quantize.in_port(1).get_connection().set_destination(
                    mul_low_reshape.in_port(0))
                fake_quantize.in_port(2).get_connection().set_destination(
                    mul_high_reshape.in_port(0))

                mul_low_reshape.out_port(0).connect(fake_quantize.in_port(1))
                mul_high_reshape.out_port(0).connect(fake_quantize.in_port(2))
def replace_resize(graph: Graph, resize: Node):
    log.debug("Converting of ONNX Resize-11 to Interpolate-4 "
              "is triggered for node {}.".format(
                  resize.soft_get('name', resize.id)))

    input_shape = resize.in_port(0).data.get_shape()
    input_rank = len(input_shape)
    resize_name = resize.soft_get('name', resize.id)
    if input_rank not in {4, 5}:
        log.warning(
            'The input shape is not 4D or 5D for op with name {}'.format(
                resize_name))
        return

    assert (resize.is_in_port_connected(0) and (resize.is_in_port_connected(2) or resize.is_in_port_connected(3))), \
        "Scales or sizes inputs must be connected to Node {} with op {}.".format(resize.soft_get("name", resize.id),
                                                                                 resize.op)

    assert resize.soft_get('coordinate_transformation_mode') != 'tf_crop_and_resize', \
        'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(resize.op,
                                                                                 resize.soft_get("name", resize.id))

    layout = graph.graph['layout']

    if input_rank == 4:
        begin_dim = get_height_dim(layout, input_rank)
        end_dim = get_width_dim(layout, input_rank) + 1
    else:
        begin_dim = get_depth_dim(layout, input_rank)
        end_dim = get_width_dim(layout, input_rank) + 1

    sizes_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([begin_dim]),
            2: int64_array([end_dim]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/StridedSlice_sizes',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })
    scales_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([begin_dim]),
            2: int64_array([end_dim]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/StridedSlice_scales',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })
    axes_node = Const(
        graph, {
            'name': resize_name + '/axis',
            'value': int64_array(np.arange(begin_dim, end_dim))
        }).create_node()

    shape_calculation_mode = 'sizes' if resize.is_in_port_connected(
        3) else 'scales'

    interpolate_node = Interpolate(
        graph, {
            'version': 'opset4',
            'mode': convert_mode(resize.mode),
            'coordinate_transformation_mode':
            resize.coordinate_transformation_mode,
            'cube_coeff': resize.cube_coeff,
            'nearest_mode': resize.nearest_mode,
            'pads_begin': int64_array([0]),
            'pads_end': int64_array([0]),
            'antialias': 0,
            'shape_calculation_mode': shape_calculation_mode,
            'in_ports_count': 4
        }).create_node()

    axes_node.out_port(0).connect(interpolate_node.in_port(3))
    shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()

    add_node = create_op_with_const_inputs(graph, Add,
                                           {1: float_array([1.0e-5])},
                                           {'name': resize_name + '/Add'})

    dst_dtype = np.float32  # even if data_type=FP16 use float32 for shape values

    if not resize.is_in_port_connected(3):
        cast_shape_to_float = Cast(graph, {
            'dst_type': dst_dtype
        }).create_node()
        mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node()
        shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
        cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))
        cast_add_result_to_int = Cast(graph, {
            'dst_type': np.int64
        }).create_node()
        floor_node = Floor(graph, {
            'name': resize_name + '/Floor'
        }).create_node()
        mul_node.out_port(0).connect(add_node.in_port(0))
        add_node.out_port(0).connect(floor_node.in_port(0))
        floor_node.out_port(0).connect(cast_add_result_to_int.in_port(0))
        cast_add_result_to_int.out_port(0).connect(sizes_ss.in_port(0))
        sizes_ss.out_port(0).connect(interpolate_node.in_port(1))
        scales_ss.out_port(0).connect(interpolate_node.in_port(2))

        connection_of_resize_input = resize.in_port(0).get_connection()
        connection_of_resize_input.set_destination(interpolate_node.in_port(0))

        connection_of_scales = resize.in_port(2).get_connection()
        connection_of_scales.set_destination(scales_ss.in_port(0))

        connection_of_resize_input.get_source().connect(shape_of.in_port(0))
        connection_of_scales.get_source().connect(mul_node.in_port(1))
    else:
        cast_shape_to_float = Cast(graph, {
            'dst_type': dst_dtype
        }).create_node()
        cast_sizes_to_float = Cast(graph, {
            'dst_type': dst_dtype
        }).create_node()
        div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()
        cast_sizes_to_float.out_port(0).connect(div_node.in_port(0))
        cast_shape_to_float.out_port(0).connect(div_node.in_port(1))
        shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
        div_node.out_port(0).connect(add_node.in_port(0))
        add_node.out_port(0).connect(scales_ss.in_port(0))
        scales_ss.out_port(0).connect(interpolate_node.in_port(2))
        sizes_ss.out_port(0).connect(interpolate_node.in_port(1))

        connection_of_resize_input = resize.in_port(0).get_connection()
        connection_of_resize_input.set_destination(interpolate_node.in_port(0))

        connection_of_sizes = resize.in_port(3).get_connection()
        connection_of_sizes.set_destination(sizes_ss.in_port(0))

        connection_of_resize_input.get_source().connect(shape_of.in_port(0))
        connection_of_sizes.get_source().connect(
            cast_sizes_to_float.in_port(0))

    rename_nodes([(resize, resize_name + '/delete'),
                  (interpolate_node, resize_name)])
    resize.out_port(0).get_connection().set_source(
        interpolate_node.out_port(0))
Exemplo n.º 19
0
def replace_resize(graph: Graph, resize: Node):
    log.debug("Converting of ONNX Resize-10 to Interpolate-4 "
              "is triggered for node {}.".format(
                  resize.soft_get('name', resize.id)))

    resize_name = resize.soft_get('name', resize.id)

    rank_node = Rank(graph, {'name': resize_name + '/max_axes'}).create_node()
    range_node = create_op_with_const_inputs(graph, Range, {
        0: int64_array(2),
        2: int64_array(1)
    }, {'name': resize_name + '/axes'})

    sizes_ss = create_op_with_const_inputs(graph, StridedSlice, {
        1: int64_array([2]),
        2: int64_array([0]),
        3: int64_array([1])
    }, {
        'name': resize_name + '/sizes_ss',
        'begin_mask': int64_array([1]),
        'end_mask': int64_array([0]),
        'new_axis_mask': int64_array([0]),
        'shrink_axis_mask': int64_array([0]),
        'ellipsis_mask': int64_array([0])
    })
    scales_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([2]),
            2: int64_array([0]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/scales_ss',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([0]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })

    rank_node.out_port(0).connect(range_node.in_port(1))

    interpolate_node = Interpolate(
        graph, {
            'version': 'opset4',
            'mode': 'linear_onnx' if resize.mode == 'linear' else 'nearest',
            'coordinate_transformation_mode': 'asymmetric',
            'cube_coeff': -0.75,
            'nearest_mode': 'simple',
            'pads_begin': int64_array([0]),
            'pads_end': int64_array([0]),
            'antialias': 0,
            'shape_calculation_mode': 'scales',
            'in_ports_count': 4
        }).create_node()

    range_node.out_port(0).connect(interpolate_node.in_port(3))
    shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()

    # When we calculate 'sizes' input as floor(input_shape * scales), we can get incorrect 'sizes' if, e.g.,
    # scales = [1.0, 1.0, 1.33333, 2.0], input_shape = [1, 3, 30, 200], because
    # input_shape * scales = [1, 3, 39.9999, 400], and floor(input_shape * scales)[2] == 39, not 40.
    # Maybe we need to calculate 'sizes' input as floor(input_shape * scales + eps), where eps is some small
    # floating point number, e.g. 1.0e-5. But, in this case, if scales = [1.0, 1.0, 1.333333, 2.0],
    # input_shape = [1, 3, 30, 200], floor(input_shape * scales + eps) = 39, not 40, because
    # input_shape[2] * scales[2] + 1.0e-5 =  39.99991.
    # Hence, we need to calculate 'sizes' as floor(input_shape * (scales + eps)).
    add_node = create_op_with_const_inputs(graph, Add,
                                           {1: float_array([1.0e-5])},
                                           {'name': resize_name + '/Add'})

    dst_dtype = np.float32  # even if data_type=FP16 use float32 for shape values

    cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()

    shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
    mul_node = Mul(graph, {
        'name': resize_name + '/Mul'
    }).create_node([cast_shape_to_float, add_node])
    floor_node = Floor(graph, {
        'name': resize_name + '/Floor'
    }).create_node([mul_node])
    cast_mul_result_to_int = Cast(graph, {
        'dst_type': np.int64
    }).create_node([floor_node])
    cast_mul_result_to_int.out_port(0).connect(sizes_ss.in_port(0))
    sizes_ss.out_port(0).connect(interpolate_node.in_port(1))

    scales_ss.out_port(0).connect(interpolate_node.in_port(2))

    connection_of_resize_input = resize.in_port(0).get_connection()
    connection_of_resize_input.set_destination(interpolate_node.in_port(0))

    connection_of_scales = resize.in_port(1).get_connection()
    connection_of_scales.set_destination(scales_ss.in_port(0))

    connection_of_resize_input.get_source().connect(shape_of.in_port(0))
    connection_of_resize_input.get_source().connect(rank_node.in_port(0))
    connection_of_scales.get_source().connect(add_node.in_port(0))

    rename_nodes([(resize, resize_name + '/delete'),
                  (interpolate_node, resize_name)])
    resize.out_port(0).get_connection().set_source(
        interpolate_node.out_port(0))
Exemplo n.º 20
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='SpaceToBatch') + graph.get_op_nodes(
                op='BatchToSpace'):
            node.add_input_port(3, skip_if_exist=True)

            # convert TF representation of the pads/crops as [N, 2] to IE representation: [N] and [N]
            transposed_pads = create_op_with_const_inputs(
                graph, Transpose, {1: int64_array([1, 0])})
            node.in_port(2).get_connection().set_destination(
                transposed_pads.in_port(0))
            split_pads = create_op_with_const_inputs(graph, Split,
                                                     {1: int64_array(0)},
                                                     {'num_splits': 2})
            transposed_pads.out_port(0).connect(split_pads.in_port(0))
            for port_ind in range(2):
                node.in_port(port_ind + 2).connect(
                    split_pads.out_port(port_ind))
                node.in_port(port_ind + 2).get_connection().insert_node(
                    create_op_with_const_inputs(graph, Squeeze,
                                                {1: int64_array([0])}))

            # add zeros/ones to related inputs to align it with data input
            in0_rank = Rank(graph, {
                'name': node.name + '/rank_0'
            }).create_node()
            in1_shape = Shape(graph, {
                'name': node.name + '/rank_1'
            }).create_node()

            diff_size = Sub(graph, {
                'name': node.name + '/sub_0'
            }).create_node()
            diff = Sub(graph, {'name': node.name + '/sub_1'}).create_node()
            const_begin = Const(graph, {
                'value': int64_array([1])
            }).create_node()
            const_pad_val = Const(graph, {
                'value': int64_array(1)
            }).create_node()

            block_shape = Pad(graph, {
                'name': node.name + '/aligned_block_shape',
                'mode': 'constant'
            }).create_node()

            # in case of SpaceToBatch begin = pads_begin, end = pads_end
            # in case of BatchToSpace begin = crops_begin, end = crops_end
            new_begin_name = '/aligned_pads_begin'
            new_end_name = '/aligned_pads_end'
            if node.type == 'BatchToSpace':
                new_begin_name = '/aligned_crops_begin'
                new_end_name = '/aligned_crops_end'

            begin = Pad(graph, {
                'name': node.name + new_begin_name,
                'mode': 'constant'
            }).create_node()
            end = Pad(graph, {
                'name': node.name + new_end_name,
                'mode': 'constant'
            }).create_node()

            in0_rank_1d = create_op_node_with_second_input(
                graph, Unsqueeze, int64_array([0]),
                {'name': node.name + '/1d_rank_of_0'}, in0_rank)

            node.in_port(0).get_source().connect(in0_rank.in_port(0))
            node.in_port(1).get_source().connect(in1_shape.in_port(0))
            in0_rank_1d.out_port(0).connect(diff_size.in_port(0))
            in1_shape.out_port(0).connect(diff_size.in_port(1))
            diff_size.out_port(0).connect(diff.in_port(0))
            const_begin.out_port(0).connect(diff.in_port(1))
            const_pad_val.out_port(0).connect(block_shape.in_port(3))

            inputs_array = [block_shape, begin, end]
            for idx, input_to_node in enumerate(inputs_array):
                name_of_input_to_node = input_to_node.name
                node.in_port(idx + 1).get_connection().set_destination(
                    input_to_node.in_port(0))
                const_begin.out_port(0).connect(input_to_node.in_port(1))
                diff.out_port(0).connect(input_to_node.in_port(2))
                input_to_node.out_port(0).connect(node.in_port(idx + 1))
                convert = Cast(graph, {
                    'name': name_of_input_to_node + '/i64',
                    'dst_type': np.int64
                }).create_node()
                input_to_node.in_port(0).get_connection().insert_node(convert)
Exemplo n.º 21
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['pool']
        node_name = node.soft_get('name', node.id)

        if node.pool_step is None:
            node.stride = int64_array([1, 1, node.window[-1], node.window[-1]])

        # create Reshape before convolution
        # shape = [in_shape[0], pool_stride, 1, in_shape[1]/pool_stride]
        i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()

        dst_dtype = np.float32  # even if data_type=FP16 use float32 for shape values
        shape = Cast(graph, {
            'name': node_name + '/to_float',
            'dst_type': dst_dtype
        }).create_node()
        i_shape.in_port(0).connect(node.in_port(0).get_source())
        shape.in_port(0).connect(i_shape.out_port(0))

        N, H = node_to_get_shape_value_of_indices(
            shape, [0]), node_to_get_shape_value_of_indices(shape, [1])

        div = create_op_with_const_inputs(
            graph, Div, {1: float32_array([node.pool_stride])},
            {'name': node_name + '/div_stride_h'})
        div.in_port(0).connect(H.out_port(0))

        concat = create_op_with_const_inputs(
            graph, Concat, {
                1: float32_array([node.pool_stride]),
                2: float32_array([1])
            }, {
                'name': node_name + '/concat_all_dims',
                'in_ports_count': 4,
                'axis': 0
            })
        concat.in_port(0).connect(N.out_port(0))
        concat.in_port(3).connect(div.out_port(0))

        reshape_pattern = Cast(graph, {
            'name': node_name + '/to_int',
            'dst_type': np.int64
        }).create_node()
        concat.out_port(0).connect(reshape_pattern.in_port(0))

        reshape_in = Reshape(graph, {
            'name': node_name + '/reshape_in'
        }).create_node()
        reshape_in.in_port(1).connect(reshape_pattern.out_port(0))

        # create Reshape after Convolution
        reshape_out = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            {'name': node_name + '/reshape_out'})

        # connect input_reshape_node
        source = node.in_port(0).get_source()
        node.in_port(0).get_connection().set_source(reshape_in.out_port(0))
        reshape_in.in_port(0).connect(source)
        # connect output_reshape_node
        node.out_port(0).get_connection().set_source(reshape_out.out_port(0))
        node.out_port(0).connect(reshape_out.in_port(0))
Exemplo n.º 22
0
def replace_tf_resize(graph: Graph, resize: Node, interpolation_mode: str):
    resize_name = resize.soft_get('name', resize.id)
    log.debug(
        "Converting of {} to Interpolate-4 is triggered for node {}.".format(
            resize.op, resize_name))

    num_of_inputs = len([
        port for port in resize.in_ports().values() if not port.disconnected()
    ])
    assert num_of_inputs == 2, \
        "Number of inputs of {} (with name {}) should be equal to 2".format(resize.op, resize_name)

    attrs_msg = "If half_pixel_centers attribute of the node {} with op {} is True, " \
                "the attribute align_corners must be False"
    assert not resize.half_pixel_centers or (resize.half_pixel_centers and not resize.align_corners), \
        attrs_msg.format(resize_name, resize.op)

    shape = Shape(graph, {'name': resize_name + '/shapeof'}).create_node()

    ss = create_op_with_const_inputs(graph, StridedSlice, {
        1: int64_array([1]),
        2: int64_array([3]),
        3: int64_array([1])
    }, {
        'name': resize_name + '/StridedSlice',
        'begin_mask': int64_array([1]),
        'end_mask': int64_array([1]),
        'new_axis_mask': int64_array([0]),
        'shrink_axis_mask': int64_array([0]),
        'ellipsis_mask': int64_array([0])
    })

    div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()

    shape_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()
    size_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()

    size_to_float.out_port(0).connect(div_node.in_port(0))
    shape_to_float.out_port(0).connect(div_node.in_port(1))
    ss.out_port(0).connect(shape_to_float.in_port(0))
    shape.out_port(0).connect(ss.in_port(0))

    align_corners = resize.align_corners
    half_pixel_centers = resize.half_pixel_centers

    nearest_mode = 'floor' if interpolation_mode == 'nearest' else 'round_prefer_floor'
    if align_corners:
        coordinate_transformation_mode = 'align_corners'
        if interpolation_mode == 'nearest':
            nearest_mode = 'round_prefer_ceil'
    elif half_pixel_centers:
        coordinate_transformation_mode = 'tf_half_pixel_for_nn' if interpolation_mode == 'nearest' else 'half_pixel'
    else:
        coordinate_transformation_mode = 'asymmetric'

    interpolate4 = create_op_with_const_inputs(
        graph, Interpolate, {3: int64_array([1, 2])}, {
            'name': resize_name + '/interpolate_4',
            'mode': interpolation_mode,
            'antialias': False,
            'coordinate_transformation_mode': coordinate_transformation_mode,
            'pads_begin': int64_array([0]),
            'pads_end': int64_array([0]),
            'nearest_mode': nearest_mode,
            'cube_coeff': -0.75,
            'shape_calculation_mode': 'sizes',
            'version': 'opset4',
            'in_ports_count': 4,
        })

    resize_input_connection = resize.in_port(0).get_connection()
    resize_input_connection.set_destination(interpolate4.in_port(0))
    resize_input_connection.get_source().connect(shape.in_port(0))

    div_node.out_port(0).connect(interpolate4.in_port(2))

    sizes_connection = resize.in_port(1).get_connection()
    sizes_connection.set_destination(interpolate4.in_port(1))
    sizes_connection.get_source().connect(size_to_float.in_port(0))

    resize.out_port(0).get_connection().set_source(interpolate4.out_port(0))
    rename_nodes([(resize, resize_name + '/delete'),
                  (interpolate4, resize_name)])
Exemplo n.º 23
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']
        name = node.soft_get('name', node.id)
        axis = node.axis
        input_shape_node = Shape(graph, {
            'name': name + '/ShapeOf'
        }).create_node()
        range_node = create_op_with_const_inputs(graph, Range, {
            0: mo_array(node.start),
            2: mo_array(node.step)
        }, {'name': name + '/Range'})
        node.in_port(0).get_connection().set_destination(
            input_shape_node.in_port(0))

        if axis is not None:
            '''
            Replace arange_like op to subgraph:
            Shape - Gather - Range
            '''
            gather_node = create_op_with_const_inputs(graph, Gather, {
                1: int64_array([axis]),
                2: int64_array(0)
            }, {'name': name + '/Gather'})
            input_shape_node.out_port(0).connect(gather_node.in_port(0))
            gather_node.out_port(0).connect(range_node.in_port(1))
            node.out_port(0).get_connection().set_source(
                range_node.out_port(0))
            rename_nodes([(node, name + '/ShouldBeDeleted'),
                          (range_node, name)])
        else:
            r'''
            Replace arange_like op to subgraph:
                    |
                 ShapeOf ----------- | 
                    |                |
                 ReduceProd          |
                    |                |
                  Range              |
                    |                |
                 Reshape ----------- | 
                    |
            '''

            flattened_shape_node = create_op_with_const_inputs(
                graph, ReduceProd, {1: int64_array([0])}, {
                    'name': input_shape_node.name + '/ReduceProd',
                    'keep_dims': True
                })
            reshape_backward_node = Reshape(graph, {
                'name': name + '/Reshape_backward'
            }).create_node()

            input_shape_node.out_port(0).connect(
                flattened_shape_node.in_port(0))
            flattened_shape_node.out_port(0).connect(range_node.in_port(1))
            range_node.out_port(0).connect(reshape_backward_node.in_port(0))
            input_shape_node.out_port(0).connect(
                reshape_backward_node.in_port(1))
            node.out_port(0).get_connection().set_source(
                reshape_backward_node.out_port(0))
            rename_nodes([(node, name + '/ShouldBeDeleted'),
                          (reshape_backward_node, name)])

        if node.repeat != 1:
            r"""
            First, we generate the correct stop value for Range like new_stop_value = stop_value // repeat + 1.
            Then repeats each value of the interval using Tile. After that we can get a longer interval
            so we reduce it with Slice.
            
            Sub-graph after Range node will be look like
            
            Range - Reshape([-1, 1]) - Tile([1, repeat]) - Reshape(-1) - Slice
            
            """

            if node.repeat < 1:
                raise Error(
                    "Unexpected value {} of the attribute 'repeat' for the node {}"
                    .format(node.repeat, name))

            div_node = create_op_with_const_inputs(
                graph, Div, {1: int64_array([node.repeat])},
                {'name': name + '/Divide'})
            add_node = create_op_with_const_inputs(
                graph, Add, {1: int64_array([1])},
                {'name': div_node.name + '/Add'})
            cast_node = Cast(graph, {
                'name': name + '/ConvertToI64',
                'dst_type': np.int64
            }).create_node()

            cast_node.out_port(0).connect(div_node.in_port(0))
            div_node.out_port(0).connect(add_node.in_port(0))
            range_node.in_port(1).get_connection().set_destination(
                cast_node.in_port(0))
            add_node.out_port(0).connect(range_node.in_port(1))

            tile_forward_reshape = create_op_with_const_inputs(
                graph, Reshape, {1: int64_array([-1, 1])},
                {'name': range_node.name + '/ForwardReshape'})
            tile = create_op_with_const_inputs(
                graph, Tile, {1: int64_array([1, node.repeat])},
                {'name': tile_forward_reshape.name + '/Tile'})
            tile_backward_reshape = create_op_with_const_inputs(
                graph, Reshape, {1: int64_array([-1])},
                {'name': tile.name + '/BackwardReshape'})
            slice_node = create_op_with_const_inputs(
                graph, Slice, {
                    1: int64_array([0]),
                    3: int64_array([0]),
                    4: int64_array([1])
                }, {'name': tile_backward_reshape.name + '/Slice'})

            tile_forward_reshape.out_port(0).connect(tile.in_port(0))
            tile.out_port(0).connect(tile_backward_reshape.in_port(0))
            tile_backward_reshape.out_port(0).connect(slice_node.in_port(0))
            slice_node.in_port(2).connect(div_node.in_port(0).get_source())

            range_node.out_port(0).get_connection().set_source(
                slice_node.out_port(0))
            range_node.out_port(0).connect(tile_forward_reshape.in_port(0))

            if axis is not None:
                rename_nodes([(range_node, name + '/Range'),
                              (slice_node, name)])

        # MXNet arange_like op has no stop attribute and the result tensor always matches the input shape, so
        # we have to correct the stop value for the Range node if step != 1 or start != 0
        if node.step != 1:
            # If step attribute is not integer, we will generate an interval with a larger size and then reduce it
            # using Slice
            true_elements_count_port = range_node.in_port(1).get_source()
            mul_value = np.ceil(node.step) if node.step > 0 else np.floor(
                node.step)
            stop_value = create_op_with_const_inputs(
                graph,
                Mul,
                port_value_dict={1: mo_array(np.ceil(mul_value))},
                op_attrs={'name': range_node.name + '/Stop'})
            range_node.in_port(1).get_connection().insert_node(stop_value)

            slice_range_values = create_op_with_const_inputs(
                graph, Slice, {
                    1: int64_array([0]),
                    3: int64_array([0]),
                    4: int64_array([1])
                }, {'name': range_node.name + '/Slice'})
            slice_range_values.in_port(2).connect(true_elements_count_port)
            range_node.out_port(0).get_connection().insert_node(
                slice_range_values)

            if axis is not None and node.repeat == 1:
                rename_nodes([(range_node, name + '/Range'),
                              (slice_range_values, name)])

        if node.start != 0:
            correct_stop_value = create_op_with_const_inputs(
                graph,
                Add,
                port_value_dict={1: mo_array(node.start)},
                op_attrs={'name': range_node.name + '/Correct_Stop'})
            range_node.in_port(1).get_connection().insert_node(
                correct_stop_value)

        # Range node supports only scalar inputs
        squeeze_node = create_op_with_const_inputs(
            graph,
            Squeeze,
            port_value_dict={1: int64_array(0)},
            op_attrs={"name": range_node.name + '/Stop/Squeeze'})
        range_node.in_port(1).get_connection().insert_node(squeeze_node)
    def replace_sub_graph(self, graph: Graph, match: dict):
        identity_spw = match['identity_spw']
        gather0_1 = match['gather0_1']
        gather0_2 = match['gather0_2']
        greaterequal0 = match['greaterequal0']
        sparse_fill_empty_rows = match['sparse_fill_empty_rows']
        gather = match['gather']
        select = match['select']
        where0 = match['where0']
        sparse_segment_op = match['sparse_segment_op']
        output_node_name = select.soft_get('name', select.id)

        log.debug('Found EmbeddingSparseSegmentsSingleFeature pattern after {} with name {}'.format(
            sparse_fill_empty_rows.op,
            sparse_fill_empty_rows.name))

        split_for_indices = create_op_with_const_inputs(graph, Split, {1: int64_array(1)}, {'num_splits': 2})
        squeeze_for_indices = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([1])})
        split_for_dense_shape = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2})
        squeeze_to_scalar = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])})

        # TODO: remove Cast nodes once we start to support EmbeddingSegmentSum (new version) with segment_ids,
        #  indices, and num_segments of different integer type.
        #  Because the real cases show that it is possible to have it in TensorFlow
        cast_indices = Cast(graph, {'name': output_node_name + '/CastIndices', 'dst_type': np.int32}).create_node()
        cast_segment_ids = Cast(graph, {'name': output_node_name + '/CastSegmentIds',
                                        'dst_type': np.int32}).create_node()
        cast_default_value = Cast(graph, {'name': output_node_name + '/CastDefaultValue',
                                          'dst_type': np.int32}).create_node()
        cast_num_segments = Cast(graph, {'name': output_node_name + '/CastSegmentsNumber',
                                         'dst_type': np.int32}).create_node()
        if sparse_segment_op.op == 'SparseSegmentSum':
            embedding_segments_op = EmbeddingSegmentsSum(graph, {'name': output_node_name}).create_node()
        else:
            embedding_segments_op = EmbeddingSegmentsMean(graph, {'name': output_node_name}).create_node()
        rename_nodes([(select, output_node_name + '/AbandonedName'), (embedding_segments_op, output_node_name)])

        # connect parameters table
        gather.in_port(0).get_connection().set_destination(embedding_segments_op.in_port(0))
        # connect indices values
        greaterequal0.in_port(0).get_connection().set_destination(cast_indices.in_port(0))
        embedding_segments_op.in_port(1).connect(cast_indices.out_port(0))
        # split and connect segment ids
        gather0_1.in_port(0).get_connection().set_destination(split_for_indices.in_port(0))
        squeeze_for_indices.in_port(0).connect(split_for_indices.out_port(0))
        cast_segment_ids.in_port(0).connect(squeeze_for_indices.out_port(0))
        embedding_segments_op.in_port(2).connect(cast_segment_ids.out_port(0))
        # split and connect number of segments
        identity_spw.in_port(0).get_connection().set_destination(split_for_dense_shape.in_port(0))
        squeeze_to_scalar.in_port(0).connect(split_for_dense_shape.out_port(0))
        cast_num_segments.in_port(0).connect(squeeze_to_scalar.out_port(0))
        embedding_segments_op.in_port(3).connect(cast_num_segments.out_port(0))
        # connect default value
        sparse_fill_empty_rows.in_port(3).get_connection().set_destination(cast_default_value.in_port(0))
        embedding_segments_op.in_port(4).connect(cast_default_value.out_port(0))
        # no input port for per_sample_weight

        identity_spw.in_port(0).disconnect()
        gather0_1.in_port(0).disconnect()
        gather0_2.in_port(0).disconnect()
        greaterequal0.in_port(0).disconnect()
        sparse_fill_empty_rows.in_port(2).disconnect()
        gather.in_port(0).disconnect()

        select.out_port(0).get_connection().set_source(embedding_segments_op.out_port(0))
        graph.remove_nodes_from(
            [gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id, where0.id])
Exemplo n.º 25
0
    def dequantize_data(fake_quantize: Node, dst_type: type,
                        quantized_type: type) -> Node:
        graph = fake_quantize.graph
        quantized_data = fake_quantize.in_port(0).get_source().node
        name = fake_quantize.soft_get('name', fake_quantize.id)

        assert quantized_data.soft_get('type') == 'Convert' and quantized_data.dst_type == quantized_type, \
            'Weights aren`t compressed as expected for node {}'.format(fake_quantize.soft_get('name', fake_quantize.id))

        dequantizing_cast = Cast(
            graph,
            dict(name=quantized_data.name +
                 "/to_{}".format(np_data_type_to_destination_type(dst_type)),
                 dst_type=dst_type,
                 stop_value_propagation=True)).create_node()
        fake_quantize.in_port(0).get_connection().set_destination(
            dequantizing_cast.in_port(0))

        # limits of dequantize
        in_low = fake_quantize.in_port(1).get_source()
        in_high = fake_quantize.in_port(2).get_source()
        out_low = fake_quantize.in_port(3).get_source()
        out_high = fake_quantize.in_port(4).get_source()

        # scale calculation
        output_range = Sub(graph, {
            'name': name + '/output_range'
        }).create_node()
        output_range.in_port(0).connect(out_high)
        output_range.in_port(1).connect(out_low)

        input_range = Sub(graph, {'name': name + '/input_range'}).create_node()
        input_range.in_port(0).connect(in_high)
        input_range.in_port(1).connect(in_low)

        scale = Div(graph, {'name': name + '/scale'}).create_node()
        scale.in_port(0).connect(output_range.out_port(0))
        scale.in_port(1).connect(input_range.out_port(0))

        # shift calculation
        descaled_output_low = Div(graph, {
            'name': name + '/descaled_output_low'
        }).create_node()
        descaled_output_low.in_port(0).connect(out_low)
        descaled_output_low.in_port(1).connect(scale.out_port(0))

        shift = Sub(graph, {'name': name + '/shift'}).create_node()
        shift.in_port(0).connect(in_low)
        shift.in_port(1).connect(descaled_output_low.out_port(0))

        zero = Const(graph, {
            'name': name + '/zero',
            'value': mo_array(0, dtype=dst_type)
        }).create_node()
        scale_eq_zero = Equal(graph, {
            'name': name + '/scale_eq_zero'
        }).create_node()
        scale_eq_zero.in_port(0).connect(scale.out_port(0))
        scale_eq_zero.in_port(1).connect(zero.out_port(0))

        zero_point = Select(graph, {
            'name': name + '/zero_point'
        }).create_node()
        zero_point.in_port(0).connect(scale_eq_zero.out_port(0))
        zero_point.in_port(1).connect(zero.out_port(0))
        zero_point.in_port(2).connect(shift.out_port(0))

        # DeQuantize(x) == Mul(Sub(x, zero_point), scale)
        sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node()
        sub_zp.in_port(0).connect(dequantizing_cast.out_port(0))
        sub_zp.in_port(1).connect(zero_point.out_port(0))

        mul_scale = Mul(graph, {
            'name': name + '/mulpiply_by_scale'
        }).create_node()
        mul_scale.in_port(0).connect(sub_zp.out_port(0))
        mul_scale.in_port(1).connect(scale.out_port(0))

        fake_quantize.out_port(0).get_connection().set_source(
            mul_scale.out_port(0))

        graph.remove_nodes_from([fake_quantize.id, fake_quantize.out_node(0)])
Exemplo n.º 26
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        group_norm_node = match['op']
        group_norm_num_input_dims = len(group_norm_node.in_port(0).data.get_shape())

        # node computing initial GroupNorm input shape
        initial_shape_op_node = Shape(graph, {'name': group_norm_node.name + '/Shape'}).create_node()
        initial_shape_op_node.in_port(0).connect(group_norm_node.in_port(0).get_source())

        initial_shape_op_node_float = Cast(
            graph, {'name': initial_shape_op_node.name + '/to_float',
                    'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
        initial_shape_op_node.out_port(0).connect(initial_shape_op_node_float.in_port(0))

        initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node_float)
        initial_features_dim_node = node_to_get_features_dimension_value(initial_shape_op_node_float)
        initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value(initial_shape_op_node)
        initial_spatial_dims_node = Cast(
            graph, {'name': initial_spatial_dims_node_int.name + '/to_float',
                    'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
        initial_spatial_dims_node_int.out_port(0).connect(initial_spatial_dims_node.in_port(0))

        group_size_node = Const(graph, {'value': int64_array([group_norm_node.num_groups]),
                                        'name': group_norm_node.name + '/GroupSize'}).create_node()

        # calculate "features // group_size" value
        reciprocal_group_size_node = Const(graph, {'value': np.array([1.0 / group_norm_node.num_groups]),
                                                   'name': group_norm_node.name + '/ReciprocalGroupSize'}).create_node()

        c_div_g_node = Mul(graph, {}).create_node()
        c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0))
        c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0))

        batch_mul_group_size_node = Mul(graph, {}).create_node()
        batch_mul_group_size_node.in_port(0).connect(initial_batch_dim_node.out_port(0))
        batch_mul_group_size_node.in_port(1).connect(group_size_node.out_port(0))

        # create new node which concatenates several dims to one
        new_shape_node_float = new_shape_node_from_shape_nodes([batch_mul_group_size_node, c_div_g_node,
                                                                initial_spatial_dims_node])
        new_shape_node = Cast(graph,
                              {'name': new_shape_node_float.name + '/to_int64', 'dst_type': np.int64}).create_node()
        new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0))

        reshape_for_mvn_node = Reshape(graph, {}).create_node()

        group_norm_node.in_port(0).get_connection().set_destination(reshape_for_mvn_node.in_port(0))
        reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0))

        # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc
        gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64)
        gamma_beta_shape[1] = -1

        gamma_value = group_norm_node.in_port(1).get_source().data.get_value()
        beta_value = group_norm_node.in_port(2).get_source().data.get_value()
        assert gamma_value is not None, 'The gamma should be constant'
        assert beta_value is not None, 'The beta should be constant'
        gamma_value = np.reshape(gamma_value, gamma_beta_shape)
        group_norm_node.in_port(1).get_source().data.set_value(gamma_value)
        beta_value = np.reshape(beta_value, gamma_beta_shape)
        group_norm_node.in_port(2).get_source().data.set_value(beta_value)

        # MVN
        mvn_node = MVN(graph, {'name': group_norm_node.name + '/MVN',
                               'normalize_variance': 1,
                               'eps': group_norm_node.eps,
                               'eps_mode': 'inside_sqrt'}).create_node()
        mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0))

        # MVN axes
        _, rank = get_shape_and_rank_nodes_by_port(mvn_node.in_port(0).get_connection().get_source(),
                                                   return_as_a_scalar=True)
        rng = create_op_with_const_inputs(graph, Range, {0: int64_array(1), 2: int64_array(1)},
                                          {'name': group_norm_node.name + '/Range', 'output_type': np.int64})
        mvn_node.in_port(1).connect(rng.out_port(0))
        rng.in_port(1).connect(rank.out_port(0))

        # reshape to the initial shape before multiplying with gamma and adding beta
        reshape_to_initial_shape_node = Reshape(graph, {}).create_node()
        reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0))
        reshape_to_initial_shape_node.in_port(1).connect(initial_shape_op_node.out_port(0))

        mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node()
        mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0))
        group_norm_node.in_port(1).get_connection().set_destination(mul_node.in_port(1))

        add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node()
        add_node.in_port(0).connect(mul_node.out_port(0))
        group_norm_node.in_port(2).get_connection().set_destination(add_node.in_port(1))

        group_norm_node.out_port(0).get_connection().set_source(add_node.out_port(0))
Exemplo n.º 27
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        log.debug('UpsampleToResample is triggered')
        upsample = match['upsample']
        upsample_name = upsample.soft_get('name', upsample.id)
        input_shape = upsample.in_port(0).data.get_shape()
        input_shape_rank = len(input_shape)
        if input_shape_rank not in [4, 5]:
            log.warning('The input shape is not 4D or 5D for op {}'.format(
                upsample.soft_get('name')))
            return

        depth_scale = None
        layout = graph.graph['layout']

        if len(upsample.in_nodes()) == 2:
            if upsample.in_node(1).value is None:
                return
            scales = upsample.in_node(1).value
            assert len(scales) in (
                4, 5
            ), 'Supported scales rank is 4 or 5, but it is {} for node {}'.format(
                len(scales), upsample_name)
            if not (math.isclose(scales[0], 1, rel_tol=1e-5)
                    and math.isclose(scales[1], 1, rel_tol=1e-5)):
                return
            height_scale = scales[get_height_dim(layout, input_shape_rank)]
            width_scale = scales[get_width_dim(layout, input_shape_rank)]
            if len(scales) == 5:
                depth_scale = scales[get_depth_dim(layout, input_shape_rank)]
        else:
            height_scale = upsample['height_scale']
            width_scale = upsample['width_scale']

        if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
            upsample.in_port(1).disconnect()

        upsample_name = upsample.soft_get('name', upsample.id)
        shape = Shape(graph, {'name': upsample_name + '/0_port'}).create_node()

        layout = graph.graph['layout']

        if input_shape_rank == 4:
            begin_value = int64_array(
                [get_height_dim(layout, input_shape_rank)])
            factor_value = float32_array([height_scale, width_scale])
        else:
            begin_value = int64_array(
                [get_depth_dim(layout, input_shape_rank)])
            factor_value = float32_array(
                [depth_scale, height_scale, width_scale])

        ss = create_op_with_const_inputs(
            graph, StridedSlice, {
                1: begin_value,
                2: int64_array([get_width_dim(layout, input_shape_rank) + 1]),
                3: int64_array([1])
            }, {
                'name': upsample_name + '/ss_0_port',
                'begin_mask': int64_array([1]),
                'end_mask': int64_array([1]),
                'new_axis_mask': int64_array([0]),
                'shrink_axis_mask': int64_array([0]),
                'ellipsis_mask': int64_array([0])
            })

        mul = create_op_node_with_second_input(
            graph, Mul, factor_value, {'name': upsample_name + '/factor_mul'})

        source = upsample.in_port(0).get_connection().get_source()
        source.connect(shape.in_port(0))
        shape.out_port(0).connect(ss.in_port(0))

        ss.out_port(0).connect(mul.in_port(0))

        # Create Interpolate operation
        if input_shape_rank == 4:
            axes = int64_array([
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])
        else:
            axes = int64_array([
                get_depth_dim(layout, input_shape_rank),
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])

        axes_node = Const(graph, {
            'name': upsample_name + '/axis',
            'value': axes
        }).create_node()

        interpolate = Interpolate(
            graph, {
                'mode': upsample.attrs()['mode'],
                'antialias': 0,
                'pads_begin': int64_array([0]),
                'pads_end': int64_array([0]),
                'coordinate_transformation_mode': 'half_pixel',
                'nearest_mode': 'round_prefer_floor',
                'cube_coeff': -0.75,
                'shape_calculation_mode': 'scales',
                'version': 'opset4',
                'in_ports_count': 4
            }).create_node()

        upsample.add_input_port(1, skip_if_exist=True)
        assert upsample.in_port(1).disconnected()
        mul.out_port(0).connect(interpolate.in_port(1))
        axes_node.out_port(0).connect(interpolate.in_port(3))

        scales_node = Const(graph, {
            'name': upsample_name + '/scales',
            'value': factor_value
        }).create_node()
        scales_node.out_port(0).connect(interpolate.in_port(2))

        upsample.in_port(0).get_connection().set_destination(
            interpolate.in_port(0))
        upsample.out_port(0).get_connection().set_source(
            interpolate.out_port(0))

        rename_nodes([(upsample, upsample_name + '/delete'),
                      (interpolate, upsample_name)])

        convert_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()
        convert_to_int = Cast(graph, dict(dst_type=np.int64)).create_node()

        mul.in_port(0).get_connection().insert_node(convert_to_float)
        mul.out_port(0).get_connection().insert_node(convert_to_int)