def get_graph_model(codec, drop_rate, model_type, output_type='values', num_bins=51): if 'conditional' in model_type: if 'sparse' in model_type: model = CondtionalProbabilityModelSparse(codec.grammar) else: model = ConditionalModelBlended(codec.grammar)#'sparse' in model_type) model.init_encoder_output = lambda x: None return model # for all other models, start with a GraphEncoder and attach the right head encoder = GraphEncoder(grammar=codec.grammar, d_model=512, drop_rate=drop_rate, model_type=model_type) if output_type == 'values': model = MultipleOutputHead(model=encoder, output_spec={'node': 1, # to be used to select next node to expand 'action': codec.feature_len()}, # to select the action for chosen node drop_rate=drop_rate) model = OneValuePerNodeRuleTransform(model) elif 'distributions' in output_type: model = MultipleOutputHead(model=encoder, output_spec={'node': 1, # to be used to select next node to expand 'action': codec.feature_len()*num_bins}, # to select the action for chosen node drop_rate=drop_rate) if 'thompson' in output_type: model = DistributionPerNodeRuleTransformThompson(model, num_bins=num_bins) elif 'softmax' in output_type: model = DistributionPerNodeRuleTransformSoftmax(model, num_bins=num_bins, T=10) model.init_encoder_output = lambda x: None return model
def test_graph_encoder_determinism(self): encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0) mol_graphs = [ HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5) ] out = encoder(mol_graphs) out2 = encoder(mol_graphs) assert (out - out2).abs().max( ) < 1e-6, "Encoder should be deterministic with zero dropout!"
def __init__(self, grammar, drop_rate=0.0, d_model=512): super().__init__() encoder = GraphEncoder(grammar=grammar, d_model=d_model, drop_rate=drop_rate) encoder_aggregated = FirstSequenceElementHead(encoder) self.discriminator = MultipleOutputHead(encoder_aggregated, {'p_zinc': 2}, drop_rate=drop_rate).to(device)
def test_graph_encoder_batch_independence(self): encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0) mol_graphs = [ HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5) ] out = encoder(mol_graphs) out2 = encoder(mol_graphs[:1]) assert (out[:1, :out2.size(1)] - out2).abs().max( ) < 1e-5, "Encoder should have no crosstalk between batches"
def __init__(self, grammar, output_spec, drop_rate=0.0, d_model=512): super().__init__() encoder = GraphEncoder(grammar=grammar, d_model=d_model, drop_rate=drop_rate) self.model = MultipleOutputHead(encoder, output_spec, drop_rate=drop_rate).to(device) # don't support using this model in VAE-style models yet self.init_encoder_output = lambda x: None self.output_shape = self.model.output_shape
def test_graph_encoder_with_head(self): codec = get_codec(molecules=True, grammar='hypergraph:' + tmp_file, max_seq_length=max_seq_length) encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0) mol_graphs = [HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)] model = MultipleOutputHead(model=encoder, output_spec={'node': 1, # to be used to select next node to expand 'action': codec.feature_len()}, # to select the action for chosen node drop_rate=0.1).to(device) out = model(mol_graphs)
def test_with_first_sequence_element_head(self): d_model = 512 encoder = GraphEncoder(grammar=gi.grammar, d_model=d_model, drop_rate=0.0) encoder_aggregated = FirstSequenceElementHead(encoder) mol_graphs = [ HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5) ] out = encoder_aggregated(mol_graphs) assert out.size(0) == len(mol_graphs) assert out.size(1) == d_model assert len(out.size()) == 2
def test_with_multihead_attenion_aggregating_head(self): d_model = 512 encoder = GraphEncoder(grammar=gi.grammar, d_model=d_model, drop_rate=0.0) encoder_aggregated = MultiheadAttentionAggregatingHead(encoder) mol_graphs = [ HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5) ] out = encoder_aggregated(mol_graphs) assert out.size(0) == len(mol_graphs) assert out.size(1) == d_model assert len(out.size()) == 2
def test_full_discriminator_parts_tuple_head(self): encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0) encoder_aggregated = FirstSequenceElementHead(encoder) discriminator = MultipleOutputHead(encoder_aggregated, [2], drop_rate=0).to(device) mol_graphs = [ HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5) ] out = discriminator(mol_graphs)[0] out2 = discriminator(mol_graphs[:1])[0] assert out.size(0) == len(mol_graphs) assert out.size(1) == 2 assert len(out.size()) == 2 assert torch.max((out[0, :] - out2[0, :]).abs()) < 1e-5
def test_encoder_batch_independence(self): d_model = 512 encoder = GraphEncoder(grammar=gi.grammar, d_model=d_model, drop_rate=0.0) encoder_aggregated = FirstSequenceElementHead(encoder) mol_graphs = [ HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5) ] out = encoder_aggregated(mol_graphs) out2 = encoder_aggregated(mol_graphs[:1]) assert out.size(0) == len(mol_graphs) assert out.size(1) == d_model assert len(out.size()) == 2 assert torch.max((out[0] - out2[0]).abs()) < 1e-5
def get_graph_model(codec, drop_rate, model_type, output_type='values', num_bins=51): encoder = GraphEncoder(grammar=codec.grammar, d_model=512, drop_rate=drop_rate, model_type=model_type) if output_type == 'values': model = MultipleOutputHead(model=encoder, output_spec={'node': 1, # to be used to select next node to expand 'action': codec.feature_len()}, # to select the action for chosen node drop_rate=drop_rate) model = OneValuePerNodeRuleTransform(model) elif 'distributions' in output_type: model = MultipleOutputHead(model=encoder, output_spec={'node': 1, # to be used to select next node to expand 'action': codec.feature_len()*num_bins}, # to select the action for chosen node drop_rate=drop_rate) if 'thompson' in output_type: model = DistributionPerNodeRuleTransformThompson(model, num_bins=num_bins) elif 'softmax' in output_type: model = DistributionPerNodeRuleTransformSoftmax(model, num_bins=num_bins, T=10) model.init_encoder_output = lambda x: None return model