예제 #1
0
    def test_get_set_params_as_vector(self):
        grammar_cache = 'hyper_grammar_guac_10k_with_clique_collapse.pickle'  # 'hyper_grammar.pickle'
        g = HypergraphGrammar.load(grammar_cache)
        m = CondtionalProbabilityModel(g)
        out = m.get_params_as_vector()

        out[0] = 1

        m.set_params_from_vector(out)
        out2 = m.get_params_as_vector()

        assert out2[0] == out[0]
    def test_cycle_length(self):
        g = HypergraphGrammar.load('hyper_grammar.pickle')
        for r in g.rules:
            if r is not None:
                g = r.to_nx()
                cycles = nx.minimum_cycle_basis(g)
                if len(cycles) > 0:
                    maxlen = max([len(c) for c in cycles])
                    if maxlen > 7:
                        print(maxlen)

                cc = nx.number_connected_components(g)
                if cc > 1:
                    print(cc)
예제 #3
0
def get_codec(molecules, grammar, max_seq_length):
    if grammar is True:
        grammar = 'classic'
    # character-based models
    if grammar is False:
        if molecules:
            charlist = [
                'C', '(', ')', 'c', '1', '2', 'o', '=', 'O', 'N', '3', 'F',
                '[', '@', 'H', ']', 'n', '-', '#', 'S', 'l', '+', 's', 'B',
                'r', '/', '4', '\\', '5', '6', '7', 'I', 'P', '8', ' '
            ]
        else:
            charlist = [
                'x', '+', '(', ')', '1', '2', '3', '*', '/', 's', 'i', 'n',
                'e', 'p', ' '
            ]

        codec = CharacterCodec(max_len=max_seq_length, charlist=charlist)
    elif grammar == 'classic':
        if molecules:
            codec = CFGrammarCodec(max_len=max_seq_length,
                                   grammar=grammar_zinc,
                                   tokenizer=zinc_tokenizer)
        else:
            codec = CFGrammarCodec(max_len=max_seq_length,
                                   grammar=grammar_eq,
                                   tokenizer=eq_tokenizer)
        codec.mask_gen = GrammarMaskGenerator(max_seq_length, codec.grammar)
    elif grammar == 'new':
        codec = CFGrammarCodec(max_len=max_seq_length,
                               grammar=grammar_zinc_new,
                               tokenizer=zinc_tokenizer_new)
        codec.mask_gen = GrammarMaskGeneratorNew(max_seq_length, codec.grammar)
    elif 'hypergraph' in grammar:
        grammar_cache = grammar.split(':')[1]
        assert grammar_cache is not None, "Invalid cached hypergraph grammar file:" + str(
            grammar_cache)
        codec = HypergraphGrammar.load(grammar_cache)
        codec.MAX_LEN = max_seq_length
        codec.normalize_conditional_frequencies()
        codec.calc_terminal_distance(
        )  # just in case it wasn't initialized yet
        codec.mask_gen = HypergraphMaskGenerator(max_seq_length, codec)
    assert hasattr(codec, 'PAD_INDEX')
    return codec
예제 #4
0
    settings = get_settings(molecules=True, grammar='new')
    thresh = 100000
    # Read in the strings
    f = open(settings['source_data'], 'r')
    L = []
    for line in f:
        line = line.strip()
        L.append(line)
        if len(L) > thresh:
            break
    f.close()

    fn = "rule_hypergraphs.pickle"
    max_rules = 50
    if os.path.isfile(fn):
        rm = HypergraphGrammar.load(fn)
    else:
        rm = HypergraphGrammar(cache_file=fn)
        bad_smiles = []
        for num, smile in enumerate(L[:100]):
            try:
                smile = MolToSmiles(MolFromSmiles(smile))
                print(smile)
                mol = MolFromSmiles(smile)
                actions = rm.strings_to_actions([smile])
                re_smile = rm.decode_from_actions(actions)[0]
                mol = MolFromSmiles(smile)
                if re_smile != smile:
                    print("SMILES reconstruction wasn't perfect for " + smile)
                print(re_smile)