def test_hypergraph_mask_gen_no_priors(self):
        tmp_file = 'tmp2.pickle'
        gi = GrammarInitializer(tmp_file)
        gi.delete_cache()
        # now create a clean new one
        gi = GrammarInitializer(tmp_file)
        # run a first run for 10 molecules
        gi.init_grammar(20)
        gi.grammar.check_attributes()

        mask_gen = HypergraphMaskGenerator(30, gi.grammar, priors=False)

        all_actions = []
        next_action = [None for _ in range(2)]
        policy = SoftmaxRandomSamplePolicy()
        while True:
            try:
                mask_gen.apply_action(next_action)
                cond_priors = mask_gen.action_prior_logits()
                cond_priors_pytorch = torch.from_numpy(cond_priors).to(device=device, dtype=torch.float32)
                next_action = policy(cond_priors_pytorch).cpu().detach().numpy()
                all_actions.append(next_action)
            except StopIteration:
                break
        all_actions = np.array(all_actions).T
        all_smiles = gi.grammar.actions_to_strings(all_actions)
        for smile in all_smiles:
            self.assertIsNot(MolFromSmiles(smile), None)
예제 #2
0
    def test_grammar_initializer(self):
        # nuke any cached data
        tmp_file = 'tmp2.pickle'
        gi = GrammarInitializer(tmp_file)
        gi.delete_cache()
        # now create a clean new one
        gi = GrammarInitializer(tmp_file)
        # run a first run for 10 molecules
        first_10 = gi.init_grammar(10)
        # load the resulting object
        gi2 = GrammarInitializer.load(gi.own_filename)
        gi2.init_grammar(20)

        freqs = gi2.grammar.get_log_frequencies()
        assert len(freqs) == len(gi2.grammar)
        assert all([f >= 0 for f in freqs])
        cond_count = 0
        for cf in gi2.grammar.conditional_frequencies.values():
            this_count = sum(cf.values())
            assert abs(this_count-1.0) < 1e-5
            cond_count += sum(cf.values())
        nt_count = 0
        for rule in gi2.grammar.rules:
            if rule is not None:
                nt_count += len(rule.nonterminal_ids())
        assert cond_count == nt_count, "Something went wrong when counting the frequencies..."
        gi2.grammar.check_attributes()
예제 #3
0
    def test_decoder_with_environment_new(self):
        tmp_file = 'tmp2.pickle'
        gi = GrammarInitializer(tmp_file)
        gi.delete_cache()
        # now create a clean new one
        gi = GrammarInitializer(tmp_file)
        # run a first run for 10 molecules
        gi.init_grammar(20)
        gi.grammar.check_attributes()
        mask_gen = HypergraphMaskGenerator(30, gi.grammar, priors=True)
        batch_size = 2

        env = GraphEnvironment(mask_gen,
                               reward_fun=lambda x: 2*np.ones(len(x)),
                               batch_size=2)

        def dummy_stepper(state):
            graphs, node_mask, full_logit_priors = state
            next_node = np.argmax(node_mask, axis=1)
            next_action_ = [np.argmax(full_logit_priors[b, next_node[b]]) for b in range(batch_size)]
            next_action = (next_node, next_action_)
            return next_action, np.zeros(len(state))

        dummy_stepper.output_shape = [None, None, None]
        dummy_stepper.init_encoder_output = lambda x: None

        decoder = DecoderWithEnvironmentNew(dummy_stepper, env)
        out = decoder()
        print('done!')
예제 #4
0
    def test_graph_environment_step(self):
        tmp_file = 'tmp2.pickle'
        gi = GrammarInitializer(tmp_file)
        gi.delete_cache()
        # now create a clean new one
        gi = GrammarInitializer(tmp_file)
        # run a first run for 10 molecules
        gi.init_grammar(20)
        gi.grammar.check_attributes()
        mask_gen = HypergraphMaskGenerator(30, gi.grammar, priors=True)
        batch_size = 2
        env = GraphEnvironment(mask_gen,
                               reward_fun=lambda x: np.zeros(len(x)),
                               batch_size=2)
        graphs, node_mask, full_logit_priors = env.reset()
        while True:
            try:
                next_node = np.argmax(node_mask, axis=1)
                next_action_ = [np.argmax(full_logit_priors[b, next_node[b]]) for b in range(batch_size)]
                next_action = (next_node, next_action_)
                (graphs, node_mask, full_logit_priors), reward, done, info = env.step(next_action)
            except StopIteration:
                break

        print(info)
예제 #5
0
def make_grammar(tmp_file = 'tmp2.pickle'):
    gi = GrammarInitializer(tmp_file)
    gi.delete_cache()
    # now create a clean new one
    gi = GrammarInitializer(tmp_file)
    # run a first run for 10 molecules
    gi.init_grammar(20)
    gi.grammar.check_attributes()
    return gi.grammar
 def test_hypergraph_mask_gen_step(self):
     tmp_file = 'tmp2.pickle'
     gi = GrammarInitializer(tmp_file)
     gi.delete_cache()
     # now create a clean new one
     gi = GrammarInitializer(tmp_file)
     # run a first run for 10 molecules
     gi.init_grammar(20)
     gi.grammar.check_attributes()
     mask_gen = HypergraphMaskGenerator(30, gi.grammar, priors=True)
     batch_size = 2
     next_action = (None, [None for _ in range(batch_size)])
     while True:
         try:
             graphs, node_mask, full_logit_priors = mask_gen.step(next_action)
             next_node = np.argmax(node_mask, axis=1)
             next_action_ = [np.argmax(full_logit_priors[b, next_node[b]]) for b in range(batch_size)]
             next_action = (next_node, next_action_)
         except StopIteration:
             break
예제 #7
0

batch_size = 20 # 20
drop_rate = 0.5
molecules = True
grammar_cache = 'hyper_grammar.pickle'
grammar = 'hypergraph:' + grammar_cache
settings = get_settings(molecules, grammar)
# max_steps = 277  # settings['max_seq_length']
invalid_value = -3.5
# scorer = NormalizedScorer(invalid_value=invalid_value)
# reward_fun = scorer #lambda x: np.ones(len(x)) # lambda x: reward_aromatic_rings(x)#
# later will run this ahead of time
gi = GrammarInitializer(grammar_cache, grammar_class=HypergraphRPEGrammar)
if False:
    gi.delete_cache()
    num_mols = 1000
    max_steps_smiles = gi.init_grammar(num_mols)
    gi.save()
    smiles = get_zinc_smiles(num_mols)
    gi.grammar.extract_rpe_pairs(smiles, 10)
    gi.grammar.count_rule_frequencies(collapsed_trees)
    gi.save()

max_steps = 30
model, gen_fitter, disc_fitter = train_policy_gradient(molecules,
                                                       grammar,
                                                       EPOCHS=100,
                                                       BATCH_SIZE=batch_size,
                                                       reward_fun_on=reward_fun,
                                                       max_steps=max_steps,