예제 #1
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['pb']
        name = node.soft_get('name', node.id)

        graph.graph['cmd_params'].static_shape = False

        assert len(node.in_ports()) == 2

        begin = Const(graph, {'value': np.array([2], dtype=np.int32), 'name': name + '/ss_begin'}).create_node()
        end = Const(graph, {'value': np.array([4], dtype=np.int32), 'name': name + '/ss_end'}).create_node()
        stride = Const(graph, {'value': np.array([1], dtype=np.int32), 'name': name + '/ss_stride'}).create_node()

        shape_0 = Shape(graph, {'name': name + '/0_port'}).create_node()
        ss_0 = StridedSlice(graph, {'name': name + '/ss_0_port',
                                    'begin_mask': np.array([1], dtype=np.int32),
                                    'end_mask': np.array([0], dtype=np.int32),
                                    'new_axis_mask': np.array([0], dtype=np.int32),
                                    'shrink_axis_mask': np.array([0], dtype=np.int32),
                                    'ellipsis_mask': np.array([0], dtype=np.int32)}).create_node()

        shape_0.out_port(0).connect(ss_0.in_port(0))
        begin.out_port(0).connect(ss_0.in_port(1))
        end.out_port(0).connect(ss_0.in_port(2))
        stride.out_port(0).connect(ss_0.in_port(3))

        source = node.in_port(0).get_connection().get_source()
        node.in_port(0).disconnect()
        source.connect(shape_0.in_port(0))
        ss_0.out_port(0).connect(node.in_port(0))

        shape_1 = Shape(graph, {'name': name + '/1_port'}).create_node()
        ss_1 = StridedSlice(graph, {'name': name + '/ss_1_port',
                                    'begin_mask': np.array([1], dtype=np.int32),
                                    'end_mask': np.array([0], dtype=np.int32),
                                    'new_axis_mask': np.array([0], dtype=np.int32),
                                    'shrink_axis_mask': np.array([0], dtype=np.int32),
                                    'ellipsis_mask': np.array([0], dtype=np.int32)}).create_node()

        shape_1.out_port(0).connect(ss_1.in_port(0))
        begin.out_port(0).connect(ss_1.in_port(1))
        end.out_port(0).connect(ss_1.in_port(2))
        stride.out_port(0).connect(ss_1.in_port(3))

        source = node.in_port(1).get_connection().get_source()
        node.in_port(1).disconnect()
        source.connect(shape_1.in_port(0))
        ss_1.out_port(0).connect(node.in_port(1))

        ss_0['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'}
        ss_1['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'}

        node['need_shape_inference'] = True
        node['override_output_shape'] = True
        node['V10_infer'] = True
        unsqueeze = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), {'name': name + '/unsqueeze'})
        naked_priorbox_name = name + '/naked_not_unsqueezed'
        rename_nodes([(node, naked_priorbox_name), (unsqueeze, name)])

        node.out_port(0).get_connection().set_source(unsqueeze.out_port(0))
        node.out_port(0).connect(unsqueeze.in_port(0))
예제 #2
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['pb']
        assert len(node.in_ports()) == 2

        begin = Const(graph, {'value': np.array([2])}).create_node()
        end = Const(graph, {'value': np.array([4])}).create_node()
        stride = Const(graph, {'value': np.array([1])}).create_node()

        shape_0 = Shape(graph, {
            'name': node.name + '/0_port',
            'stop_value_propagation': True
        }).create_node()
        ss_0 = StridedSlice(
            graph, {
                'name': node.name + '/ss_0_port',
                'begin_mask': np.array([1]),
                'end_mask': np.array([0]),
                'new_axis_mask': np.array([0]),
                'shrink_axis_mask': np.array([0]),
                'ellipsis_mask': np.array([0])
            }).create_node()

        shape_0.out_port(0).connect(ss_0.in_port(0))
        begin.out_port(0).connect(ss_0.in_port(1))
        end.out_port(0).connect(ss_0.in_port(2))
        stride.out_port(0).connect(ss_0.in_port(3))

        source = node.in_port(0).get_connection().get_source()
        node.in_port(0).disconnect()
        source.connect(shape_0.in_port(0))
        ss_0.out_port(0).connect(node.in_port(0))

        shape_1 = Shape(graph, {
            'name': node.name + '/1_port',
            'stop_value_propagation': True
        }).create_node()
        ss_1 = StridedSlice(
            graph, {
                'name': node.name + '/ss_1_port',
                'begin_mask': np.array([1]),
                'end_mask': np.array([0]),
                'new_axis_mask': np.array([0]),
                'shrink_axis_mask': np.array([0]),
                'ellipsis_mask': np.array([0])
            }).create_node()

        shape_1.out_port(0).connect(ss_1.in_port(0))
        begin.out_port(0).connect(ss_1.in_port(1))
        end.out_port(0).connect(ss_1.in_port(2))
        stride.out_port(0).connect(ss_1.in_port(3))

        source = node.in_port(1).get_connection().get_source()
        node.in_port(1).disconnect()
        source.connect(shape_1.in_port(0))
        ss_1.out_port(0).connect(node.in_port(1))

        ss_0['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'}
        ss_1['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'}
