Example #1
0
def mark_undead_nodes(graph, undead_types: list):
    """
    Mark output nodes and nodes of the specific type as undead, meaning that they should survive the dead nodes
    elimination phase. Then mark all children nodes of the undead nodes (except children of inputs) as undead.
    :param graph: graph to operate on.
    :param undead_types: list of node types that should be marked as undead.
    :return: updated graph where each has attribute 'is_undead'.
    """
    from mo.utils.graph import bfs_search

    nx.set_node_attributes(G=graph, name='is_undead', values=False)

    undead_types_with_result = undead_types + ['Result']
    undead_nodes = []
    for node in graph.get_op_nodes():
        node_type = node.soft_get('type', node.soft_get('op'))
        if node_type in undead_types_with_result:
            undead_nodes.append(node.id)

    nx.set_node_attributes(G=graph, name='is_undead', values={n: True for n in undead_nodes})
    # propagate 'undead' attribute to children nodes of undead nodes if the node produces constant value
    for node_name in bfs_search(graph, undead_nodes):
        if graph.node[node_name]['is_undead']:
            for _, dst_node_name in graph.out_edges(node_name):
                node_attrs = graph.node[dst_node_name]
                if 'kind' in node_attrs and (
                        node_attrs['kind'] == 'data' and node_attrs['value'] is not None or node_attrs['kind'] == 'op'):
                    graph.node[dst_node_name]['is_undead'] = True

    # mark input nodes as undead
    inputs = graph.get_nodes_with_attributes(is_input=True)
    nx.set_node_attributes(G=graph, name='is_undead', values={n: True for n in inputs})
Example #2
0
    def test_bfs_search_specific_start_nodes(self):
        """
        Check that BFS stars from the user defined nodes and doesn't go in backward edge direction.
        """
        graph = Graph()
        graph.add_nodes_from(list(range(1, 7)))
        graph.add_edges_from([(1, 3), (2, 3), (3, 4), (4, 5), (6, 1)])

        order = bfs_search(graph, [1])
        self.assertTrue(order == [1, 3, 4, 5])
Example #3
0
    def test_bfs_search_default_start_nodes(self):
        """
        Check that BFS automatically determines input nodes and start searching from them.
        """
        graph = Graph()
        graph.add_nodes_from(list(range(1, 6)))
        graph.add_edges_from([(1, 3), (2, 3), (3, 4), (4, 5)])

        order = bfs_search(graph)
        self.assertTrue(order == [1, 2, 3, 4, 5] or order == [2, 1, 3, 4, 5])
