Exemple #1
0
    def __call__(self, first: Node, second: Node) -> bool:
        """
        This function checks whether Interpolate nodes 'first' and 'second' can be fused.
        :param first: the first of fused nodes
        :param second: the second of fused nodes
        :return: True, if nodes can be fused, and False otherwise
        """
        if not (is_next(first, second)
                and self._compare_attributes(first, second)):
            self.accumulated_axes = set()
            return False

        fst_axes = set([a for a in Interpolate.get_axes(first)])
        snd_axes = set([a for a in Interpolate.get_axes(second)])

        self.accumulated_axes = self.accumulated_axes | fst_axes

        # If the set of accumulated axes and the set of axes of 'second' do not intersect then nodes can be fused,
        # because interpolations with respect to various axes do not affect each other.
        if not (self.accumulated_axes & snd_axes):
            return True

        # Otherwise, nodes cannot be fused.
        self.accumulated_axes = set()
        return False
    def replace_sub_graph(self, graph: Graph, match: dict):
        log.debug('Matched NearestNeighborUpsampling pattern: {}'.format([node.id for node in match.values()]))
        try:
            input_height = match['pack_1'].in_node(1).value.item()
            input_width = match['pack_1'].in_node(3).value.item()

            height_scale = match['mul_const'].shape[-4]
            width_scale = match['mul_const'].shape[-2]
        except Exception as ex:
            log.warning('Failed to determine scaling parameters from the topology. Do not apply pattern.')
            return

        axes = int64_array([2, 3]) if graph.graph['layout'] == 'NCHW' else int64_array([1, 2])

        resample_op = Interpolate(graph, {'name': 'Resample_', 'antialias': 0, 'mode': 'nearest', 'axes': axes})
        resample_node = resample_op.create_node([match['op']])
        const = Const(graph, {'value': np.array([input_height * height_scale, input_width * width_scale]),
                              'name': resample_node.name + '/target_shape'}).create_node()

        match['reshape_2'].replace_node(resample_node)

        resample_node.add_input_port(1, skip_if_exist=True)
        assert resample_node.in_port(1).disconnected()
        const.out_port(0).connect(resample_node.in_port(1))

        graph.remove_nodes_from([node.id for node in match.values() if node.id != match['op'].id])
    def replace_sub_graph(self, graph: Graph, match: dict):
        resize_node = match['resize']
        if match['mul_1'].in_node(1).value != match['mul_2'].in_node(1).value or \
                match['mul_1'].in_node(1).value != match['mul_3'].in_node(1).value:
            log.info(
                'Pattern matched around resize op {} has different scale values.'
                .format(resize_node.name))
            return

        interpolate_node = Interpolate(
            graph, {
                'name': resize_node.name + '/Interpolate',
                'mode': resize_node.mode,
                'axes': int64_array([2, 3, 4])
            }).create_node()

        scale = match['mul_1'].in_node(1).value
        scale_value = int64_array([scale, scale, scale])
        scale_const = Const(graph, {
            'value': scale_value,
            'name': resize_node.name + '/Scale'
        }).create_node()

        interpolated_shape = Mul(graph, {
            'name': resize_node.name + '/OutputShape'
        }).create_node()
        match['slice'].out_port(0).connect(interpolated_shape.in_port(0))
        scale_const.out_port(0).connect(interpolated_shape.in_port(1))

        resize_node.in_port(0).get_connection().set_destination(
            interpolate_node.in_port(0))
        interpolated_shape.out_port(0).connect(interpolate_node.in_port(1))
        resize_node.out_port(0).get_connection().set_source(
            interpolate_node.out_port(0))
Exemple #4
0
    def test_interpolate4_using_scales_without_axes(self, pads_begin, pads_end, input_shape, output_shape, sizes,
                                                   scales):
        graph = build_graph(nodes_attrs=graph_node_attrs_without_axes,
                            edges=graph_edges_without_axes,
                            update_attributes={
                                'input_data': {'shape': input_shape},
                                'sizes': {'shape': int64_array(sizes).shape, 'value': int64_array(sizes)},
                                'sizes_data': {'shape': int64_array(sizes).shape, 'value': int64_array(sizes)},
                                'scales': {'shape': np.array(scales).shape, 'value': np.array(scales)},
                                'scales_data': {'shape': np.array(scales).shape, 'value': np.array(scales)},
                                'interpolate': {'pads_begin': int64_array(pads_begin),
                                                'pads_end': int64_array(pads_end),
                                                'shape_calculation_mode': 'scales'}
                            })

        node = Node(graph, 'interpolate')
        tested_class = Interpolate(graph=graph, attrs=node.attrs())
        tested_class.infer(node)

        msg = "Interpolate-4 infer failed for case: sizes={}, scales={}, pads_begin={}, pads_end={}," \
              " expected_shape={}, actual_shape={}"

        self.assertTrue(np.array_equal(graph.node['interpolate_data']['shape'], int64_array(output_shape)),
                        msg.format(sizes, scales, pads_begin, pads_end, output_shape,
                                   graph.node['interpolate_data']['shape']))
Exemple #5
0
    def extract(node):
        proto_layer = node.pb
        param = proto_layer.resample_param
        types = [
            "",
            'nearest',
            'linear',
            'cubic',
            'area',
        ]
        resample_type = types[param.type]

        update_attrs = {
            'antialias': int(param.antialias),
            'height': param.height,
            'width': param.width,
            'type': resample_type,
            'factor': param.factor,
            'fw': 'caffe',
        }

        mapping_rule = merge_attrs(param, update_attrs)
        mapping_rule['mode'] = mapping_rule['type']
        mapping_rule['axes'] = int64_array([2, 3])
        mapping_rule.pop('type')
        Interpolate.update_node_stat(node, mapping_rule)
        return __class__.enabled
 def extract(node):
     mapping_rule = {
         'mode': 'nearest',
         'antialias': 0,
         'axes': int64_array([1, 2]),
     }
     Interpolate.update_node_stat(node, mapping_rule)
     return __class__.enabled
Exemple #7
0
 def extract(cls, node):
     mapping_rule = {
         'align_corners': int(node.pb.attr['align_corners'].b),
         'mode': 'linear',
         'axes': int64_array([1, 2]),
     }
     Interpolate.update_node_stat(node, mapping_rule)
     return cls.enabled
    def replace_op(self, graph: Graph, node: Node):
        mode = node.module.mode
        if mode == 'bilinear':
            mode = 'linear'
        align_corners = node.module.align_corners

        if mode == 'linear' and not align_corners:
            height = node.module.size[0]
            width = node.module.size[1]
            attrs = {
                'name': node.name,
                'version': 'opset4',
                'height': height,
                'width': width,
                'mode': mode,
                'axes': [2, 3],
                'pads_begin': [0, 0],
                'pads_end': [0, 0],
                'align_corners': node.module.align_corners,
                'shape_calculation_mode': 'sizes',
            }

            sizes = Const(graph, {
                'value': np.array([height, width])
            }).create_node()
            axes = Const(graph, {'value': np.array([2, 3])}).create_node()
            scales = Const(graph, {
                'value': np.array([1, 1], dtype=np.float32)
            }).create_node()
            interp = Interpolate(graph, attrs).create_node(
                [node.in_node(0), sizes, scales, axes])
        else:
            if node.module.size:
                attrs = {
                    'name': node.name,
                    'version': 'opset1',
                    'height': node.module.size[0],
                    'width': node.module.size[1],
                    'mode': mode,
                    'axes': [2, 3],
                    'align_corners': node.module.align_corners,
                }
                interp = Interpolate(graph,
                                     attrs).create_node([node.in_node(0)])
            else:
                if not node.module.scale_factor:
                    raise Error('No scale_factor found')
                attrs = {
                    'name': node.name,
                    'height_scale': np.float(node.module.scale_factor),
                    'width_scale': np.float(node.module.scale_factor),
                    'mode': mode,
                    'align_corners': node.module.align_corners,
                }
                interp = UpsampleOp(graph,
                                    attrs).create_node([node.in_node(0)])

        return [interp.id]