예제 #3
0
def replace_interpolate_pattern(graph: Graph, match: dict):
    split = match['split']
    scale = int64_array([get_split_scale(split)])
    axis = int(split.in_port(1).get_connection().get_source().node.value)
    split_node_name = split.name

    shape_node = Shape(graph,
                       dict(name=split_node_name + '/Shape_')).create_node()
    scales_node = Const(graph,
                        dict(name=split_node_name + '/scales_',
                             value=scale)).create_node()
    mul_node = Mul(graph, dict(name=split_node_name + '/Mul_')).create_node()
    scales_node.out_port(0).connect(mul_node.in_port(1))

    slice_begin = Const(
        graph,
        dict(name=split_node_name + '/slice_begin_',
             value=int64_array([axis]))).create_node()
    slice_end = Const(
        graph,
        dict(name=split_node_name + '/slice_end_',
             value=int64_array([axis + 1]))).create_node()

    strided_slice_node = StridedSlice(
        graph, {
            'name': split_node_name + '/StridedSlice_',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0]),
        }).create_node([shape_node, slice_begin, slice_end])
    strided_slice_node.out_port(0).connect(mul_node.in_port(0))

    interp_node = Interpolate(
        graph,
        dict(name=split_node_name + '/Interpolate_',
             axes=int64_array([axis]),
             mode='nearest')).create_node()
    mul_node.out_port(0).connect(interp_node.in_port(1))

    match['concat'].out_port(0).get_connection().set_source(
        interp_node.out_port(0))

    split_connection = split.in_port(0).get_connection()
    split_connection.set_destination(interp_node.in_port(0))
    split_connection.get_source().connect(shape_node.in_port(0))
    def append_variances(priors_scale_node: Node, variance: list):
        graph = priors_scale_node.graph
        name = priors_scale_node.name

        sp_shape = Shape(graph, {'name': name + '/shape'}).create_node()
        priors_scale_node.out_port(0).connect(sp_shape.in_port(0))

        begin = Const(graph, {'value': np.array([-2])}).create_node()
        end = Const(graph, {'value': np.array([-1])}).create_node()
        stride = Const(graph, {'value': np.array([1])}).create_node()
        shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': np.array([1]),
                                                     'end_mask': np.array([1]), 'new_axis_mask': np.array([0]),
                                                     'shrink_axis_mask': np.array([0]),
                                                     'ellipsis_mask': np.array([0])}).create_node()

        sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0))
        begin.out_port(0).connect(shape_part_for_tiling.in_port(1))
        end.out_port(0).connect(shape_part_for_tiling.in_port(2))
        stride.out_port(0).connect(shape_part_for_tiling.in_port(3))

        concat_value = Const(graph, {'value': np.array([4])}).create_node()
        shape_concat = Concat(graph, {'name': name + '/shape_for_tiling', 'in_ports_count': 2,
                                      'axis': np.array(0)}).create_node()
        shape_part_for_tiling.out_port(0).connect(shape_concat.in_port(0))
        concat_value.out_port(0).connect(shape_concat.in_port(1))

        variance = Const(graph, {'name': name + '/variance', 'value': np.array(variance)}).create_node()
        tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node()
        variance.out_port(0).connect(tile.in_port(0))
        shape_concat.out_port(0).connect(tile.in_port(1))

        reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node()
        sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node()
        sp_reshape.in_port(0).connect(priors_scale_node.out_port(0))
        sp_reshape.in_port(1).connect(reshape_dim.out_port(0))

        concat = Concat(graph,
                        {'name': name + '/priors_concat', 'axis': np.array(0), 'in_ports_count': 2}).create_node()
        sp_reshape.out_port(0).connect(concat.in_port(0))
        tile.out_port(0).connect(concat.in_port(1))

        output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node()
        output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node()
        concat.out_port(0).connect(output_node.in_port(0))
        output_dims.out_port(0).connect(output_node.in_port(1))

        return output_node
    def placeholder_scales(self, placeholder: Node):
        """
        Helper function to get scales for prior boxes out of input image size:
                [1 / im_width, 1 / im_height, 1 / im_width, 1 / im_height]
        """
        graph = placeholder.graph
        name = placeholder.soft_get('name', placeholder.id)

        shape_value = placeholder.soft_get('shape', None)
        assert shape_value is not None, \
            "[ {} replacer ] Placeholder `{}` should have shape attribute".format(self.replacement_id, name)
        assert isinstance(shape_value, np.ndarray), \
            "[ {} replacer ] Placeholder `{}` shape attribute should be np.ndarray".format(self.replacement_id, name)
        assert shape_value.size == 4, \
            "[ {} replacer ] Placeholder `{}` should be 4D. Shape: {}".format(self.replacement_id, name, shape_value)

        shape = Shape(graph, {'name': 'input_image_shape'}).create_node()
        shape.in_port(0).connect(placeholder.out_port(0))

        begin = Const(graph, {'value': int64_array([1])}).create_node()
        end = Const(graph, {'value': int64_array([3])}).create_node()
        stride = Const(graph, {'value': int64_array([1])}).create_node()
        spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': np.array([1]),
                                       'end_mask': np.array([1]), 'new_axis_mask': np.array([0]),
                                       'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()

        spatial.in_port(0).connect(shape.out_port(0))
        spatial.in_port(1).connect(begin.out_port(0))
        spatial.in_port(2).connect(end.out_port(0))
        spatial.in_port(3).connect(stride.out_port(0))

        power = Const(graph, {'value': float32_array([-1.])}).create_node()
        spatial_scale = Pow(graph, {}).create_node()

        spatial_scale.in_port(0).connect(spatial.out_port(0))
        spatial_scale.in_port(1).connect(power.out_port(0))

        # Power `type_infer` requires inputs to have equal data type
        convert_to_fp32 = Cast(graph, {'dst_type': np.float32}).create_node()
        spatial_scale.in_port(0).get_connection().insert_node(convert_to_fp32)

        order = Const(graph, {'value': int64_array([1, 0])}).create_node()
        axis_const = Const(graph, {'value': int64_array(0)}).create_node()
        reverse = Gather(graph, {}).create_node()

        reverse.in_port(0).connect(spatial_scale.out_port(0))
        reverse.in_port(1).connect(order.out_port(0))
        axis_const.out_port(0).connect(reverse.in_port(2))

        priors_scale_node = Concat(graph, {'axis': 0, 'in_ports_count': 2}).create_node()
        priors_scale_node.add_input_port(0, skip_if_exist=True)
        priors_scale_node.add_input_port(1, skip_if_exist=True)

        priors_scale_node.in_port(0).connect(reverse.out_port(0))
        priors_scale_node.in_port(1).connect(reverse.out_port(0))
        return priors_scale_node
예제 #6
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='Slice'):
            node_name = node.soft_get('name', node.id)

            input_shape = node.in_port(0).data.get_shape()
            if node.is_in_port_connected(3):
                axes = node.in_port(3).data.get_value().copy()
                assert axes is not None, 'The input with axes is not constant for node {}'.format(node_name)
                for i, val in enumerate(axes):
                    axes[i] = get_canonical_axis_index(input_shape, val)
            else:
                axes = int64_array(range(len(input_shape)))

            ss_begin = create_ss_interval_border(graph, node.in_port(1).get_source(), input_shape, axes, node_name)
            ss_end = create_ss_interval_border(graph, node.in_port(2).get_source(), input_shape, axes, node_name)
            node.in_port(1).disconnect()
            node.in_port(2).disconnect()
            rename_nodes([(ss_begin, node_name + '/Begin'), (ss_end, node_name + '/End')])

            if node.is_in_port_connected(4):
                steps = node.in_port(4).data.get_value()
                assert steps is not None, 'The input with steps is not constant for node {}'.format(node_name)
            else:
                steps = np.ones([axes.size])

            ss_begin_mask = np.zeros(len(input_shape), dtype=np.int64)
            ss_end_mask = np.zeros(len(input_shape), dtype=np.int64)
            ss_step = np.ones(len(input_shape), dtype=np.int64)

            for i, axis in enumerate(axes):
                ss_begin_mask[axis] = 1
                ss_end_mask[axis] = 1
                ss_step[axis] = steps[i]

            ss_strides = Const(graph, dict(name=node_name + '/Strides', value=ss_step)).create_node()

            ss = StridedSlice(graph, dict(name='ss', new_axis_mask=np.zeros(len(input_shape), dtype=np.int64),
                                          shrink_axis_mask=np.zeros(len(input_shape), dtype=np.int64),
                                          ellipsis_mask=np.zeros(len(input_shape), dtype=np.int64),
                                          begin_mask=ss_begin_mask,
                                          end_mask=ss_end_mask)).create_node()

            node.in_port(0).get_connection().set_destination(ss.in_port(0))
            ss.in_port(1).connect(ss_begin.out_port(0))
            ss.in_port(2).connect(ss_end.out_port(0))
            ss.in_port(3).connect(ss_strides.out_port(0))
            node.out_port(0).get_connection().set_source(ss.out_port(0))

            rename_nodes([(node, node_name + '/ShouldBeDeleted'), (ss, node_name)])
예제 #7
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']

        strided_slice_node = StridedSlice(
            graph,
            dict(name=node.id + '/strided_slice_',
                 shrink_axis_mask=np.array(
                     np.zeros(len(node.crop_begin), dtype=np.int64)),
                 new_axis_mask=np.array(
                     np.zeros(len(node.crop_begin), dtype=np.int64)),
                 ellipsis_mask=np.array(
                     np.zeros(len(node.crop_begin), dtype=np.int64)),
                 begin_mask=np.array(
                     np.ones(len(node.crop_begin), dtype=np.int64)),
                 end_mask=np.array(np.ones(len(node.crop_end),
                                           dtype=np.int64)))).create_node()
        node.in_port(0).get_connection().set_destination(
            strided_slice_node.in_port(0))
        node.out_port(0).get_connection().set_source(
            strided_slice_node.out_port(0))

        crop_begin_node = Const(
            graph,
            dict(value=node.crop_begin,
                 symbol_dict={'name':
                              node.id + '/crop_begin_const'})).create_node()
        crop_end_node = Const(
            graph,
            dict(value=node.crop_end,
                 symbol_dict={'name':
                              node.id + '/crop_end_const'})).create_node()
        strided_slice_node.in_port(1).get_connection().set_source(
            crop_begin_node.out_port(0))
        strided_slice_node.in_port(2).get_connection().set_source(
            crop_end_node.out_port(0))

        if len(node.step) > 0:
            stride_node = Const(
                graph,
                dict(value=node.step,
                     symbol_dict={'name':
                                  node.id + '/steps_const'})).create_node()
            strided_slice_node.in_port(3).get_connection().set_source(
                stride_node.out_port(0))
