Ejemplo n.º 1
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        div_sqrt = match['op']
        div_sqrt_name = div_sqrt.soft_get('name', div_sqrt.id)
        shape_node = Shape(graph,
                           dict(name=div_sqrt_name + '/Shape')).create_node()
        data_out_port = div_sqrt.in_port(0).get_source()
        shape_node.in_port(0).connect(data_out_port)

        shape_values_node = node_to_get_shape_value_of_indices(
            shape_node=shape_node, indices=[-1])

        pow_node = AttributedPower(
            graph, dict(name=div_sqrt_name + '/Sqrt',
                        power=mo_array(0.5))).create_node()

        # Due to specification, Power must have inputs with the same data type.
        convert_pow_input = Cast(
            graph,
            dict(dst_type=np.float32,
                 name=shape_values_node.name +
                 '/ConvertToFP32')).create_node()
        div_node = Div(graph, dict(name="Div")).create_node()

        shape_values_node.out_port(0).connect(convert_pow_input.in_port(0))
        convert_pow_input.out_port(0).connect(pow_node.in_port(0))
        div_sqrt.in_port(0).get_connection().set_destination(
            div_node.in_port(0))
        div_node.in_port(1).connect(pow_node.out_port(0))
        div_sqrt.out_port(0).get_connection().set_source(div_node.out_port(0))

        rename_nodes([(div_sqrt, div_sqrt_name + '/ShouldBeDeleted'),
                      (div_node, div_sqrt_name)])
Ejemplo n.º 2
0
    def find_and_replace_pattern(self, graph: Graph):
        for dequantize_node in graph.get_op_nodes(op='DequantizeLinear'):
            node_name = dequantize_node.soft_get('name', dequantize_node.id)
            axis = dequantize_node.soft_get('axis', None)
            scale_y_shape = dequantize_node.in_port(1).data.get_shape()
            model_data_type = data_type_str_to_np(
                graph.graph['cmd_params'].data_type)
            cast = Cast(graph, {
                'dst_type': model_data_type,
                'name': node_name + '/Cast'
            }).create_node()
            dequantize_node.in_port(0).get_connection().set_destination(
                cast.in_port(0))
            mul = Mul(graph, {'can_be_fused': False}).create_node()

            is_second_port_connected = dequantize_node.is_in_port_connected(2)
            if is_second_port_connected:
                # its is necessary not to replace subrtract for pattern in offline transformations
                # See ConvertQuantizeDequantize transformation in ngraph
                sub = Sub(graph, {
                    'name': node_name + '/Sub',
                    'zero_point_sub': True
                }).create_node()
                cast.out_port(0).connect(sub.in_port(0))
                dequantize_node.in_port(2).get_connection().set_destination(
                    sub.in_port(1))
                sub.out_port(0).connect(mul.in_port(0))
            else:
                cast.out_port(0).connect(mul.in_port(0))

            dequantize_node.in_port(1).get_connection().set_destination(
                mul.in_port(1))
            dequantize_node.out_port(0).get_connection().set_source(
                mul.out_port(0))
            rename_nodes([(dequantize_node, node_name + '/TBD'),
                          (mul, node_name)])

            assert scale_y_shape is not None
            if axis is not None and len(
                    scale_y_shape) > 0 and scale_y_shape[0] > 1:
                input_shape = cast.in_port(0).data.get_shape()
                target_shape = np.ones(len(input_shape), np.int64)
                target_shape[axis] = input_shape[axis]

                mul_reshape = create_op_with_const_inputs(
                    graph, Reshape, {1: int64_array(target_shape)},
                    {'name': node_name + '/Reshape/Mul'})
                mul.in_port(1).get_connection().set_destination(
                    mul_reshape.in_port(0))
                mul_reshape.out_port(0).connect(mul.in_port(1))

                if is_second_port_connected:
                    sub_reshape = create_op_with_const_inputs(
                        graph, Reshape, {1: int64_array(target_shape)},
                        {'name': node_name + '/Reshape/Sub'})
                    sub.in_port(1).get_connection().set_destination(
                        sub_reshape.in_port(0))
                    sub_reshape.out_port(0).connect(sub.in_port(1))
Ejemplo n.º 3
0
def create_ss_interval_border(graph: Graph, slice_border_port: Port,
                              shape: np.ndarray, axes: np.ndarray,
                              node_name: str):
    """
    This function creates "begin"/"end" parameters for the StridedSlice based on Slice's "starts"/"ends"

    :param graph: graph to operate on.
    :param slice_border_port: node output port that provides "starts"/"ends" values for the Slice.
    :param shape: input shape of the Slice
    :param axes: axes that "starts" and "ends" apply to
    :param node_name: Slice node name
    :return: Concat node that forms "begin"/"end" values for the StridedSlice
    """
    # the value for 'starts' or 'ends' might be maximum/minimum possible value of int64. This
    # value must be converted to maximum/minimum of int32 because such big values do not fit into the int32 which is
    # supported by the StridedSlice layer
    clamp = create_op_with_const_inputs(graph,
                                        Clamp,
                                        port_value_dict={
                                            1: np.iinfo(np.int32).min,
                                            2: np.iinfo(np.int32).max
                                        },
                                        op_attrs=dict(name=node_name +
                                                      '/Clamp'))
    clamp.in_port(0).connect(slice_border_port)
    # we have to convert "starts"/"ends" values from the network to one data type with constant values that are created
    # here to prevent type errors in Concat node
    cast = Cast(graph, dict(name=node_name + '/CastToI64',
                            dst_type=np.int64)).create_node()
    cast.in_port(0).connect(clamp.out_port(0))
    concat = Concat(graph, dict(name=node_name + '/Concat',
                                axis=0)).create_node()
    for value_idx, port_idx in enumerate(axes):
        concat.add_input_port(port_idx)
        # "axes" may not be sorted, so we need to split "starts"/"ends" values and connect each value to the correct
        # Concat input port
        value = create_op_with_const_inputs(
            graph,
            Gather,
            port_value_dict={
                1: int64_array([value_idx]),
                2: int64_array(0)
            },
            op_attrs={'name': node_name + '/Gather'})
        cast.out_port(0).connect(value.in_port(0))
        value.out_port(0).connect(concat.in_port(port_idx))
    for port_idx in range(len(shape)):
        if not concat.is_in_port_connected(port_idx):
            concat.add_input_port(port_idx)
            # This border value would be ignored in StridedSlice because of the begin_mask\end_mask
            const = Const(
                graph, dict(name=node_name + '/Const',
                            value=int64_array([0]))).create_node()
            const.out_port(0).connect(concat.in_port(port_idx))

    return concat