Exemple #9
0
 def extract(node):
     mapping_rule = {
         'pads_begin': 0,
         'pads_end': 0,
         'align_corners': int(node.pb.attr['align_corners'].b),
         'mode': 'linear',
         'axes': int64_array([1, 2]),
     }
     Interpolate.update_node_stat(node, mapping_rule)
     return __class__.enabled
Exemple #10
0
    def extract(cls, node):
        attrs = get_mxnet_layer_attrs(node.symbol_dict)
        scale = attrs.int("scale", 1)
        num_filter = attrs.int("num_filter", 0)
        mode = attrs.str("sample_type", None)
        if mode == 'nearest':
            node_attrs = {
                'factor': attrs.int("scale", 1),
                'mode': mode,
                'antialias': 0,
                'axes': int64_array([2, 3]),
            }
            Interpolate.update_node_stat(node, node_attrs)
        elif mode == 'bilinear':
            """
            Bilinear UpSampling uses deconvolution algorithm under the hood.
            For MXNet Bilinear UpSampling op just wrapper over Deconvolution op.
            Inputs data:
                input1 - input data
                input2 - deconvolution weight
            """
            kernel = 2 * scale - scale % 2
            stride = scale
            pad = math.ceil((scale - 1) / 2)
            num_group = num_filter

            node_attrs = {
                'op': __class__.op,
                'type': 'Deconvolution',
                'bias_addable': True,
                'bias_term': False,
                'pad': int64_array([[0, 0], [0, 0], [pad, pad], [pad, pad]]),
                'pad_spatial_shape': int64_array([[pad, pad], [pad, pad]]),
                'dilation': None,
                'output_spatial_shape': None,
                'output_shape': None,
                'stride': int64_array([1, 1, stride, stride]),
                'group': num_group,
                'output': num_filter,
                'kernel_spatial': int64_array([kernel, kernel]),
                'input_feature_channel': 0,
                'output_feature_channel': 1,
                'kernel_spatial_idx': None,
                'reshape_kernel': True,
                'spatial_dims': None,
                'channel_dims': int64_array([1]),
                'batch_dims': int64_array([0]),
                'layout': 'NCHW',
                'get_pad': DeconvFrontExtractor.get_pad,
            }
            Convolution.update_node_stat(node, node_attrs)
        return cls.enabled
    def __call__(self, first: Node, second: Node) -> bool:
        """
        This function checks whether Interpolate nodes 'first' and 'second' can be fused.
        :param first: the first of fused nodes
        :param second: the second of fused nodes
        :return: True, if nodes can be fused, and False otherwise
        """
        # If some of attributes 'mode', 'align_corners', 'antialias', 'pads_begin', 'pads_end' are different,
        # then nodes cannot be fused, because fused result will be incorrect.
        op = Interpolate(graph=first.graph, attrs={})
        for attr in [
                'mode', 'align_corners', 'antialias', 'pads_begin', 'pads_end'
        ]:
            if first.soft_get(attr, default=op.attrs[attr]) != second.soft_get(
                    attr, default=op.attrs[attr]):
                return False

        fst_axes = set([a for a in first.axes])
        snd_axes = set([a for a in second.axes])

        self.accumulated_axes = self.accumulated_axes | fst_axes

        # If the set of accumulated axes and the set of axes of 'second' do not intersect then nodes can be fused,
        # because interpolations with respect to various axes do not affect each other.
        if not (self.accumulated_axes & snd_axes):
            return True

        # Otherwise, nodes cannot be fused.
        self.accumulated_axes = set()
        return False
    def extract(node):
        proto_layer = node.pb
        param = proto_layer.interp_param

        update_attrs = {
            'height': param.height,
            'width': param.width,
            'zoom_factor': param.zoom_factor,
            'shrink_factor': param.shrink_factor,
        }

        mapping_rule = merge_attrs(param, update_attrs)
        mapping_rule.update({'fw': 'caffe', 'mode': 'linear', 'axes': int64_array([2, 3]),
                             'pads_begin': param.pad_beg, 'pads_end': param.pad_end, 'align_corners': 1})
        Interpolate.update_node_stat(node, mapping_rule)
        return __class__.enabled
Exemple #13
0
    def make_interpolate_reshape_able(self, interpolate: Node, concat: Node):
        assert interpolate.soft_get('type') == 'Interpolate'
        assert concat.soft_get('type') == 'Concat'
        interp_axes = Interpolate.get_axes(interpolate)
        concat_axis = self.get_concat_axis(concat)

        if concat_axis is None or interp_axes is None \
                or np.any(interp_axes < 0) or concat_axis < 0 \
                or concat_axis in interp_axes:
            # checks that interpolate axes and concat axis are valid and do not intersect
            return

        non_interp_concat_srcs = self.get_non_interpolate_concat_sources(
            concat)
        if not len(non_interp_concat_srcs):
            # there is no Concat input to take input from
            return

        graph = interpolate.graph
        src = non_interp_concat_srcs[0]

        shape = Shape(graph, {
            'name': src.node.soft_get('name', src.node.id) + '/Shape'
        }).create_node()
        shape.in_port(0).connect(src)
        gather = create_op_with_const_inputs(
            graph,
            Gather, {
                1: np.array(interp_axes, dtype=np.int32),
                2: int64_array(0)
            }, {'name': shape.name + '/Gathered'},
            input_node=shape)
        interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
    def make_interpolate_reshapeable(interpolate, concat):
        assert interpolate.soft_get('type') == 'Interpolate'
        assert concat.soft_get('type') == 'Concat'

        output_shape = interpolate.out_port(0).data.get_shape()

        interp_axes = [get_canonical_axis_index(output_shape, axis) for axis in Interpolate.get_axes(interpolate)]
        concat_axis = get_canonical_axis_index(output_shape, concat.axis)
        if concat_axis in interp_axes:
            return

        concat_srcs = [port.get_source() for port in concat.in_ports().values() if not port.disconnected()]
        non_interp_concat_srcs = [src for src in concat_srcs if src.node.soft_get('type') != 'Interpolate']
        if len(non_interp_concat_srcs) == 0:
            return

        graph = interpolate.graph
        src = non_interp_concat_srcs[0]

        shape = Shape(graph, {'name': src.node.soft_get('name', src.node.id) + '/Shape'}).create_node()
        shape.in_port(0).connect(src)
        gather = create_op_with_const_inputs(graph, Gather,
                                             {1: np.array(interp_axes, dtype=np.int32), 2: int64_array(0)},
                                             {'name': shape.name + '/Gathered'}, shape)
        interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
