Exemplo n.º 1
0
def order(op_node: Node, port_info: str, input_port: int):
    """
        Performs layout change related transformation of the data on the in_port_idx port of op_node.
        Translates ordered shape indexes from one layout to another according to permutation

        Transformation inserts two Gather operations

        1 Gather reorders data to new layout according to direct permutation:
            actual data to translate as 1-port input indexes of Gather and
            permutation as 0-port input data
        2 Gather translates shape indexes from one layout to another according to inverse permutation
            permutation as 0-port input data and
            actual data to translate as 1-port input indexes of Gather

    For example:
        NHWC Transpose operation has 0-port input with data of shape [1, 2, 3, 4] and
        1-port input with new order indices [0, 1, 3, 2].

        After translating such operation to NCHW layout:
            0-port input shape = [1, 4, 2, 3]

        1 phase (after first Gather insertion):
            1-port input order indices = [0, 2, 1, 3]
        2 phase (after second Gather insertion):
            1-port input order indices = [0, 3, 2, 1]
    """
    graph = op_node.graph
    permutation_data_node = get_node_with_permutation(op_node, port_info)
    assert permutation_data_node.has_and_set('permutation'), 'Data node "{}" does not have permutation for node {}, ' \
                                                             'port_info "{}".'.format(permutation_data_node.id,
                                                                                      op_node.id, port_info)
    permutation = permutation_data_node.permutation
    if len(permutation.perm) == 0:
        return

    data_node = op_node.in_node(input_port)

    const = Const(graph, {
        'value': permutation.perm,
        'need_shape_inference': True
    }).create_node_with_data()
    gather = Gather(graph, {
        'name': op_node.name + '/OrderGather_1',
        'need_shape_inference': True
    }).create_node_with_data([data_node, const])

    const_1 = Const(graph, {
        'value': permutation.inv,
        'need_shape_inference': True
    }).create_node_with_data()
    gather_1 = Gather(graph, {
        'name': op_node.name + '/OrderGather_2',
        'need_shape_inference': True
    }).create_node_with_data([const_1, gather])

    attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy()
    graph.add_edge(gather_1.id, op_node.id, **attrs)
    graph.remove_edge(data_node.id, op_node.id)
    op_node['need_shape_inference'] = True
Exemplo n.º 2
0
    def extract(node):

        attrs = {
            'axis': np.array(onnx_attr(node, 'axis', 'i', default=0), dtype=np.int64)
        }

        Gather.update_node_stat(node, attrs)
        return __class__.enabled
    def replace_op(self, graph: Graph, node: Node):
        pb = node.parameters
        weights_size = read_binary_integer32_token(pb)
        weights = read_blob(pb, weights_size, dtype=np.int32) - 1
        const_attrs = {
            'name': 'indexes/{}'.format(node.id),
            'value': np.array(weights),
            'shape': [weights_size],
            'data_type': np.int32
        }
        indexes_node = Const(graph).create_node(attrs=const_attrs)

        perm_in_1 = Const(
            graph, {
                'value': np.array([1, 0], dtype=np.int64),
                'shape': [2],
                'data_type': np.int64
            }).create_node()
        axis_const = Const(graph, {'value': int64_array(0)}).create_node()
        perm1_node = Transpose(graph, {
            'name': 'input_permute'
        }).create_node([node.in_node(0)])
        perm1_node.in_port(0).connect(node.in_port(0).get_source())
        perm1_node.in_port(1).connect(perm_in_1.out_port(0))

        gather_node = Gather(graph, {}).create_node()
        gather_node.in_port(0).connect(perm1_node.out_port(0))
        gather_node.in_port(1).connect(indexes_node.out_port(0))
        gather_node.in_port(2).connect(axis_const.out_port(0))

        perm2_node = Transpose(graph, {'name': 'output_permute'}).create_node()
        perm2_node.in_port(0).connect(gather_node.out_port(0))
        perm2_node.in_port(1).connect(perm_in_1.out_port(0))

        return [perm2_node.id]
Exemplo n.º 4
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']
        name = node.soft_get('name', node.id)
        assert node.has_valid('axis')

        axis = Const(graph, {'name': name + '/axis', 'value': int64_array(node.axis)}).create_node()
        gather = Gather(graph, {'name': name}).create_node()
        node.in_port(0).get_connection().set_destination(gather.in_port(0))
        node.in_port(1).get_connection().set_destination(gather.in_port(1))
        axis.out_port(0).connect(gather.in_port(2))
        node.out_port(0).get_connection().set_source(gather.out_port(0))