예제 #8
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        log.debug('UpsampleToResample is triggered')
        upsample = match['upsample']
        input_shape = upsample.in_port(0).data.get_shape()
        input_shape_rank = len(input_shape)
        if input_shape_rank not in [4, 5]:
            log.warning('The input shape is not 4D or 5D for op {}'.format(
                upsample.soft_get('name')))
            return

        if len(upsample.in_nodes()) == 2:
            if upsample.in_node(1).value is None:
                return
            scales = upsample.in_node(1).value
            assert scales.shape == (4, )
            if not (math.isclose(scales[0], 1, rel_tol=1e-5)
                    and math.isclose(scales[1], 1, rel_tol=1e-5)):
                return
            height_scale = scales[2]
            width_scale = scales[3]
        else:
            height_scale = upsample['height_scale']
            width_scale = upsample['width_scale']

        if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
            upsample.in_port(1).disconnect()

        factor = Const(graph, {
            'value': np.array([height_scale, width_scale])
        }).create_node()

        shape = Shape(graph, {'name': upsample.name + '/0_port'}).create_node()

        layout = graph.graph['layout']
        if input_shape_rank == 4:
            begin = Const(graph, {
                'value':
                int64_array([get_height_dim(layout, input_shape_rank)])
            }).create_node()
        else:
            begin = Const(graph, {
                'value':
                int64_array([get_depth_dim(layout, input_shape_rank)])
            }).create_node()
        end = Const(graph, {
            'value':
            int64_array([get_width_dim(layout, input_shape_rank) + 1])
        }).create_node()

        stride = Const(graph, {'value': int64_array([1])}).create_node()
        ss = StridedSlice(
            graph, {
                'name': upsample.name + '/ss_0_port',
                'begin_mask': np.array([1]),
                'end_mask': np.array([0]),
                'new_axis_mask': np.array([0]),
                'shrink_axis_mask': int64_array([0]),
                'ellipsis_mask': int64_array([0])
            }).create_node()

        mul = Mul(graph, {
            'name': upsample.name + '/factor_mul_'
        }).create_node()

        source = upsample.in_port(0).get_connection().get_source()
        source.connect(shape.in_port(0))
        shape.out_port(0).connect(ss.in_port(0))
        begin.out_port(0).connect(ss.in_port(1))
        end.out_port(0).connect(ss.in_port(2))
        stride.out_port(0).connect(ss.in_port(3))
        ss.out_port(0).connect(mul.in_port(0))
        factor.out_port(0).connect(mul.in_port(1))

        # Create Interpolate operation
        if input_shape_rank == 4:
            axes = int64_array([
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])
        else:
            axes = int64_array([
                get_depth_dim(layout, input_shape_rank),
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])

        resample_op = Interpolate(
            graph,
            dict(name='Interpolate/{}'.format(upsample.name),
                 axes=axes,
                 mode=upsample.attrs()['mode'],
                 antialias=0,
                 convert_to_resample=True)).create_node()

        upsample.add_input_port(1, skip_if_exist=True)
        assert upsample.in_port(1).disconnected()
        mul.out_port(0).connect(resample_op.in_port(1))

        upsample.in_port(0).get_connection().set_destination(
            resample_op.in_port(0))
        upsample.out_port(0).get_connection().set_source(
            resample_op.out_port(0))
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['node']

        if 2 in node.in_ports() and not node.in_port(2).disconnected():
            in_rank = node.in_port(0).data.get_shape().size

            shape_src = node.in_port(2).get_source()
            node.in_port(2).disconnect()

            begin = Const(graph, {
                'value': np.array([2], dtype=np.int32)
            }).create_node()
            end = Const(graph, {
                'value': np.array([in_rank], dtype=np.int32)
            }).create_node()
            stride = Const(graph, {
                'value': np.array([1], dtype=np.int32)
            }).create_node()

            ss_0 = StridedSlice(
                graph, {
                    'name': node.name + '/ss_0_port',
                    'begin_mask': np.array([1], dtype=np.int32),
                    'end_mask': np.array([0], dtype=np.int32),
                    'new_axis_mask': np.array([0], dtype=np.int32),
                    'shrink_axis_mask': np.array([0], dtype=np.int32),
                    'ellipsis_mask': np.array([0], dtype=np.int32)
                }).create_node()

            shape_src.connect(ss_0.in_port(0))
            begin.out_port(0).connect(ss_0.in_port(1))
            end.out_port(0).connect(ss_0.in_port(2))
            stride.out_port(0).connect(ss_0.in_port(3))

            ss_0.out_port(0).connect(node.in_port(2))

            del node['pad']

        group = node.soft_get('group', 1)

        if group != 1:
            assert group > 1

            weights_shape = node.in_port(1).data.get_shape()
            assert weights_shape is not None
            I = node.in_port(0).data.get_shape()[1]
            assert I % group == 0
            assert node.output % group == 0

            new_shape = int64_array(
                [group, I / group, node.output / group, *weights_shape[2:]])

            assert np.prod(weights_shape) == np.prod(new_shape), \
                'Initial weights shape {}, grouped weights shape {}'.format(weights_shape, new_shape)
            reshape = create_op_node_with_second_input(
                graph, Reshape, int64_array(new_shape),
                {'override_output_shape': True},
                node.in_port(1).get_source().node)

            node.in_port(1).get_connection().set_source(reshape.out_port(0))

            node['type'] = 'GroupConvolutionBackpropData'
        else:
            node['type'] = 'ConvolutionBackpropData'