Exemple #15
0
def replace_interpolate_pattern(graph: Graph, match: dict):
    split = match['split']
    scale = int64_array([get_split_scale(split)])
    axis = int(split.in_port(1).get_connection().get_source().node.value)
    split_node_name = split.name

    shape_node = Shape(graph,
                       dict(name=split_node_name + '/Shape_')).create_node()
    scales_node = Const(graph,
                        dict(name=split_node_name + '/scales_',
                             value=scale)).create_node()
    mul_node = Mul(graph, dict(name=split_node_name + '/Mul_')).create_node()
    scales_node.out_port(0).connect(mul_node.in_port(1))

    slice_begin = Const(
        graph,
        dict(name=split_node_name + '/slice_begin_',
             value=int64_array([axis]))).create_node()
    slice_end = Const(
        graph,
        dict(name=split_node_name + '/slice_end_',
             value=int64_array([axis + 1]))).create_node()

    strided_slice_node = StridedSlice(
        graph, {
            'name': split_node_name + '/StridedSlice_',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0]),
        }).create_node([shape_node, slice_begin, slice_end])
    strided_slice_node.out_port(0).connect(mul_node.in_port(0))

    interp_node = Interpolate(
        graph,
        dict(name=split_node_name + '/Interpolate_',
             axes=int64_array([axis]),
             mode='nearest')).create_node()
    mul_node.out_port(0).connect(interp_node.in_port(1))

    match['concat'].out_port(0).get_connection().set_source(
        interp_node.out_port(0))

    split_connection = split.in_port(0).get_connection()
    split_connection.set_destination(interp_node.in_port(0))
    split_connection.get_source().connect(shape_node.in_port(0))
def replace_interpolate_pattern(graph: Graph, match: dict):
    split = match['split']
    scale = np.array([get_split_scale(split)], dtype=np.float32)
    axis = int(split.in_port(1).get_connection().get_source().node.value)
    split_node_name = split.name
    axis_node = Const(graph, {'name': split_node_name + '/axis', 'value': int64_array([axis])}).create_node()

    shape_node = Shape(graph, dict(name=split_node_name + '/Shape')).create_node()
    scales_node = Const(graph, dict(name=split_node_name + '/scales', value=scale)).create_node()
    mul_node = Mul(graph, dict(name=split_node_name + '/Mul')).create_node()
    scales_node.out_port(0).connect(mul_node.in_port(1))

    strided_slice_node = create_op_with_const_inputs(graph,
                                                     StridedSlice,
                                                     {1: int64_array([axis]), 2: int64_array([axis + 1])},
                                                     {
                                                        'name': split_node_name + '/StridedSlice',
                                                        'begin_mask': int64_array([1]),
                                                        'end_mask': int64_array([1]),
                                                        'new_axis_mask': int64_array([0]),
                                                        'shrink_axis_mask': int64_array([0]),
                                                        'ellipsis_mask': int64_array([0])
                                                     })
    shape_node.out_port(0).connect(strided_slice_node.in_port(0))

    cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node()

    strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0))
    cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))

    interp_node = Interpolate(graph,
                              dict(name=split_node_name + '/Interpolate',
                                   mode='nearest',
                                   antialias=0, pads_begin=int64_array([0]), pads_end=int64_array([0]),
                                   coordinate_transformation_mode='half_pixel', nearest_mode='round_prefer_floor',
                                   cube_coeff=-0.75, version='opset4', shape_calculation_mode='scales',
                                   in_ports_count=4, maybe_part_of_sequence=True)).create_node()

    floor_node = Floor(graph, {'name': split_node_name + '/Floor'}).create_node()
    cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node()

    mul_node.out_port(0).connect(floor_node.in_port(0))
    floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0))

    cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1))
    scales_node.out_port(0).connect(interp_node.in_port(2))
    axis_node.out_port(0).connect(interp_node.in_port(3))

    match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0))

    split_connection = split.in_port(0).get_connection()
    split_connection.set_destination(interp_node.in_port(0))
    split_connection.get_source().connect(shape_node.in_port(0))
 def _compare_attributes_of_interpolate1(self, first: Node, second: Node) -> bool:
     """
     This function checks whether attributes of Interpolate-1 nodes first and second are identical
     (except attribute 'axes').
     :param first: the first of compared nodes
     :param second: the second of compared nodes
     :return: True, if attributes of nodes are identical and False otherwise
     """
     # If some of attributes 'mode', 'align_corners', 'antialias', 'pads_begin', 'pads_end' are different,
     # then attributes of nodes are not identical.
     op = Interpolate(graph=first.graph, attrs={})
     for attr in ['mode', 'align_corners', 'antialias', 'pads_begin', 'pads_end']:
         if first.soft_get(attr, default=op.attrs[attr]) != second.soft_get(attr, default=op.attrs[attr]):
             return False
     return True
 def make_interpolate_reshapeable(interpolate):
     assert interpolate.soft_get('type') == 'Interpolate'
     axes = Interpolate.get_axes(interpolate)
     input_shape = interpolate.in_port(0).data.get_shape()
     output_shape = interpolate.out_port(0).data.get_shape()
     if not np.all(np.remainder(output_shape, input_shape) == 0) and \
             not np.all(np.remainder(input_shape, output_shape) == 0):
         return
     graph = interpolate.graph
     name = interpolate.soft_get('name', interpolate.id)
     shape = Shape(graph, {'name': name + '/ShapeOf'}).create_node()
     shape.in_port(0).connect(interpolate.in_port(0).get_source())
     gather = create_op_with_const_inputs(graph, Gather, {1: np.array(axes, dtype=np.int32), 2: int64_array(0)},
                                          {'name': shape.name + '/Gathered'}, shape)
     multipliers = output_shape[axes] / input_shape[axes]
     mul = create_op_node_with_second_input(graph, Mul, multipliers, {'name': gather.name + '/Multiplied'}, gather)
     interpolate.in_port(1).get_connection().set_source(mul.out_port(0))
    def replace_sub_graph(self, graph: Graph, match: dict):
        interpolate = match['interpolate']
        transpose_1 = match['transpose_1']
        transpose_2 = match['transpose_2']

        axes = Interpolate.get_axes(interpolate)
        if axes is None or not np.array_equal(axes, int64_array([1, 2])):
            return

        # because we remove Transpose layers the ResizeNearestNeighbor should be updated for NCHW layout
        opset = interpolate.get_opset()
        assert opset in ['opset1', 'opset4'], \
            'Interpolate node with name {} has unsupported opset'.format(interpolate.soft_get('name', interpolate.id))
        if opset == 'opset1':
            interpolate.axes = int64_array([2, 3])
        else:
            interpolate.in_port(3).data.set_value(int64_array([2, 3]))

        transpose_1.in_port(0).get_connection().set_destination(interpolate.in_port(0))
        transpose_2.out_port(0).get_connection().set_source(interpolate.out_port(0))

        graph.remove_nodes_from([transpose_1.id, transpose_2.id])
    def replace_op(self, graph: Graph, node: Node):
        mode = node.module.mode
        if mode.endswith('linear'):  # like bilinear or trilinear
            mode = 'linear'
        align_corners = node.module.align_corners

        if mode == 'linear':
            height = node.module.size[0] if node.module.size is not None else -1
            width = node.module.size[1] if node.module.size is not None else -1
            dims = node.module.dims
            axes = np.arange(2, dims)
            pads = np.zeros(dims, dtype=np.int32)
            scales = np.repeat(node.module.scale_factor,
                               dims - 2).astype(np.float32)
            attrs = {
                'name':
                node.name,
                'version':
                'opset4',
                'height':
                height,
                'width':
                width,
                'mode':
                mode,
                'axes':
                axes,
                'pads_begin':
                pads,
                'pads_end':
                pads,
                'coordinate_transformation_mode':
                'align_corners' if align_corners else 'half_pixel',
                'shape_calculation_mode':
                'sizes' if node.module.size is not None else 'scales',
            }

            sizes = Const(graph, {
                'value': np.array([height, width])
            }).create_node()
            axes = Const(graph, {'value': axes}).create_node()
            scales = Const(graph, {'value': scales}).create_node()
            interp = Interpolate(graph, attrs).create_node(
                [node.in_node(0), sizes, scales, axes])
        else:
            if node.module.size:
                attrs = {
                    'name': node.name,
                    'version': 'opset1',
                    'height': node.module.size[0],
                    'width': node.module.size[1],
                    'mode': mode,
                    'axes': [2, 3],
                    'align_corners': node.module.align_corners,
                }
                interp = Interpolate(graph,
                                     attrs).create_node([node.in_node(0)])
            else:
                if not node.module.scale_factor:
                    raise Error('No scale_factor found')
                attrs = {
                    'name': node.name,
                    'height_scale': np.float(node.module.scale_factor),
                    'width_scale': np.float(node.module.scale_factor),
                    'mode': mode,
                    'align_corners': node.module.align_corners,
                }
                interp = UpsampleOp(graph,
                                    attrs).create_node([node.in_node(0)])

        return [interp.id]
    def replace_pattern(self, graph: Graph, match: dict):
        unsqueeze_node = match['unsqueeze']
        unsqueeze_name = unsqueeze_node.name

        second_input_of_unsqueeze = unsqueeze_node.in_port(
            1).get_connection().get_source().node
        if not second_input_of_unsqueeze.has_valid('value'):
            return

        d_idx = int(second_input_of_unsqueeze.value)

        second_input_of_tile = match['tile'].in_port(
            1).get_connection().get_source().node
        if not second_input_of_tile.has_valid('value'):
            return

        input_shape_of_unsqueeze = unsqueeze_node.in_port(0).data.get_shape()
        if len(input_shape_of_unsqueeze) not in {4, 5}:
            return

        scale = float32_array([second_input_of_tile.value[d_idx]])
        axis = d_idx - 1
        axis_node = Const(graph, {
            'name': unsqueeze_name + '/axis',
            'value': int64_array([axis])
        }).create_node()

        shape_node = Shape(graph,
                           dict(name=unsqueeze_name + '/Shape')).create_node()
        scales_node = Const(graph,
                            dict(name=unsqueeze_name + '/scales',
                                 value=scale)).create_node()
        mul_node = Mul(graph, dict(name=unsqueeze_name + '/Mul')).create_node()
        scales_node.out_port(0).connect(mul_node.in_port(1))

        slice_begin = Const(
            graph,
            dict(name=unsqueeze_name + '/slice_begin',
                 value=int64_array([axis]))).create_node()
        slice_end = Const(
            graph,
            dict(name=unsqueeze_name + '/slice_end',
                 value=int64_array([axis + 1]))).create_node()

        strided_slice_node = StridedSlice(
            graph, {
                'name': unsqueeze_name + '/StridedSlice',
                'begin_mask': int64_array([1]),
                'end_mask': int64_array([1]),
                'new_axis_mask': int64_array([0]),
                'shrink_axis_mask': int64_array([0]),
                'ellipsis_mask': int64_array([0]),
            }).create_node()
        shape_node.out_port(0).connect(strided_slice_node.in_port(0))
        slice_begin.out_port(0).connect(strided_slice_node.in_port(1))
        slice_end.out_port(0).connect(strided_slice_node.in_port(2))

        cast_shape_to_float = Cast(graph, {
            'dst_type': np.float32
        }).create_node()

        strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0))
        cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))

        interp_node = Interpolate(
            graph,
            dict(mode='nearest',
                 antialias=0,
                 pads_begin=int64_array([0]),
                 pads_end=int64_array([0]),
                 coordinate_transformation_mode='half_pixel',
                 nearest_mode='round_prefer_floor',
                 cube_coeff=-0.75,
                 version='opset4',
                 shape_calculation_mode='scales',
                 in_ports_count=4,
                 maybe_part_of_sequence=True)).create_node()

        floor_node = Floor(graph, {
            'name': unsqueeze_name + '/Floor'
        }).create_node()
        cast_mul_result_to_int = Cast(graph, {
            'dst_type': np.int64
        }).create_node()

        mul_node.out_port(0).connect(floor_node.in_port(0))
        floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0))

        cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1))
        scales_node.out_port(0).connect(interp_node.in_port(2))
        axis_node.out_port(0).connect(interp_node.in_port(3))

        reshape_node = match['reshape']

        reshape_node.out_port(0).get_connection().set_source(
            interp_node.out_port(0))
        reshape_name = reshape_node.soft_get('name', reshape_node.id)
        rename_nodes([(reshape_node, reshape_name + '/delete'),
                      (interp_node, reshape_name)])

        unsqueeze_connection = match['unsqueeze'].in_port(0).get_connection()
        before_unsqueeze = unsqueeze_connection.get_source().node
        unsqueeze_connection.set_destination(interp_node.in_port(0))
        before_unsqueeze.out_port(0).connect(shape_node.in_port(0))
