Exemplo n.º 1
0
    def test_hypergraph_mask_gen(self):
        molecules = True
        grammar_cache = 'tmp.pickle'
        grammar = 'hypergraph:' + grammar_cache
        # create a grammar cache inferred from our sample molecules
        g = HypergraphGrammar(cache_file=grammar_cache)
        if os.path.isfile(g.cache_file):
            os.remove(g.cache_file)
        g.strings_to_actions(get_zinc_smiles(5))
        mask_gen1 = get_codec(molecules, grammar, 30).mask_gen
        mask_gen2 = get_codec(molecules, grammar, 30).mask_gen
        mask_gen1.priors = False
        mask_gen2.priors = True
        policy1 = SoftmaxRandomSamplePolicy(
            bias=mask_gen1.grammar.get_log_frequencies())
        policy2 = SoftmaxRandomSamplePolicy()
        lp = []
        for mg in [mask_gen1, mask_gen2]:
            mg.reset()
            mg.apply_action([None])
            logit_priors = mg.action_prior_logits()  # that includes any priors
            lp.append(
                torch.from_numpy(logit_priors).to(device=device,
                                                  dtype=torch.float32))

        dummy_model_output = torch.ones_like(lp[0])
        eff_logits = []
        for this_lp, policy in zip(lp, [policy1, policy2]):
            eff_logits.append(policy.effective_logits(dummy_model_output))

        assert torch.max((eff_logits[0] - eff_logits[1]).abs()) < 1e-6
Exemplo n.º 2
0
 def test_hypergraph_grammar_codec(self):
     molecules = True
     input = smiles1
     grammar_cache = 'tmp.pickle'
     grammar = 'hypergraph:' + grammar_cache
     # create a grammar cache inferred from our sample molecules
     g = HypergraphGrammar(cache_file=grammar_cache)
     g.strings_to_actions(smiles)
     self.check_codec(input, molecules, grammar)
Exemplo n.º 3
0
    def test_hypergraph_rpe_parser_bad_smiles(self):
        g = HypergraphGrammar()

        trees = []
        for smile in bad_smiles:
            try:
                trees.append(
                    g.normalize_tree(hypergraph_parser(MolFromSmiles(smile))))
            except (AssertionError, IndexError):
                print('Failed for {}'.format(smile))
                raise
 def test_hypergraph_mask_gen(self):
     molecules = True
     grammar_cache = 'tmp.pickle'
     grammar = 'hypergraph:' + grammar_cache
     # create a grammar cache inferred from our sample molecules
     g = HypergraphGrammar(cache_file=grammar_cache)
     if os.path.isfile(g.cache_file):
         os.remove(g.cache_file)
     g.strings_to_actions(smiles)
     codec = get_codec(molecules, grammar, max_seq_length)
     self.generate_and_validate(codec)
Exemplo n.º 5
0
    def test_mask_gen(self):
        g = HypergraphGrammar()
        g.strings_to_actions(
            smiles)  # that initializes g with the rules from these molecules
        g.calc_terminal_distance()
        batch_size = 10
        max_rules = 50
        all_actions = []
        next_action = [None for _ in range(batch_size)]
        mask_gen = HypergraphMaskGenerator(max_rules, g)
        while True:
            try:
                next_masks = mask_gen(next_action)
                next_action = []
                for mask in next_masks:
                    inds = np.nonzero(mask)[0]
                    next_act = random.choice(inds)
                    next_action.append(next_act)
                all_actions.append(next_action)
            except StopIteration:
                break

        all_actions = np.array(all_actions).T
        # the test is that we get that far, producing valid molecules
        all_smiles = g.decode_from_actions(all_actions)
        for smile in all_smiles:
            print(smile)
Exemplo n.º 6
0
    def test_get_set_params_as_vector(self):
        grammar_cache = 'hyper_grammar_guac_10k_with_clique_collapse.pickle'  # 'hyper_grammar.pickle'
        g = HypergraphGrammar.load(grammar_cache)
        m = CondtionalProbabilityModel(g)
        out = m.get_params_as_vector()

        out[0] = 1

        m.set_params_from_vector(out)
        out2 = m.get_params_as_vector()

        assert out2[0] == out[0]
Exemplo n.º 7
0
    def test_hypergraph_rpe_parser(self):
        g = HypergraphGrammar()
        g.strings_to_actions(smiles)

        trees = [
            g.normalize_tree(hypergraph_parser(MolFromSmiles(smile)))
            for smile in smiles
        ]

        rule_pairs = extract_popular_hypergraph_pairs(g, trees, 10)

        parser = HypergraphRPEParser(g, rule_pairs)
        collapsed_trees = [parser.parse(smile) for smile in smiles]

        recovered_smiles = []
        for tree in collapsed_trees:
            graph = graph_from_graph_tree(tree)
            mol = to_mol(graph)
            recovered_smiles.append(MolToSmiles(mol))

        self.assertEqual(smiles, recovered_smiles)
Exemplo n.º 8
0
    def test_hypergraph_rpe(self):
        g = HypergraphGrammar()
        g.strings_to_actions(smiles)

        tree = g.normalize_tree(hypergraph_parser(MolFromSmiles(smiles1)))

        num_rules_before = len(g.rules)
        rule_pairs = extract_popular_hypergraph_pairs(g, [tree], 10)
        num_rules_after = len(g.rules)

        tree_rules_before = len(tree.rules())
        collapsed_tree = apply_hypergraph_substitution(g, tree, rule_pairs[0])
        tree_rules_after = len(collapsed_tree.rules())

        graph = graph_from_graph_tree(collapsed_tree)
        mol = to_mol(graph)
        recovered_smiles = MolToSmiles(mol)

        self.assertEqual(smiles1, recovered_smiles)
        self.assertGreater(num_rules_after, num_rules_before)
        self.assertLess(tree_rules_after, tree_rules_before)