Exemplo n.º 5
0
 def replace_sub_graph(self, graph: Graph, match: dict):
     node = match['op']
     gather_node = Gather(
         graph,
         dict(name=node.id + '/embedding_',
              axis=0,
              symbol_dict={'name': node.id + '/embedding_'})).create_node()
     node.in_port(0).get_connection().set_destination(
         gather_node.in_port(1))
     node.in_port(1).get_connection().set_destination(
         gather_node.in_port(0))
     node.out_port(0).get_connection().set_source(gather_node.out_port(0))
Exemplo n.º 6
0
    def test_gather_infer(self):
        graph = self._create_graph()

        gather_node = Node(graph, 'gather_node')
        Gather.infer(gather_node)

        exp_shape = int64_array([2, 15])
        res_shape = graph.node['gather_output']['shape']
        res_value = graph.node['gather_output']['value']

        self.assertTrue(np.array_equal(exp_shape, res_shape),
                        'shapes do not match expected: {} and given: {}'.format(exp_shape, res_shape))

        self.assertTrue(np.array_equal(res_value, np.ones(exp_shape)),
                        'shapes do not match expected: {} and given: {}'.format(exp_shape, res_shape))
Exemplo n.º 7
0
def shape(op_node: Node, port_info: str, input_port: int):
    graph = op_node.graph
    permutation_data_node = get_node_with_permutation(op_node, port_info)
    assert permutation_data_node.has_and_set('permutation'), 'Data node "{}" does not have permutation for node {}, ' \
                                                             'port_info "{}".'.format(permutation_data_node.id,
                                                                                      op_node.id, port_info)
    permutation = permutation_data_node.permutation
    if len(permutation.perm) == 0:
        return

    data_node = op_node.in_node(input_port)

    gather_name = op_node.soft_get('name', op_node.id) + '/ShapeGather'
    const = Const(
        graph, {
            'value': permutation.perm,
            'name': gather_name + '/const',
            'need_shape_inference': True
        }).create_node_with_data()
    axis_const = Const(graph, {
        'value': int64_array(0),
        'name': gather_name + '/axis'
    }).create_node_with_data()
    gather = Gather(graph, {
        'name': gather_name,
        'need_shape_inference': True
    }).create_node_with_data([data_node, const, axis_const])
    attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy()

    graph.add_edge(gather.id, op_node.id, **attrs)
    graph.remove_edge(data_node.id, op_node.id)

    # need to run manually to override output shape value to resolve shape collision for nodes with
    # 'correct_data_layout' output port attrs
    op_node['need_shape_inference'] = True
Exemplo n.º 8
0
 def replace_op(self, graph: Graph, node: Node):
     axis = Const(graph, {'value': 0}).create_node()
     inputs = [node.in_node(1),  # weight
               node.in_node(0),  # input_ids
               axis]
     gather = Gather(graph, dict(name=node.name)).create_node(inputs)
     return [gather.id]
Exemplo n.º 9
0
def get_shape_values_by_indices_node(shape_node: Node, indices_node: Node) -> Node:
    """
    The function returns a node that produces values of the specified indices node of the input node 'shape_node'

    :param shape_node: the node of 1D output shape to get elements from
    :param indices_node: the node of 1D output shape with the list of element indices to get
    :return: node producing required elements of the node
    """
    graph = shape_node.graph
    axis = Const(graph, {'value': int64_array(0), 'name': shape_node.name + '/Axis'}).create_node()
    gather_node = Gather(graph, {'name': shape_node.name + '/Gather'}).create_node()

    shape_node.out_port(0).connect(gather_node.in_port(0))
    indices_node.out_port(0).connect(gather_node.in_port(1))
    axis.out_port(0).connect(gather_node.in_port(2))
    return gather_node
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        if not node.has_port('in', 2) or node.in_port(
                2).disconnected() or not node.has_and_set('shape_input'):
            return

        if node.has_valid('layout') and not node.layout.startswith(
                'NC') and graph.graph['layout'] == 'NCHW':
            input_shape_rank = len(node.in_port(0).data.get_shape())
            permutation = PermuteAttrs.get_nhwc_to_nchw_permutation(
                input_shape_rank)

            data_node = node.in_node(2)

            name = node.soft_get('name', node.id) + '/ShapeGather'
            const = Const(
                graph, {
                    'value': permutation.perm,
                    'name': name + '/Const',
                    'need_shape_inference': True
                }).create_node_with_data()
            axis_const = Const(graph, {
                'value': int64_array(0),
                'name': name + '/Axis'
            }).create_node_with_data()
            gather = Gather(graph, {
                'name': name,
                'need_shape_inference': True
            }).create_node_with_data([data_node, const, axis_const])
            attrs = graph.get_edge_data(data_node.id, node.id, key=0).copy()

            graph.add_edge(gather.id, node.id, **attrs)
            graph.remove_edge(data_node.id, node.id)
