Exemple #1
0
def test_exclusions():

    mol = Chem.MolFromSmiles("FC(F)=C(F)F")
    exc_idxs, scales = nonbonded.generate_exclusion_idxs(mol,
                                                         scale12=0.0,
                                                         scale13=0.2,
                                                         scale14=0.5)

    for pair, scale in zip(exc_idxs, scales):
        src, dst = pair
        assert src < dst

    expected_idxs = np.array([[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [1, 2],
                              [1, 3], [1, 4], [1, 5], [2, 3], [2, 4], [2, 5],
                              [3, 4], [3, 5], [4, 5]])

    np.testing.assert_equal(exc_idxs, expected_idxs)

    expected_scales = [
        0., 0.2, 0.2, 0.5, 0.5, 0., 0., 0.2, 0.2, 0.2, 0.5, 0.5, 0., 0., 0.2
    ]
    np.testing.assert_equal(scales, expected_scales)
Exemple #2
0
    def parameterize_nonbonded(self, ff_q_params, ff_lj_params):
        q_params = self.ff.q_handle.partial_parameterize(ff_q_params, self.mol)
        lj_params = self.ff.lj_handle.partial_parameterize(ff_lj_params, self.mol)

        exclusion_idxs, scale_factors = nonbonded.generate_exclusion_idxs(
            self.mol,
            scale12=_SCALE_12,
            scale13=_SCALE_13,
            scale14=_SCALE_14
        )

        scale_factors = np.stack([scale_factors, scale_factors], axis=1)

        N = len(q_params)

        lambda_plane_idxs = np.zeros(N, dtype=np.int32)
        lambda_offset_idxs = np.ones(N, dtype=np.int32)

        beta = _BETA
        cutoff = _CUTOFF # solve for this analytically later

        nb = potentials.Nonbonded(
            exclusion_idxs,
            scale_factors,
            lambda_plane_idxs,
            lambda_offset_idxs,
            beta,
            cutoff
        ) 

        params = jnp.concatenate([
            jnp.reshape(q_params, (-1, 1)),
            jnp.reshape(lj_params, (-1, 2))
        ], axis=1)

        return params, nb
Exemple #3
0
    def parameterize_nonbonded(self, ff_q_params, ff_lj_params):
        # Nonbonded potentials combine through parameter interpolation, not energy interpolation.
        # They may or may not operate through 4D decoupling depending on the atom mapping. If an atom is
        # unique, it is kept at full strength and not switched off.

        q_params_a = self.ff.q_handle.partial_parameterize(ff_q_params, self.mol_a)
        q_params_b = self.ff.q_handle.partial_parameterize(ff_q_params, self.mol_b) # HARD TYPO
        lj_params_a = self.ff.lj_handle.partial_parameterize(ff_lj_params, self.mol_a)
        lj_params_b = self.ff.lj_handle.partial_parameterize(ff_lj_params, self.mol_b)

        qlj_params_a = jnp.concatenate([
            jnp.reshape(q_params_a, (-1, 1)),
            jnp.reshape(lj_params_a, (-1, 2))
        ], axis=1)
        qlj_params_b = jnp.concatenate([
            jnp.reshape(q_params_b, (-1, 1)),
            jnp.reshape(lj_params_b, (-1, 2))
        ], axis=1)

        qlj_params_src, qlj_params_dst = self.interpolate_nonbonded_params(qlj_params_a, qlj_params_b)
        qlj_params = jnp.concatenate([qlj_params_src, qlj_params_dst])

        exclusion_idxs_a, scale_factors_a = nonbonded.generate_exclusion_idxs(
            self.mol_a,
            scale12=_SCALE_12,
            scale13=_SCALE_13,
            scale14=_SCALE_14
        )

        exclusion_idxs_b, scale_factors_b = nonbonded.generate_exclusion_idxs(
            self.mol_b,
            scale12=_SCALE_12,
            scale13=_SCALE_13,
            scale14=_SCALE_14
        )

        # (ytz): use the same scale factors of LJ & charges for now
        # this isn't quite correct as the LJ/Coluomb may be different in 
        # different forcefields.
        scale_factors_a = np.stack([scale_factors_a, scale_factors_a], axis=1)
        scale_factors_b = np.stack([scale_factors_b, scale_factors_b], axis=1)

        combined_exclusion_dict = dict()

        for ij, scale in zip(exclusion_idxs_a, scale_factors_a):
            ij = tuple(sorted(self.a_to_c[ij]))
            if ij in combined_exclusion_dict:
                np.testing.assert_array_equal(combined_exclusion_dict[ij], scale)
            else:
                combined_exclusion_dict[ij] = scale

        for ij, scale in zip(exclusion_idxs_b, scale_factors_b):
            ij = tuple(sorted(self.b_to_c[ij]))
            if ij in combined_exclusion_dict:
                np.testing.assert_array_equal(combined_exclusion_dict[ij], scale)
            else:
                combined_exclusion_dict[ij] = scale

        combined_exclusion_idxs = []
        combined_scale_factors = []

        for e, s in combined_exclusion_dict.items():
            combined_exclusion_idxs.append(e)
            combined_scale_factors.append(s)

        combined_exclusion_idxs = np.array(combined_exclusion_idxs)
        combined_scale_factors = np.array(combined_scale_factors)

        # (ytz): we don't need exclusions between R_A and R_B will never see each other
        # under this decoupling scheme. They will always be at cutoff apart from each other.

        # plane_idxs: RA = Core = 0, RB = -1
        # offset_idxs: Core = 0, RA = RB = +1 
        combined_lambda_plane_idxs = np.zeros(self.get_num_atoms(), dtype=np.int32)
        combined_lambda_offset_idxs = np.zeros(self.get_num_atoms(), dtype=np.int32)

        for atom, group in enumerate(self.c_flags):
            if group == 0:
                # core atom
                combined_lambda_plane_idxs[atom] = 0
                combined_lambda_offset_idxs[atom] = 0
            elif group == 1:
                combined_lambda_plane_idxs[atom] = 0
                combined_lambda_offset_idxs[atom] = 1
            elif group == 2:
                combined_lambda_plane_idxs[atom] = -1
                combined_lambda_offset_idxs[atom] = 1
            else:
                assert 0

        beta = _BETA
        cutoff = _CUTOFF # solve for this analytically later

        nb = potentials.Nonbonded(
            combined_exclusion_idxs,
            combined_scale_factors,
            combined_lambda_plane_idxs,
            combined_lambda_offset_idxs,
            beta,
            cutoff
        ) 

        return qlj_params, potentials.InterpolatedPotential(nb, self.get_num_atoms(), qlj_params.size)
Exemple #4
0
    def parameterize_nonbonded(self, ff_q_params, ff_lj_params):
        q_params_a = self.ff.q_handle.partial_parameterize(ff_q_params, self.mol_a)
        q_params_b = self.ff.q_handle.partial_parameterize(ff_q_params, self.mol_b)
        lj_params_a = self.ff.lj_handle.partial_parameterize(ff_lj_params, self.mol_a)
        lj_params_b = self.ff.lj_handle.partial_parameterize(ff_lj_params, self.mol_b)

        q_params = jnp.concatenate([q_params_a, q_params_b])
        lj_params = jnp.concatenate([lj_params_a, lj_params_b])

        exclusion_idxs_a, scale_factors_a = nonbonded.generate_exclusion_idxs(
            self.mol_a,
            scale12=_SCALE_12,
            scale13=_SCALE_13,
            scale14=_SCALE_14
        )

        exclusion_idxs_b, scale_factors_b = nonbonded.generate_exclusion_idxs(
            self.mol_b,
            scale12=_SCALE_12,
            scale13=_SCALE_13,
            scale14=_SCALE_14
        )

        mutual_exclusions = []
        mutual_scale_factors = []

        NA = self.mol_a.GetNumAtoms()
        NB = self.mol_b.GetNumAtoms()

        for i in range(NA):
            for j in range(NB):
                mutual_exclusions.append([i, j + NA])
                mutual_scale_factors.append([1.0, 1.0])

        mutual_exclusions = np.array(mutual_exclusions)
        mutual_scale_factors = np.array(mutual_scale_factors)

        combined_exclusion_idxs = np.concatenate([
            exclusion_idxs_a,
            exclusion_idxs_b + NA,
            mutual_exclusions
        ]).astype(np.int32)

        combined_scale_factors = np.concatenate([
            np.stack([scale_factors_a, scale_factors_a], axis=1),
            np.stack([scale_factors_b, scale_factors_b], axis=1),
            mutual_scale_factors
        ]).astype(np.float64)

        combined_lambda_plane_idxs = np.zeros(NA+NB, dtype=np.int32)
        combined_lambda_offset_idxs = np.concatenate([
            np.ones(NA, dtype=np.int32),
            np.ones(NB, dtype=np.int32)
        ])

        beta = _BETA
        cutoff = _CUTOFF # solve for this analytically later

        nb = potentials.Nonbonded(
            combined_exclusion_idxs,
            combined_scale_factors,
            combined_lambda_plane_idxs,
            combined_lambda_offset_idxs,
            beta,
            cutoff
        ) 

        params = jnp.concatenate([
            jnp.reshape(q_params, (-1, 1)),
            jnp.reshape(lj_params, (-1, 2))
        ], axis=1)

        return params, nb
Exemple #5
0
def create_system(guest_mol, host_pdb, handlers, stage, core_atoms,
                  restr_force, restr_alpha, restr_count):
    """
    Initialize a self-encompassing System object that we can serialize and simulate.

    Parameters
    ----------

    guest_mol: rdkit.ROMol

    protein: openmm.System


    """
    # host_system = protein_system

    guest_masses = np.array([a.GetMass() for a in guest_mol.GetAtoms()],
                            dtype=np.float64)

    amber_ff = app.ForceField('amber99sbildn.xml', 'amber99_obc.xml')
    host_system = amber_ff.createSystem(host_pdb.topology,
                                        nonbondedMethod=app.NoCutoff,
                                        constraints=None,
                                        rigidWater=False)

    host_fns, host_masses = openmm_deserializer.deserialize_system(host_system)

    num_host_atoms = len(host_masses)
    num_guest_atoms = guest_mol.GetNumAtoms()

    # Name, Args, vjp_fn
    final_gradients = []
    final_vjp_fns = []

    for item in host_fns:

        if item[0] == 'LennardJones':
            host_lj_params = item[1]
        elif item[0] == 'Charges':
            host_charge_params = item[1]
        elif item[0] == 'GBSA':
            host_gb_params = item[1][0]
            host_gb_props = item[1][1:]
        elif item[0] == 'Exclusions':
            host_exclusions = item[1]
        else:
            final_gradients.append((item[0], item[1]))
            final_vjp_fns.append(None)

    # print("Ligand A Name:", a_name)

    guest_exclusion_idxs, guest_scales = nonbonded.generate_exclusion_idxs(
        guest_mol, scale12=1.0, scale13=1.0, scale14=0.5)

    guest_exclusion_idxs += num_host_atoms
    guest_lj_exclusion_scales = guest_scales
    guest_charge_exclusion_scales = guest_scales

    host_exclusion_idxs = host_exclusions[0]
    host_lj_exclusion_scales = host_exclusions[1]
    host_charge_exclusion_scales = host_exclusions[2]

    combined_exclusion_idxs = np.concatenate(
        [host_exclusion_idxs, guest_exclusion_idxs])
    combined_lj_exclusion_scales = np.concatenate(
        [host_lj_exclusion_scales, guest_lj_exclusion_scales])
    combined_charge_exclusion_scales = np.concatenate(
        [host_charge_exclusion_scales, guest_charge_exclusion_scales])

    # handler_vjps = []

    for handle in handlers:
        results = handle.parameterize(guest_mol)

        if isinstance(handle, bonded.HarmonicBondHandler):
            bond_idxs, (bond_params, bond_vjp_fn) = results
            bond_idxs += num_host_atoms
            final_gradients.append(("HarmonicBond", (bond_idxs, bond_params)))
            final_vjp_fns.append((bond_vjp_fn))
            # handler_vjps.append(bond_vjp_fn)
        elif isinstance(handle, bonded.HarmonicAngleHandler):
            angle_idxs, (angle_params, angle_vjp_fn) = results
            angle_idxs += num_host_atoms
            final_gradients.append(
                ("HarmonicAngle", (angle_idxs, angle_params)))
            final_vjp_fns.append(angle_vjp_fn)
            # handler_vjps.append(angle_vjp_fn)
        elif isinstance(handle, bonded.ProperTorsionHandler):
            torsion_idxs, (torsion_params, torsion_vjp_fn) = results
            torsion_idxs += num_host_atoms
            final_gradients.append(
                ("PeriodicTorsion", (torsion_idxs, torsion_params)))
            final_vjp_fns.append(torsion_vjp_fn)
            # handler_vjps.append(torsion_vjp_fn)
            # guest_vjp_fns.append(torsion_vjp_fn)
        elif isinstance(handle, bonded.ImproperTorsionHandler):
            torsion_idxs, (torsion_params, torsion_vjp_fn) = results
            torsion_idxs += num_host_atoms
            final_gradients.append(
                ("PeriodicTorsion", (torsion_idxs, torsion_params)))
            final_vjp_fns.append(torsion_vjp_fn)
            # handler_vjps.append(torsion_vjp_fn)
        elif isinstance(handle, nonbonded.LennardJonesHandler):
            guest_lj_params, guest_lj_vjp_fn = results
            combined_lj_params, combined_lj_vjp_fn = concat_with_vjps(
                host_lj_params, guest_lj_params, None, guest_lj_vjp_fn)
            # handler_vjps.append(lj_adjoint_fn)
        elif isinstance(handle, nonbonded.SimpleChargeHandler):
            guest_charge_params, guest_charge_vjp_fn = results
            combined_charge_params, combined_charge_vjp_fn = concat_with_vjps(
                host_charge_params, guest_charge_params, None,
                guest_charge_vjp_fn)
            # handler_vjps.append(charge_adjoint_fn)
        elif isinstance(handle, nonbonded.GBSAHandler):
            guest_gb_params, guest_gb_vjp_fn = results
            combined_gb_params, combined_gb_vjp_fn = concat_with_vjps(
                host_gb_params, guest_gb_params, None, guest_gb_vjp_fn)
            # handler_vjps.append(gb_adjoint_fn)
        elif isinstance(handle, nonbonded.AM1BCCHandler):
            # ill defined behavior if both SimpleChargeHandler and AM1Handler is present
            guest_charge_params, guest_charge_vjp_fn = results
            combined_charge_params, combined_charge_vjp_fn = concat_with_vjps(
                host_charge_params, guest_charge_params, None,
                guest_charge_vjp_fn)
            # handler_vjps.append(gb_adjoint_fn)
        elif isinstance(handle, nonbonded.AM1CCCHandler):
            guest_charge_params, guest_charge_vjp_fn = results
            combined_charge_params, combined_charge_vjp_fn = concat_with_vjps(
                host_charge_params, guest_charge_params, None,
                guest_charge_vjp_fn)
            # handler_vjps.append(gb_adjoint_fn)
        else:
            raise Exception("Unknown Handler", handle)

    # (use the below vjps for correctness)
    # combined_charge_params, charge_adjoint_fn = concat_with_vjps(host_charge_params, guest_charge_params, None, guest_charge_vjp_fn)
    # combined_lj_params, lj_adjoint_fn = concat_with_vjps(host_lj_params, guest_lj_params, None, guest_lj_vjp_fn)
    # combined_gb_params, gb_adjoint_fn = concat_with_vjps(host_gb_params, guest_gb_params, None, guest_gb_vjp_fn)

    # WIP
    N_C = num_host_atoms + num_guest_atoms
    N_A = num_host_atoms

    if stage == 0:
        combined_lambda_plane_idxs = np.zeros(N_C, dtype=np.int32)
        combined_lambda_offset_idxs = np.zeros(N_C, dtype=np.int32)
    elif stage == 1:
        combined_lambda_plane_idxs = np.zeros(N_C, dtype=np.int32)
        combined_lambda_offset_idxs = np.zeros(N_C, dtype=np.int32)
        combined_lambda_offset_idxs[N_A:] = 1
    elif stage == 2:
        combined_lambda_plane_idxs = np.zeros(N_C, dtype=np.int32)
        combined_lambda_plane_idxs[N_A:] = 1
        combined_lambda_offset_idxs = np.zeros(N_C, dtype=np.int32)

    # assert 0

    cutoff = 100000.0

    final_gradients.append(
        ('Nonbonded',
         (np.asarray(combined_charge_params), np.asarray(combined_lj_params),
          combined_exclusion_idxs, combined_charge_exclusion_scales,
          combined_lj_exclusion_scales, combined_lambda_plane_idxs,
          combined_lambda_offset_idxs, cutoff)))
    final_vjp_fns.append((combined_charge_vjp_fn, combined_lj_vjp_fn))

    final_gradients.append(
        ('GBSA', (np.asarray(combined_charge_params),
                  np.asarray(combined_gb_params), combined_lambda_plane_idxs,
                  combined_lambda_offset_idxs, *host_gb_props, cutoff,
                  cutoff)))
    final_vjp_fns.append((combined_charge_vjp_fn, combined_gb_vjp_fn))

    host_conf = []
    for x, y, z in host_pdb.positions:
        host_conf.append([to_md_units(x), to_md_units(y), to_md_units(z)])
    host_conf = np.array(host_conf)

    conformer = guest_mol.GetConformer(0)
    mol_a_conf = np.array(conformer.GetPositions(), dtype=np.float64)
    mol_a_conf = mol_a_conf / 10  # convert to md_units

    x0 = np.concatenate([host_conf, mol_a_conf])  # combined geometry
    v0 = np.zeros_like(x0)

    # build restraints using the coordinates
    backbone_atoms = []
    for r_idx, residue in enumerate(host_pdb.getTopology().residues()):
        for a in residue.atoms():
            if a.name == 'CA':
                backbone_atoms.append(a.index)

    final_gradients.append(
        setup_core_restraints(restr_force,
                              restr_alpha,
                              restr_count,
                              x0,
                              num_host_atoms,
                              core_atoms,
                              backbone_atoms,
                              stage=stage))

    final_vjp_fns.append(None)

    combined_masses = np.concatenate([host_masses, guest_masses])

    return x0, combined_masses, final_gradients, final_vjp_fns
Exemple #6
0
def create_system(guest_mol, host_pdb, handlers):
    """
    Initialize a self-encompassing System object that we can serialize and simulate.

    Parameters
    ----------

    guest_mol: rdkit.ROMol
        guest molecule
        
    host_pdb: openmm.PDBFile
        host system from OpenMM

    handlers: list of timemachine.ops.Gradients
        forcefield handlers used to parameterize the small molecule
 
    Returns
    -------
    3-tuple
        x0, combined_masses, final_gradients

    """

    guest_masses = np.array([a.GetMass() for a in guest_mol.GetAtoms()],
                            dtype=np.float64)

    amber_ff = app.ForceField('amber99sbildn.xml', 'amber99_obc.xml')
    host_system = amber_ff.createSystem(host_pdb.topology,
                                        nonbondedMethod=app.NoCutoff,
                                        constraints=None,
                                        rigidWater=False)

    host_fns, host_masses = openmm_deserializer.deserialize_system(host_system)

    num_host_atoms = len(host_masses)
    num_guest_atoms = guest_mol.GetNumAtoms()

    # Name, Args, vjp_fn
    final_gradients = []

    for item in host_fns:

        if item[0] == 'LennardJones':
            host_lj_params = item[1]
        elif item[0] == 'Charges':
            host_charge_params = item[1]
        elif item[0] == 'GBSA':
            host_gb_params = item[1][0]
            host_gb_props = item[1][1:]
        elif item[0] == 'Exclusions':
            host_exclusions = item[1]
        else:
            final_gradients.append((item[0], item[1]))

    guest_exclusion_idxs, guest_scales = nonbonded.generate_exclusion_idxs(
        guest_mol, scale12=1.0, scale13=1.0, scale14=0.5)

    guest_exclusion_idxs += num_host_atoms
    guest_lj_exclusion_scales = guest_scales
    guest_charge_exclusion_scales = guest_scales

    host_exclusion_idxs = host_exclusions[0]
    host_lj_exclusion_scales = host_exclusions[1]
    host_charge_exclusion_scales = host_exclusions[2]

    combined_exclusion_idxs = np.concatenate(
        [host_exclusion_idxs, guest_exclusion_idxs])
    combined_lj_exclusion_scales = np.concatenate(
        [host_lj_exclusion_scales, guest_lj_exclusion_scales])
    combined_charge_exclusion_scales = np.concatenate(
        [host_charge_exclusion_scales, guest_charge_exclusion_scales])

    for handle in handlers:
        results = handle.parameterize(guest_mol)

        if isinstance(handle, bonded.HarmonicBondHandler):
            bond_idxs, (bond_params, _) = results
            bond_idxs += num_host_atoms
            final_gradients.append(("HarmonicBond", (bond_idxs, bond_params)))
        elif isinstance(handle, bonded.HarmonicAngleHandler):
            angle_idxs, (angle_params, _) = results
            angle_idxs += num_host_atoms
            final_gradients.append(
                ("HarmonicAngle", (angle_idxs, angle_params)))
        elif isinstance(handle, bonded.ProperTorsionHandler):
            torsion_idxs, (torsion_params, _) = results
            torsion_idxs += num_host_atoms
            final_gradients.append(
                ("PeriodicTorsion", (torsion_idxs, torsion_params)))
        elif isinstance(handle, bonded.ImproperTorsionHandler):
            torsion_idxs, (torsion_params, _) = results
            torsion_idxs += num_host_atoms
            final_gradients.append(
                ("PeriodicTorsion", (torsion_idxs, torsion_params)))
        elif isinstance(handle, nonbonded.LennardJonesHandler):
            guest_lj_params, _ = results
            combined_lj_params = np.concatenate(
                [host_lj_params, guest_lj_params])
        elif isinstance(handle, nonbonded.SimpleChargeHandler):
            guest_charge_params, _ = results
            combined_charge_params = np.concatenate(
                [host_charge_params, guest_charge_params])
        elif isinstance(handle, nonbonded.GBSAHandler):
            guest_gb_params, _ = results
            combined_gb_params = np.concatenate(
                [host_gb_params, guest_gb_params])
        elif isinstance(handle, nonbonded.AM1BCCHandler):
            guest_charge_params, _ = results
            combined_charge_params = np.concatenate(
                [host_charge_params, guest_charge_params])
        elif isinstance(handle, nonbonded.AM1CCCHandler):
            guest_charge_params, _ = results
            combined_charge_params = np.concatenate(
                [host_charge_params, guest_charge_params])
        else:
            raise Exception("Unknown Handler", handle)

    host_conf = []
    for x, y, z in host_pdb.positions:
        host_conf.append([to_md_units(x), to_md_units(y), to_md_units(z)])
    host_conf = np.array(host_conf)

    conformer = guest_mol.GetConformer(0)
    mol_a_conf = np.array(conformer.GetPositions(), dtype=np.float64)
    mol_a_conf = mol_a_conf / 10  # convert to md_units

    center = np.mean(mol_a_conf, axis=0)

    mol_a_conf -= center

    from scipy.stats import special_ortho_group
    mol_a_conf = np.matmul(mol_a_conf, special_ortho_group.rvs(3))
    mol_a_conf += center

    # assert 0

    x0 = np.concatenate([host_conf, mol_a_conf])  # combined geometry
    v0 = np.zeros_like(x0)

    N_C = num_host_atoms + num_guest_atoms
    N_A = num_host_atoms

    cutoff = 100000.0

    combined_lambda_plane_idxs = np.zeros(N_C, dtype=np.int32)
    combined_lambda_offset_idxs = np.zeros(N_C, dtype=np.int32)
    combined_lambda_offset_idxs[num_host_atoms:] = 1

    final_gradients.append(
        ('Nonbonded',
         (np.asarray(combined_charge_params), np.asarray(combined_lj_params),
          combined_exclusion_idxs, combined_charge_exclusion_scales,
          combined_lj_exclusion_scales, combined_lambda_plane_idxs,
          combined_lambda_offset_idxs, cutoff)))

    final_gradients.append(
        ('GBSA', (np.asarray(combined_charge_params),
                  np.asarray(combined_gb_params), combined_lambda_plane_idxs,
                  combined_lambda_offset_idxs, *host_gb_props, cutoff,
                  cutoff)))

    combined_masses = np.concatenate([host_masses, guest_masses])

    return x0, combined_masses, final_gradients
Exemple #7
0
def combine_potentials(ff_handlers, guest_mol, host_system, precision):
    """
    This function is responsible for figuring out how to take two separate hamiltonians
    and combining them into one sensible alchemical system.

    Parameters
    ----------

    ff_handlers: list of forcefield handlers
        Small molecule forcefield handlers

    guest_mol: Chem.ROMol
        RDKit molecule

    host_system: openmm.System
        Host system to be deserialized

    precision: np.float32 or np.float64
        Numerical precision of the functional form

    Returns
    -------
    tuple
        Returns a list of lib.potentials objects, combined masses, and a list of
        their corresponding vjp_fns back into the forcefield

    """

    host_potentials, host_masses = openmm_deserializer.deserialize_system(
        host_system, precision, cutoff=1.0)

    host_nb_bp = None

    combined_potentials = []
    combined_vjp_fns = []

    for bp in host_potentials:
        if isinstance(bp, potentials.Nonbonded):
            # (ytz): hack to ensure we only have one nonbonded term
            assert host_nb_bp is None
            host_nb_bp = bp
        else:
            combined_potentials.append(bp)
            combined_vjp_fns.append([])

    guest_masses = np.array([a.GetMass() for a in guest_mol.GetAtoms()],
                            dtype=np.float64)

    num_guest_atoms = len(guest_masses)
    num_host_atoms = len(host_masses)

    combined_masses = np.concatenate([host_masses, guest_masses])

    for handle in ff_handlers:
        results = handle.parameterize(guest_mol)
        if isinstance(handle, bonded.HarmonicBondHandler):
            bond_idxs, (bond_params, vjp_fn) = results
            bond_idxs += num_host_atoms
            combined_potentials.append(
                potentials.HarmonicBond(bond_idxs,
                                        precision=precision).bind(bond_params))
            combined_vjp_fns.append([(handle, vjp_fn)])
        elif isinstance(handle, bonded.HarmonicAngleHandler):
            angle_idxs, (angle_params, vjp_fn) = results
            angle_idxs += num_host_atoms
            combined_potentials.append(
                potentials.HarmonicAngle(
                    angle_idxs, precision=precision).bind(angle_params))
            combined_vjp_fns.append([(handle, vjp_fn)])
        elif isinstance(handle, bonded.ProperTorsionHandler):
            torsion_idxs, (torsion_params, vjp_fn) = results
            torsion_idxs += num_host_atoms
            combined_potentials.append(
                potentials.PeriodicTorsion(
                    torsion_idxs, precision=precision).bind(torsion_params))
            combined_vjp_fns.append([(handle, vjp_fn)])
        elif isinstance(handle, bonded.ImproperTorsionHandler):
            torsion_idxs, (torsion_params, vjp_fn) = results
            torsion_idxs += num_host_atoms
            combined_potentials.append(
                potentials.PeriodicTorsion(
                    torsion_idxs, precision=precision).bind(torsion_params))
            combined_vjp_fns.append([(handle, vjp_fn)])
        elif isinstance(handle, nonbonded.AM1CCCHandler):
            charge_handle = handle
            guest_charge_params, guest_charge_vjp_fn = results
        elif isinstance(handle, nonbonded.LennardJonesHandler):
            guest_lj_params, guest_lj_vjp_fn = results
            lj_handle = handle
        else:
            print("Warning: skipping handler", handle)
            pass

    # process nonbonded terms
    combined_nb_params, (charge_vjp_fn, lj_vjp_fn) = nonbonded_vjps(
        guest_charge_params, guest_charge_vjp_fn, guest_lj_params,
        guest_lj_vjp_fn, host_nb_bp.params)

    # these vjp_fns take in adjoints of combined_params and returns derivatives
    # appropriate to the underlying handler
    combined_vjp_fns.append([(charge_handle, charge_vjp_fn),
                             (lj_handle, lj_vjp_fn)])

    # tbd change scale 14 for electrostatics
    guest_exclusion_idxs, guest_scale_factors = nonbonded.generate_exclusion_idxs(
        guest_mol, scale12=1.0, scale13=1.0, scale14=0.5)

    # allow the ligand to be alchemically decoupled
    # a value of one indicates that we allow the atom to be adjusted by the lambda value
    guest_lambda_offset_idxs = np.ones(len(guest_masses), dtype=np.int32)

    # use same scale factors until we modify 1-4s for electrostatics
    guest_scale_factors = np.stack([guest_scale_factors, guest_scale_factors],
                                   axis=1)

    combined_lambda_offset_idxs = np.concatenate(
        [host_nb_bp.get_lambda_offset_idxs(), guest_lambda_offset_idxs])
    combined_exclusion_idxs = np.concatenate([
        host_nb_bp.get_exclusion_idxs(), guest_exclusion_idxs + num_host_atoms
    ])
    combined_scales = np.concatenate(
        [host_nb_bp.get_scale_factors(), guest_scale_factors])
    combined_beta = 2.0

    combined_cutoff = 1.0  # nonbonded cutoff

    combined_potentials.append(
        potentials.Nonbonded(combined_exclusion_idxs,
                             combined_scales,
                             combined_lambda_offset_idxs,
                             combined_beta,
                             combined_cutoff,
                             precision=precision).bind(combined_nb_params))

    return combined_potentials, combined_masses, combined_vjp_fns
Exemple #8
0
    def from_rdkit(cls, mol, ff_handlers):
        """
        Initialize a system from an RDKit ROMol. 
    
        Parameters
        ----------
        mol: Chem.ROMol
            RDKit ROMol. Should have graphical hydrogens in the topology.

        ff_handlers: list of forcefield handlers.
            openforcefield small molecule handlers.

        """
        masses = np.array([a.GetMass() for a in mol.GetAtoms()], dtype=np.float64)

        bound_potentials = []

        for handle in ff_handlers:
            results = handle.parameterize(mol)
            if isinstance(handle, bonded.HarmonicBondHandler):
                bond_params, bond_idxs = results
                bound_potentials.append(potentials.HarmonicBond(bond_idxs).bind(bond_params))
            elif isinstance(handle, bonded.HarmonicAngleHandler):
                angle_params, angle_idxs = results
                bound_potentials.append(potentials.HarmonicAngle(angle_idxs).bind(angle_params))
            elif isinstance(handle, bonded.ProperTorsionHandler):
                torsion_params, torsion_idxs = results
                bound_potentials.append(potentials.PeriodicTorsion(torsion_idxs).bind(torsion_params))
            elif isinstance(handle, bonded.ImproperTorsionHandler):
                torsion_params, torsion_idxs = results
                bound_potentials.append(potentials.PeriodicTorsion(torsion_idxs).bind(torsion_params))
            elif isinstance(handle, nonbonded.AM1CCCHandler):
                charge_handle = handle
                charge_params = results
            elif isinstance(handle, nonbonded.LennardJonesHandler):
                lj_params = results
                lj_handle = handle
            else:
                print("WARNING: skipping handler", handle)
                pass

        lambda_plane_idxs = np.zeros(len(masses), dtype=np.int32)
        lambda_offset_idxs = np.zeros(len(masses), dtype=np.int32)

        exclusion_idxs, scale_factors = nonbonded.generate_exclusion_idxs(
            mol,
            scale12=1.0,
            scale13=1.0,
            scale14=0.5
        )

        # use same scale factors until we modify 1-4s for electrostatics
        scale_factors = np.stack([scale_factors, scale_factors], axis=1)

        # (ytz) fix this later to not be so hard coded
        alpha = 2.0 # same as ewald alpha
        cutoff = 1.0 # nonbonded cutoff

        qlj_params = jnp.concatenate([
            jnp.reshape(charge_params, (-1, 1)),
            jnp.reshape(lj_params, (-1, 2))
        ], axis=1)


        bound_potentials.append(potentials.Nonbonded(
            exclusion_idxs,
            scale_factors,
            lambda_plane_idxs,
            lambda_offset_idxs,
            alpha,
            cutoff).bind(qlj_params))

        return cls(masses, bound_potentials)
Exemple #9
0
def create_system(guest_mol, host_pdb, handlers, restr_search_radius,
                  restr_force_constant, intg_temperature, stage):
    """
    Initialize a self-encompassing System object that we can serialize and simulate.

    Parameters
    ----------

    guest_mol: rdkit.ROMol
        guest molecule
        
    host_pdb: openmm.PDBFile
        host system from OpenMM

    handlers: list of timemachine.ops.Gradients
        forcefield handlers used to parameterize the system

    restr_search_radius: float
        how far away we search from the ligand to define the binding pocket atoms.

    restr_force_constant: float
        strength of the harmonic oscillator for the restraint

    intg_temperature: float
        temperature of the integrator in Kelvin

    stage: int (0 or 1)
        a free energy specific variable that determines how we decouple.
 
    """

    guest_masses = np.array([a.GetMass() for a in guest_mol.GetAtoms()],
                            dtype=np.float64)

    amber_ff = app.ForceField('amber99sbildn.xml', 'amber99_obc.xml')
    host_system = amber_ff.createSystem(host_pdb.topology,
                                        nonbondedMethod=app.NoCutoff,
                                        constraints=None,
                                        rigidWater=False)

    host_fns, host_masses = openmm_deserializer.deserialize_system(host_system)

    num_host_atoms = len(host_masses)
    num_guest_atoms = guest_mol.GetNumAtoms()

    # Name, Args, vjp_fn
    final_gradients = []

    for item in host_fns:

        if item[0] == 'LennardJones':
            host_lj_params = item[1]
        elif item[0] == 'Charges':
            host_charge_params = item[1]
        elif item[0] == 'GBSA':
            host_gb_params = item[1][0]
            host_gb_props = item[1][1:]
        elif item[0] == 'Exclusions':
            host_exclusions = item[1]
        else:
            final_gradients.append((item[0], item[1]))

    guest_exclusion_idxs, guest_scales = nonbonded.generate_exclusion_idxs(
        guest_mol, scale12=1.0, scale13=1.0, scale14=0.5)

    guest_exclusion_idxs += num_host_atoms
    guest_lj_exclusion_scales = guest_scales
    guest_charge_exclusion_scales = guest_scales

    host_exclusion_idxs = host_exclusions[0]
    host_lj_exclusion_scales = host_exclusions[1]
    host_charge_exclusion_scales = host_exclusions[2]

    combined_exclusion_idxs = np.concatenate(
        [host_exclusion_idxs, guest_exclusion_idxs])
    combined_lj_exclusion_scales = np.concatenate(
        [host_lj_exclusion_scales, guest_lj_exclusion_scales])
    combined_charge_exclusion_scales = np.concatenate(
        [host_charge_exclusion_scales, guest_charge_exclusion_scales])

    # We build up a map of handles to a corresponding vjp_fn that takes in adjoints of output parameters
    # for nonbonded terms, the vjp_fn has been modified to take in combined parameters
    handler_vjp_fns = {}

    for handle in handlers:
        results = handle.parameterize(guest_mol)

        if isinstance(handle, bonded.HarmonicBondHandler):
            bond_idxs, (bond_params, handler_vjp_fn) = results
            bond_idxs += num_host_atoms
            final_gradients.append(("HarmonicBond", (bond_idxs, bond_params)))
        elif isinstance(handle, bonded.HarmonicAngleHandler):
            angle_idxs, (angle_params, handler_vjp_fn) = results
            angle_idxs += num_host_atoms
            final_gradients.append(
                ("HarmonicAngle", (angle_idxs, angle_params)))
        elif isinstance(handle, bonded.ProperTorsionHandler):
            torsion_idxs, (torsion_params, handler_vjp_fn) = results
            torsion_idxs += num_host_atoms
            final_gradients.append(
                ("PeriodicTorsion", (torsion_idxs, torsion_params)))
        elif isinstance(handle, bonded.ImproperTorsionHandler):
            torsion_idxs, (torsion_params, handler_vjp_fn) = results
            torsion_idxs += num_host_atoms
            final_gradients.append(
                ("PeriodicTorsion", (torsion_idxs, torsion_params)))
        elif isinstance(handle, nonbonded.LennardJonesHandler):
            guest_lj_params, guest_lj_vjp_fn = results
            combined_lj_params, handler_vjp_fn = concat_with_vjps(
                host_lj_params, guest_lj_params, None, guest_lj_vjp_fn)
        elif isinstance(handle, nonbonded.SimpleChargeHandler):
            guest_charge_params, guest_charge_vjp_fn = results
            combined_charge_params, handler_vjp_fn = concat_with_vjps(
                host_charge_params, guest_charge_params, None,
                guest_charge_vjp_fn)
        elif isinstance(handle, nonbonded.GBSAHandler):
            guest_gb_params, guest_gb_vjp_fn = results
            combined_gb_params, handler_vjp_fn = concat_with_vjps(
                host_gb_params, guest_gb_params, None, guest_gb_vjp_fn)
        elif isinstance(handle, nonbonded.AM1BCCHandler):
            guest_charge_params, guest_charge_vjp_fn = results
            combined_charge_params, handler_vjp_fn = concat_with_vjps(
                host_charge_params, guest_charge_params, None,
                guest_charge_vjp_fn)
        elif isinstance(handle, nonbonded.AM1CCCHandler):
            guest_charge_params, guest_charge_vjp_fn = results
            combined_charge_params, handler_vjp_fn = concat_with_vjps(
                host_charge_params, guest_charge_params, None,
                guest_charge_vjp_fn)
        else:
            raise Exception("Unknown Handler", handle)

        handler_vjp_fns[handle] = handler_vjp_fn

    host_conf = []
    for x, y, z in host_pdb.positions:
        host_conf.append([to_md_units(x), to_md_units(y), to_md_units(z)])
    host_conf = np.array(host_conf)

    conformer = guest_mol.GetConformer(0)
    mol_a_conf = np.array(conformer.GetPositions(), dtype=np.float64)
    mol_a_conf = mol_a_conf / 10  # convert to md_units

    x0 = np.concatenate([host_conf, mol_a_conf])  # combined geometry
    v0 = np.zeros_like(x0)

    pocket_atoms = find_protein_pocket_atoms(x0, num_host_atoms,
                                             restr_search_radius)

    N_C = num_host_atoms + num_guest_atoms
    N_A = num_host_atoms

    cutoff = 100000.0

    if stage == 0:
        combined_lambda_plane_idxs = np.zeros(N_C, dtype=np.int32)
        combined_lambda_offset_idxs = np.zeros(N_C, dtype=np.int32)
    elif stage == 1:
        combined_lambda_plane_idxs = np.zeros(N_C, dtype=np.int32)
        combined_lambda_offset_idxs = np.zeros(N_C, dtype=np.int32)
        combined_lambda_offset_idxs[num_host_atoms:] = 1
    else:
        assert 0

    final_gradients.append(
        ('Nonbonded',
         (np.asarray(combined_charge_params), np.asarray(combined_lj_params),
          combined_exclusion_idxs, combined_charge_exclusion_scales,
          combined_lj_exclusion_scales, combined_lambda_plane_idxs,
          combined_lambda_offset_idxs, cutoff)))

    final_gradients.append(
        ('GBSA', (np.asarray(combined_charge_params),
                  np.asarray(combined_gb_params), combined_lambda_plane_idxs,
                  combined_lambda_offset_idxs, *host_gb_props, cutoff,
                  cutoff)))

    ligand_idxs = np.arange(N_A, N_C, dtype=np.int32)

    # restraints
    if stage == 0:
        lamb_flag = 1
        lamb_offset = 0
    if stage == 1:
        lamb_flag = 0
        lamb_offset = 1

    # unweighted center of mass restraints
    avg_xi = np.mean(x0[ligand_idxs], axis=0)
    avg_xj = np.mean(x0[pocket_atoms], axis=0)
    ctr_dij = np.sqrt(np.sum((avg_xi - avg_xj)**2))

    combined_masses = np.concatenate([host_masses, guest_masses])

    # restraints
    final_gradients.append(
        ('CentroidRestraint',
         (ligand_idxs, pocket_atoms, combined_masses, restr_force_constant,
          ctr_dij, lamb_flag, lamb_offset)))

    ssc = standard_state.harmonic_com_ssc(restr_force_constant, ctr_dij,
                                          intg_temperature)

    return x0, combined_masses, ssc, final_gradients, handler_vjp_fns