Пример #1
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        axis = node.in_port(1).data.get_value()
        size_splits = node.in_port(2).data.get_value()

        output_shape = sum([node.out_node(port).shape[axis] for port in node.out_nodes()])

        if output_shape == node.in_port(0).data.get_shape()[axis]:
            return

        if not node.has_valid('out_ports_count'):
            node['out_ports_count'] = len(size_splits)

        Op.normalize_outputs(node)
Пример #2
0
def copy_graph_with_ops(graph: Graph) -> Graph:
    """
    Function to copy graph and apply extenders to appropriate nodes
    :param graph: Graph to copy
    :return:Copied graph with applied extenders
    """
    new_graph = Graph()
    new_graph.stage = 'back'
    new_graph.graph = graph.graph

    node_connections = dict()
    mapping_of_old_idx_into_new = dict()

    restore_correct_ports(graph)

    # Nodes preprocessing stage in source graph
    # Firstly propagate values only for Const nodes, because other preprocessings
    # assumes Const nodes are already preprocessed.
    for op in graph.get_op_nodes(type='Const'):
        preprocessing_op_nodes[op.type](op)

    for op in graph.get_op_nodes():
        if op.soft_get('type') != 'Const' and op.soft_get(
                'type') in preprocessing_op_nodes:
            preprocessing_op_nodes[op.type](op)

    # Create a new copy of graph with correct attributes (shape & type infer, backend attrs etc.)
    for op in graph.get_op_nodes():

        # Save input shapes restored from IR
        op['old_input_shapes'] = list()
        for n in op.in_nodes():
            op.old_input_shapes.append(int64_array(op.in_node(n).shape))

        # Apply extenders to nodes in source graph
        if op.type in Extender.registered_ops:
            Extender.get_extender_class_by_name(op.type).extend(op)
        else:
            log.debug(
                'Extender for node {} with type={} not found, please note.'.
                format(op.name, op.type))

        # Add node with necessary type and extended attrs in new graph
        op_type = op.soft_get('type_to_create', op.type)

        if op_type in custom_ops:
            node = custom_ops[op_type](new_graph, op.attrs()).create_node()
        else:
            if op_type not in Op.registered_ops:
                log.warning(
                    'Operation {} is not found in MO operations, please check it! '
                    'Simple shape infer function is used'.format(op_type))
                node = Op(new_graph, op.attrs()).create_node()
                assert 'type' in node, 'Operation {} have no `type` attribute.'.format(
                    node.soft_get('name'))
                node['op'] = node.type
                node['infer'] = Extender.use_shapes_from_ir
                if 'ir_data_attrs' in op:
                    node['IE'] = [('layer', [
                        ('id', lambda node: node.node), 'name', 'type',
                        'version'
                    ], [('data', list(op.ir_data_attrs.keys()), []), '@ports',
                        '@consts'])]

            else:
                node = Op.get_op_class_by_name(op_type)(
                    new_graph, op.attrs()).create_node()

            # Fill out_ports_count attribute
            if 'out_ports_count' not in node and node.soft_get(
                    'type') != 'Result':
                node['out_ports_count'] = len(op.out_edges())

        # This attribute is no longer needed and we can delete it
        if 'ir_data_attrs' in node:
            del node['ir_data_attrs']

        if op.has_and_set('need_copy_input_blobs'):
            copy_input_blobs(op, node)

        # Collect node connections
        mapping_of_old_idx_into_new[op.id] = node.id
        node_connections[op.id] = collect_node_outputs(op)

    # Restore connections in new graph
    for input_node_idx, its_outputs in list(node_connections.items()):
        for out_port_idx, out_port_dest in its_outputs.items():
            for dest_in_port_idx, dest_node_idx in out_port_dest:
                src = Node(new_graph,
                           mapping_of_old_idx_into_new[input_node_idx])
                dst = Node(new_graph,
                           mapping_of_old_idx_into_new[dest_node_idx])
                src.out_port(out_port_idx).connect(
                    dst.in_port(dest_in_port_idx))

    # Nodes postprocessing stage in new graph
    for op in new_graph.get_op_nodes():
        # Call normalize node outputs for restored operations to connect temporary Result operations for disconnected
        # output ports. We need to do that for correct shape inference. These Result operations will be removed during
        # IR emitting. For TopK operation outputs normalizing we should use specific
        # function TopKNormalizer.normalize_outputs.
        if op.soft_get('type') != 'TopK':
            Op.normalize_outputs(op)

        # Set correct_data_type attribute to Const data nodes to correct processing of restored values
        if op.soft_get('type') == 'Const':
            assert len(op.out_nodes()) == 1 and op.out_node(0).soft_get('kind') == 'data',\
                'Const node {} not properly corrected to appropriate data node'.format(op.soft_get('name'))
            op.out_node(0)['correct_data_type'] = True

            if op.has_and_set('rt_info'):
                op.out_node(0)['rt_info'] = op.rt_info

        # operations postprocessing with some special types
        if op.soft_get('type') in postprocessing_op_nodes:
            postprocessing_op_nodes[op.type](op)

        restore_tensor_names(op)

    # clean up graph to shape inference
    new_graph.clean_up()

    return new_graph
Пример #3
0
 def find_and_replace_pattern(self, graph: Graph):
     for split_node in graph.get_op_nodes(op='Split'):
         Op.normalize_outputs(split_node)