Example #1
0
def is_spatial_squeeze(layout: str, input_shape: np.ndarray, squeeze_dims: np.ndarray):
    """
    Checks that the squeeze operation removes all spatial dimensions.
    :param layout: graph layout.
    :param input_shape: numpy array with input shape.
    :param squeeze_dims: numpy array with dims to squeeze.
    :return: result of the check.
    """
    if len(input_shape) < 4 or len(input_shape) > 5:
        return False
    spatial_dims = [get_height_dim(layout, len(input_shape)), get_width_dim(layout, len(input_shape))]
    if len(input_shape) == 5:
        spatial_dims.append(get_depth_dim(layout, len(input_shape)))
    for dim in spatial_dims:
        if input_shape[dim] != 1:
            log.debug('The reshape from "{}" with squeezed dims "{}" is not a spatial squeeze'.format(input_shape,
                                                                                                      squeeze_dims))
            return False
    if len(squeeze_dims) != len(spatial_dims):
        log.debug('The reshape from "{}" with squeezed dims "{}" is not a spatial squeeze'.format(input_shape,
                                                                                                  squeeze_dims))
        return False
    log.debug('The reshape from "{}" with squeezed dims "{}" is not a spatial squeeze'.format(input_shape,
                                                                                              squeeze_dims))
    return True
Example #2
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        log.debug('UpsampleToResample is triggered')
        upsample = match['upsample']
        input_shape = upsample.in_port(0).data.get_shape()
        input_shape_rank = len(input_shape)
        if input_shape_rank not in [4, 5]:
            log.warning('The input shape is not 4D or 5D for op {}'.format(
                upsample.soft_get('name')))
            return

        if len(upsample.in_nodes()) == 2:
            if upsample.in_node(1).value is None:
                return
            scales = upsample.in_node(1).value
            assert scales.shape == (4, )
            if not (math.isclose(scales[0], 1, rel_tol=1e-5)
                    and math.isclose(scales[1], 1, rel_tol=1e-5)):
                return
            height_scale = scales[2]
            width_scale = scales[3]
        else:
            height_scale = upsample['height_scale']
            width_scale = upsample['width_scale']

        if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
            upsample.in_port(1).disconnect()

        factor = Const(graph, {
            'value': np.array([height_scale, width_scale])
        }).create_node()

        shape = Shape(graph, {'name': upsample.name + '/0_port'}).create_node()

        layout = graph.graph['layout']
        if input_shape_rank == 4:
            begin = Const(graph, {
                'value':
                int64_array([get_height_dim(layout, input_shape_rank)])
            }).create_node()
        else:
            begin = Const(graph, {
                'value':
                int64_array([get_depth_dim(layout, input_shape_rank)])
            }).create_node()
        end = Const(graph, {
            'value':
            int64_array([get_width_dim(layout, input_shape_rank) + 1])
        }).create_node()

        stride = Const(graph, {'value': int64_array([1])}).create_node()
        ss = StridedSlice(
            graph, {
                'name': upsample.name + '/ss_0_port',
                'begin_mask': np.array([1]),
                'end_mask': np.array([0]),
                'new_axis_mask': np.array([0]),
                'shrink_axis_mask': int64_array([0]),
                'ellipsis_mask': int64_array([0])
            }).create_node()

        mul = Mul(graph, {
            'name': upsample.name + '/factor_mul_'
        }).create_node()

        source = upsample.in_port(0).get_connection().get_source()
        source.connect(shape.in_port(0))
        shape.out_port(0).connect(ss.in_port(0))
        begin.out_port(0).connect(ss.in_port(1))
        end.out_port(0).connect(ss.in_port(2))
        stride.out_port(0).connect(ss.in_port(3))
        ss.out_port(0).connect(mul.in_port(0))
        factor.out_port(0).connect(mul.in_port(1))

        # Create Interpolate operation
        if input_shape_rank == 4:
            axes = int64_array([
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])
        else:
            axes = int64_array([
                get_depth_dim(layout, input_shape_rank),
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])

        resample_op = Interpolate(
            graph,
            dict(name='Interpolate/{}'.format(upsample.name),
                 axes=axes,
                 mode=upsample.attrs()['mode'],
                 antialias=0,
                 convert_to_resample=True)).create_node()

        upsample.add_input_port(1, skip_if_exist=True)
        assert upsample.in_port(1).disconnected()
        mul.out_port(0).connect(resample_op.in_port(1))

        upsample.in_port(0).get_connection().set_destination(
            resample_op.in_port(0))
        upsample.out_port(0).get_connection().set_source(
            resample_op.out_port(0))