Exemplo n.º 11
0
def reorder_inputs_for_shape_or_slice(op_node: Node, input_port: int,
                                      permute_indices_for_gather: list):
    """
    axis and slice permutations are almost the same the only difference is that for slice in general
    case permutation depends from slice_rank not from input_rank or output_rank
    """
    graph = op_node.graph
    data_node = op_node.in_node(input_port)

    gather_name = op_node.soft_get('name', op_node.id) + '/ShapeGather'
    const = Const(
        graph, {
            'value': permute_indices_for_gather,
            'name': gather_name + '/const',
            'need_shape_inference': True
        }).create_node_with_data()
    axis_const = Const(graph, {
        'value': int64_array(0),
        'name': gather_name + '/axis'
    }).create_node_with_data()
    gather = Gather(graph, {
        'name': gather_name,
        'need_shape_inference': True
    }).create_node_with_data([data_node, const, axis_const])
    attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy()

    graph.add_edge(gather.id, op_node.id, **attrs)
    graph.remove_edge(data_node.id, op_node.id)

    # need to run manually to override output shape value to resolve shape collision for nodes with
    # 'correct_data_layout' output port attrs
    op_node['need_shape_inference'] = True
    def placeholder_scales(self, placeholder: Node):
        """
        Helper function to get scales for prior boxes out of input image size:
                [1 / im_width, 1 / im_height, 1 / im_width, 1 / im_height]
        """
        graph = placeholder.graph
        name = placeholder.soft_get('name', placeholder.id)

        shape_value = placeholder.soft_get('shape', None)
        assert shape_value is not None, \
            "[ {} replacer ] Placeholder `{}` should have shape attribute".format(self.replacement_id, name)
        assert isinstance(shape_value, np.ndarray), \
            "[ {} replacer ] Placeholder `{}` shape attribute should be np.ndarray".format(self.replacement_id, name)
        assert shape_value.size == 4, \
            "[ {} replacer ] Placeholder `{}` should be 4D. Shape: {}".format(self.replacement_id, name, shape_value)

        shape = Shape(graph, {'name': 'input_image_shape'}).create_node()
        shape.in_port(0).connect(placeholder.out_port(0))

        begin = Const(graph, {'value': int64_array([1])}).create_node()
        end = Const(graph, {'value': int64_array([3])}).create_node()
        stride = Const(graph, {'value': int64_array([1])}).create_node()
        spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': np.array([1]),
                                       'end_mask': np.array([1]), 'new_axis_mask': np.array([0]),
                                       'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()

        spatial.in_port(0).connect(shape.out_port(0))
        spatial.in_port(1).connect(begin.out_port(0))
        spatial.in_port(2).connect(end.out_port(0))
        spatial.in_port(3).connect(stride.out_port(0))

        power = Const(graph, {'value': float32_array([-1.])}).create_node()
        spatial_scale = Pow(graph, {}).create_node()

        spatial_scale.in_port(0).connect(spatial.out_port(0))
        spatial_scale.in_port(1).connect(power.out_port(0))

        # Power `type_infer` requires inputs to have equal data type
        convert_to_fp32 = Cast(graph, {'dst_type': np.float32}).create_node()
        spatial_scale.in_port(0).get_connection().insert_node(convert_to_fp32)

        order = Const(graph, {'value': int64_array([1, 0])}).create_node()
        axis_const = Const(graph, {'value': int64_array(0)}).create_node()
        reverse = Gather(graph, {}).create_node()

        reverse.in_port(0).connect(spatial_scale.out_port(0))
        reverse.in_port(1).connect(order.out_port(0))
        axis_const.out_port(0).connect(reverse.in_port(2))

        priors_scale_node = Concat(graph, {'axis': 0, 'in_ports_count': 2}).create_node()
        priors_scale_node.add_input_port(0, skip_if_exist=True)
        priors_scale_node.add_input_port(1, skip_if_exist=True)

        priors_scale_node.in_port(0).connect(reverse.out_port(0))
        priors_scale_node.in_port(1).connect(reverse.out_port(0))
        return priors_scale_node
