def replace_pattern(self, graph: Graph, match: dict):
        node = match['reduce']

        if node.out_port(0).data.get_value() is not None:
            # We leave Reduce* operations located in constant sub-graph as is
            # to keep model reshapable with --keep_shape_ops cli key
            return

        reduce_type = node.type
        if reduce_type not in self.pool_method_map:
            log.error(
                "Reduce type {} is not included in pool_method_map. Please update pool_method_map with new key "
                "{}".format(reduce_type, reduce_type))
            return

        input_data = node.in_node()
        output_data = node.out_node()

        input_shape = node.in_port(0).data.get_shape()
        output_shape = node.out_port(0).data.get_shape()

        # normalize node axes to exclude negative indices
        axes_data_value = node.in_port(1).data.get_value()
        axes = int64_array([
            axes_data_value.item()
        ]) if axes_data_value.size == 1 else axes_data_value
        axes = [get_canonical_axis_index(input_shape, a) for a in axes]
        axes = sorted(axes)

        # Check that values in axes list are consecutive
        for idx in range(1, len(axes)):
            if axes[idx] != (axes[idx - 1] + 1):
                log.error(
                    "Reduce with not consecutive axes {} is not supported ".
                    format(axes))
                return
        # So now we are sure that we can convert Reduce to appropriate operation

        # 1. Calculate shape that will be used in reduction
        reduction_dim = np.prod([input_shape[idx] for idx in axes])
        begin_dims = np.array([input_shape[idx] for idx in range(axes[0])])
        end_dim = np.prod([
            input_shape[idx] for idx in range(axes[-1] + 1, len(input_shape))
        ])

        # 2. Create reshape with appropriate shape
        if len(begin_dims) > 2:
            if 0 not in axes:
                begin_dims = int64_array(
                    [begin_dims[0], np.prod(begin_dims[1:])])
            else:
                begin_dims = int64_array(
                    [np.prod(begin_dims[0:-1]), begin_dims[-1]])
        else:
            # Expand begin_dims to 2
            begin_dims = int64_array(
                np.append(begin_dims, [1] * (2 - len(begin_dims))))

        reshape_shape = np.array([*begin_dims, reduction_dim, end_dim],
                                 dtype=np.int64)
        pool_window = np.array([1, 1, reduction_dim, 1], dtype=np.int64)

        # 3. Reduce => Reshape->Pooling->Reshape
        reshape_op = Reshape(graph, {'name': node.id + '/Reshape'})
        reshape_dim_const_data = Const(graph, {
            'name': node.id + '/Reshape/Dim',
            'value': reshape_shape
        }).create_node_with_data()

        final_reshape_op = Reshape(graph, {'name': node.id + '/FinalReshape'})
        final_reshape_dim_const_data = Const(graph, {
            'name': node.id + '/FinalReshape/Dim',
            'value': output_shape
        }).create_node_with_data()
        pooling_op = Pooling(
            graph,
            dict(name=node.id + '/Pool',
                 window=pool_window,
                 output_spatial_shape=None,
                 batch_dims=int64_array([0]),
                 channel_dims=int64_array([1]),
                 exclude_pad='false',
                 pool_method=self.pool_method_map[reduce_type]))

        graph.remove_edge(input_data.id, node.id)
        graph.remove_edge(node.id, output_data.id)

        final_reshape_op.create_node_with_data(inputs=[
            pooling_op.create_node_with_data(inputs=[
                reshape_op.create_node_with_data(
                    inputs=[input_data, reshape_dim_const_data])
            ]), final_reshape_dim_const_data
        ],
                                               data_nodes=output_data)

        # convert batch dimension to 0 to produce reshape-able IR over the batch dimension
        if 0 not in axes:
            reshape_dim_const_data.in_node(0).value[0] = 0
            final_reshape_dim_const_data.in_node(0).value[0] = 0

        # 4. If it is reduction with summation, we need to multiply by size of the reduction slice with Mul op
        if reduce_type == 'ReduceSum':
            output_data.in_node().insert_node_with_data_after(
                output_data, AttributedPower, {
                    'name': node.name + '/Mul',
                    'scale': float(reduction_dim)
                })
    def append_variances(priors_scale_node: Node, variance: list):
        graph = priors_scale_node.graph
        name = priors_scale_node.name

        sp_shape = Shape(graph, {'name': name + '/shape'}).create_node()
        priors_scale_node.out_port(0).connect(sp_shape.in_port(0))

        begin = Const(graph, {'value': np.array([-2])}).create_node()
        end = Const(graph, {'value': np.array([-1])}).create_node()
        stride = Const(graph, {'value': np.array([1])}).create_node()
        shape_part_for_tiling = StridedSlice(
            graph, {
                'name': name + '/get_-2_dim',
                'begin_mask': np.array([1]),
                'end_mask': np.array([1]),
                'new_axis_mask': np.array([0]),
                'shrink_axis_mask': np.array([0]),
                'ellipsis_mask': np.array([0])
            }).create_node()

        sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0))
        begin.out_port(0).connect(shape_part_for_tiling.in_port(1))
        end.out_port(0).connect(shape_part_for_tiling.in_port(2))
        stride.out_port(0).connect(shape_part_for_tiling.in_port(3))

        concat_value = Const(graph, {'value': np.array([4])}).create_node()
        shape_concat = Concat(
            graph, {
                'name': name + '/shape_for_tiling',
                'in_ports_count': 2,
                'axis': np.array(0)
            }).create_node()
        shape_part_for_tiling.out_port(0).connect(shape_concat.in_port(0))
        concat_value.out_port(0).connect(shape_concat.in_port(1))

        variance = Const(graph, {
            'name': name + '/variance',
            'value': np.array(variance)
        }).create_node()
        tile = Broadcast(graph, {
            'name': name + '/variance_tile'
        }).create_node()
        variance.out_port(0).connect(tile.in_port(0))
        shape_concat.out_port(0).connect(tile.in_port(1))

        reshape_dim = Const(graph, {
            'value': int64_array([-1, 4])
        }).create_node()
        sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node()
        sp_reshape.in_port(0).connect(priors_scale_node.out_port(0))
        sp_reshape.in_port(1).connect(reshape_dim.out_port(0))

        concat = Concat(
            graph, {
                'name': name + '/priors_concat',
                'axis': np.array(0),
                'in_ports_count': 2
            }).create_node()
        sp_reshape.out_port(0).connect(concat.in_port(0))
        tile.out_port(0).connect(concat.in_port(1))

        output_dims = Const(graph, {
            'value': int64_array([1, 2, -1])
        }).create_node()
        output_node = Reshape(graph, {
            'name': name + '/3D_priors_wth_variances'
        }).create_node()
        concat.out_port(0).connect(output_node.in_port(0))
        output_dims.out_port(0).connect(output_node.in_port(1))

        return output_node
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']

        if 1 not in node.in_ports() or node.in_port(1).disconnected():

            if node.has_valid('factor') and not node.has_valid('width') and not node.has_valid('height'):
                factor = Const(graph, {'value': np.array(node.factor)}).create_node()

                shape = Shape(graph, {'name': node.name + '/shape'}).create_node()

                begin = Const(graph, {'value': np.array([2])}).create_node()
                end = Const(graph, {'value': np.array([4])}).create_node()
                stride = Const(graph, {'value': np.array([1])}).create_node()
                ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]),
                                          'end_mask': np.array([0]), 'new_axis_mask': np.array([0]),
                                          'shrink_axis_mask': np.array([0]),
                                          'ellipsis_mask': np.array([0])}).create_node()

                mul = Mul(graph, {'name': node.name + '/factor_mul_'}).create_node()

                source = node.in_port(0).get_connection().get_source()
                source.connect(shape.in_port(0))
                shape.out_port(0).connect(ss.in_port(0))
                begin.out_port(0).connect(ss.in_port(1))
                end.out_port(0).connect(ss.in_port(2))
                stride.out_port(0).connect(ss.in_port(3))
                ss.out_port(0).connect(mul.in_port(0))
                factor.out_port(0).connect(mul.in_port(1))

                node.add_input_port(1, skip_if_exist=True)
                assert node.in_port(1).disconnected()
                mul.out_port(0).connect(node.in_port(1))

            else:
                shape = Shape(graph, {'name': node.name + '/shape'}).create_node()

                begin = Const(graph, {'value': np.array([2])}).create_node()
                end = Const(graph, {'value': np.array([4])}).create_node()
                stride = Const(graph, {'value': np.array([1])}).create_node()
                ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]),
                                          'end_mask': np.array([0]), 'new_axis_mask': np.array([0]),
                                          'shrink_axis_mask': np.array([0]),
                                          'ellipsis_mask': np.array([0])}).create_node()

                source = node.in_port(0).get_connection().get_source()
                source.connect(shape.in_port(0))
                shape.out_port(0).connect(ss.in_port(0))
                begin.out_port(0).connect(ss.in_port(1))
                end.out_port(0).connect(ss.in_port(2))
                stride.out_port(0).connect(ss.in_port(3))

                pads_value = node.pads_begin + node.pads_end
                pads_const = Const(graph, {'value': np.array(pads_value)}).create_node()
                add = Add(graph, {'name': node.name + '/pad_add'}).create_node()
                ss.out_port(0).connect(add.in_port(0))
                add.in_port(1).connect(pads_const.out_port(0))

                if node.soft_get('shrink_factor') != 1 and node.soft_get('zoom_factor') == 1:
                    shrink_factor = node.shrink_factor
                    if shrink_factor < 1:
                        log.error('Shrink factor should be positive in node {}'.format(node.id))
                        return None

                    const = Const(graph, {'name': node.name + '/pre_shrink_sub_const',
                                          'value': np.array(-1)}).create_node()
                    sub = Add(graph, {'name': node.name + '/pre_shrink_sub'}).create_node()
                    add.out_port(0).connect(sub.in_port(0))
                    sub.in_port(1).connect(const.out_port(0))

                    const = Const(graph, {'value': np.array(1 / shrink_factor),
                                          'name': node.name + 'shrink_factor_div_const'}).create_node()
                    div = Mul(graph, {'name': node.name + 'shrink_factor_div'}).create_node()
                    sub.out_port(0).connect(div.in_port(0))
                    div.in_port(1).connect(const.out_port(0))

                    const = Const(graph, {'name': node.name + '/shrink_factor_add_one_const', 'value': np.array(1)
                                          }).create_node()
                    add = Add(graph, {'name': node.name + '/shrink_factor_add_one'}).create_node()
                    div.out_port(0).connect(add.in_port(0))
                    const.out_port(0).connect(add.in_port(1))

                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    add.out_port(0).connect(node.in_port(1))

                elif node.soft_get('shrink_factor') == 1 and node.soft_get('zoom_factor') != 1:
                    zoom_factor = node.zoom_factor
                    if zoom_factor < 1:
                        log.error('Zoom factor should be positive in node {}'.format(node.id))
                        return None

                    node['debug_message'] = 'Interpolate layer replacer may be wrong, please, try to update it in the' \
                                            ' file (extensions/front/InterpolateNormalizer.py at the line {}).' \
                                            ''.format(inspect.currentframe().f_lineno) + refer_to_faq_msg(100)

                    # Reshape methods can be different in some cases
                    # Commented out section represents reshape that used in deeplab-caffe
                    # Uncomment the following lines, if your model was trained with deeplab-caffe
                    # or have the same reshape method
                    # const = Const(graph, {'value': np.array(-1),
                    #                       'name': node.name + 'zoom_factor_deeplab-caffe_sub_const'}).create_node()
                    # sub = Add(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_sub'}).create_node()
                    # add.out_port(0).connect(sub.in_port(0))
                    # const.out_port(0).connect(sub.in_port(1))
                    #
                    # const = Const(graph, {'value': np.array(zoom_factor - 1),
                    #                       'name': node.name + 'zoom_factor_deeplab-caffe_mul_const'}).create_node()
                    # mul = Mul(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_mul'}).create_node()
                    # sub.out_port(0).connect(mul.in_port(0))
                    # const.out_port(0).connect(mul.in_port(1))
                    #
                    # sum = Add(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_sum'}).create_node()
                    # add.out_port(0).connect(sum.in_port(0))
                    # mul.out_port(0).connect(sum.in_port(1))
                    #
                    # node.add_input_port(1, skip_if_exist=True)
                    # assert node.in_port(1).disconnected()
                    # sum.out_port(0).connect(node.in_port(1))

                    # Comment out the following lines if you use the reshape method from previous section
                    const = Const(graph, {'value': np.array(zoom_factor),
                                          'name': node.name + '/zoom_factor_mul_const'}).create_node()
                    mul = Mul(graph, {'name': node.name + '/zoom_factor_mul'}).create_node()

                    add.out_port(0).connect(mul.in_port(0))
                    const.out_port(0).connect(mul.in_port(1))

                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    mul.out_port(0).connect(node.in_port(1))

                elif node.soft_get('width') != 0 and node.soft_get('height') != 0:
                    const = Const(graph, {'value': np.array([node.height, node.width])}).create_node()
                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    const.out_port(0).connect(node.in_port(1))

                elif node.soft_get('shrink_factor') != 1 and node.soft_get('zoom_factor') != 1:
                    shrink_factor = node.shrink_factor
                    zoom_factor = node.zoom_factor
                    if shrink_factor < 1:
                        log.error('Shrink factor should be positive in node {}'.format(node.id))
                        return None
                    if zoom_factor < 1:
                        log.error('Zoom factor should be positive in node {}'.format(node.id))
                        return None

                    const = Const(graph, {'value': np.array(-1)}).create_node()
                    sub = Add(graph, {'name': node.name + '/shrink_zoom_factor_sub'}).create_node()
                    add.out_port(0).connect(sub.in_port(0))
                    const.out_port(0).connect(sub.in_port(1))

                    const = Const(graph, {'value': np.array(1 / (shrink_factor + 1))}).create_node()
                    div = Mul(graph, {'name': node.name + '/shrink_factor_div'}).create_node()
                    sub.out_port(0).connect(div.in_port(0))
                    const.out_port(0).connect(div.in_port(1))

                    const = Const(graph, {'value': np.array(-1),
                                          'name': node.name + 'shrink_zoom_factor_sum_const'}).create_node()
                    sum = Add(graph, {'name': node.name + '/shrink_zoom_factor_sum'}).create_node()
                    div.out_port(0).connect(sum.in_port(0))
                    const.out_port(0).connect(sum.in_port(1))

                    const = Const(graph, {'value': np.array(zoom_factor - 1)}).create_node()
                    mul = Mul(graph, {'name': node.name + '/zoom_factor_mul'}).create_node()
                    sum.out_port(0).connect(mul.in_port(0))
                    const.out_port(0).connect(mul.in_port(1))

                    sum = Add(graph, {'name': node.name + '/final_shrink_zoom_factor_sum'}).create_node()
                    div.out_port(0).connect(sum.in_port(0))
                    mul.out_port(0).connect(sum.in_port(1))

                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    sum.out_port(0).connect(node.in_port(1))
        else:
            if node.soft_get('fw') == 'caffe':
                shape = Shape(graph, {'name': node.name + '/shape'}).create_node()

                begin = Const(graph, {'value': np.array([2])}).create_node()
                end = Const(graph, {'value': np.array([4])}).create_node()
                stride = Const(graph, {'value': np.array([1])}).create_node()
                ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]),
                                          'end_mask': np.array([0]), 'new_axis_mask': np.array([0]),
                                          'shrink_axis_mask': np.array([0]),
                                          'ellipsis_mask': np.array([0])}).create_node()

                source = node.in_port(1).get_connection().get_source()
                node.in_port(1).disconnect()
                source.connect(shape.in_port(0))
                shape.out_port(0).connect(ss.in_port(0))
                begin.out_port(0).connect(ss.in_port(1))
                end.out_port(0).connect(ss.in_port(2))
                stride.out_port(0).connect(ss.in_port(3))
                ss.out_port(0).connect(node.in_port(1))