def replace_resize(graph: Graph, resize: Node):
    log.debug("Converting of ONNX Resize-11 to Interpolate-4 "
              "is triggered for node {}.".format(
                  resize.soft_get('name', resize.id)))

    input_shape = resize.in_port(0).data.get_shape()
    input_rank = len(input_shape)
    resize_name = resize.soft_get('name', resize.id)
    if input_rank not in {4, 5}:
        log.warning(
            'The input shape is not 4D or 5D for op with name {}'.format(
                resize_name))
        return

    num_of_inputs = len([
        port for port in resize.in_ports().values() if not port.disconnected()
    ])
    assert num_of_inputs in {3, 4}, \
        "Number of inputs of ONNXResize (with name {}) should be equal to 3 or 4".format(resize_name)

    assert resize.soft_get('coordinate_transformation_mode') != 'tf_crop_and_resize', \
        'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(resize.op, resize_name)

    layout = graph.graph['layout']

    if input_rank == 4:
        begin_dim = get_height_dim(layout, input_rank)
        end_dim = get_width_dim(layout, input_rank) + 1
    else:
        begin_dim = get_depth_dim(layout, input_rank)
        end_dim = get_width_dim(layout, input_rank) + 1

    sizes_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([begin_dim]),
            2: int64_array([end_dim]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/StridedSlice_sizes',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })
    scales_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([begin_dim]),
            2: int64_array([end_dim]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/StridedSlice_scales',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })
    axes_node = Const(
        graph, {
            'name': resize_name + '/axis',
            'value': int64_array(np.arange(begin_dim, end_dim))
        }).create_node()

    shape_calculation_mode = 'scales' if num_of_inputs == 3 else 'sizes'

    interpolate_node = Interpolate(
        graph, {
            'version': 'opset4',
            'mode': convert_mode(resize.mode),
            'coordinate_transformation_mode':
            resize.coordinate_transformation_mode,
            'cube_coeff': resize.cube_coeff,
            'nearest_mode': resize.nearest_mode,
            'pads_begin': int64_array([0]),
            'pads_end': int64_array([0]),
            'antialias': 0,
            'shape_calculation_mode': shape_calculation_mode,
            'in_ports_count': 4
        }).create_node()

    axes_node.out_port(0).connect(interpolate_node.in_port(3))
    shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()

    add_node = create_op_with_const_inputs(graph, Add,
                                           {1: float_array([1.0e-5])},
                                           {'name': resize_name + '/Add'})

    input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)

    if num_of_inputs == 3:
        cast_shape_to_float = Cast(graph, {
            'dst_type': input_data_type
        }).create_node()
        mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node()
        shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
        cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))
        cast_add_result_to_int = Cast(graph, {
            'dst_type': np.int64
        }).create_node()
        floor_node = Floor(graph, {
            'name': resize_name + '/Floor'
        }).create_node()
        mul_node.out_port(0).connect(add_node.in_port(0))
        add_node.out_port(0).connect(floor_node.in_port(0))
        floor_node.out_port(0).connect(cast_add_result_to_int.in_port(0))
        cast_add_result_to_int.out_port(0).connect(sizes_ss.in_port(0))
        sizes_ss.out_port(0).connect(interpolate_node.in_port(1))
        scales_ss.out_port(0).connect(interpolate_node.in_port(2))

        connection_of_resize_input = resize.in_port(0).get_connection()
        connection_of_resize_input.set_destination(interpolate_node.in_port(0))

        connection_of_scales = resize.in_port(2).get_connection()
        connection_of_scales.set_destination(scales_ss.in_port(0))

        connection_of_resize_input.get_source().connect(shape_of.in_port(0))
        connection_of_scales.get_source().connect(mul_node.in_port(1))
    else:
        cast_shape_to_float = Cast(graph, {
            'dst_type': input_data_type
        }).create_node()
        cast_sizes_to_float = Cast(graph, {
            'dst_type': input_data_type
        }).create_node()
        div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()
        cast_sizes_to_float.out_port(0).connect(div_node.in_port(0))
        cast_shape_to_float.out_port(0).connect(div_node.in_port(1))
        shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
        div_node.out_port(0).connect(add_node.in_port(0))
        add_node.out_port(0).connect(scales_ss.in_port(0))
        scales_ss.out_port(0).connect(interpolate_node.in_port(2))
        sizes_ss.out_port(0).connect(interpolate_node.in_port(1))

        connection_of_resize_input = resize.in_port(0).get_connection()
        connection_of_resize_input.set_destination(interpolate_node.in_port(0))

        connection_of_sizes = resize.in_port(3).get_connection()
        connection_of_sizes.set_destination(sizes_ss.in_port(0))

        connection_of_resize_input.get_source().connect(shape_of.in_port(0))
        connection_of_sizes.get_source().connect(
            cast_sizes_to_float.in_port(0))

    rename_nodes([(resize, resize_name + '/delete'),
                  (interpolate_node, resize_name)])
    resize.out_port(0).get_connection().set_source(
        interpolate_node.out_port(0))
 def test_get_depth_dim_NDHWC(self):
     self.assertEqual(get_depth_dim('NHWC', 5), 1)
 def test_get_depth_dim_NCDHW(self):
     self.assertEqual(get_depth_dim('NCHW', 5), 2)