def replace_interpolate_pattern(graph: Graph, match: dict):
    split = match['split']
    scale = float32_array([get_split_scale(split)])
    axis = int(split.in_port(1).get_connection().get_source().node.value)
    split_node_name = split.name
    axis_node = Const(graph, {'name': split_node_name + '/axis', 'value': int64_array([axis])}).create_node()

    shape_node = Shape(graph, dict(name=split_node_name + '/Shape')).create_node()
    scales_node = Const(graph, dict(name=split_node_name + '/scales', value=scale)).create_node()
    mul_node = Mul(graph, dict(name=split_node_name + '/Mul')).create_node()
    scales_node.out_port(0).connect(mul_node.in_port(1))

    strided_slice_node = create_op_with_const_inputs(graph,
                                                     StridedSlice,
                                                     {1: int64_array([axis]), 2: int64_array([axis + 1])},
                                                     {
                                                        'name': split_node_name + '/StridedSlice',
                                                        'begin_mask': int64_array([1]),
                                                        'end_mask': int64_array([1]),
                                                        'new_axis_mask': int64_array([0]),
                                                        'shrink_axis_mask': int64_array([0]),
                                                        'ellipsis_mask': int64_array([0])
                                                     })
    shape_node.out_port(0).connect(strided_slice_node.in_port(0))

    cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node()

    strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0))
    cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))

    interp_node = Interpolate(graph,
                              dict(name=split_node_name + '/Interpolate',
                                   mode='nearest',
                                   antialias=0, pads_begin=int64_array([0]), pads_end=int64_array([0]),
                                   coordinate_transformation_mode='half_pixel', nearest_mode='round_prefer_floor',
                                   cube_coeff=-0.75, version='opset4', shape_calculation_mode='scales',
                                   in_ports_count=4, maybe_part_of_sequence=True)).create_node()

    floor_node = Floor(graph, {'name': split_node_name + '/Floor'}).create_node()
    cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node()

    mul_node.out_port(0).connect(floor_node.in_port(0))
    floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0))

    cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1))
    scales_node.out_port(0).connect(interp_node.in_port(2))
    axis_node.out_port(0).connect(interp_node.in_port(3))

    match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0))

    split_connection = split.in_port(0).get_connection()
    split_connection.set_destination(interp_node.in_port(0))
    split_connection.get_source().connect(shape_node.in_port(0))
Ejemplo n.º 5
0
    def quantize_data(fake_quantize: Node, dst_type: type,
                      quantized_type: type, mode: str):
        graph = fake_quantize.graph
        name = fake_quantize.soft_get('name', fake_quantize.id)
        levels = fake_quantize.levels

        quantize = fake_quantize.copy_node(
            dict(name=name + '/Copy', stop_value_propagation=False), graph)
        fake_quantize.in_port(0).get_connection().set_destination(
            quantize.in_port(0))

        # inherit input limits
        fake_quantize.in_port(1).get_connection().set_destination(
            quantize.in_port(1))
        fake_quantize.in_port(2).get_connection().set_destination(
            quantize.in_port(2))

        # calculate output limits for quantized weights
        assert mode in ["signed", "unsigned"]
        i_min_value = -(levels // 2) if mode == "signed" else 0

        i_min = mo_array(i_min_value, dtype=dst_type) if not quantize.in_node(
            0).shape.size else mo_array([i_min_value], dtype=dst_type)
        i_max = mo_array(levels + i_min - 1, dtype=dst_type)

        assert i_max - i_min == levels - 1
        out_low = Const(graph, dict(name=name + '/Copy/out_low',
                                    value=i_min)).create_node()
        out_high = Const(graph, dict(name=name + '/Copy/out_high',
                                     value=i_max)).create_node()

        out_low.out_port(0).connect(quantize.in_port(3))
        out_high.out_port(0).connect(quantize.in_port(4))
        out_low.out_port(0).connect(fake_quantize.in_port(1))
        out_high.out_port(0).connect(fake_quantize.in_port(2))

        original_const = quantize.in_port(0).get_source().node
        quantized_data_name = original_const.soft_get(
            'name', original_const.id) + '/quantized'
        cast = Cast(
            graph,
            dict(name=quantized_data_name,
                 dst_type=quantized_type,
                 stop_value_propagation=False)).create_node()

        quantize.out_port(0).connect(cast.in_port(0))

        cast.out_port(0).connect(fake_quantize.in_port(0))
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='ThresholdedRelu'):
            name = node.soft_get('name', node.id)

            greater = create_op_with_const_inputs(
                graph, Greater, {1: float_array([node.alpha])})
            greater.in_port(0).connect(node.in_port(0).get_source())
            float_greater = Cast(
                graph, {
                    'dst_type':
                    data_type_str_to_np(graph.graph['cmd_params'].data_type)
                }).create_node()
            greater.out_port(0).connect(float_greater.in_port(0))

            mul = Mul(graph, {}).create_node()
            node.out_port(0).get_connection().set_source(mul.out_port(0))
            mul.in_port(0).connect(node.in_port(0).get_source())
            mul.in_port(1).connect(float_greater.out_port(0))

            rename_nodes([(node, name + '/TBR'), (mul, name)])
            graph.remove_node(node.id)
