def replace_interpolate_pattern(graph: Graph, match: dict):
    split = match['split']
    scale = np.array([get_split_scale(split)], dtype=np.float32)
    axis = int(split.in_port(1).get_connection().get_source().node.value)
    split_node_name = split.name
    axis_node = Const(graph, {'name': split_node_name + '/axis', 'value': int64_array([axis])}).create_node()

    shape_node = Shape(graph, dict(name=split_node_name + '/Shape')).create_node()
    scales_node = Const(graph, dict(name=split_node_name + '/scales', value=scale)).create_node()
    mul_node = Mul(graph, dict(name=split_node_name + '/Mul')).create_node()
    scales_node.out_port(0).connect(mul_node.in_port(1))

    strided_slice_node = create_op_with_const_inputs(graph,
                                                     StridedSlice,
                                                     {1: int64_array([axis]), 2: int64_array([axis + 1])},
                                                     {
                                                        'name': split_node_name + '/StridedSlice',
                                                        'begin_mask': int64_array([1]),
                                                        'end_mask': int64_array([1]),
                                                        'new_axis_mask': int64_array([0]),
                                                        'shrink_axis_mask': int64_array([0]),
                                                        'ellipsis_mask': int64_array([0])
                                                     })
    shape_node.out_port(0).connect(strided_slice_node.in_port(0))

    cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node()

    strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0))
    cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))

    interp_node = Interpolate(graph,
                              dict(name=split_node_name + '/Interpolate',
                                   mode='nearest',
                                   antialias=0, pads_begin=int64_array([0]), pads_end=int64_array([0]),
                                   coordinate_transformation_mode='half_pixel', nearest_mode='round_prefer_floor',
                                   cube_coeff=-0.75, version='opset4', shape_calculation_mode='scales',
                                   in_ports_count=4, maybe_part_of_sequence=True)).create_node()

    floor_node = Floor(graph, {'name': split_node_name + '/Floor'}).create_node()
    cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node()

    mul_node.out_port(0).connect(floor_node.in_port(0))
    floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0))

    cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1))
    scales_node.out_port(0).connect(interp_node.in_port(2))
    axis_node.out_port(0).connect(interp_node.in_port(3))

    match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0))

    split_connection = split.in_port(0).get_connection()
    split_connection.set_destination(interp_node.in_port(0))
    split_connection.get_source().connect(shape_node.in_port(0))
    def floor_div_replacement(floor_div: Node):
        graph = floor_div.graph
        name = floor_div.soft_get('name', floor_div.id)

        div = Div(graph, {'name': name + '/Div'}).create_node()
        floor = Floor(graph, {'name': name}).create_node()
        div.out_port(0).connect(floor.in_port(0))

        div.in_port(0).connect(floor_div.in_port(0).get_source())
        div.in_port(1).connect(floor_div.in_port(1).get_source())
        floor_div.out_port(0).get_connection().set_source(floor.out_port(0))

        graph.remove_node(floor_div.id)
        rename_node(floor, name)
Example #3
0
 def extract(cls, node):
     Floor.update_node_stat(node, {})
     return cls.enabled
