def pose_dock( guests_sdfile, host_pdbfile, transition_type, n_steps, transition_steps, max_lambda, outdir, random_rotation=False, constant_atoms=[], ): """Runs short simulations in which the guests phase in or out over time Parameters ---------- guests_sdfile: path to input sdf with guests to pose/dock host_pdbfile: path to host pdb file to dock into transition_type: "insertion" or "deletion" n_steps: how many total steps of simulation to do (recommended: <= 1000) transition_steps: how many steps to insert/delete the guest over (recommended: <= 500) (must be <= n_steps) max_lambda: lambda value the guest should insert from or delete to (recommended: 1.0 for work calulation, 0.25 to stay close to original pose) (must be =1 for work calculation to be applicable) outdir: where to write output (will be created if it does not already exist) random_rotation: whether to apply a random rotation to each guest before inserting constant_atoms: atom numbers from the host_pdbfile to hold mostly fixed across the simulation (1-indexed, like PDB files) Output ------ A pdb & sdf file every 100 steps (outdir/<guest_name>_<step>.pdb) stdout every 100 steps noting the step number, lambda value, and energy stdout for each guest noting the work of transition stdout for each guest noting how long it took to run Note ---- If any norm of force per atom exceeds 20000 kJ/(mol*nm) [MAX_NORM_FORCE defined in docking/report.py], the simulation for that guest will stop and the work will not be calculated. """ assert transition_steps <= n_steps assert transition_type in ("insertion", "deletion") if random_rotation: assert transition_type == "insertion" if not os.path.exists(outdir): os.makedirs(outdir) host_mol = Chem.MolFromPDBFile(host_pdbfile, removeHs=False) amber_ff = app.ForceField("amber99sbildn.xml", "tip3p.xml") host_file = PDBFile(host_pdbfile) host_system = amber_ff.createSystem( host_file.topology, nonbondedMethod=app.NoCutoff, constraints=None, rigidWater=False, ) host_conf = [] for x, y, z in host_file.positions: host_conf.append([to_md_units(x), to_md_units(y), to_md_units(z)]) host_conf = np.array(host_conf) final_potentials = [] host_potentials, host_masses = openmm_deserializer.deserialize_system( host_system, cutoff=1.2) host_nb_bp = None for bp in host_potentials: if isinstance(bp, potentials.Nonbonded): # (ytz): hack to ensure we only have one nonbonded term assert host_nb_bp is None host_nb_bp = bp else: final_potentials.append(bp) # TODO (ytz): we should really fix this later on. This padding was done to # address the particles that are too close to the boundary. padding = 0.1 box_lengths = np.amax(host_conf, axis=0) - np.amin(host_conf, axis=0) box_lengths = box_lengths + padding box = np.eye(3, dtype=np.float64) * box_lengths suppl = Chem.SDMolSupplier(guests_sdfile, removeHs=False) for guest_mol in suppl: start_time = time.time() guest_name = guest_mol.GetProp("_Name") guest_ff_handlers = deserialize_handlers( open( os.path.join( os.path.dirname(os.path.abspath(__file__)), "..", "ff/params/smirnoff_1_1_0_ccc.py", )).read()) ff = Forcefield(guest_ff_handlers) guest_base_topology = topology.BaseTopology(guest_mol, ff) # combine hgt = topology.HostGuestTopology(host_nb_bp, guest_base_topology) # setup the parameter handlers for the ligand bonded_tuples = [[hgt.parameterize_harmonic_bond, ff.hb_handle], [hgt.parameterize_harmonic_angle, ff.ha_handle], [hgt.parameterize_proper_torsion, ff.pt_handle], [hgt.parameterize_improper_torsion, ff.it_handle]] these_potentials = list(final_potentials) # instantiate the vjps while parameterizing (forward pass) for fn, handle in bonded_tuples: params, potential = fn(handle.params) these_potentials.append(potential.bind(params)) nb_params, nb_potential = hgt.parameterize_nonbonded( ff.q_handle.params, ff.lj_handle.params) these_potentials.append(nb_potential.bind(nb_params)) bps = these_potentials guest_masses = [a.GetMass() for a in guest_mol.GetAtoms()] masses = np.concatenate([host_masses, guest_masses]) for atom_num in constant_atoms: masses[atom_num - 1] += 50000 conformer = guest_mol.GetConformer(0) mol_conf = np.array(conformer.GetPositions(), dtype=np.float64) mol_conf = mol_conf / 10 # convert to md_units if random_rotation: center = np.mean(mol_conf, axis=0) mol_conf -= center from scipy.stats import special_ortho_group mol_conf = np.matmul(mol_conf, special_ortho_group.rvs(3)) mol_conf += center x0 = np.concatenate([host_conf, mol_conf]) # combined geometry v0 = np.zeros_like(x0) seed = 2021 intg = LangevinIntegrator(300, 1.5e-3, 1.0, masses, seed).impl() impls = [] precision = np.float32 for b in bps: p_impl = b.bound_impl(precision) impls.append(p_impl) ctxt = custom_ops.Context(x0, v0, box, intg, impls) # collect a du_dl calculation once every other step subsample_freq = 2 du_dl_obs = custom_ops.FullPartialUPartialLambda(impls, subsample_freq) ctxt.add_observable(du_dl_obs) if transition_type == "insertion": new_lambda_schedule = np.concatenate([ np.linspace(max_lambda, 0.0, transition_steps), np.zeros(n_steps - transition_steps), ]) elif transition_type == "deletion": new_lambda_schedule = np.concatenate([ np.linspace(0.0, max_lambda, transition_steps), np.ones(n_steps - transition_steps) * max_lambda, ]) else: raise (RuntimeError( 'invalid `transition_type` (must be one of ["insertion", "deletion"])' )) calc_work = True for step, lamb in enumerate(new_lambda_schedule): ctxt.step(lamb) if step % 100 == 0: report.report_step(ctxt, step, lamb, box, bps, impls, guest_name, n_steps, 'pose_dock') host_coords = ctxt.get_x_t()[:len(host_conf)] * 10 guest_coords = ctxt.get_x_t()[len(host_conf):] * 10 report.write_frame(host_coords, host_mol, guest_coords, guest_mol, guest_name, outdir, step, 'pd') if step in (0, int(n_steps / 2), n_steps - 1): if report.too_much_force(ctxt, lamb, box, bps, impls): calc_work = False break # Note: this condition only applies for ABFE, not RBFE if (abs(du_dl_obs.full_du_dl()[0]) > 0.001 or abs(du_dl_obs.full_du_dl()[-1]) > 0.001): print("Error: du_dl endpoints are not ~0") calc_work = False if calc_work: work = np.trapz(du_dl_obs.full_du_dl(), new_lambda_schedule[::subsample_freq]) print(f"guest_name: {guest_name}\twork: {work:.2f}") end_time = time.time() print(f"{guest_name} took {(end_time - start_time):.2f} seconds")
def dock_and_equilibrate(host_pdbfile, guests_sdfile, max_lambda, insertion_steps, eq_steps, outdir, fewer_outfiles=False, constant_atoms=[]): """Solvates a host, inserts guest(s) into solvated host, equilibrates Parameters ---------- host_pdbfile: path to host pdb file to dock into guests_sdfile: path to input sdf with guests to pose/dock max_lambda: lambda value the guest should insert from or delete to (recommended: 1.0 for work calulation, 0.25 to stay close to original pose) (must be =1 for work calculation to be applicable) insertion_steps: how many steps to insert the guest over (recommended: 501) eq_steps: how many steps of equilibration to do after insertion (recommended: 15001) outdir: where to write output (will be created if it does not already exist) fewer_outfiles: if True, will only write frames for the equilibration, not insertion constant_atoms: atom numbers from the host_pdbfile to hold mostly fixed across the simulation (1-indexed, like PDB files) Output ------ A pdb & sdf file every 100 steps of insertion (outdir/<guest_name>/<guest_name>_<step>.[pdb/sdf]) A pdb & sdf file every 1000 steps of equilibration (outdir/<guest_name>/<guest_name>_<step>.[pdb/sdf]) stdout every 100(0) steps noting the step number, lambda value, and energy stdout for each guest noting the work of transition stdout for each guest noting how long it took to run Note ---- If any norm of force per atom exceeds 20000 kJ/(mol*nm) [MAX_NORM_FORCE defined in docking/report.py], the simulation for that guest will stop and the work will not be calculated. """ if not os.path.exists(outdir): os.makedirs(outdir) print(f""" HOST_PDBFILE = {host_pdbfile} GUESTS_SDFILE = {guests_sdfile} OUTDIR = {outdir} MAX_LAMBDA = {max_lambda} INSERTION_STEPS = {insertion_steps} EQ_STEPS = {eq_steps} """) # Prepare host # TODO: handle extra (non-transitioning) guests? print("Solvating host...") # TODO: return topology from builders.build_protein_system ( solvated_host_system, solvated_host_coords, _, _, host_box, solvated_topology, ) = builders.build_protein_system(host_pdbfile) # sometimes water boxes are sad. Should be minimized first; this is a workaround host_box += np.eye(3) * 0.1 print("host box", host_box) solvated_host_pdb = os.path.join(outdir, "solvated_host.pdb") writer = pdb_writer.PDBWriter([solvated_topology], solvated_host_pdb) writer.write_frame(solvated_host_coords) writer.close() solvated_host_mol = Chem.MolFromPDBFile(solvated_host_pdb, removeHs=False) os.remove(solvated_host_pdb) final_host_potentials = [] host_potentials, host_masses = openmm_deserializer.deserialize_system( solvated_host_system, cutoff=1.2) host_nb_bp = None for bp in host_potentials: if isinstance(bp, potentials.Nonbonded): # (ytz): hack to ensure we only have one nonbonded term assert host_nb_bp is None host_nb_bp = bp else: final_host_potentials.append(bp) # Run the procedure print("Getting guests...") suppl = Chem.SDMolSupplier(guests_sdfile, removeHs=False) for guest_mol in suppl: start_time = time.time() guest_name = guest_mol.GetProp("_Name") guest_conformer = guest_mol.GetConformer(0) orig_guest_coords = np.array(guest_conformer.GetPositions(), dtype=np.float64) orig_guest_coords = orig_guest_coords / 10 # convert to md_units guest_ff_handlers = deserialize_handlers( open( os.path.join( os.path.dirname(os.path.abspath(__file__)), "..", "ff/params/smirnoff_1_1_0_ccc.py", )).read()) ff = Forcefield(guest_ff_handlers) guest_base_top = topology.BaseTopology(guest_mol, ff) # combine host & guest hgt = topology.HostGuestTopology(host_nb_bp, guest_base_top) # setup the parameter handlers for the ligand bonded_tuples = [[hgt.parameterize_harmonic_bond, ff.hb_handle], [hgt.parameterize_harmonic_angle, ff.ha_handle], [hgt.parameterize_proper_torsion, ff.pt_handle], [hgt.parameterize_improper_torsion, ff.it_handle]] combined_bps = list(final_host_potentials) # instantiate the vjps while parameterizing (forward pass) for fn, handle in bonded_tuples: params, potential = fn(handle.params) combined_bps.append(potential.bind(params)) nb_params, nb_potential = hgt.parameterize_nonbonded( ff.q_handle.params, ff.lj_handle.params) combined_bps.append(nb_potential.bind(nb_params)) guest_masses = [a.GetMass() for a in guest_mol.GetAtoms()] combined_masses = np.concatenate([host_masses, guest_masses]) x0 = np.concatenate([solvated_host_coords, orig_guest_coords]) v0 = np.zeros_like(x0) print( f"SYSTEM", f"guest_name: {guest_name}", f"num_atoms: {len(x0)}", ) for atom_num in constant_atoms: combined_masses[atom_num - 1] += 50000 seed = 2021 intg = LangevinIntegrator(300.0, 1.5e-3, 1.0, combined_masses, seed).impl() u_impls = [] for bp in combined_bps: bp_impl = bp.bound_impl(precision=np.float32) u_impls.append(bp_impl) ctxt = custom_ops.Context(x0, v0, host_box, intg, u_impls) # collect a du_dl calculation once every other step subsample_freq = 2 du_dl_obs = custom_ops.FullPartialUPartialLambda( u_impls, subsample_freq) ctxt.add_observable(du_dl_obs) # insert guest insertion_lambda_schedule = np.linspace(max_lambda, 0.0, insertion_steps) calc_work = True for step, lamb in enumerate(insertion_lambda_schedule): ctxt.step(lamb) if step % 100 == 0: report.report_step(ctxt, step, lamb, host_box, combined_bps, u_impls, guest_name, insertion_steps, "INSERTION") if not fewer_outfiles: host_coords = ctxt.get_x_t()[:len(solvated_host_coords )] * 10 guest_coords = ctxt.get_x_t()[len(solvated_host_coords ):] * 10 report.write_frame( host_coords, solvated_host_mol, guest_coords, guest_mol, guest_name, outdir, str(step).zfill(len(str(insertion_steps))), f"ins", ) if step in (0, int(insertion_steps / 2), insertion_steps - 1): if report.too_much_force(ctxt, lamb, host_box, combined_bps, u_impls): calc_work = False break # Note: this condition only applies for ABFE, not RBFE if (abs(du_dl_obs.full_du_dl()[0]) > 0.001 or abs(du_dl_obs.full_du_dl()[-1]) > 0.001): print("Error: du_dl endpoints are not ~0") calc_work = False if calc_work: work = np.trapz(du_dl_obs.full_du_dl(), insertion_lambda_schedule[::subsample_freq]) print(f"guest_name: {guest_name}\tinsertion_work: {work:.2f}") # equilibrate for step in range(eq_steps): ctxt.step(0.00) if step % 1000 == 0: report.report_step(ctxt, step, 0.00, host_box, combined_bps, u_impls, guest_name, eq_steps, 'EQUILIBRATION') host_coords = ctxt.get_x_t()[:len(solvated_host_coords)] * 10 guest_coords = ctxt.get_x_t()[len(solvated_host_coords):] * 10 report.write_frame( host_coords, solvated_host_mol, guest_coords, guest_mol, guest_name, outdir, str(step).zfill(len(str(eq_steps))), f"eq", ) if step in (0, int(eq_steps / 2), eq_steps - 1): if report.too_much_force(ctxt, 0.00, host_box, combined_bps, u_impls): break end_time = time.time() print(f"{guest_name} took {(end_time - start_time):.2f} seconds")
num_host_atoms = host_coords.shape[0] final_potentials = [] final_vjp_and_handles = [] # keep the bonded terms in the host the same. # but we keep the nonbonded term for a subsequent modification for bp in host_bps: if isinstance(bp, potentials.Nonbonded): host_p = bp else: final_potentials.append(bp) final_vjp_and_handles.append(None) gdt = topology.DualTopology(romol_a, romol_b, ff) hgt = topology.HostGuestTopology(host_p, gdt) # setup the parameter handlers for the ligand tuples = [ [hgt.parameterize_harmonic_bond, [ff.hb_handle]], [hgt.parameterize_harmonic_angle, [ff.ha_handle]], [hgt.parameterize_proper_torsion, [ff.pt_handle]], [hgt.parameterize_improper_torsion, [ff.it_handle]], [hgt.parameterize_nonbonded, [ff.q_handle, ff.lj_handle]], ] # instantiate the vjps while parameterizing (forward pass) for fn, handles in tuples: params, vjp_fn, potential = jax.vjp(fn, *[h.params for h in handles], has_aux=True)
def calculate_rigorous_work( host_pdbfile, guests_sdfile, outdir, fewer_outfiles=False, no_outfiles=False ): """ """ if not os.path.exists(outdir): os.makedirs(outdir) print( f""" HOST_PDBFILE = {host_pdbfile} GUESTS_SDFILE = {guests_sdfile} OUTDIR = {outdir} INSERTION_MAX_LAMBDA = {INSERTION_MAX_LAMBDA} DELETION_MAX_LAMBDA = {DELETION_MAX_LAMBDA} MIN_LAMBDA = {MIN_LAMBDA} TRANSITION_STEPS = {TRANSITION_STEPS} EQ1_STEPS = {EQ1_STEPS} EQ2_STEPS = {EQ2_STEPS} """ ) # Prepare host # TODO: handle extra (non-transitioning) guests? print("Solvating host...") ( solvated_host_system, solvated_host_coords, _, _, host_box, solvated_topology, ) = builders.build_protein_system(host_pdbfile) # sometimes water boxes are sad. Should be minimized first; this is a workaround host_box += np.eye(3) * 0.1 print("host box", host_box) solvated_host_pdb = os.path.join(outdir, "solvated_host.pdb") writer = pdb_writer.PDBWriter([solvated_topology], solvated_host_pdb) writer.write_frame(solvated_host_coords) writer.close() solvated_host_mol = Chem.MolFromPDBFile(solvated_host_pdb, removeHs=False) if no_outfiles: os.remove(solvated_host_pdb) final_host_potentials = [] host_potentials, host_masses = openmm_deserializer.deserialize_system(solvated_host_system, cutoff=1.2) host_nb_bp = None for bp in host_potentials: if isinstance(bp, potentials.Nonbonded): # (ytz): hack to ensure we only have one nonbonded term assert host_nb_bp is None host_nb_bp = bp else: final_host_potentials.append(bp) # Prepare water box print("Generating water box...") # TODO: water box probably doesn't need to be this big box_lengths = host_box[np.diag_indices(3)] water_box_width = min(box_lengths) ( water_system, orig_water_coords, water_box, water_topology, ) = builders.build_water_system(water_box_width) # sometimes water boxes are sad. should be minimized first; this is a workaround water_box += np.eye(3) * 0.1 print("water box", water_box) # it's okay if the water box here and the solvated protein box don't align -- they have PBCs water_pdb = os.path.join(outdir, "water_box.pdb") writer = pdb_writer.PDBWriter([water_topology], water_pdb) writer.write_frame(orig_water_coords) writer.close() water_mol = Chem.MolFromPDBFile(water_pdb, removeHs=False) if no_outfiles: os.remove(water_pdb) final_water_potentials = [] water_potentials, water_masses = openmm_deserializer.deserialize_system(water_system, cutoff=1.2) water_nb_bp = None for bp in water_potentials: if isinstance(bp, potentials.Nonbonded): # (ytz): hack to ensure we only have one nonbonded term assert water_nb_bp is None water_nb_bp = bp else: final_water_potentials.append(bp) # Run the procedure print("Getting guests...") suppl = Chem.SDMolSupplier(guests_sdfile, removeHs=False) for guest_mol in suppl: start_time = time.time() guest_name = guest_mol.GetProp("_Name") guest_conformer = guest_mol.GetConformer(0) orig_guest_coords = np.array(guest_conformer.GetPositions(), dtype=np.float64) orig_guest_coords = orig_guest_coords / 10 # convert to md_units guest_ff_handlers = deserialize_handlers( open( os.path.join( os.path.dirname(os.path.abspath(__file__)), "..", "ff/params/smirnoff_1_1_0_ccc.py", ) ).read() ) ff = Forcefield(guest_ff_handlers) guest_base_top = topology.BaseTopology(guest_mol, ff) # combine host & guest hgt = topology.HostGuestTopology(host_nb_bp, guest_base_top) # setup the parameter handlers for the ligand bonded_tuples = [ [hgt.parameterize_harmonic_bond, ff.hb_handle], [hgt.parameterize_harmonic_angle, ff.ha_handle], [hgt.parameterize_proper_torsion, ff.pt_handle], [hgt.parameterize_improper_torsion, ff.it_handle] ] combined_bps = list(final_host_potentials) # instantiate the vjps while parameterizing (forward pass) for fn, handle in bonded_tuples: params, potential = fn(handle.params) combined_bps.append(potential.bind(params)) nb_params, nb_potential = hgt.parameterize_nonbonded(ff.q_handle.params, ff.lj_handle.params) combined_bps.append(nb_potential.bind(nb_params)) guest_masses = [a.GetMass() for a in guest_mol.GetAtoms()] combined_masses = np.concatenate([host_masses, guest_masses]) run_leg( solvated_host_coords, orig_guest_coords, combined_bps, combined_masses, host_box, guest_name, "host", solvated_host_mol, guest_mol, outdir, fewer_outfiles, no_outfiles, ) end_time = time.time() print( f"{guest_name} host leg time:", "%.2f" % (end_time - start_time), "seconds" ) # combine water & guest wgt = topology.HostGuestTopology(water_nb_bp, guest_base_top) # setup the parameter handlers for the ligand bonded_tuples = [ [wgt.parameterize_harmonic_bond, ff.hb_handle], [wgt.parameterize_harmonic_angle, ff.ha_handle], [wgt.parameterize_proper_torsion, ff.pt_handle], [wgt.parameterize_improper_torsion, ff.it_handle] ] combined_bps = list(final_water_potentials) # instantiate the vjps while parameterizing (forward pass) for fn, handle in bonded_tuples: params, potential = fn(handle.params) combined_bps.append(potential.bind(params)) nb_params, nb_potential = wgt.parameterize_nonbonded(ff.q_handle.params, ff.lj_handle.params) combined_bps.append(nb_potential.bind(nb_params)) guest_masses = [a.GetMass() for a in guest_mol.GetAtoms()] combined_masses = np.concatenate([water_masses, guest_masses]) start_time = time.time() run_leg( orig_water_coords, orig_guest_coords, combined_bps, combined_masses, water_box, guest_name, "water", water_mol, guest_mol, outdir, fewer_outfiles, no_outfiles, ) end_time = time.time() print( f"{guest_name} water leg time:", "%.2f" % (end_time - start_time), "seconds" )
def minimize_host_4d(romol, host_system, host_coords, ff, box): """ Insert romol into a host system via 4D decoupling under a Langevin thermostat. The ligand coordinates are fixed during this, and only host_coordinates are minimized. Parameters ---------- romol: ROMol Ligand to be inserted. It must be embedded. host_system: openmm.System OpenMM System representing the host host_coords: np.ndarray N x 3 coordinates of the host. units of nanometers. ff: ff.Forcefield Wrapper class around a list of handlers box: np.ndarray [3,3] Box matrix for periodic boundary conditions. units of nanometers. Returns ------- np.ndarray This returns minimized host_coords. """ host_bps, host_masses = openmm_deserializer.deserialize_system(host_system, cutoff=1.2) # keep the ligand rigid ligand_masses = [a.GetMass()*100000 for a in romol.GetAtoms()] combined_masses = np.concatenate([host_masses, ligand_masses]) ligand_coords = get_romol_conf(romol) combined_coords = np.concatenate([host_coords, ligand_coords]) num_host_atoms = host_coords.shape[0] final_potentials = [] for bp in host_bps: if isinstance(bp, potentials.Nonbonded): host_p = bp else: final_potentials.append(bp) gbt = topology.BaseTopology(romol, ff) hgt = topology.HostGuestTopology(host_p, gbt) # setup the parameter handlers for the ligand tuples = [ [hgt.parameterize_harmonic_bond, [ff.hb_handle]], [hgt.parameterize_harmonic_angle, [ff.ha_handle]], [hgt.parameterize_proper_torsion, [ff.pt_handle]], [hgt.parameterize_improper_torsion, [ff.it_handle]], [hgt.parameterize_nonbonded, [ff.q_handle, ff.lj_handle]], ] for fn, handles in tuples: params, potential = fn(*[h.params for h in handles]) final_potentials.append(potential.bind(params)) seed = 2020 intg = LangevinIntegrator( 300.0, 1.5e-3, 1.0, combined_masses, seed ).impl() x0 = combined_coords v0 = np.zeros_like(x0) u_impls = [] for bp in final_potentials: fn = bp.bound_impl(precision=np.float32) u_impls.append(fn) # context components: positions, velocities, box, integrator, energy fxns ctxt = custom_ops.Context( x0, v0, box, intg, u_impls ) for lamb in np.linspace(1.0, 0, 1000): ctxt.step(lamb) return ctxt.get_x_t()[:num_host_atoms]
def host_edge(self, lamb, host_system, host_coords, box, equil_steps=10000, prod_steps=100000): """ Run equilibrium decoupling simulation at a given value of lambda in a host environment. Parameters ---------- lamb: float [0, 1] 0 is the fully interacting system, and 1 is the non-interacting system host_system: openmm.System OpenMM System object to be deserialized. The host can be simply a box of water, or a fully solvated protein host_coords: np.array of shape [..., 3] Host coordinates, in nanometers. It should be properly minimized and not have clashes with the ligand coordinates. box: np.array [3,3] Periodic boundary conditions, in nanometers. equil_steps: float Number of steps to run equilibration. Statistics are not gathered. prod_steps: float Number of steps to run production. Statistics are gathered. Returns ------- float, float Returns a pair of average du_dl values for bonded and nonbonded terms. """ ligand_masses_a = [a.GetMass() for a in self.mol_a.GetAtoms()] ligand_masses_b = [b.GetMass() for b in self.mol_b.GetAtoms()] # extract the 0th conformer ligand_coords_a = get_romol_conf(self.mol_a) ligand_coords_b = get_romol_conf(self.mol_b) host_bps, host_masses = openmm_deserializer.deserialize_system(host_system, cutoff=1.2) num_host_atoms = host_coords.shape[0] final_potentials = [] final_vjp_and_handles = [] # keep the bonded terms in the host the same. # but we keep the nonbonded term for a subsequent modification for bp in host_bps: if isinstance(bp, potentials.Nonbonded): host_p = bp else: final_potentials.append([bp]) # (ytz): no protein ff support for now, so we skip their vjps final_vjp_and_handles.append(None) hgt = topology.HostGuestTopology(host_p, self.top) # setup the parameter handlers for the ligand bonded_tuples = [ [hgt.parameterize_harmonic_bond, self.ff.hb_handle], [hgt.parameterize_harmonic_angle, self.ff.ha_handle], [hgt.parameterize_proper_torsion, self.ff.pt_handle], [hgt.parameterize_improper_torsion, self.ff.it_handle] ] # instantiate the vjps while parameterizing (forward pass) for fn, handle in bonded_tuples: (src_params, dst_params, uni_params), vjp_fn, (src_potential, dst_potential, uni_potential) = jax.vjp(fn, handle.params, has_aux=True) final_potentials.append([src_potential.bind(src_params), dst_potential.bind(dst_params), uni_potential.bind(uni_params)]) final_vjp_and_handles.append((vjp_fn, handle)) nb_params, vjp_fn, nb_potential = jax.vjp(hgt.parameterize_nonbonded, self.ff.q_handle.params, self.ff.lj_handle.params, has_aux=True) final_potentials.append([nb_potential.bind(nb_params)]) final_vjp_and_handles.append([vjp_fn, (self.ff.q_handle, self.ff.lj_handle)]) # (ytz): note the handlers are a tuple, this is checked later combined_masses = np.concatenate([host_masses, np.mean(self.top.interpolate_params(ligand_masses_a, ligand_masses_b), axis=0)]) src_conf, dst_conf = self.top.interpolate_params(ligand_coords_a, ligand_coords_b) combined_coords = np.concatenate([host_coords, np.mean(self.top.interpolate_params(ligand_coords_a, ligand_coords_b), axis=0)]) # (ytz): us is short form for mean and std dev. bonded_us, nonbonded_us, grads = self._simulate( lamb, box, combined_coords, np.zeros_like(combined_coords), final_potentials, self._get_integrator(combined_masses), equil_steps, prod_steps ) grads_and_handles = [] for du_dqs, vjps_and_handles in zip(grads, final_vjp_and_handles): if vjps_and_handles is not None: vjp_fn = vjps_and_handles[0] handles = vjps_and_handles[1] # we need to get the shapes correct (eg. nonbonded vjp emits an ndarray, not a list.) # (ytz): so far nonbonded grads is the only term that map back out to two # vjp handlers (charge and lj). the vjp also expects an nd.array, not a list. So we kill # two birds with one stone here, but this is quite brittle and should be refactored later on. if type(handles) == tuple: # handle nonbonded terms du_dps = vjp_fn(du_dqs[0]) for du_dp, handler in zip(du_dps, handles): grads_and_handles.append((du_dp, type(handler))) else: du_dp = vjp_fn(du_dqs) # bonded terms return a list, so we need to flatten it here grads_and_handles.append((du_dp[0], type(handles))) return bonded_us, nonbonded_us, grads_and_handles
def host_edge(self, lamb, host_system, host_coords, box, equil_steps=10000, prod_steps=100000): """ Run equilibrium decoupling simulation at a given value of lambda in a host environment. Parameters ---------- lamb: float [0, 1] 0 is the fully interacting system, and 1 is the non-interacting system host_system: openmm.System OpenMM System object to be deserialized. The host can be simply a box of water, or a fully solvated protein host_coords: np.array of shape [..., 3] Host coordinates, in nanometers. It should be properly minimized and not have clashes with the ligand coordinates. box: np.array [3,3] Periodic boundary conditions, in nanometers. equil_steps: float Number of steps to run equilibration. Statistics are not gathered. prod_steps: float Number of steps to run production. Statistics are gathered. Returns ------- float, float Returns a pair of average du_dl values for bonded and nonbonded terms. """ ligand_masses = [a.GetMass() for a in self.mol.GetAtoms()] ligand_coords = get_romol_conf(self.mol) host_bps, host_masses = openmm_deserializer.deserialize_system(host_system, cutoff=1.2) num_host_atoms = host_coords.shape[0] final_potentials = [] final_vjp_and_handles = [] for bp in host_bps: if isinstance(bp, potentials.Nonbonded): host_p = bp else: final_potentials.append([bp]) final_vjp_and_handles.append(None) hgt = topology.HostGuestTopology(host_p, self.top) # setup the parameter handlers for the ligand bonded_tuples = [ [hgt.parameterize_harmonic_bond, self.ff.hb_handle], [hgt.parameterize_harmonic_angle, self.ff.ha_handle], [hgt.parameterize_proper_torsion, self.ff.pt_handle], [hgt.parameterize_improper_torsion, self.ff.it_handle] ] # instantiate the vjps while parameterizing (forward pass) for fn, handle in bonded_tuples: params, vjp_fn, potential = jax.vjp(fn, handle.params, has_aux=True) final_potentials.append([potential.bind(params)]) final_vjp_and_handles.append((vjp_fn, handle)) nb_params, vjp_fn, nb_potential = jax.vjp(hgt.parameterize_nonbonded, self.ff.q_handle.params, self.ff.lj_handle.params, has_aux=True) final_potentials.append([nb_potential.bind(nb_params)]) final_vjp_and_handles.append([vjp_fn]) combined_masses = np.concatenate([host_masses, ligand_masses]) combined_coords = np.concatenate([host_coords, ligand_coords]) return self._simulate( lamb, box, combined_coords, np.zeros_like(combined_coords), final_potentials, self._get_integrator(combined_masses), equil_steps, prod_steps )