Ejemplo n.º 7
0
def replace_tf_resize(graph: Graph, resize: Node, interpolation_mode: str):
    resize_name = resize.soft_get('name', resize.id)
    log.debug(
        "Converting of {} to Interpolate-4 is triggered for node {}.".format(
            resize.op, resize_name))

    num_of_inputs = len([
        port for port in resize.in_ports().values() if not port.disconnected()
    ])
    assert num_of_inputs == 2, \
        "Number of inputs of {} (with name {}) should be equal to 2".format(resize.op, resize_name)

    attrs_msg = "If half_pixel_centers attribute of the node {} with op {} is True, " \
                "the attribute align_corners must be False"
    assert not resize.half_pixel_centers or (resize.half_pixel_centers and not resize.align_corners), \
        attrs_msg.format(resize_name, resize.op)

    shape = Shape(graph, {'name': resize_name + '/shapeof'}).create_node()

    ss = create_op_with_const_inputs(graph, StridedSlice, {
        1: int64_array([1]),
        2: int64_array([3]),
        3: int64_array([1])
    }, {
        'name': resize_name + '/StridedSlice',
        'begin_mask': int64_array([1]),
        'end_mask': int64_array([1]),
        'new_axis_mask': int64_array([0]),
        'shrink_axis_mask': int64_array([0]),
        'ellipsis_mask': int64_array([0])
    })

    div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()

    shape_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()
    size_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()

    size_to_float.out_port(0).connect(div_node.in_port(0))
    shape_to_float.out_port(0).connect(div_node.in_port(1))
    ss.out_port(0).connect(shape_to_float.in_port(0))
    shape.out_port(0).connect(ss.in_port(0))

    align_corners = resize.align_corners
    half_pixel_centers = resize.half_pixel_centers

    nearest_mode = 'floor' if interpolation_mode == 'nearest' else 'round_prefer_floor'
    if align_corners:
        coordinate_transformation_mode = 'align_corners'
        if interpolation_mode == 'nearest':
            nearest_mode = 'round_prefer_ceil'
    elif half_pixel_centers:
        coordinate_transformation_mode = 'tf_half_pixel_for_nn' if interpolation_mode == 'nearest' else 'half_pixel'
    else:
        coordinate_transformation_mode = 'asymmetric'

    interpolate4 = create_op_with_const_inputs(
        graph, Interpolate, {3: int64_array([1, 2])}, {
            'name': resize_name + '/interpolate_4',
            'mode': interpolation_mode,
            'antialias': False,
            'coordinate_transformation_mode': coordinate_transformation_mode,
            'pads_begin': int64_array([0]),
            'pads_end': int64_array([0]),
            'nearest_mode': nearest_mode,
            'cube_coeff': -0.75,
            'shape_calculation_mode': 'sizes',
            'version': 'opset4',
            'in_ports_count': 4,
        })

    resize_input_connection = resize.in_port(0).get_connection()
    resize_input_connection.set_destination(interpolate4.in_port(0))
    resize_input_connection.get_source().connect(shape.in_port(0))

    div_node.out_port(0).connect(interpolate4.in_port(2))

    sizes_connection = resize.in_port(1).get_connection()
    sizes_connection.set_destination(interpolate4.in_port(1))
    sizes_connection.get_source().connect(size_to_float.in_port(0))

    resize.out_port(0).get_connection().set_source(interpolate4.out_port(0))
    rename_nodes([(resize, resize_name + '/delete'),
                  (interpolate4, resize_name)])
