def test_gbsa_handler():

    patterns = [['[*:1]', 99., 99.], ['[#1:1]', 99., 99.],
                ['[#1:1]~[#7]', 99., 99.], ['[#6:1]', 0.1, 0.2],
                ['[#7:1]', 0.3, 0.4], ['[#8:1]', 0.5, 0.6],
                ['[#9:1]', 0.7, 0.8], ['[#14:1]', 99., 99.],
                ['[#15:1]', 99., 99.], ['[#16:1]', 99., 99.],
                ['[#17:1]', 99., 99.]]

    props = {
        'solvent_dielectric': 78.3,  # matches OBC2,
        'solute_dielectric': 1.0,
        'probe_radius': 0.14,
        'surface_tension': 28.3919551,
        'dielectric_offset': 0.009,
        # GBOBC1
        'alpha': 0.8,
        'beta': 0.0,
        'gamma': 2.909125
    }

    smirks = [x[0] for x in patterns]
    params = np.array([[x[1], x[2]] for x in patterns])

    gbh = nonbonded.GBSAHandler(smirks, params, props)

    obj = gbh.serialize()
    all_handlers = deserialize_handlers(bin_to_str(obj))

    assert len(all_handlers) == 1

    new_gbh = all_handlers[0]
    np.testing.assert_equal(new_gbh.smirks, gbh.smirks)
    np.testing.assert_equal(new_gbh.params, gbh.params)
    assert new_gbh.props == gbh.props
def test_am1ccc():

    patterns = [['[#6X4:1]-[#1:2]', 0.46323257920556493],
                ['[#6X3$(*=[#8,#16]):1]-[#6a:2]', 0.24281402370571598],
                ['[#6X3$(*=[#8,#16]):1]-[#8X1,#8X2:2]', 1.0620166764992722],
                [
                    '[#6X3$(*=[#8,#16]):1]=[#8X1$(*=[#6X3]-[#8X2]):2]',
                    2.227759732057297
                ], ['[#6X3$(*=[#8,#16]):1]=[#8X1,#8X2:2]', 2.8182928673804217],
                ['[#6a:1]-[#8X1,#8X2:2]', 0.5315976926761063],
                ['[#6a:1]-[#1:2]', 0.0], ['[#6a:1]:[#6a:2]', 0.0],
                ['[#6a:1]:[#6a:2]', 0.0],
                ['[#8X1,#8X2:1]-[#1:2]', -2.3692047944101415],
                ['[#16:1]-[#8:2]', 99.]]

    smirks = [x[0] for x in patterns]
    params = np.array([x[1] * np.sqrt(138.935456) for x in patterns])
    props = None

    am1h = nonbonded.AM1CCCHandler(smirks, params, props)
    obj = am1h.serialize()
    all_handlers = deserialize_handlers(bin_to_str(obj))

    assert len(all_handlers) == 1

    new_am1h = all_handlers[0]
    np.testing.assert_equal(new_am1h.smirks, am1h.smirks)
    np.testing.assert_equal(new_am1h.params, am1h.params)
    assert new_am1h.props == am1h.props
Example #3
0
def test_combine_recipe():
    ff_handlers = deserialize_handlers(
        open('ff/params/smirnoff_1_1_0_ccc.py').read())
    aspirin = Chem.AddHs(Chem.MolFromSmiles("CC(=O)OC1=CC=CC=C1C(=O)O"))
    AllChem.EmbedMolecule(aspirin)
    ligand_recipe = md.Recipe.from_rdkit(aspirin, ff_handlers)
    fname = 'tests/data/hif2a_nowater_min.pdb'
    pdb = open(fname, 'r').read()
    openmm_system, openmm_conf, _, _, _, _ = builders.build_protein_system(
        'tests/data/hif2a_nowater_min.pdb')
    protein_recipe = md.Recipe.from_openmm(openmm_system)

    for left_recipe, right_recipe in [[protein_recipe, ligand_recipe],
                                      [ligand_recipe, protein_recipe]]:

        combined_recipe = left_recipe.combine(right_recipe)
        qlj = np.ones((aspirin.GetNumAtoms() + openmm_conf.shape[0], 3))

        left_nonbonded_potential = left_recipe.bound_potentials[-1]
        right_nonbonded_potential = right_recipe.bound_potentials[-1]
        combined_nonbonded_potential = combined_recipe.bound_potentials[-1]

        left_idxs = left_nonbonded_potential.get_exclusion_idxs()
        right_idxs = right_nonbonded_potential.get_exclusion_idxs()
        combined_idxs = combined_nonbonded_potential.get_exclusion_idxs()

        n_left = len(left_recipe.masses)
        n_right = len(right_recipe.masses)

        np.testing.assert_array_equal(
            np.concatenate([left_idxs, right_idxs + n_left]), combined_idxs)

        for bp in combined_recipe.bound_potentials:
            bp.bound_impl(precision=np.float32)
def test_proper_torsion():

    # proper torsions have a variadic number of terms

    patterns = [
        ['[*:1]-[#6X3:2]=[#6X3:3]-[*:4]', [[99., 99., 99.]]],
        ['[*:1]-[#6X3:2]=[#6X3:3]-[#35:4]', [[99., 99., 99.]]],
        ['[#9:1]-[#6X3:2]=[#6X3:3]-[#35:4]', [[1., 2., 3.], [4., 5., 6.]]],
        [
            '[#35:1]-[#6X3:2]=[#6X3:3]-[#35:4]',
            [[7., 8., 9.], [1., 3., 5.], [4., 4., 4.]]
        ],
        ['[#9:1]-[#6X3:2]=[#6X3:3]-[#9:4]', [[7., 8., 9.]]],
    ]

    smirks = [x[0] for x in patterns]
    params = [x[1] for x in patterns]
    props = None

    ph = bonded.ProperTorsionHandler(smirks, params, None)
    obj = ph.serialize()
    all_handlers = deserialize_handlers(bin_to_str(obj))

    assert len(all_handlers) == 1

    new_ph = all_handlers[0]
    np.testing.assert_equal(new_ph.smirks, ph.smirks)
    np.testing.assert_equal(new_ph.params, ph.params)
    assert new_ph.props == ph.props
def test_improper_torsion():

    patterns = [[
        '[*:1]~[#6X3:2](~[*:3])~[*:4]', 1.5341333333333333, 3.141592653589793,
        2.0
    ], ['[*:1]~[#6X3:2](~[#8X1:3])~[#8:4]', 99., 99., 99.],
                [
                    '[*:1]~[#7X3$(*~[#15,#16](!-[*])):2](~[*:3])~[*:4]', 99.,
                    99., 99.
                ],
                [
                    '[*:1]~[#7X3$(*~[#6X3]):2](~[*:3])~[*:4]',
                    1.3946666666666667, 3.141592653589793, 2.0
                ], ['[*:1]~[#7X3$(*~[#7X2]):2](~[*:3])~[*:4]', 99., 99., 99.],
                [
                    '[*:1]~[#7X3$(*@1-[*]=,:[*][*]=,:[*]@1):2](~[*:3])~[*:4]',
                    99., 99., 99.
                ], ['[*:1]~[#6X3:2](=[#7X2,#7X3+1:3])~[#7:4]', 99., 99., 99.]]

    smirks = [x[0] for x in patterns]
    params = np.array([[x[1], x[2], x[3]] for x in patterns])
    imph = bonded.ImproperTorsionHandler(smirks, params, None)

    obj = imph.serialize()
    all_handlers = deserialize_handlers(bin_to_str(obj))

    assert len(all_handlers) == 1

    new_imph = all_handlers[0]
    np.testing.assert_equal(new_imph.smirks, imph.smirks)
    np.testing.assert_equal(new_imph.params, imph.params)
    assert new_imph.props == imph.props
