Пример #1
0
    def test_periodic_torsion(self, n_particles=64, n_torsions=25, dim=3):
        """Randomly connect quadruples of particles, then validate the resulting PeriodicTorsion force"""
        np.random.seed(125)

        x = self.get_random_coords(n_particles, dim)

        atom_idxs = np.arange(n_particles)
        params = np.random.rand(n_torsions, 3).astype(np.float64)
        torsion_idxs = []
        for _ in range(n_torsions):
            torsion_idxs.append(np.random.choice(atom_idxs, size=4, replace=False))

        torsion_idxs = np.array(torsion_idxs, dtype=np.int32)

        lamb = 0.0
        box = np.eye(3) * 100

        # specific to periodic torsion force
        relative_tolerance_at_precision = {np.float32: 2e-5, np.float64: 1e-9}

        for precision, rtol in relative_tolerance_at_precision.items():
            test_potential = potentials.PeriodicTorsion(torsion_idxs)
            ref_potential = functools.partial(bonded.periodic_torsion, torsion_idxs=torsion_idxs)

            self.compare_forces(x, params, box, lamb, ref_potential, test_potential, rtol, precision=precision)

        lamb_mult = np.random.randint(-5, 5, size=n_torsions, dtype=np.int32)
        lamb_offset = np.random.randint(-5, 5, size=n_torsions, dtype=np.int32)
        lamb = 0.35

        for precision, rtol in relative_tolerance_at_precision.items():
            test_potential = potentials.PeriodicTorsion(torsion_idxs, lamb_mult, lamb_offset)
            ref_potential = functools.partial(
                bonded.periodic_torsion, torsion_idxs=torsion_idxs, lamb_mult=lamb_mult, lamb_offset=lamb_offset
            )

            self.compare_forces(x, params, box, lamb, ref_potential, test_potential, rtol, precision=precision)
Пример #2
0
    def parameterize_periodic_torsion(self, proper_params, improper_params):
        """
        Parameterize all periodic torsions in the system.
        """
        proper_params, proper_potential = self.parameterize_proper_torsion(
            proper_params)
        improper_params, improper_potential = self.parameterize_improper_torsion(
            improper_params)

        combined_params = jnp.concatenate([proper_params, improper_params])
        combined_idxs = np.concatenate(
            [proper_potential.get_idxs(),
             improper_potential.get_idxs()])

        proper_lambda_mult = proper_potential.get_lambda_mult()
        proper_lambda_offset = proper_potential.get_lambda_offset()

        if proper_lambda_mult is None:
            proper_lambda_mult = np.zeros(len(proper_params))
        if proper_lambda_offset is None:
            proper_lambda_offset = np.ones(len(proper_params))

        improper_lambda_mult = improper_potential.get_lambda_mult()
        improper_lambda_offset = improper_potential.get_lambda_offset()

        if improper_lambda_mult is None:
            improper_lambda_mult = np.zeros(len(improper_params))
        if improper_lambda_offset is None:
            improper_lambda_offset = np.ones(len(improper_params))

        combined_lambda_mult = np.concatenate(
            [proper_lambda_mult, improper_lambda_mult]).astype(np.int32)
        combined_lambda_offset = np.concatenate(
            [proper_lambda_offset, improper_lambda_offset]).astype(np.int32)

        combined_potential = potentials.PeriodicTorsion(
            combined_idxs, combined_lambda_mult, combined_lambda_offset)
        return combined_params, combined_potential
Пример #3
0
 def parameterize_periodic_torsion(self, proper_params, improper_params):
     """
     Parameterize all periodic torsions in the system.
     """
     proper_params, proper_potential = self.parameterize_proper_torsion(
         proper_params)
     improper_params, improper_potential = self.parameterize_improper_torsion(
         improper_params)
     combined_params = jnp.concatenate([proper_params, improper_params])
     combined_idxs = np.concatenate(
         [proper_potential.get_idxs(),
          improper_potential.get_idxs()])
     combined_lambda_mult = np.concatenate([
         proper_potential.get_lambda_mult(),
         improper_potential.get_lambda_mult()
     ])
     combined_lambda_offset = np.concatenate([
         proper_potential.get_lambda_offset(),
         improper_potential.get_lambda_offset()
     ])
     combined_potential = potentials.PeriodicTorsion(
         combined_idxs, combined_lambda_mult, combined_lambda_offset)
     return combined_params, combined_potential
Пример #4
0
 def parameterize_improper_torsion(self, ff_params):
     params, idxs = self.ff.it_handle.partial_parameterize(ff_params, self.mol)
     return params, potentials.PeriodicTorsion(idxs)
