def replace_pattern(graph: Graph, match: dict): node = match['op'] axis = node.in_port(1).data.get_value() size_splits = node.in_port(2).data.get_value() output_shape = sum([node.out_node(port).shape[axis] for port in node.out_nodes()]) if output_shape == node.in_port(0).data.get_shape()[axis]: return if not node.has_valid('out_ports_count'): node['out_ports_count'] = len(size_splits) Op.normalize_outputs(node)
def copy_graph_with_ops(graph: Graph) -> Graph: """ Function to copy graph and apply extenders to appropriate nodes :param graph: Graph to copy :return:Copied graph with applied extenders """ new_graph = Graph() new_graph.stage = 'back' new_graph.graph = graph.graph node_connections = dict() mapping_of_old_idx_into_new = dict() restore_correct_ports(graph) # Nodes preprocessing stage in source graph # Firstly propagate values only for Const nodes, because other preprocessings # assumes Const nodes are already preprocessed. for op in graph.get_op_nodes(type='Const'): preprocessing_op_nodes[op.type](op) for op in graph.get_op_nodes(): if op.soft_get('type') != 'Const' and op.soft_get( 'type') in preprocessing_op_nodes: preprocessing_op_nodes[op.type](op) # Create a new copy of graph with correct attributes (shape & type infer, backend attrs etc.) for op in graph.get_op_nodes(): # Save input shapes restored from IR op['old_input_shapes'] = list() for n in op.in_nodes(): op.old_input_shapes.append(int64_array(op.in_node(n).shape)) # Apply extenders to nodes in source graph if op.type in Extender.registered_ops: Extender.get_extender_class_by_name(op.type).extend(op) else: log.debug( 'Extender for node {} with type={} not found, please note.'. format(op.name, op.type)) # Add node with necessary type and extended attrs in new graph op_type = op.soft_get('type_to_create', op.type) if op_type in custom_ops: node = custom_ops[op_type](new_graph, op.attrs()).create_node() else: if op_type not in Op.registered_ops: log.warning( 'Operation {} is not found in MO operations, please check it! ' 'Simple shape infer function is used'.format(op_type)) node = Op(new_graph, op.attrs()).create_node() assert 'type' in node, 'Operation {} have no `type` attribute.'.format( node.soft_get('name')) node['op'] = node.type node['infer'] = Extender.use_shapes_from_ir if 'ir_data_attrs' in op: node['IE'] = [('layer', [ ('id', lambda node: node.node), 'name', 'type', 'version' ], [('data', list(op.ir_data_attrs.keys()), []), '@ports', '@consts'])] else: node = Op.get_op_class_by_name(op_type)( new_graph, op.attrs()).create_node() # Fill out_ports_count attribute if 'out_ports_count' not in node and node.soft_get( 'type') != 'Result': node['out_ports_count'] = len(op.out_edges()) # This attribute is no longer needed and we can delete it if 'ir_data_attrs' in node: del node['ir_data_attrs'] if op.has_and_set('need_copy_input_blobs'): copy_input_blobs(op, node) # Collect node connections mapping_of_old_idx_into_new[op.id] = node.id node_connections[op.id] = collect_node_outputs(op) # Restore connections in new graph for input_node_idx, its_outputs in list(node_connections.items()): for out_port_idx, out_port_dest in its_outputs.items(): for dest_in_port_idx, dest_node_idx in out_port_dest: src = Node(new_graph, mapping_of_old_idx_into_new[input_node_idx]) dst = Node(new_graph, mapping_of_old_idx_into_new[dest_node_idx]) src.out_port(out_port_idx).connect( dst.in_port(dest_in_port_idx)) # Nodes postprocessing stage in new graph for op in new_graph.get_op_nodes(): # Call normalize node outputs for restored operations to connect temporary Result operations for disconnected # output ports. We need to do that for correct shape inference. These Result operations will be removed during # IR emitting. For TopK operation outputs normalizing we should use specific # function TopKNormalizer.normalize_outputs. if op.soft_get('type') != 'TopK': Op.normalize_outputs(op) # Set correct_data_type attribute to Const data nodes to correct processing of restored values if op.soft_get('type') == 'Const': assert len(op.out_nodes()) == 1 and op.out_node(0).soft_get('kind') == 'data',\ 'Const node {} not properly corrected to appropriate data node'.format(op.soft_get('name')) op.out_node(0)['correct_data_type'] = True if op.has_and_set('rt_info'): op.out_node(0)['rt_info'] = op.rt_info # operations postprocessing with some special types if op.soft_get('type') in postprocessing_op_nodes: postprocessing_op_nodes[op.type](op) restore_tensor_names(op) # clean up graph to shape inference new_graph.clean_up() return new_graph
def find_and_replace_pattern(self, graph: Graph): for split_node in graph.get_op_nodes(op='Split'): Op.normalize_outputs(split_node)