Ejemplo n.º 8
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']
        name = node.soft_get('name', node.id)
        axis = node.axis
        input_shape_node = Shape(graph, {
            'name': name + '/ShapeOf'
        }).create_node()
        range_node = create_op_with_const_inputs(graph, Range, {
            0: mo_array(node.start),
            2: mo_array(node.step)
        }, {'name': name + '/Range'})
        node.in_port(0).get_connection().set_destination(
            input_shape_node.in_port(0))

        if axis is not None:
            '''
            Replace arange_like op to subgraph:
            Shape - Gather - Range
            '''
            gather_node = create_op_with_const_inputs(graph, Gather, {
                1: int64_array([axis]),
                2: int64_array(0)
            }, {'name': name + '/Gather'})
            input_shape_node.out_port(0).connect(gather_node.in_port(0))
            gather_node.out_port(0).connect(range_node.in_port(1))
            node.out_port(0).get_connection().set_source(
                range_node.out_port(0))
            rename_nodes([(node, name + '/ShouldBeDeleted'),
                          (range_node, name)])
        else:
            r'''
            Replace arange_like op to subgraph:
                    |
                 ShapeOf ----------- | 
                    |                |
                 ReduceProd          |
                    |                |
                  Range              |
                    |                |
                 Reshape ----------- | 
                    |
            '''

            flattened_shape_node = create_op_with_const_inputs(
                graph, ReduceProd, {1: int64_array([0])}, {
                    'name': input_shape_node.name + '/ReduceProd',
                    'keep_dims': True
                })
            reshape_backward_node = Reshape(graph, {
                'name': name + '/Reshape_backward'
            }).create_node()

            input_shape_node.out_port(0).connect(
                flattened_shape_node.in_port(0))
            flattened_shape_node.out_port(0).connect(range_node.in_port(1))
            range_node.out_port(0).connect(reshape_backward_node.in_port(0))
            input_shape_node.out_port(0).connect(
                reshape_backward_node.in_port(1))
            node.out_port(0).get_connection().set_source(
                reshape_backward_node.out_port(0))
            rename_nodes([(node, name + '/ShouldBeDeleted'),
                          (reshape_backward_node, name)])

        if node.repeat != 1:
            r"""
            First, we generate the correct stop value for Range like new_stop_value = stop_value // repeat + 1.
            Then repeats each value of the interval using Tile. After that we can get a longer interval
            so we reduce it with Slice.
            
            Sub-graph after Range node will be look like
            
            Range - Reshape([-1, 1]) - Tile([1, repeat]) - Reshape(-1) - Slice
            
            """

            if node.repeat < 1:
                raise Error(
                    "Unexpected value {} of the attribute 'repeat' for the node {}"
                    .format(node.repeat, name))

            div_node = create_op_with_const_inputs(
                graph, Div, {1: int64_array([node.repeat])},
                {'name': name + '/Divide'})
            add_node = create_op_with_const_inputs(
                graph, Add, {1: int64_array([1])},
                {'name': div_node.name + '/Add'})
            cast_node = Cast(graph, {
                'name': name + '/ConvertToI64',
                'dst_type': np.int64
            }).create_node()

            cast_node.out_port(0).connect(div_node.in_port(0))
            div_node.out_port(0).connect(add_node.in_port(0))
            range_node.in_port(1).get_connection().set_destination(
                cast_node.in_port(0))
            add_node.out_port(0).connect(range_node.in_port(1))

            tile_forward_reshape = create_op_with_const_inputs(
                graph, Reshape, {1: int64_array([-1, 1])},
                {'name': range_node.name + '/ForwardReshape'})
            tile = create_op_with_const_inputs(
                graph, Tile, {1: int64_array([1, node.repeat])},
                {'name': tile_forward_reshape.name + '/Tile'})
            tile_backward_reshape = create_op_with_const_inputs(
                graph, Reshape, {1: int64_array([-1])},
                {'name': tile.name + '/BackwardReshape'})
            slice_node = create_op_with_const_inputs(
                graph, Slice, {
                    1: int64_array([0]),
                    3: int64_array([0]),
                    4: int64_array([1])
                }, {'name': tile_backward_reshape.name + '/Slice'})

            tile_forward_reshape.out_port(0).connect(tile.in_port(0))
            tile.out_port(0).connect(tile_backward_reshape.in_port(0))
            tile_backward_reshape.out_port(0).connect(slice_node.in_port(0))
            slice_node.in_port(2).connect(div_node.in_port(0).get_source())

            range_node.out_port(0).get_connection().set_source(
                slice_node.out_port(0))
            range_node.out_port(0).connect(tile_forward_reshape.in_port(0))

            if axis is not None:
                rename_nodes([(range_node, name + '/Range'),
                              (slice_node, name)])

        # MXNet arange_like op has no stop attribute and the result tensor always matches the input shape, so
        # we have to correct the stop value for the Range node if step != 1 or start != 0
        if node.step != 1:
            # If step attribute is not integer, we will generate an interval with a larger size and then reduce it
            # using Slice
            true_elements_count_port = range_node.in_port(1).get_source()
            mul_value = np.ceil(node.step) if node.step > 0 else np.floor(
                node.step)
            stop_value = create_op_with_const_inputs(
                graph,
                Mul,
                port_value_dict={1: mo_array(np.ceil(mul_value))},
                op_attrs={'name': range_node.name + '/Stop'})
            range_node.in_port(1).get_connection().insert_node(stop_value)

            slice_range_values = create_op_with_const_inputs(
                graph, Slice, {
                    1: int64_array([0]),
                    3: int64_array([0]),
                    4: int64_array([1])
                }, {'name': range_node.name + '/Slice'})
            slice_range_values.in_port(2).connect(true_elements_count_port)
            range_node.out_port(0).get_connection().insert_node(
                slice_range_values)

            if axis is not None and node.repeat == 1:
                rename_nodes([(range_node, name + '/Range'),
                              (slice_range_values, name)])

        if node.start != 0:
            correct_stop_value = create_op_with_const_inputs(
                graph,
                Add,
                port_value_dict={1: mo_array(node.start)},
                op_attrs={'name': range_node.name + '/Correct_Stop'})
            range_node.in_port(1).get_connection().insert_node(
                correct_stop_value)

        # Range node supports only scalar inputs
        squeeze_node = create_op_with_const_inputs(
            graph,
            Squeeze,
            port_value_dict={1: int64_array(0)},
            op_attrs={"name": range_node.name + '/Stop/Squeeze'})
        range_node.in_port(1).get_connection().insert_node(squeeze_node)
    def replace_sub_graph(self, graph: Graph, match: dict):
        identity_spw = match['identity_spw']
        gather0_1 = match['gather0_1']
        gather0_2 = match['gather0_2']
        greaterequal0 = match['greaterequal0']
        sparse_fill_empty_rows = match['sparse_fill_empty_rows']
        gather = match['gather']
        select = match['select']
        where0 = match['where0']
        sparse_segment_op = match['sparse_segment_op']
        output_node_name = select.soft_get('name', select.id)

        log.debug('Found EmbeddingSparseSegmentsSingleFeature pattern after {} with name {}'.format(
            sparse_fill_empty_rows.op,
            sparse_fill_empty_rows.name))

        split_for_indices = create_op_with_const_inputs(graph, Split, {1: int64_array(1)}, {'num_splits': 2})
        squeeze_for_indices = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([1])})
        split_for_dense_shape = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2})
        squeeze_to_scalar = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])})

        # TODO: remove Cast nodes once we start to support EmbeddingSegmentSum (new version) with segment_ids,
        #  indices, and num_segments of different integer type.
        #  Because the real cases show that it is possible to have it in TensorFlow
        cast_indices = Cast(graph, {'name': output_node_name + '/CastIndices', 'dst_type': np.int32}).create_node()
        cast_segment_ids = Cast(graph, {'name': output_node_name + '/CastSegmentIds',
                                        'dst_type': np.int32}).create_node()
        cast_default_value = Cast(graph, {'name': output_node_name + '/CastDefaultValue',
                                          'dst_type': np.int32}).create_node()
        cast_num_segments = Cast(graph, {'name': output_node_name + '/CastSegmentsNumber',
                                         'dst_type': np.int32}).create_node()
        if sparse_segment_op.op == 'SparseSegmentSum':
            embedding_segments_op = EmbeddingSegmentsSum(graph, {'name': output_node_name}).create_node()
        else:
            embedding_segments_op = EmbeddingSegmentsMean(graph, {'name': output_node_name}).create_node()
        rename_nodes([(select, output_node_name + '/AbandonedName'), (embedding_segments_op, output_node_name)])

        # connect parameters table
        gather.in_port(0).get_connection().set_destination(embedding_segments_op.in_port(0))
        # connect indices values
        greaterequal0.in_port(0).get_connection().set_destination(cast_indices.in_port(0))
        embedding_segments_op.in_port(1).connect(cast_indices.out_port(0))
        # split and connect segment ids
        gather0_1.in_port(0).get_connection().set_destination(split_for_indices.in_port(0))
        squeeze_for_indices.in_port(0).connect(split_for_indices.out_port(0))
        cast_segment_ids.in_port(0).connect(squeeze_for_indices.out_port(0))
        embedding_segments_op.in_port(2).connect(cast_segment_ids.out_port(0))
        # split and connect number of segments
        identity_spw.in_port(0).get_connection().set_destination(split_for_dense_shape.in_port(0))
        squeeze_to_scalar.in_port(0).connect(split_for_dense_shape.out_port(0))
        cast_num_segments.in_port(0).connect(squeeze_to_scalar.out_port(0))
        embedding_segments_op.in_port(3).connect(cast_num_segments.out_port(0))
        # connect default value
        sparse_fill_empty_rows.in_port(3).get_connection().set_destination(cast_default_value.in_port(0))
        embedding_segments_op.in_port(4).connect(cast_default_value.out_port(0))
        # no input port for per_sample_weight

        identity_spw.in_port(0).disconnect()
        gather0_1.in_port(0).disconnect()
        gather0_2.in_port(0).disconnect()
        greaterequal0.in_port(0).disconnect()
        sparse_fill_empty_rows.in_port(2).disconnect()
        gather.in_port(0).disconnect()

        select.out_port(0).get_connection().set_source(embedding_segments_op.out_port(0))
        graph.remove_nodes_from(
            [gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id, where0.id])
