def find_and_replace_pattern(self, graph: Graph):
        should_continue = False
        for n in graph:
            if Node(graph, n).op == 'MemoryOffset' and Node(graph, n).t > 0:
                should_continue = True
                break

        if not should_continue:
            return

        try:
            nodes = list(nx.topological_sort(graph))
        except:
            return

        nx.set_node_attributes(G=graph, name='frame_time', values=-1)

        for n in nodes:
            node = Node(graph, n)

            # calculate frame_time (delay) that was not calculated
            if node.frame_time < 0:
                # MemoryOffset with t>0 increases frame delay
                if node.op == "MemoryOffset":
                    node.frame_time = node.in_port(
                        0).get_source().node.frame_time + node.t
                # for node with several inputs frame_time = maximum of delays from branches
                # other branches should be synced by adding MemoryOffset(branch frame_time  - max)
                # After that MemoryOffset with maximum delay should be deleted (t becomes 0)
                elif len(node.in_edges()) > 1:
                    # find out maximum of delay and check that we have at least one branch with another delay
                    in_frame_time_max, should_align = find_max_frame_time(node)
                    if should_align:
                        align_frame_time(graph, node, in_frame_time_max)
                    node.frame_time = in_frame_time_max
                elif len(node.in_edges()) == 1:
                    node.frame_time = node.in_port(
                        0).get_source().node.frame_time
                else:
                    # for all input nodes (without inputs) frame_time is 0
                    node.frame_time = 0

        for n in graph:
            node = Node(graph, n)
            if 'frame_time' in node:
                del node['frame_time']
Пример #2
0
    def calculate_frame_time(graph: Graph):
        # there are either one or two inputs in Kaldi. Only main input can change delay in network.
        # Usually ivector input has name 'ivector'.
        max_frame_time = -2
        inputs = graph.get_op_nodes(op='Parameter')
        inp = check_inputs(graph)
        inp_name = inp.soft_get('name', inp.id)

        # sort nodes to calculate delays
        nodes = list(bfs_search(graph, [inp_name]))

        for n in nodes:
            node = Node(graph, n)

            # just ignore data nodes
            if node.kind != 'op':
                continue

            # calculate frame_time (delay) that was not calculated
            if node.frame_time < 0:
                # Splice increases frame delay
                if node.op == "Splice":
                    if node.in_port(0).get_source().node.frame_time == -1:
                        continue
                    node.frame_time = node.in_port(
                        0).get_source().node.frame_time + len(node.context) - 1
                # crop often used to get concrete time frame, set frame_time correctly for this case
                elif node.op == 'Crop':
                    if node.in_port(0).get_source().node.frame_time == -1:
                        continue
                    if node.in_port(0).get_connection().get_source(
                    ).node.op == 'Splice':
                        splice_node = node.in_port(0).get_source().node
                        assert len(node.offset) == 1
                        assert len(node.dim) == 1
                        new_delay = splice_node.context[
                            node.offset[0] //
                            node.dim[0]] - splice_node.context[0]
                        node.frame_time = splice_node.in_port(
                            0).get_source().node.frame_time + new_delay
                    else:
                        node.frame_time = node.in_port(
                            0).get_source().node.frame_time
                elif node.op == 'ShapeOf':
                    # exclude shape path from time delay calculation using special value
                    node.frame_time = max_frame_time
                elif node.op == 'Broadcast':
                    # finished shape path
                    node.frame_time = node.in_port(
                        0).get_source().node.frame_time
                # for node with several inputs frame_time = maximum of delays from branches
                else:
                    # find out maximum of delay and check that we have at least one branch with another delay
                    node.frame_time = -1 if len(node.in_ports()) != 0 else 0
                    min_in_frame_time = -1
                    for inp in node.in_ports():
                        if node.in_port(inp).disconnected():
                            continue
                        in_node = node.in_port(inp).get_source().node
                        if in_node.frame_time < min_in_frame_time:
                            min_in_frame_time = in_node.frame_time
                        if in_node.frame_time > node.frame_time and in_node.frame_time != -1:
                            node.frame_time = in_node.frame_time
                    # if all inputs have special value for frame time, node have special value for frame time too
                    # because it is on shape path
                    if min_in_frame_time == max_frame_time:
                        node.frame_time = max_frame_time