def test_graph_embedder_on_complete_hypergraphs(self): ge = GraphEmbedder(target_dim=512, grammar=gi.grammar) mol_graphs = [HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)] out = ge(mol_graphs) for eg, g in zip(out,mol_graphs): for i in range(len(g), max([len(gg) for gg in mol_graphs])): assert eg[i].abs().max() == 0 # embedded values should only be nonzero for actual nodes
def test_graph_encoder_determinism(self): encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0) mol_graphs = [ HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5) ] out = encoder(mol_graphs) out2 = encoder(mol_graphs) assert (out - out2).abs().max( ) < 1e-6, "Encoder should be deterministic with zero dropout!"
def test_graph_encoder_batch_independence(self): encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0) mol_graphs = [ HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5) ] out = encoder(mol_graphs) out2 = encoder(mol_graphs[:1]) assert (out[:1, :out2.size(1)] - out2).abs().max( ) < 1e-5, "Encoder should have no crosstalk between batches"
def test_graph_encoder_with_head(self): codec = get_codec(molecules=True, grammar='hypergraph:' + tmp_file, max_seq_length=max_seq_length) encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0) mol_graphs = [HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)] model = MultipleOutputHead(model=encoder, output_spec={'node': 1, # to be used to select next node to expand 'action': codec.feature_len()}, # to select the action for chosen node drop_rate=0.1).to(device) out = model(mol_graphs)
def test_with_first_sequence_element_head(self): d_model = 512 encoder = GraphEncoder(grammar=gi.grammar, d_model=d_model, drop_rate=0.0) encoder_aggregated = FirstSequenceElementHead(encoder) mol_graphs = [ HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5) ] out = encoder_aggregated(mol_graphs) assert out.size(0) == len(mol_graphs) assert out.size(1) == d_model assert len(out.size()) == 2
def test_with_multihead_attenion_aggregating_head(self): d_model = 512 encoder = GraphEncoder(grammar=gi.grammar, d_model=d_model, drop_rate=0.0) encoder_aggregated = MultiheadAttentionAggregatingHead(encoder) mol_graphs = [ HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5) ] out = encoder_aggregated(mol_graphs) assert out.size(0) == len(mol_graphs) assert out.size(1) == d_model assert len(out.size()) == 2
def test_full_discriminator_parts_tuple_head(self): encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0) encoder_aggregated = FirstSequenceElementHead(encoder) discriminator = MultipleOutputHead(encoder_aggregated, [2], drop_rate=0).to(device) mol_graphs = [ HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5) ] out = discriminator(mol_graphs)[0] out2 = discriminator(mol_graphs[:1])[0] assert out.size(0) == len(mol_graphs) assert out.size(1) == 2 assert len(out.size()) == 2 assert torch.max((out[0, :] - out2[0, :]).abs()) < 1e-5
def test_encoder_batch_independence(self): d_model = 512 encoder = GraphEncoder(grammar=gi.grammar, d_model=d_model, drop_rate=0.0) encoder_aggregated = FirstSequenceElementHead(encoder) mol_graphs = [ HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5) ] out = encoder_aggregated(mol_graphs) out2 = encoder_aggregated(mol_graphs[:1]) assert out.size(0) == len(mol_graphs) assert out.size(1) == d_model assert len(out.size()) == 2 assert torch.max((out[0] - out2[0]).abs()) < 1e-5