Exemple #22
0
    def replace_pattern(self, graph: Graph, match: dict):
        unsqueeze_node = match['unsqueeze']
        unsqueeze_name = unsqueeze_node.name

        second_input_of_unsqueeze = unsqueeze_node.in_port(
            1).get_connection().get_source().node
        if not second_input_of_unsqueeze.has_valid('value'):
            return

        d_idx = int(second_input_of_unsqueeze.value)

        second_input_of_tile = match['tile'].in_port(
            1).get_connection().get_source().node
        if not second_input_of_tile.has_valid('value'):
            return

        input_shape_of_unsqueeze = unsqueeze_node.in_port(0).data.get_shape()
        if len(input_shape_of_unsqueeze) not in {4, 5}:
            return

        scale = int64_array([second_input_of_tile.value[d_idx]])
        axis = d_idx - 1

        shape_node = Shape(graph, dict(name=unsqueeze_name +
                                       '/Shape_')).create_node()
        scales_node = Const(
            graph, dict(name=unsqueeze_name + '/scales_',
                        value=scale)).create_node()
        mul_node = Mul(graph,
                       dict(name=unsqueeze_name + '/Mul_')).create_node()
        scales_node.out_port(0).connect(mul_node.in_port(1))

        slice_begin = Const(
            graph,
            dict(name=unsqueeze_name + '/slice_begin_',
                 value=int64_array([axis]))).create_node()
        slice_end = Const(
            graph,
            dict(name=unsqueeze_name + '/slice_end_',
                 value=int64_array([axis + 1]))).create_node()

        strided_slice_node = StridedSlice(
            graph, {
                'name': unsqueeze_name + '/StridedSlice_',
                'begin_mask': int64_array([1]),
                'end_mask': int64_array([1]),
                'new_axis_mask': int64_array([0]),
                'shrink_axis_mask': int64_array([0]),
                'ellipsis_mask': int64_array([0]),
            }).create_node()
        shape_node.out_port(0).connect(strided_slice_node.in_port(0))
        slice_begin.out_port(0).connect(strided_slice_node.in_port(1))
        slice_end.out_port(0).connect(strided_slice_node.in_port(2))
        strided_slice_node.out_port(0).connect(mul_node.in_port(0))

        interp_node = Interpolate(
            graph,
            dict(name=unsqueeze_name + '/Interpolate_',
                 axes=int64_array([axis]),
                 mode='nearest')).create_node()
        mul_node.out_port(0).connect(interp_node.in_port(1))

        match['reshape'].out_port(0).get_connection().set_source(
            interp_node.out_port(0))

        unsqueeze_connection = match['unsqueeze'].in_port(0).get_connection()
        before_unsqueeze = unsqueeze_connection.get_source().node
        unsqueeze_connection.set_destination(interp_node.in_port(0))
        before_unsqueeze.out_port(0).connect(shape_node.in_port(0))