예제 #10
0
    def replace_pattern(self, graph: Graph, match: [str, Node]):
        node = match['crop']
        assert node.has_valid('axis')
        node.axis = self.list_to_ndarray(node.axis)

        in_shape = node.in_port(0).data.get_shape()
        shape_rank = in_shape.size
        axis_mask = int64_array(
            [1 if i in node.axis else 0 for i in range(shape_rank)])
        begin_mask = axis_mask.copy()
        end_mask = axis_mask.copy()

        if len(node.in_nodes()) == 2 and node.has_valid('offset'):
            # Crop Type 1
            begin = Const(graph, {
                'value':
                self.mask_normalizer(shape_rank, node.axis, node.offset)
            }).create_node()
            shape = Shape(graph, {
                'name': node.name + '/shape_of_crop'
            }).create_node()
            end = Add(graph, {'name': node.name + '/end'}).create_node()
            node.in_port(1).get_connection().get_source().connect(
                shape.in_port(0))
            node.in_port(1).disconnect()
            shape.out_port(0).connect(end.in_port(0))
            begin.out_port(0).connect(end.in_port(1))
        elif node.has_valid('dim') and node.has_valid('offset'):
            # Crop Type 2
            node.dim = self.list_to_ndarray(node.dim)
            node.offset = self.list_to_ndarray(node.offset)
            assert node.dim.size == node.offset.size == node.axis.size

            begin = Const(graph, {
                'value':
                self.mask_normalizer(shape_rank, node.axis, node.offset)
            }).create_node()
            end_values = np.array(
                [node.offset[i] + node.dim[i] for i in range(len(node.dim))])
            end = Const(graph, {
                'value':
                self.mask_normalizer(shape_rank, node.axis, end_values)
            }).create_node()
        elif node.has_valid('crop_begin') and node.has_valid('crop_end'):
            # Crop Type 3
            node.crop_begin = self.list_to_ndarray(node.crop_begin)
            node.crop_end = self.list_to_ndarray(node.crop_end)
            assert len(node.crop_begin) == len(node.crop_end) == len(node.axis)

            begin = Const(
                graph, {
                    'value':
                    self.mask_normalizer(shape_rank, node.axis,
                                         node.crop_begin)
                }).create_node()
            shape = Shape(graph, {
                'name': node.name + '/shape_of_crop'
            }).create_node()
            const = Const(
                graph, {
                    'value':
                    -1 *
                    self.mask_normalizer(shape_rank, node.axis, node.crop_end)
                }).create_node()
            end = Add(graph, {'name': node.name + '/end'}).create_node()

            node.in_port(0).get_connection().get_source().connect(
                shape.in_port(0))
            shape.out_port(0).connect(end.in_port(0))
            const.out_port(0).connect(end.in_port(1))

        else:
            raise Exception("Unknown type of Crop")

        source = node.in_port(0).get_connection().get_source()

        stride = Const(graph, {
            'value': np.ones(shape_rank, dtype=np.int64)
        }).create_node()
        ss = StridedSlice(
            graph, {
                'name': 'Crop_',
                'begin_mask': begin_mask,
                'end_mask': end_mask,
                'new_axis_mask': np.array([0]),
                'shrink_axis_mask': np.array([0]),
                'ellipsis_mask': np.array([0])
            }).create_node()

        source.connect(ss.in_port(0))
        begin.out_port(0).connect(ss.in_port(1))
        end.out_port(0).connect(ss.in_port(2))
        stride.out_port(0).connect(ss.in_port(3))

        node.in_port(0).disconnect()
        node.out_port(0).get_connection().set_source(ss.out_port(0))

        ss['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'}
예제 #11
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']

        if 1 not in node.in_ports() or node.in_port(1).disconnected():

            if node.has_valid('factor') and not node.has_valid('width') and not node.has_valid('height'):
                factor = Const(graph, {'value': np.array(node.factor)}).create_node()

                shape = Shape(graph, {'name': node.name + '/shape'}).create_node()

                begin = Const(graph, {'value': np.array([2])}).create_node()
                end = Const(graph, {'value': np.array([4])}).create_node()
                stride = Const(graph, {'value': np.array([1])}).create_node()
                ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]),
                                          'end_mask': np.array([0]), 'new_axis_mask': np.array([0]),
                                          'shrink_axis_mask': np.array([0]),
                                          'ellipsis_mask': np.array([0])}).create_node()

                mul = Mul(graph, {'name': node.name + '/factor_mul_'}).create_node()

                source = node.in_port(0).get_connection().get_source()
                source.connect(shape.in_port(0))
                shape.out_port(0).connect(ss.in_port(0))
                begin.out_port(0).connect(ss.in_port(1))
                end.out_port(0).connect(ss.in_port(2))
                stride.out_port(0).connect(ss.in_port(3))
                ss.out_port(0).connect(mul.in_port(0))
                factor.out_port(0).connect(mul.in_port(1))

                node.add_input_port(1, skip_if_exist=True)
                assert node.in_port(1).disconnected()
                mul.out_port(0).connect(node.in_port(1))

            else:
                shape = Shape(graph, {'name': node.name + '/shape'}).create_node()

                begin = Const(graph, {'value': np.array([2])}).create_node()
                end = Const(graph, {'value': np.array([4])}).create_node()
                stride = Const(graph, {'value': np.array([1])}).create_node()
                ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]),
                                          'end_mask': np.array([0]), 'new_axis_mask': np.array([0]),
                                          'shrink_axis_mask': np.array([0]),
                                          'ellipsis_mask': np.array([0])}).create_node()

                source = node.in_port(0).get_connection().get_source()
                source.connect(shape.in_port(0))
                shape.out_port(0).connect(ss.in_port(0))
                begin.out_port(0).connect(ss.in_port(1))
                end.out_port(0).connect(ss.in_port(2))
                stride.out_port(0).connect(ss.in_port(3))

                pads_value = node.pads_begin + node.pads_end
                pads_const = Const(graph, {'value': np.array(pads_value)}).create_node()
                add = Add(graph, {'name': node.name + '/pad_add'}).create_node()
                ss.out_port(0).connect(add.in_port(0))
                add.in_port(1).connect(pads_const.out_port(0))

                if node.soft_get('shrink_factor') != 1 and node.soft_get('zoom_factor') == 1:
                    shrink_factor = node.shrink_factor
                    if shrink_factor < 1:
                        log.error('Shrink factor should be positive in node {}'.format(node.id))
                        return None

                    const = Const(graph, {'name': node.name + '/pre_shrink_sub_const',
                                          'value': np.array(-1)}).create_node()
                    sub = Add(graph, {'name': node.name + '/pre_shrink_sub'}).create_node()
                    add.out_port(0).connect(sub.in_port(0))
                    sub.in_port(1).connect(const.out_port(0))

                    const = Const(graph, {'value': np.array(1 / shrink_factor),
                                          'name': node.name + 'shrink_factor_div_const'}).create_node()
                    div = Mul(graph, {'name': node.name + 'shrink_factor_div'}).create_node()
                    sub.out_port(0).connect(div.in_port(0))
                    div.in_port(1).connect(const.out_port(0))

                    const = Const(graph, {'name': node.name + '/shrink_factor_add_one_const', 'value': np.array(1)
                                          }).create_node()
                    add = Add(graph, {'name': node.name + '/shrink_factor_add_one'}).create_node()
                    div.out_port(0).connect(add.in_port(0))
                    const.out_port(0).connect(add.in_port(1))

                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    add.out_port(0).connect(node.in_port(1))

                elif node.soft_get('shrink_factor') == 1 and node.soft_get('zoom_factor') != 1:
                    zoom_factor = node.zoom_factor
                    if zoom_factor < 1:
                        log.error('Zoom factor should be positive in node {}'.format(node.id))
                        return None

                    node['debug_message'] = 'Interpolate layer replacer may be wrong, please, try to update it in the' \
                                            ' file (extensions/front/InterpolateNormalizer.py at the line {}).' \
                                            ''.format(inspect.currentframe().f_lineno) + refer_to_faq_msg(100)

                    # Reshape methods can be different in some cases
                    # Commented out section represents reshape that used in deeplab-caffe
                    # Uncomment the following lines, if your model was trained with deeplab-caffe
                    # or have the same reshape method
                    # const = Const(graph, {'value': np.array(-1),
                    #                       'name': node.name + 'zoom_factor_deeplab-caffe_sub_const'}).create_node()
                    # sub = Add(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_sub'}).create_node()
                    # add.out_port(0).connect(sub.in_port(0))
                    # const.out_port(0).connect(sub.in_port(1))
                    #
                    # const = Const(graph, {'value': np.array(zoom_factor - 1),
                    #                       'name': node.name + 'zoom_factor_deeplab-caffe_mul_const'}).create_node()
                    # mul = Mul(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_mul'}).create_node()
                    # sub.out_port(0).connect(mul.in_port(0))
                    # const.out_port(0).connect(mul.in_port(1))
                    #
                    # sum = Add(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_sum'}).create_node()
                    # add.out_port(0).connect(sum.in_port(0))
                    # mul.out_port(0).connect(sum.in_port(1))
                    #
                    # node.add_input_port(1, skip_if_exist=True)
                    # assert node.in_port(1).disconnected()
                    # sum.out_port(0).connect(node.in_port(1))

                    # Comment out the following lines if you use the reshape method from previous section
                    const = Const(graph, {'value': np.array(zoom_factor),
                                          'name': node.name + '/zoom_factor_mul_const'}).create_node()
                    mul = Mul(graph, {'name': node.name + '/zoom_factor_mul'}).create_node()

                    add.out_port(0).connect(mul.in_port(0))
                    const.out_port(0).connect(mul.in_port(1))

                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    mul.out_port(0).connect(node.in_port(1))

                elif node.soft_get('width') != 0 and node.soft_get('height') != 0:
                    const = Const(graph, {'value': np.array([node.height, node.width])}).create_node()
                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    const.out_port(0).connect(node.in_port(1))

                elif node.soft_get('shrink_factor') != 1 and node.soft_get('zoom_factor') != 1:
                    shrink_factor = node.shrink_factor
                    zoom_factor = node.zoom_factor
                    if shrink_factor < 1:
                        log.error('Shrink factor should be positive in node {}'.format(node.id))
                        return None
                    if zoom_factor < 1:
                        log.error('Zoom factor should be positive in node {}'.format(node.id))
                        return None

                    const = Const(graph, {'value': np.array(-1)}).create_node()
                    sub = Add(graph, {'name': node.name + '/shrink_zoom_factor_sub'}).create_node()
                    add.out_port(0).connect(sub.in_port(0))
                    const.out_port(0).connect(sub.in_port(1))

                    const = Const(graph, {'value': np.array(1 / (shrink_factor + 1))}).create_node()
                    div = Mul(graph, {'name': node.name + '/shrink_factor_div'}).create_node()
                    sub.out_port(0).connect(div.in_port(0))
                    const.out_port(0).connect(div.in_port(1))

                    const = Const(graph, {'value': np.array(-1),
                                          'name': node.name + 'shrink_zoom_factor_sum_const'}).create_node()
                    sum = Add(graph, {'name': node.name + '/shrink_zoom_factor_sum'}).create_node()
                    div.out_port(0).connect(sum.in_port(0))
                    const.out_port(0).connect(sum.in_port(1))

                    const = Const(graph, {'value': np.array(zoom_factor - 1)}).create_node()
                    mul = Mul(graph, {'name': node.name + '/zoom_factor_mul'}).create_node()
                    sum.out_port(0).connect(mul.in_port(0))
                    const.out_port(0).connect(mul.in_port(1))

                    sum = Add(graph, {'name': node.name + '/final_shrink_zoom_factor_sum'}).create_node()
                    div.out_port(0).connect(sum.in_port(0))
                    mul.out_port(0).connect(sum.in_port(1))

                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    sum.out_port(0).connect(node.in_port(1))
        else:
            if node.soft_get('fw') == 'caffe':
                shape = Shape(graph, {'name': node.name + '/shape'}).create_node()

                begin = Const(graph, {'value': np.array([2])}).create_node()
                end = Const(graph, {'value': np.array([4])}).create_node()
                stride = Const(graph, {'value': np.array([1])}).create_node()
                ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]),
                                          'end_mask': np.array([0]), 'new_axis_mask': np.array([0]),
                                          'shrink_axis_mask': np.array([0]),
                                          'ellipsis_mask': np.array([0])}).create_node()

                source = node.in_port(1).get_connection().get_source()
                node.in_port(1).disconnect()
                source.connect(shape.in_port(0))
                shape.out_port(0).connect(ss.in_port(0))
                begin.out_port(0).connect(ss.in_port(1))
                end.out_port(0).connect(ss.in_port(2))
                stride.out_port(0).connect(ss.in_port(3))
                ss.out_port(0).connect(node.in_port(1))
