def find_and_replace_pattern(self, graph: Graph): should_continue = False for n in graph: if Node(graph, n).op == 'MemoryOffset' and Node(graph, n).t > 0: should_continue = True break if not should_continue: return try: nodes = list(nx.topological_sort(graph)) except: return nx.set_node_attributes(G=graph, name='frame_time', values=-1) for n in nodes: node = Node(graph, n) # calculate frame_time (delay) that was not calculated if node.frame_time < 0: # MemoryOffset with t>0 increases frame delay if node.op == "MemoryOffset": node.frame_time = node.in_port( 0).get_source().node.frame_time + node.t # for node with several inputs frame_time = maximum of delays from branches # other branches should be synced by adding MemoryOffset(branch frame_time - max) # After that MemoryOffset with maximum delay should be deleted (t becomes 0) elif len(node.in_edges()) > 1: # find out maximum of delay and check that we have at least one branch with another delay in_frame_time_max, should_align = find_max_frame_time(node) if should_align: align_frame_time(graph, node, in_frame_time_max) node.frame_time = in_frame_time_max elif len(node.in_edges()) == 1: node.frame_time = node.in_port( 0).get_source().node.frame_time else: # for all input nodes (without inputs) frame_time is 0 node.frame_time = 0 for n in graph: node = Node(graph, n) if 'frame_time' in node: del node['frame_time']
def special_port_to_real_port(node: Node, special_port_id: int, direction: str = 'in'): assert node.kind == 'op' assert direction in ['in', 'out'] port_type = 'external_port_id' if node.has_valid('body') else 'internal_port_id' if direction == 'in': edges = node.in_edges() else: edges = node.out_edges() suitable_edges = {} for idx, attrs in edges.items(): if port_type in attrs and attrs[port_type] == special_port_id: suitable_edges[idx] = attrs assert len(suitable_edges) == 1 return list(suitable_edges.keys())[0]
def build_graph_with_attrs(nodes_with_attrs: list, edges_with_attrs: list, new_nodes_with_attrs: list = [], new_edges_with_attrs: list = [], update_edge_attrs: dict = None, update_nodes_attributes: list = None, nodes_with_edges_only: bool = False, add_nodes_from_edges: bool = False): """ Build the Graph with specific nodes and edges. Also update of edge and node parameters is supported. :param nodes_with_attrs: list of tuples ('node_name', {node_attrs}) :param edges_with_attrs: list of tuples like (start node, end node, (optional) {attrs of the edge}). :param new_nodes_with_attrs: analogically nodes_with_attrs :param new_edges_with_attrs: analogically new_edges :param update_edge_attrs: optional dictionary like {('from_node', 'to_node', key): {edge_attrs}}. :param update_nodes_attributes: optional list of tuples which specifies nodes names and their attributes to be updated. The first element is a node name to update attribute and the second element is a dictionary with attribute name and its value. :param nodes_with_edges_only: add nodes which has at least one incoming or outcoming edge. :param add_nodes_from_edges: whether nodes that is not listed in all_nodes but are in all_edges is allowed. :return: generated graph. """ if not_all_new([node[0] for node in nodes_with_attrs], [node[0] for node in new_nodes_with_attrs]): raise Error( 'Some nodes from new_nodes_with_attrs are already in nodes.' ' Please, add to new_nodes_with_attrs only NEW nodes.') if not_all_new([(edge[0], edge[1]) for edge in edges_with_attrs], [(edge[0], edge[1]) for edge in new_edges_with_attrs]): raise Error( 'Some edges from new_edges_with_attrs are already in edges.' ' Please, add to new_edges_with_attrs only NEW edges.') # Check that all nodes from list of edges are in nodes all_nodes = nodes_with_attrs + new_nodes_with_attrs all_edges = edges_with_attrs + new_edges_with_attrs all_nodes_names = [node[0] for node in all_nodes] if not add_nodes_from_edges and not all_edges_in_nodes( nodes=all_nodes_names, edges=all_edges): raise Error( "Some nodes from list of edges is not in nodes. Please, add all necessary nodes." ) graph = Graph() # Create dict for nodes with attrs nodes_attrs = {} for node_name, attrs in all_nodes: nodes_attrs[node_name] = attrs if 'name' not in attrs: attrs['name'] = node_name if nodes_with_edges_only: # filter nodes to keep only ones with edges connected filtered_nodes = {} for edge in all_edges: node_1, node_2 = edge[0], edge[1] filtered_nodes[node_1] = nodes_attrs[node_1] filtered_nodes[node_2] = nodes_attrs[node_2] nodes_attrs = filtered_nodes # Create all nodes for node, attrs in nodes_attrs.items(): graph.add_node(node, **deepcopy(attrs)) # Connect nodes with edges (also unpack edge params) for edge in all_edges: node_1, node_2 = edge[0], edge[1] edge_attrs = edge[2] if len(edge) == 3 else {} graph.add_edge(node_1, node_2, **edge_attrs) # Update attributes of edges if update_edge_attrs: # it will work in 2.x networkx only for edge, attr in update_edge_attrs.items(): for k, v in attr.items(): nx.set_edge_attributes(G=graph, name=k, values={edge: v}) # Update attributes of nodes if update_nodes_attributes is not None: for node_name, new_attrs in update_nodes_attributes: assert (node_name in graph.nodes()) for attr, value in new_attrs.items(): graph.node[node_name][attr] = value for node_id in graph.nodes(): node = Node(graph, node_id) check_and_update_ports(node, [ graph.get_edge_data(edge[0], node_id)[0] for edge in graph.in_edges(node_id) ], True) check_and_update_ports(node, [ graph.get_edge_data(node_id, edge[1])[0] for edge in graph.out_edges(node_id) ], False) for node in graph.get_op_nodes(): # Add in_ports attribute in_edges = node.in_edges() for i in range(len(in_edges)): node.add_input_port(idx=i) # Add out_ports attribute out_edges = node.out_edges() for i in range(len(out_edges)): node.add_output_port(idx=i) return graph
def replace_op(self, graph: Graph, node: Node): if node.use_peephole: raise Error( "BlockLSTM operation is not supported with `use_peephole`==True. Node: {}" "".format(node.soft_get('name'))) if node.cell_clip != -1: raise Error( "Clipping is not supported for BlockLSTM operation. `cell_clip`={!s} for node: {}" "".format(node.cell_clip, node.soft_get('name'))) log.debug( "Start BlockLSTM->LSTMSequence translation for node: {} with parameters:\n" "`cell_clip`={!s}, `use_peephole`=={!s}, `forget_bias`={!s}\n" "inputs: {},\noutputs:{}".format( node.soft_get('name'), node.cell_clip, node.use_peephole, node.forget_bias, {p: i.id for p, i in node.in_nodes().items()}, {p: o.id for p, o in node.out_nodes().items()})) log.debug( "Cutting all inputs for peephole connection (5, 6, 7 input ports) off, as `use_peephole`=False" ) for p, input_data in node.in_nodes().items(): if p in [5, 6, 7]: key = self.find_key_by_input_port(node.in_node(p), node, p) assert key is not None graph.remove_edge(node.in_node(p).id, node.id, key=key) log.debug("Cutting seq_len_max input off") graph.remove_edge(node.in_node(0).id, node.id) """ Reconnecting input edges of LSTMSequence: TF input edges: Description: MO input edges: 1 input 0 4 weights 1 8 biases 2 3 h_prev: initial output of cell 3 2 cs_prev: initial cell state 4 """ inputs = node.in_edges() assert 1 in inputs, "Sequence input to the BlockLSTM is required (1 port). Node {}".format( node.id) assert 2 in inputs, "Value of the initial cell state is required (2 port). Node {}".format( node.id) assert 3 in inputs, "Initial output of cell is required input to BlockLSTM (3 port). Node {}".format( node.id) assert 4 in inputs, "The weight matrix is required input to BlockLSTM (4 port) . Node {}".format( node.id) assert 8 in inputs, "The bias vector is required input to BlockLSTM (8 port). Node {}".format( node.id) inputs[3]['in'] = 3 inputs[1]['in'] = 0 inputs[4]['in'] = 1 inputs[2]['in'] = 4 inputs[8]['in'] = 2 log.debug( "Checking for unsupported outputs usage (output ports: 0, 2, 3, 4, 5)" ) for port, input_data in node.out_nodes().items(): if port in [0, 2, 3, 4, 5]: raise Error( "Output port {} of BlockLSTM node {} is not supported". format(node.id, port)) """ Reconnecting output edges of LSTMSequence: TF output edges: Description: MO output edges: 6 output h vector 0 1 cell state before the tanh 1 """ outputs = node.out_edges() if 6 in outputs: outputs[6]['out'] = 0 node.add_output_port(0, skip_if_exist=True) # do not replace any output edge return []