Пример #5
0
def combine_potentials(ff_handlers, guest_mol, host_system, precision):
    """
    This function is responsible for figuring out how to take two separate hamiltonians
    and combining them into one sensible alchemical system.

    Parameters
    ----------

    ff_handlers: list of forcefield handlers
        Small molecule forcefield handlers

    guest_mol: Chem.ROMol
        RDKit molecule

    host_system: openmm.System
        Host system to be deserialized

    precision: np.float32 or np.float64
        Numerical precision of the functional form

    Returns
    -------
    tuple
        Returns a list of lib.potentials objects, combined masses, and a list of
        their corresponding vjp_fns back into the forcefield

    """

    host_potentials, host_masses = openmm_deserializer.deserialize_system(
        host_system, precision, cutoff=1.0)

    host_nb_bp = None

    combined_potentials = []
    combined_vjp_fns = []

    for bp in host_potentials:
        if isinstance(bp, potentials.Nonbonded):
            # (ytz): hack to ensure we only have one nonbonded term
            assert host_nb_bp is None
            host_nb_bp = bp
        else:
            combined_potentials.append(bp)
            combined_vjp_fns.append([])

    guest_masses = np.array([a.GetMass() for a in guest_mol.GetAtoms()],
                            dtype=np.float64)

    num_guest_atoms = len(guest_masses)
    num_host_atoms = len(host_masses)

    combined_masses = np.concatenate([host_masses, guest_masses])

    for handle in ff_handlers:
        results = handle.parameterize(guest_mol)
        if isinstance(handle, bonded.HarmonicBondHandler):
            bond_idxs, (bond_params, vjp_fn) = results
            bond_idxs += num_host_atoms
            combined_potentials.append(
                potentials.HarmonicBond(bond_idxs,
                                        precision=precision).bind(bond_params))
            combined_vjp_fns.append([(handle, vjp_fn)])
        elif isinstance(handle, bonded.HarmonicAngleHandler):
            angle_idxs, (angle_params, vjp_fn) = results
            angle_idxs += num_host_atoms
            combined_potentials.append(
                potentials.HarmonicAngle(
                    angle_idxs, precision=precision).bind(angle_params))
            combined_vjp_fns.append([(handle, vjp_fn)])
        elif isinstance(handle, bonded.ProperTorsionHandler):
            torsion_idxs, (torsion_params, vjp_fn) = results
            torsion_idxs += num_host_atoms
            combined_potentials.append(
                potentials.PeriodicTorsion(
                    torsion_idxs, precision=precision).bind(torsion_params))
            combined_vjp_fns.append([(handle, vjp_fn)])
        elif isinstance(handle, bonded.ImproperTorsionHandler):
            torsion_idxs, (torsion_params, vjp_fn) = results
            torsion_idxs += num_host_atoms
            combined_potentials.append(
                potentials.PeriodicTorsion(
                    torsion_idxs, precision=precision).bind(torsion_params))
            combined_vjp_fns.append([(handle, vjp_fn)])
        elif isinstance(handle, nonbonded.AM1CCCHandler):
            charge_handle = handle
            guest_charge_params, guest_charge_vjp_fn = results
        elif isinstance(handle, nonbonded.LennardJonesHandler):
            guest_lj_params, guest_lj_vjp_fn = results
            lj_handle = handle
        else:
            print("Warning: skipping handler", handle)
            pass

    # process nonbonded terms
    combined_nb_params, (charge_vjp_fn, lj_vjp_fn) = nonbonded_vjps(
        guest_charge_params, guest_charge_vjp_fn, guest_lj_params,
        guest_lj_vjp_fn, host_nb_bp.params)

    # these vjp_fns take in adjoints of combined_params and returns derivatives
    # appropriate to the underlying handler
    combined_vjp_fns.append([(charge_handle, charge_vjp_fn),
                             (lj_handle, lj_vjp_fn)])

    # tbd change scale 14 for electrostatics
    guest_exclusion_idxs, guest_scale_factors = nonbonded.generate_exclusion_idxs(
        guest_mol, scale12=1.0, scale13=1.0, scale14=0.5)

    # allow the ligand to be alchemically decoupled
    # a value of one indicates that we allow the atom to be adjusted by the lambda value
    guest_lambda_offset_idxs = np.ones(len(guest_masses), dtype=np.int32)

    # use same scale factors until we modify 1-4s for electrostatics
    guest_scale_factors = np.stack([guest_scale_factors, guest_scale_factors],
                                   axis=1)

    combined_lambda_offset_idxs = np.concatenate(
        [host_nb_bp.get_lambda_offset_idxs(), guest_lambda_offset_idxs])
    combined_exclusion_idxs = np.concatenate([
        host_nb_bp.get_exclusion_idxs(), guest_exclusion_idxs + num_host_atoms
    ])
    combined_scales = np.concatenate(
        [host_nb_bp.get_scale_factors(), guest_scale_factors])
    combined_beta = 2.0

    combined_cutoff = 1.0  # nonbonded cutoff

    combined_potentials.append(
        potentials.Nonbonded(
            combined_exclusion_idxs,
            combined_scales,
            combined_lambda_offset_idxs,
            combined_beta,
            combined_cutoff,
            precision=precision,
        ).bind(combined_nb_params))

    return combined_potentials, combined_masses, combined_vjp_fns