Ejemplo n.º 10
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        group_norm_node = match['op']
        group_norm_num_input_dims = len(group_norm_node.in_port(0).data.get_shape())

        # node computing initial GroupNorm input shape
        initial_shape_op_node = Shape(graph, {'name': group_norm_node.name + '/Shape'}).create_node()
        initial_shape_op_node.in_port(0).connect(group_norm_node.in_port(0).get_source())

        initial_shape_op_node_float = Cast(
            graph, {'name': initial_shape_op_node.name + '/to_float',
                    'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
        initial_shape_op_node.out_port(0).connect(initial_shape_op_node_float.in_port(0))

        initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node_float)
        initial_features_dim_node = node_to_get_features_dimension_value(initial_shape_op_node_float)
        initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value(initial_shape_op_node)
        initial_spatial_dims_node = Cast(
            graph, {'name': initial_spatial_dims_node_int.name + '/to_float',
                    'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
        initial_spatial_dims_node_int.out_port(0).connect(initial_spatial_dims_node.in_port(0))

        group_size_node = Const(graph, {'value': int64_array([group_norm_node.num_groups]),
                                        'name': group_norm_node.name + '/GroupSize'}).create_node()

        # calculate "features // group_size" value
        reciprocal_group_size_node = Const(graph, {'value': np.array([1.0 / group_norm_node.num_groups]),
                                                   'name': group_norm_node.name + '/ReciprocalGroupSize'}).create_node()

        c_div_g_node = Mul(graph, {}).create_node()
        c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0))
        c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0))

        batch_mul_group_size_node = Mul(graph, {}).create_node()
        batch_mul_group_size_node.in_port(0).connect(initial_batch_dim_node.out_port(0))
        batch_mul_group_size_node.in_port(1).connect(group_size_node.out_port(0))

        # create new node which concatenates several dims to one
        new_shape_node_float = new_shape_node_from_shape_nodes([batch_mul_group_size_node, c_div_g_node,
                                                                initial_spatial_dims_node])
        new_shape_node = Cast(graph,
                              {'name': new_shape_node_float.name + '/to_int64', 'dst_type': np.int64}).create_node()
        new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0))

        reshape_for_mvn_node = Reshape(graph, {}).create_node()

        group_norm_node.in_port(0).get_connection().set_destination(reshape_for_mvn_node.in_port(0))
        reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0))

        # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc
        gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64)
        gamma_beta_shape[1] = -1

        gamma_value = group_norm_node.in_port(1).get_source().data.get_value()
        beta_value = group_norm_node.in_port(2).get_source().data.get_value()
        assert gamma_value is not None, 'The gamma should be constant'
        assert beta_value is not None, 'The beta should be constant'
        gamma_value = np.reshape(gamma_value, gamma_beta_shape)
        group_norm_node.in_port(1).get_source().data.set_value(gamma_value)
        beta_value = np.reshape(beta_value, gamma_beta_shape)
        group_norm_node.in_port(2).get_source().data.set_value(beta_value)

        # MVN
        mvn_node = MVN(graph, {'name': group_norm_node.name + '/MVN',
                               'normalize_variance': 1,
                               'eps': group_norm_node.eps,
                               'eps_mode': 'inside_sqrt'}).create_node()
        mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0))

        # MVN axes
        _, rank = get_shape_and_rank_nodes_by_port(mvn_node.in_port(0).get_connection().get_source(),
                                                   return_as_a_scalar=True)
        rng = create_op_with_const_inputs(graph, Range, {0: int64_array(1), 2: int64_array(1)},
                                          {'name': group_norm_node.name + '/Range', 'output_type': np.int64})
        mvn_node.in_port(1).connect(rng.out_port(0))
        rng.in_port(1).connect(rank.out_port(0))

        # reshape to the initial shape before multiplying with gamma and adding beta
        reshape_to_initial_shape_node = Reshape(graph, {}).create_node()
        reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0))
        reshape_to_initial_shape_node.in_port(1).connect(initial_shape_op_node.out_port(0))

        mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node()
        mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0))
        group_norm_node.in_port(1).get_connection().set_destination(mul_node.in_port(1))

        add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node()
        add_node.in_port(0).connect(mul_node.out_port(0))
        group_norm_node.in_port(2).get_connection().set_destination(add_node.in_port(1))

        group_norm_node.out_port(0).get_connection().set_source(add_node.out_port(0))