Exemplo n.º 13
0
    def replace_pattern(self, graph: Graph, match: dict):
        gather = match['GatherNd']
        input_shape = gather.in_node(0).shape
        indices = gather.in_node(1).value
        if indices is None:
            # We can't do such special pass without indices value
            return

        # 0. All needed checks that we can replace GatherNd by Gather
        gather_idx = self.indices_check(indices, input_shape)
        if gather_idx is None:
            log.warning(
                'Node {} with op=GatherNd  can\'t be normalized to op=Gather.'.
                format(gather.name))
            return

        # 1. Add Reshape and connect
        new_shape = int64_array([-1] + list(input_shape[indices.shape[-1]:]))
        reshape = Reshape(graph, {
            'name': gather.name + '/Reshape_for_GatherNd/'
        }).create_node()
        reshape_const_node = Const(graph, {
            'name': reshape.name + '/Dim',
            'value': new_shape
        }).create_node()
        gather.in_port(0).get_connection().set_destination(reshape.in_port(0))
        reshape.in_port(1).connect(reshape_const_node.out_port(0))

        # 2. Change indices from Nd to 1d:
        new_indices = np.reshape(
            np.take(indices, indices=[gather_idx], axis=-1), [-1])
        new_indices_const = Const(graph, dict(value=new_indices)).create_node()

        # 3. Create new Gather operation and reconnect all inputs/outputs
        new_gather = Gather(graph, {
            'name': gather.name + '/NewGather/',
            'axis': 0
        }).create_node()
        reshape.out_port(0).connect(new_gather.in_port(0))
        new_indices_const.out_port(0).connect(new_gather.in_port(1))

        gather.out_port(0).get_connection().set_source(new_gather.out_port(0))

        # 4. Remove old Gather node
        graph.remove_node(gather.id)
Exemplo n.º 14
0
def node_to_get_shape_value_of_range(shape_node: Node, indices: list):
    """
    The function returns a node that produces values of the specified indices of the input node 'shape_node'

    :param shape_node: the node of 1D output shape to get elements from
    :param indices: the list of element indices to get
    :return: node producing required elements of the node
    """
    graph = shape_node.graph
    indices_node = Const(graph, {
        'value': int64_array(indices),
        'name': shape_node.name + '/Indices'
    }).create_node()
    gather_node = Gather(graph, {
        'name': shape_node.name + '/Gather'
    }).create_node()

    shape_node.out_port(0).connect(gather_node.in_port(0))
    indices_node.out_port(0).connect(gather_node.in_port(1))

    return gather_node
Exemplo n.º 15
0
    def replace_with_gather(node):
        graph = node.graph

        name = node.soft_get('name', node.id)
        axis = node.axis
        order = node.order

        indices = Const(graph, {'name': name + '/reverse_order', 'value': order}).create_node()
        axis_const = Const(graph, {'value': int64_array(axis)}).create_node()
        gather = Gather(graph, {'name': name}).create_node()
        gather.in_port(1).connect(indices.out_port(0))
        gather.in_port(2).connect(axis_const.out_port(0))

        node.out_port(0).get_connection().set_source(gather.out_port(0))
        node.in_port(0).get_connection().set_destination(gather.in_port(0))
Exemplo n.º 16
0
def axis(op_node: Node, port_info: str, input_port: int):
    """
    Performs layout change related transformation of the data on the in_port_idx port of op_node.
    Translates shape indexes from one layout to another according to inverse permutation

    Transformation inserts Gather operation with
        permutation as 0-port input data and
        actual data to translate as 1-port input indexes of Gather

    For example:
        NHWC Reduce operation has 0-port input with data of shape [1, 2, 3, 4] and
        1-port input with axis indices [0, 1].

        After translating such operation to NCHW layout:
            0-port input shape = [1, 4, 2, 3]
            1-port input axis indices = [0, 2]
    """
    graph = op_node.graph

    permutation_data_node = get_node_with_permutation(op_node, port_info)
    assert permutation_data_node.has_and_set('permutation'), 'Data node "{}" does not have permutation for node {}, ' \
                                                             'port_info "{}".'.format(permutation_data_node.id,
                                                                                      op_node.id, port_info)
    permutation = permutation_data_node.permutation
    if len(permutation.perm) == 0:
        return

    data_node = op_node.in_node(input_port)

    gather_name = op_node.soft_get('name', op_node.id) + '/AxisGather'
    const = Const(
        graph, {
            'value': permutation.inv,
            'name': gather_name + '/const',
            'need_shape_inference': True
        }).create_node_with_data()
    axis_const = Const(graph, {
        'value': int64_array(0),
        'name': gather_name + '/axis'
    }).create_node_with_data()
    gather = Gather(graph, {
        'name': gather_name,
        'need_shape_inference': True
    }).create_node_with_data([const, data_node, axis_const])
    attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy()
    graph.add_edge(gather.id, op_node.id, **attrs)
    graph.remove_edge(data_node.id, op_node.id)
    op_node['need_shape_inference'] = True