Example #6
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        log.debug('UpsampleToResample is triggered')
        upsample = match['upsample']
        upsample_name = upsample.soft_get('name', upsample.id)
        input_shape = upsample.in_port(0).data.get_shape()
        input_shape_rank = len(input_shape)
        if input_shape_rank not in [4, 5]:
            log.warning('The input shape is not 4D or 5D for op {}'.format(
                upsample.soft_get('name')))
            return

        depth_scale = None
        layout = graph.graph['layout']

        if len(upsample.in_nodes()) == 2:
            if upsample.in_node(1).value is None:
                return
            scales = upsample.in_node(1).value
            assert len(scales) in (
                4, 5
            ), 'Supported scales rank is 4 or 5, but it is {} for node {}'.format(
                len(scales), upsample_name)
            if not (math.isclose(scales[0], 1, rel_tol=1e-5)
                    and math.isclose(scales[1], 1, rel_tol=1e-5)):
                return
            height_scale = scales[get_height_dim(layout, input_shape_rank)]
            width_scale = scales[get_width_dim(layout, input_shape_rank)]
            if len(scales) == 5:
                depth_scale = scales[get_depth_dim(layout, input_shape_rank)]
        else:
            height_scale = upsample['height_scale']
            width_scale = upsample['width_scale']

        if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
            upsample.in_port(1).disconnect()

        upsample_name = upsample.soft_get('name', upsample.id)
        shape = Shape(graph, {'name': upsample_name + '/0_port'}).create_node()

        layout = graph.graph['layout']

        if input_shape_rank == 4:
            begin_value = int64_array(
                [get_height_dim(layout, input_shape_rank)])
            factor_value = float32_array([height_scale, width_scale])
        else:
            begin_value = int64_array(
                [get_depth_dim(layout, input_shape_rank)])
            factor_value = float32_array(
                [depth_scale, height_scale, width_scale])

        ss = create_op_with_const_inputs(
            graph, StridedSlice, {
                1: begin_value,
                2: int64_array([get_width_dim(layout, input_shape_rank) + 1]),
                3: int64_array([1])
            }, {
                'name': upsample_name + '/ss_0_port',
                'begin_mask': int64_array([1]),
                'end_mask': int64_array([1]),
                'new_axis_mask': int64_array([0]),
                'shrink_axis_mask': int64_array([0]),
                'ellipsis_mask': int64_array([0])
            })

        mul = create_op_node_with_second_input(
            graph, Mul, factor_value, {'name': upsample_name + '/factor_mul'})

        source = upsample.in_port(0).get_connection().get_source()
        source.connect(shape.in_port(0))
        shape.out_port(0).connect(ss.in_port(0))

        ss.out_port(0).connect(mul.in_port(0))

        # Create Interpolate operation
        if input_shape_rank == 4:
            axes = int64_array([
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])
        else:
            axes = int64_array([
                get_depth_dim(layout, input_shape_rank),
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])

        axes_node = Const(graph, {
            'name': upsample_name + '/axis',
            'value': axes
        }).create_node()

        interpolate = Interpolate(
            graph, {
                'mode': upsample.attrs()['mode'],
                'antialias': 0,
                'pads_begin': int64_array([0]),
                'pads_end': int64_array([0]),
                'coordinate_transformation_mode': 'half_pixel',
                'nearest_mode': 'round_prefer_floor',
                'cube_coeff': -0.75,
                'shape_calculation_mode': 'scales',
                'version': 'opset4',
                'in_ports_count': 4
            }).create_node()

        upsample.add_input_port(1, skip_if_exist=True)
        assert upsample.in_port(1).disconnected()
        mul.out_port(0).connect(interpolate.in_port(1))
        axes_node.out_port(0).connect(interpolate.in_port(3))

        scales_node = Const(graph, {
            'name': upsample_name + '/scales',
            'value': factor_value
        }).create_node()
        scales_node.out_port(0).connect(interpolate.in_port(2))

        upsample.in_port(0).get_connection().set_destination(
            interpolate.in_port(0))
        upsample.out_port(0).get_connection().set_source(
            interpolate.out_port(0))

        rename_nodes([(upsample, upsample_name + '/delete'),
                      (interpolate, upsample_name)])

        convert_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()
        convert_to_int = Cast(graph, dict(dst_type=np.int64)).create_node()

        mul.in_port(0).get_connection().insert_node(convert_to_float)
        mul.out_port(0).get_connection().insert_node(convert_to_int)