Пример #6
0
    def from_rdkit(cls, mol, ff_handlers):
        """
        Initialize a system from an RDKit ROMol. 
    
        Parameters
        ----------
        mol: Chem.ROMol
            RDKit ROMol. Should have graphical hydrogens in the topology.

        ff_handlers: list of forcefield handlers.
            openforcefield small molecule handlers.

        """
        masses = np.array([a.GetMass() for a in mol.GetAtoms()], dtype=np.float64)

        bound_potentials = []

        for handle in ff_handlers:
            results = handle.parameterize(mol)
            if isinstance(handle, bonded.HarmonicBondHandler):
                bond_params, bond_idxs = results
                bound_potentials.append(potentials.HarmonicBond(bond_idxs).bind(bond_params))
            elif isinstance(handle, bonded.HarmonicAngleHandler):
                angle_params, angle_idxs = results
                bound_potentials.append(potentials.HarmonicAngle(angle_idxs).bind(angle_params))
            elif isinstance(handle, bonded.ProperTorsionHandler):
                torsion_params, torsion_idxs = results
                bound_potentials.append(potentials.PeriodicTorsion(torsion_idxs).bind(torsion_params))
            elif isinstance(handle, bonded.ImproperTorsionHandler):
                torsion_params, torsion_idxs = results
                bound_potentials.append(potentials.PeriodicTorsion(torsion_idxs).bind(torsion_params))
            elif isinstance(handle, nonbonded.AM1CCCHandler):
                charge_handle = handle
                charge_params = results
            elif isinstance(handle, nonbonded.LennardJonesHandler):
                lj_params = results
                lj_handle = handle
            else:
                print("WARNING: skipping handler", handle)
                pass

        lambda_plane_idxs = np.zeros(len(masses), dtype=np.int32)
        lambda_offset_idxs = np.zeros(len(masses), dtype=np.int32)

        exclusion_idxs, scale_factors = nonbonded.generate_exclusion_idxs(
            mol,
            scale12=1.0,
            scale13=1.0,
            scale14=0.5
        )

        # use same scale factors until we modify 1-4s for electrostatics
        scale_factors = np.stack([scale_factors, scale_factors], axis=1)

        # (ytz) fix this later to not be so hard coded
        alpha = 2.0 # same as ewald alpha
        cutoff = 1.0 # nonbonded cutoff

        qlj_params = jnp.concatenate([
            jnp.reshape(charge_params, (-1, 1)),
            jnp.reshape(lj_params, (-1, 2))
        ], axis=1)


        bound_potentials.append(potentials.Nonbonded(
            exclusion_idxs,
            scale_factors,
            lambda_plane_idxs,
            lambda_offset_idxs,
            alpha,
            cutoff).bind(qlj_params))

        return cls(masses, bound_potentials)