예제 #4
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['flatten']
        name = node.soft_get('name', node.id)

        assert node.has_valid(
            'axis'), 'Flatten {} has no mandatory `axis` attribute'.format(
                name)
        assert node.has_valid(
            'end_axis'
        ), 'Flatten {} has no mandatory `end_axis` attribute'.format(name)

        axis = node.axis
        end_axis = node.end_axis

        if end_axis == -1 and axis >= 0:
            begin_dims = Const(graph, {
                'value': int64_array([0] * axis)
            }).create_node()
            middle_dim = Const(graph, {
                'value': int64_array([-1])
            }).create_node()
            end_dims = Const(graph, {'value': int64_array([])}).create_node()
        else:
            rank = Rank(graph, {'name': name + '/input_rank'}).create_node()
            node.in_port(0).get_source().connect(rank.in_port(0))

            shape = Shape(graph, {'name': name + '/input_shape'}).create_node()
            node.in_port(0).get_source().connect(shape.in_port(0))

            begin_dims = get_shape_values_by_range_idxs(shape=shape,
                                                        rank=rank,
                                                        begin=0,
                                                        end=axis)
            middle_dims = get_shape_values_by_range_idxs(shape=shape,
                                                         rank=rank,
                                                         begin=axis,
                                                         end=end_axis,
                                                         include_end=True)
            end_dims = get_shape_values_by_range_idxs(shape=shape,
                                                      rank=rank,
                                                      begin=end_axis,
                                                      end=-1,
                                                      include_begin=False,
                                                      include_end=True)

            middle_dim = create_op_node_with_second_input(
                graph, ReduceProd, int64_array([0]), {'keep_dims': True})
            middle_dims.out_port(0).connect(middle_dim.in_port(0))

        dim = new_shape_node_from_shape_nodes(
            [begin_dims, middle_dim, end_dims])

        original_name = node.soft_get('name')
        abandoned_name = original_name + '/ShouldBeDeleted'
        reshape_node = Reshape(graph, {}).create_node()
        # Keep node with the same name to avoid confuse with renaming
        rename_nodes([(node, abandoned_name), (reshape_node, original_name)])
        reshape_node.in_port(1).connect(dim.out_port(0))

        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
        node.in_port(0).get_connection().set_destination(
            reshape_node.in_port(0))
    def placeholder_scales(self, placeholder: Node):
        """
        Helper function to get scales for prior boxes out of input image size:
                [1 / im_width, 1 / im_height, 1 / im_width, 1 / im_height]
        """
        graph = placeholder.graph
        name = placeholder.soft_get('name', placeholder.id)

        shape_value = placeholder.soft_get('shape', None)
        assert shape_value is not None, \
            "[ {} replacer ] Placeholder `{}` should have shape attribute".format(self.replacement_id, name)
        assert isinstance(shape_value, np.ndarray), \
            "[ {} replacer ] Placeholder `{}` shape attribute should be np.ndarray".format(self.replacement_id, name)
        assert shape_value.size == 4, \
            "[ {} replacer ] Placeholder `{}` should be 4D. Shape: {}".format(self.replacement_id, name, shape_value)

        shape = Shape(graph, {'name': 'input_image_shape'}).create_node()
        shape.in_port(0).connect(placeholder.out_port(0))

        begin = Const(graph, {'value': np.array([1])}).create_node()
        end = Const(graph, {'value': np.array([3])}).create_node()
        stride = Const(graph, {'value': np.array([1])}).create_node()
        spatial = StridedSlice(
            graph, {
                'name': name + '/get_h_w',
                'begin_mask': np.array([1]),
                'end_mask': np.array([1]),
                'new_axis_mask': np.array([0]),
                'shrink_axis_mask': np.array([0]),
                'ellipsis_mask': np.array([0])
            }).create_node()

        spatial.in_port(0).connect(shape.out_port(0))
        spatial.in_port(1).connect(begin.out_port(0))
        spatial.in_port(2).connect(end.out_port(0))
        spatial.in_port(3).connect(stride.out_port(0))

        power = Const(graph, {'value': np.array([-1.])}).create_node()
        spatial_scale = Pow(graph, {}).create_node()

        spatial_scale.in_port(0).connect(spatial.out_port(0))
        spatial_scale.in_port(1).connect(power.out_port(0))

        # Power `type_infer` requires inputs to have equal data type
        convert_to_fp32 = Cast(graph, {'dst_type': np.float32}).create_node()
        spatial_scale.in_port(0).get_connection().insert_node(convert_to_fp32)

        order = Const(graph, {'value': int64_array([1, 0])}).create_node()
        axis_const = Const(graph, {'value': int64_array(0)}).create_node()
        reverse = Gather(graph, {}).create_node()

        reverse.in_port(0).connect(spatial_scale.out_port(0))
        reverse.in_port(1).connect(order.out_port(0))
        axis_const.out_port(0).connect(reverse.in_port(2))

        priors_scale_node = Concat(graph, {
            'axis': 0,
            'in_ports_count': 2
        }).create_node()
        priors_scale_node.add_input_port(0, skip_if_exist=True)
        priors_scale_node.add_input_port(1, skip_if_exist=True)

        priors_scale_node.in_port(0).connect(reverse.out_port(0))
        priors_scale_node.in_port(1).connect(reverse.out_port(0))
        return priors_scale_node
예제 #6
0
def _fuse_mul(graph: Graph,
              node: Node,
              fuse_nodes: list,
              backward: bool = True):
    """
    This function takes Mul node and array of convolution/fc nodes for further fusion
    Parameters
    ----------
    x : bool
        If backward is False, that means that Convolution/FC goes after Mul node
        else means that Mul goes after Convolutions/FC
        :param backward:
        :param fuse_nodes:
        :param node:
        :param graph:
    """
    is_fused = False
    const_port, tensor_port = get_value_in_port(node), get_tensor_in_port(node)

    if const_port is None or tensor_port is None:
        log.warning(
            'Cannot do fuse_mul for node {} because this node has wrong inputs'
            .format(node.id))
        return False

    for fuse_node in fuse_nodes:
        if fuse_node.soft_get('can_be_fused') is False:
            log.warning(
                'Node {} can\'t be used in fusing because attr can_be_fused = False'
                .format(fuse_node.name))
            return False

        if len(fuse_node.in_ports()) < 2:
            log.warning('Node {} has no weights node'.format(fuse_node.name))
            return False

        if not backward and not fuse_node.has_valid('layout'):
            log.warning('Node {} has no layout attr'.format(fuse_node.name))
            return False

        weights_port = fuse_node.in_port(1)
        if not weights_port.data.has_valid('output_channel_dim') or \
                not weights_port.data.has_valid('input_channel_dim'):
            log.warning(
                'Cannot do fuse_mul for node {} because there is no field ' +
                'output_channel_dim and/or input_channel_dim in weights.'.
                format(fuse_node.soft_get('name')))
            return False

        inp_ch = weights_port.data.get_attr('input_channel_dim')
        out_ch = weights_port.data.get_attr('output_channel_dim')
        if max(inp_ch, out_ch) >= len(weights_port.data.get_shape()):
            log.warning('Node {} has wrong weights shape'.format(
                fuse_node.name))
            return False

    for fuse_node in fuse_nodes:
        weights_port = fuse_node.in_port(1)
        value = np.array(const_port.data.get_value())

        value = np.squeeze(value)

        # TODO : ch_dim should be equal to node.in_node(1).value.shape
        # We will multiply weights according output/input channel dimension
        ch_dim = weights_port.data.get_attr(
            'output_channel_dim' if backward else 'input_channel_dim')
        shape = np.array([weights_port.data.get_shape()[ch_dim]])

        # Scalar broadcast
        if value.size == 1:
            value = np.full(shape, value.item())

        # Common broadcast for forward fusion
        if not backward:
            cnt = shape[-1] / value.shape[0]
            if fuse_node.layout == 'NCHW':
                tmp = []
                for val in value:
                    tmp = np.concatenate((tmp, np.repeat(val, cnt)))
                value = np.array(tmp)
            else:
                value = np.tile(value, int(cnt))

        # Expand dims for multiplication (ex. [38] to [38, 1, 1])
        wdims_number = weights_port.data.get_attr('dims_number')
        for x in range(wdims_number - ch_dim - 1):
            shape = np.append(shape, 1)

        mul_val = np.array(value)
        # If the value fails to reshape to the provided shape, skip fusing.
        # This can happen in case of group != 1 of the convolution.
        try:
            value = np.reshape(value, shape)
        except ValueError:
            log.error(
                "Cannot fuse const from {} to {}. Reshape failed. Skipping.".
                format(node.soft_get('name', node.id),
                       fuse_node.soft_get('name', fuse_node.id)),
                extra={'is_warning': True})
            return False

        # Weights multiplication
        mul_name = node.name + '_copy'
        mul_const = Const(graph, {
            'value': value,
            'name': mul_name + '/const'
        }).create_node()
        w_mul = node.copy_node({
            'name': mul_name,
            'in_ports_count': len(node.in_ports()),
            'out_ports_count': len(node.out_ports()),
            'can_be_fused': False
        })
        w_mul.in_port(const_port.idx).connect(mul_const.out_port(0))
        w_const = weights_port.get_source()
        weights_port.get_connection().set_source(w_mul.out_port(0))
        w_const.connect(w_mul.in_port(tensor_port.idx))

        # If we fuse in backward direction we should multiply biases if they exists
        if backward and len(fuse_node.in_ports()) == 3 and not fuse_node.in_port(2).disconnected() and \
                not fuse_node.has_and_set('shape_input'):
            conv_bias = fuse_node.in_port(2)
            conv_bias.data.set_value(conv_bias.data.get_value() *
                                     np.squeeze(mul_val))

        mul_const.infer(mul_const)
        w_mul.infer(w_mul)

        log.debug('Fused: {} to {}'.format(node.name, fuse_node.name))
        is_fused = True

    if is_fused:
        # Delete Mul node
        producer_port = tensor_port.get_source()
        tensor_port.disconnect()
        const_port.disconnect()
        node.out_port(0).get_connection().set_source(producer_port)

    return is_fused
