def test_hypergraph_mask_gen_no_priors(self):
        tmp_file = 'tmp2.pickle'
        gi = GrammarInitializer(tmp_file)
        gi.delete_cache()
        # now create a clean new one
        gi = GrammarInitializer(tmp_file)
        # run a first run for 10 molecules
        gi.init_grammar(20)
        gi.grammar.check_attributes()

        mask_gen = HypergraphMaskGenerator(30, gi.grammar, priors=False)

        all_actions = []
        next_action = [None for _ in range(2)]
        policy = SoftmaxRandomSamplePolicy()
        while True:
            try:
                mask_gen.apply_action(next_action)
                cond_priors = mask_gen.action_prior_logits()
                cond_priors_pytorch = torch.from_numpy(cond_priors).to(device=device, dtype=torch.float32)
                next_action = policy(cond_priors_pytorch).cpu().detach().numpy()
                all_actions.append(next_action)
            except StopIteration:
                break
        all_actions = np.array(all_actions).T
        all_smiles = gi.grammar.actions_to_strings(all_actions)
        for smile in all_smiles:
            self.assertIsNot(MolFromSmiles(smile), None)
Esempio n. 2
0
    def test_grammar_initializer(self):
        # nuke any cached data
        tmp_file = 'tmp2.pickle'
        gi = GrammarInitializer(tmp_file)
        gi.delete_cache()
        # now create a clean new one
        gi = GrammarInitializer(tmp_file)
        # run a first run for 10 molecules
        first_10 = gi.init_grammar(10)
        # load the resulting object
        gi2 = GrammarInitializer.load(gi.own_filename)
        gi2.init_grammar(20)

        freqs = gi2.grammar.get_log_frequencies()
        assert len(freqs) == len(gi2.grammar)
        assert all([f >= 0 for f in freqs])
        cond_count = 0
        for cf in gi2.grammar.conditional_frequencies.values():
            this_count = sum(cf.values())
            assert abs(this_count-1.0) < 1e-5
            cond_count += sum(cf.values())
        nt_count = 0
        for rule in gi2.grammar.rules:
            if rule is not None:
                nt_count += len(rule.nonterminal_ids())
        assert cond_count == nt_count, "Something went wrong when counting the frequencies..."
        gi2.grammar.check_attributes()
Esempio n. 3
0
    def test_decoder_with_environment_new(self):
        tmp_file = 'tmp2.pickle'
        gi = GrammarInitializer(tmp_file)
        gi.delete_cache()
        # now create a clean new one
        gi = GrammarInitializer(tmp_file)
        # run a first run for 10 molecules
        gi.init_grammar(20)
        gi.grammar.check_attributes()
        mask_gen = HypergraphMaskGenerator(30, gi.grammar, priors=True)
        batch_size = 2

        env = GraphEnvironment(mask_gen,
                               reward_fun=lambda x: 2*np.ones(len(x)),
                               batch_size=2)

        def dummy_stepper(state):
            graphs, node_mask, full_logit_priors = state
            next_node = np.argmax(node_mask, axis=1)
            next_action_ = [np.argmax(full_logit_priors[b, next_node[b]]) for b in range(batch_size)]
            next_action = (next_node, next_action_)
            return next_action, np.zeros(len(state))

        dummy_stepper.output_shape = [None, None, None]
        dummy_stepper.init_encoder_output = lambda x: None

        decoder = DecoderWithEnvironmentNew(dummy_stepper, env)
        out = decoder()
        print('done!')
Esempio n. 4
0
    def test_graph_environment_step(self):
        tmp_file = 'tmp2.pickle'
        gi = GrammarInitializer(tmp_file)
        gi.delete_cache()
        # now create a clean new one
        gi = GrammarInitializer(tmp_file)
        # run a first run for 10 molecules
        gi.init_grammar(20)
        gi.grammar.check_attributes()
        mask_gen = HypergraphMaskGenerator(30, gi.grammar, priors=True)
        batch_size = 2
        env = GraphEnvironment(mask_gen,
                               reward_fun=lambda x: np.zeros(len(x)),
                               batch_size=2)
        graphs, node_mask, full_logit_priors = env.reset()
        while True:
            try:
                next_node = np.argmax(node_mask, axis=1)
                next_action_ = [np.argmax(full_logit_priors[b, next_node[b]]) for b in range(batch_size)]
                next_action = (next_node, next_action_)
                (graphs, node_mask, full_logit_priors), reward, done, info = env.step(next_action)
            except StopIteration:
                break

        print(info)
Esempio n. 5
0
def make_grammar(tmp_file = 'tmp2.pickle'):
    gi = GrammarInitializer(tmp_file)
    gi.delete_cache()
    # now create a clean new one
    gi = GrammarInitializer(tmp_file)
    # run a first run for 10 molecules
    gi.init_grammar(20)
    gi.grammar.check_attributes()
    return gi.grammar
 def test_hypergraph_mask_gen_step(self):
     tmp_file = 'tmp2.pickle'
     gi = GrammarInitializer(tmp_file)
     gi.delete_cache()
     # now create a clean new one
     gi = GrammarInitializer(tmp_file)
     # run a first run for 10 molecules
     gi.init_grammar(20)
     gi.grammar.check_attributes()
     mask_gen = HypergraphMaskGenerator(30, gi.grammar, priors=True)
     batch_size = 2
     next_action = (None, [None for _ in range(batch_size)])
     while True:
         try:
             graphs, node_mask, full_logit_priors = mask_gen.step(next_action)
             next_node = np.argmax(node_mask, axis=1)
             next_action_ = [np.argmax(full_logit_priors[b, next_node[b]]) for b in range(batch_size)]
             next_action = (next_node, next_action_)
         except StopIteration:
             break