예제 #12
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['node']

        if 2 in node.in_ports() and not node.in_port(2).disconnected():
            # Third input represents output shape. Cutting its value according to scheme:
            # [N, C, spatial_dim_0, ..., spatial_dim_n] -> [spatial_dim_0, ..., spatial_dim_n]
            in_rank = node.in_port(0).data.get_shape().size

            shape_src = node.in_port(2).get_source()
            node.in_port(2).disconnect()

            begin = Const(graph, {
                'value': np.array([2], dtype=np.int32)
            }).create_node()
            end = Const(graph, {
                'value': np.array([in_rank], dtype=np.int32)
            }).create_node()
            stride = Const(graph, {
                'value': np.array([1], dtype=np.int32)
            }).create_node()

            ss_0 = StridedSlice(
                graph, {
                    'name': node.name + '/ss_0_port',
                    'begin_mask': np.array([1], dtype=np.int32),
                    'end_mask': np.array([0], dtype=np.int32),
                    'new_axis_mask': np.array([0], dtype=np.int32),
                    'shrink_axis_mask': np.array([0], dtype=np.int32),
                    'ellipsis_mask': np.array([0], dtype=np.int32)
                }).create_node()

            shape_src.connect(ss_0.in_port(0))
            begin.out_port(0).connect(ss_0.in_port(1))
            end.out_port(0).connect(ss_0.in_port(2))
            stride.out_port(0).connect(ss_0.in_port(3))

            ss_0.out_port(0).connect(node.in_port(2))

            # Specification: *padding amount* is deduced from relation of input and output spatial shapes
            del node['pad']

        elif node.has_valid('original_output_spatial_shape'):
            # node had fixed output spatial shape set in original framework, so we restore it here
            const = Const(
                graph, {
                    'value': int64_array(node.original_output_spatial_shape)
                }).create_node()
            node.add_input_port(2, skip_if_exist=True)
            const.out_port(0).connect(node.in_port(2))

            # Specification: *padding amount* is deduced from relation of input and output spatial shapes
            del node['pad']

        group = node.soft_get('group', 1)

        if group != 1:
            assert group > 1

            weights_shape = node.in_port(1).data.get_shape()
            assert weights_shape is not None
            I = node.in_port(0).data.get_shape()[1]
            assert I % group == 0
            assert node.output % group == 0

            new_shape = int64_array(
                [group, I / group, node.output / group, *weights_shape[2:]])

            assert np.prod(weights_shape) == np.prod(new_shape), \
                'Initial weights shape {}, grouped weights shape {}'.format(weights_shape, new_shape)
            reshape = create_op_node_with_second_input(
                graph, Reshape, int64_array(new_shape),
                {'override_output_shape': True},
                node.in_port(1).get_source().node)

            node.in_port(1).get_connection().set_source(reshape.out_port(0))

            node['type'] = 'GroupConvolutionBackpropData'
        else:
            node['type'] = 'ConvolutionBackpropData'
    def replace_pattern(self, graph: Graph, match: dict):
        unsqueeze_node = match['unsqueeze']
        unsqueeze_name = unsqueeze_node.name

        second_input_of_unsqueeze = unsqueeze_node.in_port(
            1).get_connection().get_source().node
        if not second_input_of_unsqueeze.has_valid('value'):
            return

        d_idx = int(second_input_of_unsqueeze.value)

        second_input_of_tile = match['tile'].in_port(
            1).get_connection().get_source().node
        if not second_input_of_tile.has_valid('value'):
            return

        input_shape_of_unsqueeze = unsqueeze_node.in_port(0).data.get_shape()
        if len(input_shape_of_unsqueeze) not in {4, 5}:
            return

        scale = float32_array([second_input_of_tile.value[d_idx]])
        axis = d_idx - 1
        axis_node = Const(graph, {
            'name': unsqueeze_name + '/axis',
            'value': int64_array([axis])
        }).create_node()

        shape_node = Shape(graph,
                           dict(name=unsqueeze_name + '/Shape')).create_node()
        scales_node = Const(graph,
                            dict(name=unsqueeze_name + '/scales',
                                 value=scale)).create_node()
        mul_node = Mul(graph, dict(name=unsqueeze_name + '/Mul')).create_node()
        scales_node.out_port(0).connect(mul_node.in_port(1))

        slice_begin = Const(
            graph,
            dict(name=unsqueeze_name + '/slice_begin',
                 value=int64_array([axis]))).create_node()
        slice_end = Const(
            graph,
            dict(name=unsqueeze_name + '/slice_end',
                 value=int64_array([axis + 1]))).create_node()

        strided_slice_node = StridedSlice(
            graph, {
                'name': unsqueeze_name + '/StridedSlice',
                'begin_mask': int64_array([1]),
                'end_mask': int64_array([1]),
                'new_axis_mask': int64_array([0]),
                'shrink_axis_mask': int64_array([0]),
                'ellipsis_mask': int64_array([0]),
            }).create_node()
        shape_node.out_port(0).connect(strided_slice_node.in_port(0))
        slice_begin.out_port(0).connect(strided_slice_node.in_port(1))
        slice_end.out_port(0).connect(strided_slice_node.in_port(2))

        cast_shape_to_float = Cast(graph, {
            'dst_type': np.float32
        }).create_node()

        strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0))
        cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))

        interp_node = Interpolate(
            graph,
            dict(mode='nearest',
                 antialias=0,
                 pads_begin=int64_array([0]),
                 pads_end=int64_array([0]),
                 coordinate_transformation_mode='half_pixel',
                 nearest_mode='round_prefer_floor',
                 cube_coeff=-0.75,
                 version='opset4',
                 shape_calculation_mode='scales',
                 in_ports_count=4,
                 maybe_part_of_sequence=True)).create_node()

        floor_node = Floor(graph, {
            'name': unsqueeze_name + '/Floor'
        }).create_node()
        cast_mul_result_to_int = Cast(graph, {
            'dst_type': np.int64
        }).create_node()

        mul_node.out_port(0).connect(floor_node.in_port(0))
        floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0))

        cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1))
        scales_node.out_port(0).connect(interp_node.in_port(2))
        axis_node.out_port(0).connect(interp_node.in_port(3))

        reshape_node = match['reshape']

        reshape_node.out_port(0).get_connection().set_source(
            interp_node.out_port(0))
        reshape_name = reshape_node.soft_get('name', reshape_node.id)
        rename_nodes([(reshape_node, reshape_name + '/delete'),
                      (interp_node, reshape_name)])

        unsqueeze_connection = match['unsqueeze'].in_port(0).get_connection()
        before_unsqueeze = unsqueeze_connection.get_source().node
        unsqueeze_connection.set_destination(interp_node.in_port(0))
        before_unsqueeze.out_port(0).connect(shape_node.in_port(0))