예제 #7
0
    def replace_pattern(self, graph: Graph, match: dict):
        unsqueeze_node = match['unsqueeze']
        unsqueeze_name = unsqueeze_node.name

        second_input_of_unsqueeze = unsqueeze_node.in_port(
            1).get_connection().get_source().node
        if not second_input_of_unsqueeze.has_valid('value'):
            return

        d_idx = int(second_input_of_unsqueeze.value)

        second_input_of_tile = match['tile'].in_port(
            1).get_connection().get_source().node
        if not second_input_of_tile.has_valid('value'):
            return

        input_shape_of_unsqueeze = unsqueeze_node.in_port(0).data.get_shape()
        if len(input_shape_of_unsqueeze) not in {4, 5}:
            return

        scale = int64_array([second_input_of_tile.value[d_idx]])
        axis = d_idx - 1

        shape_node = Shape(graph, dict(name=unsqueeze_name +
                                       '/Shape_')).create_node()
        scales_node = Const(
            graph, dict(name=unsqueeze_name + '/scales_',
                        value=scale)).create_node()
        mul_node = Mul(graph,
                       dict(name=unsqueeze_name + '/Mul_')).create_node()
        scales_node.out_port(0).connect(mul_node.in_port(1))

        slice_begin = Const(
            graph,
            dict(name=unsqueeze_name + '/slice_begin_',
                 value=int64_array([axis]))).create_node()
        slice_end = Const(
            graph,
            dict(name=unsqueeze_name + '/slice_end_',
                 value=int64_array([axis + 1]))).create_node()

        strided_slice_node = StridedSlice(
            graph, {
                'name': unsqueeze_name + '/StridedSlice_',
                'begin_mask': int64_array([1]),
                'end_mask': int64_array([1]),
                'new_axis_mask': int64_array([0]),
                'shrink_axis_mask': int64_array([0]),
                'ellipsis_mask': int64_array([0]),
            }).create_node()
        shape_node.out_port(0).connect(strided_slice_node.in_port(0))
        slice_begin.out_port(0).connect(strided_slice_node.in_port(1))
        slice_end.out_port(0).connect(strided_slice_node.in_port(2))
        strided_slice_node.out_port(0).connect(mul_node.in_port(0))

        interp_node = Interpolate(
            graph,
            dict(name=unsqueeze_name + '/Interpolate_',
                 axes=int64_array([axis]),
                 mode='nearest')).create_node()
        mul_node.out_port(0).connect(interp_node.in_port(1))

        match['reshape'].out_port(0).get_connection().set_source(
            interp_node.out_port(0))

        unsqueeze_connection = match['unsqueeze'].in_port(0).get_connection()
        before_unsqueeze = unsqueeze_connection.get_source().node
        unsqueeze_connection.set_destination(interp_node.in_port(0))
        before_unsqueeze.out_port(0).connect(shape_node.in_port(0))
예제 #8
0
def order(op_node: Node, port_info: str, input_port: int):
    """
        Performs layout change related transformation of the data on the in_port_idx port of op_node.
        Translates ordered shape indexes from one layout to another according to permutation

        Transformation inserts two Gather operations

        1 Gather reorders data to new layout according to direct permutation:
            actual data to translate as 1-port input indexes of Gather and
            permutation as 0-port input data
        2 Gather translates shape indexes from one layout to another according to inverse permutation
            permutation as 0-port input data and
            actual data to translate as 1-port input indexes of Gather

    For example:
        NHWC Transpose operation has 0-port input with data of shape [1, 2, 3, 4] and
        1-port input with new order indices [0, 1, 3, 2].

        After translating such operation to NCHW layout:
            0-port input shape = [1, 4, 2, 3]

        1 phase (after first Gather insertion):
            1-port input order indices = [0, 2, 1, 3]
        2 phase (after second Gather insertion):
            1-port input order indices = [0, 3, 2, 1]
    """
    graph = op_node.graph
    permutation_data_node = get_node_with_permutation(op_node, port_info)
    assert permutation_data_node.has_and_set('permutation'), 'Data node "{}" does not have permutation for node {}, ' \
                                                             'port_info "{}".'.format(permutation_data_node.id,
                                                                                      op_node.id, port_info)
    permutation = permutation_data_node.permutation
    if len(permutation.perm) == 0:
        return

    data_node = op_node.in_node(input_port)

    gather_name = op_node.soft_get('name', op_node.id) + '/OrderGather_1'
    const = Const(
        graph, {
            'value': permutation.perm,
            'name': gather_name + '/const',
            'need_shape_inference': True
        }).create_node_with_data()
    axis_const = Const(graph, {
        'value': int64_array(0),
        'name': gather_name + '/axis'
    }).create_node_with_data()
    gather = Gather(graph, {
        'name': gather_name,
        'need_shape_inference': True
    }).create_node_with_data([data_node, const, axis_const])

    gather_1_name = op_node.soft_get('name', op_node.id) + '/OrderGather_2'
    const_1 = Const(
        graph, {
            'value': permutation.inv,
            'name': gather_1_name + '/const',
            'need_shape_inference': True
        }).create_node_with_data()
    axis_const_1 = Const(graph, {
        'value': int64_array(0),
        'name': gather_1_name + '/axis'
    }).create_node_with_data()
    gather_1 = Gather(graph, {
        'name': gather_1_name,
        'need_shape_inference': True
    }).create_node_with_data([const_1, gather, axis_const_1])

    attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy()
    graph.add_edge(gather_1.id, op_node.id, **attrs)
    graph.remove_edge(data_node.id, op_node.id)
    op_node['need_shape_inference'] = True
예제 #9
0
    def replace_pattern(graph, match: dict):
        # Here we will found all parts of TI: condition, inputs/outputs, back edges, body and create TensorIterator Op
        # and make all checks needed for TensorIteator work
        cond_data = match['condition'].out_node(0)
        time_data = match['condition'].out_node(1) if len(
            match['condition'].out_nodes()) > 1 else None
        name = match['condition'].name

        back_edges = []
        inputs = []
        outputs = []

        for node in cond_data.out_nodes():
            if node['kind'] == 'op' and node['op'] == 'TensorIteratorBackEdge':
                back_edges.append(node.id)
            elif node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
                inputs.append(node.id)
            elif node['kind'] == 'op' and node['op'] == 'TensorIteratorOutput':
                outputs.append(node.id)

        if time_data is not None:
            for node in time_data.out_nodes():
                if node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
                    inputs.append(node.id)
                elif node['kind'] == 'op' and node[
                        'op'] == 'TensorIteratorOutput':
                    outputs.append(node.id)
                else:
                    # something goes wrong here
                    assert False
        condition = match['condition']
        tensor_sequence_length = condition.in_node(0)
        graph.remove_nodes_from(
            [condition.id, cond_data.id, tensor_sequence_length.id])
        if time_data is not None:
            graph.remove_nodes_from([time_data.id])

        body_nodes, extra_inputs = get_body(graph, inputs, outputs)
        body_nodes = list(set(body_nodes) - set([cond_data]))

        inputs += extra_inputs

        assert all([node in graph.nodes() for node in body_nodes])

        inputs = [Node(graph, node) for node in inputs]
        outputs = [Node(graph, node) for node in outputs]
        back_edges = [Node(graph, node) for node in back_edges]

        external_inputs = [{
            'external_data_id':
            node.in_node(1 if node.has_valid('axis') else 0),
            'internal_data_id':
            node.out_node(0),
            'axis':
            node.axis,
            'start':
            node.start,
            'end':
            node.end,
            'stride':
            node.stride,
            'part_size':
            node.part_size
        } for node in inputs]

        external_outputs = [{
            'external_data_id':
            node.out_node(0),
            'internal_data_id':
            node.in_node(1 if node.has_valid('axis') else 0),
            'axis':
            node.axis,
            'start':
            node.start,
            'end':
            node.end,
            'stride':
            node.stride,
            'part_size':
            node.part_size
        } for node in outputs]

        back_edges_data = [{
            'from_data_id': node.in_node(1),
            'to_data_id': node.out_node(0),
            'init_data_id': node.in_node(0),
        } for node in back_edges]

        body = Graph(name='body')
        body.graph = graph.graph
        body.add_nodes_from([(node, graph.node[node]) for node in body_nodes])
        body.add_edges_from([
            (u, v, k, d) for u, v, k, d in graph.edges(data=True, keys=True)
            if u in body_nodes and v in body_nodes
        ])

        graph.remove_nodes_from(body_nodes + [match['condition'].id] +
                                [inp.id for inp in inputs] +
                                [out.id for out in outputs])
        internal_id_count = 0
        real_back_edges = []
        for edge in back_edges_data:
            assert edge['from_data_id'].id in body.nodes()
            assert edge['to_data_id'].id in body.nodes()
            assert edge['init_data_id'].id in body.nodes()
            edge['from_data_id'] = Node(body, edge['from_data_id'].id)
            edge['to_data_id'] = Node(body, edge['to_data_id'].id)
            edge['init_data_id'] = Node(body, edge['init_data_id'].id)
            add_opoutput(body, edge['from_data_id'].id, 0, False)

            # Assign/reuse ids for the back-edge start; it comes from from_data_id
            assert len(edge['from_data_id'].in_nodes()) == 1
            # layer id
            if not edge['from_data_id'].in_node().has_valid(
                    'internal_layer_id'):
                edge['from_data_id'].in_node(
                )['internal_layer_id'] = internal_id_count
                internal_id_count += 1
            edge['from_layer'] = edge['from_data_id'].in_node(
            )['internal_layer_id']

            # port id
            if 'internal_port_id' not in edge['from_data_id'].in_edge():
                edge['from_data_id'].in_edge(
                )['internal_port_id'] = internal_id_count
                internal_id_count += 1
            edge['from_port'] = edge['from_data_id'].in_edge(
            )['internal_port_id']

            # Look at all consumers for a data that ends a back-edge
            # For each such consumer, there will be a separate back-edge (and input)
            current_real_back_edges = []
            for _, consumer, key, edge_attrs in body.out_edges(
                    edge['to_data_id'].id, data=True, keys=True):

                real_edge = {}
                real_edge.update(
                    edge)  # all real back_edges have the same back-edge start

                consumer = Node(body, consumer)

                if real_edge['to_data_id'].in_node().has_valid(
                        'internal_layer_id'):
                    assert False
                    real_edge['to_data_id'].out_node()['internal_layer_id'] = \
                        real_edge['to_data_id'].in_node().internal_layer_id
                elif not consumer.has_valid('internal_layer_id'):
                    consumer['internal_layer_id'] = internal_id_count
                    internal_id_count += 1
                real_edge['to_layer'] = consumer['internal_layer_id']

                assert 'internal_port_id' not in edge_attrs
                assert len(real_edge['init_data_id'].out_edges()) == 1
                assert not 'internal_port_id' in real_edge[
                    'init_data_id'].out_edge()
                edge_attrs['internal_port_id'] = internal_id_count
                internal_id_count += 1
                real_edge['to_port'] = edge_attrs['internal_port_id']
                real_edge['consumer'] = consumer
                real_edge['consumer_key'] = key

                real_edge['attrs'] = deepcopy(edge_attrs)
                current_real_back_edges.append(real_edge)

            # connect initial data node with each consumer providing actual edge attributes
            body.add_edges_from([
                (real_edge['init_data_id'].id, real_edge['consumer'].id,
                 real_edge['consumer_key'], real_edge['attrs'])
                for real_edge in current_real_back_edges
            ])

            body.remove_nodes_from(
                [edge['to_data_id'].id, edge['to_data_id'].in_node().id])
            real_back_edges += current_real_back_edges

        real_external_inputs = []

        for ext_inp in external_inputs:
            assert ext_inp['external_data_id'].id not in body.nodes()
            assert ext_inp['internal_data_id'].id in body.nodes()
            ext_inp['internal_data_id'] = Node(body,
                                               ext_inp['internal_data_id'].id)

            if ext_inp['axis'] is not None:
                # Insert squeezing resize at input port that has partitioning
                shape = ext_inp['internal_data_id'].shape.copy()
                assert not ext_inp['internal_data_id'].has_valid('value')
                new_input_data = Op._create_data_node(
                    body,
                    ext_inp['internal_data_id'].name + '/UnsqueezedInput',
                    dict(shape=np.insert(shape, ext_inp['axis'], 1)))
                dim = shape.copy()
                # try to do it dynamically reshapable along one of the axis
                # it is practically useful to reshape along batch dimension, but here we cannot detect where it is
                # so, we are guessing based on other transformations that it is the major dimension
                dim[0] = -1
                reshape_op = Squeeze(
                    body,
                    dict(name=ext_inp['internal_data_id'].name +
                         '/InputSqueeze'))
                reshape_dim_data = Const(
                    body, {
                        'name':
                        ext_inp['internal_data_id'].name + '/ReshapeDim',
                        'value': ext_inp['axis']
                    }).create_node_with_data()
                reshape_op.create_node_with_data(
                    [new_input_data, reshape_dim_data],
                    data_nodes=[ext_inp['internal_data_id']])
                ext_inp['internal_data_id'] = new_input_data

            ext_inp['internal_data_id']['is_input'] = True
            assert len(ext_inp['internal_data_id'].in_nodes()) == 0
            ext_inp['external_port_id'] = internal_id_count
            internal_id_count += 1
            for _, consumer, edge_attrs in body.out_edges(
                    ext_inp['internal_data_id'].id, data=True):
                real_ext_inp = {}
                real_ext_inp.update(ext_inp)
                consumer = Node(body, consumer)
                if not consumer.has_valid('internal_layer_id'):
                    consumer['internal_layer_id'] = internal_id_count
                    internal_id_count += 1
                if not 'internal_port_id' in edge_attrs:
                    edge_attrs['internal_port_id'] = internal_id_count
                    internal_id_count += 1
                real_ext_inp['internal_layer_id'] = consumer[
                    'internal_layer_id']
                real_ext_inp['internal_port_id'] = edge_attrs[
                    'internal_port_id']
                real_external_inputs.append(real_ext_inp)

        for ext_out in external_outputs:
            assert ext_out['external_data_id'].id not in body.nodes()
            assert ext_out['internal_data_id'].id in body.nodes()
            ext_out['internal_data_id'] = Node(body,
                                               ext_out['internal_data_id'].id)

            if ext_out['axis'] is not None:
                # Insert unsqueezing resize at output port that has partitioning
                dim = ext_out['internal_data_id'].shape.copy()
                # trying to make it dynamically reshapable (see related comment above for the first Reshape)
                dim[0] = -1
                assert not ext_out['internal_data_id'].has_valid('value')
                reshape_op = Unsqueeze(
                    body,
                    dict(name=ext_out['internal_data_id'].name +
                         '/OutputUnsqueeze'))
                reshape_dim_data = Const(
                    body, {
                        'name':
                        ext_out['internal_data_id'].name + '/ReshapeDim',
                        'value': ext_out['axis']
                    }).create_node_with_data()
                ext_out['internal_data_id'] = reshape_op.create_node_with_data(
                    [ext_out['internal_data_id'], reshape_dim_data])

            # TODO: add here working with simple outputs

            if not any([
                    out_node.soft_get('op', None) == 'Result'
                    for out_node in ext_out['internal_data_id'].out_nodes()
            ]):
                add_opoutput(body, ext_out['internal_data_id'].id, 0, False)

            # assert len(ext_out['internal_data_id'].out_nodes()) == 0
            assert len(ext_out['internal_data_id'].in_nodes()) == 1
            if not 'internal_layer_id' in ext_out['internal_data_id'].in_node(
            ):
                ext_out['internal_data_id'].in_node(
                )['internal_layer_id'] = internal_id_count
                internal_id_count += 1
            if not 'internal_port_id' in ext_out['internal_data_id'].in_edge():
                ext_out['internal_data_id'].in_edge(
                )['internal_port_id'] = internal_id_count
                internal_id_count += 1
            ext_out['internal_layer_id'] = ext_out['internal_data_id'].in_node(
            )['internal_layer_id']
            ext_out['internal_port_id'] = ext_out['internal_data_id'].in_edge(
            )['internal_port_id']
            ext_out['external_port_id'] = internal_id_count
            internal_id_count += 1

        ti_op = TensorIterator(
            graph, {
                'name':
                name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                len(external_inputs),
                'out_ports_count':
                len(external_outputs),
                'input_port_map': [{
                    field: external_input[field]
                    for field in [
                        'external_port_id', 'internal_layer_id',
                        'internal_port_id', 'axis', 'stride', 'part_size',
                        'start', 'end'
                    ]
                } for external_input in real_external_inputs],
                'output_port_map': [{
                    field: external_output[field]
                    for field in [
                        'external_port_id', 'internal_layer_id',
                        'internal_port_id', 'axis', 'stride', 'part_size',
                        'start', 'end'
                    ]
                } for external_output in external_outputs],
                'back_edges': [{
                    field: edge[field]
                    for field in
                    ['from_layer', 'from_port', 'to_layer', 'to_port']
                } for edge in real_back_edges],
            })

        ti_outs = ti_op.create_node_with_data(
            inputs=[inp['external_data_id'] for inp in external_inputs],
            edge_attrs=[{
                'external_port_id': inp['external_port_id']
            } for inp in external_inputs],
            data_nodes=[out['external_data_id'] for out in external_outputs])

        if not isinstance(ti_outs, list):
            ti_outs = [ti_outs]

        for i, out in enumerate(ti_outs):
            out.in_edge(
            )['external_port_id'] = external_outputs[i]['external_port_id']

        ti = ti_outs[0].in_node()
        TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
        TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
        TensorIterator.normalize_internal_ids(ti)
