def mark_undead_nodes(graph, undead_types: list): """ Mark output nodes and nodes of the specific type as undead, meaning that they should survive the dead nodes elimination phase. Then mark all children nodes of the undead nodes (except children of inputs) as undead. :param graph: graph to operate on. :param undead_types: list of node types that should be marked as undead. :return: updated graph where each has attribute 'is_undead'. """ from mo.utils.graph import bfs_search nx.set_node_attributes(G=graph, name='is_undead', values=False) undead_types_with_result = undead_types + ['Result'] undead_nodes = [] for node in graph.get_op_nodes(): node_type = node.soft_get('type', node.soft_get('op')) if node_type in undead_types_with_result: undead_nodes.append(node.id) nx.set_node_attributes(G=graph, name='is_undead', values={n: True for n in undead_nodes}) # propagate 'undead' attribute to children nodes of undead nodes if the node produces constant value for node_name in bfs_search(graph, undead_nodes): if graph.node[node_name]['is_undead']: for _, dst_node_name in graph.out_edges(node_name): node_attrs = graph.node[dst_node_name] if 'kind' in node_attrs and ( node_attrs['kind'] == 'data' and node_attrs['value'] is not None or node_attrs['kind'] == 'op'): graph.node[dst_node_name]['is_undead'] = True # mark input nodes as undead inputs = graph.get_nodes_with_attributes(is_input=True) nx.set_node_attributes(G=graph, name='is_undead', values={n: True for n in inputs})
def test_bfs_search_specific_start_nodes(self): """ Check that BFS stars from the user defined nodes and doesn't go in backward edge direction. """ graph = Graph() graph.add_nodes_from(list(range(1, 7))) graph.add_edges_from([(1, 3), (2, 3), (3, 4), (4, 5), (6, 1)]) order = bfs_search(graph, [1]) self.assertTrue(order == [1, 3, 4, 5])
def test_bfs_search_default_start_nodes(self): """ Check that BFS automatically determines input nodes and start searching from them. """ graph = Graph() graph.add_nodes_from(list(range(1, 6))) graph.add_edges_from([(1, 3), (2, 3), (3, 4), (4, 5)]) order = bfs_search(graph) self.assertTrue(order == [1, 2, 3, 4, 5] or order == [2, 1, 3, 4, 5])
def calculate_frame_time(graph: Graph): # there are either one or two inputs in Kaldi. Only main input can change delay in network. # Usually ivector input has name 'ivector'. inputs = graph.get_op_nodes(op='Parameter') if len(inputs) == 1: inp_name = inputs[0].name elif len(inputs) == 2: if inputs[0].name == 'ivector': inp_name = inputs[1].name elif inputs[1].name == 'ivector': inp_name = inputs[0].name else: raise Error("There are 2 inputs for Kaldi model but we can't find out which one is ivector. " + "Use name \'ivector\' for the corresponding input") else: raise Error("There are {} inputs for Kaldi model but we expect only 1 or 2".format(len(inputs))) # sort nodes to calculate delays nodes = list(bfs_search(graph, [inp_name])) nx.set_node_attributes(G=graph, name='frame_time', values=-1) for n in nodes: node = Node(graph, n) # just ignore data nodes if node.kind != 'op': continue # calculate frame_time (delay) that was not calculated if node.frame_time < 0: # Splice increases frame delay if node.op == "Splice": node.frame_time = node.in_port(0).get_source().node.frame_time + len(node.context) - 1 # crop often used to get concrete time frame, set frame_time correctly for this case elif node.op == 'Crop': if node.in_port(0).get_connection().get_source().node.op == 'Splice': splice_node = node.in_port(0).get_source().node assert len(node.offset) == 1 assert len(node.dim) == 1 new_delay = splice_node.context[node.offset[0] // node.dim[0]] - splice_node.context[0] node.frame_time = splice_node.in_port(0).get_source().node.frame_time + new_delay else: node.frame_time = node.in_port(0).get_source().node.frame_time # for node with several inputs frame_time = maximum of delays from branches else: # find out maximum of delay and check that we have at least one branch with another delay node.frame_time = 0 for inp in node.in_ports(): if node.in_port(inp).disconnected(): continue in_node = node.in_port(inp).get_source().node if in_node.frame_time > node.frame_time: node.frame_time = in_node.frame_time
def mark_undead_nodes(graph: Graph, undead_types: list): """ Mark output nodes and nodes of the specific type as undead, meaning that they should survive the dead nodes elimination phase. Then mark all children nodes of the undead nodes (except children of inputs) as undead. :param graph: graph to operate on. :param undead_types: list of node types that should be marked as undead. :return: updated graph where each has attribute 'is_undead'. """ nx.set_node_attributes(G=graph, name='is_undead', values=False) # mark output nodes as undead outputs = graph.get_nodes_with_attributes(op='OpOutput') nx.set_node_attributes(G=graph, name='is_undead', values={n: True for n in outputs}) # mark specifically defined with node type set of nodes for type in undead_types: node_of_specific_type = graph.get_nodes_with_attributes(type=type) nx.set_node_attributes(G=graph, name='is_undead', values={n: True for n in node_of_specific_type}) undead_nodes = graph.get_nodes_with_attributes(is_undead=True) # propagate 'undead' attribute to children nodes of undead nodes if the node produces constant value for node_name in bfs_search(graph, undead_nodes): if graph.node[node_name]['is_undead']: for _, dst_node_name in graph.out_edges(node_name): node_attrs = graph.node[dst_node_name] if 'kind' in node_attrs and (node_attrs['kind'] == 'data' and node_attrs['value'] is not None or node_attrs['kind'] == 'op'): graph.node[dst_node_name]['is_undead'] = True # mark input nodes as undead inputs = graph.get_nodes_with_attributes(is_input=True) nx.set_node_attributes(G=graph, name='is_undead', values={n: True for n in inputs})
def calculate_frame_time(graph: Graph): # there are either one or two inputs in Kaldi. Only main input can change delay in network. # Usually ivector input has name 'ivector'. max_frame_time = -2 inputs = graph.get_op_nodes(op='Parameter') inp = check_inputs(graph) inp_name = inp.soft_get('name', inp.id) # sort nodes to calculate delays nodes = list(bfs_search(graph, [inp_name])) for n in nodes: node = Node(graph, n) # just ignore data nodes if node.kind != 'op': continue # calculate frame_time (delay) that was not calculated if node.frame_time < 0: # Splice increases frame delay if node.op == "Splice": if node.in_port(0).get_source().node.frame_time == -1: continue node.frame_time = node.in_port( 0).get_source().node.frame_time + len(node.context) - 1 # crop often used to get concrete time frame, set frame_time correctly for this case elif node.op == 'Crop': if node.in_port(0).get_source().node.frame_time == -1: continue if node.in_port(0).get_connection().get_source( ).node.op == 'Splice': splice_node = node.in_port(0).get_source().node assert len(node.offset) == 1 assert len(node.dim) == 1 new_delay = splice_node.context[ node.offset[0] // node.dim[0]] - splice_node.context[0] node.frame_time = splice_node.in_port( 0).get_source().node.frame_time + new_delay else: node.frame_time = node.in_port( 0).get_source().node.frame_time elif node.op == 'ShapeOf': # exclude shape path from time delay calculation using special value node.frame_time = max_frame_time elif node.op == 'Broadcast': # finished shape path node.frame_time = node.in_port( 0).get_source().node.frame_time # for node with several inputs frame_time = maximum of delays from branches else: # find out maximum of delay and check that we have at least one branch with another delay node.frame_time = -1 if len(node.in_ports()) != 0 else 0 min_in_frame_time = -1 for inp in node.in_ports(): if node.in_port(inp).disconnected(): continue in_node = node.in_port(inp).get_source().node if in_node.frame_time < min_in_frame_time: min_in_frame_time = in_node.frame_time if in_node.frame_time > node.frame_time and in_node.frame_time != -1: node.frame_time = in_node.frame_time # if all inputs have special value for frame time, node have special value for frame time too # because it is on shape path if min_in_frame_time == max_frame_time: node.frame_time = max_frame_time