예제 #14
0
    def replace_pattern(self, graph: Graph, match: dict):
        unsqueeze_node = match['unsqueeze']
        unsqueeze_name = unsqueeze_node.name

        second_input_of_unsqueeze = unsqueeze_node.in_port(
            1).get_connection().get_source().node
        if not second_input_of_unsqueeze.has_valid('value'):
            return

        d_idx = int(second_input_of_unsqueeze.value)

        second_input_of_tile = match['tile'].in_port(
            1).get_connection().get_source().node
        if not second_input_of_tile.has_valid('value'):
            return

        input_shape_of_unsqueeze = unsqueeze_node.in_port(0).data.get_shape()
        if len(input_shape_of_unsqueeze) not in {4, 5}:
            return

        scale = int64_array([second_input_of_tile.value[d_idx]])
        axis = d_idx - 1

        shape_node = Shape(graph, dict(name=unsqueeze_name +
                                       '/Shape_')).create_node()
        scales_node = Const(
            graph, dict(name=unsqueeze_name + '/scales_',
                        value=scale)).create_node()
        mul_node = Mul(graph,
                       dict(name=unsqueeze_name + '/Mul_')).create_node()
        scales_node.out_port(0).connect(mul_node.in_port(1))

        slice_begin = Const(
            graph,
            dict(name=unsqueeze_name + '/slice_begin_',
                 value=int64_array([axis]))).create_node()
        slice_end = Const(
            graph,
            dict(name=unsqueeze_name + '/slice_end_',
                 value=int64_array([axis + 1]))).create_node()

        strided_slice_node = StridedSlice(
            graph, {
                'name': unsqueeze_name + '/StridedSlice_',
                'begin_mask': int64_array([1]),
                'end_mask': int64_array([1]),
                'new_axis_mask': int64_array([0]),
                'shrink_axis_mask': int64_array([0]),
                'ellipsis_mask': int64_array([0]),
            }).create_node()
        shape_node.out_port(0).connect(strided_slice_node.in_port(0))
        slice_begin.out_port(0).connect(strided_slice_node.in_port(1))
        slice_end.out_port(0).connect(strided_slice_node.in_port(2))
        strided_slice_node.out_port(0).connect(mul_node.in_port(0))

        interp_node = Interpolate(
            graph,
            dict(name=unsqueeze_name + '/Interpolate_',
                 axes=int64_array([axis]),
                 mode='nearest')).create_node()
        mul_node.out_port(0).connect(interp_node.in_port(1))

        match['reshape'].out_port(0).get_connection().set_source(
            interp_node.out_port(0))

        unsqueeze_connection = match['unsqueeze'].in_port(0).get_connection()
        before_unsqueeze = unsqueeze_connection.get_source().node
        unsqueeze_connection.set_destination(interp_node.in_port(0))
        before_unsqueeze.out_port(0).connect(shape_node.in_port(0))