예제 #10
0
def replace_resize(graph: Graph, resize: Node):
    log.debug("Converting of ONNX Resize-11 to Interpolate-4 "
              "is triggered for node {}.".format(resize.soft_get('name', resize.id)))

    input_shape = resize.in_port(0).data.get_shape()
    input_rank = len(input_shape)
    resize_name = resize.soft_get('name', resize.id)
    if input_rank not in {4, 5}:
        log.warning('The input shape is not 4D or 5D for op with name {}'.format(resize_name))
        return

    assert (resize.is_in_port_connected(0) and (resize.is_in_port_connected(2) or resize.is_in_port_connected(3))), \
        "Scales or sizes inputs must be connected to Node {} with op {}.".format(resize.soft_get("name", resize.id),
                                                                                 resize.op)

    assert resize.soft_get('coordinate_transformation_mode') != 'tf_crop_and_resize', \
        'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(resize.op,
                                                                                 resize.soft_get("name", resize.id))

    layout = graph.graph['layout']

    if input_rank == 4:
        begin_dim = get_height_dim(layout, input_rank)
        end_dim = get_width_dim(layout, input_rank) + 1
    else:
        begin_dim = get_depth_dim(layout, input_rank)
        end_dim = get_width_dim(layout, input_rank) + 1

    sizes_ss = create_op_with_const_inputs(graph, StridedSlice,
                                           {1: int64_array([begin_dim]),
                                            2: int64_array([end_dim]),
                                            3: int64_array([1])},
                                           {'name': resize_name + '/StridedSlice_sizes',
                                            'begin_mask': int64_array([1]),
                                            'end_mask': int64_array([1]),
                                            'new_axis_mask': int64_array([0]),
                                            'shrink_axis_mask': int64_array([0]),
                                            'ellipsis_mask': int64_array([0])})
    scales_ss = create_op_with_const_inputs(graph, StridedSlice,
                                            {1: int64_array([begin_dim]),
                                             2: int64_array([end_dim]),
                                             3: int64_array([1])},
                                            {'name': resize_name + '/StridedSlice_scales',
                                             'begin_mask': int64_array([1]),
                                             'end_mask': int64_array([1]),
                                             'new_axis_mask': int64_array([0]),
                                             'shrink_axis_mask': int64_array([0]),
                                             'ellipsis_mask': int64_array([0])})
    axes_node = Const(graph,
                      {'name': resize_name + '/axis',
                       'value': int64_array(np.arange(begin_dim, end_dim))}).create_node()

    shape_calculation_mode = 'sizes' if resize.is_in_port_connected(3) else 'scales'

    interpolate_node = Interpolate(graph, {'version': 'opset4',
                                           'mode': convert_mode(resize.mode),
                                           'coordinate_transformation_mode': resize.coordinate_transformation_mode,
                                           'cube_coeff': resize.cube_coeff,
                                           'nearest_mode': resize.nearest_mode,
                                           'pads_begin': int64_array([0]),
                                           'pads_end': int64_array([0]),
                                           'antialias': 0,
                                           'shape_calculation_mode': shape_calculation_mode,
                                           'in_ports_count': 4}).create_node()

    axes_node.out_port(0).connect(interpolate_node.in_port(3))
    shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()

    add_node = create_op_with_const_inputs(graph, Add,
                                           {1: float_array([1.0e-5])},
                                           {'name': resize_name + '/Add'})

    dst_dtype = np.float32  # even if data_type=FP16 use float32 for shape values

    if not resize.is_in_port_connected(3):
        cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()
        mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node()
        shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
        cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))
        cast_add_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node()
        floor_node = Floor(graph, {'name': resize_name + '/Floor'}).create_node()
        mul_node.out_port(0).connect(add_node.in_port(0))
        add_node.out_port(0).connect(floor_node.in_port(0))
        floor_node.out_port(0).connect(cast_add_result_to_int.in_port(0))
        cast_add_result_to_int.out_port(0).connect(sizes_ss.in_port(0))
        sizes_ss.out_port(0).connect(interpolate_node.in_port(1))
        scales_ss.out_port(0).connect(interpolate_node.in_port(2))

        connection_of_resize_input = resize.in_port(0).get_connection()
        connection_of_resize_input.set_destination(interpolate_node.in_port(0))

        connection_of_scales = resize.in_port(2).get_connection()
        connection_of_scales.set_destination(scales_ss.in_port(0))

        connection_of_resize_input.get_source().connect(shape_of.in_port(0))
        connection_of_scales.get_source().connect(mul_node.in_port(1))
    else:
        cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()
        cast_sizes_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()
        div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()
        cast_sizes_to_float.out_port(0).connect(div_node.in_port(0))
        cast_shape_to_float.out_port(0).connect(div_node.in_port(1))
        shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
        div_node.out_port(0).connect(add_node.in_port(0))
        add_node.out_port(0).connect(scales_ss.in_port(0))
        scales_ss.out_port(0).connect(interpolate_node.in_port(2))
        sizes_ss.out_port(0).connect(interpolate_node.in_port(1))

        connection_of_resize_input = resize.in_port(0).get_connection()
        connection_of_resize_input.set_destination(interpolate_node.in_port(0))

        connection_of_sizes = resize.in_port(3).get_connection()
        connection_of_sizes.set_destination(sizes_ss.in_port(0))

        connection_of_resize_input.get_source().connect(shape_of.in_port(0))
        connection_of_sizes.get_source().connect(cast_sizes_to_float.in_port(0))

    rename_nodes([(resize, resize_name + '/delete'), (interpolate_node, resize_name)])
    resize.out_port(0).get_connection().set_source(interpolate_node.out_port(0))
    def quantize_data(fake_quantize: Node, dst_type: type,
                      quantized_type: type, mode: str):
        graph = fake_quantize.graph
        name = fake_quantize.soft_get('name', fake_quantize.id)
        levels = fake_quantize.levels

        quantize = fake_quantize.copy_node(
            dict(name=name + '/Copy', stop_value_propagation=False), graph)
        fake_quantize.in_port(0).get_connection().set_destination(
            quantize.in_port(0))

        # inherit input limits
        fake_quantize.in_port(1).get_connection().set_destination(
            quantize.in_port(1))
        fake_quantize.in_port(2).get_connection().set_destination(
            quantize.in_port(2))

        # calculate output limits for quantized weights
        assert mode in ["signed", "unsigned"]
        i_min_value = -(levels // 2) if mode == "signed" else 0

        i_min = np.array([i_min_value], dtype=dst_type)
        i_max = np.array(levels + i_min - 1, dtype=dst_type)

        assert i_max - i_min == levels - 1
        out_low = Const(graph, dict(name=name + '/Copy/out_low',
                                    value=i_min)).create_node()
        out_high = Const(graph, dict(name=name + '/Copy/out_high',
                                     value=i_max)).create_node()

        out_low.out_port(0).connect(quantize.in_port(3))
        out_high.out_port(0).connect(quantize.in_port(4))
        out_low.out_port(0).connect(fake_quantize.in_port(1))
        out_high.out_port(0).connect(fake_quantize.in_port(2))

        original_const = quantize.in_port(0).get_source().node
        quantized_data_name = original_const.soft_get(
            'name', original_const.id) + '/quantized'
        cast = Cast(
            graph,
            dict(name=quantized_data_name,
                 dst_type=quantized_type,
                 stop_value_propagation=False)).create_node()

        quantize.out_port(0).connect(cast.in_port(0))

        cast.out_port(0).connect(fake_quantize.in_port(0))
예제 #12
0
    def replace_pattern(self, graph: Graph, match: dict):
        gather = match['GatherNd']
        input_shape = gather.in_node(0).shape
        indices = gather.in_node(1).value
        if indices is None:
            # We can't do such special pass without indices value
            return

        # 0. All needed checks that we can replace GatherNd by Gather
        gather_idx = self.indices_check(indices, input_shape)
        if gather_idx is None:
            log.warning(
                'Node {} with op=GatherNd  can\'t be normalized to op=Gather.'.
                format(gather.name))
            return

        # 1. Add Reshape and connect
        new_shape = int64_array([-1] + list(input_shape[indices.shape[-1]:]))
        reshape = Reshape(graph, {
            'name': gather.name + '/Reshape_for_GatherNd/'
        }).create_node()
        reshape_const_node = Const(graph, {
            'name': reshape.name + '/Dim',
            'value': new_shape
        }).create_node()
        gather.in_port(0).get_connection().set_destination(reshape.in_port(0))
        reshape.in_port(1).connect(reshape_const_node.out_port(0))

        # 2. Change indices from Nd to 1d:
        new_indices = np.reshape(
            np.take(indices, indices=[gather_idx], axis=-1), [-1])
        new_indices_const = Const(graph, dict(value=new_indices)).create_node()
        axis_const = Const(graph, {'value': int64_array(0)}).create_node()

        # 3. Create new Gather operation and reconnect all inputs/outputs
        new_gather = Gather(graph, {
            'name': gather.name + '/NewGather/'
        }).create_node()
        reshape.out_port(0).connect(new_gather.in_port(0))
        new_indices_const.out_port(0).connect(new_gather.in_port(1))
        axis_const.out_port(0).connect(new_gather.in_port(2))

        gather.out_port(0).get_connection().set_source(new_gather.out_port(0))

        # 4. Remove old Gather node
        graph.remove_node(gather.id)
예제 #13
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='SpaceToBatch') + graph.get_op_nodes(
                op='BatchToSpace'):
            node.add_input_port(3, skip_if_exist=True)

            # convert TF representation of the pads/crops as [N, 2] to IE representation: [N] and [N]
            transposed_pads = create_op_with_const_inputs(
                graph, Transpose, {1: int64_array([1, 0])})
            node.in_port(2).get_connection().set_destination(
                transposed_pads.in_port(0))
            split_pads = create_op_with_const_inputs(graph, Split,
                                                     {1: int64_array(0)},
                                                     {'num_splits': 2})
            transposed_pads.out_port(0).connect(split_pads.in_port(0))
            for port_ind in range(2):
                node.in_port(port_ind + 2).connect(
                    split_pads.out_port(port_ind))
                node.in_port(port_ind + 2).get_connection().insert_node(
                    create_op_with_const_inputs(graph, Squeeze,
                                                {1: int64_array([0])}))

            # add zeros/ones to related inputs to align it with data input
            in0_rank = Rank(graph, {
                'name': node.name + '/rank_0'
            }).create_node()
            in1_shape = Shape(graph, {
                'name': node.name + '/rank_1'
            }).create_node()

            diff_size = Sub(graph, {
                'name': node.name + '/sub_0'
            }).create_node()
            diff = Sub(graph, {'name': node.name + '/sub_1'}).create_node()
            const_begin = Const(graph, {
                'value': int64_array([1])
            }).create_node()
            const_pad_val = Const(graph, {
                'value': int64_array(1)
            }).create_node()

            block_shape = Pad(graph, {
                'name': node.name + '/aligned_block_shape',
                'mode': 'constant'
            }).create_node()

            # in case of SpaceToBatch begin = pads_begin, end = pads_end
            # in case of BatchToSpace begin = crops_begin, end = crops_end
            new_begin_name = '/aligned_pads_begin'
            new_end_name = '/aligned_pads_end'
            if node.type == 'BatchToSpace':
                new_begin_name = '/aligned_crops_begin'
                new_end_name = '/aligned_crops_end'

            begin = Pad(graph, {
                'name': node.name + new_begin_name,
                'mode': 'constant'
            }).create_node()
            end = Pad(graph, {
                'name': node.name + new_end_name,
                'mode': 'constant'
            }).create_node()

            in0_rank_1d = create_op_node_with_second_input(
                graph, Unsqueeze, int64_array([0]),
                {'name': node.name + '/1d_rank_of_0'}, in0_rank)

            node.in_port(0).get_source().connect(in0_rank.in_port(0))
            node.in_port(1).get_source().connect(in1_shape.in_port(0))
            in0_rank_1d.out_port(0).connect(diff_size.in_port(0))
            in1_shape.out_port(0).connect(diff_size.in_port(1))
            diff_size.out_port(0).connect(diff.in_port(0))
            const_begin.out_port(0).connect(diff.in_port(1))
            const_pad_val.out_port(0).connect(block_shape.in_port(3))

            inputs_array = [block_shape, begin, end]
            for idx, input_to_node in enumerate(inputs_array):
                name_of_input_to_node = input_to_node.name
                node.in_port(idx + 1).get_connection().set_destination(
                    input_to_node.in_port(0))
                const_begin.out_port(0).connect(input_to_node.in_port(1))
                diff.out_port(0).connect(input_to_node.in_port(2))
                input_to_node.out_port(0).connect(node.in_port(idx + 1))
                convert = Cast(graph, {
                    'name': name_of_input_to_node + '/i64',
                    'dst_type': np.int64
                }).create_node()
                input_to_node.in_port(0).get_connection().insert_node(convert)