Example #6
0
def test_recipe_from_rdkit():
    ff_handlers = deserialize_handlers(
        open('ff/params/smirnoff_1_1_0_ccc.py').read())
    suppl = Chem.SDMolSupplier('tests/data/ligands_40.sdf', removeHs=False)
    for mol_idx, mol in enumerate(suppl):
        print(mol_idx, Chem.MolToSmiles(mol))
        system = md.Recipe.from_rdkit(mol, ff_handlers)
        if mol_idx > 2:
            break
def test_simple_charge_handler():

    patterns = [
        ['[#1:1]', 99.],
        ['[#1:1]-[#6X4]', 99.],
        ['[#1:1]-[#6X4]-[#7,#8,#9,#16,#17,#35]', 99.],
        ['[#1:1]-[#6X4](-[#7,#8,#9,#16,#17,#35])-[#7,#8,#9,#16,#17,#35]', 99.],
        [
            '[#1:1]-[#6X4](-[#7,#8,#9,#16,#17,#35])(-[#7,#8,#9,#16,#17,#35])-[#7,#8,#9,#16,#17,#35]',
            99.
        ],
        ['[#1:1]-[#6X4]~[*+1,*+2]', 99.],
        ['[#1:1]-[#6X3]', 99.],
        ['[#1:1]-[#6X3]~[#7,#8,#9,#16,#17,#35]', 99.],
        ['[#1:1]-[#6X3](~[#7,#8,#9,#16,#17,#35])~[#7,#8,#9,#16,#17,#35]', 99.],
        ['[#1:1]-[#6X2]', 99.],
        ['[#1:1]-[#7]', 99.],
        ['[#1:1]-[#8]', 99.],
        ['[#1:1]-[#16]', 99.],
        ['[#6:1]', 0.7],
        ['[#6X2:1]', 99.],
        ['[#6X4:1]', 0.1],
        ['[#8:1]', 99.],
        ['[#8X2H0+0:1]', 0.5],
        ['[#8X2H1+0:1]', 99.],
        ['[#7:1]', 0.3],
        ['[#16:1]', 99.],
        ['[#15:1]', 99.],
        ['[#9:1]', 1.0],
        ['[#17:1]', 99.],
        ['[#35:1]', 99.],
        ['[#53:1]', 99.],
        ['[#3+1:1]', 99.],
        ['[#11+1:1]', 99.],
        ['[#19+1:1]', 99.],
        ['[#37+1:1]', 99.],
        ['[#55+1:1]', 99.],
        ['[#9X0-1:1]', 99.],
        ['[#17X0-1:1]', 99.],
        ['[#35X0-1:1]', 99.],
        ['[#53X0-1:1]', 99.],
    ]

    smirks = [x[0] for x in patterns]
    params = np.array([x[1] for x in patterns])
    props = None

    sch = nonbonded.SimpleChargeHandler(smirks, params, props)
    obj = sch.serialize()
    all_handlers = deserialize_handlers(bin_to_str(obj))

    assert len(all_handlers) == 1

    new_sch = all_handlers[0]
    np.testing.assert_equal(new_sch.smirks, sch.smirks)
    np.testing.assert_equal(new_sch.params, sch.params)
    assert new_sch.props == sch.props
def test_am1bcc():

    smirks = []
    params = []
    props = None

    am1 = nonbonded.AM1BCCHandler(smirks, params, props)
    obj = am1.serialize()
    all_handlers = deserialize_handlers(bin_to_str(obj))

    am1 = all_handlers[0]
    np.testing.assert_equal(am1.smirks, am1.smirks)
    np.testing.assert_equal(am1.params, am1.params)
    assert am1.props == am1.props
Example #9
0
    def setUp(self, *args, **kwargs):

        suppl = Chem.SDMolSupplier('tests/data/benzene_phenol_sparse.sdf',
                                   removeHs=False)
        all_mols = [x for x in suppl]

        self.mol_a = all_mols[0]
        self.mol_b = all_mols[1]

        # atom type free
        ff_handlers = deserialize_handlers(
            open('ff/params/smirnoff_1_1_0_recharge.py').read())

        self.ff = Forcefield(ff_handlers)

        super(BenzenePhenolSparseTest, self).__init__(*args, **kwargs)
Example #10
0
    def test_bad_factor(self):
        # test a bad mapping that results in a non-cancellable endpoint
        suppl = Chem.SDMolSupplier('tests/data/ligands_40.sdf', removeHs=False)
        all_mols = [x for x in suppl]
        mol_a = all_mols[0]
        mol_b = all_mols[1]

        ff_handlers = deserialize_handlers(
            open('ff/params/smirnoff_1_1_0_recharge.py').read())
        ff = Forcefield(ff_handlers)

        core = np.array([[4, 1], [5, 2], [6, 3], [7, 4], [8, 5], [9,
                                                                  6], [10, 7],
                         [11, 8], [12, 9], [13, 10], [15, 11], [16, 12],
                         [18, 14], [34, 31], [17, 13], [23, 23], [33, 30],
                         [32, 28], [31, 27], [30, 26], [19, 15], [20, 16],
                         [21, 17]])

        with self.assertRaises(topology.AtomMappingError):
            st = topology.SingleTopology(mol_a, mol_b, core, ff)
Example #11
0
def test_am1_differences():

    ff_raw = open("ff/params/smirnoff_1_1_0_ccc.py").read()
    ff_handlers = deserialize_handlers(ff_raw)
    for ccc in ff_handlers:
        if isinstance(ccc, nonbonded.AM1CCCHandler):
            break

    suppl = Chem.SDMolSupplier('tests/data/ligands_40.sdf', removeHs=False)
    smi = "[H]c1c(OP(=S)(OC([H])([H])C([H])([H])[H])OC([H])([H])C([H])([H])[H])nc(C([H])(C([H])([H])[H])C([H])([H])[H])nc1C([H])([H])[H]"
    smi = "Clc1c(Cl)c(Cl)c(-c2c(Cl)c(Cl)c(Cl)c(Cl)c2Cl)c(Cl)c1Cl"
    mol = Chem.MolFromSmiles(smi)
    mol = Chem.AddHs(mol)
    mol.SetProp("_Name", "Debug")
    assert AllChem.EmbedMolecule(mol) == 0

    suppl = [mol]
    am1 = nonbonded.AM1Handler([], [], None)
    bcc = nonbonded.AM1BCCHandler([], [], None)

    for mol in suppl:

        am1_params = am1.parameterize(mol)
        ccc_params = ccc.parameterize(mol)
        bcc_params = bcc.parameterize(mol)

        if np.sum(np.abs(ccc_params - bcc_params)) > 0.1:
        
            print(mol.GetProp("_Name"), Chem.MolToSmiles(mol))
            print("  AM1    CCC    BCC  S ?")
            for atom_idx, atom in enumerate(mol.GetAtoms()):
                a = am1_params[atom_idx]
                b = bcc_params[atom_idx]
                c = ccc_params[atom_idx]
                print("{:6.2f}".format(a), "{:6.2f}".format(c), "{:6.2f}".format(b), atom.GetSymbol(), end="")
                if np.abs(b-c) > 0.1:
                    print(" *")
                else:
                    print(" ")

            assert 0
Example #12
0
    def test_good_factor(self):
        # test a good mapping
        suppl = Chem.SDMolSupplier('tests/data/ligands_40.sdf', removeHs=False)
        all_mols = [x for x in suppl]
        mol_a = all_mols[1]
        mol_b = all_mols[4]

        ff_handlers = deserialize_handlers(
            open('ff/params/smirnoff_1_1_0_recharge.py').read())
        ff = Forcefield(ff_handlers)

        core = np.array([[0, 0], [2, 2], [1, 1], [6, 6], [5, 5], [4, 4],
                         [3, 3], [15, 16], [16, 17], [17, 18], [18, 19],
                         [19, 20], [20, 21], [32, 30], [26, 25], [27, 26],
                         [7, 7], [8, 8], [9, 9], [10, 10], [29, 11], [11, 12],
                         [12, 13], [14, 15], [31, 29], [13, 14], [23, 24],
                         [30, 28], [28, 27], [21, 22]])

        st = topology.SingleTopology(mol_a, mol_b, core, ff)

        # test that the vjps work
        _ = jax.vjp(st.parameterize_harmonic_bond,
                    ff.hb_handle.params,
                    has_aux=True)
        _ = jax.vjp(st.parameterize_harmonic_angle,
                    ff.ha_handle.params,
                    has_aux=True)
        _ = jax.vjp(st.parameterize_proper_torsion,
                    ff.pt_handle.params,
                    has_aux=True)
        _ = jax.vjp(st.parameterize_improper_torsion,
                    ff.it_handle.params,
                    has_aux=True)
        _ = jax.vjp(st.parameterize_nonbonded,
                    ff.q_handle.params,
                    ff.lj_handle.params,
                    has_aux=True)