Example #7
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        log.debug('UpsampleToResample is triggered')
        upsample = match['upsample']
        upsample_name = upsample.soft_get('name', upsample.id)
        input_shape = upsample.in_port(0).data.get_shape()
        input_shape_rank = len(input_shape)
        if input_shape_rank not in [4, 5]:
            log.warning('The input shape is not 4D or 5D for op {}'.format(
                upsample.soft_get('name')))
            return

        depth_scale = None
        if len(upsample.in_nodes()) == 2:
            if upsample.in_node(1).value is None:
                return
            scales = upsample.in_node(1).value
            assert len(scales) in (
                4, 5
            ), 'Supported scales rank is 4 or 5, but it is {} for node {}'.format(
                len(scales), upsample_name)
            if not (math.isclose(scales[0], 1, rel_tol=1e-5)
                    and math.isclose(scales[1], 1, rel_tol=1e-5)):
                return
            height_scale = scales[2]
            width_scale = scales[3]
            if len(scales) == 5:
                depth_scale = scales[4]
        else:
            height_scale = upsample['height_scale']
            width_scale = upsample['width_scale']

        if not math.isclose(height_scale, width_scale, rel_tol=1e-5):
            log.debug(
                'Width and height scales are not equal: {} vs {} for node {}'.
                format(width_scale, height_scale, upsample_name))
            return
        if depth_scale is not None and not math.isclose(
                height_scale, depth_scale, rel_tol=1e-5):
            log.debug(
                'Depth and height scales are not equal: {} vs {} for node {}'.
                format(depth_scale, height_scale, upsample_name))
            return

        if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
            upsample.in_port(1).disconnect()

        shape = Shape(graph, {'name': upsample_name + '/0_port'}).create_node()

        layout = graph.graph['layout']

        if input_shape_rank == 4:
            begin_value = int64_array(
                [get_height_dim(layout, input_shape_rank)])
            factor_value = np.array([height_scale, width_scale])
        else:
            begin_value = int64_array(
                [get_depth_dim(layout, input_shape_rank)])
            factor_value = np.array([depth_scale, height_scale, width_scale])

        ss = create_op_with_const_inputs(
            graph, StridedSlice, {
                1: begin_value,
                2: int64_array([get_width_dim(layout, input_shape_rank) + 1]),
                3: int64_array([1])
            }, {
                'name': upsample_name + '/ss_0_port',
                'begin_mask': int64_array([1]),
                'end_mask': int64_array([1]),
                'new_axis_mask': int64_array([0]),
                'shrink_axis_mask': int64_array([0]),
                'ellipsis_mask': int64_array([0])
            })

        mul = create_op_node_with_second_input(
            graph, Mul, factor_value, {'name': upsample_name + '/factor_mul_'})

        source = upsample.in_port(0).get_connection().get_source()
        source.connect(shape.in_port(0))
        shape.out_port(0).connect(ss.in_port(0))

        ss.out_port(0).connect(mul.in_port(0))

        # Create Interpolate operation
        if input_shape_rank == 4:
            axes = int64_array([
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])
        else:
            axes = int64_array([
                get_depth_dim(layout, input_shape_rank),
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])

        resample_op = Interpolate(
            graph,
            dict(name=upsample_name + '/Interpolate',
                 axes=axes,
                 mode=upsample.attrs()['mode'],
                 antialias=0,
                 convert_to_resample=True)).create_node()

        upsample.add_input_port(1, skip_if_exist=True)
        assert upsample.in_port(1).disconnected()
        mul.out_port(0).connect(resample_op.in_port(1))

        upsample.in_port(0).get_connection().set_destination(
            resample_op.in_port(0))
        upsample.out_port(0).get_connection().set_source(
            resample_op.out_port(0))

        convert_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()
        convert_to_int = Cast(graph, dict(dst_type=np.int64)).create_node()

        mul.in_port(0).get_connection().insert_node(convert_to_float)
        mul.out_port(0).get_connection().insert_node(convert_to_int)