예제 #15
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['slice']

        input = node.in_node(0)
        output_data = node.out_node()

        # ONNX 10 opset case
        if len(node.in_nodes()) >= 3 and node.has_valid(
                'format') and node['format'] == 'onnx':
            self.convert_onnx_slice_opset10(node)
            return

        # Caffe case
        if not node.has_valid('start') or not node.has_valid('end'):
            return

        begin = node.start
        end = node.end
        axis = node.axis if node.has_valid('axis') else np.arange(begin.size)

        # Check whether operation use only one axis or not
        axes_begin = np.zeros(len(input.shape), dtype=np.int32)
        axes_end = np.zeros(len(input.shape), dtype=np.int32)
        ss_begin = np.zeros(len(input.shape), dtype=np.int32)
        ss_end = np.zeros(len(input.shape), dtype=np.int32)
        dims = 0
        axes = np.zeros(begin.size)
        for i in range(len(axis)):
            if begin[i] != 0 or end[i] < input.shape[axis[i]]:
                dims += 1
                axes[i] = 1
                if begin[i] != 0:
                    axes_begin[axis[i]] = 1
                    ss_begin[axis[i]] = begin[i]
                if end[i] < input.shape[axis[i]]:
                    axes_end[axis[i]] = 1
                    ss_end[axis[i]] = end[i]
        axes = np.array(axes, dtype=bool)

        slice_node_name = node.soft_get('name', node.id)

        if dims == 1 or dims == 0:
            # If Slice use only one axis or no axis, than
            # convert Slice to StridedSlice
            ss = StridedSlice(
                graph,
                dict(new_axis_mask=np.zeros(len(output_data.shape),
                                            dtype=np.int32),
                     shrink_axis_mask=np.zeros(len(output_data.shape),
                                               dtype=np.int32),
                     ellipsis_mask=np.zeros(len(output_data.shape),
                                            dtype=np.int32),
                     begin_mask=axes_begin,
                     end_mask=axes_end)).create_node()

            convert_negative_indices(ss_begin, input.shape)
            convert_negative_indices(ss_end, input.shape)

            begin_node = Const(graph, {
                'value': ss_begin,
                'name': slice_node_name + '/begin'
            }).create_node()
            end_node = Const(graph, {
                'value': ss_end,
                'name': slice_node_name + '/end'
            }).create_node()

            rename_nodes([(node, slice_node_name + '_delete'),
                          (ss, slice_node_name)])

            node.in_port(0).get_connection().set_destination(ss.in_port(0))
            begin_node.out_port(0).connect(ss.in_port(1))
            end_node.out_port(0).connect(ss.in_port(2))
            node.out_port(0).get_connection().set_source(ss.out_port(0))
        else:
            # If Slice use more than one axis use Crop layer
            crop = Crop(
                graph,
                dict(axis=axis[axes],
                     offset=begin[axes],
                     dim=end[axes] - begin[axes])).create_node()
            rename_nodes([(node, slice_node_name + '_delete'),
                          (crop, slice_node_name)])

            node.in_port(0).get_connection().set_destination(crop.in_port(0))
            node.out_port(0).get_connection().set_source(crop.out_port(0))
예제 #16
0
    def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
        reshape_classes_node = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            dict(name='do_reshape_classes'),
            match.single_input_node(1)[0])

        initial_priors_node = match.single_input_node(2)[0]
        priors_name = initial_priors_node.soft_get('name',
                                                   initial_priors_node.id)
        # model calculates identical prior boxes for each batch, so we take first slice of them
        begin = Const(graph, {
            'value': np.array([0, 0, 0], dtype=np.int32)
        }).create_node()
        end = Const(graph, {
            'value': np.array([1, 0, 0], dtype=np.int32)
        }).create_node()
        stride = Const(graph, {
            'value': np.array([1, 1, 1], dtype=np.int32)
        }).create_node()

        priors_node = StridedSlice(
            graph, {
                'name': priors_name + '/0_batch_slice',
                'begin_mask': np.array([1, 1, 1], dtype=np.int32),
                'end_mask': np.array([1, 0, 0], dtype=np.int32),
                'new_axis_mask': np.array([0], dtype=np.int32),
                'shrink_axis_mask': np.array([0], dtype=np.int32),
                'ellipsis_mask': np.array([0], dtype=np.int32)
            }).create_node()

        initial_priors_node.out_port(0).connect(priors_node.in_port(0))
        begin.out_port(0).connect(priors_node.in_port(1))
        end.out_port(0).connect(priors_node.in_port(2))
        stride.out_port(0).connect(priors_node.in_port(3))

        placeholders = graph.get_op_nodes(type='Parameter')
        assert len(placeholders) == 1, "{} replacer requires model to have one Placeholder, but current model has " \
                                       "{} placeholders".format(self.replacement_id, len(placeholders))
        placeholder = placeholders[0]

        # scale prior boxes to the [0, 1] interval
        node_with_scales_for_prior_boxes = self.placeholder_scales(placeholder)
        priors_scale_node = Mul(graph, {'name': 'scale_priors'}).create_node()

        broadcast = Broadcast(graph, {
            'name': 'scales_broadcast'
        }).create_node()
        shape_of_priors = Shape(graph, {'name': 'priors_shape'}).create_node()
        priors_node.out_port(0).connect(shape_of_priors.in_port(0))
        broadcast.in_port(1).connect(shape_of_priors.out_port(0))
        broadcast.in_port(0).connect(
            node_with_scales_for_prior_boxes.out_port(0))

        priors_scale_node.in_port(0).connect(priors_node.out_port(0))
        priors_scale_node.in_port(1).connect(broadcast.out_port(0))

        try:
            variance = match.custom_replacement_desc.custom_attributes[
                'variance']
        except:
            raise Error(
                'There is no variance attribute in {} replacement config file `custom_attributes`'
                ''.format(self.replacement_id))

        priors = self.append_variances(priors_scale_node, variance)

        # calculate prior boxes widths and heights
        split_node = create_op_with_const_inputs(graph, VariadicSplit, {
            1: int64_array(2),
            2: int64_array([1, 1, 1, 1])
        }, {'out_ports_count': 4}, priors_scale_node)

        priors_width_node = Sub(
            graph, dict(name=split_node.name + '/sub_2-0_')).create_node([
                (split_node, 2), (split_node, 0)
            ])
        priors_height_node = Sub(graph, dict(name=split_node.name +
                                             '/sub_3-1_')).create_node([
                                                 (split_node, 3),
                                                 (split_node, 1)
                                             ])

        # concat weights and heights into a single tensor and multiple with the box coordinates regression values
        # WA with 3 Concats instead of 1 for keeping model reshapable
        # concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1,
        #                                           'in_ports_count': 4}).create_node(
        # [priors_width_node, priors_height_node, priors_width_node, priors_height_node])

        concat_1 = Concat(graph, {
            'name': 'concat_width_height',
            'axis': -1,
            'in_ports_count': 2
        }).create_node([priors_width_node, priors_height_node])
        concat_2 = Concat(graph, {
            'name': 'concat_width_height_width',
            'axis': -1,
            'in_ports_count': 2
        }).create_node([concat_1, priors_width_node])
        concat_width_height_node = Concat(graph, {
            'name': 'concat_priors_width_height',
            'axis': -1,
            'in_ports_count': 2
        }).create_node([concat_2, priors_height_node])

        applied_width_height_regressions_node = Mul(graph, {
            'name': 'final_regressions'
        }).create_node(
            [concat_width_height_node,
             match.single_input_node(0)[0]])

        # reshape to 2D tensor as Inference Engine Detection Output layer expects
        reshape_regression_node = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            dict(name='reshape_regression'),
            applied_width_height_regressions_node)

        detection_output_op = DetectionOutput(
            graph, match.custom_replacement_desc.custom_attributes)
        # get nms from the original network
        iou_threshold = None
        nms_nodes = graph.get_op_nodes(op='NonMaxSuppression')
        if len(nms_nodes) > 0:
            # it is highly unlikely that for different classes NMS has different
            # moreover DetectionOutput accepts only scalar values for iou_threshold (nms_threshold)
            iou_threshold = nms_nodes[0].in_node(3).value
        if iou_threshold is None:
            raise Error(
                'During {} `iou_threshold` was not retrieved from RetinaNet graph'
                .format(self.replacement_id))

        detection_output_node = detection_output_op.create_node(
            [reshape_regression_node, reshape_classes_node, priors],
            dict(name=detection_output_op.attrs['type'],
                 nms_threshold=iou_threshold,
                 clip_after_nms=1,
                 normalized=1,
                 variance_encoded_in_target=0,
                 background_label_id=1000))

        return {'detection_output_node': detection_output_node}