Ejemplo n.º 11
0
def replace_resize(graph: Graph, resize: Node):
    log.debug("Converting of ONNX Resize-10 to Interpolate-4 "
              "is triggered for node {}.".format(
                  resize.soft_get('name', resize.id)))

    resize_name = resize.soft_get('name', resize.id)

    rank_node = Rank(graph, {'name': resize_name + '/max_axes'}).create_node()
    range_node = create_op_with_const_inputs(graph, Range, {
        0: int64_array(2),
        2: int64_array(1)
    }, {'name': resize_name + '/axes'})

    sizes_ss = create_op_with_const_inputs(graph, StridedSlice, {
        1: int64_array([2]),
        2: int64_array([0]),
        3: int64_array([1])
    }, {
        'name': resize_name + '/sizes_ss',
        'begin_mask': int64_array([1]),
        'end_mask': int64_array([0]),
        'new_axis_mask': int64_array([0]),
        'shrink_axis_mask': int64_array([0]),
        'ellipsis_mask': int64_array([0])
    })
    scales_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([2]),
            2: int64_array([0]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/scales_ss',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([0]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })

    rank_node.out_port(0).connect(range_node.in_port(1))

    interpolate_node = Interpolate(
        graph, {
            'version': 'opset4',
            'mode': 'linear_onnx' if resize.mode == 'linear' else 'nearest',
            'coordinate_transformation_mode': 'asymmetric',
            'cube_coeff': -0.75,
            'nearest_mode': 'simple',
            'pads_begin': int64_array([0]),
            'pads_end': int64_array([0]),
            'antialias': 0,
            'shape_calculation_mode': 'scales',
            'in_ports_count': 4
        }).create_node()

    range_node.out_port(0).connect(interpolate_node.in_port(3))
    shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()

    # When we calculate 'sizes' input as floor(input_shape * scales), we can get incorrect 'sizes' if, e.g.,
    # scales = [1.0, 1.0, 1.33333, 2.0], input_shape = [1, 3, 30, 200], because
    # input_shape * scales = [1, 3, 39.9999, 400], and floor(input_shape * scales)[2] == 39, not 40.
    # Maybe we need to calculate 'sizes' input as floor(input_shape * scales + eps), where eps is some small
    # floating point number, e.g. 1.0e-5. But, in this case, if scales = [1.0, 1.0, 1.333333, 2.0],
    # input_shape = [1, 3, 30, 200], floor(input_shape * scales + eps) = 39, not 40, because
    # input_shape[2] * scales[2] + 1.0e-5 =  39.99991.
    # Hence, we need to calculate 'sizes' as floor(input_shape * (scales + eps)).
    add_node = create_op_with_const_inputs(graph, Add,
                                           {1: float_array([1.0e-5])},
                                           {'name': resize_name + '/Add'})

    dst_dtype = np.float32  # even if data_type=FP16 use float32 for shape values

    cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()

    shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
    mul_node = Mul(graph, {
        'name': resize_name + '/Mul'
    }).create_node([cast_shape_to_float, add_node])
    floor_node = Floor(graph, {
        'name': resize_name + '/Floor'
    }).create_node([mul_node])
    cast_mul_result_to_int = Cast(graph, {
        'dst_type': np.int64
    }).create_node([floor_node])
    cast_mul_result_to_int.out_port(0).connect(sizes_ss.in_port(0))
    sizes_ss.out_port(0).connect(interpolate_node.in_port(1))

    scales_ss.out_port(0).connect(interpolate_node.in_port(2))

    connection_of_resize_input = resize.in_port(0).get_connection()
    connection_of_resize_input.set_destination(interpolate_node.in_port(0))

    connection_of_scales = resize.in_port(1).get_connection()
    connection_of_scales.set_destination(scales_ss.in_port(0))

    connection_of_resize_input.get_source().connect(shape_of.in_port(0))
    connection_of_resize_input.get_source().connect(rank_node.in_port(0))
    connection_of_scales.get_source().connect(add_node.in_port(0))

    rename_nodes([(resize, resize_name + '/delete'),
                  (interpolate_node, resize_name)])
    resize.out_port(0).get_connection().set_source(
        interpolate_node.out_port(0))
def replace_resize(graph: Graph, resize: Node):
    log.debug("Converting of ONNX Resize-11 to Interpolate-4 "
              "is triggered for node {}.".format(
                  resize.soft_get('name', resize.id)))

    input_shape = resize.in_port(0).data.get_shape()
    input_rank = len(input_shape)
    resize_name = resize.soft_get('name', resize.id)
    if input_rank not in {4, 5}:
        log.warning(
            'The input shape is not 4D or 5D for op with name {}'.format(
                resize_name))
        return

    assert (resize.is_in_port_connected(0) and (resize.is_in_port_connected(2) or resize.is_in_port_connected(3))), \
        "Scales or sizes inputs must be connected to Node {} with op {}.".format(resize.soft_get("name", resize.id),
                                                                                 resize.op)

    assert resize.soft_get('coordinate_transformation_mode') != 'tf_crop_and_resize', \
        'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(resize.op,
                                                                                 resize.soft_get("name", resize.id))

    layout = graph.graph['layout']

    if input_rank == 4:
        begin_dim = get_height_dim(layout, input_rank)
        end_dim = get_width_dim(layout, input_rank) + 1
    else:
        begin_dim = get_depth_dim(layout, input_rank)
        end_dim = get_width_dim(layout, input_rank) + 1

    sizes_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([begin_dim]),
            2: int64_array([end_dim]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/StridedSlice_sizes',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })
    scales_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([begin_dim]),
            2: int64_array([end_dim]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/StridedSlice_scales',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })
    axes_node = Const(
        graph, {
            'name': resize_name + '/axis',
            'value': int64_array(np.arange(begin_dim, end_dim))
        }).create_node()

    shape_calculation_mode = 'sizes' if resize.is_in_port_connected(
        3) else 'scales'

    interpolate_node = Interpolate(
        graph, {
            'version': 'opset4',
            'mode': convert_mode(resize.mode),
            'coordinate_transformation_mode':
            resize.coordinate_transformation_mode,
            'cube_coeff': resize.cube_coeff,
            'nearest_mode': resize.nearest_mode,
            'pads_begin': int64_array([0]),
            'pads_end': int64_array([0]),
            'antialias': 0,
            'shape_calculation_mode': shape_calculation_mode,
            'in_ports_count': 4
        }).create_node()

    axes_node.out_port(0).connect(interpolate_node.in_port(3))
    shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()

    add_node = create_op_with_const_inputs(graph, Add,
                                           {1: float_array([1.0e-5])},
                                           {'name': resize_name + '/Add'})

    dst_dtype = np.float32  # even if data_type=FP16 use float32 for shape values

    if not resize.is_in_port_connected(3):
        cast_shape_to_float = Cast(graph, {
            'dst_type': dst_dtype
        }).create_node()
        mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node()
        shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
        cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))
        cast_add_result_to_int = Cast(graph, {
            'dst_type': np.int64
        }).create_node()
        floor_node = Floor(graph, {
            'name': resize_name + '/Floor'
        }).create_node()
        mul_node.out_port(0).connect(add_node.in_port(0))
        add_node.out_port(0).connect(floor_node.in_port(0))
        floor_node.out_port(0).connect(cast_add_result_to_int.in_port(0))
        cast_add_result_to_int.out_port(0).connect(sizes_ss.in_port(0))
        sizes_ss.out_port(0).connect(interpolate_node.in_port(1))
        scales_ss.out_port(0).connect(interpolate_node.in_port(2))

        connection_of_resize_input = resize.in_port(0).get_connection()
        connection_of_resize_input.set_destination(interpolate_node.in_port(0))

        connection_of_scales = resize.in_port(2).get_connection()
        connection_of_scales.set_destination(scales_ss.in_port(0))

        connection_of_resize_input.get_source().connect(shape_of.in_port(0))
        connection_of_scales.get_source().connect(mul_node.in_port(1))
    else:
        cast_shape_to_float = Cast(graph, {
            'dst_type': dst_dtype
        }).create_node()
        cast_sizes_to_float = Cast(graph, {
            'dst_type': dst_dtype
        }).create_node()
        div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()
        cast_sizes_to_float.out_port(0).connect(div_node.in_port(0))
        cast_shape_to_float.out_port(0).connect(div_node.in_port(1))
        shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
        div_node.out_port(0).connect(add_node.in_port(0))
        add_node.out_port(0).connect(scales_ss.in_port(0))
        scales_ss.out_port(0).connect(interpolate_node.in_port(2))
        sizes_ss.out_port(0).connect(interpolate_node.in_port(1))

        connection_of_resize_input = resize.in_port(0).get_connection()
        connection_of_resize_input.set_destination(interpolate_node.in_port(0))

        connection_of_sizes = resize.in_port(3).get_connection()
        connection_of_sizes.set_destination(sizes_ss.in_port(0))

        connection_of_resize_input.get_source().connect(shape_of.in_port(0))
        connection_of_sizes.get_source().connect(
            cast_sizes_to_float.in_port(0))

    rename_nodes([(resize, resize_name + '/delete'),
                  (interpolate_node, resize_name)])
    resize.out_port(0).get_connection().set_source(
        interpolate_node.out_port(0))
