def find_and_replace_pattern(self, graph: Graph):
        should_continue = False
        for n in graph:
            if Node(graph, n).op == 'MemoryOffset' and Node(graph, n).t > 0:
                should_continue = True
                break

        if not should_continue:
            return

        try:
            nodes = list(nx.topological_sort(graph))
        except:
            return

        nx.set_node_attributes(G=graph, name='frame_time', values=-1)

        for n in nodes:
            node = Node(graph, n)

            # calculate frame_time (delay) that was not calculated
            if node.frame_time < 0:
                # MemoryOffset with t>0 increases frame delay
                if node.op == "MemoryOffset":
                    node.frame_time = node.in_port(
                        0).get_source().node.frame_time + node.t
                # for node with several inputs frame_time = maximum of delays from branches
                # other branches should be synced by adding MemoryOffset(branch frame_time  - max)
                # After that MemoryOffset with maximum delay should be deleted (t becomes 0)
                elif len(node.in_edges()) > 1:
                    # find out maximum of delay and check that we have at least one branch with another delay
                    in_frame_time_max, should_align = find_max_frame_time(node)
                    if should_align:
                        align_frame_time(graph, node, in_frame_time_max)
                    node.frame_time = in_frame_time_max
                elif len(node.in_edges()) == 1:
                    node.frame_time = node.in_port(
                        0).get_source().node.frame_time
                else:
                    # for all input nodes (without inputs) frame_time is 0
                    node.frame_time = 0

        for n in graph:
            node = Node(graph, n)
            if 'frame_time' in node:
                del node['frame_time']
Ejemplo n.º 2
0
    def special_port_to_real_port(node: Node, special_port_id: int, direction: str = 'in'):
        assert node.kind == 'op'
        assert direction in ['in', 'out']

        port_type = 'external_port_id' if node.has_valid('body') else 'internal_port_id'

        if direction == 'in':
            edges = node.in_edges()
        else:
            edges = node.out_edges()

        suitable_edges = {}
        for idx, attrs in edges.items():
            if port_type in attrs and attrs[port_type] == special_port_id:
                suitable_edges[idx] = attrs
        assert len(suitable_edges) == 1
        return list(suitable_edges.keys())[0]
Ejemplo n.º 3
0
def build_graph_with_attrs(nodes_with_attrs: list,
                           edges_with_attrs: list,
                           new_nodes_with_attrs: list = [],
                           new_edges_with_attrs: list = [],
                           update_edge_attrs: dict = None,
                           update_nodes_attributes: list = None,
                           nodes_with_edges_only: bool = False,
                           add_nodes_from_edges: bool = False):
    """
    Build the Graph with specific nodes and edges. Also update of edge and node parameters is supported.
    :param nodes_with_attrs: list of tuples ('node_name', {node_attrs})
    :param edges_with_attrs: list of tuples like (start node, end node, (optional) {attrs of the edge}).
    :param new_nodes_with_attrs: analogically nodes_with_attrs
    :param new_edges_with_attrs: analogically new_edges
    :param update_edge_attrs: optional dictionary like {('from_node', 'to_node', key): {edge_attrs}}.
    :param update_nodes_attributes: optional list of tuples which specifies nodes names and their attributes to be
    updated. The first element is a node name to update attribute and the second element is a dictionary with attribute
    name and its value.
    :param nodes_with_edges_only: add nodes which has at least one incoming or outcoming edge.
    :param add_nodes_from_edges: whether nodes that is not listed in all_nodes but are in all_edges is allowed.
    :return: generated graph.
    """
    if not_all_new([node[0] for node in nodes_with_attrs],
                   [node[0] for node in new_nodes_with_attrs]):
        raise Error(
            'Some nodes from new_nodes_with_attrs are already in nodes.'
            ' Please, add to new_nodes_with_attrs only NEW nodes.')

    if not_all_new([(edge[0], edge[1]) for edge in edges_with_attrs],
                   [(edge[0], edge[1]) for edge in new_edges_with_attrs]):
        raise Error(
            'Some edges from new_edges_with_attrs are already in edges.'
            ' Please, add to new_edges_with_attrs only NEW edges.')

    # Check that all nodes from list of edges are in nodes
    all_nodes = nodes_with_attrs + new_nodes_with_attrs
    all_edges = edges_with_attrs + new_edges_with_attrs
    all_nodes_names = [node[0] for node in all_nodes]
    if not add_nodes_from_edges and not all_edges_in_nodes(
            nodes=all_nodes_names, edges=all_edges):
        raise Error(
            "Some nodes from list of edges is not in nodes. Please, add all necessary nodes."
        )

    graph = Graph()

    # Create dict for nodes with attrs
    nodes_attrs = {}
    for node_name, attrs in all_nodes:
        nodes_attrs[node_name] = attrs
        if 'name' not in attrs:
            attrs['name'] = node_name

    if nodes_with_edges_only:
        # filter nodes to keep only ones with edges connected
        filtered_nodes = {}
        for edge in all_edges:
            node_1, node_2 = edge[0], edge[1]
            filtered_nodes[node_1] = nodes_attrs[node_1]
            filtered_nodes[node_2] = nodes_attrs[node_2]
        nodes_attrs = filtered_nodes

    # Create all nodes
    for node, attrs in nodes_attrs.items():
        graph.add_node(node, **deepcopy(attrs))

    # Connect nodes with edges (also unpack edge params)
    for edge in all_edges:
        node_1, node_2 = edge[0], edge[1]
        edge_attrs = edge[2] if len(edge) == 3 else {}
        graph.add_edge(node_1, node_2, **edge_attrs)

    # Update attributes of edges
    if update_edge_attrs:
        # it will work in 2.x networkx only
        for edge, attr in update_edge_attrs.items():
            for k, v in attr.items():
                nx.set_edge_attributes(G=graph, name=k, values={edge: v})

    # Update attributes of nodes
    if update_nodes_attributes is not None:
        for node_name, new_attrs in update_nodes_attributes:
            assert (node_name in graph.nodes())
            for attr, value in new_attrs.items():
                graph.node[node_name][attr] = value

    for node_id in graph.nodes():
        node = Node(graph, node_id)
        check_and_update_ports(node, [
            graph.get_edge_data(edge[0], node_id)[0]
            for edge in graph.in_edges(node_id)
        ], True)
        check_and_update_ports(node, [
            graph.get_edge_data(node_id, edge[1])[0]
            for edge in graph.out_edges(node_id)
        ], False)

    for node in graph.get_op_nodes():
        # Add in_ports attribute
        in_edges = node.in_edges()
        for i in range(len(in_edges)):
            node.add_input_port(idx=i)

        # Add out_ports attribute
        out_edges = node.out_edges()
        for i in range(len(out_edges)):
            node.add_output_port(idx=i)
    return graph
