def test_get_set_params_as_vector(self): grammar_cache = 'hyper_grammar_guac_10k_with_clique_collapse.pickle' # 'hyper_grammar.pickle' g = HypergraphGrammar.load(grammar_cache) m = CondtionalProbabilityModel(g) out = m.get_params_as_vector() out[0] = 1 m.set_params_from_vector(out) out2 = m.get_params_as_vector() assert out2[0] == out[0]
def test_cycle_length(self): g = HypergraphGrammar.load('hyper_grammar.pickle') for r in g.rules: if r is not None: g = r.to_nx() cycles = nx.minimum_cycle_basis(g) if len(cycles) > 0: maxlen = max([len(c) for c in cycles]) if maxlen > 7: print(maxlen) cc = nx.number_connected_components(g) if cc > 1: print(cc)
def get_codec(molecules, grammar, max_seq_length): if grammar is True: grammar = 'classic' # character-based models if grammar is False: if molecules: charlist = [ 'C', '(', ')', 'c', '1', '2', 'o', '=', 'O', 'N', '3', 'F', '[', '@', 'H', ']', 'n', '-', '#', 'S', 'l', '+', 's', 'B', 'r', '/', '4', '\\', '5', '6', '7', 'I', 'P', '8', ' ' ] else: charlist = [ 'x', '+', '(', ')', '1', '2', '3', '*', '/', 's', 'i', 'n', 'e', 'p', ' ' ] codec = CharacterCodec(max_len=max_seq_length, charlist=charlist) elif grammar == 'classic': if molecules: codec = CFGrammarCodec(max_len=max_seq_length, grammar=grammar_zinc, tokenizer=zinc_tokenizer) else: codec = CFGrammarCodec(max_len=max_seq_length, grammar=grammar_eq, tokenizer=eq_tokenizer) codec.mask_gen = GrammarMaskGenerator(max_seq_length, codec.grammar) elif grammar == 'new': codec = CFGrammarCodec(max_len=max_seq_length, grammar=grammar_zinc_new, tokenizer=zinc_tokenizer_new) codec.mask_gen = GrammarMaskGeneratorNew(max_seq_length, codec.grammar) elif 'hypergraph' in grammar: grammar_cache = grammar.split(':')[1] assert grammar_cache is not None, "Invalid cached hypergraph grammar file:" + str( grammar_cache) codec = HypergraphGrammar.load(grammar_cache) codec.MAX_LEN = max_seq_length codec.normalize_conditional_frequencies() codec.calc_terminal_distance( ) # just in case it wasn't initialized yet codec.mask_gen = HypergraphMaskGenerator(max_seq_length, codec) assert hasattr(codec, 'PAD_INDEX') return codec
settings = get_settings(molecules=True, grammar='new') thresh = 100000 # Read in the strings f = open(settings['source_data'], 'r') L = [] for line in f: line = line.strip() L.append(line) if len(L) > thresh: break f.close() fn = "rule_hypergraphs.pickle" max_rules = 50 if os.path.isfile(fn): rm = HypergraphGrammar.load(fn) else: rm = HypergraphGrammar(cache_file=fn) bad_smiles = [] for num, smile in enumerate(L[:100]): try: smile = MolToSmiles(MolFromSmiles(smile)) print(smile) mol = MolFromSmiles(smile) actions = rm.strings_to_actions([smile]) re_smile = rm.decode_from_actions(actions)[0] mol = MolFromSmiles(smile) if re_smile != smile: print("SMILES reconstruction wasn't perfect for " + smile) print(re_smile)