def test_acnn():
    remove_dir('tmp1')
    remove_dir('tmp2')

    url = _get_dgl_url('dgllife/example_mols.tar.gz')
    local_path = 'tmp1/example_mols.tar.gz'
    download(url, path=local_path)
    extract_archive(local_path, 'tmp2')

    pocket_mol, pocket_coords = load_molecule(
        'tmp2/example_mols/example.pdb', remove_hs=True)
    ligand_mol, ligand_coords = load_molecule(
        'tmp2/example_mols/example.pdbqt', remove_hs=True)

    remove_dir('tmp1')
    remove_dir('tmp2')

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    g1 = ACNN_graph_construction_and_featurization(ligand_mol,
                                                   pocket_mol,
                                                   ligand_coords,
                                                   pocket_coords)

    model = ACNN()
    model.to(device)
    g1.to(device)
    assert model(g1).shape == torch.Size([1, 1])

    bg = dgl.batch_hetero([g1, g1])
    bg.to(device)
    assert model(bg).shape == torch.Size([2, 1])

    model = ACNN(hidden_sizes=[1, 2],
                 weight_init_stddevs=[1, 1],
                 dropouts=[0.1, 0.],
                 features_to_use=torch.tensor([6., 8.]),
                 radial=[[12.0], [0.0, 2.0], [4.0]])
    model.to(device)
    g1.to(device)
    assert model(g1).shape == torch.Size([1, 1])

    bg = dgl.batch_hetero([g1, g1])
    bg.to(device)
    assert model(bg).shape == torch.Size([2, 1])
예제 #2
0
def test_load_molecule():
    remove_dir('tmp1')
    remove_dir('tmp2')

    url = _get_dgl_url('dgllife/example_mols.tar.gz')
    local_path = 'tmp1/example_mols.tar.gz'
    download(url, path=local_path)
    extract_archive(local_path, 'tmp2')

    load_molecule('tmp2/example_mols/example.sdf')
    load_molecule('tmp2/example_mols/example.mol2',
                  use_conformation=False,
                  sanitize=True)
    load_molecule('tmp2/example_mols/example.pdbqt', calc_charges=True)
    mol, _ = load_molecule('tmp2/example_mols/example.pdb', remove_hs=True)
    assert mol.GetNumAtoms() == mol.GetNumHeavyAtoms()

    remove_dir('tmp1')
    remove_dir('tmp2')
