def find_and_replace_pattern(self, graph: Graph): should_continue = False for n in graph: if Node(graph, n).op == 'MemoryOffset' and Node(graph, n).t > 0: should_continue = True break if not should_continue: return try: nodes = list(nx.topological_sort(graph)) except: return nx.set_node_attributes(G=graph, name='frame_time', values=-1) for n in nodes: node = Node(graph, n) # calculate frame_time (delay) that was not calculated if node.frame_time < 0: # MemoryOffset with t>0 increases frame delay if node.op == "MemoryOffset": node.frame_time = node.in_port( 0).get_source().node.frame_time + node.t # for node with several inputs frame_time = maximum of delays from branches # other branches should be synced by adding MemoryOffset(branch frame_time - max) # After that MemoryOffset with maximum delay should be deleted (t becomes 0) elif len(node.in_edges()) > 1: # find out maximum of delay and check that we have at least one branch with another delay in_frame_time_max, should_align = find_max_frame_time(node) if should_align: align_frame_time(graph, node, in_frame_time_max) node.frame_time = in_frame_time_max elif len(node.in_edges()) == 1: node.frame_time = node.in_port( 0).get_source().node.frame_time else: # for all input nodes (without inputs) frame_time is 0 node.frame_time = 0 for n in graph: node = Node(graph, n) if 'frame_time' in node: del node['frame_time']
def calculate_frame_time(graph: Graph): # there are either one or two inputs in Kaldi. Only main input can change delay in network. # Usually ivector input has name 'ivector'. max_frame_time = -2 inputs = graph.get_op_nodes(op='Parameter') inp = check_inputs(graph) inp_name = inp.soft_get('name', inp.id) # sort nodes to calculate delays nodes = list(bfs_search(graph, [inp_name])) for n in nodes: node = Node(graph, n) # just ignore data nodes if node.kind != 'op': continue # calculate frame_time (delay) that was not calculated if node.frame_time < 0: # Splice increases frame delay if node.op == "Splice": if node.in_port(0).get_source().node.frame_time == -1: continue node.frame_time = node.in_port( 0).get_source().node.frame_time + len(node.context) - 1 # crop often used to get concrete time frame, set frame_time correctly for this case elif node.op == 'Crop': if node.in_port(0).get_source().node.frame_time == -1: continue if node.in_port(0).get_connection().get_source( ).node.op == 'Splice': splice_node = node.in_port(0).get_source().node assert len(node.offset) == 1 assert len(node.dim) == 1 new_delay = splice_node.context[ node.offset[0] // node.dim[0]] - splice_node.context[0] node.frame_time = splice_node.in_port( 0).get_source().node.frame_time + new_delay else: node.frame_time = node.in_port( 0).get_source().node.frame_time elif node.op == 'ShapeOf': # exclude shape path from time delay calculation using special value node.frame_time = max_frame_time elif node.op == 'Broadcast': # finished shape path node.frame_time = node.in_port( 0).get_source().node.frame_time # for node with several inputs frame_time = maximum of delays from branches else: # find out maximum of delay and check that we have at least one branch with another delay node.frame_time = -1 if len(node.in_ports()) != 0 else 0 min_in_frame_time = -1 for inp in node.in_ports(): if node.in_port(inp).disconnected(): continue in_node = node.in_port(inp).get_source().node if in_node.frame_time < min_in_frame_time: min_in_frame_time = in_node.frame_time if in_node.frame_time > node.frame_time and in_node.frame_time != -1: node.frame_time = in_node.frame_time # if all inputs have special value for frame time, node have special value for frame time too # because it is on shape path if min_in_frame_time == max_frame_time: node.frame_time = max_frame_time