Ejemplo n.º 13
0
    def dequantize_data(fake_quantize: Node, dst_type: type,
                        quantized_type: type) -> Node:
        graph = fake_quantize.graph
        quantized_data = fake_quantize.in_port(0).get_source().node
        name = fake_quantize.soft_get('name', fake_quantize.id)

        assert quantized_data.soft_get('type') == 'Convert' and quantized_data.dst_type == quantized_type, \
            'Weights aren`t compressed as expected for node {}'.format(fake_quantize.soft_get('name', fake_quantize.id))

        dequantizing_cast = Cast(
            graph,
            dict(name=quantized_data.name +
                 "/to_{}".format(np_data_type_to_destination_type(dst_type)),
                 dst_type=dst_type,
                 stop_value_propagation=True)).create_node()
        fake_quantize.in_port(0).get_connection().set_destination(
            dequantizing_cast.in_port(0))

        # limits of dequantize
        in_low = fake_quantize.in_port(1).get_source()
        in_high = fake_quantize.in_port(2).get_source()
        out_low = fake_quantize.in_port(3).get_source()
        out_high = fake_quantize.in_port(4).get_source()

        # scale calculation
        output_range = Sub(graph, {
            'name': name + '/output_range'
        }).create_node()
        output_range.in_port(0).connect(out_high)
        output_range.in_port(1).connect(out_low)

        input_range = Sub(graph, {'name': name + '/input_range'}).create_node()
        input_range.in_port(0).connect(in_high)
        input_range.in_port(1).connect(in_low)

        scale = Div(graph, {'name': name + '/scale'}).create_node()
        scale.in_port(0).connect(output_range.out_port(0))
        scale.in_port(1).connect(input_range.out_port(0))

        # shift calculation
        descaled_output_low = Div(graph, {
            'name': name + '/descaled_output_low'
        }).create_node()
        descaled_output_low.in_port(0).connect(out_low)
        descaled_output_low.in_port(1).connect(scale.out_port(0))

        shift = Sub(graph, {'name': name + '/shift'}).create_node()
        shift.in_port(0).connect(in_low)
        shift.in_port(1).connect(descaled_output_low.out_port(0))

        zero = Const(graph, {
            'name': name + '/zero',
            'value': mo_array(0, dtype=dst_type)
        }).create_node()
        scale_eq_zero = Equal(graph, {
            'name': name + '/scale_eq_zero'
        }).create_node()
        scale_eq_zero.in_port(0).connect(scale.out_port(0))
        scale_eq_zero.in_port(1).connect(zero.out_port(0))

        zero_point = Select(graph, {
            'name': name + '/zero_point'
        }).create_node()
        zero_point.in_port(0).connect(scale_eq_zero.out_port(0))
        zero_point.in_port(1).connect(zero.out_port(0))
        zero_point.in_port(2).connect(shift.out_port(0))

        # DeQuantize(x) == Mul(Sub(x, zero_point), scale)
        sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node()
        sub_zp.in_port(0).connect(dequantizing_cast.out_port(0))
        sub_zp.in_port(1).connect(zero_point.out_port(0))

        mul_scale = Mul(graph, {
            'name': name + '/mulpiply_by_scale'
        }).create_node()
        mul_scale.in_port(0).connect(sub_zp.out_port(0))
        mul_scale.in_port(1).connect(scale.out_port(0))

        fake_quantize.out_port(0).get_connection().set_source(
            mul_scale.out_port(0))

        graph.remove_nodes_from([fake_quantize.id, fake_quantize.out_node(0)])
