コード例 #1
0
    def __call__(self, first: Node, second: Node) -> bool:
        """
        This function checks whether Interpolate nodes 'first' and 'second' can be fused.
        :param first: the first of fused nodes
        :param second: the second of fused nodes
        :return: True, if nodes can be fused, and False otherwise
        """
        if not (is_next(first, second)
                and self._compare_attributes(first, second)):
            self.accumulated_axes = set()
            return False

        fst_axes = set([a for a in Interpolate.get_axes(first)])
        snd_axes = set([a for a in Interpolate.get_axes(second)])

        self.accumulated_axes = self.accumulated_axes | fst_axes

        # If the set of accumulated axes and the set of axes of 'second' do not intersect then nodes can be fused,
        # because interpolations with respect to various axes do not affect each other.
        if not (self.accumulated_axes & snd_axes):
            return True

        # Otherwise, nodes cannot be fused.
        self.accumulated_axes = set()
        return False
コード例 #2
0
    def make_interpolate_reshape_able(self, interpolate: Node, concat: Node):
        assert interpolate.soft_get('type') == 'Interpolate'
        assert concat.soft_get('type') == 'Concat'
        interp_axes = Interpolate.get_axes(interpolate)
        concat_axis = self.get_concat_axis(concat)

        if concat_axis is None or interp_axes is None \
                or np.any(interp_axes < 0) or concat_axis < 0 \
                or concat_axis in interp_axes:
            # checks that interpolate axes and concat axis are valid and do not intersect
            return

        non_interp_concat_srcs = self.get_non_interpolate_concat_sources(
            concat)
        if not len(non_interp_concat_srcs):
            # there is no Concat input to take input from
            return

        graph = interpolate.graph
        src = non_interp_concat_srcs[0]

        shape = Shape(graph, {
            'name': src.node.soft_get('name', src.node.id) + '/Shape'
        }).create_node()
        shape.in_port(0).connect(src)
        gather = create_op_with_const_inputs(
            graph,
            Gather, {
                1: np.array(interp_axes, dtype=np.int32),
                2: int64_array(0)
            }, {'name': shape.name + '/Gathered'},
            input_node=shape)
        interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
コード例 #3
0
    def make_interpolate_reshapeable(interpolate, concat):
        assert interpolate.soft_get('type') == 'Interpolate'
        assert concat.soft_get('type') == 'Concat'

        output_shape = interpolate.out_port(0).data.get_shape()

        interp_axes = [get_canonical_axis_index(output_shape, axis) for axis in Interpolate.get_axes(interpolate)]
        concat_axis = get_canonical_axis_index(output_shape, concat.axis)
        if concat_axis in interp_axes:
            return

        concat_srcs = [port.get_source() for port in concat.in_ports().values() if not port.disconnected()]
        non_interp_concat_srcs = [src for src in concat_srcs if src.node.soft_get('type') != 'Interpolate']
        if len(non_interp_concat_srcs) == 0:
            return

        graph = interpolate.graph
        src = non_interp_concat_srcs[0]

        shape = Shape(graph, {'name': src.node.soft_get('name', src.node.id) + '/Shape'}).create_node()
        shape.in_port(0).connect(src)
        gather = create_op_with_const_inputs(graph, Gather,
                                             {1: np.array(interp_axes, dtype=np.int32), 2: int64_array(0)},
                                             {'name': shape.name + '/Gathered'}, shape)
        interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
コード例 #4
0
 def make_interpolate_reshapeable(interpolate):
     assert interpolate.soft_get('type') == 'Interpolate'
     axes = Interpolate.get_axes(interpolate)
     input_shape = interpolate.in_port(0).data.get_shape()
     output_shape = interpolate.out_port(0).data.get_shape()
     if not np.all(np.remainder(output_shape, input_shape) == 0) and \
             not np.all(np.remainder(input_shape, output_shape) == 0):
         return
     graph = interpolate.graph
     name = interpolate.soft_get('name', interpolate.id)
     shape = Shape(graph, {'name': name + '/ShapeOf'}).create_node()
     shape.in_port(0).connect(interpolate.in_port(0).get_source())
     gather = create_op_with_const_inputs(graph, Gather, {1: np.array(axes, dtype=np.int32), 2: int64_array(0)},
                                          {'name': shape.name + '/Gathered'}, shape)
     multipliers = output_shape[axes] / input_shape[axes]
     mul = create_op_node_with_second_input(graph, Mul, multipliers, {'name': gather.name + '/Multiplied'}, gather)
     interpolate.in_port(1).get_connection().set_source(mul.out_port(0))