Пример #7
0
def deserialize_system(system, cutoff):
    """
    Deserialize an OpenMM XML file

    Parameters
    ----------
    system: openmm.System
        A system object to be deserialized

    Returns
    -------
    list of lib.Potential, masses

    Note: We add a small epsilon (1e-3) to all zero eps values to prevent
    a singularity from occuring in the lennard jones derivatives

    """

    masses = []

    for p in range(system.getNumParticles()):
        masses.append(value(system.getParticleMass(p)))

    N = len(masses)

    # this should not be a dict since we may have more than one instance of a given
    # force.
    bps = []

    for force in system.getForces():

        if isinstance(force, mm.HarmonicBondForce):
            bond_idxs = []
            bond_params = []

            for b_idx in range(force.getNumBonds()):
                src_idx, dst_idx, length, k = force.getBondParameters(b_idx)
                length = value(length)
                k = value(k)

                bond_idxs.append([src_idx, dst_idx])
                bond_params.append((k, length))

            bond_idxs = np.array(bond_idxs, dtype=np.int32)
            bond_params = np.array(bond_params, dtype=np.float64)
            bps.append(potentials.HarmonicBond(bond_idxs).bind(bond_params))

        if isinstance(force, mm.HarmonicAngleForce):

            angle_idxs = []
            angle_params = []

            for a_idx in range(force.getNumAngles()):

                src_idx, mid_idx, dst_idx, angle, k = force.getAngleParameters(
                    a_idx)
                angle = value(angle)
                k = value(k)

                angle_idxs.append([src_idx, mid_idx, dst_idx])
                angle_params.append((k, angle))

            angle_idxs = np.array(angle_idxs, dtype=np.int32)
            angle_params = np.array(angle_params, dtype=np.float64)

            bps.append(potentials.HarmonicAngle(angle_idxs).bind(angle_params))

        if isinstance(force, mm.PeriodicTorsionForce):

            torsion_idxs = []
            torsion_params = []

            for t_idx in range(force.getNumTorsions()):
                a_idx, b_idx, c_idx, d_idx, period, phase, k = force.getTorsionParameters(
                    t_idx)

                phase = value(phase)
                k = value(k)

                torsion_params.append((k, phase, period))
                torsion_idxs.append([a_idx, b_idx, c_idx, d_idx])

            torsion_idxs = np.array(torsion_idxs, dtype=np.int32)
            torsion_params = np.array(torsion_params, dtype=np.float64)
            bps.append(
                potentials.PeriodicTorsion(torsion_idxs).bind(torsion_params))

        if isinstance(force, mm.NonbondedForce):

            num_atoms = force.getNumParticles()

            charge_params = []
            lj_params = []

            for a_idx in range(num_atoms):

                charge, sig, eps = force.getParticleParameters(a_idx)
                charge = value(charge) * np.sqrt(constants.ONE_4PI_EPS0)

                sig = value(sig)
                eps = value(eps)

                # increment eps by 1e-3 if we have eps==0 to avoid a singularity in parameter derivatives
                # override default amber types

                # this doesn't work for water!
                # if eps == 0:
                # print("Warning: overriding eps by 1e-3 to avoid a singularity")
                # eps += 1e-3

                # charge_params.append(charge_idx)
                charge_params.append(charge)
                lj_params.append((sig, eps))

            charge_params = np.array(charge_params, dtype=np.float64)

            # print("Protein net charge:", np.sum(np.array(global_params)[charge_param_idxs]))
            lj_params = np.array(lj_params, dtype=np.float64)

            # 1 here means we fully remove the interaction
            # 1-2, 1-3
            # scale_full = insert_parameters(1.0, 20)

            # 1-4, remove half of the interaction
            # scale_half = insert_parameters(0.5, 21)

            exclusion_idxs = []
            scale_factors = []

            all_sig = lj_params[:, 0]
            all_eps = lj_params[:, 1]

            # validate exclusions/exceptions to make sure they make sense
            for a_idx in range(force.getNumExceptions()):

                # tbd charge scale factors
                src, dst, new_cp, new_sig, new_eps = force.getExceptionParameters(
                    a_idx)
                new_sig = value(new_sig)
                new_eps = value(new_eps)

                src_sig = all_sig[src]
                dst_sig = all_sig[dst]

                src_eps = all_eps[src]
                dst_eps = all_eps[dst]
                expected_sig = (src_sig + dst_sig) / 2
                expected_eps = np.sqrt(src_eps * dst_eps)

                exclusion_idxs.append([src, dst])

                # sanity check this (expected_eps can be zero), redo this thing

                # the lj_scale factor measures how much we *remove*
                if expected_eps == 0:
                    if new_eps == 0:
                        lj_scale_factor = 1
                    else:
                        raise RuntimeError(
                            "Divide by zero in epsilon calculation")
                else:
                    lj_scale_factor = 1 - new_eps / expected_eps

                scale_factors.append(lj_scale_factor)

                # tbd fix charge_scale_factors using new_cp
                if new_eps != 0:
                    np.testing.assert_almost_equal(expected_sig, new_sig)

            exclusion_idxs = np.array(exclusion_idxs, dtype=np.int32)

            lambda_plane_idxs = np.zeros(N, dtype=np.int32)
            lambda_offset_idxs = np.zeros(N, dtype=np.int32)

            # cutoff = 1000.0

            nb_params = np.concatenate(
                [np.expand_dims(charge_params, axis=1), lj_params], axis=1)

            # optimizations
            nb_params[:, 1] = nb_params[:, 1] / 2
            nb_params[:, 2] = np.sqrt(nb_params[:, 2])

            beta = 2.0  # erfc correction

            # use the same scale factors for electrostatics and lj
            scale_factors = np.stack([scale_factors, scale_factors], axis=1)

            bps.append(
                potentials.Nonbonded(exclusion_idxs, scale_factors,
                                     lambda_plane_idxs, lambda_offset_idxs,
                                     beta, cutoff).bind(nb_params))

            # nrg_fns.append(('Exclusions', (exclusion_idxs, scale_factors, es_scale_factors)))

    return bps, masses