def test_simple(): test_smiles = 'CC(O)C' rd_mol = Chem.MolFromSmiles(test_smiles) paths_dict, pointer_dict, rings_dict = get_shortest_paths(rd_mol) paths_dict_target = { (0, 1): (0, 1), (0, 2): (0, 1, 2), (0, 3): (0, 1, 3), (1, 2): (1, 2), (1, 3): (1, 3), (2, 3): (2, 1, 3) } assert_dict_equal(paths_dict, paths_dict_target) assert_dict_equal(pointer_dict, {}) paths_dict, pointer_dict, rings_dict = get_shortest_paths( rd_mol, max_path_length=1) paths_dict_target = {(0, 1): (0, 1), (1, 2): (1, 2), (1, 3): (1, 3)} pointer_dict_target = { (0, 2): 1, (2, 0): 1, (0, 3): 1, (3, 0): 1, (2, 3): 1, (3, 2): 1 } assert_dict_equal(paths_dict, paths_dict_target) assert_dict_equal(pointer_dict, pointer_dict_target)
def test_path_input(): args = get_args() args.max_path_length = 3 args.self_attn = True args.ring_embed = True smiles = ['C1CC1CO', 'o2c1ccccc1cc2OC'] mols = [Chem.MolFromSmiles(s) for s in smiles] n_atoms = [m.GetNumAtoms() for m in mols] shortest_paths = [get_shortest_paths(m, 5) for m in mols] path_input1, path_mask1 = path_utils.get_path_input([mols[0]], [shortest_paths[0]], n_atoms[0], args, output_tensor=False) path_input2, path_mask2 = path_utils.get_path_input([mols[1]], [shortest_paths[1]], n_atoms[1], args, output_tensor=False) path_input1 = path_input1.squeeze(0) # Remove batch dimension path_mask1 = path_mask1.squeeze(0) # Remove batch dimension path_input2 = path_input2.squeeze(0) # Remove batch dimension path_mask2 = path_mask2.squeeze(0) # Remove batch dimension path_input, path_mask = path_utils.merge_path_inputs( [path_input1, path_input2], [path_mask1, path_mask2], max(n_atoms), args)
def test_simple_aro_ring(): test_smiles = 'c1ccccc1' test_mol = Chem.MolFromSmiles(test_smiles) n_atoms = test_mol.GetNumAtoms() paths_dict, pointer_dict, rings_dict = get_shortest_paths( test_mol, max_path_length=3) count = 0 for i in range(n_atoms): for j in range(i, n_atoms, 1): # Count all pairs including itself count += 1 assert (i, j) in rings_dict assert rings_dict[(i, j)] == [(6, True)] assert len(rings_dict) == count
def test_simple_ring(): test_smiles = 'C1NCCCC1' test_mol = Chem.MolFromSmiles(test_smiles) n_atoms = test_mol.GetNumAtoms() paths_dict, pointer_dict, rings_dict = get_shortest_paths( test_mol, max_path_length=5) count = 0 for i in range(n_atoms): for j in range(i, n_atoms, 1): count += 1 assert (i, j) in rings_dict assert rings_dict[(i, j)] == [(6, False)] assert len(rings_dict) == count
def test_fused_ring(): test_smiles = 'o2c1ccccc1cc2' test_mol = Chem.MolFromSmiles(test_smiles) paths_dict, pointer_dict, rings_dict = get_shortest_paths( test_mol, max_path_length=3) assert (0, 1) in paths_dict assert (0, 4) not in paths_dict assert (6, 1) not in rings_dict assert (1, 6) in rings_dict assert rings_dict[(1, 6)] == [(5, True), (6, True)] assert (1, 3) in rings_dict assert rings_dict[(1, 3)] == [(6, True)]
def test_complex(): test_smiles = 'COCCCNc1nc(NC(C)C)nc(SC)n1' paths_dict, pointer_dict, rings_dict = get_shortest_paths( Chem.MolFromSmiles(test_smiles), max_path_length=3) assert (5, 8) in paths_dict assert (13, 15) in paths_dict assert (13, 16) in paths_dict assert (8, 5) not in paths_dict # only include ordered pairs assert (5, 9) not in paths_dict # path length is 4 assert (8, 16) not in paths_dict # path length is 4 assert (8, 5) not in pointer_dict # should not be in pointers either assert (5, 9) in pointer_dict assert (8, 16) in pointer_dict assert (9, 5) in pointer_dict assert (16, 8) in pointer_dict