def test_graph_embedder_on_complete_hypergraphs(self):
     ge = GraphEmbedder(target_dim=512, grammar=gi.grammar)
     mol_graphs = [HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)]
     out = ge(mol_graphs)
     for eg, g in zip(out,mol_graphs):
         for i in range(len(g), max([len(gg) for gg in mol_graphs])):
             assert eg[i].abs().max() == 0 # embedded values should only be nonzero for actual nodes
예제 #2
0
 def test_graph_encoder_determinism(self):
     encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0)
     mol_graphs = [
         HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
     ]
     out = encoder(mol_graphs)
     out2 = encoder(mol_graphs)
     assert (out - out2).abs().max(
     ) < 1e-6, "Encoder should be deterministic with zero dropout!"
예제 #3
0
    def test_graph_encoder_batch_independence(self):
        encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0)
        mol_graphs = [
            HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
        ]
        out = encoder(mol_graphs)
        out2 = encoder(mol_graphs[:1])

        assert (out[:1, :out2.size(1)] - out2).abs().max(
        ) < 1e-5, "Encoder should have no crosstalk between batches"
 def test_graph_encoder_with_head(self):
     codec = get_codec(molecules=True,
                       grammar='hypergraph:' + tmp_file,
                       max_seq_length=max_seq_length)
     encoder = GraphEncoder(grammar=gi.grammar,
                            d_model=512,
                            drop_rate=0.0)
     mol_graphs = [HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)]
     model = MultipleOutputHead(model=encoder,
                                output_spec={'node': 1,  # to be used to select next node to expand
                                             'action': codec.feature_len()},  # to select the action for chosen node
                                drop_rate=0.1).to(device)
     out = model(mol_graphs)
예제 #5
0
    def test_with_first_sequence_element_head(self):
        d_model = 512
        encoder = GraphEncoder(grammar=gi.grammar,
                               d_model=d_model,
                               drop_rate=0.0)

        encoder_aggregated = FirstSequenceElementHead(encoder)
        mol_graphs = [
            HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
        ]
        out = encoder_aggregated(mol_graphs)
        assert out.size(0) == len(mol_graphs)
        assert out.size(1) == d_model
        assert len(out.size()) == 2
예제 #6
0
    def test_with_multihead_attenion_aggregating_head(self):
        d_model = 512
        encoder = GraphEncoder(grammar=gi.grammar,
                               d_model=d_model,
                               drop_rate=0.0)

        encoder_aggregated = MultiheadAttentionAggregatingHead(encoder)
        mol_graphs = [
            HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
        ]
        out = encoder_aggregated(mol_graphs)
        assert out.size(0) == len(mol_graphs)
        assert out.size(1) == d_model
        assert len(out.size()) == 2
예제 #7
0
    def test_full_discriminator_parts_tuple_head(self):
        encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0)

        encoder_aggregated = FirstSequenceElementHead(encoder)
        discriminator = MultipleOutputHead(encoder_aggregated, [2],
                                           drop_rate=0).to(device)
        mol_graphs = [
            HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
        ]
        out = discriminator(mol_graphs)[0]
        out2 = discriminator(mol_graphs[:1])[0]
        assert out.size(0) == len(mol_graphs)
        assert out.size(1) == 2
        assert len(out.size()) == 2
        assert torch.max((out[0, :] - out2[0, :]).abs()) < 1e-5
    def forward(self, x):
        if type(x) in (list, tuple):
            smiles = x
        elif type(x) in (dict, OrderedDict):
            smiles = x['smiles']
        else:
            raise ValueError("Unknown input type: " + str(x))

        mol_graphs = [HyperGraph.from_smiles(s) for s in smiles]
        out = self.discriminator(mol_graphs)
        if type(x) in (list, tuple):
            out['smiles'] = smiles
        elif type(x) in (dict, OrderedDict):
            out.update(x)

        return out
예제 #9
0
    def test_encoder_batch_independence(self):
        d_model = 512
        encoder = GraphEncoder(grammar=gi.grammar,
                               d_model=d_model,
                               drop_rate=0.0)

        encoder_aggregated = FirstSequenceElementHead(encoder)
        mol_graphs = [
            HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
        ]
        out = encoder_aggregated(mol_graphs)
        out2 = encoder_aggregated(mol_graphs[:1])
        assert out.size(0) == len(mol_graphs)
        assert out.size(1) == d_model
        assert len(out.size()) == 2
        assert torch.max((out[0] - out2[0]).abs()) < 1e-5
예제 #10
0
    def rule_to_index(self, rule: HyperGraph, no_new_rules=False):
        self.candidate_counter += 1
        parent_node = rule.parent_node()

        if str(parent_node) not in self.id_by_parent:
            self.id_by_parent[str(parent_node)] = []

        # only check the equivalence against graphs with matching parent node
        for rule_id in self.id_by_parent[str(parent_node)]:
            mapping = hypergraphs_are_equivalent(self.rules[rule_id], rule)
            if mapping is not None:
                # if we found a match, we're done!
                return rule_id, mapping
        # if got this far, no match so this is a new rule
        if no_new_rules:
            raise ValueError("Unknown rule hypergraph " + str(rule))
        self.add_rule(rule)
        return (len(self.rules) - 1), {i: i for i in rule.node.keys()}
