def test_simple():
    test_smiles = 'CC(O)C'
    rd_mol = Chem.MolFromSmiles(test_smiles)

    paths_dict, pointer_dict, rings_dict = get_shortest_paths(rd_mol)
    paths_dict_target = {
        (0, 1): (0, 1),
        (0, 2): (0, 1, 2),
        (0, 3): (0, 1, 3),
        (1, 2): (1, 2),
        (1, 3): (1, 3),
        (2, 3): (2, 1, 3)
    }
    assert_dict_equal(paths_dict, paths_dict_target)
    assert_dict_equal(pointer_dict, {})

    paths_dict, pointer_dict, rings_dict = get_shortest_paths(
        rd_mol, max_path_length=1)
    paths_dict_target = {(0, 1): (0, 1), (1, 2): (1, 2), (1, 3): (1, 3)}
    pointer_dict_target = {
        (0, 2): 1,
        (2, 0): 1,
        (0, 3): 1,
        (3, 0): 1,
        (2, 3): 1,
        (3, 2): 1
    }
    assert_dict_equal(paths_dict, paths_dict_target)
    assert_dict_equal(pointer_dict, pointer_dict_target)
def test_path_input():
    args = get_args()
    args.max_path_length = 3
    args.self_attn = True
    args.ring_embed = True
    smiles = ['C1CC1CO', 'o2c1ccccc1cc2OC']
    mols = [Chem.MolFromSmiles(s) for s in smiles]
    n_atoms = [m.GetNumAtoms() for m in mols]

    shortest_paths = [get_shortest_paths(m, 5) for m in mols]

    path_input1, path_mask1 = path_utils.get_path_input([mols[0]],
                                                        [shortest_paths[0]],
                                                        n_atoms[0],
                                                        args,
                                                        output_tensor=False)
    path_input2, path_mask2 = path_utils.get_path_input([mols[1]],
                                                        [shortest_paths[1]],
                                                        n_atoms[1],
                                                        args,
                                                        output_tensor=False)

    path_input1 = path_input1.squeeze(0)  # Remove batch dimension
    path_mask1 = path_mask1.squeeze(0)  # Remove batch dimension
    path_input2 = path_input2.squeeze(0)  # Remove batch dimension
    path_mask2 = path_mask2.squeeze(0)  # Remove batch dimension

    path_input, path_mask = path_utils.merge_path_inputs(
        [path_input1, path_input2], [path_mask1, path_mask2], max(n_atoms),
        args)
def test_simple_aro_ring():
    test_smiles = 'c1ccccc1'
    test_mol = Chem.MolFromSmiles(test_smiles)
    n_atoms = test_mol.GetNumAtoms()

    paths_dict, pointer_dict, rings_dict = get_shortest_paths(
        test_mol, max_path_length=3)

    count = 0
    for i in range(n_atoms):
        for j in range(i, n_atoms, 1):  # Count all pairs including itself
            count += 1
            assert (i, j) in rings_dict
            assert rings_dict[(i, j)] == [(6, True)]
    assert len(rings_dict) == count
def test_simple_ring():
    test_smiles = 'C1NCCCC1'
    test_mol = Chem.MolFromSmiles(test_smiles)
    n_atoms = test_mol.GetNumAtoms()

    paths_dict, pointer_dict, rings_dict = get_shortest_paths(
        test_mol, max_path_length=5)

    count = 0
    for i in range(n_atoms):
        for j in range(i, n_atoms, 1):
            count += 1
            assert (i, j) in rings_dict
            assert rings_dict[(i, j)] == [(6, False)]
    assert len(rings_dict) == count
def test_fused_ring():
    test_smiles = 'o2c1ccccc1cc2'
    test_mol = Chem.MolFromSmiles(test_smiles)

    paths_dict, pointer_dict, rings_dict = get_shortest_paths(
        test_mol, max_path_length=3)

    assert (0, 1) in paths_dict
    assert (0, 4) not in paths_dict

    assert (6, 1) not in rings_dict
    assert (1, 6) in rings_dict
    assert rings_dict[(1, 6)] == [(5, True), (6, True)]

    assert (1, 3) in rings_dict
    assert rings_dict[(1, 3)] == [(6, True)]
def test_complex():
    test_smiles = 'COCCCNc1nc(NC(C)C)nc(SC)n1'

    paths_dict, pointer_dict, rings_dict = get_shortest_paths(
        Chem.MolFromSmiles(test_smiles), max_path_length=3)

    assert (5, 8) in paths_dict
    assert (13, 15) in paths_dict
    assert (13, 16) in paths_dict

    assert (8, 5) not in paths_dict  # only include ordered pairs
    assert (5, 9) not in paths_dict  # path length is 4
    assert (8, 16) not in paths_dict  # path length is 4

    assert (8, 5) not in pointer_dict  # should not be in pointers either
    assert (5, 9) in pointer_dict
    assert (8, 16) in pointer_dict
    assert (9, 5) in pointer_dict
    assert (16, 8) in pointer_dict