def test_harmonic_bond():

    patterns = [['[#6X4:1]-[#6X4:2]', 0.1,
                 0.2], ['[#6X4:1]-[#6X3:2]', 99., 99.],
                ['[#6X4:1]-[#6X3:2]=[#8X1+0]', 99., 99.],
                ['[#6X3:1]-[#6X3:2]', 99.,
                 99.], ['[#6X3:1]:[#6X3:2]', 99., 99.],
                ['[#6X3:1]=[#6X3:2]', 99., 99.], ['[#6:1]-[#7:2]', 0.1, 0.2],
                ['[#6X3:1]-[#7X3:2]', 99., 99.],
                ['[#6X4:1]-[#7X3:2]-[#6X3]=[#8X1+0]', 99., 99.],
                ['[#6X3:1](=[#8X1+0])-[#7X3:2]', 99., 99.],
                ['[#6X3:1]-[#7X2:2]', 99., 99.],
                ['[#6X3:1]:[#7X2,#7X3+1:2]', 99., 99.],
                ['[#6X3:1]=[#7X2,#7X3+1:2]', 99., 99.],
                ['[#6:1]-[#8:2]', 99., 99.], ['[#6X3:1]-[#8X1-1:2]', 99., 99.],
                ['[#6X4:1]-[#8X2H0:2]', 0.3, 0.4],
                ['[#6X3:1]-[#8X2:2]', 99., 99.],
                ['[#6X3:1]-[#8X2H1:2]', 99., 99.],
                ['[#6X3a:1]-[#8X2H0:2]', 99., 99.],
                ['[#6X3:1](=[#8X1])-[#8X2H0:2]', 99., 99.],
                ['[#6:1]=[#8X1+0,#8X2+1:2]', 99., 99.],
                ['[#6X3:1](~[#8X1])~[#8X1:2]', 99., 99.],
                ['[#6X3:1]~[#8X2+1:2]~[#6X3]', 99., 99.],
                ['[#6X2:1]-[#6:2]', 99., 99.], ['[#6X2:1]-[#6X4:2]', 99., 99.],
                ['[#6X2:1]=[#6X3:2]', 99., 99.], ['[#6:1]#[#7:2]', 99., 99.],
                ['[#6X2:1]#[#6X2:2]', 99.,
                 99.], ['[#6X2:1]-[#8X2:2]', 99., 99.],
                ['[#6X2:1]-[#7:2]', 99., 99.], ['[#6X2:1]=[#7:2]', 99., 99.],
                ['[#16:1]=[#6:2]', 99., 99.], ['[#6X2:1]=[#16:2]', 99., 99.],
                ['[#7:1]-[#7:2]', 99., 99.], ['[#7X3:1]-[#7X2:2]', 99., 99.],
                ['[#7X2:1]-[#7X2:2]', 99., 99.], ['[#7:1]:[#7:2]', 99., 99.],
                ['[#7:1]=[#7:2]', 99., 99.], ['[#7+1:1]=[#7-1:2]', 99., 99.],
                ['[#7:1]#[#7:2]', 99., 99.], ['[#7:1]-[#8X2:2]', 99., 99.],
                ['[#7:1]~[#8X1:2]', 99., 99.], ['[#8X2:1]-[#8X2:2]', 99., 99.],
                ['[#16:1]-[#6:2]', 99., 99.], ['[#16:1]-[#1:2]', 99., 99.],
                ['[#16:1]-[#16:2]', 99., 99.], ['[#16:1]-[#9:2]', 99., 99.],
                ['[#16:1]-[#17:2]', 99., 99.], ['[#16:1]-[#35:2]', 99., 99.],
                ['[#16:1]-[#53:2]', 99., 99.],
                ['[#16X2,#16X1-1,#16X3+1:1]-[#6X4:2]', 99., 99.],
                ['[#16X2,#16X1-1,#16X3+1:1]-[#6X3:2]', 99., 99.],
                ['[#16X2:1]-[#7:2]', 99., 99.],
                ['[#16X2:1]-[#8X2:2]', 99., 99.],
                ['[#16X2:1]=[#8X1,#7X2:2]', 99., 99.],
                ['[#16X4,#16X3!+1:1]-[#6:2]', 99., 99.],
                ['[#16X4,#16X3:1]~[#7:2]', 99., 99.],
                ['[#16X4,#16X3:1]-[#8X2:2]', 99., 99.],
                ['[#16X4,#16X3:1]~[#8X1:2]', 99., 99.],
                ['[#15:1]-[#1:2]', 99., 99.], ['[#15:1]~[#6:2]', 99., 99.],
                ['[#15:1]-[#7:2]', 99., 99.], ['[#15:1]=[#7:2]', 99., 99.],
                ['[#15:1]~[#8X2:2]', 99., 99.], ['[#15:1]~[#8X1:2]', 99., 99.],
                ['[#16:1]-[#15:2]', 99., 99.], ['[#15:1]=[#16X1:2]', 99., 99.],
                ['[#6:1]-[#9:2]', 99., 99.], ['[#6X4:1]-[#9:2]', 0.6, 0.7],
                ['[#6:1]-[#17:2]', 99., 99.], ['[#6X4:1]-[#17:2]', 99., 99.],
                ['[#6:1]-[#35:2]', 99., 99.], ['[#6X4:1]-[#35:2]', 99., 99.],
                ['[#6:1]-[#53:2]', 99., 99.], ['[#6X4:1]-[#53:2]', 99., 99.],
                ['[#7:1]-[#9:2]', 99., 99.], ['[#7:1]-[#17:2]', 99., 99.],
                ['[#7:1]-[#35:2]', 99., 99.], ['[#7:1]-[#53:2]', 99., 99.],
                ['[#15:1]-[#9:2]', 99., 99.], ['[#15:1]-[#17:2]', 99., 99.],
                ['[#15:1]-[#35:2]', 99., 99.], ['[#15:1]-[#53:2]', 99., 99.],
                ['[#6X4:1]-[#1:2]', 99., 99.], ['[#6X3:1]-[#1:2]', 99., 99.],
                ['[#6X2:1]-[#1:2]', 99., 99.], ['[#7:1]-[#1:2]', 99., 99.],
                ['[#8:1]-[#1:2]', 99., 99.1]]

    smirks = [x[0] for x in patterns]
    params = np.array([[x[1], x[2]] for x in patterns])
    props = None
    hbh = bonded.HarmonicBondHandler(smirks, params, None)

    obj = hbh.serialize()
    all_handlers = deserialize_handlers(bin_to_str(obj))

    assert len(all_handlers) == 1

    new_hbh = all_handlers[0]
    np.testing.assert_equal(new_hbh.smirks, hbh.smirks)
    np.testing.assert_equal(new_hbh.params, hbh.params)

    assert new_hbh.props == hbh.props
