def run_meek_rule_four(self, a, graph): if self.knowledge is None: return adjacent_nodes = graph_util.adjacent_nodes(graph, a) for (b, c, d) in itertools.permutations(adjacent_nodes, 3): if (graph_util.has_undir_edge(graph, a, b) and graph_util.has_dir_edge(graph, b, c) and graph_util.has_dir_edge(graph, c, d) and graph_util.has_undir_edge(graph, d, a) and self.is_arrowpoint_allowed(a, d)): self.direct(a, d, graph)
def delete(self, x, y, H): # Remove any edge between x and y graph_util.remove_dir_edge(self.graph, x, y) graph_util.remove_dir_edge(self.graph, y, x) # H is the set of neighbors of y that are adjacent to x for node in H: if (graph_util.has_dir_edge(self.graph, node, y) or graph_util.has_dir_edge(self.graph, node, x)): continue # Direct the edge y --- node as y --> node graph_util.undir_to_dir(self.graph, y, node) # If x --- node is undirected, direct it as x --> node if graph_util.has_undir_edge(self.graph, x, node): graph_util.undir_to_dir(self.graph, x, node) self.removed_edges.add((x, y))
def bes(self): """BES removes edges from the graph generated by FGES, as added edges can now have negative bump in light of the additions to the graph after those edges were added.""" while len(self.sorted_arrows) > 0: if self.checkpoint_frequency > 0 and ( time.time() - self.last_checkpoint) > self.checkpoint_frequency: self.create_checkpoint() self.last_checkpoint = time.time() arrow = self.sorted_arrows.pop(0) x = arrow.a y = arrow.b if (not (arrow.na_y_x == graph_util.get_na_y_x(self.graph, x, y))) or \ (not graph_util.adjacent(self.graph, x, y)) or (graph_util.has_dir_edge(self.graph, y, x)): continue if not self.valid_delete(x, y, arrow.h_or_t, arrow.na_y_x): continue H = arrow.h_or_t bump = arrow.bump self.delete(x, y, H) meek_rules = MeekRules(knowledge=self.knowledge) meek_rules.orient_implied_subset(self.graph, set([x, y])) self.total_score += bump self.clear_arrow(x, y) if self.verbose: print("BES: Removed arrow " + str(x) + " -> " + str(y) + " with bump -" + str(bump)) visited = self.reapply_orientation(x, y, H) to_process = set() for node in visited: neighbors = graph_util.neighbors(self.graph, node) str_neighbors = self.stored_neighbors[node] if str_neighbors != neighbors: to_process.update([node]) to_process.add(x) to_process.add(y) to_process.update(graph_util.get_common_adjacents( self.graph, x, y)) # TODO: Store graph self.reevaluate_backward(to_process)
def is_violated_by(self, graph): for edge in self.required_edges: if not graph_util.has_dir_edge(graph, edge[0], edge[1]): return True for edge in graph.edges: if graph_util.has_undir_edge(graph, edge[0], edge[1]): continue if self.is_forbidden(edge[0], edge[1]): return True return False
def r1_helper(self, node_a, node_b, node_c, graph): if ((not graph_util.adjacent(graph, node_a, node_c)) and (graph_util.has_dir_edge(graph, node_a, node_b) or (node_a, node_b) in self.oriented) and graph_util.has_undir_edge(graph, node_b, node_c)): if not graph_util.is_unshielded_non_collider( graph, node_a, node_b, node_c): return if self.is_arrowpoint_allowed(node_b, node_c): # print("R1: " + str(node_b) + " " + str(node_c)) if (node_a, node_c) not in self.oriented and (node_c, node_a) not in self.oriented and \ (node_b, node_c) not in self.oriented and (node_c, node_b) not in self.oriented: self.direct(node_b, node_c, graph)
def reevaluate_backward(self, to_process): for node in to_process: self.stored_neighbors[node] = graph_util.neighbors( self.graph, node) adjacent_nodes = graph_util.adjacent_nodes(self.graph, node) for adj_node in adjacent_nodes: if graph_util.has_dir_edge(self.graph, adj_node, node): self.clear_arrow(adj_node, node) self.clear_arrow(node, adj_node) self.calculate_arrows_backward(adj_node, node) elif graph_util.has_undir_edge(self.graph, adj_node, node): self.clear_arrow(adj_node, node) self.clear_arrow(node, adj_node) self.calculate_arrows_backward(adj_node, node) self.calculate_arrows_backward(node, adj_node)
def r2_helper(self, a, b, c, graph): if graph_util.has_dir_edge(graph, a, b) and \ graph_util.has_dir_edge(graph, b, c) and \ graph_util.has_undir_edge(graph, a, c): if self.is_arrowpoint_allowed(a, c): self.direct(a, c, graph)