Example #4
0
    def calculate_frame_time(graph: Graph):
        # there are either one or two inputs in Kaldi. Only main input can change delay in network.
        # Usually ivector input has name 'ivector'.
        inputs = graph.get_op_nodes(op='Parameter')
        if len(inputs) == 1:
            inp_name = inputs[0].name
        elif len(inputs) == 2:
            if inputs[0].name == 'ivector':
                inp_name = inputs[1].name
            elif inputs[1].name == 'ivector':
                inp_name = inputs[0].name
            else:
                raise Error("There are 2 inputs for Kaldi model but we can't find out which one is ivector. " +
                            "Use name \'ivector\' for the corresponding input")
        else:
            raise Error("There are {} inputs for Kaldi model but we expect only 1 or 2".format(len(inputs)))

        # sort nodes to calculate delays
        nodes = list(bfs_search(graph, [inp_name]))
        nx.set_node_attributes(G=graph, name='frame_time', values=-1)

        for n in nodes:
            node = Node(graph, n)

            # just ignore data nodes
            if node.kind != 'op':
                continue

            # calculate frame_time (delay) that was not calculated
            if node.frame_time < 0:
                # Splice increases frame delay
                if node.op == "Splice":
                    node.frame_time = node.in_port(0).get_source().node.frame_time + len(node.context) - 1
                # crop often used to get concrete time frame, set frame_time correctly for this case
                elif node.op == 'Crop':
                    if node.in_port(0).get_connection().get_source().node.op == 'Splice':
                        splice_node = node.in_port(0).get_source().node
                        assert len(node.offset) == 1
                        assert len(node.dim) == 1
                        new_delay = splice_node.context[node.offset[0] // node.dim[0]] - splice_node.context[0]
                        node.frame_time = splice_node.in_port(0).get_source().node.frame_time + new_delay
                    else:
                        node.frame_time = node.in_port(0).get_source().node.frame_time
                # for node with several inputs frame_time = maximum of delays from branches
                else:
                    # find out maximum of delay and check that we have at least one branch with another delay
                    node.frame_time = 0
                    for inp in node.in_ports():
                        if node.in_port(inp).disconnected():
                            continue
                        in_node = node.in_port(inp).get_source().node
                        if in_node.frame_time > node.frame_time:
                            node.frame_time = in_node.frame_time
Example #5
0
def mark_undead_nodes(graph: Graph, undead_types: list):
    """
    Mark output nodes and nodes of the specific type as undead, meaning that they should survive the dead nodes
    elimination phase. Then mark all children nodes of the undead nodes (except children of inputs) as undead.
    :param graph: graph to operate on.
    :param undead_types: list of node types that should be marked as undead.
    :return: updated graph where each has attribute 'is_undead'.
    """
    nx.set_node_attributes(G=graph, name='is_undead', values=False)

    # mark output nodes as undead
    outputs = graph.get_nodes_with_attributes(op='OpOutput')
    nx.set_node_attributes(G=graph,
                           name='is_undead',
                           values={n: True
                                   for n in outputs})

    # mark specifically defined with node type set of nodes
    for type in undead_types:
        node_of_specific_type = graph.get_nodes_with_attributes(type=type)
        nx.set_node_attributes(G=graph,
                               name='is_undead',
                               values={n: True
                                       for n in node_of_specific_type})

    undead_nodes = graph.get_nodes_with_attributes(is_undead=True)
    # propagate 'undead' attribute to children nodes of undead nodes if the node produces constant value
    for node_name in bfs_search(graph, undead_nodes):
        if graph.node[node_name]['is_undead']:
            for _, dst_node_name in graph.out_edges(node_name):
                node_attrs = graph.node[dst_node_name]
                if 'kind' in node_attrs and (node_attrs['kind'] == 'data' and
                                             node_attrs['value'] is not None
                                             or node_attrs['kind'] == 'op'):
                    graph.node[dst_node_name]['is_undead'] = True

    # mark input nodes as undead
    inputs = graph.get_nodes_with_attributes(is_input=True)
    nx.set_node_attributes(G=graph,
                           name='is_undead',
                           values={n: True
                                   for n in inputs})
Example #6
0
    def calculate_frame_time(graph: Graph):
        # there are either one or two inputs in Kaldi. Only main input can change delay in network.
        # Usually ivector input has name 'ivector'.
        max_frame_time = -2
        inputs = graph.get_op_nodes(op='Parameter')
        inp = check_inputs(graph)
        inp_name = inp.soft_get('name', inp.id)

        # sort nodes to calculate delays
        nodes = list(bfs_search(graph, [inp_name]))

        for n in nodes:
            node = Node(graph, n)

            # just ignore data nodes
            if node.kind != 'op':
                continue

            # calculate frame_time (delay) that was not calculated
            if node.frame_time < 0:
                # Splice increases frame delay
                if node.op == "Splice":
                    if node.in_port(0).get_source().node.frame_time == -1:
                        continue
                    node.frame_time = node.in_port(
                        0).get_source().node.frame_time + len(node.context) - 1
                # crop often used to get concrete time frame, set frame_time correctly for this case
                elif node.op == 'Crop':
                    if node.in_port(0).get_source().node.frame_time == -1:
                        continue
                    if node.in_port(0).get_connection().get_source(
                    ).node.op == 'Splice':
                        splice_node = node.in_port(0).get_source().node
                        assert len(node.offset) == 1
                        assert len(node.dim) == 1
                        new_delay = splice_node.context[
                            node.offset[0] //
                            node.dim[0]] - splice_node.context[0]
                        node.frame_time = splice_node.in_port(
                            0).get_source().node.frame_time + new_delay
                    else:
                        node.frame_time = node.in_port(
                            0).get_source().node.frame_time
                elif node.op == 'ShapeOf':
                    # exclude shape path from time delay calculation using special value
                    node.frame_time = max_frame_time
                elif node.op == 'Broadcast':
                    # finished shape path
                    node.frame_time = node.in_port(
                        0).get_source().node.frame_time
                # for node with several inputs frame_time = maximum of delays from branches
                else:
                    # find out maximum of delay and check that we have at least one branch with another delay
                    node.frame_time = -1 if len(node.in_ports()) != 0 else 0
                    min_in_frame_time = -1
                    for inp in node.in_ports():
                        if node.in_port(inp).disconnected():
                            continue
                        in_node = node.in_port(inp).get_source().node
                        if in_node.frame_time < min_in_frame_time:
                            min_in_frame_time = in_node.frame_time
                        if in_node.frame_time > node.frame_time and in_node.frame_time != -1:
                            node.frame_time = in_node.frame_time
                    # if all inputs have special value for frame time, node have special value for frame time too
                    # because it is on shape path
                    if min_in_frame_time == max_frame_time:
                        node.frame_time = max_frame_time