def get_graph_model(codec, drop_rate, model_type, output_type='values', num_bins=51):
    if 'conditional' in model_type:
        if 'sparse' in model_type:
            model = CondtionalProbabilityModelSparse(codec.grammar)
        else:
            model = ConditionalModelBlended(codec.grammar)#'sparse' in model_type)
        model.init_encoder_output = lambda x: None
        return model

    # for all other models, start with a GraphEncoder and attach the right head
    encoder = GraphEncoder(grammar=codec.grammar,
                           d_model=512,
                           drop_rate=drop_rate,
                           model_type=model_type)
    if output_type == 'values':
        model = MultipleOutputHead(model=encoder,
                                   output_spec={'node': 1,  # to be used to select next node to expand
                                                'action': codec.feature_len()},  # to select the action for chosen node
                                   drop_rate=drop_rate)
        model = OneValuePerNodeRuleTransform(model)
    elif 'distributions' in output_type:
        model = MultipleOutputHead(model=encoder,
                                   output_spec={'node': 1,  # to be used to select next node to expand
                                                'action': codec.feature_len()*num_bins},  # to select the action for chosen node
                                   drop_rate=drop_rate)
        if 'thompson' in output_type:
            model = DistributionPerNodeRuleTransformThompson(model, num_bins=num_bins)
        elif 'softmax' in output_type:
            model = DistributionPerNodeRuleTransformSoftmax(model, num_bins=num_bins, T=10)
    model.init_encoder_output = lambda x: None
    return model
예제 #2
0
 def test_graph_encoder_determinism(self):
     encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0)
     mol_graphs = [
         HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
     ]
     out = encoder(mol_graphs)
     out2 = encoder(mol_graphs)
     assert (out - out2).abs().max(
     ) < 1e-6, "Encoder should be deterministic with zero dropout!"
예제 #3
0
 def __init__(self, grammar, drop_rate=0.0, d_model=512):
     super().__init__()
     encoder = GraphEncoder(grammar=grammar,
                    d_model=d_model,
                    drop_rate=drop_rate)
     encoder_aggregated = FirstSequenceElementHead(encoder)
     self.discriminator = MultipleOutputHead(encoder_aggregated,
                                             {'p_zinc': 2},
                                             drop_rate=drop_rate).to(device)
예제 #4
0
    def test_graph_encoder_batch_independence(self):
        encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0)
        mol_graphs = [
            HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
        ]
        out = encoder(mol_graphs)
        out2 = encoder(mol_graphs[:1])

        assert (out[:1, :out2.size(1)] - out2).abs().max(
        ) < 1e-5, "Encoder should have no crosstalk between batches"
예제 #5
0
    def __init__(self, grammar, output_spec, drop_rate=0.0, d_model=512):
        super().__init__()
        encoder = GraphEncoder(grammar=grammar,
                               d_model=d_model,
                               drop_rate=drop_rate)
        self.model = MultipleOutputHead(encoder,
                                                output_spec,
                                                drop_rate=drop_rate).to(device)

        # don't support using this model in VAE-style models yet
        self.init_encoder_output = lambda x: None
        self.output_shape = self.model.output_shape
 def test_graph_encoder_with_head(self):
     codec = get_codec(molecules=True,
                       grammar='hypergraph:' + tmp_file,
                       max_seq_length=max_seq_length)
     encoder = GraphEncoder(grammar=gi.grammar,
                            d_model=512,
                            drop_rate=0.0)
     mol_graphs = [HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)]
     model = MultipleOutputHead(model=encoder,
                                output_spec={'node': 1,  # to be used to select next node to expand
                                             'action': codec.feature_len()},  # to select the action for chosen node
                                drop_rate=0.1).to(device)
     out = model(mol_graphs)
예제 #7
0
    def test_with_first_sequence_element_head(self):
        d_model = 512
        encoder = GraphEncoder(grammar=gi.grammar,
                               d_model=d_model,
                               drop_rate=0.0)

        encoder_aggregated = FirstSequenceElementHead(encoder)
        mol_graphs = [
            HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
        ]
        out = encoder_aggregated(mol_graphs)
        assert out.size(0) == len(mol_graphs)
        assert out.size(1) == d_model
        assert len(out.size()) == 2
예제 #8
0
    def test_with_multihead_attenion_aggregating_head(self):
        d_model = 512
        encoder = GraphEncoder(grammar=gi.grammar,
                               d_model=d_model,
                               drop_rate=0.0)

        encoder_aggregated = MultiheadAttentionAggregatingHead(encoder)
        mol_graphs = [
            HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
        ]
        out = encoder_aggregated(mol_graphs)
        assert out.size(0) == len(mol_graphs)
        assert out.size(1) == d_model
        assert len(out.size()) == 2
예제 #9
0
    def test_full_discriminator_parts_tuple_head(self):
        encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0)

        encoder_aggregated = FirstSequenceElementHead(encoder)
        discriminator = MultipleOutputHead(encoder_aggregated, [2],
                                           drop_rate=0).to(device)
        mol_graphs = [
            HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
        ]
        out = discriminator(mol_graphs)[0]
        out2 = discriminator(mol_graphs[:1])[0]
        assert out.size(0) == len(mol_graphs)
        assert out.size(1) == 2
        assert len(out.size()) == 2
        assert torch.max((out[0, :] - out2[0, :]).abs()) < 1e-5
예제 #10
0
    def test_encoder_batch_independence(self):
        d_model = 512
        encoder = GraphEncoder(grammar=gi.grammar,
                               d_model=d_model,
                               drop_rate=0.0)

        encoder_aggregated = FirstSequenceElementHead(encoder)
        mol_graphs = [
            HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
        ]
        out = encoder_aggregated(mol_graphs)
        out2 = encoder_aggregated(mol_graphs[:1])
        assert out.size(0) == len(mol_graphs)
        assert out.size(1) == d_model
        assert len(out.size()) == 2
        assert torch.max((out[0] - out2[0]).abs()) < 1e-5
예제 #11
0
def get_graph_model(codec, drop_rate, model_type, output_type='values', num_bins=51):
    encoder = GraphEncoder(grammar=codec.grammar,
                           d_model=512,
                           drop_rate=drop_rate,
                           model_type=model_type)
    if output_type == 'values':
        model = MultipleOutputHead(model=encoder,
                                   output_spec={'node': 1,  # to be used to select next node to expand
                                                'action': codec.feature_len()},  # to select the action for chosen node
                                   drop_rate=drop_rate)
        model = OneValuePerNodeRuleTransform(model)
    elif 'distributions' in output_type:
        model = MultipleOutputHead(model=encoder,
                                   output_spec={'node': 1,  # to be used to select next node to expand
                                                'action': codec.feature_len()*num_bins},  # to select the action for chosen node
                                   drop_rate=drop_rate)
        if 'thompson' in output_type:
            model = DistributionPerNodeRuleTransformThompson(model, num_bins=num_bins)
        elif 'softmax' in output_type:
            model = DistributionPerNodeRuleTransformSoftmax(model, num_bins=num_bins, T=10)
    model.init_encoder_output = lambda x: None
    return model