예제 #3
0
def test_acnn_graph_construction_and_featurization():
    remove_dir('tmp1')
    remove_dir('tmp2')

    url = _get_dgl_url('dgllife/example_mols.tar.gz')
    local_path = 'tmp1/example_mols.tar.gz'
    download(url, path=local_path)
    extract_archive(local_path, 'tmp2')

    pocket_mol, pocket_coords = load_molecule('tmp2/example_mols/example.pdb',
                                              remove_hs=True)
    ligand_mol, ligand_coords = load_molecule(
        'tmp2/example_mols/example.pdbqt', remove_hs=True)
    pocket_mol_with_h, pocket_coords_with_h = load_molecule(
        'tmp2/example_mols/example.pdb', remove_hs=False)

    remove_dir('tmp1')
    remove_dir('tmp2')

    # Test default case
    g = ACNN_graph_construction_and_featurization(ligand_mol, pocket_mol,
                                                  ligand_coords, pocket_coords)
    assert set(g.ntypes) == set(['protein_atom', 'ligand_atom'])
    assert set(g.etypes) == set(
        ['protein', 'ligand', 'complex', 'complex', 'complex', 'complex'])
    assert g.number_of_nodes('protein_atom') == 286
    assert g.number_of_nodes('ligand_atom') == 21

    assert g.number_of_edges('protein') == 3432
    assert g.number_of_edges('ligand') == 252
    assert g.number_of_edges(
        ('protein_atom', 'complex', 'protein_atom')) == 3349
    assert g.number_of_edges(('ligand_atom', 'complex', 'ligand_atom')) == 131
    assert g.number_of_edges(('protein_atom', 'complex', 'ligand_atom')) == 121
    assert g.number_of_edges(('ligand_atom', 'complex', 'protein_atom')) == 83

    assert 'atomic_number' in g.nodes['protein_atom'].data
    assert 'atomic_number' in g.nodes['ligand_atom'].data
    assert torch.allclose(g.nodes['protein_atom'].data['mask'],
                          torch.ones(g.number_of_nodes('protein_atom'), 1))
    assert torch.allclose(g.nodes['ligand_atom'].data['mask'],
                          torch.ones(g.number_of_nodes('ligand_atom'), 1))
    assert 'distance' in g.edges['protein'].data
    assert 'distance' in g.edges['ligand'].data
    assert 'distance' in g.edges[('protein_atom', 'complex',
                                  'protein_atom')].data
    assert 'distance' in g.edges[('ligand_atom', 'complex',
                                  'ligand_atom')].data
    assert 'distance' in g.edges[('protein_atom', 'complex',
                                  'ligand_atom')].data
    assert 'distance' in g.edges[('ligand_atom', 'complex',
                                  'protein_atom')].data

    # Test max_num_ligand_atoms and max_num_protein_atoms
    max_num_ligand_atoms = 30
    max_num_protein_atoms = 300
    g = ACNN_graph_construction_and_featurization(
        ligand_mol,
        pocket_mol,
        ligand_coords,
        pocket_coords,
        max_num_ligand_atoms=max_num_ligand_atoms,
        max_num_protein_atoms=max_num_protein_atoms)
    assert g.number_of_nodes('ligand_atom') == max_num_ligand_atoms
    assert g.number_of_nodes('protein_atom') == max_num_protein_atoms
    ligand_mask = torch.zeros(max_num_ligand_atoms, 1)
    ligand_mask[:ligand_mol.GetNumAtoms(), :] = 1.
    assert torch.allclose(ligand_mask, g.nodes['ligand_atom'].data['mask'])
    protein_mask = torch.zeros(max_num_protein_atoms, 1)
    protein_mask[:pocket_mol.GetNumAtoms(), :] = 1.
    assert torch.allclose(protein_mask, g.nodes['protein_atom'].data['mask'])

    # Test neighbor_cutoff
    neighbor_cutoff = 6.
    g = ACNN_graph_construction_and_featurization(
        ligand_mol,
        pocket_mol,
        ligand_coords,
        pocket_coords,
        neighbor_cutoff=neighbor_cutoff)
    assert g.number_of_edges('protein') == 3405
    assert g.number_of_edges('ligand') == 193
    assert g.number_of_edges(
        ('protein_atom', 'complex', 'protein_atom')) == 3331
    assert g.number_of_edges(('ligand_atom', 'complex', 'ligand_atom')) == 123
    assert g.number_of_edges(('protein_atom', 'complex', 'ligand_atom')) == 119
    assert g.number_of_edges(('ligand_atom', 'complex', 'protein_atom')) == 82

    # Test max_num_neighbors
    g = ACNN_graph_construction_and_featurization(ligand_mol,
                                                  pocket_mol,
                                                  ligand_coords,
                                                  pocket_coords,
                                                  max_num_neighbors=6)
    assert g.number_of_edges('protein') == 1716
    assert g.number_of_edges('ligand') == 126
    assert g.number_of_edges(
        ('protein_atom', 'complex', 'protein_atom')) == 1691
    assert g.number_of_edges(('ligand_atom', 'complex', 'ligand_atom')) == 86
    assert g.number_of_edges(('protein_atom', 'complex', 'ligand_atom')) == 40
    assert g.number_of_edges(('ligand_atom', 'complex', 'protein_atom')) == 25

    # Test strip_hydrogens
    g = ACNN_graph_construction_and_featurization(pocket_mol_with_h,
                                                  pocket_mol_with_h,
                                                  pocket_coords_with_h,
                                                  pocket_coords_with_h,
                                                  strip_hydrogens=True)
    assert g.number_of_nodes('ligand_atom') != pocket_mol_with_h.GetNumAtoms()
    assert g.number_of_nodes('protein_atom') != pocket_mol_with_h.GetNumAtoms()
    non_h_atomic_numbers = []
    for i in range(pocket_mol_with_h.GetNumAtoms()):
        atom = pocket_mol_with_h.GetAtomWithIdx(i)
        if atom.GetSymbol() != 'H':
            non_h_atomic_numbers.append(atom.GetAtomicNum())
    non_h_atomic_numbers = torch.tensor(non_h_atomic_numbers).float().reshape(
        -1, 1)
    assert torch.allclose(non_h_atomic_numbers,
                          g.nodes['ligand_atom'].data['atomic_number'])
    assert torch.allclose(non_h_atomic_numbers,
                          g.nodes['protein_atom'].data['atomic_number'])