Exemplo n.º 9
0
    def test_parser_roundtrip_via_indices(self):
        # TODO: cheating a bit here, reconstruction does fail for some smiles
        # chirality gets inverted sometimes, so need to run the loop twice to reconstruct the original
        g = HypergraphGrammar()
        g.delete_cache()
        actions = g.strings_to_actions(smiles)
        re_smiles = g.decode_from_actions(actions)
        re_actions = g.strings_to_actions(re_smiles)
        rere_smiles = g.decode_from_actions(re_actions)

        for old, new in zip(smiles, rere_smiles):
            old_fix = old  #.replace('@@', '@')#.replace('/','\\')
            new_fix = new  #.replace('@@', '@')#.replace('/','\\').replace('\\','')
            assert old_fix == new_fix
    def test_cycle_length(self):
        g = HypergraphGrammar.load('hyper_grammar.pickle')
        for r in g.rules:
            if r is not None:
                g = r.to_nx()
                cycles = nx.minimum_cycle_basis(g)
                if len(cycles) > 0:
                    maxlen = max([len(c) for c in cycles])
                    if maxlen > 7:
                        print(maxlen)

                cc = nx.number_connected_components(g)
                if cc > 1:
                    print(cc)
Exemplo n.º 11
0
def get_codec(molecules, grammar, max_seq_length):
    if grammar is True:
        grammar = 'classic'
    # character-based models
    if grammar is False:
        if molecules:
            charlist = [
                'C', '(', ')', 'c', '1', '2', 'o', '=', 'O', 'N', '3', 'F',
                '[', '@', 'H', ']', 'n', '-', '#', 'S', 'l', '+', 's', 'B',
                'r', '/', '4', '\\', '5', '6', '7', 'I', 'P', '8', ' '
            ]
        else:
            charlist = [
                'x', '+', '(', ')', '1', '2', '3', '*', '/', 's', 'i', 'n',
                'e', 'p', ' '
            ]

        codec = CharacterCodec(max_len=max_seq_length, charlist=charlist)
    elif grammar == 'classic':
        if molecules:
            codec = CFGrammarCodec(max_len=max_seq_length,
                                   grammar=grammar_zinc,
                                   tokenizer=zinc_tokenizer)
        else:
            codec = CFGrammarCodec(max_len=max_seq_length,
                                   grammar=grammar_eq,
                                   tokenizer=eq_tokenizer)
        codec.mask_gen = GrammarMaskGenerator(max_seq_length, codec.grammar)
    elif grammar == 'new':
        codec = CFGrammarCodec(max_len=max_seq_length,
                               grammar=grammar_zinc_new,
                               tokenizer=zinc_tokenizer_new)
        codec.mask_gen = GrammarMaskGeneratorNew(max_seq_length, codec.grammar)
    elif 'hypergraph' in grammar:
        grammar_cache = grammar.split(':')[1]
        assert grammar_cache is not None, "Invalid cached hypergraph grammar file:" + str(
            grammar_cache)
        codec = HypergraphGrammar.load(grammar_cache)
        codec.MAX_LEN = max_seq_length
        codec.normalize_conditional_frequencies()
        codec.calc_terminal_distance(
        )  # just in case it wasn't initialized yet
        codec.mask_gen = HypergraphMaskGenerator(max_seq_length, codec)
    assert hasattr(codec, 'PAD_INDEX')
    return codec
Exemplo n.º 12
0
    def test_graph_from_graph_tree_idempotent(self):
        g = HypergraphGrammar()
        g.strings_to_actions(smiles)
        g.calc_terminal_distance()

        tree = g.normalize_tree(hypergraph_parser(MolFromSmiles(smiles1)))

        # The second call here would fail before
        # This was solved by copying in remove_nonterminals where the issue
        # was with mutating the parent tree.node state
        graph1 = graph_from_graph_tree(tree)
        graph2 = graph_from_graph_tree(tree)

        mol1 = to_mol(graph1)
        mol2 = to_mol(graph2)
        recovered_smiles1 = MolToSmiles(mol1)
        recovered_smiles2 = MolToSmiles(mol2)

        self.assertEqual(smiles1, recovered_smiles1)
        self.assertEqual(recovered_smiles1, recovered_smiles2)
Exemplo n.º 13
0
    settings = get_settings(molecules=True, grammar='new')
    thresh = 100000
    # Read in the strings
    f = open(settings['source_data'], 'r')
    L = []
    for line in f:
        line = line.strip()
        L.append(line)
        if len(L) > thresh:
            break
    f.close()

    fn = "rule_hypergraphs.pickle"
    max_rules = 50
    if os.path.isfile(fn):
        rm = HypergraphGrammar.load(fn)
    else:
        rm = HypergraphGrammar(cache_file=fn)
        bad_smiles = []
        for num, smile in enumerate(L[:100]):
            try:
                smile = MolToSmiles(MolFromSmiles(smile))
                print(smile)
                mol = MolFromSmiles(smile)
                actions = rm.strings_to_actions([smile])
                re_smile = rm.decode_from_actions(actions)[0]
                mol = MolFromSmiles(smile)
                if re_smile != smile:
                    print("SMILES reconstruction wasn't perfect for " + smile)
                print(re_smile)