예제 #11
0
 def test_hypergraph_via_nx_graph_roundtrip(self):
     mol = MolFromSmiles(smiles1)
     hg = HyperGraph.from_mol(mol)
     re_mol = to_mol(hg.to_nx())
     re_smiles = MolToSmiles(re_mol)
     assert re_smiles == smiles1
    def junction_tree_stage(my_start_atoms, parent_bond_inds=[]):
        # TODO: assert that start atoms are in parent bonds, if any

        # am I part of any rings?
        my_atoms = copy.copy(my_start_atoms)
        for ring in atom_rings:
            for atom_idx in my_start_atoms:
                if atom_idx in ring:
                    for other_idx in ring:
                        # don't insert ring atoms that were already parsed
                        if other_idx not in my_atoms and other_idx in atoms_left:
                            my_atoms.append(other_idx)

        # this is a check that we never assign the same atom to two nodes
        for idx in my_atoms:
            assert idx in atoms_left
            atoms_left.discard(idx)

        # determine all my bonds
        my_bonds_pre = OrderedDict()
        for atom_idx in my_atoms:
            my_bonds_pre.update(get_neighbors(mol, atom_idx))

        # enumerate my bonds
        my_bonds = my_bonds_pre
        internal_bonds = OrderedDict([
            (key, value) for key, value in my_bonds.items()
            if value[0] in my_atoms and value[1] in my_atoms
        ])
        parent_bonds = OrderedDict([(key, value)
                                    for key, value in my_bonds.items()
                                    if key in parent_bond_inds])
        child_bonds = OrderedDict([
            (key, value) for key, value in my_bonds.items()
            if key not in internal_bonds and key not in parent_bonds
        ])

        me = {
            'atoms': my_atoms,
            'bonds': my_bonds,
            'parent_bonds': parent_bonds
        }

        child_bond_groups = []
        processed = []
        # now go over the neighbor atoms checking if they're part of another ring
        for child_bond, child_bond_atoms in child_bonds.items():

            if child_bond in processed:  # already  processed
                continue

            processed.append(child_bond)
            this_group = [child_bond]
            for bond_ring in bond_rings:
                if child_bond in bond_ring:
                    # what other bonds are part of the same ring?
                    for other_bond_idx in child_bonds.keys():
                        if other_bond_idx in bond_ring and other_bond_idx not in this_group:
                            this_group.append(other_bond_idx)
                            processed.append(other_bond_idx)

            child_bond_groups.append(this_group)

        # and now the recursive call
        children = OrderedDict()
        for group in child_bond_groups:
            next_parent_bonds = group
            next_start_atoms = [child_bonds[g][1] for g in group]
            children[tuple(group)] = junction_tree_stage(
                next_start_atoms, next_parent_bonds)

        if len(children):
            me['children'] = children

        me['node'] = HyperGraph.from_tree_node(mol, me)

        return me
def abstract_atom(graph: HyperGraph, loc):
    atom = graph.node[loc]
    # replace ring atoms with nonterminals that just have the connections necessary for the ring
    assert atom.is_terminal, "We only can abstract terminal atoms!"
    new_edges = [
        copy.deepcopy(graph.edges[edge_id]) for edge_id in atom.edge_ids
    ]

    new_graph = HyperGraph()
    new_graph.add_edges(new_edges)

    atom_copy = Node(edge_ids=[x for x in new_graph.edges.keys()],
                     edges=[x for x in new_graph.edges.values()],
                     is_terminal=True,
                     data=atom.data)
    parent = Node(edge_ids=atom_copy.edge_ids, edges=atom_copy.edges)
    new_graph.add_parent_node(parent)
    new_graph.add_node(atom_copy)
    new_graph.validate()

    # replace the original atom with a nonterminal matching the parent of the new graph
    replacement_node = Node(
        edge_ids=atom.edge_ids,
        edges=[graph.edges[edge_id] for edge_id in atom.edge_ids])
    del graph.node[loc]
    graph.add_node(
        replacement_node)  # this makes sure it's appended at the end
    graph.validate()

    return new_graph