コード例 #5
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        interpolate = match['interpolate']
        transpose_1 = match['transpose_1']
        transpose_2 = match['transpose_2']

        axes = Interpolate.get_axes(interpolate)
        if axes is None or not np.array_equal(axes, int64_array([1, 2])):
            return

        # because we remove Transpose layers the ResizeNearestNeighbor should be updated for NCHW layout
        opset = interpolate.get_opset()
        assert opset in ['opset1', 'opset4'], \
            'Interpolate node with name {} has unsupported opset'.format(interpolate.soft_get('name', interpolate.id))
        if opset == 'opset1':
            interpolate.axes = int64_array([2, 3])
        else:
            interpolate.in_port(3).data.set_value(int64_array([2, 3]))

        transpose_1.in_port(0).get_connection().set_destination(
            interpolate.in_port(0))
        transpose_2.out_port(0).get_connection().set_source(
            interpolate.out_port(0))

        graph.remove_nodes_from([transpose_1.id, transpose_2.id])
コード例 #6
0
def replace_sequence(seq: List[Node], graph: Graph):
    """
    This function replaces a sequence of consecutive Interpolate layers with one Interpolate layer,
    if modes of all nodes of a sequence are the same.
    :param seq: sequence of Interpolate layers
    :param graph: graph to which nodes of seq belong
    :return: Nothing
    """
    if not seq:
        return
    if len(seq) == 1:
        return

    modes = set([n.mode for n in seq])
    if len(modes) != 1:
        return

    dims_and_scales_ = []
    # Each element of the list dims_and_scales_ is a pair
    #      (axis, output size for this axis) (opset1)
    # or
    #      (axis, output size for this axis, output scales for this axis) (opset4)
    if seq[0].get_opset() == 'opset1':
        for interp in seq:
            dims_and_scales_.extend(
                zip(
                    Interpolate.get_axes(interp),
                    interp.in_port(
                        1).get_connection().get_source().data.get_value()))

        axis_to_size = sorted(list(dict(dims_and_scales_).items()),
                              key=lambda x: x[0])
        axes_of_node = int64_array([z[0] for z in axis_to_size])
        sizes = shape_array([z[1] for z in axis_to_size])
        scales = np.ones(len(axis_to_size), dtype=np.float32)
    else:
        for interp in seq:
            dims_and_scales_.extend(
                zip(
                    Interpolate.get_axes(interp),
                    interp.in_port(
                        1).get_connection().get_source().data.get_value(),
                    interp.in_port(
                        2).get_connection().get_source().data.get_value()))

        axis_to_size = sorted(dims_and_scales_, key=lambda x: x[0])
        axes_of_node = int64_array([z[0] for z in axis_to_size])
        sizes = shape_array([z[1] for z in axis_to_size])
        scales = mo_array([z[2] for z in axis_to_size])

    fst_interp_node = seq[0]
    last_interp_node = seq[-1]
    last_interp_node_name = last_interp_node.soft_get('name',
                                                      last_interp_node.id)
    attributes = get_interpolate_attributes(fst_interp_node)

    opset = fst_interp_node.get_opset()
    if opset == 'opset1':
        attributes['axes'] = axes_of_node
        interp_node = create_op_with_const_inputs(graph, Interpolate,
                                                  {1: sizes}, attributes)

        fst_interp_connection = fst_interp_node.in_port(0).get_connection()
        fst_interp_connection.set_destination(interp_node.in_port(0))

        last_interp_node.out_port(0).get_connection().set_source(
            interp_node.out_port(0))
    else:
        attributes['in_ports_count'] = 4
        interp_node = create_op_with_const_inputs(graph, Interpolate, {
            1: sizes,
            2: scales,
            3: axes_of_node
        }, attributes)

        fst_interp_connection = fst_interp_node.in_port(0).get_connection()
        fst_interp_connection.set_destination(interp_node.in_port(0))

        last_interp_node.out_port(0).get_connection().set_source(
            interp_node.out_port(0))

    rename_nodes([(last_interp_node, last_interp_node_name + '/delete'),
                  (interp_node, last_interp_node_name)])