def test_graph_embedder_on_complete_hypergraphs(self):
     ge = GraphEmbedder(target_dim=512, grammar=gi.grammar)
     mol_graphs = [HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)]
     out = ge(mol_graphs)
     for eg, g in zip(out,mol_graphs):
         for i in range(len(g), max([len(gg) for gg in mol_graphs])):
             assert eg[i].abs().max() == 0 # embedded values should only be nonzero for actual nodes
예제 #2
0
 def test_graph_encoder_determinism(self):
     encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0)
     mol_graphs = [
         HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
     ]
     out = encoder(mol_graphs)
     out2 = encoder(mol_graphs)
     assert (out - out2).abs().max(
     ) < 1e-6, "Encoder should be deterministic with zero dropout!"
예제 #3
0
    def test_graph_encoder_batch_independence(self):
        encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0)
        mol_graphs = [
            HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
        ]
        out = encoder(mol_graphs)
        out2 = encoder(mol_graphs[:1])

        assert (out[:1, :out2.size(1)] - out2).abs().max(
        ) < 1e-5, "Encoder should have no crosstalk between batches"
 def test_graph_encoder_with_head(self):
     codec = get_codec(molecules=True,
                       grammar='hypergraph:' + tmp_file,
                       max_seq_length=max_seq_length)
     encoder = GraphEncoder(grammar=gi.grammar,
                            d_model=512,
                            drop_rate=0.0)
     mol_graphs = [HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)]
     model = MultipleOutputHead(model=encoder,
                                output_spec={'node': 1,  # to be used to select next node to expand
                                             'action': codec.feature_len()},  # to select the action for chosen node
                                drop_rate=0.1).to(device)
     out = model(mol_graphs)
예제 #5
0
    def test_with_first_sequence_element_head(self):
        d_model = 512
        encoder = GraphEncoder(grammar=gi.grammar,
                               d_model=d_model,
                               drop_rate=0.0)

        encoder_aggregated = FirstSequenceElementHead(encoder)
        mol_graphs = [
            HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
        ]
        out = encoder_aggregated(mol_graphs)
        assert out.size(0) == len(mol_graphs)
        assert out.size(1) == d_model
        assert len(out.size()) == 2
예제 #6
0
    def test_with_multihead_attenion_aggregating_head(self):
        d_model = 512
        encoder = GraphEncoder(grammar=gi.grammar,
                               d_model=d_model,
                               drop_rate=0.0)

        encoder_aggregated = MultiheadAttentionAggregatingHead(encoder)
        mol_graphs = [
            HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
        ]
        out = encoder_aggregated(mol_graphs)
        assert out.size(0) == len(mol_graphs)
        assert out.size(1) == d_model
        assert len(out.size()) == 2
예제 #7
0
    def test_full_discriminator_parts_tuple_head(self):
        encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0)

        encoder_aggregated = FirstSequenceElementHead(encoder)
        discriminator = MultipleOutputHead(encoder_aggregated, [2],
                                           drop_rate=0).to(device)
        mol_graphs = [
            HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
        ]
        out = discriminator(mol_graphs)[0]
        out2 = discriminator(mol_graphs[:1])[0]
        assert out.size(0) == len(mol_graphs)
        assert out.size(1) == 2
        assert len(out.size()) == 2
        assert torch.max((out[0, :] - out2[0, :]).abs()) < 1e-5
예제 #8
0
    def test_encoder_batch_independence(self):
        d_model = 512
        encoder = GraphEncoder(grammar=gi.grammar,
                               d_model=d_model,
                               drop_rate=0.0)

        encoder_aggregated = FirstSequenceElementHead(encoder)
        mol_graphs = [
            HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
        ]
        out = encoder_aggregated(mol_graphs)
        out2 = encoder_aggregated(mol_graphs[:1])
        assert out.size(0) == len(mol_graphs)
        assert out.size(1) == d_model
        assert len(out.size()) == 2
        assert torch.max((out[0] - out2[0]).abs()) < 1e-5
예제 #9
0
 def test_hypergraph_via_nx_graph_roundtrip(self):
     mol = MolFromSmiles(smiles1)
     hg = HyperGraph.from_mol(mol)
     re_mol = to_mol(hg.to_nx())
     re_smiles = MolToSmiles(re_mol)
     assert re_smiles == smiles1