예제 #14
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']

        if node.name == 'iteration_number_out':
            return

        # calculate length of context when state of inference becomes meaningful
        inputs = []
        for n in graph.get_op_nodes(**{'op': 'Parameter'}):
            inputs.append(n)

        in_nodes = []
        for inp in inputs:
            for ins in inp.out_port(0).get_destinations():
                in_nodes.append(ins.node.name)

        context_len = 1
        try:
            subgraph = invert_sub_graph_between_nodes(
                graph, [node.in_port(0).get_source().node.name], in_nodes)
        except Error:
            return

        for n in subgraph:
            n_node = Node(graph, n)
            if n_node.kind == 'op' and n_node.op == 'Splice':
                context_len += len(n_node.context) - 1

        if context_len == 1:
            return

        in_node_port = node.in_port(0).get_source()
        in_node_shape = node.in_port(0).data.get_shape()
        node.in_port(0).disconnect()

        # add Select before saving state to avoid saving garbage
        select_node = Select(graph, {
            'name': 'select_' + node.name
        }).create_node()
        zero_else = Const(graph, {
            'name': 'zero_else',
            'value': np.zeros(in_node_shape)
        }).create_node()
        select_node.in_port(1).connect(in_node_port)
        select_node.in_port(2).connect(zero_else.out_port(0))

        # check if we have already appropriate iteration counter
        existing_counters = find_pattern_matches(
            graph,
            nodes=[('mem_in', dict(op='ReadValue')),
                   ('mem_in_data', dict(shape=int64_array([context_len]))),
                   ('crop_mem_in',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([1]),
                         dim=int64_array([context_len - 1]))),
                   ('crop_mem_in_data', dict()),
                   ('concat', dict(op='Concat', axis=1)),
                   ('concat_data', dict()), ('const_1', dict(op='Const')),
                   ('const_1_data', dict()), ('mem_out', dict(op='Assign')),
                   ('crop_out',
                    dict(op='Crop',
                         axis=int64_array([1]),
                         offset=int64_array([0]),
                         dim=int64_array([1]))), ('crop_out_data', dict()),
                   ('select', dict(op='Select'))],
            edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
                   ('crop_mem_in', 'crop_mem_in_data'),
                   ('crop_mem_in_data', 'concat', {
                       'in': 0
                   }), ('const_1', 'const_1_data'),
                   ('const_1_data', 'concat', {
                       'in': 1
                   }), ('concat', 'concat_data'), ('concat_data', 'mem_out'),
                   ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
                   ('crop_out_data', 'select')])
        counter_match = next(existing_counters, None)
        if counter_match is not None:
            ones = Node(graph, inverse_dict(counter_match)['const_1'])
            input_port = Node(
                graph,
                inverse_dict(counter_match)['crop_out']).out_port(0)
        else:
            init_value_mem_out = create_zero_value_with_batch_from_input(
                in_node_port, context_len, np.int32)
            mem_out = ReadValue(
                graph, {
                    'name': 'iteration_number',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
            cut_first = Crop(
                graph, {
                    'name': 'cut_first',
                    'axis': int64_array([1]),
                    'offset': int64_array([1]),
                    'dim': int64_array([context_len - 1])
                }).create_node()
            cut_first.in_port(0).connect(mem_out.out_port(0))
            ones = Const(graph, {
                'name': 'ones',
                'value': np.ones([1, 1], dtype=np.int32)
            }).create_node()
            concat = Concat(graph, {
                'name': 'concat_ones',
                'in_ports_count': 2,
                'axis': 1
            }).create_node()
            concat.in_port(0).connect(cut_first.out_port(0))
            concat.in_port(1).connect(ones.out_port(0))
            mem_in = Assign(
                graph, {
                    'name': 'iteration_number_out',
                    'variable_id': 'iteration_' + node.name
                }).create_node()
            mem_in.in_port(0).connect(concat.out_port(0))
            res = Result(graph, {}).create_node()
            mem_in.out_port(0).connect(res.in_port(0))
            cut_last = Crop(
                graph, {
                    'name': 'cut_last',
                    'axis': int64_array([1]),
                    'offset': int64_array([0]),
                    'dim': int64_array([1])
                }).create_node()
            cut_last.in_port(0).connect(concat.out_port(0))
            input_port = cut_last.out_port(0)

        # Check if data from memory is 1
        # if it is True, we have correct data and should proceed with saving it to memory
        # else we have not gathered context and have garbage here, shouldn't change initial state of memory
        cast_in = Equal(graph, {
            'name': input_port.node.name + '/cast_to_bool'
        }).create_node()
        cast_in.in_port(0).connect(ones.out_port(0))
        cast_in.in_port(1).connect(input_port)
        select_node.in_port(0).connect(cast_in.out_port(0))
        select_node.out_port(0).connect(node.in_port(0))
        select_node.out_port(0).data.set_shape(in_node_shape)
예제 #15
0
    def replace_pattern(self, graph: Graph, match: [str, Node]):
        node = match['crop']
        assert node.has_valid('axis')
        node_axis = self.list_to_ndarray(node.axis)

        in_shape = node.in_port(0).data.get_shape()
        shape_rank = in_shape.size
        axis_mask = int64_array([1 if i in node_axis else 0 for i in range(shape_rank)])
        begin_mask = axis_mask.copy()
        end_mask = axis_mask.copy()

        ss = StridedSlice(graph, {'name': node.soft_get('name', node.id) + '/strided_slice', 'begin_mask': begin_mask,
                                  'end_mask': end_mask,
                                  'new_axis_mask': np.zeros(len(end_mask)),
                                  'shrink_axis_mask': np.zeros(len(end_mask)),
                                  'ellipsis_mask': np.zeros(len(end_mask))}).create_node()

        if len(node.in_nodes()) == 2 and node.has_valid('offset'):
            # Crop Type 1
            begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, node.offset),
                                  'name': ss.name + '/begin'}).create_node()
            shape = Shape(graph, {'name': ss.name + '/shape_of_crop'}).create_node()
            end = Add(graph, {'name': ss.name + '/end'}).create_node()
            node.in_port(1).get_connection().get_source().connect(shape.in_port(0))
            node.in_port(1).disconnect()
            shape.out_port(0).connect(end.in_port(0))
            begin.out_port(0).connect(end.in_port(1))
        elif node.has_valid('dim') and node.has_valid('offset'):
            # Crop Type 2
            node_dim = self.list_to_ndarray(node.dim)
            node_offset = self.list_to_ndarray(node.offset)
            assert node_dim.size == node_offset.size == node_axis.size

            begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, node_offset),
                                  'name': ss.name + '/begin'}).create_node()
            end_values = np.array([node_offset[i] + node_dim[i] for i in range(len(node_dim))])
            end = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, end_values),
                                'name': ss.name + '/end'}).create_node()
        elif node.has_valid('crop_begin') and node.has_valid('crop_end'):
            # Crop Type 3
            node_crop_begin = self.list_to_ndarray(node.crop_begin)
            node_crop_end = self.list_to_ndarray(node.crop_end)
            assert len(node_crop_begin) == len(node_crop_end) == len(node_axis)

            begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, node_crop_begin),
                                  'name': ss.name + '/begin'}).create_node()
            shape = Shape(graph, {'name': ss.name + '/shape'}).create_node()

            end = Add(graph, {'name': ss.name + '/end'}).create_node()
            const = Const(graph, {'value': -1 * self.mask_normalizer(shape_rank, node_axis, node_crop_end),
                                  'name': ss.name + '/const'}).create_node()

            node.in_port(0).get_connection().get_source().connect(shape.in_port(0))
            shape.out_port(0).connect(end.in_port(0))
            const.out_port(0).connect(end.in_port(1))

        else:
            raise Exception("Unknown type of Crop")

        source = node.in_port(0).get_connection().get_source()

        stride = Const(graph, {'value': np.ones(shape_rank, dtype=np.int64),
                               'name': ss.name + '/stride'}).create_node()

        source.connect(ss.in_port(0))
        begin.out_port(0).connect(ss.in_port(1))
        end.out_port(0).connect(ss.in_port(2))
        stride.out_port(0).connect(ss.in_port(3))

        node.in_port(0).disconnect()
        node.out_port(0).get_connection().set_source(ss.out_port(0))

        ss['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'}
    def replace_pattern(self, graph: Graph, match: dict):
        lstm = match['lstm']

        # Build TensorIterator body first
        body = Graph(name=lstm.name + '/sub_graph')
        body.graph = graph.graph

        # 1. Input squeeze Reshape
        inputs = [
            Op._create_data_node(
                body, lstm.name + '/inport/' + str(inp), {
                    'shape':
                    lstm.in_node(inp).shape.copy(),
                    'value':
                    lstm.in_node(inp).value.copy() if lstm.in_node(inp).value
                    is not None and inp in [1, 2] else None
                }) for inp in [0, 4, 5, 1, 2]
        ]  # X, WR, B, h_init, c_init

        inputs[0].shape[lstm.sequence_dim] = 1
        input_squeeze = Squeeze(
            body, dict(name=lstm.name + '/input_squeeze', internal_layer_id=0))
        squeeze_dim_data = Const(body, {
            'name': lstm.name + '/input_squeeze_dim',
            'value': [lstm.sequence_dim]
        }).create_node_with_data()
        inputs[0] = input_squeeze.create_node_with_data(
            [inputs[0], squeeze_dim_data],
            edge_attrs=[{
                'internal_port_id': 0
            }])

        # 2. Output unsqueeze Reshape
        outputs = [
            Op._create_data_node(
                body, lstm.name + '/outport/' + str(out), {
                    'shape':
                    lstm.out_node(out).shape.copy() if out in lstm.out_nodes()
                    else lstm.in_node(4).shape.copy()
                }) for out in [0, 1]
        ]
        for out in outputs:
            add_opoutput(body, out.id, 0, False)

        outputs[0].shape = np.delete(outputs[0].shape, lstm.sequence_dim)
        output_unsqueeze = Unsqueeze(
            body, dict(name=lstm.name + 'output_unsqueeze',
                       internal_layer_id=2))
        unsqueeze_dim_data = Const(
            body, {
                'name': lstm.name + '/output_unsqueeze_dim',
                'value': [lstm.sequence_dim]
            }).create_node_with_data()

        # 3. LSTMCell
        lstm_cell_op = LSTMCell(
            body,
            dict(hidden_size=lstm.hidden_size,
                 activations=lstm.activations,
                 activation_alpha=lstm.activation_alpha,
                 activation_beta=lstm.activation_beta,
                 clip=lstm.clip,
                 input_forget=lstm.input_forget,
                 name=lstm.name + '/LSTMCell',
                 internal_layer_id=1))
        lstm_cell_node = lstm_cell_op.create_node_with_data(
            inputs,
            data_nodes=outputs,
            edge_attrs=[{}, {
                'internal_port_id': 1
            }, {
                'internal_port_id': 2
            }, {
                'bin': 'weights'
            }, {
                'bin': 'biases'
            }])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4
        lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5
        lstm_cell_node[0] = output_unsqueeze.create_node_with_data(
            [lstm_cell_node[0], unsqueeze_dim_data])
        lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3
        add_opoutput(body, lstm_cell_node[0].id, 0, False)

        # 4. TensorIterator layer creating
        assert lstm.direction in ['forward', 'reverse']
        if lstm.direction == 'forward':
            stride = 1
            start = None
            end = None
        else:
            assert lstm.direction == 'reverse'
            stride = -1
            start = -1
            end = 0

        output_port_map = [{
            'external_port_id': 3,
            'internal_layer_id': 2,
            'internal_port_id': 3,
            'axis': lstm.sequence_dim,
            'stride': stride,
            'start': start,
            'end': end,
            'part_size': 1,
        }]

        # Adding h_state, c_state to outputs
        if len(lstm.out_nodes()) == 3:
            output_port_map.extend([{
                'external_port_id': 4,
                'internal_layer_id': 1,
                'internal_port_id': 4,
            }, {
                'external_port_id': 5,
                'internal_layer_id': 1,
                'internal_port_id': 5,
            }])

        ti_op = TensorIterator(
            graph, {
                'name':
                lstm.name + '/TensorIterator',
                'body':
                body,
                'in_ports_count':
                3,
                'out_ports_count':
                len(lstm.out_nodes()),
                'input_port_map': [
                    {
                        'external_port_id': 0,
                        'internal_layer_id': 0,
                        'internal_port_id': 0,
                        'axis': lstm.sequence_dim,
                        'stride': stride,
                        'start': start,
                        'end': end,
                        'part_size': 1,
                    },
                    {
                        'external_port_id': 1,
                        'internal_layer_id': 1,
                        'internal_port_id': 1,
                    },
                    {
                        'external_port_id': 2,
                        'internal_layer_id': 1,
                        'internal_port_id': 2,
                    },
                ],
                'output_port_map':
                output_port_map,
                'back_edges': [
                    {
                        'from_layer': 1,
                        'from_port': 4,
                        'to_layer': 1,
                        'to_port': 1,
                    },
                    {
                        'from_layer': 1,
                        'from_port': 5,
                        'to_layer': 1,
                        'to_port': 2,
                    },
                ]
            })

        assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \
            "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id)

        outs = ti_op.create_node_with_data(
            [lstm.in_node(i) for i in [0, 4, 5]],  # X, h_init, c_init
            data_nodes=[
                lstm.out_node(i) for i in range(len(lstm.out_nodes()))
            ],
            edge_attrs=[{
                'external_port_id': 0
            }, {
                'external_port_id': 1
            }, {
                'external_port_id': 2
            }])

        if not isinstance(outs, list):
            outs = list([outs])

        graph.remove_node(lstm.id)
        outs[0].in_edge(0)['external_port_id'] = 3
        for i, out in enumerate(outs[1:]):
            external_port_id = 4 + i
            out.in_edge()['external_port_id'] = external_port_id

        ti = outs[0].in_node()
        TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti)
        TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti)
        TensorIterator.normalize_internal_ids(ti)