Example #14
0
def main(args, stage):

    # benzene = Chem.AddHs(Chem.MolFromSmiles("c1ccccc1")) # a
    # phenol = Chem.AddHs(Chem.MolFromSmiles("Oc1ccccc1")) # b
    #01234567890
    benzene = Chem.AddHs(Chem.MolFromSmiles("C1=CC=C2C=CC=CC2=C1"))  # a
    phenol = Chem.AddHs(Chem.MolFromSmiles("C1=CC=C2C=CC=CC2=C1"))  # b

    AllChem.EmbedMolecule(benzene)
    AllChem.EmbedMolecule(phenol)

    ff_handlers = deserialize_handlers(
        open('ff/params/smirnoff_1_1_0_ccc.py').read())
    r_benzene = Recipe.from_rdkit(benzene, ff_handlers)
    r_phenol = Recipe.from_rdkit(phenol, ff_handlers)

    r_combined = r_benzene.combine(r_phenol)
    core_pairs = np.array(
        [
            [0, 0],
            [1, 1],
            [2, 2],
            [3, 3],
            [4, 4],
            [5, 5],
            [6, 6],
            [7, 7],
            [8, 8],
            [9, 9],
            # [10,10]
        ],
        dtype=np.int32)
    core_pairs[:, 1] += benzene.GetNumAtoms()

    a_idxs = np.arange(benzene.GetNumAtoms())
    b_idxs = np.arange(phenol.GetNumAtoms()) + benzene.GetNumAtoms()

    core_k = 20.0

    if stage == 0:
        centroid_k = 200.0
        rbfe.stage_0(r_combined, b_idxs, core_pairs, centroid_k, core_k)
        # lambda_schedule = np.linspace(0.0, 1.0, 2)
        # lambda_schedule = np.array([0.0, 0.0, 0.0, 0.0, 0.0])
        lambda_schedule = np.array([0.0, 0.0, 0.0, 0.0, 0.0])
    elif stage == 1:
        rbfe.stage_1(r_combined, a_idxs, b_idxs, core_pairs, core_k)
        lambda_schedule = np.linspace(0.0, 1.2, 60)
    else:
        assert 0

    system, host_coords, box, topology = builders.build_water_system(4.0)

    r_host = Recipe.from_openmm(system)
    r_final = r_host.combine(r_combined)

    # minimize coordinates of host + ligand A
    ha_coords = np.concatenate([host_coords, get_romol_conf(benzene)])

    pool = Pool(args.num_gpus)

    # we need to run this in a subprocess since the cuda runtime
    # must not be initialized in the master thread due to lack of
    # fork safety
    r_minimize = minimize_setup(r_host, r_benzene)
    ha_coords = pool.map(
        minimize,
        [(r_minimize.bound_potentials, r_minimize.masses, ha_coords, box)],
        chunksize=1)
    # this is a list
    ha_coords = ha_coords[0]
    pool.close()

    pool = Pool(args.num_gpus)

    x0 = np.concatenate([ha_coords, get_romol_conf(phenol)])

    masses = np.concatenate([r_host.masses, r_benzene.masses, r_phenol.masses])

    seed = np.random.randint(np.iinfo(np.int32).max)

    intg = LangevinIntegrator(300.0, 1.5e-3, 1.0, masses, seed)

    # production run at various values of lambda
    for epoch in range(10):
        avg_du_dls = []

        run_args = []
        for lamb_idx, lamb in enumerate(lambda_schedule):
            run_args.append(
                (lamb, intg, r_final.bound_potentials, r_final.masses, x0, box,
                 lamb_idx % args.num_gpus, stage))

        avg_du_dls = pool.map(run, run_args, chunksize=1)

        print("stage", stage, "epoch", epoch, "dG",
              np.trapz(avg_du_dls, lambda_schedule))
def test_lennard_jones_handler():

    patterns = [
        ['[#1:1]', 99., 999.],
        ['[#1:1]-[#6X4]', 99., 999.],
        ['[#1:1]-[#6X4]-[#7,#8,#9,#16,#17,#35]', 99., 999.],
        [
            '[#1:1]-[#6X4](-[#7,#8,#9,#16,#17,#35])-[#7,#8,#9,#16,#17,#35]',
            99., 999.
        ],
        [
            '[#1:1]-[#6X4](-[#7,#8,#9,#16,#17,#35])(-[#7,#8,#9,#16,#17,#35])-[#7,#8,#9,#16,#17,#35]',
            99., 999.
        ],
        ['[#1:1]-[#6X4]~[*+1,*+2]', 99., 999.],
        ['[#1:1]-[#6X3]', 99., 999.],
        ['[#1:1]-[#6X3]~[#7,#8,#9,#16,#17,#35]', 99., 999.],
        [
            '[#1:1]-[#6X3](~[#7,#8,#9,#16,#17,#35])~[#7,#8,#9,#16,#17,#35]',
            99., 999.
        ],
        ['[#1:1]-[#6X2]', 99., 999.],
        ['[#1:1]-[#7]', 99., 999.],
        ['[#1:1]-[#8]', 99., 999.],
        ['[#1:1]-[#16]', 99., 999.],
        ['[#6:1]', 0.7, 0.8],
        ['[#6X2:1]', 99., 999.],
        ['[#6X4:1]', 0.1, 0.2],
        ['[#8:1]', 99., 999.],
        ['[#8X2H0+0:1]', 0.5, 0.6],
        ['[#8X2H1+0:1]', 99., 999.],
        ['[#7:1]', 0.3, 0.4],
        ['[#16:1]', 99., 999.],
        ['[#15:1]', 99., 999.],
        ['[#9:1]', 1.0, 1.1],
        ['[#17:1]', 99., 999.],
        ['[#35:1]', 99., 999.],
        ['[#53:1]', 99., 999.],
        ['[#3+1:1]', 99., 999.],
        ['[#11+1:1]', 99., 999.],
        ['[#19+1:1]', 99., 999.],
        ['[#37+1:1]', 99., 999.],
        ['[#55+1:1]', 99., 999.],
        ['[#9X0-1:1]', 99., 999.],
        ['[#17X0-1:1]', 99., 999.],
        ['[#35X0-1:1]', 99., 999.],
        ['[#53X0-1:1]', 99., 999.],
    ]

    smirks = [x[0] for x in patterns]
    params = np.array([[x[1], x[2]] for x in patterns])
    props = None

    ljh = nonbonded.LennardJonesHandler(smirks, params, props)
    obj = ljh.serialize()
    all_handlers = deserialize_handlers(bin_to_str(obj))

    ljh = all_handlers[0]
    np.testing.assert_equal(ljh.smirks, ljh.smirks)
    np.testing.assert_equal(ljh.params, ljh.params)
    assert ljh.props == ljh.props