Exemple #23
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        log.debug('UpsampleToResample is triggered')
        upsample = match['upsample']
        input_shape = upsample.in_port(0).data.get_shape()
        input_shape_rank = len(input_shape)
        if input_shape_rank not in [4, 5]:
            log.warning('The input shape is not 4D or 5D for op {}'.format(
                upsample.soft_get('name')))
            return

        if len(upsample.in_nodes()) == 2:
            if upsample.in_node(1).value is None:
                return
            scales = upsample.in_node(1).value
            assert scales.shape == (4, )
            if not (math.isclose(scales[0], 1, rel_tol=1e-5)
                    and math.isclose(scales[1], 1, rel_tol=1e-5)):
                return
            height_scale = scales[2]
            width_scale = scales[3]
        else:
            height_scale = upsample['height_scale']
            width_scale = upsample['width_scale']

        if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
            upsample.in_port(1).disconnect()

        factor = Const(graph, {
            'value': np.array([height_scale, width_scale])
        }).create_node()

        shape = Shape(graph, {'name': upsample.name + '/0_port'}).create_node()

        layout = graph.graph['layout']
        if input_shape_rank == 4:
            begin = Const(graph, {
                'value':
                int64_array([get_height_dim(layout, input_shape_rank)])
            }).create_node()
        else:
            begin = Const(graph, {
                'value':
                int64_array([get_depth_dim(layout, input_shape_rank)])
            }).create_node()
        end = Const(graph, {
            'value':
            int64_array([get_width_dim(layout, input_shape_rank) + 1])
        }).create_node()

        stride = Const(graph, {'value': int64_array([1])}).create_node()
        ss = StridedSlice(
            graph, {
                'name': upsample.name + '/ss_0_port',
                'begin_mask': np.array([1]),
                'end_mask': np.array([0]),
                'new_axis_mask': np.array([0]),
                'shrink_axis_mask': int64_array([0]),
                'ellipsis_mask': int64_array([0])
            }).create_node()

        mul = Mul(graph, {
            'name': upsample.name + '/factor_mul_'
        }).create_node()

        source = upsample.in_port(0).get_connection().get_source()
        source.connect(shape.in_port(0))
        shape.out_port(0).connect(ss.in_port(0))
        begin.out_port(0).connect(ss.in_port(1))
        end.out_port(0).connect(ss.in_port(2))
        stride.out_port(0).connect(ss.in_port(3))
        ss.out_port(0).connect(mul.in_port(0))
        factor.out_port(0).connect(mul.in_port(1))

        # Create Interpolate operation
        if input_shape_rank == 4:
            axes = int64_array([
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])
        else:
            axes = int64_array([
                get_depth_dim(layout, input_shape_rank),
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])

        resample_op = Interpolate(
            graph,
            dict(name='Interpolate/{}'.format(upsample.name),
                 axes=axes,
                 mode=upsample.attrs()['mode'],
                 antialias=0,
                 convert_to_resample=True)).create_node()

        upsample.add_input_port(1, skip_if_exist=True)
        assert upsample.in_port(1).disconnected()
        mul.out_port(0).connect(resample_op.in_port(1))

        upsample.in_port(0).get_connection().set_destination(
            resample_op.in_port(0))
        upsample.out_port(0).get_connection().set_source(
            resample_op.out_port(0))
def replace_sequence(seq: List[Node], graph: Graph):
    """
    This function replaces a sequence of consecutive Interpolate layers with one Interpolate layer,
    if modes of all nodes of a sequence are the same.
    :param seq: sequence of Interpolate layers
    :param graph: graph to which nodes of seq belong
    :return: Nothing
    """
    if not seq:
        return
    if len(seq) == 1:
        return

    modes = set([n.mode for n in seq])
    if len(modes) != 1:
        return

    dims_and_scales_ = []
    # Each element of the list dims_and_scales_ is a pair
    #      (axis, output size for this axis)
    for interp in seq:
        dims_and_scales_.extend(
            zip(interp.axes,
                interp.in_port(1).get_connection().get_source().node.value))

    axis_to_size = sorted(list(dict(dims_and_scales_).items()),
                          key=lambda x: x[0])
    axes_of_node = int64_array([z[0] for z in axis_to_size])
    sizes = int64_array([z[1] for z in axis_to_size])

    fst_interp_node = seq[0]
    last_interp_node = seq[-1]
    fst_interp_node_name = fst_interp_node.name
    fst_interp_node_mode = fst_interp_node.mode
    fst_interp_node_align_corners = fst_interp_node.soft_get('align_corners',
                                                             default=0)
    fst_interp_node_antialias = fst_interp_node.soft_get('antialias',
                                                         default=0)
    fst_interp_node_pads_begin = fst_interp_node.soft_get('pads_begin',
                                                          default=0)
    fst_interp_node_pads_end = fst_interp_node.soft_get('pads_end', default=0)
    interp_node = Interpolate(
        graph,
        dict(name=fst_interp_node_name + '/Interpolate_',
             axes=axes_of_node,
             mode=fst_interp_node_mode,
             align_corners=fst_interp_node_align_corners,
             antialias=fst_interp_node_antialias,
             pads_begin=fst_interp_node_pads_begin,
             pads_end=fst_interp_node_pads_end)).create_node()

    scales_node = Const(
        graph, dict(name=fst_interp_node_name + '/scales_',
                    value=sizes)).create_node()
    scales_node.out_port(0).connect(interp_node.in_port(1))

    fst_interp_connection = fst_interp_node.in_port(0).get_connection()
    fst_interp_connection.set_destination(interp_node.in_port(0))

    last_interp_node.out_port(0).get_connection().set_source(
        interp_node.out_port(0))
