def test_parser_roundtrip_via_indices(self): # TODO: cheating a bit here, reconstruction does fail for some smiles # chirality gets inverted sometimes, so need to run the loop twice to reconstruct the original g = HypergraphGrammar() g.delete_cache() actions = g.strings_to_actions(smiles) re_smiles = g.decode_from_actions(actions) re_actions = g.strings_to_actions(re_smiles) rere_smiles = g.decode_from_actions(re_actions) for old, new in zip(smiles, rere_smiles): old_fix = old #.replace('@@', '@')#.replace('/','\\') new_fix = new #.replace('@@', '@')#.replace('/','\\').replace('\\','') assert old_fix == new_fix
def test_mask_gen(self): g = HypergraphGrammar() g.strings_to_actions( smiles) # that initializes g with the rules from these molecules g.calc_terminal_distance() batch_size = 10 max_rules = 50 all_actions = [] next_action = [None for _ in range(batch_size)] mask_gen = HypergraphMaskGenerator(max_rules, g) while True: try: next_masks = mask_gen(next_action) next_action = [] for mask in next_masks: inds = np.nonzero(mask)[0] next_act = random.choice(inds) next_action.append(next_act) all_actions.append(next_action) except StopIteration: break all_actions = np.array(all_actions).T # the test is that we get that far, producing valid molecules all_smiles = g.decode_from_actions(all_actions) for smile in all_smiles: print(smile)
f.close() fn = "rule_hypergraphs.pickle" max_rules = 50 if os.path.isfile(fn): rm = HypergraphGrammar.load(fn) else: rm = HypergraphGrammar(cache_file=fn) bad_smiles = [] for num, smile in enumerate(L[:100]): try: smile = MolToSmiles(MolFromSmiles(smile)) print(smile) mol = MolFromSmiles(smile) actions = rm.strings_to_actions([smile]) re_smile = rm.decode_from_actions(actions)[0] mol = MolFromSmiles(smile) if re_smile != smile: print("SMILES reconstruction wasn't perfect for " + smile) print(re_smile) except Exception as e: print(e) bad_smiles.append(smile) rm.calc_terminal_distance() # now let's write a basic for-loop to create molecules batch_size = 10 all_actions = [] next_action = [None for _ in range(batch_size)]