Example #16
0
    for address in worker_address_list:
        print("connecting to", address)
        channel = grpc.insecure_channel(address,
            options = [
                ('grpc.max_send_message_length', 500 * 1024 * 1024),
                ('grpc.max_receive_message_length', 500 * 1024 * 1024)
            ]
        )

        stub = service_pb2_grpc.WorkerStub(channel)
        stubs.append(stub)

    ff_raw = open(forcefield, "r").read()

    ff_handlers = deserialize_handlers(ff_raw)

    box_width = 3.0
    host_system, host_coords, box, _ = water_box.prep_system(box_width)

    lambda_schedule = np.array([float(x) for x in general_cfg['lambda_schedule'].split(',')])

    num_steps = int(general_cfg['n_steps'])

    for epoch in range(100):

        print("Starting Epoch", epoch, datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S"))

        epoch_dir = os.path.join(general_cfg["out_dir"], "epoch_"+str(epoch))

        if not os.path.exists(epoch_dir):
Example #17
0
def pose_dock(
    guests_sdfile,
    host_pdbfile,
    transition_type,
    n_steps,
    transition_steps,
    max_lambda,
    outdir,
    random_rotation=False,
    constant_atoms=[],
):
    """Runs short simulations in which the guests phase in or out over time

    Parameters
    ----------

    guests_sdfile: path to input sdf with guests to pose/dock
    host_pdbfile: path to host pdb file to dock into
    transition_type: "insertion" or "deletion"
    n_steps: how many total steps of simulation to do (recommended: <= 1000)
    transition_steps: how many steps to insert/delete the guest over (recommended: <= 500)
        (must be <= n_steps)
    max_lambda: lambda value the guest should insert from or delete to
        (recommended: 1.0 for work calulation, 0.25 to stay close to original pose)
        (must be =1 for work calculation to be applicable)
    outdir: where to write output (will be created if it does not already exist)
    random_rotation: whether to apply a random rotation to each guest before inserting
    constant_atoms: atom numbers from the host_pdbfile to hold mostly fixed across the simulation
        (1-indexed, like PDB files)

    Output
    ------

    A pdb & sdf file every 100 steps (outdir/<guest_name>_<step>.pdb)
    stdout every 100 steps noting the step number, lambda value, and energy
    stdout for each guest noting the work of transition
    stdout for each guest noting how long it took to run

    Note
    ----
    If any norm of force per atom exceeds 20000 kJ/(mol*nm) [MAX_NORM_FORCE defined in docking/report.py],
    the simulation for that guest will stop and the work will not be calculated.
    """
    assert transition_steps <= n_steps
    assert transition_type in ("insertion", "deletion")
    if random_rotation:
        assert transition_type == "insertion"

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    host_mol = Chem.MolFromPDBFile(host_pdbfile, removeHs=False)
    amber_ff = app.ForceField("amber99sbildn.xml", "tip3p.xml")
    host_file = PDBFile(host_pdbfile)
    host_system = amber_ff.createSystem(
        host_file.topology,
        nonbondedMethod=app.NoCutoff,
        constraints=None,
        rigidWater=False,
    )
    host_conf = []
    for x, y, z in host_file.positions:
        host_conf.append([to_md_units(x), to_md_units(y), to_md_units(z)])
    host_conf = np.array(host_conf)

    final_potentials = []
    host_potentials, host_masses = openmm_deserializer.deserialize_system(
        host_system, cutoff=1.2)
    host_nb_bp = None
    for bp in host_potentials:
        if isinstance(bp, potentials.Nonbonded):
            # (ytz): hack to ensure we only have one nonbonded term
            assert host_nb_bp is None
            host_nb_bp = bp
        else:
            final_potentials.append(bp)

    # TODO (ytz): we should really fix this later on. This padding was done to
    # address the particles that are too close to the boundary.
    padding = 0.1
    box_lengths = np.amax(host_conf, axis=0) - np.amin(host_conf, axis=0)
    box_lengths = box_lengths + padding
    box = np.eye(3, dtype=np.float64) * box_lengths

    suppl = Chem.SDMolSupplier(guests_sdfile, removeHs=False)
    for guest_mol in suppl:
        start_time = time.time()
        guest_name = guest_mol.GetProp("_Name")
        guest_ff_handlers = deserialize_handlers(
            open(
                os.path.join(
                    os.path.dirname(os.path.abspath(__file__)),
                    "..",
                    "ff/params/smirnoff_1_1_0_ccc.py",
                )).read())
        ff = Forcefield(guest_ff_handlers)
        guest_base_topology = topology.BaseTopology(guest_mol, ff)

        # combine
        hgt = topology.HostGuestTopology(host_nb_bp, guest_base_topology)
        # setup the parameter handlers for the ligand
        bonded_tuples = [[hgt.parameterize_harmonic_bond, ff.hb_handle],
                         [hgt.parameterize_harmonic_angle, ff.ha_handle],
                         [hgt.parameterize_proper_torsion, ff.pt_handle],
                         [hgt.parameterize_improper_torsion, ff.it_handle]]
        these_potentials = list(final_potentials)
        # instantiate the vjps while parameterizing (forward pass)
        for fn, handle in bonded_tuples:
            params, potential = fn(handle.params)
            these_potentials.append(potential.bind(params))
        nb_params, nb_potential = hgt.parameterize_nonbonded(
            ff.q_handle.params, ff.lj_handle.params)
        these_potentials.append(nb_potential.bind(nb_params))
        bps = these_potentials

        guest_masses = [a.GetMass() for a in guest_mol.GetAtoms()]
        masses = np.concatenate([host_masses, guest_masses])

        for atom_num in constant_atoms:
            masses[atom_num - 1] += 50000

        conformer = guest_mol.GetConformer(0)
        mol_conf = np.array(conformer.GetPositions(), dtype=np.float64)
        mol_conf = mol_conf / 10  # convert to md_units

        if random_rotation:
            center = np.mean(mol_conf, axis=0)
            mol_conf -= center
            from scipy.stats import special_ortho_group

            mol_conf = np.matmul(mol_conf, special_ortho_group.rvs(3))
            mol_conf += center

        x0 = np.concatenate([host_conf, mol_conf])  # combined geometry
        v0 = np.zeros_like(x0)

        seed = 2021
        intg = LangevinIntegrator(300, 1.5e-3, 1.0, masses, seed).impl()

        impls = []
        precision = np.float32
        for b in bps:
            p_impl = b.bound_impl(precision)
            impls.append(p_impl)

        ctxt = custom_ops.Context(x0, v0, box, intg, impls)

        # collect a du_dl calculation once every other step
        subsample_freq = 2
        du_dl_obs = custom_ops.FullPartialUPartialLambda(impls, subsample_freq)
        ctxt.add_observable(du_dl_obs)

        if transition_type == "insertion":
            new_lambda_schedule = np.concatenate([
                np.linspace(max_lambda, 0.0, transition_steps),
                np.zeros(n_steps - transition_steps),
            ])
        elif transition_type == "deletion":
            new_lambda_schedule = np.concatenate([
                np.linspace(0.0, max_lambda, transition_steps),
                np.ones(n_steps - transition_steps) * max_lambda,
            ])
        else:
            raise (RuntimeError(
                'invalid `transition_type` (must be one of ["insertion", "deletion"])'
            ))

        calc_work = True
        for step, lamb in enumerate(new_lambda_schedule):
            ctxt.step(lamb)
            if step % 100 == 0:
                report.report_step(ctxt, step, lamb, box, bps, impls,
                                   guest_name, n_steps, 'pose_dock')
                host_coords = ctxt.get_x_t()[:len(host_conf)] * 10
                guest_coords = ctxt.get_x_t()[len(host_conf):] * 10
                report.write_frame(host_coords, host_mol, guest_coords,
                                   guest_mol, guest_name, outdir, step, 'pd')
            if step in (0, int(n_steps / 2), n_steps - 1):
                if report.too_much_force(ctxt, lamb, box, bps, impls):
                    calc_work = False
                    break

        # Note: this condition only applies for ABFE, not RBFE
        if (abs(du_dl_obs.full_du_dl()[0]) > 0.001
                or abs(du_dl_obs.full_du_dl()[-1]) > 0.001):
            print("Error: du_dl endpoints are not ~0")
            calc_work = False

        if calc_work:
            work = np.trapz(du_dl_obs.full_du_dl(),
                            new_lambda_schedule[::subsample_freq])
            print(f"guest_name: {guest_name}\twork: {work:.2f}")
        end_time = time.time()
        print(f"{guest_name} took {(end_time - start_time):.2f} seconds")
Example #18
0
        type=int,
        help="number of absolute lambda windows",
        required=True
    )

    cmd_args = parser.parse_args()

    multiprocessing.set_start_method('spawn') # CUDA runtime is not forkable
    pool = multiprocessing.Pool(cmd_args.num_gpus)

    suppl = Chem.SDMolSupplier('tests/data/benzene_fluorinated.sdf', removeHs=False)
    all_mols = [x for x in suppl]
    mol_a = all_mols[0]
    mol_b = all_mols[1]

    ff_handlers = deserialize_handlers(open('ff/params/smirnoff_1_1_0_ccc.py').read())
    ff = Forcefield(ff_handlers)

    # the water system first.
    solvent_system, solvent_coords, solvent_box, omm_topology = builders.build_water_system(4.0)
    solvent_box += np.eye(3)*0.1 # BFGS this later

    print("Minimizing the host structure to remove clashes.")
    minimized_solvent_coords = minimizer.minimize_host_4d(mol_a, solvent_system, solvent_coords, ff, solvent_box)

    absolute_lambda_schedule = np.concatenate([
        np.linspace(0.0, 0.333, cmd_args.num_absolute_windows - cmd_args.num_absolute_windows//3, endpoint=False),
        np.linspace(0.333, 1.0, cmd_args.num_absolute_windows//3),
    ])

    abs_dGs = []
Example #19
0
def calculate_rigorous_work(
    host_pdbfile, guests_sdfile, outdir, fewer_outfiles=False, no_outfiles=False
):
    """
    """

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    print(
        f"""
    HOST_PDBFILE = {host_pdbfile}
    GUESTS_SDFILE = {guests_sdfile}
    OUTDIR = {outdir}

    INSERTION_MAX_LAMBDA = {INSERTION_MAX_LAMBDA}
    DELETION_MAX_LAMBDA = {DELETION_MAX_LAMBDA}
    MIN_LAMBDA = {MIN_LAMBDA}
    TRANSITION_STEPS = {TRANSITION_STEPS}
    EQ1_STEPS = {EQ1_STEPS}
    EQ2_STEPS = {EQ2_STEPS}
    """
    )

    # Prepare host
    # TODO: handle extra (non-transitioning) guests?
    print("Solvating host...")
    (
        solvated_host_system,
        solvated_host_coords,
        _,
        _,
        host_box,
        solvated_topology,
    ) = builders.build_protein_system(host_pdbfile)

    # sometimes water boxes are sad. Should be minimized first; this is a workaround
    host_box += np.eye(3) * 0.1
    print("host box", host_box)

    solvated_host_pdb = os.path.join(outdir, "solvated_host.pdb")
    writer = pdb_writer.PDBWriter([solvated_topology], solvated_host_pdb)
    writer.write_frame(solvated_host_coords)
    writer.close()
    solvated_host_mol = Chem.MolFromPDBFile(solvated_host_pdb, removeHs=False)
    if no_outfiles:
        os.remove(solvated_host_pdb)
    final_host_potentials = []
    host_potentials, host_masses = openmm_deserializer.deserialize_system(solvated_host_system, cutoff=1.2)
    host_nb_bp = None
    for bp in host_potentials:
        if isinstance(bp, potentials.Nonbonded):
            # (ytz): hack to ensure we only have one nonbonded term
            assert host_nb_bp is None
            host_nb_bp = bp
        else:
            final_host_potentials.append(bp)


    # Prepare water box
    print("Generating water box...")
    # TODO: water box probably doesn't need to be this big
    box_lengths = host_box[np.diag_indices(3)]
    water_box_width = min(box_lengths)
    (
        water_system,
        orig_water_coords,
        water_box,
        water_topology,
    ) = builders.build_water_system(water_box_width)

    # sometimes water boxes are sad. should be minimized first; this is a workaround
    water_box += np.eye(3) * 0.1
    print("water box", water_box)

    # it's okay if the water box here and the solvated protein box don't align -- they have PBCs
    water_pdb = os.path.join(outdir, "water_box.pdb")
    writer = pdb_writer.PDBWriter([water_topology], water_pdb)
    writer.write_frame(orig_water_coords)
    writer.close()
    water_mol = Chem.MolFromPDBFile(water_pdb, removeHs=False)
    if no_outfiles:
        os.remove(water_pdb)

    final_water_potentials = []
    water_potentials, water_masses = openmm_deserializer.deserialize_system(water_system, cutoff=1.2)
    water_nb_bp = None
    for bp in water_potentials:
        if isinstance(bp, potentials.Nonbonded):
            # (ytz): hack to ensure we only have one nonbonded term
            assert water_nb_bp is None
            water_nb_bp = bp
        else:
            final_water_potentials.append(bp)

    # Run the procedure
    print("Getting guests...")
    suppl = Chem.SDMolSupplier(guests_sdfile, removeHs=False)
    for guest_mol in suppl:
        start_time = time.time()
        guest_name = guest_mol.GetProp("_Name")
        guest_conformer = guest_mol.GetConformer(0)
        orig_guest_coords = np.array(guest_conformer.GetPositions(), dtype=np.float64)
        orig_guest_coords = orig_guest_coords / 10  # convert to md_units
        guest_ff_handlers = deserialize_handlers(
            open(
                os.path.join(
                    os.path.dirname(os.path.abspath(__file__)),
                    "..",
                    "ff/params/smirnoff_1_1_0_ccc.py",
                )
            ).read()
        )
        ff = Forcefield(guest_ff_handlers)
        guest_base_top = topology.BaseTopology(guest_mol, ff)

        # combine host & guest
        hgt = topology.HostGuestTopology(host_nb_bp, guest_base_top)
        # setup the parameter handlers for the ligand
        bonded_tuples = [
            [hgt.parameterize_harmonic_bond, ff.hb_handle],
            [hgt.parameterize_harmonic_angle, ff.ha_handle],
            [hgt.parameterize_proper_torsion, ff.pt_handle],
            [hgt.parameterize_improper_torsion, ff.it_handle]
        ]
        combined_bps = list(final_host_potentials)
        # instantiate the vjps while parameterizing (forward pass)
        for fn, handle in bonded_tuples:
            params, potential = fn(handle.params)
            combined_bps.append(potential.bind(params))
        nb_params, nb_potential = hgt.parameterize_nonbonded(ff.q_handle.params, ff.lj_handle.params)
        combined_bps.append(nb_potential.bind(nb_params))
        guest_masses = [a.GetMass() for a in guest_mol.GetAtoms()]
        combined_masses = np.concatenate([host_masses, guest_masses])

        run_leg(
            solvated_host_coords,
            orig_guest_coords,
            combined_bps,
            combined_masses,
            host_box,
            guest_name,
            "host",
            solvated_host_mol,
            guest_mol,
            outdir,
            fewer_outfiles,
            no_outfiles,
        )
        end_time = time.time()
        print(
            f"{guest_name} host leg time:", "%.2f" % (end_time - start_time), "seconds"
        )

        # combine water & guest
        wgt = topology.HostGuestTopology(water_nb_bp, guest_base_top)
        # setup the parameter handlers for the ligand
        bonded_tuples = [
            [wgt.parameterize_harmonic_bond, ff.hb_handle],
            [wgt.parameterize_harmonic_angle, ff.ha_handle],
            [wgt.parameterize_proper_torsion, ff.pt_handle],
            [wgt.parameterize_improper_torsion, ff.it_handle]
        ]
        combined_bps = list(final_water_potentials)
        # instantiate the vjps while parameterizing (forward pass)
        for fn, handle in bonded_tuples:
            params, potential = fn(handle.params)
            combined_bps.append(potential.bind(params))
        nb_params, nb_potential = wgt.parameterize_nonbonded(ff.q_handle.params, ff.lj_handle.params)
        combined_bps.append(nb_potential.bind(nb_params))
        guest_masses = [a.GetMass() for a in guest_mol.GetAtoms()]
        combined_masses = np.concatenate([water_masses, guest_masses])
        start_time = time.time()
        run_leg(
            orig_water_coords,
            orig_guest_coords,
            combined_bps,
            combined_masses,
            water_box,
            guest_name,
            "water",
            water_mol,
            guest_mol,
            outdir,
            fewer_outfiles,
            no_outfiles,
        )
        end_time = time.time()
        print(
            f"{guest_name} water leg time:", "%.2f" % (end_time - start_time), "seconds"
        )
Example #20
0
def dock_and_equilibrate(host_pdbfile,
                         guests_sdfile,
                         max_lambda,
                         insertion_steps,
                         eq_steps,
                         outdir,
                         fewer_outfiles=False,
                         constant_atoms=[]):
    """Solvates a host, inserts guest(s) into solvated host, equilibrates

    Parameters
    ----------

    host_pdbfile: path to host pdb file to dock into
    guests_sdfile: path to input sdf with guests to pose/dock
    max_lambda: lambda value the guest should insert from or delete to
        (recommended: 1.0 for work calulation, 0.25 to stay close to original pose)
        (must be =1 for work calculation to be applicable)
    insertion_steps: how many steps to insert the guest over (recommended: 501)
    eq_steps: how many steps of equilibration to do after insertion (recommended: 15001)
    outdir: where to write output (will be created if it does not already exist)
    fewer_outfiles: if True, will only write frames for the equilibration, not insertion
    constant_atoms: atom numbers from the host_pdbfile to hold mostly fixed across the simulation
        (1-indexed, like PDB files)

    Output
    ------

    A pdb & sdf file every 100 steps of insertion (outdir/<guest_name>/<guest_name>_<step>.[pdb/sdf])
    A pdb & sdf file every 1000 steps of equilibration (outdir/<guest_name>/<guest_name>_<step>.[pdb/sdf])
    stdout every 100(0) steps noting the step number, lambda value, and energy
    stdout for each guest noting the work of transition
    stdout for each guest noting how long it took to run

    Note
    ----
    If any norm of force per atom exceeds 20000 kJ/(mol*nm) [MAX_NORM_FORCE defined in docking/report.py],
    the simulation for that guest will stop and the work will not be calculated.
    """

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    print(f"""
    HOST_PDBFILE = {host_pdbfile}
    GUESTS_SDFILE = {guests_sdfile}
    OUTDIR = {outdir}
    MAX_LAMBDA = {max_lambda}
    INSERTION_STEPS = {insertion_steps}
    EQ_STEPS = {eq_steps}
    """)

    # Prepare host
    # TODO: handle extra (non-transitioning) guests?
    print("Solvating host...")
    # TODO: return topology from builders.build_protein_system
    (
        solvated_host_system,
        solvated_host_coords,
        _,
        _,
        host_box,
        solvated_topology,
    ) = builders.build_protein_system(host_pdbfile)

    # sometimes water boxes are sad. Should be minimized first; this is a workaround
    host_box += np.eye(3) * 0.1
    print("host box", host_box)

    solvated_host_pdb = os.path.join(outdir, "solvated_host.pdb")
    writer = pdb_writer.PDBWriter([solvated_topology], solvated_host_pdb)
    writer.write_frame(solvated_host_coords)
    writer.close()
    solvated_host_mol = Chem.MolFromPDBFile(solvated_host_pdb, removeHs=False)
    os.remove(solvated_host_pdb)
    final_host_potentials = []
    host_potentials, host_masses = openmm_deserializer.deserialize_system(
        solvated_host_system, cutoff=1.2)
    host_nb_bp = None
    for bp in host_potentials:
        if isinstance(bp, potentials.Nonbonded):
            # (ytz): hack to ensure we only have one nonbonded term
            assert host_nb_bp is None
            host_nb_bp = bp
        else:
            final_host_potentials.append(bp)

    # Run the procedure
    print("Getting guests...")
    suppl = Chem.SDMolSupplier(guests_sdfile, removeHs=False)
    for guest_mol in suppl:
        start_time = time.time()
        guest_name = guest_mol.GetProp("_Name")
        guest_conformer = guest_mol.GetConformer(0)
        orig_guest_coords = np.array(guest_conformer.GetPositions(),
                                     dtype=np.float64)
        orig_guest_coords = orig_guest_coords / 10  # convert to md_units
        guest_ff_handlers = deserialize_handlers(
            open(
                os.path.join(
                    os.path.dirname(os.path.abspath(__file__)),
                    "..",
                    "ff/params/smirnoff_1_1_0_ccc.py",
                )).read())
        ff = Forcefield(guest_ff_handlers)
        guest_base_top = topology.BaseTopology(guest_mol, ff)

        # combine host & guest
        hgt = topology.HostGuestTopology(host_nb_bp, guest_base_top)
        # setup the parameter handlers for the ligand
        bonded_tuples = [[hgt.parameterize_harmonic_bond, ff.hb_handle],
                         [hgt.parameterize_harmonic_angle, ff.ha_handle],
                         [hgt.parameterize_proper_torsion, ff.pt_handle],
                         [hgt.parameterize_improper_torsion, ff.it_handle]]
        combined_bps = list(final_host_potentials)
        # instantiate the vjps while parameterizing (forward pass)
        for fn, handle in bonded_tuples:
            params, potential = fn(handle.params)
            combined_bps.append(potential.bind(params))
        nb_params, nb_potential = hgt.parameterize_nonbonded(
            ff.q_handle.params, ff.lj_handle.params)
        combined_bps.append(nb_potential.bind(nb_params))
        guest_masses = [a.GetMass() for a in guest_mol.GetAtoms()]
        combined_masses = np.concatenate([host_masses, guest_masses])

        x0 = np.concatenate([solvated_host_coords, orig_guest_coords])
        v0 = np.zeros_like(x0)
        print(
            f"SYSTEM",
            f"guest_name: {guest_name}",
            f"num_atoms: {len(x0)}",
        )

        for atom_num in constant_atoms:
            combined_masses[atom_num - 1] += 50000

        seed = 2021
        intg = LangevinIntegrator(300.0, 1.5e-3, 1.0, combined_masses,
                                  seed).impl()

        u_impls = []
        for bp in combined_bps:
            bp_impl = bp.bound_impl(precision=np.float32)
            u_impls.append(bp_impl)

        ctxt = custom_ops.Context(x0, v0, host_box, intg, u_impls)

        # collect a du_dl calculation once every other step
        subsample_freq = 2
        du_dl_obs = custom_ops.FullPartialUPartialLambda(
            u_impls, subsample_freq)
        ctxt.add_observable(du_dl_obs)

        # insert guest
        insertion_lambda_schedule = np.linspace(max_lambda, 0.0,
                                                insertion_steps)
        calc_work = True
        for step, lamb in enumerate(insertion_lambda_schedule):
            ctxt.step(lamb)
            if step % 100 == 0:
                report.report_step(ctxt, step, lamb, host_box, combined_bps,
                                   u_impls, guest_name, insertion_steps,
                                   "INSERTION")
                if not fewer_outfiles:
                    host_coords = ctxt.get_x_t()[:len(solvated_host_coords
                                                      )] * 10
                    guest_coords = ctxt.get_x_t()[len(solvated_host_coords
                                                      ):] * 10
                    report.write_frame(
                        host_coords,
                        solvated_host_mol,
                        guest_coords,
                        guest_mol,
                        guest_name,
                        outdir,
                        str(step).zfill(len(str(insertion_steps))),
                        f"ins",
                    )
            if step in (0, int(insertion_steps / 2), insertion_steps - 1):
                if report.too_much_force(ctxt, lamb, host_box, combined_bps,
                                         u_impls):
                    calc_work = False
                    break

        # Note: this condition only applies for ABFE, not RBFE
        if (abs(du_dl_obs.full_du_dl()[0]) > 0.001
                or abs(du_dl_obs.full_du_dl()[-1]) > 0.001):
            print("Error: du_dl endpoints are not ~0")
            calc_work = False

        if calc_work:
            work = np.trapz(du_dl_obs.full_du_dl(),
                            insertion_lambda_schedule[::subsample_freq])
            print(f"guest_name: {guest_name}\tinsertion_work: {work:.2f}")

        # equilibrate
        for step in range(eq_steps):
            ctxt.step(0.00)
            if step % 1000 == 0:
                report.report_step(ctxt, step, 0.00, host_box, combined_bps,
                                   u_impls, guest_name, eq_steps,
                                   'EQUILIBRATION')
                host_coords = ctxt.get_x_t()[:len(solvated_host_coords)] * 10
                guest_coords = ctxt.get_x_t()[len(solvated_host_coords):] * 10
                report.write_frame(
                    host_coords,
                    solvated_host_mol,
                    guest_coords,
                    guest_mol,
                    guest_name,
                    outdir,
                    str(step).zfill(len(str(eq_steps))),
                    f"eq",
                )
            if step in (0, int(eq_steps / 2), eq_steps - 1):
                if report.too_much_force(ctxt, 0.00, host_box, combined_bps,
                                         u_impls):
                    break

        end_time = time.time()
        print(f"{guest_name} took {(end_time - start_time):.2f} seconds")
Example #21
0
def convergence(args):
    epoch, lamb, lamb_idx = args

    suppl = Chem.SDMolSupplier("tests/data/ligands_40.sdf", removeHs=False)

    ligands = []
    for mol in suppl:
        ligands.append(mol)

    ligand_a = ligands[0]
    ligand_b = ligands[1]

    # print(ligand_a.GetNumAtoms())
    # print(ligand_b.GetNumAtoms())

    # ligand_a = Chem.AddHs(Chem.MolFromSmiles("CCCC1=NN(C2=C1N=C(NC2=O)C3=C(C=CC(=C3)S(=O)(=O)N4CCN(CC4)C)OCC)C"))
    # ligand_b = Chem.AddHs(Chem.MolFromSmiles("CCCC1=NN(C2=C1N=C(NC2=O)C3=C(C=CC(=C3)S(=O)(=O)N4CCN(CC4)C)OCC)C"))
    # ligand_a = Chem.AddHs(Chem.MolFromSmiles("c1ccccc1CC"))
    # ligand_b = Chem.AddHs(Chem.MolFromSmiles("c1ccccc1CC"))
    # AllChem.EmbedMolecule(ligand_a, randomSeed=2020)
    # AllChem.EmbedMolecule(ligand_b, randomSeed=2020)

    coords_a = get_conf(ligand_a, idx=0)
    coords_b = get_conf(ligand_b, idx=0)
    # coords_b = np.matmul(coords_b, special_ortho_group.rvs(3))

    coords_a = recenter(coords_a)
    coords_b = recenter(coords_b)

    coords = np.concatenate([coords_a, coords_b])

    a_idxs = get_heavy_atom_idxs(ligand_a)
    b_idxs = get_heavy_atom_idxs(ligand_b)

    a_full_idxs = np.arange(0, ligand_a.GetNumAtoms())
    b_full_idxs = np.arange(0, ligand_b.GetNumAtoms())

    b_idxs += ligand_a.GetNumAtoms()
    b_full_idxs += ligand_a.GetNumAtoms()

    nrg_fns = []

    forcefield = 'ff/params/smirnoff_1_1_0_ccc.py'
    ff_raw = open(forcefield, "r").read()
    ff_handlers = deserialize_handlers(ff_raw)

    combined_mol = Chem.CombineMols(ligand_a, ligand_b)

    for handler in ff_handlers:
        if isinstance(handler, handlers.HarmonicBondHandler):
            bond_idxs, (bond_params, _) = handler.parameterize(combined_mol)
            nrg_fns.append(
                functools.partial(bonded.harmonic_bond,
                    params=bond_params,
                    box=None,
                    bond_idxs=bond_idxs
                )
            )
        elif isinstance(handler, handlers.HarmonicAngleHandler):
            angle_idxs, (angle_params, _) = handler.parameterize(combined_mol)
            nrg_fns.append(
                functools.partial(bonded.harmonic_angle,
                    params=angle_params,
                    box=None,
                    angle_idxs=angle_idxs
                )
            )
        # elif isinstance(handler, handlers.ImproperTorsionHandler):
        #     torsion_idxs, (torsion_params, _) = handler.parameterize(combined_mol)
        #     print(torsion_idxs)
        #     assert 0
        #     nrg_fns.append(
        #         functools.partial(bonded.periodic_torsion,
        #             params=torsion_params,
        #             box=None,
        #             lamb=None,
        #             torsion_idxs=torsion_idxs
        #         )
        #     )
        # elif isinstance(handler, handlers.ProperTorsionHandler):
        #     torsion_idxs, (torsion_params, _) = handler.parameterize(combined_mol)
        #     # print(torsion_idxs)
        #     nrg_fns.append(
        #         functools.partial(bonded.periodic_torsion,
        #             params=torsion_params,
        #             box=None,
        #             lamb=None,
        #             torsion_idxs=torsion_idxs
        #         )
        #     )

    masses_a = onp.array([a.GetMass() for a in ligand_a.GetAtoms()]) * 10000
    masses_b = onp.array([a.GetMass() for a in ligand_b.GetAtoms()])

    combined_masses = np.concatenate([masses_a, masses_b])

    # com_restraint_fn = functools.partial(bonded.centroid_restraint,
    #     params=None,
    #     box=None,
    #     lamb=None,
    #     # masses=combined_masses, # try making this ones-like
    #     masses=np.ones_like(combined_masses),
    #     group_a_idxs=a_idxs,
    #     group_b_idxs=b_idxs,
    #     kb=50.0,
    #     b0=0.0)

    pmi_restraint_fn = functools.partial(pmi_restraints_new,
        params=None,
        box=None,
        lamb=None,
        # masses=np.ones_like(combined_masses),
        masses=combined_masses,
        # a_idxs=a_full_idxs,
        # b_idxs=b_full_idxs,
        a_idxs=a_idxs,
        b_idxs=b_idxs,
        angle_force=100.0,
        com_force=100.0
    )

    prefactor = 2.7 # unitless
    shape_lamb = (4*np.pi)/(3*prefactor) # unitless
    kappa = np.pi/(np.power(shape_lamb, 2/3)) # unitless
    sigma = 0.15 # 1 angstrom std, 95% coverage by 2 angstroms
    alpha = kappa/(sigma*sigma)

    alphas = np.zeros(combined_mol.GetNumAtoms())+alpha
    weights = np.zeros(combined_mol.GetNumAtoms())+prefactor

    shape_restraint_fn = functools.partial(
        shape.harmonic_overlap,
        box=None,
        lamb=None,
        params=None,
        a_idxs=a_idxs,
        b_idxs=b_idxs,
        alphas=alphas,
        weights=weights,
        k=150.0
    )

    # shape_restraint_4d_fn = functools.partial(
    #     shape.harmonic_4d_overlap,
    #     box=None,
    #     params=None,
    #     a_idxs=a_idxs,
    #     b_idxs=b_idxs,
    #     alphas=alphas,
    #     weights=weights,
    #     k=200.0
    # )

    def restraint_fn(conf, lamb):

        return pmi_restraint_fn(conf) + lamb*shape_restraint_fn(conf)
        # return (1-lamb)*pmi_restraint_fn(conf) + lamb*shape_restraint_fn(conf)


    nrg_fns.append(restraint_fn)

    def nrg_fn(conf, lamb):
        s = []
        for u in nrg_fns:
            s.append(u(conf, lamb=lamb))
        return np.sum(s)
 
    grad_fn = jax.grad(nrg_fn, argnums=(0,1))
    grad_fn = jax.jit(grad_fn)

    du_dx_fn = jax.grad(nrg_fn, argnums=(0))
    du_dx_fn = jax.jit(du_dx_fn)

    x_t = coords
    v_t = np.zeros_like(x_t)

    w = Chem.SDWriter('frames_heavy_'+str(epoch)+'_'+str(lamb_idx)+'.sdf')

    dt = 1.5e-3
    ca, cb, cc = langevin_coefficients(300.0, dt, 1.0, combined_masses)
    cb = -1*onp.expand_dims(cb, axis=-1)
    cc = onp.expand_dims(cc, axis=-1)

    du_dls = []

    # re-seed since forking 
    onp.random.seed(int.from_bytes(os.urandom(4), byteorder='little'))


    # for step in range(100000):
    for step in range(100000):

        # if step % 1000 == 0:
        #     u = nrg_fn(x_t, lamb)
        #     print("step", step, "nrg", onp.asarray(u), "avg_du_dl",  onp.mean(du_dls))
        #     mol = make_conformer(combined_mol, x_t[:ligand_a.GetNumAtoms()], x_t[ligand_a.GetNumAtoms():])
        #     w.write(mol)
        #     w.flush()

        if step % 5 == 0 and step > 10000:
            du_dx, du_dl = grad_fn(x_t, lamb)
            du_dls.append(du_dl)
        else:
            du_dx = du_dx_fn(x_t, lamb)

        v_t = ca*v_t + cb*du_dx + cc*onp.random.normal(size=x_t.shape)
        x_t = x_t + v_t*dt

    return np.mean(onp.mean(du_dls))