def replace_resize(graph: Graph, resize: Node):
    log.debug("Converting of ONNX Resize-10 to Interpolate-4 "
              "is triggered for node {}.".format(
                  resize.soft_get('name', resize.id)))

    resize_name = resize.soft_get('name', resize.id)

    rank_node = Rank(graph, {'name': resize_name + '/max_axes'}).create_node()
    range_node = create_op_with_const_inputs(graph, Range, {
        0: int64_array(2),
        2: int64_array(1)
    }, {'name': resize_name + '/axes'})

    sizes_ss = create_op_with_const_inputs(graph, StridedSlice, {
        1: int64_array([2]),
        2: int64_array([0]),
        3: int64_array([1])
    }, {
        'name': resize_name + '/sizes_ss',
        'begin_mask': int64_array([1]),
        'end_mask': int64_array([0]),
        'new_axis_mask': int64_array([0]),
        'shrink_axis_mask': int64_array([0]),
        'ellipsis_mask': int64_array([0])
    })
    scales_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([2]),
            2: int64_array([0]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/scales_ss',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([0]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })

    rank_node.out_port(0).connect(range_node.in_port(1))

    interpolate_node = Interpolate(
        graph, {
            'version': 'opset4',
            'mode': 'linear_onnx' if resize.mode == 'linear' else 'nearest',
            'coordinate_transformation_mode': 'asymmetric',
            'cube_coeff': -0.75,
            'nearest_mode': 'simple',
            'pads_begin': int64_array([0]),
            'pads_end': int64_array([0]),
            'antialias': 0,
            'shape_calculation_mode': 'scales',
            'in_ports_count': 4
        }).create_node()

    range_node.out_port(0).connect(interpolate_node.in_port(3))
    shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()

    # When we calculate 'sizes' input as floor(input_shape * scales), we can get incorrect 'sizes' if, e.g.,
    # scales = [1.0, 1.0, 1.33333, 2.0], input_shape = [1, 3, 30, 200], because
    # input_shape * scales = [1, 3, 39.9999, 400], and floor(input_shape * scales)[2] == 39, not 40.
    # Maybe we need to calculate 'sizes' input as floor(input_shape * scales + eps), where eps is some small
    # floating point number, e.g. 1.0e-5. But, in this case, if scales = [1.0, 1.0, 1.333333, 2.0],
    # input_shape = [1, 3, 30, 200], floor(input_shape * scales + eps) = 39, not 40, because
    # input_shape[2] * scales[2] + 1.0e-5 =  39.99991.
    # Hence, we need to calculate 'sizes' as floor(input_shape * (scales + eps)).
    add_node = create_op_with_const_inputs(graph, Add,
                                           {1: float_array([1.0e-5])},
                                           {'name': resize_name + '/Add'})

    input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)

    cast_shape_to_float = Cast(graph, {
        'dst_type': input_data_type
    }).create_node()

    shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
    mul_node = Mul(graph, {
        'name': resize_name + '/Mul'
    }).create_node([cast_shape_to_float, add_node])
    floor_node = Floor(graph, {
        'name': resize_name + '/Floor'
    }).create_node([mul_node])
    cast_mul_result_to_int = Cast(graph, {
        'dst_type': np.int64
    }).create_node([floor_node])
    cast_mul_result_to_int.out_port(0).connect(sizes_ss.in_port(0))
    sizes_ss.out_port(0).connect(interpolate_node.in_port(1))

    scales_ss.out_port(0).connect(interpolate_node.in_port(2))

    connection_of_resize_input = resize.in_port(0).get_connection()
    connection_of_resize_input.set_destination(interpolate_node.in_port(0))

    connection_of_scales = resize.in_port(1).get_connection()
    connection_of_scales.set_destination(scales_ss.in_port(0))

    connection_of_resize_input.get_source().connect(shape_of.in_port(0))
    connection_of_resize_input.get_source().connect(rank_node.in_port(0))
    connection_of_scales.get_source().connect(add_node.in_port(0))

    rename_nodes([(resize, resize_name + '/delete'),
                  (interpolate_node, resize_name)])
    resize.out_port(0).get_connection().set_source(
        interpolate_node.out_port(0))
Exemple #26
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        log.debug('UpsampleToResample is triggered')
        upsample = match['upsample']
        upsample_name = upsample.soft_get('name', upsample.id)
        input_shape = upsample.in_port(0).data.get_shape()
        input_shape_rank = len(input_shape)
        if input_shape_rank not in [4, 5]:
            log.warning('The input shape is not 4D or 5D for op {}'.format(
                upsample.soft_get('name')))
            return

        depth_scale = None
        if len(upsample.in_nodes()) == 2:
            if upsample.in_node(1).value is None:
                return
            scales = upsample.in_node(1).value
            assert len(scales) in (
                4, 5
            ), 'Supported scales rank is 4 or 5, but it is {} for node {}'.format(
                len(scales), upsample_name)
            if not (math.isclose(scales[0], 1, rel_tol=1e-5)
                    and math.isclose(scales[1], 1, rel_tol=1e-5)):
                return
            height_scale = scales[2]
            width_scale = scales[3]
            if len(scales) == 5:
                depth_scale = scales[4]
        else:
            height_scale = upsample['height_scale']
            width_scale = upsample['width_scale']

        if not math.isclose(height_scale, width_scale, rel_tol=1e-5):
            log.debug(
                'Width and height scales are not equal: {} vs {} for node {}'.
                format(width_scale, height_scale, upsample_name))
            return
        if depth_scale is not None and not math.isclose(
                height_scale, depth_scale, rel_tol=1e-5):
            log.debug(
                'Depth and height scales are not equal: {} vs {} for node {}'.
                format(depth_scale, height_scale, upsample_name))
            return

        if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
            upsample.in_port(1).disconnect()

        shape = Shape(graph, {'name': upsample_name + '/0_port'}).create_node()

        layout = graph.graph['layout']

        if input_shape_rank == 4:
            begin_value = int64_array(
                [get_height_dim(layout, input_shape_rank)])
            factor_value = np.array([height_scale, width_scale])
        else:
            begin_value = int64_array(
                [get_depth_dim(layout, input_shape_rank)])
            factor_value = np.array([depth_scale, height_scale, width_scale])

        ss = create_op_with_const_inputs(
            graph, StridedSlice, {
                1: begin_value,
                2: int64_array([get_width_dim(layout, input_shape_rank) + 1]),
                3: int64_array([1])
            }, {
                'name': upsample_name + '/ss_0_port',
                'begin_mask': int64_array([1]),
                'end_mask': int64_array([1]),
                'new_axis_mask': int64_array([0]),
                'shrink_axis_mask': int64_array([0]),
                'ellipsis_mask': int64_array([0])
            })

        mul = create_op_node_with_second_input(
            graph, Mul, factor_value, {'name': upsample_name + '/factor_mul_'})

        source = upsample.in_port(0).get_connection().get_source()
        source.connect(shape.in_port(0))
        shape.out_port(0).connect(ss.in_port(0))

        ss.out_port(0).connect(mul.in_port(0))

        # Create Interpolate operation
        if input_shape_rank == 4:
            axes = int64_array([
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])
        else:
            axes = int64_array([
                get_depth_dim(layout, input_shape_rank),
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])

        resample_op = Interpolate(
            graph,
            dict(name=upsample_name + '/Interpolate',
                 axes=axes,
                 mode=upsample.attrs()['mode'],
                 antialias=0,
                 convert_to_resample=True)).create_node()

        upsample.add_input_port(1, skip_if_exist=True)
        assert upsample.in_port(1).disconnected()
        mul.out_port(0).connect(resample_op.in_port(1))

        upsample.in_port(0).get_connection().set_destination(
            resample_op.in_port(0))
        upsample.out_port(0).get_connection().set_source(
            resample_op.out_port(0))

        convert_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()
        convert_to_int = Cast(graph, dict(dst_type=np.int64)).create_node()

        mul.in_port(0).get_connection().insert_node(convert_to_float)
        mul.out_port(0).get_connection().insert_node(convert_to_int)
