def test_run_with_const_input(self): inp_shape = (1, 3, 1000, 1000) nodes = { **shaped_const_with_data('input', int64_array(inp_shape)), **regular_op('sizes_const', {'op': 'Const'}), **{'sizes_const_d': {'kind': 'data', 'value': float32_array([1., 1., 1., 100.])}}, **regular_op_with_empty_data('interpolate', {'type': 'Interpolate', 'shape_calculation_model': 'scales'}), **result('res'), } nodes_ref = { **shaped_const_with_data('input', int64_array(inp_shape)), **regular_op('sizes_const', {'op': 'Const', 'returns_shape_value': True}), **{'sizes_const_d': {'kind': 'data', 'value': float32_array([1., 1., 1., 100.])}}, **regular_op_with_empty_data('interpolate', {'type': 'Interpolate', 'shape_calculation_model': 'scales'}), **result('res'), } edges = [ *connect('input', '0:interpolate'), *connect('sizes_const', '1:interpolate'), *connect('interpolate', 'res'), ] graph = build_graph(nodes, edges) interp_node = Node(graph, 'interpolate') interp_node.add_input_port(2) MarkNodesWithShapeValues().find_and_replace_pattern(graph) graph_ref = build_graph(nodes_ref, edges) (flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=True) self.assertTrue(flag, resp)
def test_case1_source(self): graph = build_graph(nodes, [('input', 'Op1', { 'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')] })]) graph_ref = build_graph( nodes, [('input', 'NewOp', { 'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')] })]) input_node = Node(graph, 'input') new_node = Node(graph, 'NewOp') new_node.add_input_port(0) graph.stage = 'front' new_node.in_port(0).get_connection().set_source( input_node.out_port(0), "source") (flag, resp) = compare_graphs(graph, graph_ref, 'NewOp', check_op_attrs=True) self.assertTrue(flag, resp) self.check_graph_attrs_front(graph, graph_ref)
def normalize_body_graph(loop_node: Node): loop_name = loop_node.soft_get('name', loop_node.id) # connect "trip count" input if it is not connected with default value "Infinity" (-1) if not loop_node.is_in_port_connected(0): loop_node.add_input_port(0, skip_if_exist=True) Const(loop_node.graph, {'name': loop_name + '/trip_count', 'value': int64_array(-1)}).\ create_node().out_port(0).connect(loop_node.in_port(0)) # connect "execution condition" input if it is not connected with default value True if not loop_node.is_in_port_connected(1): loop_node.add_input_port(1, skip_if_exist=True) Const(loop_node.graph, {'name': loop_name + '/execution_cond', 'value': np.array(True, dtype=np.bool)}).\ create_node().out_port(0).connect(loop_node.in_port(1)) # scan output need Unsqueeze over axis 0 for record in loop_node.output_port_map: body_node = Loop.get_body_node_by_internal_id(loop_node, record['internal_layer_id']) assert body_node is not None assert body_node.soft_get('type') == 'Result' if record['axis'] is not None: unsqueeze = create_op_with_const_inputs(loop_node.body, Unsqueeze, {1: int64_array([0])}) body_node.in_port(0).get_connection().insert_node(unsqueeze) Loop.normalize_input_output_ports(loop_node)
def add_input_data_to_prior_boxes(graph: Graph, input_names: str = ''): """ PriorBox layer has data input unlike mxnet. Need to add data input to _contrib_MultiBoxPrior for for correct conversion to PriorBox layer. Parameters ---------- graph : Graph Graph with loaded model. """ if not input_names: input_names = ('data', ) else: input_names = input_names.split(',') input_nodes = {} for node in graph.nodes(): node = Node(graph, node) if node.has_valid('op') and node.name in input_names: input_nodes.update({node.id: node}) if len(input_nodes) > 0: for node in graph.nodes(): node = Node(graph, node) if node.has_valid( 'op') and node.op == '_contrib_MultiBoxPrior': node.add_input_port(idx=1) graph.create_edge(list(input_nodes.values())[0], node, out_port=0, in_port=1)
def replace_output_edges(graph: Graph, output_edges_match: dict): """ Replacing existing input/output edges with a new ones to a new sub-graph. :param graph: networkX graph to operate on. :param output_edges_match: match of output edges between old and new sub-graph. :return: None """ for old_name_port, new_name_port in output_edges_match.items(): old_node_name, old_out_port = __class__.extract_port(old_name_port) new_node_name, new_out_port = __class__.extract_port(new_name_port) for src, dst, edge_attrs in graph.out_edges(old_node_name, data=True): if edge_attrs['out'] == old_out_port: new_edge_attrs = edge_attrs.copy() new_edge_attrs['out'] = new_out_port # Add control_flow ports, as we do not copy control flow ports to new node if 'control_flow_edge' in new_edge_attrs and new_edge_attrs['control_flow_edge'] is True: in_port_id = 'control_flow_{}'.format(new_edge_attrs['in']) out_port_id = 'control_flow_{}'.format(new_edge_attrs['out']) in_node, out_node = Node(graph, dst), Node(graph, new_node_name) # if not out_node.has_port('out', out_port_id, control_flow=True): out_node.add_output_port(out_port_id, control_flow=True, skip_if_exist=True) # if not in_node.has_port('in', in_port_id, control_flow=True): in_node.add_input_port(in_port_id, control_flow=True, skip_if_exist=True) graph.add_edge(new_node_name, dst, **new_edge_attrs) log.debug("Created edge from {} to {} with attrs: {}".format(new_node_name, dst, new_edge_attrs))
def re_number_input_port(loop_node: Node, old_port_id: int, new_port_id: int): loop_node.add_input_port(new_port_id, skip_if_exist=True) loop_node.in_port(old_port_id).get_connection().set_destination( loop_node.in_port(new_port_id)) Loop.update_port_map_value(loop_node.input_port_map, 'external_port_id', old_port_id, new_port_id)
def input_as_const(node: Node, attrs: dict, port: int, bin: str, value: np.ndarray): """ Inserts constant node on input `port` of `node` with `values` and `attrs`. Marks input edge with bin `attribute` """ graph = node.graph const = Const(graph, {'value': value, **attrs}).create_node() node.add_input_port(port, skip_if_exist=True) const.out_port(0).connect(node.in_port(port)) node.in_port(port).bin = bin node.in_port(port).in_attrs.append('bin')
def find_and_replace_pattern(self, graph: Graph): graph.stage = 'front' for node_id in graph.nodes(data=False): node = Node(graph, node_id) inputs = node.get_sorted_inputs() outputs = node.get_sorted_outputs() in_ports_count = node.in_ports_count if node.has_valid( 'in_ports_count') else len(inputs) out_ports_count = node.out_ports_count if node.has_valid( 'out_ports_count') else len(outputs) if len(outputs) > out_ports_count > 1: raise Error("Node {} has more children than it should: " + "should be {} but there is {}".format( node_id, out_ports_count, len(outputs))) node['_in_ports'] = {} node['_out_ports'] = {} if in_ports_count is not None: for idx in range(in_ports_count): node.add_input_port(idx=idx) if out_ports_count is not None: for idx in range(out_ports_count): node.add_output_port(idx=idx) idx = 0 for in_node_id, edge_attrs in inputs: graph.remove_edge(in_node_id, node_id) if len(Node(graph, in_node_id).out_ports()) == 0: Node(graph, in_node_id).add_output_port(0) in_node = Node(graph, in_node_id) in_node.out_port(edge_attrs['out']).connect(node.in_port(idx)) # need to keep this attribute in edge for correct .mapping file generation and # for generation of "names" field in IR in_node.out_edge( edge_attrs['out'] )['fw_tensor_debug_info'] = edge_attrs['fw_tensor_debug_info'] if idx < in_ports_count - 1: idx = idx + 1 idx = 0 for out_node_id, edge_attrs in outputs: graph.remove_edge(node_id, out_node_id) if len(Node(graph, out_node_id).in_ports()) == 0: Node(graph, out_node_id).add_input_port(0) node.out_port(idx).connect( Node(graph, out_node_id).in_port(edge_attrs['in'])) # need to keep this attribute in edge for correct .mapping file generation and # for generation of "names" field in IR node.out_edge(idx)['fw_tensor_debug_info'] = edge_attrs[ 'fw_tensor_debug_info'] if idx < out_ports_count - 1: idx = idx + 1
def test_case2_source(self): graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1')]) graph_ref = build_graph(nodes, [('input', 'input_data'), ('input_data', 'NewOp')]) op1_node = Node(graph, 'Op1') new_node = Node(graph, 'NewOp') new_node.add_input_port(0) op1_node.in_port(0).get_connection().set_destination( new_node.in_port(0), "source") (flag, resp) = compare_graphs(graph, graph_ref, 'NewOp', check_op_attrs=True) self.assertTrue(flag, resp) self.check_graph_attrs_middle(graph, graph_ref)
def test_case1_dest(self): graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1')]) graph_ref = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'), ('input_data', 'NewOp')]) input_node_data = Node(graph_ref, 'input_data') del input_node_data['fw_tensor_debug_info'] input_node = Node(graph, 'input') new_node = Node(graph, 'NewOp') new_node.add_input_port(0) new_node.in_port(0).get_connection().set_source( input_node.out_port(0), "dest") (flag, resp) = compare_graphs(graph, graph_ref, 'NewOp', check_op_attrs=True) self.assertTrue(flag, resp) self.check_graph_attrs_middle(graph, graph_ref)
def create_and_connect_input_data_node(graph: Graph, op_node: Node, attrs: dict = None, edge_attrs: dict = None): assert op_node is not None and op_node.kind == 'op' if attrs is None: attrs = {} if edge_attrs is None: edge_attrs = {} data_node = graph.unique_id(op_node.id) default_attrs = dict(kind='data', name=data_node, value=None, shape=None, data_type=None, infer=None) default_attrs.update(attrs) graph.add_node(data_node, **add_attrs_props(default_attrs)) data_node = Node(graph, data_node) op_node.add_input_port(edge_attrs['in'], skip_if_exist=True) graph.add_edges_from([(data_node.id, op_node.id, edge_attrs)]) return data_node
def muladd_to_scaleshift_action(graph: Graph, match: dict): mul = match['mul'] add = match['add'] output = match['output'] # Pass works correctly only in case when node have only 1 output if len(mul.out_port(0).get_destinations()) > 1: return if mul.soft_get('can_be_scaleshift') is False or add.soft_get('can_be_scaleshift') is False: return mul_weights_id = get_value_id(mul) mul_input_id = get_tensor_id(mul) add_weights_id = get_value_id(add) if mul_weights_id is None: log.debug("Mul->Add to ScaleShift: Mul {} has no weights".format(mul.name)) return if mul_input_id is None: log.debug("Mul->Add to ScaleShift: Mul {} has no input".format(mul.name)) return if add_weights_id is None: log.debug("Mul->Add to ScaleShift: Add {} has no weights".format(add.name)) return input = mul.in_node(mul_input_id) weights = mul.in_node(mul_weights_id) bias = add.in_node(add_weights_id) # Transform values weights.value = np.squeeze(weights.value) weights.shape = int64_array(weights.value.shape) bias.value = np.squeeze(bias.value) bias.shape = int64_array(bias.value.shape) # Broadcast weights if they are scalar if weights.value.ndim == 0 and bias.value.ndim == 1: weights.value = np.full(bias.shape, weights.value.item(), dtype=weights.value.dtype) weights.shape = int64_array(weights.value.shape) if bias.shape != weights.shape: log.warning('Mul->Add to ScaleShift conversion stopped {} != {}'.format(weights.shape, bias.shape)) return if bias.value.ndim != weights.value.ndim or bias.value.size != weights.value.size: log.debug("Skipping Mul->Add to ScaleShift conversion for nodes {}, {} because of different weights " "and biases".format(mul.name, add.name)) return if bias.value.size == 1 and weights.value.size == 1: log.debug("Skipping Mul->Add to ScaleShift conversion for nodes {}, {}. Will be converted to Power" "".format(mul.name, add.name)) return op_name = "ScaleShift" log.debug("Fusing Mul->Add to {}. Input nodes: {} and {}, bias.shape = {}, weights.shape = {}" "".format(op_name, mul.id, add.id, bias.shape, weights.shape)) graph.remove_edge(input.node, mul.id) graph.remove_edge(weights.node, mul.id) graph.remove_edge(bias.node, add.id) graph.remove_edge(add.node, output.id) op_node = graph.unique_id(mul.name + '/Fused{}_'.format(op_name)) graph.add_node(op_node, **add_attrs_props(dict(kind='op', type=op_name, name=op_node, op=op_name, data_type=input.data_type))) scsh = Node(graph, op_node) scsh.add_input_port(0) scsh.add_input_port(1) scsh.add_input_port(2) scsh.add_output_port(0) update_ie_fields(graph.node[op_node]) graph.add_edges_from([ (input.node, op_node, {'in': 0}), (weights.node, op_node, {'in': 1, 'bin': 'weights'}), (bias.node, op_node, {'in': 2, 'bin': 'biases'}), (op_node, output.node, {'out': 0}) ]) return
def build_graph_with_attrs(nodes_with_attrs: list, edges_with_attrs: list, new_nodes_with_attrs: list = [], new_edges_with_attrs: list = [], update_edge_attrs: dict = None, update_nodes_attributes: list = None, nodes_with_edges_only: bool = False, add_nodes_from_edges: bool = False): """ Build the Graph with specific nodes and edges. Also update of edge and node parameters is supported. :param nodes_with_attrs: list of tuples ('node_name', {node_attrs}) :param edges_with_attrs: list of tuples like (start node, end node, (optional) {attrs of the edge}). :param new_nodes_with_attrs: analogically nodes_with_attrs :param new_edges_with_attrs: analogically new_edges :param update_edge_attrs: optional dictionary like {('from_node', 'to_node', key): {edge_attrs}}. :param update_nodes_attributes: optional list of tuples which specifies nodes names and their attributes to be updated. The first element is a node name to update attribute and the second element is a dictionary with attribute name and its value. :param nodes_with_edges_only: add nodes which has at least one incoming or outcoming edge. :param add_nodes_from_edges: whether nodes that is not listed in all_nodes but are in all_edges is allowed. :return: generated graph. """ if not_all_new([node[0] for node in nodes_with_attrs], [node[0] for node in new_nodes_with_attrs]): raise Error( 'Some nodes from new_nodes_with_attrs are already in nodes.' ' Please, add to new_nodes_with_attrs only NEW nodes.') if not_all_new([(edge[0], edge[1]) for edge in edges_with_attrs], [(edge[0], edge[1]) for edge in new_edges_with_attrs]): raise Error( 'Some edges from new_edges_with_attrs are already in edges.' ' Please, add to new_edges_with_attrs only NEW edges.') # Check that all nodes from list of edges are in nodes all_nodes = nodes_with_attrs + new_nodes_with_attrs all_edges = edges_with_attrs + new_edges_with_attrs all_nodes_names = [node[0] for node in all_nodes] if not add_nodes_from_edges and not all_edges_in_nodes( nodes=all_nodes_names, edges=all_edges): raise Error( "Some nodes from list of edges is not in nodes. Please, add all necessary nodes." ) graph = Graph() # Create dict for nodes with attrs nodes_attrs = {} for node_name, attrs in all_nodes: nodes_attrs[node_name] = attrs if 'name' not in attrs: attrs['name'] = node_name if nodes_with_edges_only: # filter nodes to keep only ones with edges connected filtered_nodes = {} for edge in all_edges: node_1, node_2 = edge[0], edge[1] filtered_nodes[node_1] = nodes_attrs[node_1] filtered_nodes[node_2] = nodes_attrs[node_2] nodes_attrs = filtered_nodes # Create all nodes for node, attrs in nodes_attrs.items(): graph.add_node(node, **deepcopy(attrs)) # Connect nodes with edges (also unpack edge params) for edge in all_edges: node_1, node_2 = edge[0], edge[1] edge_attrs = edge[2] if len(edge) == 3 else {} graph.add_edge(node_1, node_2, **edge_attrs) # Update attributes of edges if update_edge_attrs: # it will work in 2.x networkx only for edge, attr in update_edge_attrs.items(): for k, v in attr.items(): nx.set_edge_attributes(G=graph, name=k, values={edge: v}) # Update attributes of nodes if update_nodes_attributes is not None: for node_name, new_attrs in update_nodes_attributes: assert (node_name in graph.nodes()) for attr, value in new_attrs.items(): graph.node[node_name][attr] = value for node_id in graph.nodes(): node = Node(graph, node_id) check_and_update_ports(node, [ graph.get_edge_data(edge[0], node_id)[0] for edge in graph.in_edges(node_id) ], True) check_and_update_ports(node, [ graph.get_edge_data(node_id, edge[1])[0] for edge in graph.out_edges(node_id) ], False) for node in graph.get_op_nodes(): # Add in_ports attribute in_edges = node.in_edges() for i in range(len(in_edges)): node.add_input_port(idx=i) # Add out_ports attribute out_edges = node.out_edges() for i in range(len(out_edges)): node.add_output_port(idx=i) return graph
def read_node(file_descr, graph, component_layer_map, layer_node_map): s = file_descr.readline() if s == b'\n': return False tokens = s.split(b' ') if tokens[0] == b'input-node': in_name = s[s.find(b'name=') + len(b'name='):].split(b' ')[0] in_name = str(in_name).strip('b').replace('\'', "") in_shape = np.array( [1, s[s.find(b'dim=') + len(b'dim='):].split(b' ')[0]], dtype=np.int) if in_name not in layer_node_map: graph.add_node(in_name, name=in_name, kind='op', op='Parameter', parameters=None, shape=in_shape) layer_node_map[in_name] = in_name else: Node(graph, in_name)['op'] = 'Parameter' Node(graph, in_name)['shape'] = in_shape elif tokens[0] == b'component-node': layer_name = s[s.find(b'name=') + len(b'name='):].split(b' ')[0] layer_name = str(layer_name).strip('b').replace('\'', "") component_name = s[s.find(b'component=') + len(b'component='):].split(b' ')[0] if layer_name not in layer_node_map: node_name = graph.unique_id(prefix=layer_name) graph.add_node(node_name, parameters=None, op=None, kind='op') layer_node_map[layer_name] = node_name else: node_name = layer_node_map[layer_name] if component_name in component_layer_map: component_layer_map[component_name].append(node_name) else: component_layer_map[component_name] = [node_name] # parse input in_node_id = parse_input_for_node(s[s.find(b'input=') + 6:], graph, layer_node_map) # don't create cyclic edges node to itself to avoid removing later if in_node_id != node_name: out_port = len(Node(graph, in_node_id).out_nodes()) in_port = len(Node(graph, node_name).in_nodes()) Node(graph, node_name).add_input_port(in_port) Node(graph, in_node_id).add_output_port(out_port, skip_if_exist=True) graph.add_edge( in_node_id, node_name, **create_edge_attrs(in_node_id, node_name, in_node_id, in_port, out_port)) elif tokens[0] == b'output-node': layer_name = s[s.find(b'name=') + len(b'name='):].split(b' ')[0] layer_name = str(layer_name).strip('b').replace('\'', "") node_name = graph.unique_id(prefix=layer_name) graph.add_node(node_name, parameters=None, op='Identity', kind='op') out_name = graph.unique_id(prefix=node_name + "_out") graph.add_node(out_name, parameters=None, op='Result', kind='op') Node(graph, node_name).add_input_port(0) Node(graph, node_name).add_output_port(0) Node(graph, out_name).add_input_port(0) graph.add_edge(node_name, out_name, **create_edge_attrs(node_name, out_name, node_name)) # parse input in_node_id = parse_input_for_node( s[s.find(b'input=') + len(b'input='):], graph, layer_node_map) out_port = len(Node(graph, in_node_id).out_nodes()) Node(graph, in_node_id).add_output_port(out_port) graph.create_edge( Node(graph, in_node_id), Node(graph, node_name), out_port, 0, create_edge_attrs(in_node_id, node_name, in_node_id, 0, out_port)) objective_type = s[s.find(b'objective=') + 10:].split(b' ')[0].split(b'\n')[0] if objective_type != b'linear': raise Error( "Unsupported objective-type for output {}".format(node_name)) elif tokens[0] == b'dim-range-node': layer_name = s[s.find(b'name=') + len(b'name='):].split(b' ')[0] layer_name = str(layer_name).strip('b').replace('\'', "") offset = int(s[s.find(b'dim-offset=') + len(b'dim-offset='):].split(b' ')[0]) dim = int(s[s.find(b'dim=') + len(b'dim='):].split(b' ')[0]) if layer_name in layer_node_map: node_name = layer_node_map[layer_name] node = Node(graph, node_name) node['parameters'] = { 'offset': np.array([offset]), 'dim': np.array([dim]), 'axis': np.array([1]) } node['op'] = 'Crop' else: node_name = graph.unique_id(prefix=layer_name) graph.add_node(node_name, parameters={ 'offset': np.array([offset]), 'dim': np.array([dim]), 'axis': np.array([1]) }, op='Crop', kind='op') layer_node_map[layer_name] = node_name node = Node(graph, node_name) in_node_id = parse_input_for_node( s[s.find(b'input-node=') + len(b'input-node='):], graph, layer_node_map) out_port = len(Node(graph, in_node_id).out_nodes()) in_port = len(Node(graph, node_name).in_nodes()) node.add_input_port(in_port) Node(graph, in_node_id).add_output_port(out_port) graph.create_edge( Node(graph, in_node_id), node, out_port, in_port, create_edge_attrs(in_node_id, node_name, in_node_id, in_port, out_port)) # read dim info where possible to simplify shape calculation for MemoryOffset # shape calculation for MemoryOffset can't be done through shape of previous layer because # it is separated in 2 parts to remove cycle from graph for o_n_name, params in node.get_outputs(): o_n = Node(graph, o_n_name) if o_n['op'] == 'MemoryOffset': o_n['parameters']['element_size'] = int64_array([1, dim]) else: raise Error("Unsupported node specifier {}".format(tokens[0])) return True
def load_parallel_component(file_descr, graph: Graph, prev_layer_id): """ Load ParallelComponent of the Kaldi model. ParallelComponent contains parallel nested networks. VariadicSplit is inserted before nested networks. Outputs of nested networks concatenate with layer Concat. :param file_descr: descriptor of the model file :param graph: graph with the topology. :param prev_layer_id: id of the input layers for parallel component layer :return: id of the concat layer - last layer of the parallel component layers """ nnet_count = read_token_value(file_descr, b'<NestedNnetCount>') log.debug( 'Model contains parallel component with {} nested networks'.format( nnet_count)) split_points = [] outputs = [] inputs = [] for i in range(nnet_count): read_token_value(file_descr, b'<NestedNnet>') collect_until_token(file_descr, b'<Nnet>') g = Graph() load_kalid_nnet1_model(g, file_descr, 'Nested_net_{}'.format(i)) # input to nnet1 models is of a rank 1 but we also insert batch_size to 0th axis # 1st axis contains input_size of the nested subnetwork # we split input from the main network to subnetworks input_node = Node(g, 'Parameter') split_points.append(input_node['shape'][1]) g.remove_node(input_node.id) mapping = { node: graph.unique_id(node) for node in g.nodes(data=False) if node in graph } g = nx.relabel_nodes(g, mapping) for val in mapping.values(): g.node[val]['name'] = val graph.add_nodes_from(g.nodes(data=True)) graph.add_edges_from(g.edges(data=True)) sorted_nodes = tuple(nx.topological_sort(g)) outputs.append(Node(graph, sorted_nodes[-1])) inputs.append(Node(graph, sorted_nodes[0])) split_id = graph.unique_id(prefix='NestedNets/VariadicSplit') attrs = { 'out_ports_count': nnet_count, 'size_splits': split_points, 'axis': 1, 'name': split_id } variadic_split_node = AttributedVariadicSplit(graph, attrs).create_node() prev_layer_node = Node(graph, prev_layer_id) prev_layer_node.add_output_port(0) graph.create_edge( prev_layer_node, variadic_split_node, 0, 0, create_edge_attrs(prev_layer_id, variadic_split_node.id, prev_layer_id)) concat_id = graph.unique_id(prefix='Concat') graph.add_node(concat_id, parameters=None, op='concat', kind='op') concat_node = Node(graph, concat_id) # Connect each output of variadic_split_node to each subnetwork's inputs in ParallelComponent # and each subnetwork's output to concat_node for i, (input_node, output_node) in enumerate(zip(inputs, outputs)): output_node.add_output_port(0) concat_node.add_input_port(i) graph.create_edge( output_node, concat_node, 0, i, create_edge_attrs(output_node.id, concat_id, output_node.id, i, 0)) graph.create_edge( variadic_split_node, input_node, i, 0, create_edge_attrs(variadic_split_node.id, input_node.id, variadic_split_node.id, 0, i)) return concat_id
def merge_nodes(graph: Graph, nodes_to_merge_names: list, inputs_desc: list = None, outputs_desc: list = None): """ Merges nodes specified in the set 'nodes_to_merge_names' into one mega-node, creating new edges between mega-node and inputs/outputs nodes of the mega-node. The added edges contain name of input/output nodes which will be used for generation of placeholders and will be saved to the IR xml so IE plug-in know how to map input/output data for the layer. Also the function adds protobufs of the nodes of the sub-graph and 'Const' ops consumed by nodes in the sub-graph to the node's attribute 'pbs'. :param graph: the graph object to operate on. :param nodes_to_merge_names: list of nodes names that should be merged into a single node. :param inputs_desc: optional list describing input nodes order. :param outputs_desc: optional list describing output nodes order. """ if not is_connected_component(graph, nodes_to_merge_names): log.warning( "The following nodes do not form connected sub-graph: {}".format( nodes_to_merge_names)) # graph.dump_graph_for_graphviz(nodes_to_dump=nodes_to_merge_names) new_node_name = graph.unique_id("TFSubgraphCall_") log.info("Create new node with name '{}' for nodes '{}'".format( new_node_name, ', '.join(nodes_to_merge_names))) graph.add_node(new_node_name) new_node_attrs = graph.node[new_node_name] new_node_attrs['name'] = new_node_name set_tf_custom_call_node_attrs(new_node_attrs) new_node = Node(graph, new_node_name) added_input_tensors_names = set( ) # set of tensors that are were added as input to the sub-graph added_new_node_output_tensors = dict( ) # key - tensor name, value - out port for node_name in nodes_to_merge_names: node = Node(graph, node_name) add_node_pb_if_not_yet_added(node, new_node) # TODO: any improvements? for in_node_name, edge_attrs in Node(graph, node_name).get_inputs(): in_node = Node(graph, in_node_name) # internal edges between nodes of the sub-graph if in_node_name in nodes_to_merge_names: add_node_pb_if_not_yet_added(in_node, new_node) continue # edge outside of sub-graph into sub-graph if in_node_name not in nodes_to_merge_names: # we cannot use the 'in_node_name' as a protobuf operation name here # because the 'in_node_name' could be a sub-graph matched before. input_tensor_name = node.pb.input[edge_attrs['in']] if input_tensor_name not in added_input_tensors_names: if not new_node.has_port('in', edge_attrs['in']): new_node.add_input_port(edge_attrs['in']) graph.add_edge( in_node_name, new_node_name, **merge_edge_props( { 'in': find_input_port(new_node, inputs_desc, node_name, edge_attrs['in']), 'out': edge_attrs['out'], 'internal_input_node_name': input_tensor_name, 'original_dst_node_name': node_name, 'original_dst_port': edge_attrs['in'], 'in_attrs': [ 'in', 'internal_input_node_name', 'original_dst_node_name', 'original_dst_port', 'placeholder_name' ], 'out_attrs': ['out'] }, edge_attrs)) log.debug( "Creating edge from outside of sub-graph to inside sub-graph: {} -> {}" .format(in_node_name, new_node_name)) added_input_tensors_names.add(input_tensor_name) # edge from inside sub-graph to outside sub-graph for out_node_name, edge_attrs in Node(graph, node_name).get_outputs(): if out_node_name not in nodes_to_merge_names: log.debug( "Creating edge from inside of sub-graph to outside sub-graph: {} -> {}" .format(new_node_name, out_node_name)) out_name = internal_output_name_for_node( node_name, edge_attrs['out']) if out_name not in added_new_node_output_tensors.keys(): added_new_node_output_tensors[out_name] = find_output_port( new_node, outputs_desc, node_name, edge_attrs['out']) if not new_node.has_port( 'out', added_new_node_output_tensors[out_name]): new_node.add_output_port( added_new_node_output_tensors[out_name]) graph.add_edge( new_node_name, out_node_name, **merge_edge_props( { 'in': edge_attrs['in'], 'out': added_new_node_output_tensors[out_name], 'internal_output_node_name': out_name, 'in_attrs': ['in', 'internal_input_node_name'], 'out_attrs': ['out', 'internal_output_node_name'] }, edge_attrs)) new_node['output_tensors_names'] = [ val for val in {v: k for k, v in added_new_node_output_tensors.items()}.values() ] # add nodes using the same order as in initial GraphDef so we can dump them to IR in "correct" order new_node['nodes_order'] = [ node for node in graph.graph['initial_nodes_order'] if node in new_node['pbs'].keys() ] for n in nodes_to_merge_names: if graph.has_node( n): # check if not deleted by another (similar) pattern graph.remove_node(n) return Node(graph, new_node_name)