def replace_resize(graph: Graph, resize: Node):
    log.debug("Converting of ONNX Resize-11 to Interpolate-4 "
              "is triggered for node {}.".format(
                  resize.soft_get('name', resize.id)))

    input_shape = resize.in_port(0).data.get_shape()
    input_rank = len(input_shape)
    resize_name = resize.soft_get('name', resize.id)
    if input_rank not in {4, 5}:
        log.warning(
            'The input shape is not 4D or 5D for op with name {}'.format(
                resize_name))
        return

    num_of_inputs = len([
        port for port in resize.in_ports().values() if not port.disconnected()
    ])
    assert num_of_inputs in {3, 4}, \
        "Number of inputs of ONNXResize (with name {}) should be equal to 3 or 4".format(resize_name)

    assert resize.soft_get('coordinate_transformation_mode') != 'tf_crop_and_resize', \
        'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(resize.op, resize_name)

    layout = graph.graph['layout']

    if input_rank == 4:
        begin_dim = get_height_dim(layout, input_rank)
        end_dim = get_width_dim(layout, input_rank) + 1
    else:
        begin_dim = get_depth_dim(layout, input_rank)
        end_dim = get_width_dim(layout, input_rank) + 1

    sizes_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([begin_dim]),
            2: int64_array([end_dim]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/StridedSlice_sizes',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })
    scales_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([begin_dim]),
            2: int64_array([end_dim]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/StridedSlice_scales',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })
    axes_node = Const(
        graph, {
            'name': resize_name + '/axis',
            'value': int64_array(np.arange(begin_dim, end_dim))
        }).create_node()

    shape_calculation_mode = 'scales' if num_of_inputs == 3 else 'sizes'

    interpolate_node = Interpolate(
        graph, {
            'version': 'opset4',
            'mode': convert_mode(resize.mode),
            'coordinate_transformation_mode':
            resize.coordinate_transformation_mode,
            'cube_coeff': resize.cube_coeff,
            'nearest_mode': resize.nearest_mode,
            'pads_begin': int64_array([0]),
            'pads_end': int64_array([0]),
            'antialias': 0,
            'shape_calculation_mode': shape_calculation_mode,
            'in_ports_count': 4
        }).create_node()

    axes_node.out_port(0).connect(interpolate_node.in_port(3))
    shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()

    add_node = create_op_with_const_inputs(graph, Add,
                                           {1: float_array([1.0e-5])},
                                           {'name': resize_name + '/Add'})

    input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)

    if num_of_inputs == 3:
        cast_shape_to_float = Cast(graph, {
            'dst_type': input_data_type
        }).create_node()
        mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node()
        shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
        cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))
        cast_add_result_to_int = Cast(graph, {
            'dst_type': np.int64
        }).create_node()
        floor_node = Floor(graph, {
            'name': resize_name + '/Floor'
        }).create_node()
        mul_node.out_port(0).connect(add_node.in_port(0))
        add_node.out_port(0).connect(floor_node.in_port(0))
        floor_node.out_port(0).connect(cast_add_result_to_int.in_port(0))
        cast_add_result_to_int.out_port(0).connect(sizes_ss.in_port(0))
        sizes_ss.out_port(0).connect(interpolate_node.in_port(1))
        scales_ss.out_port(0).connect(interpolate_node.in_port(2))

        connection_of_resize_input = resize.in_port(0).get_connection()
        connection_of_resize_input.set_destination(interpolate_node.in_port(0))

        connection_of_scales = resize.in_port(2).get_connection()
        connection_of_scales.set_destination(scales_ss.in_port(0))

        connection_of_resize_input.get_source().connect(shape_of.in_port(0))
        connection_of_scales.get_source().connect(mul_node.in_port(1))
    else:
        cast_shape_to_float = Cast(graph, {
            'dst_type': input_data_type
        }).create_node()
        cast_sizes_to_float = Cast(graph, {
            'dst_type': input_data_type
        }).create_node()
        div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()
        cast_sizes_to_float.out_port(0).connect(div_node.in_port(0))
        cast_shape_to_float.out_port(0).connect(div_node.in_port(1))
        shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
        div_node.out_port(0).connect(add_node.in_port(0))
        add_node.out_port(0).connect(scales_ss.in_port(0))
        scales_ss.out_port(0).connect(interpolate_node.in_port(2))
        sizes_ss.out_port(0).connect(interpolate_node.in_port(1))

        connection_of_resize_input = resize.in_port(0).get_connection()
        connection_of_resize_input.set_destination(interpolate_node.in_port(0))

        connection_of_sizes = resize.in_port(3).get_connection()
        connection_of_sizes.set_destination(sizes_ss.in_port(0))

        connection_of_resize_input.get_source().connect(shape_of.in_port(0))
        connection_of_sizes.get_source().connect(
            cast_sizes_to_float.in_port(0))

    rename_nodes([(resize, resize_name + '/delete'),
                  (interpolate_node, resize_name)])
    resize.out_port(0).get_connection().set_source(
        interpolate_node.out_port(0))