def replace_resize(graph: Graph, resize: Node):
    log.debug("Converting of ONNX Resize-11 to Interpolate-4 "
              "is triggered for node {}.".format(
                  resize.soft_get('name', resize.id)))

    input_shape = resize.in_port(0).data.get_shape()
    input_rank = len(input_shape)
    resize_name = resize.soft_get('name', resize.id)
    if input_rank not in {4, 5}:
        log.warning(
            'The input shape is not 4D or 5D for op with name {}'.format(
                resize_name))
        return

    num_of_inputs = len([
        port for port in resize.in_ports().values() if not port.disconnected()
    ])
    assert num_of_inputs in {3, 4}, \
        "Number of inputs of ONNXResize (with name {}) should be equal to 3 or 4".format(resize_name)

    assert resize.soft_get('coordinate_transformation_mode') != 'tf_crop_and_resize', \
        'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(resize.op, resize_name)

    layout = graph.graph['layout']

    if input_rank == 4:
        begin_dim = get_height_dim(layout, input_rank)
        end_dim = get_width_dim(layout, input_rank) + 1
    else:
        begin_dim = get_depth_dim(layout, input_rank)
        end_dim = get_width_dim(layout, input_rank) + 1

    sizes_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([begin_dim]),
            2: int64_array([end_dim]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/StridedSlice_sizes',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })
    scales_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([begin_dim]),
            2: int64_array([end_dim]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/StridedSlice_scales',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })
    axes_node = Const(
        graph, {
            'name': resize_name + '/axis',
            'value': int64_array(np.arange(begin_dim, end_dim))
        }).create_node()

    shape_calculation_mode = 'scales' if num_of_inputs == 3 else 'sizes'

    interpolate_node = Interpolate(
        graph, {
            'version': 'opset4',
            'mode': convert_mode(resize.mode),
            'coordinate_transformation_mode':
            resize.coordinate_transformation_mode,
            'cube_coeff': resize.cube_coeff,
            'nearest_mode': resize.nearest_mode,
            'pads_begin': int64_array([0]),
            'pads_end': int64_array([0]),
            'antialias': 0,
            'shape_calculation_mode': shape_calculation_mode,
            'in_ports_count': 4
        }).create_node()

    axes_node.out_port(0).connect(interpolate_node.in_port(3))
    shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()

    add_node = create_op_with_const_inputs(graph, Add,
                                           {1: float_array([1.0e-5])},
                                           {'name': resize_name + '/Add'})

    input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)

    if num_of_inputs == 3:
        cast_shape_to_float = Cast(graph, {
            'dst_type': input_data_type
        }).create_node()
        mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node()
        shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
        cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))
        cast_add_result_to_int = Cast(graph, {
            'dst_type': np.int64
        }).create_node()
        floor_node = Floor(graph, {
            'name': resize_name + '/Floor'
        }).create_node()
        mul_node.out_port(0).connect(add_node.in_port(0))
        add_node.out_port(0).connect(floor_node.in_port(0))
        floor_node.out_port(0).connect(cast_add_result_to_int.in_port(0))
        cast_add_result_to_int.out_port(0).connect(sizes_ss.in_port(0))
        sizes_ss.out_port(0).connect(interpolate_node.in_port(1))
        scales_ss.out_port(0).connect(interpolate_node.in_port(2))

        connection_of_resize_input = resize.in_port(0).get_connection()
        connection_of_resize_input.set_destination(interpolate_node.in_port(0))

        connection_of_scales = resize.in_port(2).get_connection()
        connection_of_scales.set_destination(scales_ss.in_port(0))

        connection_of_resize_input.get_source().connect(shape_of.in_port(0))
        connection_of_scales.get_source().connect(mul_node.in_port(1))
    else:
        cast_shape_to_float = Cast(graph, {
            'dst_type': input_data_type
        }).create_node()
        cast_sizes_to_float = Cast(graph, {
            'dst_type': input_data_type
        }).create_node()
        div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()
        cast_sizes_to_float.out_port(0).connect(div_node.in_port(0))
        cast_shape_to_float.out_port(0).connect(div_node.in_port(1))
        shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
        div_node.out_port(0).connect(add_node.in_port(0))
        add_node.out_port(0).connect(scales_ss.in_port(0))
        scales_ss.out_port(0).connect(interpolate_node.in_port(2))
        sizes_ss.out_port(0).connect(interpolate_node.in_port(1))

        connection_of_resize_input = resize.in_port(0).get_connection()
        connection_of_resize_input.set_destination(interpolate_node.in_port(0))

        connection_of_sizes = resize.in_port(3).get_connection()
        connection_of_sizes.set_destination(sizes_ss.in_port(0))

        connection_of_resize_input.get_source().connect(shape_of.in_port(0))
        connection_of_sizes.get_source().connect(
            cast_sizes_to_float.in_port(0))

    rename_nodes([(resize, resize_name + '/delete'),
                  (interpolate_node, resize_name)])
    resize.out_port(0).get_connection().set_source(
        interpolate_node.out_port(0))