예제 #17
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['proposal']
        assert len(node.in_ports()) == 3, "Proposal op must have exactly 3 input ports"
        im_info_shape = node.in_port(2).data.get_shape()
        assert im_info_shape is not None

        if np.array_equal(im_info_shape, [1, 6]):
            log.error('The model contains Proposal layer "{}" with input of shape [1, 6]. Inference Engine '
                      'implementation of the Proposal layer uses only 4 first values (indices 0, 1, 2 and 3). '
                      'Elements with indices 4 and 5 will be ignored.'.format(node.soft_get('name', node.id)),
                      extra={'is_warning': True})
            begin = Const(graph, {'value': np.array([0, 0], dtype=np.int32)}).create_node()
            end = Const(graph, {'value': np.array([1, 3], dtype=np.int32)}).create_node()
            stride = Const(graph, {'value': np.array([1, 1], dtype=np.int32)}).create_node()

            cropped_im_info = StridedSlice(graph, {'name': 'cropped_im_info',
                                                   'begin_mask': int64_array([1, 1]),
                                                   'end_mask': int64_array([1, 1]),
                                                   'new_axis_mask': int64_array([0]),
                                                   'shrink_axis_mask': int64_array([0]),
                                                   'ellipsis_mask': int64_array([0]),
                                                   'override_output_shape': True,
                                                   }).create_node()

            node.in_port(2).get_connection().insert_node(cropped_im_info)
            begin.out_port(0).connect(cropped_im_info.in_port(1))
            end.out_port(0).connect(cropped_im_info.in_port(2))
            stride.out_port(0).connect(cropped_im_info.in_port(3))

            # update the im_info_shape so the next 'if' statement become true
            im_info_shape = int64_array([1, 3])

        if np.array_equal(im_info_shape, [1, 3]) or np.array_equal(im_info_shape, [1, 4]):
            reshape = Reshape(graph, dict(name="im_info/Reshape")).create_node()
            const = Const(graph, dict(value=[im_info_shape[1]])).create_node()
            node.in_port(2).get_connection().set_destination(reshape.in_port(0))
            const.out_port(0).connect(reshape.in_port(1))
            reshape.out_port(0).connect(node.in_port(2))

        if node.has_port('out', 1) and not node.out_port(1).disconnected():
            node['version'] = 'extension'
예제 #18
0
def parse_specifier(string, graph, layer_node_map):
    pos = string.find(b'(')
    if pos == -1:
        # node name
        input_name = str(string.split(b' ')[0]).strip('b').replace(
            "\'", '').replace('\\n', '')

        if input_name not in layer_node_map:
            node_name = graph.unique_id(prefix=input_name)
            graph.add_node(node_name, parameters=[], op="", kind='op')
            layer_node_map[input_name] = node_name
        else:
            node_name = layer_node_map[input_name]
        return node_name

    spec = string[:pos]
    args = get_args_for_specifier(string[pos:])
    if spec == b'Append':
        nodes = []
        for i in range(len(args)):
            nodes.append(parse_specifier(args[i], graph, layer_node_map))
        layer_name = 'Append_'
        for node in nodes:
            layer_name = layer_name + node + "_"

        if layer_name not in layer_node_map:
            concat_name = graph.unique_id(prefix=layer_name)
            graph.add_node(concat_name,
                           parameters=None,
                           op='concat',
                           kind='op')
            layer_node_map[layer_name] = concat_name
            i = 0
            Node(graph,
                 concat_name).add_sequence_of_ports('in', range(len(nodes)))
            for node in nodes:
                out_port = len(Node(graph, node).out_nodes())
                Node(graph, node).add_output_port(out_port)
                graph.create_edge(
                    Node(graph, node), Node(graph, concat_name), out_port, i,
                    create_edge_attrs(node, concat_name, node, i, out_port))
                i = i + 1
        else:
            concat_name = layer_node_map[layer_name]
        return concat_name
    elif spec == b'Offset':
        node = parse_specifier(args[0], graph, layer_node_map)
        t = int(args[1])
        if len(args) > 2:
            raise Error("ModelOptimizer supports only 2 arguments for Offset")
        layer_name = 'Offset_' + node + '_'
        if t < 0:
            layer_name = layer_name + '_' + str(-t)
        else:
            layer_name = layer_name + str(t)

        if layer_name not in layer_node_map:
            memory_name = graph.unique_id(prefix=layer_name)
            layer_node_map[layer_name] = memory_name
            memory_name_2 = memory_name + '_out'
            graph.add_node(memory_name,
                           parameters=dict(t=t,
                                           pair_name=memory_name_2,
                                           has_default=False),
                           op='MemoryOffset',
                           kind='op')
            out_port = len(Node(graph, node).out_nodes())
            in_port = len(Node(graph, memory_name).in_nodes())
            Node(graph, memory_name).add_input_port(in_port)
            Node(graph, node).add_output_port(out_port, skip_if_exist=True)
            graph.create_edge(
                Node(graph, node), Node(graph, memory_name), out_port, in_port,
                create_edge_attrs(node, memory_name, node, in_port, out_port))
        else:
            memory_name = layer_node_map[layer_name]
        return memory_name
    elif spec == b'Sum':
        nodes = []
        for i in range(len(args)):
            nodes.append(parse_specifier(args[i], graph, layer_node_map))

        layer_name = 'Sum_'
        for node in nodes:
            layer_name = layer_name + node + "_"

        if layer_name not in layer_node_map:
            sum_name = graph.unique_id(prefix=layer_name)
            graph.add_node(sum_name, parameters=None, op='Add', kind='op')
            layer_node_map[layer_name] = sum_name
        else:
            sum_name = layer_node_map[layer_name]

        for i, node in enumerate(nodes):
            out_port = len(Node(graph, node).out_nodes())
            Node(graph, node).add_output_port(out_port, skip_if_exist=True)
            Node(graph, sum_name).add_input_port(i)
            graph.add_edge(node, sum_name,
                           **create_edge_attrs(node, sum_name, node, i))

        return sum_name
    elif spec == b'IfDefined':
        node_id = parse_specifier(args[0], graph, layer_node_map)
        node = Node(graph, node_id)
        if node.op == 'MemoryOffset':
            node['parameters']['has_default'] = True
        return node_id
    elif spec == b'ReplaceIndex':
        node = parse_specifier(args[0], graph, layer_node_map)
        return node
    elif spec == b'Scale':
        node_name = parse_specifier(args[1], graph, layer_node_map)
        scale_value = float(args[0])
        layer_name = '{}/Mul/{}'.format(node_name, scale_value)

        if layer_name not in layer_node_map:
            scale_name = graph.unique_id(prefix=layer_name)
            scale_node = Mul(graph, {'name': scale_name}).create_node()

            layer_node_map[layer_name] = scale_name

            scale_const_name = 'Const_{}'.format(scale_value)
            const_node = Const(graph, {
                'name': scale_const_name,
                'value': float_array([scale_value])
            }).create_node()

            node = Node(graph, node_name)
            graph.create_edge(
                const_node, scale_node, 0, 0,
                create_edge_attrs(const_node.id, scale_node.id, const_node.id))
            out_port = len(node.out_nodes())
            graph.create_edge(
                node, scale_node, out_port, 1,
                create_edge_attrs(node_name, scale_node.id, node_name, 1,
                                  out_port))
        else:
            scale_name = layer_node_map[layer_name]

        return scale_name