def replace_resize(graph: Graph, resize: Node):
    log.debug("Converting of ONNX Resize-10 to Interpolate-4 "
              "is triggered for node {}.".format(
                  resize.soft_get('name', resize.id)))

    resize_name = resize.soft_get('name', resize.id)

    rank_node = Rank(graph, {'name': resize_name + '/max_axes'}).create_node()
    range_node = create_op_with_const_inputs(graph, Range, {
        0: int64_array(2),
        2: int64_array(1)
    }, {'name': resize_name + '/axes'})

    sizes_ss = create_op_with_const_inputs(graph, StridedSlice, {
        1: int64_array([2]),
        2: int64_array([0]),
        3: int64_array([1])
    }, {
        'name': resize_name + '/sizes_ss',
        'begin_mask': int64_array([1]),
        'end_mask': int64_array([0]),
        'new_axis_mask': int64_array([0]),
        'shrink_axis_mask': int64_array([0]),
        'ellipsis_mask': int64_array([0])
    })
    scales_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([2]),
            2: int64_array([0]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/scales_ss',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([0]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })

    rank_node.out_port(0).connect(range_node.in_port(1))

    interpolate_node = Interpolate(
        graph, {
            'version': 'opset4',
            'mode': 'linear_onnx' if resize.mode == 'linear' else 'nearest',
            'coordinate_transformation_mode': 'asymmetric',
            'cube_coeff': -0.75,
            'nearest_mode': 'simple',
            'pads_begin': int64_array([0]),
            'pads_end': int64_array([0]),
            'antialias': 0,
            'shape_calculation_mode': 'scales',
            'in_ports_count': 4
        }).create_node()

    range_node.out_port(0).connect(interpolate_node.in_port(3))
    shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()

    # When we calculate 'sizes' input as floor(input_shape * scales), we can get incorrect 'sizes' if, e.g.,
    # scales = [1.0, 1.0, 1.33333, 2.0], input_shape = [1, 3, 30, 200], because
    # input_shape * scales = [1, 3, 39.9999, 400], and floor(input_shape * scales)[2] == 39, not 40.
    # Maybe we need to calculate 'sizes' input as floor(input_shape * scales + eps), where eps is some small
    # floating point number, e.g. 1.0e-5. But, in this case, if scales = [1.0, 1.0, 1.333333, 2.0],
    # input_shape = [1, 3, 30, 200], floor(input_shape * scales + eps) = 39, not 40, because
    # input_shape[2] * scales[2] + 1.0e-5 =  39.99991.
    # Hence, we need to calculate 'sizes' as floor(input_shape * (scales + eps)).
    add_node = create_op_with_const_inputs(graph, Add,
                                           {1: float_array([1.0e-5])},
                                           {'name': resize_name + '/Add'})

    input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)

    cast_shape_to_float = Cast(graph, {
        'dst_type': input_data_type
    }).create_node()

    shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
    mul_node = Mul(graph, {
        'name': resize_name + '/Mul'
    }).create_node([cast_shape_to_float, add_node])
    floor_node = Floor(graph, {
        'name': resize_name + '/Floor'
    }).create_node([mul_node])
    cast_mul_result_to_int = Cast(graph, {
        'dst_type': np.int64
    }).create_node([floor_node])
    cast_mul_result_to_int.out_port(0).connect(sizes_ss.in_port(0))
    sizes_ss.out_port(0).connect(interpolate_node.in_port(1))

    scales_ss.out_port(0).connect(interpolate_node.in_port(2))

    connection_of_resize_input = resize.in_port(0).get_connection()
    connection_of_resize_input.set_destination(interpolate_node.in_port(0))

    connection_of_scales = resize.in_port(1).get_connection()
    connection_of_scales.set_destination(scales_ss.in_port(0))

    connection_of_resize_input.get_source().connect(shape_of.in_port(0))
    connection_of_resize_input.get_source().connect(rank_node.in_port(0))
    connection_of_scales.get_source().connect(add_node.in_port(0))

    rename_nodes([(resize, resize_name + '/delete'),
                  (interpolate_node, resize_name)])
    resize.out_port(0).get_connection().set_source(
        interpolate_node.out_port(0))