def abstract_ring_atom(graph: HyperGraph, loc):
    graph.validate()
    atom = graph.node[loc]
    # replace ring atoms with nonterminals that just have the connections necessary for the ring
    assert atom.is_terminal, "We only can abstract terminal atoms!"
    neighbors = [(edge_id, graph.other_end_of_edge(loc, edge_id))
                 for edge_id in atom.edge_ids]
    # determine the nodes in the 'key structure', which is terminals, parent nonterminal, and nonterminals connecting to more than one of the former
    internal_neighbor_ids = [
        x for x in neighbors if x[1] not in graph.child_ids()  # '
        or len(graph.node[x[1]].edge_ids) > 1
    ]
    child_neighbors = [x for x in neighbors if x not in internal_neighbor_ids]
    if len(internal_neighbor_ids) >= 2 and len(internal_neighbor_ids) < len(
            neighbors):
        # create the edges between the parent placeholder and the atom being abstracted
        new_edges = [
            copy.deepcopy(graph.edges[edge_id[0]])
            for edge_id in internal_neighbor_ids
        ]
        new_graph = HyperGraph()
        new_graph.add_edges(new_edges)
        parent = Node(edge_ids=[x for x in new_graph.edges.keys()],
                      edges=new_edges)
        old_new_edge_map = {
            old[0]: new
            for old, new in zip(internal_neighbor_ids, parent.edge_ids)
        }
        child_edge_ids = [
            old_new_edge_map[edge_id]
            if edge_id in old_new_edge_map else edge_id
            for edge_id in atom.edge_ids
        ]

        # and move over all the edges from the abstracted atom to its children
        for edge_id in child_edge_ids:
            if edge_id not in old_new_edge_map.values():
                new_graph.edges[edge_id] = graph.edges[edge_id]
                del graph.edges[edge_id]
        new_graph.add_parent_node(parent)
        child = Node(
            data=atom.data,
            edge_ids=child_edge_ids,
            edges=[new_graph.edges[edge_id] for edge_id in child_edge_ids],
            is_terminal=True)

        new_graph.add_node(child)

        # new hypergraph takes over all the children
        child_neighbor_nodes = [x[1] for x in child_neighbors]
        # determine which of the children we want to take over, we need the indices to amend the tree later
        child_inds = []
        for ind, child_id in enumerate(graph.child_ids()):
            if child_id in child_neighbor_nodes:
                child_inds.append(ind)

        # and now purge these nodes from the original graph and append them to the child, along with any edges
        for child_id in child_neighbor_nodes:
            new_graph.add_node(graph.node[child_id])
            del graph.node[child_id]
        new_graph.validate()

        # replace the atom with a nonterminal with just the connections for the ring
        replace_edge_ids = [edge_id[0] for edge_id in internal_neighbor_ids]
        replacement_node = Node(
            edge_ids=replace_edge_ids,
            edges=[graph.edges[edge_id] for edge_id in replace_edge_ids])
        del graph.node[loc]
        graph.add_node(replacement_node)
        graph.validate()

        return new_graph, child_inds
    else:
        return None, None
def _remove_largest_clique(cliques, tree):
    hypergraph = tree.node
    clique = sorted(cliques, key=lambda x: len(x))[-1]
    clique_nodes = []
    clique_children = []
    clique_idxs = []
    for i, node_id in enumerate(hypergraph.child_ids()):
        if node_id in clique:
            clique_nodes.append(hypergraph.node[node_id])
            clique_children.append(tree[i])
            clique_idxs.append(i)

    # Clean parent
    for node_id in clique:
        del hypergraph.node[node_id]
    for idx in sorted(clique_idxs, reverse=True):
        hypergraph.child_ids
        tree.pop(idx)

    # Add new non-terminal
    # Find external edges from clique
    clique_edges = set()
    for node in clique_nodes:
        clique_edges.update(set(node.edge_ids))

    external_edge_ids = []
    external_edges = []
    for node in hypergraph.node.values():
        for edge_id, edge in zip(node.edge_ids, node.edges):
            if edge_id in clique_edges:
                external_edge_ids.append(edge_id)
                external_edges.append(edge)

    new_nt = Node(edge_ids=external_edge_ids, edges=external_edges)
    hypergraph.add_node(new_nt)
    remaining_clique_edges = clique_edges - set(external_edge_ids)
    for edge_id in remaining_clique_edges:
        hypergraph.edges.pop(edge_id)
    hypergraph.validate()

    # Make new child from clique
    external_edges_ids_cpy = copy.deepcopy(external_edge_ids)
    external_edges_cpy = copy.deepcopy(external_edges)
    new_nt_child = Node(edge_ids=external_edges_ids_cpy,
                        edges=external_edges_cpy)
    new_child = HyperGraph()
    new_child.add_parent_node(new_nt_child)
    edge_id_map = {}
    for node in clique_nodes:
        new_child.add_node(node)
        for i, (edge_id, edge) in enumerate(zip(node.edge_ids, node.edges)):
            new_id = edge_id_map.get(edge_id)
            if not new_id:
                new_id = new_child.add_edge(edge)
                edge_id_map[edge_id] = new_id

            node.edge_ids[i] = new_id

            if edge_id in external_edge_ids:
                idx = external_edge_ids.index(edge_id)
                new_nt_child.edge_ids[idx] = new_id

    new_child.validate()
    child_graph = new_child.to_nx()
    child_graph.remove_node(new_child.parent_node_id)
    assert len(make_max_clique_graph(child_graph)) == 1

    tree.append(HypergraphTree(new_child, clique_children))