예제 #19
0
def get_range_node_of_idxs(rank: Node, begin: int, end: int,
                           include_begin: bool = True, include_end: bool = False) -> Node:
    """
    Returns node that produces 1D output of values of range from begin to end (ex)/(in)cluding begin or end point

    :param rank: the node of 0D output shape to get rank of tensor from
    :param begin: integer value from [-rank; rank - 1]
    :param end: integer value from [-rank; +rank]
    :param include_begin: boolean flag to include or exclude start point from range output
    :param include_end: boolean flag to include or exclude end point from range output
    :return: range node producing 1D output
    """
    graph = rank.graph
    name = rank.soft_get('name', rank.id)

    start_idx = get_canonical_axis_index_node(rank, begin)
    end_idx = get_canonical_axis_index_node(rank, end)

    if not include_begin:
        const = Const(graph, {'value': int64_array([1]), 'name': name + '/exclude_begin/value'}).create_node()
        add = Add(graph, {'name': name + '/exclude_begin'}).create_node()
        start_idx.out_port(0).connect(add.in_port(0))
        const.out_port(0).connect(add.in_port(1))
        start_idx = add

    if include_end:
        const = Const(graph, {'value': int64_array([1]), 'name': name + '/including_end/value'}).create_node()
        add = Add(graph, {'name': name + '/including_end'}).create_node()
        end_idx.out_port(0).connect(add.in_port(0))
        const.out_port(0).connect(add.in_port(1))
        end_idx = add

    delta = Const(graph, {'name': name + '/delta', 'value': int64_array([1])}).create_node()
    range_node = Range(graph, {'name': name + '/range_idxs'}).create_node()

    start_idx.out_port(0).connect(range_node.in_port(0))
    end_idx.out_port(0).connect(range_node.in_port(1))
    delta.out_port(0).connect(range_node.in_port(2))

    return range_node
    def replace_sub_graph(self, graph: Graph, match: dict):
        log.debug('Matched NearestNeighborUpsampling pattern: {}'.format(
            [node.id for node in match.values()]))
        try:
            input_height = match['pack_1'].in_node(1).value.item()
            input_width = match['pack_1'].in_node(3).value.item()

            height_scale = match['mul_const'].shape[-4]
            width_scale = match['mul_const'].shape[-2]
        except Exception as ex:
            log.warning(
                'Failed to determine scaling parameters from the topology. Do not apply pattern.'
            )
            return

        reshape2_name = match['reshape_2'].name
        resample_op = Interpolate(
            graph, {
                'mode': 'nearest',
                'antialias': 0,
                'pads_begin': int64_array([0]),
                'pads_end': int64_array([0]),
                'coordinate_transformation_mode': 'half_pixel',
                'nearest_mode': 'round_prefer_floor',
                'cube_coeff': -0.75,
                'version': 'opset4',
                'name': reshape2_name + '/Resample',
                'shape_calculation_mode': 'scales',
                'in_ports_count': 4
            })
        resample_node = resample_op.create_node([match['op']])
        axes_node = Const(
            graph, {
                'name':
                resample_node.name + '/axes',
                'value':
                int64_array([2, 3])
                if graph.graph['layout'] == 'NCHW' else int64_array([1, 2])
            }).create_node()
        sizes_node = Const(
            graph, {
                'value':
                np.array(
                    [input_height * height_scale, input_width * width_scale]),
                'name':
                resample_node.name + '/target_shape'
            }).create_node()
        scales_node = Const(
            graph, {
                'value': np.array([height_scale, width_scale],
                                  dtype=np.float32),
                'name': resample_node.name + '/scales'
            }).create_node()

        match['reshape_2'].replace_node(resample_node)

        resample_node.add_input_port(1, skip_if_exist=True)
        assert resample_node.in_port(1).disconnected()
        sizes_node.out_port(0).connect(resample_node.in_port(1))
        scales_node.out_port(0).connect(resample_node.in_port(2))
        axes_node.out_port(0).connect(resample_node.in_port(3))

        graph.remove_nodes_from(
            [node.id for node in match.values() if node.id != match['op'].id])
    def find_and_replace_pattern(self, graph: Graph):
        if graph.graph['layout'] != 'NHWC':
            # we check it here because this transformation is called explicitly from the pipeline
            return

        # reshape from 4D-5D -> ND. Insert Transpose(NC(D)HW->N(D)HWC) before Reshape
        for reinterp_shape_node_id in graph.get_nodes_with_attributes(
                reinterp_shape=True):
            reinterp_shape_node = Node(graph, reinterp_shape_node_id)
            assert 0 in reinterp_shape_node.in_nodes(
            ), 'Node {} does not have 0 input. \n{}'.format(
                reinterp_shape_node_id, graph.dump_graph_for_graphviz())
            input_shape = reinterp_shape_node.in_node(0).shape
            if self.is_nchw_to_nhwc_transpose_needed(reinterp_shape_node):
                order_const = Const(
                    graph, {
                        'value':
                        PermuteAttrs().get_nchw_to_nhwc_permutation(
                            len(input_shape)).perm
                    }).create_node()
                permute_node = Transpose(
                    graph, {
                        'name':
                        reinterp_shape_node.in_port(0).get_source().node.name +
                        '/Transpose'
                    }).create_node()
                reinterp_shape_node.in_port(0).get_connection().insert_node(
                    permute_node)
                order_const.out_port(0).connect(permute_node.in_port(1))
                order_const.infer(order_const)

                # do not infer the Transpose node because it should have input data node in NCHW layout (but currently
                # it is NHWC because data node attributes has not been permuted yet) and produce output in NHWC layout
                # (which is true at this moment)
                permute_node['need_shape_inference'] = False
                # mark the Transpose output data node having correct layout so it's shape will not be permuted
                mark_output_as_in_correct_layout(permute_node, 0)

                # keep the reinterp_shape_node in NHWC layout
                mark_input_as_in_correct_layout(reinterp_shape_node, 0)
                mark_input_as_in_correct_layout(reinterp_shape_node, 1)

        # reshape from ND -> 4D-5D. Insert Transpose(N(D)HWC->NC(D)HW) after Reshape
        for reinterp_shape_node_id in graph.get_nodes_with_attributes(
                reinterp_shape=True):
            reinterp_shape_node = Node(graph, reinterp_shape_node_id)
            assert 0 in reinterp_shape_node.out_nodes(
            ), 'Node {} does not have 0 output. \n{}'.format(
                reinterp_shape_node_id, graph.dump_graph_for_graphviz())
            output_shape = reinterp_shape_node.out_node(0).shape
            if self.is_nhwc_to_nchw_transpose_needed(reinterp_shape_node):
                order_const = Const(
                    graph, {
                        'value':
                        PermuteAttrs().get_nhwc_to_nchw_permutation(
                            len(output_shape)).perm
                    }).create_node()
                permute_node = Transpose(
                    graph, {
                        'name': reinterp_shape_node.id + '/Transpose'
                    }).create_node()
                reinterp_shape_node.out_port(0).get_connection().insert_node(
                    permute_node)
                order_const.out_port(0).connect(permute_node.in_port(1))

                # the Reshape and Transpose operations should work in original (NHWC layout) so the Transpose
                # will convert it to the NCHW
                mark_input_as_in_correct_layout(permute_node, 0)
                mark_input_as_in_correct_layout(permute_node, 1)
                # do not set Transpose output data node 'correct_data_layout' attribute so the data node shape will be
                # permuted

                # keep the reinterp_shape_node in NHWC layout
                mark_output_as_in_correct_layout(reinterp_shape_node, 0)
                mark_input_as_in_correct_layout(reinterp_shape_node, 1)

                # do not re-infer the Transpose node because it output data node should be in NHWC layout to make the
                # rest of the graph consistent
                permute_node['need_shape_inference'] = False
예제 #22
0
def _fused_batch_norm_decomposition(graph: Graph,
                                    tinput: Port,
                                    toutput: Port,
                                    gamma: Port,
                                    beta: Port,
                                    mean: np.ndarray,
                                    variance: np.ndarray,
                                    can_be_fused=True):
    """
    This is common function for TF, Caffe and MXNet
    It creates Mul->Add->Mul->Add sub graph
    """
    shape = tinput.data.get_shape()
    batch_norm_name = tinput.get_connection().get_destination().node.name

    # Create first Mul & Add operations
    mul1_node = Mul(
        graph, dict(name=batch_norm_name + "/mean",
                    can_be_fused=can_be_fused)).create_node()
    add1_node = Add(
        graph,
        dict(name=batch_norm_name + "/variance",
             can_be_fused=can_be_fused)).create_node()

    const_mul1_node = Const(graph, dict(name="data_mul_",
                                        value=np.array(mean))).create_node()
    const_add1_node = Const(graph,
                            dict(name="data_add_",
                                 value=np.array(variance))).create_node()

    # Broadcast const from scalar
    # We can broadcast only when const.value is scalar
    if gamma.data.get_shape()[0] != gamma.data.get_value().shape[0]:
        value = gamma.data.get_value()
        value.resize(gamma.data.get_shape()).fill(value[0])
        gamma.data.set_value(value)

    # Create second Mul & Add
    mul2_node = Mul(
        graph, dict(name=batch_norm_name + "/gamma",
                    can_be_fused=can_be_fused)).create_node()
    add2_node = Add(
        graph, dict(name=batch_norm_name + "/beta",
                    can_be_fused=can_be_fused)).create_node()

    # Connect edges Mul1->Add1->Mul2->Add2
    tinput.get_connection().set_destination(mul1_node.in_port(0))
    mul1_node.in_port(1).get_connection().set_source(
        const_mul1_node.out_port(0))

    add1_node.in_port(0).get_connection().set_source(mul1_node.out_port(0))
    add1_node.in_port(1).get_connection().set_source(
        const_add1_node.out_port(0))

    mul2_node.in_port(0).get_connection().set_source(add1_node.out_port(0))
    gamma.get_connection().set_destination(mul2_node.in_port(1))

    add2_node.in_port(0).get_connection().set_source(mul2_node.out_port(0))
    beta.get_connection().set_destination(add2_node.in_port(1))

    toutput.get_connection().set_source(add2_node.out_port(0))
예제 #23
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        group_norm_node = match['op']
        group_norm_num_input_dims = len(group_norm_node.in_port(0).data.get_shape())

        # node computing initial GroupNorm input shape
        initial_shape_op_node = Shape(graph, {'name': group_norm_node.name + '/Shape'}).create_node()
        initial_shape_op_node.in_port(0).connect(group_norm_node.in_port(0).get_source())

        initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node)
        initial_features_dim_node = node_to_get_features_dimension_value(initial_shape_op_node)
        initial_spatial_dims_node = node_to_get_spatial_dimensions_value(initial_shape_op_node)
        group_size_node = Const(graph, {'value': int64_array([group_norm_node.num_groups]),
                                        'name': group_norm_node.name + '/GroupSize'}).create_node()

        # calculate "features // group_size" value
        reciprocal_group_size_node = Const(graph, {'value': np.array([1.0 / group_norm_node.num_groups]),
                                                   'name': group_norm_node.name + '/ReciprocalGroupSize'}).create_node()

        c_div_g_node = Mul(graph, {}).create_node()
        c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0))
        c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0))

        batch_mul_group_size_node = Mul(graph, {}).create_node()
        batch_mul_group_size_node.in_port(0).connect(initial_batch_dim_node.out_port(0))
        batch_mul_group_size_node.in_port(1).connect(group_size_node.out_port(0))

        # create new node which concatenates several dims to one
        new_shape_node = new_shape_node_from_shape_nodes([batch_mul_group_size_node, c_div_g_node,
                                                          initial_spatial_dims_node])

        reshape_for_mvn_node = Reshape(graph, {}).create_node()

        group_norm_node.in_port(0).get_connection().set_destination(reshape_for_mvn_node.in_port(0))
        reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0))

        # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc
        gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64)
        gamma_beta_shape[1] = -1

        gamma_value = group_norm_node.in_port(1).get_source().data.get_value()
        beta_value = group_norm_node.in_port(2).get_source().data.get_value()
        assert gamma_value is not None, 'The gamma should be constant'
        assert beta_value is not None, 'The beta should be constant'
        gamma_value = np.reshape(gamma_value, gamma_beta_shape)
        group_norm_node.in_port(1).get_source().data.set_value(gamma_value)
        beta_value = np.reshape(beta_value, gamma_beta_shape)
        group_norm_node.in_port(2).get_source().data.set_value(beta_value)

        # MVN
        mvn_node = MVN(graph, {'name': group_norm_node.name + '/MVN',
                               'across_channels': 1,
                               'normalize_variance': 1,
                               'eps': group_norm_node.eps}).create_node()
        mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0))

        # reshape to the initial shape before multiplying with gamma and adding beta
        reshape_to_initial_shape_node = Reshape(graph, {}).create_node()
        reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0))
        reshape_to_initial_shape_node.in_port(1).connect(initial_shape_op_node.out_port(0))

        mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node()
        mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0))
        group_norm_node.in_port(1).get_connection().set_destination(mul_node.in_port(1))

        add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node()
        add_node.in_port(0).connect(mul_node.out_port(0))
        group_norm_node.in_port(2).get_connection().set_destination(add_node.in_port(1))

        group_norm_node.out_port(0).get_connection().set_source(add_node.out_port(0))
