def equality_transitive_closure(graph: Graph, equality_label: int, join_nodes: Optional[Set[Node]] = None, valid_combinations: Optional[Set[int]] = None): worklist = collections.deque( e for e in graph.iter_edges(label=equality_label)) seen_edges = set(worklist) while len(worklist) > 0: edge_item = worklist.popleft() added = set() if join_nodes is None or edge_item.src in join_nodes: for e in graph.iter_edges(dst=edge_item.src): if valid_combinations is None or e.label in valid_combinations: added.add(Edge(e.src, edge_item.dst, e.label)) if join_nodes is None or edge_item.dst in join_nodes: for e in graph.iter_edges(src=edge_item.dst): if valid_combinations is None or e.label in valid_combinations: added.add(Edge(edge_item.src, e.dst, e.label)) added -= seen_edges if len(added) > 0: graph.add_nodes_and_edges(edges=added) for e in added: if e.label == equality_label: worklist.append(e) seen_edges.add(e)
def extract_paths(graph: Graph, input_entities: List[Entity], output_entity: Entity): path_dict: Dict[Entity, List[Path]] = {ent: [] for ent in input_entities} for node in itertools.chain(*(graph.iter_nodes(entity=ent) for ent in input_entities)): # Find all the paths from node to an output node, without any other input or output nodes in between. # An entry is the set of visited nodes, the current node to explore, and the current set of edges. entry: Tuple[Set[Node], Node, List[Edge]] = ({node}, node, []) worklist = collections.deque([entry]) paths: List[Path] = [] while len(worklist) > 0: visited, cur_node, edges = worklist.popleft() for edge in graph.iter_edges(src=cur_node): dst = edge.dst if dst in visited or dst.entity in input_entities: continue if dst.entity is output_entity: paths.append((visited | {dst}, edges + [edge])) else: worklist.append((visited | {dst}, dst, edges + [edge])) path_dict[node.entity].extend(paths) return path_dict
def _get_explanation_expr_str( graph: Graph, node: Node, node_label_dict: Dict[int, str], edge_label_dict: Dict[int, str]) -> Optional[str]: args = collections.defaultdict(list) for edge in graph.iter_edges(dst=node): label = edge_label_dict[edge.label] if label.startswith("CUM") or label == "COLUMN" or label == "ROW": continue if node_label_dict[edge.src.label] == "INTERM": args[label].append( _get_explanation_expr_str(graph, edge.src, node_label_dict, edge_label_dict)) else: args[label].append(str(edge.src.value)) if len(args) == 0: return None if "EQUAL" in args: return args["EQUAL"][0] key = next(iter(args.keys())) arg_str = ", ".join(args[key]) return f"({key.upper()}({arg_str}))"
def _get_involved_nodes(graph: Graph, node: Node, node_label_dict: Dict[int, str], edge_label_dict: Dict[int, str]) -> Set[Node]: result = set() for edge in graph.iter_edges(dst=node): label = edge_label_dict[edge.label] if label.startswith("CUM") or label == "COLUMN" or label == "ROW": continue result.add(edge.src) result.update( _get_involved_nodes(graph, edge.src, node_label_dict, edge_label_dict)) return result
def create_symbolic_copy(graph: Graph) -> Tuple[Graph, GraphMapping]: mapping = GraphMapping() for entity in graph.iter_entities(): mapping.m_ent[entity] = Entity(value=SYMBOLIC_VALUE) for node in graph.iter_nodes(): mapping.m_node[node] = Node(label=node.label, entity=mapping.m_ent[node.entity], value=SYMBOLIC_VALUE) new_graph = Graph.from_nodes_and_edges(nodes=set(mapping.m_node.values()), edges={ Edge(src=mapping.m_node[e.src], dst=mapping.m_node[e.dst], label=e.label) for e in graph.iter_edges() }) return new_graph, mapping