Example #6
0
 def extract(node):
     Floor.update_node_stat(node)
     return __class__.enabled
    def replace_pattern(self, graph: Graph, match: dict):
        unsqueeze_node = match['unsqueeze']
        unsqueeze_name = unsqueeze_node.name

        second_input_of_unsqueeze = unsqueeze_node.in_port(
            1).get_connection().get_source().node
        if not second_input_of_unsqueeze.has_valid('value'):
            return

        d_idx = int(second_input_of_unsqueeze.value)

        second_input_of_tile = match['tile'].in_port(
            1).get_connection().get_source().node
        if not second_input_of_tile.has_valid('value'):
            return

        input_shape_of_unsqueeze = unsqueeze_node.in_port(0).data.get_shape()
        if len(input_shape_of_unsqueeze) not in {4, 5}:
            return

        scale = float32_array([second_input_of_tile.value[d_idx]])
        axis = d_idx - 1
        axis_node = Const(graph, {
            'name': unsqueeze_name + '/axis',
            'value': int64_array([axis])
        }).create_node()

        shape_node = Shape(graph,
                           dict(name=unsqueeze_name + '/Shape')).create_node()
        scales_node = Const(graph,
                            dict(name=unsqueeze_name + '/scales',
                                 value=scale)).create_node()
        mul_node = Mul(graph, dict(name=unsqueeze_name + '/Mul')).create_node()
        scales_node.out_port(0).connect(mul_node.in_port(1))

        slice_begin = Const(
            graph,
            dict(name=unsqueeze_name + '/slice_begin',
                 value=int64_array([axis]))).create_node()
        slice_end = Const(
            graph,
            dict(name=unsqueeze_name + '/slice_end',
                 value=int64_array([axis + 1]))).create_node()

        strided_slice_node = StridedSlice(
            graph, {
                'name': unsqueeze_name + '/StridedSlice',
                'begin_mask': int64_array([1]),
                'end_mask': int64_array([1]),
                'new_axis_mask': int64_array([0]),
                'shrink_axis_mask': int64_array([0]),
                'ellipsis_mask': int64_array([0]),
            }).create_node()
        shape_node.out_port(0).connect(strided_slice_node.in_port(0))
        slice_begin.out_port(0).connect(strided_slice_node.in_port(1))
        slice_end.out_port(0).connect(strided_slice_node.in_port(2))

        cast_shape_to_float = Cast(graph, {
            'dst_type': np.float32
        }).create_node()

        strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0))
        cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))

        interp_node = Interpolate(
            graph,
            dict(mode='nearest',
                 antialias=0,
                 pads_begin=int64_array([0]),
                 pads_end=int64_array([0]),
                 coordinate_transformation_mode='half_pixel',
                 nearest_mode='round_prefer_floor',
                 cube_coeff=-0.75,
                 version='opset4',
                 shape_calculation_mode='scales',
                 in_ports_count=4,
                 maybe_part_of_sequence=True)).create_node()

        floor_node = Floor(graph, {
            'name': unsqueeze_name + '/Floor'
        }).create_node()
        cast_mul_result_to_int = Cast(graph, {
            'dst_type': np.int64
        }).create_node()

        mul_node.out_port(0).connect(floor_node.in_port(0))
        floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0))

        cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1))
        scales_node.out_port(0).connect(interp_node.in_port(2))
        axis_node.out_port(0).connect(interp_node.in_port(3))

        reshape_node = match['reshape']

        reshape_node.out_port(0).get_connection().set_source(
            interp_node.out_port(0))
        reshape_name = reshape_node.soft_get('name', reshape_node.id)
        rename_nodes([(reshape_node, reshape_name + '/delete'),
                      (interp_node, reshape_name)])

        unsqueeze_connection = match['unsqueeze'].in_port(0).get_connection()
        before_unsqueeze = unsqueeze_connection.get_source().node
        unsqueeze_connection.set_destination(interp_node.in_port(0))
        before_unsqueeze.out_port(0).connect(shape_node.in_port(0))