예제 #24
0
 def replace_sub_graph(self, graph: Graph, match: dict):
     node = match['op']
     if node.has_valid('reps'):
         tile_array = Const(graph, dict(value=int64_array(node.reps),
                                        symbol_dict={'name': node.id + '/tile_array'})).create_node()
         node.in_port(1).get_connection().set_source(tile_array.out_port(0))
    def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
        reshape_classes_node = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            dict(name='do_reshape_classes'),
            match.single_input_node(1)[0])

        initial_priors_node = match.single_input_node(2)[0]
        priors_name = initial_priors_node.soft_get('name',
                                                   initial_priors_node.id)
        # model calculates identical prior boxes for each batch, so we take first slice of them
        begin = Const(graph, {
            'value': np.array([0, 0, 0], dtype=np.int32)
        }).create_node()
        end = Const(graph, {
            'value': np.array([1, 0, 0], dtype=np.int32)
        }).create_node()
        stride = Const(graph, {
            'value': np.array([1, 1, 1], dtype=np.int32)
        }).create_node()

        priors_node = StridedSlice(
            graph, {
                'name': priors_name + '/0_batch_slice',
                'begin_mask': np.array([1, 1, 1], dtype=np.int32),
                'end_mask': np.array([1, 0, 0], dtype=np.int32),
                'new_axis_mask': np.array([0], dtype=np.int32),
                'shrink_axis_mask': np.array([0], dtype=np.int32),
                'ellipsis_mask': np.array([0], dtype=np.int32)
            }).create_node()

        initial_priors_node.out_port(0).connect(priors_node.in_port(0))
        begin.out_port(0).connect(priors_node.in_port(1))
        end.out_port(0).connect(priors_node.in_port(2))
        stride.out_port(0).connect(priors_node.in_port(3))

        placeholders = graph.get_op_nodes(type='Parameter')
        assert len(placeholders) == 1, "{} replacer requires model to have one Placeholder, but current model has " \
                                       "{} placeholders".format(self.replacement_id, len(placeholders))
        placeholder = placeholders[0]

        # scale prior boxes to the [0, 1] interval
        node_with_scales_for_prior_boxes = self.placeholder_scales(placeholder)
        priors_scale_node = Mul(graph, {'name': 'scale_priors'}).create_node()

        broadcast = Broadcast(graph, {
            'name': 'scales_broadcast'
        }).create_node()
        shape_of_priors = Shape(graph, {'name': 'priors_shape'}).create_node()
        priors_node.out_port(0).connect(shape_of_priors.in_port(0))
        broadcast.in_port(1).connect(shape_of_priors.out_port(0))
        broadcast.in_port(0).connect(
            node_with_scales_for_prior_boxes.out_port(0))

        priors_scale_node.in_port(0).connect(priors_node.out_port(0))
        priors_scale_node.in_port(1).connect(broadcast.out_port(0))

        try:
            variance = match.custom_replacement_desc.custom_attributes[
                'variance']
        except:
            raise Error(
                'There is no variance attribute in {} replacement config file `custom_attributes`'
                ''.format(self.replacement_id))

        priors = self.append_variances(priors_scale_node, variance)

        # calculate prior boxes widths and heights
        split_node = create_op_with_const_inputs(graph, VariadicSplit, {
            1: int64_array(2),
            2: int64_array([1, 1, 1, 1])
        }, {'out_ports_count': 4}, priors_scale_node)

        priors_width_node = Sub(
            graph, dict(name=split_node.name + '/sub_2-0_')).create_node([
                (split_node, 2), (split_node, 0)
            ])
        priors_height_node = Sub(graph, dict(name=split_node.name +
                                             '/sub_3-1_')).create_node([
                                                 (split_node, 3),
                                                 (split_node, 1)
                                             ])

        # concat weights and heights into a single tensor and multiple with the box coordinates regression values
        # WA with 3 Concats instead of 1 for keeping model reshapable
        # concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1,
        #                                           'in_ports_count': 4}).create_node(
        # [priors_width_node, priors_height_node, priors_width_node, priors_height_node])

        concat_1 = Concat(graph, {
            'name': 'concat_width_height',
            'axis': -1,
            'in_ports_count': 2
        }).create_node([priors_width_node, priors_height_node])
        concat_2 = Concat(graph, {
            'name': 'concat_width_height_width',
            'axis': -1,
            'in_ports_count': 2
        }).create_node([concat_1, priors_width_node])
        concat_width_height_node = Concat(graph, {
            'name': 'concat_priors_width_height',
            'axis': -1,
            'in_ports_count': 2
        }).create_node([concat_2, priors_height_node])

        applied_width_height_regressions_node = Mul(graph, {
            'name': 'final_regressions'
        }).create_node(
            [concat_width_height_node,
             match.single_input_node(0)[0]])

        # reshape to 2D tensor as Inference Engine Detection Output layer expects
        reshape_regression_node = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, -1]),
            dict(name='reshape_regression'),
            applied_width_height_regressions_node)

        detection_output_op = DetectionOutput(
            graph, match.custom_replacement_desc.custom_attributes)
        detection_output_node = detection_output_op.create_node(
            [reshape_regression_node, reshape_classes_node, priors],
            dict(name=detection_output_op.attrs['type'],
                 clip_after_nms=1,
                 normalized=1,
                 variance_encoded_in_target=0,
                 background_label_id=1000))

        return {'detection_output_node': detection_output_node}
    def find_and_replace_pattern(self, graph: Graph):
        # 1. Inserting Gather to N*C format on constant shape paths
        #   - Search for Shape ops
        #   - Inserting Gather after them in case of [4] or [5] output shape

        shape_ops = graph.get_op_nodes(op='ShapeOf')
        constant_shape_paths = set()
        gather_inserted = []

        for shape in shape_ops:
            output_port = shape.in_port(0).get_source()
            if is_output_data_in_correct_layout(output_port.node, output_port.idx):
                continue
            shape_of_shape_op_output = shape.out_node().shape

            if np.array_equal(shape_of_shape_op_output, [4]):
                index = np.array([0, 2, 3, 1])
            elif np.array_equal(shape_of_shape_op_output, [5]):
                index = np.array([0, 2, 3, 4, 1])
            else:
                continue

            const = Const(graph, {'value': index}).create_node()
            gather = Gather(graph, {'name': shape.name + '/GatherNCHWtoNHWC'}).create_node()

            shape.out_port(0).get_connection().set_source(gather.out_port(0))
            shape.out_port(0).connect(gather.in_port(0))
            const.out_port(0).connect(gather.in_port(1))

            constant_shape_paths.add(gather.id)
            gather_inserted.append(gather.id)

        # 2. Inserting Gather to NC* format
        #   - Search from Shape ops found in previous step for nodes without value that are n-th children of Shape op
        #       * MO can not propagate value, there is data path
        #   - Inserting Gather on ports which comes from operations in `constant_shape_paths` list

        constant_shape_ends = []

        for shape in shape_ops:
            constant_shape_ends.extend(self.search_of_constant_path_end(graph, node_name=shape.id,
                                                                        visited=constant_shape_paths))

        for end in constant_shape_ends:
            node = Node(graph, end)
            in_ports = [in_port for in_port in node.in_ports().values()
                        if in_port.get_source().node.id in constant_shape_paths]

            for in_port in in_ports:
                shape = in_port.data.get_shape()

                if np.array_equal(shape, [4]):
                    index = np.array([0, 3, 1, 2])
                elif np.array_equal(shape, [5]):
                    index = np.array([0, 2, 3, 4, 1])
                else:
                    continue

                const = Const(graph, {'value': np.array(index)}).create_node()
                gather = Gather(graph, {'name': node.name + '/GatherNHWCtoNCHW'}).create_node()

                in_port.get_source().connect(gather.in_port(0))
                in_port.get_connection().set_source(gather.out_port(0))
                const.out_port(0).connect(gather.in_port(1))
예제 #27
0
 def replace_op(self, graph: Graph, node: Node):
     const = Const(graph, dict(value=np.array(-1), name=node.name + '/reciprocal_pow_const_')).create_node()
     reciprocal = Pow(graph, {'name': node.name + '/reciprocal_pow_'}).create_node()
     node.in_port(0).get_connection().set_destination(reciprocal.in_port(0))
     const.out_port(0).connect(reciprocal.in_port(1))
     return [reciprocal.id]
예제 #28
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['pb']
        name = node.soft_get('name', node.id)

        graph.graph['cmd_params'].keep_shape_ops = True

        assert len(node.in_ports()) == 2

        begin = Const(graph, {'value': np.array([2], dtype=np.int32)}).create_node()
        end = Const(graph, {'value': np.array([4], dtype=np.int32)}).create_node()
        stride = Const(graph, {'value': np.array([1], dtype=np.int32)}).create_node()

        shape_0 = Shape(graph, {'name': node.name + '/0_port'}).create_node()
        ss_0 = StridedSlice(graph, {'name': node.name + '/ss_0_port',
                                    'begin_mask': np.array([1], dtype=np.int32),
                                    'end_mask': np.array([0], dtype=np.int32),
                                    'new_axis_mask': np.array([0], dtype=np.int32),
                                    'shrink_axis_mask': np.array([0], dtype=np.int32),
                                    'ellipsis_mask': np.array([0], dtype=np.int32)}).create_node()

        shape_0.out_port(0).connect(ss_0.in_port(0))
        begin.out_port(0).connect(ss_0.in_port(1))
        end.out_port(0).connect(ss_0.in_port(2))
        stride.out_port(0).connect(ss_0.in_port(3))

        source = node.in_port(0).get_connection().get_source()
        node.in_port(0).disconnect()
        source.connect(shape_0.in_port(0))
        ss_0.out_port(0).connect(node.in_port(0))

        shape_1 = Shape(graph, {'name': node.name + '/1_port'}).create_node()
        ss_1 = StridedSlice(graph, {'name': node.name + '/ss_1_port',
                                    'begin_mask': np.array([1], dtype=np.int32),
                                    'end_mask': np.array([0], dtype=np.int32),
                                    'new_axis_mask': np.array([0], dtype=np.int32),
                                    'shrink_axis_mask': np.array([0], dtype=np.int32),
                                    'ellipsis_mask': np.array([0], dtype=np.int32)}).create_node()

        shape_1.out_port(0).connect(ss_1.in_port(0))
        begin.out_port(0).connect(ss_1.in_port(1))
        end.out_port(0).connect(ss_1.in_port(2))
        stride.out_port(0).connect(ss_1.in_port(3))

        source = node.in_port(1).get_connection().get_source()
        node.in_port(1).disconnect()
        source.connect(shape_1.in_port(0))
        ss_1.out_port(0).connect(node.in_port(1))

        ss_0['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'}
        ss_1['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'}

        node['need_shape_inference'] = True
        node['override_output_shape'] = True
        node['V10_infer'] = True
        unsqueeze = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), {'name': name + '/unsqueeze'})
        naked_priorbox_name = name + '/naked_not_unsqueezed'
        rename_nodes([(node, naked_priorbox_name), (unsqueeze, name)])

        node.out_port(0).get_connection().set_source(unsqueeze.out_port(0))
        node.out_port(0).connect(unsqueeze.in_port(0))
예제 #29
0
    def replace_pattern(self, graph: Graph, match: dict):
        assert match['operator'].has('multiplication_transparent_ports')

        port = match['operator'].input_ports_with(match['quantized'])
        assert len(port) >= 1
        if len(port) > 1:
            log.debug(
                'BinarizeWeightsM1P1 cannot apply transformation for data {} because it consumed more'
                ' than once'.format(match['quantized'].name))
            return

        assert len(port) == 1
        port = port[0]
        applicable = [
            pair for pair in match['operator'].multiplication_transparent_ports
            if pair[0] == port
        ]
        if len(applicable) == 0:
            return

        # Look at 3-rd and 4-th inputs of FakeQuantize -- they have constants that should be passed through.
        # Assume that the constant that should be passed through is a scalar.
        quantize = match['quantize']
        output_low = quantize.in_node(3)
        output_high = quantize.in_node(4)

        quantize_name = quantize.soft_get('name', quantize.id)

        if not output_low.has_valid('value') and not output_high.has_valid(
                'value'):
            return

        output_low = output_low.value
        output_high = output_high.value

        # This pass is applicable for binarization only. Other intX variants are not relevant.
        if quantize.levels != 2:
            return

        # Recognize two cases: 0/+1 and -1/+1.
        zp1 = np.all(output_low == 0) or np.all(output_high == 0)
        m1p1 = np.all(-output_low == output_high)
        if (not zp1 and not m1p1) or (zp1 and m1p1):
            log.debug(
                'BinarizeWeightsM1P1 cannot apply transformation for data {} because it does\'t has one of'
                ' 0/+1 or -1/+1 forms.'.format(match['quantized'].name))
            return

        # TODO: Extract real scalar from 3rd and 4th inputs; reusing original tensors is dangerous because
        #       it may have incompatible shape.

        mult_term = quantize.in_node(3) if np.all(
            output_high == 0) else quantize.in_node(4)

        new_shape = Const(
            graph, {
                'name': quantize_name + '/Reshape/Shape',
                'value': int64_array([-1, 1, 1])
            }).create_node_with_data()
        reshape = Reshape(graph, {
            'name': quantize_name + '/Reshape'
        }).create_node_with_data([mult_term, new_shape])

        # Patch inflow path (by diving by mult_term)
        # Put a new Pow/Mul combination here:
        #       ---->---- (here)---> data ---> [3rd/4th ports]quantize ---> quantized ---> operator

        if len(match['quantized'].out_nodes()) > 1:
            log.debug(
                'BinarizeWeightsM1P1: len(match[\'quantized\'].out_nodes()) > 1'
            )
            return
        power_of_exponent = Const(graph, {
            'name': quantize_name + '/DivNormalize/Power',
            'value': np.array(-1.0)
        }).create_node_with_data()
        div_op = Pow(graph, {'name': quantize_name + '/DivNormalize'})
        div_output = div_op.create_node_with_data(
            [mult_term, power_of_exponent])

        for i in [3, 4]:
            match['quantize'].insert_node_with_data_before(
                match['quantize'].in_node(i),
                Mul,
                dict(name=quantize_name + '/MulNormalize'),
                additional_inputs=[div_output],
            )

        match[
            'quantized'].value = None  # reset value because it will be recomputed
        match['quantize'].infer(match['quantize'])

        # Put a complimentary new Mul node here:   operator -->---(here)-----> operator.out_node()

        match['operator'].insert_node_with_data_after(
            match['operator'].out_node(),
            Mul,
            dict(name=match['operator'].name + '/MulNormalize'),
            [reshape],
        )

        # Disable 'operator' fusion with linear ops, otherwise it will annihilate changes that we just made
        match['operator']['can_be_fused'] = False
예제 #30
0
 def extract(cls, node):
     value = node.value
     attrs = {'data_type': value.dtype, 'value': value}
     Const.update_node_stat(node, attrs)
     return cls.enabled