Ejemplo n.º 14
0
    def find_and_replace_pattern(self, graph: Graph):
        for quantize_node in graph.get_op_nodes(op='QuantizeLinear'):
            node_name = quantize_node.soft_get('name', quantize_node.id)
            axis = quantize_node.soft_get('axis', None)
            scale_y_shape = quantize_node.in_port(1).data.get_shape()

            if quantize_node.is_in_port_connected(2):
                zerop = quantize_node.in_port(2).get_source().node
            else:
                zerop = Const(
                    graph, {
                        'value': mo_array(0, dtype=np.uint8),
                        'name': node_name + '/ZeroPoint'
                    }).create_node()

            assert zerop.soft_get(
                'type'
            ) == 'Const', 'only constant for zero_point is supported for QuantizeLinear'
            zero_point_type = zerop.value.dtype
            # data type affects range of output values: [-128..127] or [0..255]
            if zero_point_type == np.int8:
                output_low_value = -128.0
                output_high_value = 127.0
            elif zero_point_type == np.uint8:
                output_low_value = 0.0
                output_high_value = 255.0
            else:
                raise Error(
                    'Not expected type {} for zero point value in node {}'.
                    format(zero_point_type, zerop.soft_get('name')))

            fake_quantize = create_op_with_const_inputs(
                graph, FakeQuantize, {
                    3: float_array(output_low_value),
                    4: float_array(output_high_value)
                }, {
                    'levels': 256,
                    'name': node_name + '/FakeQuantize'
                })
            quantize_node.in_port(0).get_connection().set_destination(
                fake_quantize.in_port(0))

            # Calculate input_low value
            mul_low = create_op_with_const_inputs(
                graph, Mul, {1: float_array(output_low_value - zerop.value)},
                {'name': node_name + '/Mul/Low'})
            quantize_node.in_port(1).get_connection().set_destination(
                mul_low.in_port(0))
            mul_low.out_port(0).connect(fake_quantize.in_port(1))

            # Calculate input_high value
            mul_high = create_op_with_const_inputs(
                graph, Mul, {1: float_array(output_high_value - zerop.value)},
                {'name': node_name + '/Mul/High'})
            mul_low.in_port(0).get_connection().add_destination(
                mul_high.in_port(0))
            mul_high.out_port(0).connect(fake_quantize.in_port(2))

            cast = Cast(graph, {
                'dst_type': zero_point_type,
                'name': node_name + '/Cast'
            }).create_node()
            fake_quantize.out_port(0).connect(cast.in_port(0))
            quantize_node.out_port(0).get_connection().set_source(
                cast.out_port(0))
            rename_nodes([(quantize_node, node_name + '/TBD'),
                          (cast, node_name)])

            assert scale_y_shape is not None
            if axis is not None and len(
                    scale_y_shape) > 0 and scale_y_shape[0] > 1:
                input_shape = fake_quantize.in_port(0).data.get_shape()
                target_shape = np.ones(len(input_shape), np.int)
                target_shape[axis] = input_shape[axis]
                mul_low_reshape = create_op_with_const_inputs(
                    graph, Reshape, {1: int64_array(target_shape)},
                    {'name': node_name + '/Reshape/Mul/Low'})
                mul_high_reshape = create_op_with_const_inputs(
                    graph, Reshape, {1: int64_array(target_shape)},
                    {'name': node_name + '/Reshape/Mul/high'})

                fake_quantize.in_port(1).get_connection().set_destination(
                    mul_low_reshape.in_port(0))
                fake_quantize.in_port(2).get_connection().set_destination(
                    mul_high_reshape.in_port(0))

                mul_low_reshape.out_port(0).connect(fake_quantize.in_port(1))
                mul_high_reshape.out_port(0).connect(fake_quantize.in_port(2))
Ejemplo n.º 15
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['pool']
        node_name = node.soft_get('name', node.id)

        if node.pool_step is None:
            node.stride = int64_array([1, 1, node.window[-1], node.window[-1]])

        # create Reshape before convolution
        # shape = [in_shape[0], pool_stride, 1, in_shape[1]/pool_stride]
        i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()

        dst_dtype = np.float32  # even if data_type=FP16 use float32 for shape values
        shape = Cast(graph, {
            'name': node_name + '/to_float',
            'dst_type': dst_dtype
        }).create_node()
        i_shape.in_port(0).connect(node.in_port(0).get_source())
        shape.in_port(0).connect(i_shape.out_port(0))

        N, H = node_to_get_shape_value_of_indices(
            shape, [0]), node_to_get_shape_value_of_indices(shape, [1])

        div = create_op_with_const_inputs(
            graph, Div, {1: float32_array([node.pool_stride])},
            {'name': node_name + '/div_stride_h'})
        div.in_port(0).connect(H.out_port(0))

        concat = create_op_with_const_inputs(
            graph, Concat, {
                1: float32_array([node.pool_stride]),
                2: float32_array([1])
            }, {
                'name': node_name + '/concat_all_dims',
                'in_ports_count': 4,
                'axis': 0
            })
        concat.in_port(0).connect(N.out_port(0))
        concat.in_port(3).connect(div.out_port(0))

        reshape_pattern = Cast(graph, {
            'name': node_name + '/to_int',
            'dst_type': np.int64
        }).create_node()
        concat.out_port(0).connect(reshape_pattern.in_port(0))

        reshape_in = Reshape(graph, {
            'name': node_name + '/reshape_in'
        }).create_node()
        reshape_in.in_port(1).connect(reshape_pattern.out_port(0))

        # create Reshape after Convolution
        reshape_out = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            {'name': node_name + '/reshape_out'})

        # connect input_reshape_node
        source = node.in_port(0).get_source()
        node.in_port(0).get_connection().set_source(reshape_in.out_port(0))
        reshape_in.in_port(0).connect(source)
        # connect output_reshape_node
        node.out_port(0).get_connection().set_source(reshape_out.out_port(0))
        node.out_port(0).connect(reshape_out.in_port(0))