예제 #17
0
    def convert_onnx_slice_opset10(node: Node):
        """
        Converts the Slice node from ONNX opset10 to StridedSlice.
        :param node: Slice node
        :return: None
        """
        graph = node.graph

        input_shape = node.in_port(0).data.get_shape()
        output_shape = node.out_port(0).data.get_shape()
        starts = node.in_port(1).data.get_value()
        ends = node.in_port(2).data.get_value()
        if starts is None or ends is None:
            raise Error(
                'The input with starts or end is not constant for node {}'.
                format(node.id))

        # in ONNX the value for 'ends' is usually -1 which is translated to maximum possible value of int64. This
        # value must be converted to maximum of int32 because such big values do not fit into the int32 which is
        # supported by the StridedSlice layer
        ends = np.clip(ends, np.iinfo(np.int32).min, np.iinfo(np.int32).max)
        if node.is_in_port_connected(3):
            axes = node.in_port(3).data.get_value()
            if axes is None:
                raise Error(
                    'The input with axes is not constant for node {}'.format(
                        node.id))
        else:
            axes = int64_array(list(range(starts.size)))

        if node.is_in_port_connected(4):
            steps = node.in_port(4).data.get_value()
            if steps is None:
                raise Error(
                    'The input with steps is not constant for node {}'.format(
                        node.id))
        else:
            steps = np.ones([starts.size])

        ss_begin_mask = np.zeros(len(input_shape), dtype=np.int32)
        ss_end_mask = np.zeros(len(input_shape), dtype=np.int32)
        ss_begin = np.zeros(len(input_shape), dtype=np.int32)
        ss_end = np.zeros(len(input_shape), dtype=np.int32)
        ss_steps = np.ones(len(input_shape), dtype=np.int32)

        # prepare inputs and attributes for the StridedSlice layer
        for i, axis in enumerate(axes):
            if starts[i] != 0:
                ss_begin_mask[axis] = 1
                ss_begin[axis] = starts[i]

            ss_end_mask[axis] = 1
            ss_end[axis] = ends[i]

            ss_steps[axis] = steps[i]

        slice_node_name = node.soft_get('name', node.id)

        begin_node = Const(graph, {
            'value': ss_begin,
            'name': slice_node_name + '/begin'
        }).create_node()
        end_node = Const(graph, {
            'value': ss_end,
            'name': slice_node_name + '/end'
        }).create_node()
        strides_node = Const(graph, {
            'value': ss_steps,
            'name': slice_node_name + '/stride'
        }).create_node()

        ss = StridedSlice(
            graph,
            dict(new_axis_mask=np.zeros(len(output_shape), dtype=np.int32),
                 shrink_axis_mask=np.zeros(len(output_shape), dtype=np.int32),
                 ellipsis_mask=np.zeros(len(output_shape), dtype=np.int32),
                 begin_mask=ss_begin_mask,
                 end_mask=ss_end_mask)).create_node()
        rename_nodes([(node, slice_node_name + '_delete'),
                      (ss, slice_node_name)])
        node.in_port(0).get_connection().set_destination(ss.in_port(0))
        begin_node.out_port(0).connect(ss.in_port(1))
        end_node.out_port(0).connect(ss.in_port(2))
        strides_node.out_port(0).connect(ss.in_port(3))
        node.out_port(0).get_connection().set_source(ss.out_port(0))
예제 #18
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        log.debug('UpsampleToResample is triggered')
        upsample = match['upsample']
        input_shape = upsample.in_port(0).data.get_shape()
        input_shape_rank = len(input_shape)
        if input_shape_rank not in [4, 5]:
            log.warning('The input shape is not 4D or 5D for op {}'.format(
                upsample.soft_get('name')))
            return

        depth_scale = None
        if len(upsample.in_nodes()) == 2:
            if upsample.in_node(1).value is None:
                return
            scales = upsample.in_node(1).value
            assert len(scales) in (
                4, 5
            ), 'Supported scales rank is 4 or 5, but it is {} for node {}'.format(
                len(scales), upsample.soft_get('name', upsample.id))
            if not (math.isclose(scales[0], 1, rel_tol=1e-5)
                    and math.isclose(scales[1], 1, rel_tol=1e-5)):
                return
            height_scale = scales[2]
            width_scale = scales[3]
            if len(scales) == 5:
                depth_scale = scales[4]
        else:
            height_scale = upsample['height_scale']
            width_scale = upsample['width_scale']

        if not math.isclose(height_scale, width_scale, rel_tol=1e-5):
            log.debug(
                'Width and height scales are not equal: {} vs {} for node {}'.
                format(width_scale, height_scale, upsample.soft_get('name')))
            return
        if depth_scale is not None and not math.isclose(
                height_scale, depth_scale, rel_tol=1e-5):
            log.debug(
                'Depth and height scales are not equal: {} vs {} for node {}'.
                format(depth_scale, height_scale, upsample.soft_get('name')))
            return

        if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
            upsample.in_port(1).disconnect()

        shape = Shape(graph, {'name': upsample.name + '/0_port'}).create_node()

        layout = graph.graph['layout']
        if input_shape_rank == 4:
            begin = Const(graph, {
                'value':
                int64_array([get_height_dim(layout, input_shape_rank)])
            }).create_node()
            factor = Const(graph, {
                'value': np.array([height_scale, width_scale])
            }).create_node()
        else:
            begin = Const(graph, {
                'value':
                int64_array([get_depth_dim(layout, input_shape_rank)])
            }).create_node()
            factor = Const(
                graph, {
                    'value': np.array([depth_scale, height_scale, width_scale])
                }).create_node()
        end = Const(graph, {
            'value':
            int64_array([get_width_dim(layout, input_shape_rank) + 1])
        }).create_node()

        stride = Const(graph, {'value': int64_array([1])}).create_node()
        ss = StridedSlice(
            graph, {
                'name': upsample.name + '/ss_0_port',
                'begin_mask': int64_array([1]),
                'end_mask': int64_array([1]),
                'new_axis_mask': int64_array([0]),
                'shrink_axis_mask': int64_array([0]),
                'ellipsis_mask': int64_array([0])
            }).create_node()

        mul = Mul(graph, {
            'name': upsample.name + '/factor_mul_'
        }).create_node()

        source = upsample.in_port(0).get_connection().get_source()
        source.connect(shape.in_port(0))
        shape.out_port(0).connect(ss.in_port(0))
        begin.out_port(0).connect(ss.in_port(1))
        end.out_port(0).connect(ss.in_port(2))
        stride.out_port(0).connect(ss.in_port(3))
        ss.out_port(0).connect(mul.in_port(0))
        factor.out_port(0).connect(mul.in_port(1))

        # Create Interpolate operation
        if input_shape_rank == 4:
            axes = int64_array([
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])
        else:
            axes = int64_array([
                get_depth_dim(layout, input_shape_rank),
                get_height_dim(layout, input_shape_rank),
                get_width_dim(layout, input_shape_rank)
            ])

        resample_op = Interpolate(
            graph,
            dict(name='Interpolate/{}'.format(upsample.name),
                 axes=axes,
                 mode=upsample.attrs()['mode'],
                 antialias=0,
                 convert_to_resample=True)).create_node()

        upsample.add_input_port(1, skip_if_exist=True)
        assert upsample.in_port(1).disconnected()
        mul.out_port(0).connect(resample_op.in_port(1))

        upsample.in_port(0).get_connection().set_destination(
            resample_op.in_port(0))
        upsample.out_port(0).get_connection().set_source(
            resample_op.out_port(0))

        convert_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()
        int_np_type = np.int64 if graph.graph[
            'cmd_params'].generate_experimental_IR_V10 else np.int32
        convert_to_int = Cast(graph, dict(dst_type=int_np_type)).create_node()

        mul.in_port(0).get_connection().insert_node(convert_to_float)
        mul.out_port(0).get_connection().insert_node(convert_to_int)