Exemple #28
0
def replace_sequence(seq: List[Node], graph: Graph):
    """
    This function replaces a sequence of consecutive Interpolate layers with one Interpolate layer,
    if modes of all nodes of a sequence are the same.
    :param seq: sequence of Interpolate layers
    :param graph: graph to which nodes of seq belong
    :return: Nothing
    """
    if not seq:
        return
    if len(seq) == 1:
        return

    modes = set([n.mode for n in seq])
    if len(modes) != 1:
        return

    dims_and_scales_ = []
    # Each element of the list dims_and_scales_ is a pair
    #      (axis, output size for this axis) (opset1)
    # or
    #      (axis, output size for this axis, output scales for this axis) (opset4)
    if seq[0].get_opset() == 'opset1':
        for interp in seq:
            dims_and_scales_.extend(
                zip(
                    Interpolate.get_axes(interp),
                    interp.in_port(
                        1).get_connection().get_source().data.get_value()))

        axis_to_size = sorted(list(dict(dims_and_scales_).items()),
                              key=lambda x: x[0])
        axes_of_node = int64_array([z[0] for z in axis_to_size])
        sizes = int64_array([z[1] for z in axis_to_size])
        scales = np.ones(len(axis_to_size))
    else:
        for interp in seq:
            dims_and_scales_.extend(
                zip(
                    Interpolate.get_axes(interp),
                    interp.in_port(
                        1).get_connection().get_source().data.get_value(),
                    interp.in_port(
                        2).get_connection().get_source().data.get_value()))

        axis_to_size = sorted(dims_and_scales_, key=lambda x: x[0])
        axes_of_node = int64_array([z[0] for z in axis_to_size])
        sizes = int64_array([z[1] for z in axis_to_size])
        scales = np.array([z[2] for z in axis_to_size])

    fst_interp_node = seq[0]
    last_interp_node = seq[-1]
    last_interp_node_name = last_interp_node.soft_get('name',
                                                      last_interp_node.id)
    attributes = get_interpolate_attributes(fst_interp_node)

    opset = fst_interp_node.get_opset()
    if opset == 'opset1':
        attributes['axes'] = axes_of_node
        interp_node = create_op_with_const_inputs(graph, Interpolate,
                                                  {1: sizes}, attributes)

        fst_interp_connection = fst_interp_node.in_port(0).get_connection()
        fst_interp_connection.set_destination(interp_node.in_port(0))

        last_interp_node.out_port(0).get_connection().set_source(
            interp_node.out_port(0))
    else:
        attributes['in_ports_count'] = 4
        interp_node = create_op_with_const_inputs(graph, Interpolate, {
            1: sizes,
            2: scales,
            3: axes_of_node
        }, attributes)

        fst_interp_connection = fst_interp_node.in_port(0).get_connection()
        fst_interp_connection.set_destination(interp_node.in_port(0))

        last_interp_node.out_port(0).get_connection().set_source(
            interp_node.out_port(0))

    rename_nodes([(last_interp_node, last_interp_node_name + '/delete_'),
                  (interp_node, last_interp_node_name)])
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        log.debug('UpsampleToResample is triggered')
        upsample = match['upsample']
        upsample_name = upsample.soft_get('name', upsample.id)
        input_shape = upsample.in_port(0).data.get_shape()
        input_shape_rank = len(input_shape)
        if input_shape_rank not in [4, 5]:
            log.warning('The input shape is not 4D or 5D for op {}'.format(
                upsample.soft_get('name')))
            return

        depth_scale = None
        layout = graph.graph['layout']

        if len(upsample.in_nodes()) == 2:
            if upsample.in_node(1).value is None:
                return
            scales = upsample.in_node(1).value
            assert len(scales) in (
                4, 5
            ), 'Supported scales rank is 4 or 5, but it is {} for node {}'.format(
                len(scales), upsample_name)
            if not (math.isclose(scales[0], 1, rel_tol=1e-5)
                    and math.isclose(scales[1], 1, rel_tol=1e-5)):
                return
            height_scale = scales[get_height_dim(layout, input_shape_rank)]
            width_scale = scales[get_width_dim(layout, input_shape_rank)]
            if len(scales) == 5:
                depth_scale = scales[get_depth_dim(layout, input_shape_rank)]
        else:
            height_scale = upsample['height_scale']
            width_scale = upsample['width_scale']

        if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
            upsample.in_port(1).disconnect()

        upsample_name = upsample.soft_get('name', upsample.id)
        shape = Shape(graph, {'name': upsample_name + '/0_port'}).create_node()

        layout = graph.graph['layout']

        if input_shape_rank == 4:
            begin_value = int64_array(
                [get_height_dim(layout, input_shape_rank)])
            factor_value = float32_array([height_scale, width_scale])
        else:
            begin_value = int64_array(
                [get_depth_dim(layout, input_shape_rank)])
            factor_value = float32_array(
                [depth_scale, height_scale, width_scale])

        ss = create_op_with_const_inputs(
            graph, StridedSlice, {
                1: begin_value,
                2: int64_array([get_width_dim(layout, input_shape_rank) + 1]),
                3: int64_array([1])
            }, {
                'name': upsample_name + '/ss_0_port',
                'begin_mask': int64_array([1]),
                'end_mask': int64_array([1]),
                'new_axis_mask': int64_array([0]),
                'shrink_axis_mask': int64_array([0]),
                'ellipsis_mask': int64_array([0])
            })

        mul = create_op_node_with_second_input(
            graph, Mul, factor_value, {'name': upsample_name + '/factor_mul'})

        source = upsample.in_port(0).get_connection().get_source()
        source.connect(shape.in_port(0))
        shape.out_port(0).connect(ss.in_port(0))

        ss.out_port(0).connect(mul.in_port(0))

        # Create Interpolate operation
        if input_shape_rank == 4:
            axes = int64_array([
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])
        else:
            axes = int64_array([
                get_depth_dim(layout, input_shape_rank),
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])

        axes_node = Const(graph, {
            'name': upsample_name + '/axis',
            'value': axes
        }).create_node()

        interpolate = Interpolate(
            graph, {
                'mode': upsample.attrs()['mode'],
                'antialias': 0,
                'pads_begin': int64_array([0]),
                'pads_end': int64_array([0]),
                'coordinate_transformation_mode': 'half_pixel',
                'nearest_mode': 'round_prefer_floor',
                'cube_coeff': -0.75,
                'shape_calculation_mode': 'scales',
                'version': 'opset4',
                'in_ports_count': 4
            }).create_node()

        upsample.add_input_port(1, skip_if_exist=True)
        assert upsample.in_port(1).disconnected()
        mul.out_port(0).connect(interpolate.in_port(1))
        axes_node.out_port(0).connect(interpolate.in_port(3))

        scales_node = Const(graph, {
            'name': upsample_name + '/scales',
            'value': factor_value
        }).create_node()
        scales_node.out_port(0).connect(interpolate.in_port(2))

        upsample.in_port(0).get_connection().set_destination(
            interpolate.in_port(0))
        upsample.out_port(0).get_connection().set_source(
            interpolate.out_port(0))

        rename_nodes([(upsample, upsample_name + '/delete'),
                      (interpolate, upsample_name)])

        convert_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()
        convert_to_int = Cast(graph, dict(dst_type=np.int64)).create_node()

        mul.in_port(0).get_connection().insert_node(convert_to_float)
        mul.out_port(0).get_connection().insert_node(convert_to_int)
    def replace_sub_graph(self, graph: Graph, match: dict):
        log.debug('Matched NearestNeighborUpsampling pattern: {}'.format(
            [node.id for node in match.values()]))
        try:
            input_height = match['pack_1'].in_node(1).value.item()
            input_width = match['pack_1'].in_node(3).value.item()

            height_scale = match['mul_const'].shape[-4]
            width_scale = match['mul_const'].shape[-2]
        except Exception as ex:
            log.warning(
                'Failed to determine scaling parameters from the topology. Do not apply pattern.'
            )
            return

        reshape2_name = match['reshape_2'].name
        resample_op = Interpolate(
            graph, {
                'mode': 'nearest',
                'antialias': 0,
                'pads_begin': int64_array([0]),
                'pads_end': int64_array([0]),
                'coordinate_transformation_mode': 'half_pixel',
                'nearest_mode': 'round_prefer_floor',
                'cube_coeff': -0.75,
                'version': 'opset4',
                'name': reshape2_name + '/Resample',
                'shape_calculation_mode': 'scales',
                'in_ports_count': 4
            })
        resample_node = resample_op.create_node([match['op']])
        axes_node = Const(
            graph, {
                'name':
                resample_node.name + '/axes',
                'value':
                int64_array([2, 3])
                if graph.graph['layout'] == 'NCHW' else int64_array([1, 2])
            }).create_node()
        sizes_node = Const(
            graph, {
                'value':
                np.array(
                    [input_height * height_scale, input_width * width_scale]),
                'name':
                resample_node.name + '/target_shape'
            }).create_node()
        scales_node = Const(
            graph, {
                'value': np.array([height_scale, width_scale],
                                  dtype=np.float32),
                'name': resample_node.name + '/scales'
            }).create_node()

        match['reshape_2'].replace_node(resample_node)

        resample_node.add_input_port(1, skip_if_exist=True)
        assert resample_node.in_port(1).disconnected()
        sizes_node.out_port(0).connect(resample_node.in_port(1))
        scales_node.out_port(0).connect(resample_node.in_port(2))
        axes_node.out_port(0).connect(resample_node.in_port(3))

        graph.remove_nodes_from(
            [node.id for node in match.values() if node.id != match['op'].id])