Ejemplo n.º 4
0
    def replace_op(self, graph: Graph, node: Node):
        if node.use_peephole:
            raise Error(
                "BlockLSTM operation is not supported with `use_peephole`==True. Node: {}"
                "".format(node.soft_get('name')))

        if node.cell_clip != -1:
            raise Error(
                "Clipping is not supported for BlockLSTM operation. `cell_clip`={!s} for node: {}"
                "".format(node.cell_clip, node.soft_get('name')))

        log.debug(
            "Start BlockLSTM->LSTMSequence translation for node: {} with parameters:\n"
            "`cell_clip`={!s}, `use_peephole`=={!s}, `forget_bias`={!s}\n"
            "inputs: {},\noutputs:{}".format(
                node.soft_get('name'), node.cell_clip, node.use_peephole,
                node.forget_bias,
                {p: i.id
                 for p, i in node.in_nodes().items()},
                {p: o.id
                 for p, o in node.out_nodes().items()}))

        log.debug(
            "Cutting all inputs for peephole connection (5, 6, 7 input ports) off, as `use_peephole`=False"
        )

        for p, input_data in node.in_nodes().items():
            if p in [5, 6, 7]:
                key = self.find_key_by_input_port(node.in_node(p), node, p)
                assert key is not None
                graph.remove_edge(node.in_node(p).id, node.id, key=key)

        log.debug("Cutting seq_len_max input off")
        graph.remove_edge(node.in_node(0).id, node.id)
        """
        Reconnecting input edges of LSTMSequence:
        TF input edges:             Description:                 MO input edges:
              1                          input                        0
              4                         weights                       1
              8                         biases                        2
              3               h_prev: initial output of cell          3
              2               cs_prev: initial cell state             4
        """
        inputs = node.in_edges()
        assert 1 in inputs, "Sequence input to the BlockLSTM is required (1 port). Node {}".format(
            node.id)
        assert 2 in inputs, "Value of the initial cell state is required (2 port). Node {}".format(
            node.id)
        assert 3 in inputs, "Initial output of cell is required input to BlockLSTM (3 port). Node {}".format(
            node.id)
        assert 4 in inputs, "The weight matrix is required input to BlockLSTM (4 port) . Node {}".format(
            node.id)
        assert 8 in inputs, "The bias vector is required input to BlockLSTM (8 port). Node {}".format(
            node.id)

        inputs[3]['in'] = 3
        inputs[1]['in'] = 0
        inputs[4]['in'] = 1
        inputs[2]['in'] = 4
        inputs[8]['in'] = 2

        log.debug(
            "Checking for unsupported outputs usage (output ports: 0, 2, 3, 4, 5)"
        )
        for port, input_data in node.out_nodes().items():
            if port in [0, 2, 3, 4, 5]:
                raise Error(
                    "Output port {} of BlockLSTM node {} is not supported".
                    format(node.id, port))
        """
        Reconnecting output edges of LSTMSequence:
        TF output edges:             Description:                 MO output edges:
              6                     output h vector                     0
              1                   cell state before the tanh            1
        """

        outputs = node.out_edges()
        if 6 in outputs:
            outputs[6]['out'] = 0
            node.add_output_port(0, skip_if_exist=True)

        # do not replace any output edge
        return []