def test_hypergraph_mask_gen_no_priors(self): tmp_file = 'tmp2.pickle' gi = GrammarInitializer(tmp_file) gi.delete_cache() # now create a clean new one gi = GrammarInitializer(tmp_file) # run a first run for 10 molecules gi.init_grammar(20) gi.grammar.check_attributes() mask_gen = HypergraphMaskGenerator(30, gi.grammar, priors=False) all_actions = [] next_action = [None for _ in range(2)] policy = SoftmaxRandomSamplePolicy() while True: try: mask_gen.apply_action(next_action) cond_priors = mask_gen.action_prior_logits() cond_priors_pytorch = torch.from_numpy(cond_priors).to(device=device, dtype=torch.float32) next_action = policy(cond_priors_pytorch).cpu().detach().numpy() all_actions.append(next_action) except StopIteration: break all_actions = np.array(all_actions).T all_smiles = gi.grammar.actions_to_strings(all_actions) for smile in all_smiles: self.assertIsNot(MolFromSmiles(smile), None)
def test_grammar_initializer(self): # nuke any cached data tmp_file = 'tmp2.pickle' gi = GrammarInitializer(tmp_file) gi.delete_cache() # now create a clean new one gi = GrammarInitializer(tmp_file) # run a first run for 10 molecules first_10 = gi.init_grammar(10) # load the resulting object gi2 = GrammarInitializer.load(gi.own_filename) gi2.init_grammar(20) freqs = gi2.grammar.get_log_frequencies() assert len(freqs) == len(gi2.grammar) assert all([f >= 0 for f in freqs]) cond_count = 0 for cf in gi2.grammar.conditional_frequencies.values(): this_count = sum(cf.values()) assert abs(this_count-1.0) < 1e-5 cond_count += sum(cf.values()) nt_count = 0 for rule in gi2.grammar.rules: if rule is not None: nt_count += len(rule.nonterminal_ids()) assert cond_count == nt_count, "Something went wrong when counting the frequencies..." gi2.grammar.check_attributes()
def test_decoder_with_environment_new(self): tmp_file = 'tmp2.pickle' gi = GrammarInitializer(tmp_file) gi.delete_cache() # now create a clean new one gi = GrammarInitializer(tmp_file) # run a first run for 10 molecules gi.init_grammar(20) gi.grammar.check_attributes() mask_gen = HypergraphMaskGenerator(30, gi.grammar, priors=True) batch_size = 2 env = GraphEnvironment(mask_gen, reward_fun=lambda x: 2*np.ones(len(x)), batch_size=2) def dummy_stepper(state): graphs, node_mask, full_logit_priors = state next_node = np.argmax(node_mask, axis=1) next_action_ = [np.argmax(full_logit_priors[b, next_node[b]]) for b in range(batch_size)] next_action = (next_node, next_action_) return next_action, np.zeros(len(state)) dummy_stepper.output_shape = [None, None, None] dummy_stepper.init_encoder_output = lambda x: None decoder = DecoderWithEnvironmentNew(dummy_stepper, env) out = decoder() print('done!')
def test_graph_environment_step(self): tmp_file = 'tmp2.pickle' gi = GrammarInitializer(tmp_file) gi.delete_cache() # now create a clean new one gi = GrammarInitializer(tmp_file) # run a first run for 10 molecules gi.init_grammar(20) gi.grammar.check_attributes() mask_gen = HypergraphMaskGenerator(30, gi.grammar, priors=True) batch_size = 2 env = GraphEnvironment(mask_gen, reward_fun=lambda x: np.zeros(len(x)), batch_size=2) graphs, node_mask, full_logit_priors = env.reset() while True: try: next_node = np.argmax(node_mask, axis=1) next_action_ = [np.argmax(full_logit_priors[b, next_node[b]]) for b in range(batch_size)] next_action = (next_node, next_action_) (graphs, node_mask, full_logit_priors), reward, done, info = env.step(next_action) except StopIteration: break print(info)
def make_grammar(tmp_file = 'tmp2.pickle'): gi = GrammarInitializer(tmp_file) gi.delete_cache() # now create a clean new one gi = GrammarInitializer(tmp_file) # run a first run for 10 molecules gi.init_grammar(20) gi.grammar.check_attributes() return gi.grammar
def test_hypergraph_mask_gen_step(self): tmp_file = 'tmp2.pickle' gi = GrammarInitializer(tmp_file) gi.delete_cache() # now create a clean new one gi = GrammarInitializer(tmp_file) # run a first run for 10 molecules gi.init_grammar(20) gi.grammar.check_attributes() mask_gen = HypergraphMaskGenerator(30, gi.grammar, priors=True) batch_size = 2 next_action = (None, [None for _ in range(batch_size)]) while True: try: graphs, node_mask, full_logit_priors = mask_gen.step(next_action) next_node = np.argmax(node_mask, axis=1) next_action_ = [np.argmax(full_logit_priors[b, next_node[b]]) for b in range(batch_size)] next_action = (next_node, next_action_) except StopIteration: break
atoms = num_aromatic_rings(smiles) return [-1 if num is None else num + 0.5 for num in atoms] batch_size = 20 # 20 drop_rate = 0.5 molecules = True grammar_cache = 'hyper_grammar.pickle' grammar = 'hypergraph:' + grammar_cache settings = get_settings(molecules, grammar) # max_steps = 277 # settings['max_seq_length'] invalid_value = -3.5 scorer = NormalizedScorer(invalid_value=invalid_value) reward_fun = scorer #lambda x: np.ones(len(x)) # lambda x: reward_aromatic_rings(x)# # later will run this ahead of time gi = GrammarInitializer(grammar_cache) # if True: # gi.delete_cache() # gi = GrammarInitializer(grammar_cache) # max_steps_smiles = gi.init_grammar(1000) max_steps = 30 model, gen_fitter, disc_fitter = train_policy_gradient( molecules, grammar, EPOCHS=100, BATCH_SIZE=batch_size, reward_fun_on=reward_fun, max_steps=max_steps, lr_on=0.3e-5, lr_discrim=5e-4,
from generative_playground.models.decoder.graph_decoder import GraphEncoder from generative_playground.codec.codec import get_codec from generative_playground.codec.hypergraph_grammar import GrammarInitializer from generative_playground.models.heads import MultipleOutputHead from generative_playground.codec.hypergraph import HyperGraph from generative_playground.molecules.data_utils.zinc_utils import get_zinc_molecules from generative_playground.utils.gpu_utils import device # make sure there's a cached grammar for us to use tmp_file = 'tmp.pickle' if os.path.isfile(tmp_file): os.remove(tmp_file) if os.path.isfile('init_' + tmp_file): os.remove('init_' + tmp_file) gi = GrammarInitializer(tmp_file) gi.init_grammar(10) z_size = 200 batch_size = 2 max_seq_length = 30 class TestDecoders(TestCase): def generic_decoder_test(self, decoder_type, grammar): codec = get_codec(molecules=True, grammar=grammar, max_seq_length=max_seq_length) decoder, pre_decoder = get_decoder(decoder_type=decoder_type, max_seq_length=max_seq_length, grammar=grammar,
atoms = num_aromatic_rings(smiles) return [-1 if num is None else num + 0.5 for num in atoms] batch_size = 20 # 20 drop_rate = 0.5 molecules = True grammar_cache = 'hyper_grammar.pickle' grammar = 'hypergraph:' + grammar_cache settings = get_settings(molecules, grammar) # max_steps = 277 # settings['max_seq_length'] invalid_value = -3.5 # scorer = NormalizedScorer(invalid_value=invalid_value) # reward_fun = scorer #lambda x: np.ones(len(x)) # lambda x: reward_aromatic_rings(x)# # later will run this ahead of time gi = GrammarInitializer(grammar_cache, grammar_class=HypergraphRPEGrammar) if False: gi.delete_cache() num_mols = 1000 max_steps_smiles = gi.init_grammar(num_mols) gi.save() smiles = get_zinc_smiles(num_mols) gi.grammar.extract_rpe_pairs(smiles, 10) gi.grammar.count_rule_frequencies(collapsed_trees) gi.save() max_steps = 30 model, gen_fitter, disc_fitter = train_policy_gradient(molecules, grammar, EPOCHS=100, BATCH_SIZE=batch_size,