def abstract_atom(graph: HyperGraph, loc):
    atom = graph.node[loc]
    # replace ring atoms with nonterminals that just have the connections necessary for the ring
    assert atom.is_terminal, "We only can abstract terminal atoms!"
    new_edges = [
        copy.deepcopy(graph.edges[edge_id]) for edge_id in atom.edge_ids
    ]

    new_graph = HyperGraph()
    new_graph.add_edges(new_edges)

    atom_copy = Node(edge_ids=[x for x in new_graph.edges.keys()],
                     edges=[x for x in new_graph.edges.values()],
                     is_terminal=True,
                     data=atom.data)
    parent = Node(edge_ids=atom_copy.edge_ids, edges=atom_copy.edges)
    new_graph.add_parent_node(parent)
    new_graph.add_node(atom_copy)
    new_graph.validate()

    # replace the original atom with a nonterminal matching the parent of the new graph
    replacement_node = Node(
        edge_ids=atom.edge_ids,
        edges=[graph.edges[edge_id] for edge_id in atom.edge_ids])
    del graph.node[loc]
    graph.add_node(
        replacement_node)  # this makes sure it's appended at the end
    graph.validate()

    return new_graph
def abstract_ring_atom(graph: HyperGraph, loc):
    graph.validate()
    atom = graph.node[loc]
    # replace ring atoms with nonterminals that just have the connections necessary for the ring
    assert atom.is_terminal, "We only can abstract terminal atoms!"
    neighbors = [(edge_id, graph.other_end_of_edge(loc, edge_id))
                 for edge_id in atom.edge_ids]
    # determine the nodes in the 'key structure', which is terminals, parent nonterminal, and nonterminals connecting to more than one of the former
    internal_neighbor_ids = [
        x for x in neighbors if x[1] not in graph.child_ids()  # '
        or len(graph.node[x[1]].edge_ids) > 1
    ]
    child_neighbors = [x for x in neighbors if x not in internal_neighbor_ids]
    if len(internal_neighbor_ids) >= 2 and len(internal_neighbor_ids) < len(
            neighbors):
        # create the edges between the parent placeholder and the atom being abstracted
        new_edges = [
            copy.deepcopy(graph.edges[edge_id[0]])
            for edge_id in internal_neighbor_ids
        ]
        new_graph = HyperGraph()
        new_graph.add_edges(new_edges)
        parent = Node(edge_ids=[x for x in new_graph.edges.keys()],
                      edges=new_edges)
        old_new_edge_map = {
            old[0]: new
            for old, new in zip(internal_neighbor_ids, parent.edge_ids)
        }
        child_edge_ids = [
            old_new_edge_map[edge_id]
            if edge_id in old_new_edge_map else edge_id
            for edge_id in atom.edge_ids
        ]

        # and move over all the edges from the abstracted atom to its children
        for edge_id in child_edge_ids:
            if edge_id not in old_new_edge_map.values():
                new_graph.edges[edge_id] = graph.edges[edge_id]
                del graph.edges[edge_id]
        new_graph.add_parent_node(parent)
        child = Node(
            data=atom.data,
            edge_ids=child_edge_ids,
            edges=[new_graph.edges[edge_id] for edge_id in child_edge_ids],
            is_terminal=True)

        new_graph.add_node(child)

        # new hypergraph takes over all the children
        child_neighbor_nodes = [x[1] for x in child_neighbors]
        # determine which of the children we want to take over, we need the indices to amend the tree later
        child_inds = []
        for ind, child_id in enumerate(graph.child_ids()):
            if child_id in child_neighbor_nodes:
                child_inds.append(ind)

        # and now purge these nodes from the original graph and append them to the child, along with any edges
        for child_id in child_neighbor_nodes:
            new_graph.add_node(graph.node[child_id])
            del graph.node[child_id]
        new_graph.validate()

        # replace the atom with a nonterminal with just the connections for the ring
        replace_edge_ids = [edge_id[0] for edge_id in internal_neighbor_ids]
        replacement_node = Node(
            edge_ids=replace_edge_ids,
            edges=[graph.edges[edge_id] for edge_id in replace_edge_ids])
        del graph.node[loc]
        graph.add_node(replacement_node)
        graph.validate()

        return new_graph, child_inds
    else:
        return None, None
def _remove_largest_clique(cliques, tree):
    hypergraph = tree.node
    clique = sorted(cliques, key=lambda x: len(x))[-1]
    clique_nodes = []
    clique_children = []
    clique_idxs = []
    for i, node_id in enumerate(hypergraph.child_ids()):
        if node_id in clique:
            clique_nodes.append(hypergraph.node[node_id])
            clique_children.append(tree[i])
            clique_idxs.append(i)

    # Clean parent
    for node_id in clique:
        del hypergraph.node[node_id]
    for idx in sorted(clique_idxs, reverse=True):
        hypergraph.child_ids
        tree.pop(idx)

    # Add new non-terminal
    # Find external edges from clique
    clique_edges = set()
    for node in clique_nodes:
        clique_edges.update(set(node.edge_ids))

    external_edge_ids = []
    external_edges = []
    for node in hypergraph.node.values():
        for edge_id, edge in zip(node.edge_ids, node.edges):
            if edge_id in clique_edges:
                external_edge_ids.append(edge_id)
                external_edges.append(edge)

    new_nt = Node(edge_ids=external_edge_ids, edges=external_edges)
    hypergraph.add_node(new_nt)
    remaining_clique_edges = clique_edges - set(external_edge_ids)
    for edge_id in remaining_clique_edges:
        hypergraph.edges.pop(edge_id)
    hypergraph.validate()

    # Make new child from clique
    external_edges_ids_cpy = copy.deepcopy(external_edge_ids)
    external_edges_cpy = copy.deepcopy(external_edges)
    new_nt_child = Node(edge_ids=external_edges_ids_cpy,
                        edges=external_edges_cpy)
    new_child = HyperGraph()
    new_child.add_parent_node(new_nt_child)
    edge_id_map = {}
    for node in clique_nodes:
        new_child.add_node(node)
        for i, (edge_id, edge) in enumerate(zip(node.edge_ids, node.edges)):
            new_id = edge_id_map.get(edge_id)
            if not new_id:
                new_id = new_child.add_edge(edge)
                edge_id_map[edge_id] = new_id

            node.edge_ids[i] = new_id

            if edge_id in external_edge_ids:
                idx = external_edge_ids.index(edge_id)
                new_nt_child.edge_ids[idx] = new_id

    new_child.validate()
    child_graph = new_child.to_nx()
    child_graph.remove_node(new_child.parent_node_id)
    assert len(make_max_clique_graph(child_graph)) == 1

    tree.append(HypergraphTree(new_child, clique_children))