Exemplo n.º 17
0
 def replace_sub_graph(self, graph: Graph, match: dict):
     node = match['op']
     gather_node = Gather(
         graph,
         dict(name=node.id + '/embedding_',
              symbol_dict={'name': node.id + '/embedding_'})).create_node()
     axis_const = Const(graph, {'value': int64_array(0)}).create_node()
     node.in_port(0).get_connection().set_destination(
         gather_node.in_port(1))
     node.in_port(1).get_connection().set_destination(
         gather_node.in_port(0))
     axis_const.out_port(0).connect(gather_node.in_port(2))
     node.out_port(0).get_connection().set_source(gather_node.out_port(0))
Exemplo n.º 18
0
 def extract(cls, node):
     Gather.update_node_stat(node, {'batch_dims': node.pb.attr['batch_dims'].i})
     return cls.enabled
    def find_and_replace_pattern(self, graph: Graph):
        # 1. Inserting Gather to N*C format on constant shape paths
        #   - Search for Shape ops
        #   - Inserting Gather after them in case of [4] or [5] output shape

        shape_ops = graph.get_op_nodes(op='ShapeOf')
        constant_shape_paths = set()
        gather_inserted = []

        for shape in shape_ops:
            output_port = shape.in_port(0).get_source()
            if is_output_data_in_correct_layout(output_port.node,
                                                output_port.idx):
                continue
            shape_of_shape_op_output = shape.out_node().shape

            if np.array_equal(shape_of_shape_op_output, [4]):
                index = np.array([0, 2, 3, 1])
            elif np.array_equal(shape_of_shape_op_output, [5]):
                index = np.array([0, 2, 3, 4, 1])
            else:
                continue

            const = Const(graph, {'value': index}).create_node()
            gather = Gather(graph, {
                'name': shape.name + '/GatherNCHWtoNHWC'
            }).create_node()

            shape.out_port(0).get_connection().set_source(gather.out_port(0))
            shape.out_port(0).connect(gather.in_port(0))
            const.out_port(0).connect(gather.in_port(1))

            constant_shape_paths.add(gather.id)
            gather_inserted.append(gather.id)

        # 2. Inserting Gather to NC* format
        #   - Search from Shape ops found in previous step for nodes without value that are n-th children of Shape op
        #       * MO can not propagate value, there is data path
        #   - Inserting Gather on ports which comes from operations in `constant_shape_paths` list

        constant_shape_ends = []

        for shape in shape_ops:
            constant_shape_ends.extend(
                self.search_of_constant_path_end(graph,
                                                 node_name=shape.id,
                                                 visited=constant_shape_paths))

        for end in constant_shape_ends:
            node = Node(graph, end)
            in_ports = [
                in_port for in_port in node.in_ports().values()
                if in_port.get_source().node.id in constant_shape_paths
            ]

            for in_port in in_ports:
                shape = in_port.data.get_shape()

                if np.array_equal(shape, [4]):
                    index = np.array([0, 3, 1, 2])
                elif np.array_equal(shape, [5]):
                    index = np.array([0, 2, 3, 4, 1])
                else:
                    continue

                const = Const(graph, {'value': np.array(index)}).create_node()
                gather = Gather(graph, {
                    'name': node.name + '/GatherNHWCtoNCHW'
                }).create_node()

                in_port.get_source().connect(gather.in_port(0))
                in_port.get_connection().set_source(gather.out_port(0))
                const.out_port(0).connect(gather.in_port(1))
Exemplo n.º 20
0
    def extract(node):
        attrs = {}

        Gather.update_node_stat(node, attrs)

        return __class__.enabled
Exemplo n.º 21
0
 def extract(cls, node):
     Gather.update_node_stat(node, {})
     return cls.enabled