Esempio n. 7
0
    atoms = num_aromatic_rings(smiles)
    return [-1 if num is None else num + 0.5 for num in atoms]


batch_size = 20  # 20
drop_rate = 0.5
molecules = True
grammar_cache = 'hyper_grammar.pickle'
grammar = 'hypergraph:' + grammar_cache
settings = get_settings(molecules, grammar)
# max_steps = 277  # settings['max_seq_length']
invalid_value = -3.5
scorer = NormalizedScorer(invalid_value=invalid_value)
reward_fun = scorer  #lambda x: np.ones(len(x)) # lambda x: reward_aromatic_rings(x)#
# later will run this ahead of time
gi = GrammarInitializer(grammar_cache)
# if True:
#     gi.delete_cache()
#     gi = GrammarInitializer(grammar_cache)
#     max_steps_smiles = gi.init_grammar(1000)

max_steps = 30
model, gen_fitter, disc_fitter = train_policy_gradient(
    molecules,
    grammar,
    EPOCHS=100,
    BATCH_SIZE=batch_size,
    reward_fun_on=reward_fun,
    max_steps=max_steps,
    lr_on=0.3e-5,
    lr_discrim=5e-4,
Esempio n. 8
0
from generative_playground.models.decoder.graph_decoder import GraphEncoder
from generative_playground.codec.codec import get_codec
from generative_playground.codec.hypergraph_grammar import GrammarInitializer
from generative_playground.models.heads import MultipleOutputHead
from generative_playground.codec.hypergraph import HyperGraph
from generative_playground.molecules.data_utils.zinc_utils import get_zinc_molecules
from generative_playground.utils.gpu_utils import device

# make sure there's a cached grammar for us to use
tmp_file = 'tmp.pickle'
if os.path.isfile(tmp_file):
    os.remove(tmp_file)
if os.path.isfile('init_' + tmp_file):
    os.remove('init_' + tmp_file)

gi = GrammarInitializer(tmp_file)
gi.init_grammar(10)

z_size = 200
batch_size = 2
max_seq_length = 30


class TestDecoders(TestCase):
    def generic_decoder_test(self, decoder_type, grammar):
        codec = get_codec(molecules=True,
                          grammar=grammar,
                          max_seq_length=max_seq_length)
        decoder, pre_decoder = get_decoder(decoder_type=decoder_type,
                                           max_seq_length=max_seq_length,
                                           grammar=grammar,
Esempio n. 9
0
from generative_playground.codec.hypergraph import HyperGraph
from generative_playground.codec.hypergraph_grammar import GrammarInitializer
from generative_playground.models.embedder.graph_embedder import GraphEmbedder
from generative_playground.molecules.data_utils.zinc_utils import get_zinc_molecules
from generative_playground.models.decoder.graph_decoder import GraphEncoder
from generative_playground.codec.codec import get_codec

# create a grammar from scratch # TODO: later, want to load a cached grammar instead
tmp_file = 'tmp.pickle'
# delete the cached files
if os.path.isfile(tmp_file):
    os.remove(tmp_file)
if os.path.isfile('init_' + tmp_file):
    os.remove('init_' + tmp_file)

gi = GrammarInitializer(tmp_file)

# run a first run for 10 molecules
first_10 = gi.init_grammar(10)


class TestGraphEmbedder(TestCase):
    def test_graph_embedder_on_complete_hypergraphs(self):
        ge = GraphEmbedder(target_dim=512, grammar=gi.grammar)
        mol_graphs = [
            HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
        ]
        out = ge(mol_graphs)
        for eg, g in zip(out, mol_graphs):
            for i in range(len(g), max([len(gg) for gg in mol_graphs])):
                assert eg[i].abs().max(
Esempio n. 10
0
    atoms = num_aromatic_rings(smiles)
    return [-1 if num is None else num + 0.5 for num in atoms]


batch_size = 20 # 20
drop_rate = 0.5
molecules = True
grammar_cache = 'hyper_grammar.pickle'
grammar = 'hypergraph:' + grammar_cache
settings = get_settings(molecules, grammar)
# max_steps = 277  # settings['max_seq_length']
invalid_value = -3.5
# scorer = NormalizedScorer(invalid_value=invalid_value)
# reward_fun = scorer #lambda x: np.ones(len(x)) # lambda x: reward_aromatic_rings(x)#
# later will run this ahead of time
gi = GrammarInitializer(grammar_cache, grammar_class=HypergraphRPEGrammar)
if False:
    gi.delete_cache()
    num_mols = 1000
    max_steps_smiles = gi.init_grammar(num_mols)
    gi.save()
    smiles = get_zinc_smiles(num_mols)
    gi.grammar.extract_rpe_pairs(smiles, 10)
    gi.grammar.count_rule_frequencies(collapsed_trees)
    gi.save()

max_steps = 30
model, gen_fitter, disc_fitter = train_policy_gradient(molecules,
                                                       grammar,
                                                       EPOCHS=100,
                                                       BATCH_SIZE=batch_size,