def parameterize_nonbonded(self, ff_q_params, ff_lj_params): # dummy is either "a or "b" q_params_a = self.ff.q_handle.partial_parameterize( ff_q_params, self.mol_a) q_params_b = self.ff.q_handle.partial_parameterize( ff_q_params, self.mol_b) lj_params_a = self.ff.lj_handle.partial_parameterize( ff_lj_params, self.mol_a) lj_params_b = self.ff.lj_handle.partial_parameterize( ff_lj_params, self.mol_b) q_params = jnp.concatenate([q_params_a, q_params_b]) lj_params = jnp.concatenate([lj_params_a, lj_params_b]) exclusion_idxs_a, scale_factors_a = nonbonded.generate_exclusion_idxs( self.mol_a, scale12=_SCALE_12, scale13=_SCALE_13, scale14=_SCALE_14) exclusion_idxs_b, scale_factors_b = nonbonded.generate_exclusion_idxs( self.mol_b, scale12=_SCALE_12, scale13=_SCALE_13, scale14=_SCALE_14) mutual_exclusions = [] mutual_scale_factors = [] NA = self.mol_a.GetNumAtoms() NB = self.mol_b.GetNumAtoms() for i in range(NA): for j in range(NB): mutual_exclusions.append([i, j + NA]) mutual_scale_factors.append([1.0, 1.0]) mutual_exclusions = np.array(mutual_exclusions) mutual_scale_factors = np.array(mutual_scale_factors) combined_exclusion_idxs = np.concatenate( [exclusion_idxs_a, exclusion_idxs_b + NA, mutual_exclusions]).astype(np.int32) combined_scale_factors = np.concatenate([ np.stack([scale_factors_a, scale_factors_a], axis=1), np.stack([scale_factors_b, scale_factors_b], axis=1), mutual_scale_factors, ]).astype(np.float64) combined_lambda_plane_idxs = None combined_lambda_offset_idxs = None beta = _BETA cutoff = _CUTOFF # solve for this analytically later qlj_params = jnp.concatenate( [jnp.reshape(q_params, (-1, 1)), jnp.reshape(lj_params, (-1, 2))], axis=1) return qlj_params, potentials.Nonbonded( combined_exclusion_idxs, combined_scale_factors, combined_lambda_plane_idxs, combined_lambda_offset_idxs, beta, cutoff, )
def combine_potentials(ff_handlers, guest_mol, host_system, precision): """ This function is responsible for figuring out how to take two separate hamiltonians and combining them into one sensible alchemical system. Parameters ---------- ff_handlers: list of forcefield handlers Small molecule forcefield handlers guest_mol: Chem.ROMol RDKit molecule host_system: openmm.System Host system to be deserialized precision: np.float32 or np.float64 Numerical precision of the functional form Returns ------- tuple Returns a list of lib.potentials objects, combined masses, and a list of their corresponding vjp_fns back into the forcefield """ host_potentials, host_masses = openmm_deserializer.deserialize_system( host_system, precision, cutoff=1.0) host_nb_bp = None combined_potentials = [] combined_vjp_fns = [] for bp in host_potentials: if isinstance(bp, potentials.Nonbonded): # (ytz): hack to ensure we only have one nonbonded term assert host_nb_bp is None host_nb_bp = bp else: combined_potentials.append(bp) combined_vjp_fns.append([]) guest_masses = np.array([a.GetMass() for a in guest_mol.GetAtoms()], dtype=np.float64) num_guest_atoms = len(guest_masses) num_host_atoms = len(host_masses) combined_masses = np.concatenate([host_masses, guest_masses]) for handle in ff_handlers: results = handle.parameterize(guest_mol) if isinstance(handle, bonded.HarmonicBondHandler): bond_idxs, (bond_params, vjp_fn) = results bond_idxs += num_host_atoms combined_potentials.append( potentials.HarmonicBond(bond_idxs, precision=precision).bind(bond_params)) combined_vjp_fns.append([(handle, vjp_fn)]) elif isinstance(handle, bonded.HarmonicAngleHandler): angle_idxs, (angle_params, vjp_fn) = results angle_idxs += num_host_atoms combined_potentials.append( potentials.HarmonicAngle( angle_idxs, precision=precision).bind(angle_params)) combined_vjp_fns.append([(handle, vjp_fn)]) elif isinstance(handle, bonded.ProperTorsionHandler): torsion_idxs, (torsion_params, vjp_fn) = results torsion_idxs += num_host_atoms combined_potentials.append( potentials.PeriodicTorsion( torsion_idxs, precision=precision).bind(torsion_params)) combined_vjp_fns.append([(handle, vjp_fn)]) elif isinstance(handle, bonded.ImproperTorsionHandler): torsion_idxs, (torsion_params, vjp_fn) = results torsion_idxs += num_host_atoms combined_potentials.append( potentials.PeriodicTorsion( torsion_idxs, precision=precision).bind(torsion_params)) combined_vjp_fns.append([(handle, vjp_fn)]) elif isinstance(handle, nonbonded.AM1CCCHandler): charge_handle = handle guest_charge_params, guest_charge_vjp_fn = results elif isinstance(handle, nonbonded.LennardJonesHandler): guest_lj_params, guest_lj_vjp_fn = results lj_handle = handle else: print("Warning: skipping handler", handle) pass # process nonbonded terms combined_nb_params, (charge_vjp_fn, lj_vjp_fn) = nonbonded_vjps( guest_charge_params, guest_charge_vjp_fn, guest_lj_params, guest_lj_vjp_fn, host_nb_bp.params) # these vjp_fns take in adjoints of combined_params and returns derivatives # appropriate to the underlying handler combined_vjp_fns.append([(charge_handle, charge_vjp_fn), (lj_handle, lj_vjp_fn)]) # tbd change scale 14 for electrostatics guest_exclusion_idxs, guest_scale_factors = nonbonded.generate_exclusion_idxs( guest_mol, scale12=1.0, scale13=1.0, scale14=0.5) # allow the ligand to be alchemically decoupled # a value of one indicates that we allow the atom to be adjusted by the lambda value guest_lambda_offset_idxs = np.ones(len(guest_masses), dtype=np.int32) # use same scale factors until we modify 1-4s for electrostatics guest_scale_factors = np.stack([guest_scale_factors, guest_scale_factors], axis=1) combined_lambda_offset_idxs = np.concatenate( [host_nb_bp.get_lambda_offset_idxs(), guest_lambda_offset_idxs]) combined_exclusion_idxs = np.concatenate([ host_nb_bp.get_exclusion_idxs(), guest_exclusion_idxs + num_host_atoms ]) combined_scales = np.concatenate( [host_nb_bp.get_scale_factors(), guest_scale_factors]) combined_beta = 2.0 combined_cutoff = 1.0 # nonbonded cutoff combined_potentials.append( potentials.Nonbonded( combined_exclusion_idxs, combined_scales, combined_lambda_offset_idxs, combined_beta, combined_cutoff, precision=precision, ).bind(combined_nb_params)) return combined_potentials, combined_masses, combined_vjp_fns
def parameterize_nonbonded(self, ff_q_params, ff_lj_params): num_guest_atoms = self.guest_topology.get_num_atoms() guest_qlj, guest_p = self.guest_topology.parameterize_nonbonded( ff_q_params, ff_lj_params) if isinstance(guest_p, potentials.NonbondedInterpolated): assert guest_qlj.shape[0] == num_guest_atoms * 2 is_interpolated = True else: assert guest_qlj.shape[0] == num_guest_atoms is_interpolated = False # see if we're doing parameter interpolation assert guest_qlj.shape[1] == 3 assert guest_p.get_beta() == self.host_nonbonded.get_beta() assert guest_p.get_cutoff() == self.host_nonbonded.get_cutoff() hg_exclusion_idxs = np.concatenate([ self.host_nonbonded.get_exclusion_idxs(), guest_p.get_exclusion_idxs() + self.num_host_atoms ]) hg_scale_factors = np.concatenate([ self.host_nonbonded.get_scale_factors(), guest_p.get_scale_factors() ]) hg_lambda_offset_idxs = np.concatenate([ self.host_nonbonded.get_lambda_offset_idxs(), guest_p.get_lambda_offset_idxs() ]) hg_lambda_plane_idxs = np.concatenate([ self.host_nonbonded.get_lambda_plane_idxs(), guest_p.get_lambda_plane_idxs() ]) if is_interpolated: # with parameter interpolation hg_nb_params_src = jnp.concatenate( [self.host_nonbonded.params, guest_qlj[:num_guest_atoms]]) hg_nb_params_dst = jnp.concatenate( [self.host_nonbonded.params, guest_qlj[num_guest_atoms:]]) hg_nb_params = jnp.concatenate( [hg_nb_params_src, hg_nb_params_dst]) nb = potentials.NonbondedInterpolated( hg_exclusion_idxs, hg_scale_factors, hg_lambda_plane_idxs, hg_lambda_offset_idxs, guest_p.get_beta(), guest_p.get_cutoff(), ) return hg_nb_params, nb else: # no parameter interpolation hg_nb_params = jnp.concatenate( [self.host_nonbonded.params, guest_qlj]) return hg_nb_params, potentials.Nonbonded( hg_exclusion_idxs, hg_scale_factors, hg_lambda_plane_idxs, hg_lambda_offset_idxs, guest_p.get_beta(), guest_p.get_cutoff(), )
def prepare_water_system(x, lambda_plane_idxs, lambda_offset_idxs, p_scale, cutoff): assert x.ndim == 2 N = x.shape[0] # D = x.shape[1] assert N % 3 == 0 params = np.stack( [ (np.random.rand(N).astype(np.float64) - 0.5) * np.sqrt(ONE_4PI_EPS0), # q np.random.rand(N).astype(np.float64) / 5.0, # sig np.random.rand(N).astype(np.float64), # eps ], axis=1, ) params[:, 1] = params[:, 1] / 2 params[:, 2] = np.sqrt(params[:, 2]) scales = [] exclusion_idxs = [] for i in range(N // 3): O_idx = i * 3 + 0 H1_idx = i * 3 + 1 H2_idx = i * 3 + 2 exclusion_idxs.append([O_idx, H1_idx]) # 1-2 exclusion_idxs.append([O_idx, H2_idx]) # 1-2 exclusion_idxs.append([H1_idx, H2_idx]) # 1-3 scales.append([1.0, 1.0]) scales.append([1.0, 1.0]) scales.append([np.random.rand(), np.random.rand()]) scales = np.array(scales, dtype=np.float64) exclusion_idxs = np.array(exclusion_idxs, dtype=np.int32) beta = 2.0 test_potential = potentials.Nonbonded(exclusion_idxs, scales, lambda_plane_idxs, lambda_offset_idxs, beta, cutoff) charge_rescale_mask = np.ones((N, N)) for (i, j), exc in zip(exclusion_idxs, scales[:, 0]): charge_rescale_mask[i][j] = 1 - exc charge_rescale_mask[j][i] = 1 - exc lj_rescale_mask = np.ones((N, N)) for (i, j), exc in zip(exclusion_idxs, scales[:, 1]): lj_rescale_mask[i][j] = 1 - exc lj_rescale_mask[j][i] = 1 - exc ref_total_energy = functools.partial( nonbonded.nonbonded_v3, charge_rescale_mask=charge_rescale_mask, lj_rescale_mask=lj_rescale_mask, beta=beta, cutoff=cutoff, lambda_plane_idxs=lambda_plane_idxs, lambda_offset_idxs=lambda_offset_idxs, runtime_validate=False, ) return params, ref_total_energy, test_potential
def from_rdkit(cls, mol, ff_handlers): """ Initialize a system from an RDKit ROMol. Parameters ---------- mol: Chem.ROMol RDKit ROMol. Should have graphical hydrogens in the topology. ff_handlers: list of forcefield handlers. openforcefield small molecule handlers. """ masses = np.array([a.GetMass() for a in mol.GetAtoms()], dtype=np.float64) bound_potentials = [] for handle in ff_handlers: results = handle.parameterize(mol) if isinstance(handle, bonded.HarmonicBondHandler): bond_params, bond_idxs = results bound_potentials.append(potentials.HarmonicBond(bond_idxs).bind(bond_params)) elif isinstance(handle, bonded.HarmonicAngleHandler): angle_params, angle_idxs = results bound_potentials.append(potentials.HarmonicAngle(angle_idxs).bind(angle_params)) elif isinstance(handle, bonded.ProperTorsionHandler): torsion_params, torsion_idxs = results bound_potentials.append(potentials.PeriodicTorsion(torsion_idxs).bind(torsion_params)) elif isinstance(handle, bonded.ImproperTorsionHandler): torsion_params, torsion_idxs = results bound_potentials.append(potentials.PeriodicTorsion(torsion_idxs).bind(torsion_params)) elif isinstance(handle, nonbonded.AM1CCCHandler): charge_handle = handle charge_params = results elif isinstance(handle, nonbonded.LennardJonesHandler): lj_params = results lj_handle = handle else: print("WARNING: skipping handler", handle) pass lambda_plane_idxs = np.zeros(len(masses), dtype=np.int32) lambda_offset_idxs = np.zeros(len(masses), dtype=np.int32) exclusion_idxs, scale_factors = nonbonded.generate_exclusion_idxs( mol, scale12=1.0, scale13=1.0, scale14=0.5 ) # use same scale factors until we modify 1-4s for electrostatics scale_factors = np.stack([scale_factors, scale_factors], axis=1) # (ytz) fix this later to not be so hard coded alpha = 2.0 # same as ewald alpha cutoff = 1.0 # nonbonded cutoff qlj_params = jnp.concatenate([ jnp.reshape(charge_params, (-1, 1)), jnp.reshape(lj_params, (-1, 2)) ], axis=1) bound_potentials.append(potentials.Nonbonded( exclusion_idxs, scale_factors, lambda_plane_idxs, lambda_offset_idxs, alpha, cutoff).bind(qlj_params)) return cls(masses, bound_potentials)
def combine(self, other): """ Combine two recipes together. self will keep its original indexing, while other will be incremented. This method automatically increments indices in the potential functions accordingly. For nonbonded terms, the recipe does a straight forward concatenation of the lambda idxs. Parameters ---------- other: Recipe the right hand side recipe to combine with Returns ------- Recipe combined recipe """ self_num_atoms = len(self.masses) combined_masses = np.concatenate([self.masses, other.masses]) combined_bound_potentials = [] for bp in self.bound_potentials: if isinstance(bp, potentials.Nonbonded): # save these parameters for the merge part. self_nb_params = bp.params self_nb_exclusions = bp.get_exclusion_idxs() self_nb_scale_factors = bp.get_scale_factors() self_nb_cutoff = bp.get_cutoff() self_nb_beta = bp.get_beta() self_nb_lambda_plane_idxs = bp.get_lambda_plane_idxs() self_nb_lambda_offset_idxs = bp.get_lambda_offset_idxs() else: combined_bound_potentials.append(bp) for full_obp in other.bound_potentials: # always deepcopy to prevent modifying original copy full_obp = copy.deepcopy(full_obp) # if this is a lambda potential we replace only the .u_fn part if isinstance(full_obp, potentials.LambdaPotential): full_obp.set_N(full_obp.get_N() + self_num_atoms) obp = full_obp.get_u_fn() else: obp = full_obp if isinstance(obp, potentials.HarmonicBond) or isinstance(obp, potentials.CoreRestraint): idxs = obp.get_idxs() idxs += self_num_atoms # modify inplace elif isinstance(obp, potentials.HarmonicAngle): idxs = obp.get_idxs() idxs += self_num_atoms # modify inplace elif isinstance(obp, potentials.PeriodicTorsion): idxs = obp.get_idxs() idxs += self_num_atoms # modify inplace elif isinstance(obp, potentials.CentroidRestraint): a_idxs = obp.get_a_idxs() a_idxs += self_num_atoms # modify inplace b_idxs = obp.get_b_idxs() b_idxs += self_num_atoms # modify inplace # adjust masses obp.set_masses(np.concatenate([np.zeros(self_num_atoms), obp.get_masses()])) elif isinstance(obp, potentials.Nonbonded): assert self_nb_cutoff == obp.get_cutoff() assert self_nb_beta == obp.get_beta() combined_nb_params = jnp.concatenate([self_nb_params, obp.params]) combined_exclusion_idxs = np.concatenate([self_nb_exclusions, obp.get_exclusion_idxs() + self_num_atoms]) combined_scale_factors = np.concatenate([self_nb_scale_factors, obp.get_scale_factors()]) combined_lambda_offset_idxs = np.concatenate([self_nb_lambda_offset_idxs, obp.get_lambda_offset_idxs()]) combined_lambda_plane_idxs = np.concatenate([self_nb_lambda_plane_idxs, obp.get_lambda_plane_idxs()]) # (ytz): leave this in for now # sanity check to ensure that the chain rules are working dummy = np.ones_like(combined_nb_params) obp = potentials.Nonbonded( combined_exclusion_idxs, combined_scale_factors, combined_lambda_plane_idxs, combined_lambda_offset_idxs, self_nb_beta, self_nb_cutoff).bind(combined_nb_params) else: raise Exception("Unknown functional form") if isinstance(full_obp, potentials.LambdaPotential): combined_bound_potentials.append(full_obp) else: combined_bound_potentials.append(obp) return Recipe(combined_masses, combined_bound_potentials)
def deserialize_system(system, cutoff): """ Deserialize an OpenMM XML file Parameters ---------- system: openmm.System A system object to be deserialized Returns ------- list of lib.Potential, masses Note: We add a small epsilon (1e-3) to all zero eps values to prevent a singularity from occuring in the lennard jones derivatives """ masses = [] for p in range(system.getNumParticles()): masses.append(value(system.getParticleMass(p))) N = len(masses) # this should not be a dict since we may have more than one instance of a given # force. bps = [] for force in system.getForces(): if isinstance(force, mm.HarmonicBondForce): bond_idxs = [] bond_params = [] for b_idx in range(force.getNumBonds()): src_idx, dst_idx, length, k = force.getBondParameters(b_idx) length = value(length) k = value(k) bond_idxs.append([src_idx, dst_idx]) bond_params.append((k, length)) bond_idxs = np.array(bond_idxs, dtype=np.int32) bond_params = np.array(bond_params, dtype=np.float64) bps.append(potentials.HarmonicBond(bond_idxs).bind(bond_params)) if isinstance(force, mm.HarmonicAngleForce): angle_idxs = [] angle_params = [] for a_idx in range(force.getNumAngles()): src_idx, mid_idx, dst_idx, angle, k = force.getAngleParameters( a_idx) angle = value(angle) k = value(k) angle_idxs.append([src_idx, mid_idx, dst_idx]) angle_params.append((k, angle)) angle_idxs = np.array(angle_idxs, dtype=np.int32) angle_params = np.array(angle_params, dtype=np.float64) bps.append(potentials.HarmonicAngle(angle_idxs).bind(angle_params)) if isinstance(force, mm.PeriodicTorsionForce): torsion_idxs = [] torsion_params = [] for t_idx in range(force.getNumTorsions()): a_idx, b_idx, c_idx, d_idx, period, phase, k = force.getTorsionParameters( t_idx) phase = value(phase) k = value(k) torsion_params.append((k, phase, period)) torsion_idxs.append([a_idx, b_idx, c_idx, d_idx]) torsion_idxs = np.array(torsion_idxs, dtype=np.int32) torsion_params = np.array(torsion_params, dtype=np.float64) bps.append( potentials.PeriodicTorsion(torsion_idxs).bind(torsion_params)) if isinstance(force, mm.NonbondedForce): num_atoms = force.getNumParticles() charge_params = [] lj_params = [] for a_idx in range(num_atoms): charge, sig, eps = force.getParticleParameters(a_idx) charge = value(charge) * np.sqrt(constants.ONE_4PI_EPS0) sig = value(sig) eps = value(eps) # increment eps by 1e-3 if we have eps==0 to avoid a singularity in parameter derivatives # override default amber types # this doesn't work for water! # if eps == 0: # print("Warning: overriding eps by 1e-3 to avoid a singularity") # eps += 1e-3 # charge_params.append(charge_idx) charge_params.append(charge) lj_params.append((sig, eps)) charge_params = np.array(charge_params, dtype=np.float64) # print("Protein net charge:", np.sum(np.array(global_params)[charge_param_idxs])) lj_params = np.array(lj_params, dtype=np.float64) # 1 here means we fully remove the interaction # 1-2, 1-3 # scale_full = insert_parameters(1.0, 20) # 1-4, remove half of the interaction # scale_half = insert_parameters(0.5, 21) exclusion_idxs = [] scale_factors = [] all_sig = lj_params[:, 0] all_eps = lj_params[:, 1] # validate exclusions/exceptions to make sure they make sense for a_idx in range(force.getNumExceptions()): # tbd charge scale factors src, dst, new_cp, new_sig, new_eps = force.getExceptionParameters( a_idx) new_sig = value(new_sig) new_eps = value(new_eps) src_sig = all_sig[src] dst_sig = all_sig[dst] src_eps = all_eps[src] dst_eps = all_eps[dst] expected_sig = (src_sig + dst_sig) / 2 expected_eps = np.sqrt(src_eps * dst_eps) exclusion_idxs.append([src, dst]) # sanity check this (expected_eps can be zero), redo this thing # the lj_scale factor measures how much we *remove* if expected_eps == 0: if new_eps == 0: lj_scale_factor = 1 else: raise RuntimeError( "Divide by zero in epsilon calculation") else: lj_scale_factor = 1 - new_eps / expected_eps scale_factors.append(lj_scale_factor) # tbd fix charge_scale_factors using new_cp if new_eps != 0: np.testing.assert_almost_equal(expected_sig, new_sig) exclusion_idxs = np.array(exclusion_idxs, dtype=np.int32) lambda_plane_idxs = np.zeros(N, dtype=np.int32) lambda_offset_idxs = np.zeros(N, dtype=np.int32) # cutoff = 1000.0 nb_params = np.concatenate( [np.expand_dims(charge_params, axis=1), lj_params], axis=1) # optimizations nb_params[:, 1] = nb_params[:, 1] / 2 nb_params[:, 2] = np.sqrt(nb_params[:, 2]) beta = 2.0 # erfc correction # use the same scale factors for electrostatics and lj scale_factors = np.stack([scale_factors, scale_factors], axis=1) bps.append( potentials.Nonbonded(exclusion_idxs, scale_factors, lambda_plane_idxs, lambda_offset_idxs, beta, cutoff).bind(nb_params)) # nrg_fns.append(('Exclusions', (exclusion_idxs, scale_factors, es_scale_factors))) return bps, masses