Ejemplo n.º 1
0
def test_combine_recipe():
    ff_handlers = deserialize_handlers(
        open('ff/params/smirnoff_1_1_0_ccc.py').read())
    aspirin = Chem.AddHs(Chem.MolFromSmiles("CC(=O)OC1=CC=CC=C1C(=O)O"))
    AllChem.EmbedMolecule(aspirin)
    ligand_recipe = md.Recipe.from_rdkit(aspirin, ff_handlers)
    fname = 'tests/data/hif2a_nowater_min.pdb'
    pdb = open(fname, 'r').read()
    openmm_system, openmm_conf, _, _, _, _ = builders.build_protein_system(
        'tests/data/hif2a_nowater_min.pdb')
    protein_recipe = md.Recipe.from_openmm(openmm_system)

    for left_recipe, right_recipe in [[protein_recipe, ligand_recipe],
                                      [ligand_recipe, protein_recipe]]:

        combined_recipe = left_recipe.combine(right_recipe)
        qlj = np.ones((aspirin.GetNumAtoms() + openmm_conf.shape[0], 3))

        left_nonbonded_potential = left_recipe.bound_potentials[-1]
        right_nonbonded_potential = right_recipe.bound_potentials[-1]
        combined_nonbonded_potential = combined_recipe.bound_potentials[-1]

        left_idxs = left_nonbonded_potential.get_exclusion_idxs()
        right_idxs = right_nonbonded_potential.get_exclusion_idxs()
        combined_idxs = combined_nonbonded_potential.get_exclusion_idxs()

        n_left = len(left_recipe.masses)
        n_right = len(right_recipe.masses)

        np.testing.assert_array_equal(
            np.concatenate([left_idxs, right_idxs + n_left]), combined_idxs)

        for bp in combined_recipe.bound_potentials:
            bp.bound_impl(precision=np.float32)
Ejemplo n.º 2
0
def dock_and_equilibrate(host_pdbfile,
                         guests_sdfile,
                         max_lambda,
                         insertion_steps,
                         eq_steps,
                         outdir,
                         fewer_outfiles=False,
                         constant_atoms=[]):
    """Solvates a host, inserts guest(s) into solvated host, equilibrates

    Parameters
    ----------

    host_pdbfile: path to host pdb file to dock into
    guests_sdfile: path to input sdf with guests to pose/dock
    max_lambda: lambda value the guest should insert from or delete to
        (recommended: 1.0 for work calulation, 0.25 to stay close to original pose)
        (must be =1 for work calculation to be applicable)
    insertion_steps: how many steps to insert the guest over (recommended: 501)
    eq_steps: how many steps of equilibration to do after insertion (recommended: 15001)
    outdir: where to write output (will be created if it does not already exist)
    fewer_outfiles: if True, will only write frames for the equilibration, not insertion
    constant_atoms: atom numbers from the host_pdbfile to hold mostly fixed across the simulation
        (1-indexed, like PDB files)

    Output
    ------

    A pdb & sdf file every 100 steps of insertion (outdir/<guest_name>/<guest_name>_<step>.[pdb/sdf])
    A pdb & sdf file every 1000 steps of equilibration (outdir/<guest_name>/<guest_name>_<step>.[pdb/sdf])
    stdout every 100(0) steps noting the step number, lambda value, and energy
    stdout for each guest noting the work of transition
    stdout for each guest noting how long it took to run

    Note
    ----
    If any norm of force per atom exceeds 20000 kJ/(mol*nm) [MAX_NORM_FORCE defined in docking/report.py],
    the simulation for that guest will stop and the work will not be calculated.
    """

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    print(f"""
    HOST_PDBFILE = {host_pdbfile}
    GUESTS_SDFILE = {guests_sdfile}
    OUTDIR = {outdir}
    MAX_LAMBDA = {max_lambda}
    INSERTION_STEPS = {insertion_steps}
    EQ_STEPS = {eq_steps}
    """)

    # Prepare host
    # TODO: handle extra (non-transitioning) guests?
    print("Solvating host...")
    # TODO: return topology from builders.build_protein_system
    (
        solvated_host_system,
        solvated_host_coords,
        _,
        _,
        host_box,
        solvated_topology,
    ) = builders.build_protein_system(host_pdbfile)

    # sometimes water boxes are sad. Should be minimized first; this is a workaround
    host_box += np.eye(3) * 0.1
    print("host box", host_box)

    solvated_host_pdb = os.path.join(outdir, "solvated_host.pdb")
    writer = pdb_writer.PDBWriter([solvated_topology], solvated_host_pdb)
    writer.write_frame(solvated_host_coords)
    writer.close()
    solvated_host_mol = Chem.MolFromPDBFile(solvated_host_pdb, removeHs=False)
    os.remove(solvated_host_pdb)
    final_host_potentials = []
    host_potentials, host_masses = openmm_deserializer.deserialize_system(
        solvated_host_system, cutoff=1.2)
    host_nb_bp = None
    for bp in host_potentials:
        if isinstance(bp, potentials.Nonbonded):
            # (ytz): hack to ensure we only have one nonbonded term
            assert host_nb_bp is None
            host_nb_bp = bp
        else:
            final_host_potentials.append(bp)

    # Run the procedure
    print("Getting guests...")
    suppl = Chem.SDMolSupplier(guests_sdfile, removeHs=False)
    for guest_mol in suppl:
        start_time = time.time()
        guest_name = guest_mol.GetProp("_Name")
        guest_conformer = guest_mol.GetConformer(0)
        orig_guest_coords = np.array(guest_conformer.GetPositions(),
                                     dtype=np.float64)
        orig_guest_coords = orig_guest_coords / 10  # convert to md_units
        guest_ff_handlers = deserialize_handlers(
            open(
                os.path.join(
                    os.path.dirname(os.path.abspath(__file__)),
                    "..",
                    "ff/params/smirnoff_1_1_0_ccc.py",
                )).read())
        ff = Forcefield(guest_ff_handlers)
        guest_base_top = topology.BaseTopology(guest_mol, ff)

        # combine host & guest
        hgt = topology.HostGuestTopology(host_nb_bp, guest_base_top)
        # setup the parameter handlers for the ligand
        bonded_tuples = [[hgt.parameterize_harmonic_bond, ff.hb_handle],
                         [hgt.parameterize_harmonic_angle, ff.ha_handle],
                         [hgt.parameterize_proper_torsion, ff.pt_handle],
                         [hgt.parameterize_improper_torsion, ff.it_handle]]
        combined_bps = list(final_host_potentials)
        # instantiate the vjps while parameterizing (forward pass)
        for fn, handle in bonded_tuples:
            params, potential = fn(handle.params)
            combined_bps.append(potential.bind(params))
        nb_params, nb_potential = hgt.parameterize_nonbonded(
            ff.q_handle.params, ff.lj_handle.params)
        combined_bps.append(nb_potential.bind(nb_params))
        guest_masses = [a.GetMass() for a in guest_mol.GetAtoms()]
        combined_masses = np.concatenate([host_masses, guest_masses])

        x0 = np.concatenate([solvated_host_coords, orig_guest_coords])
        v0 = np.zeros_like(x0)
        print(
            f"SYSTEM",
            f"guest_name: {guest_name}",
            f"num_atoms: {len(x0)}",
        )

        for atom_num in constant_atoms:
            combined_masses[atom_num - 1] += 50000

        seed = 2021
        intg = LangevinIntegrator(300.0, 1.5e-3, 1.0, combined_masses,
                                  seed).impl()

        u_impls = []
        for bp in combined_bps:
            bp_impl = bp.bound_impl(precision=np.float32)
            u_impls.append(bp_impl)

        ctxt = custom_ops.Context(x0, v0, host_box, intg, u_impls)

        # collect a du_dl calculation once every other step
        subsample_freq = 2
        du_dl_obs = custom_ops.FullPartialUPartialLambda(
            u_impls, subsample_freq)
        ctxt.add_observable(du_dl_obs)

        # insert guest
        insertion_lambda_schedule = np.linspace(max_lambda, 0.0,
                                                insertion_steps)
        calc_work = True
        for step, lamb in enumerate(insertion_lambda_schedule):
            ctxt.step(lamb)
            if step % 100 == 0:
                report.report_step(ctxt, step, lamb, host_box, combined_bps,
                                   u_impls, guest_name, insertion_steps,
                                   "INSERTION")
                if not fewer_outfiles:
                    host_coords = ctxt.get_x_t()[:len(solvated_host_coords
                                                      )] * 10
                    guest_coords = ctxt.get_x_t()[len(solvated_host_coords
                                                      ):] * 10
                    report.write_frame(
                        host_coords,
                        solvated_host_mol,
                        guest_coords,
                        guest_mol,
                        guest_name,
                        outdir,
                        str(step).zfill(len(str(insertion_steps))),
                        f"ins",
                    )
            if step in (0, int(insertion_steps / 2), insertion_steps - 1):
                if report.too_much_force(ctxt, lamb, host_box, combined_bps,
                                         u_impls):
                    calc_work = False
                    break

        # Note: this condition only applies for ABFE, not RBFE
        if (abs(du_dl_obs.full_du_dl()[0]) > 0.001
                or abs(du_dl_obs.full_du_dl()[-1]) > 0.001):
            print("Error: du_dl endpoints are not ~0")
            calc_work = False

        if calc_work:
            work = np.trapz(du_dl_obs.full_du_dl(),
                            insertion_lambda_schedule[::subsample_freq])
            print(f"guest_name: {guest_name}\tinsertion_work: {work:.2f}")

        # equilibrate
        for step in range(eq_steps):
            ctxt.step(0.00)
            if step % 1000 == 0:
                report.report_step(ctxt, step, 0.00, host_box, combined_bps,
                                   u_impls, guest_name, eq_steps,
                                   'EQUILIBRATION')
                host_coords = ctxt.get_x_t()[:len(solvated_host_coords)] * 10
                guest_coords = ctxt.get_x_t()[len(solvated_host_coords):] * 10
                report.write_frame(
                    host_coords,
                    solvated_host_mol,
                    guest_coords,
                    guest_mol,
                    guest_name,
                    outdir,
                    str(step).zfill(len(str(eq_steps))),
                    f"eq",
                )
            if step in (0, int(eq_steps / 2), eq_steps - 1):
                if report.too_much_force(ctxt, 0.00, host_box, combined_bps,
                                         u_impls):
                    break

        end_time = time.time()
        print(f"{guest_name} took {(end_time - start_time):.2f} seconds")
Ejemplo n.º 3
0
def run_epoch(ff, mol_a, mol_b, core):
    # build the protein system.
    complex_system, complex_coords, _, _, complex_box = builders.build_protein_system('tests/data/hif2a_nowater_min.pdb')
    complex_box += np.eye(3)*0.1 # BFGS this later

    # build the water system.
    solvent_system, solvent_coords, solvent_box, _ = builders.build_water_system(4.0)
    solvent_box += np.eye(3)*0.1 # BFGS this later

    combined_handle_and_grads = {}
    stage_dGs = []

    for stage, host_system, host_coords, host_box, num_host_windows in [
        ("complex", complex_system, complex_coords, complex_box, cmd_args.num_complex_windows),
        ("solvent", solvent_system, solvent_coords, solvent_box, cmd_args.num_solvent_windows)]:

        A = int(.35*num_host_windows)
        B = int(.30*num_host_windows)
        C = num_host_windows - A - B

        # Emprically, we see the largest variance in std <du/dl> near the endpoints in the nonbonded
        # terms. Bonded terms are roughly linear. So we add more lambda windows at the endpoint to
        # help improve convergence.
        lambda_schedule = np.concatenate([
            np.linspace(0.0,  0.25, A, endpoint=False),
            np.linspace(0.25, 0.75, B, endpoint=False),
            np.linspace(0.75, 1.0,  C, endpoint=True)
        ])

        assert len(lambda_schedule) == num_host_windows

        print("Minimizing the host structure to remove clashes.")
        minimized_host_coords = minimizer.minimize_host_4d(mol_a, host_system, host_coords, ff, host_box)

        rfe = free_energy.RelativeFreeEnergy(mol_a, mol_b, core, ff)

        # solvent leg
        host_args = []
        for lambda_idx, lamb in enumerate(lambda_schedule):
            gpu_idx = lambda_idx % cmd_args.num_gpus
            host_args.append((gpu_idx, lamb, host_system, minimized_host_coords, host_box, cmd_args.num_equil_steps, cmd_args.num_prod_steps))
        
        results = pool.map(functools.partial(wrap_method, fn=rfe.host_edge), host_args, chunksize=1)

        ghs = []

        for lamb, (bonded_du_dl, nonbonded_du_dl, grads_and_handles) in zip(lambda_schedule, results):
            ghs.append(grads_and_handles)
            print("final", stage, "lambda", lamb, "bonded:", bonded_du_dl[0], bonded_du_dl[1], "nonbonded:", nonbonded_du_dl[0], nonbonded_du_dl[1])

        dG_host = np.trapz([x[0][0]+x[1][0] for x in results], lambda_schedule)
        stage_dGs.append(dG_host)

        # use gradient information from the endpoints
        for (grad_lhs, handle_type_lhs), (grad_rhs, handle_type_rhs) in zip(ghs[0], ghs[-1]):
            assert handle_type_lhs == handle_type_rhs # ffs are forked so the return handler isn't same object as that of ff
            grad = grad_rhs - grad_lhs
            # complex - solvent
            if handle_type_lhs not in combined_handle_and_grads:
                combined_handle_and_grads[handle_type_lhs] = grad
            else:
                combined_handle_and_grads[handle_type_lhs] -= grad

        print(stage, "pred_dG:", dG_host)

    pred = stage_dGs[0] - stage_dGs[1]

    loss = np.abs(pred - label)

    print("loss", loss, "pred", pred, "label", label)

    dl_dpred = np.sign(pred - label)

    # (ytz): these should be made configurable later on.
    gradient_clip_thresholds = {
        nonbonded.AM1CCCHandler: 0.05,
        nonbonded.LennardJonesHandler: np.array([0.001,0])
    }

    # update gradients in place.
    # for handle_type, grad in combined_handle_and_grads.items():

    for handle_type, grad in combined_handle_and_grads.items():
        if handle_type in gradient_clip_thresholds:
            bounds = gradient_clip_thresholds[handle_type]
            dl_dp = dl_dpred*grad # chain rule
            # lots of room to improve here.
            dl_dp = np.clip(dl_dp, -bounds, bounds) # clip gradients so they're well behaved


            if handle_type == nonbonded.AM1CCCHandler:
                # sanity check as we have other charge methods that exist
                assert handle_type == type(ff.q_handle)
                ff.q_handle.params -= dl_dp

                # useful for debugging to dump out the grads
                # for smirks, dp in zip(ff.q_handle.smirks, dl_dp):
                    # if np.any(dp) > 0:
                        # print(smirks, dp)

            elif handle_type == nonbonded.LennardJonesHandler:
                # sanity check again, even though we don't have other lj methods currently
                assert handle_type == type(ff.lj_handle)
                ff.lj_handle.params -= dl_dp
Ejemplo n.º 4
0
def test_recipe_from_openmm():
    fname = 'tests/data/hif2a_nowater_min.pdb'
    openmm_system, _, _, _, _, _ = builders.build_protein_system(fname)
    md.Recipe.from_openmm(openmm_system)
Ejemplo n.º 5
0
def calculate_rigorous_work(
    host_pdbfile, guests_sdfile, outdir, fewer_outfiles=False, no_outfiles=False
):
    """
    """

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    print(
        f"""
    HOST_PDBFILE = {host_pdbfile}
    GUESTS_SDFILE = {guests_sdfile}
    OUTDIR = {outdir}

    INSERTION_MAX_LAMBDA = {INSERTION_MAX_LAMBDA}
    DELETION_MAX_LAMBDA = {DELETION_MAX_LAMBDA}
    MIN_LAMBDA = {MIN_LAMBDA}
    TRANSITION_STEPS = {TRANSITION_STEPS}
    EQ1_STEPS = {EQ1_STEPS}
    EQ2_STEPS = {EQ2_STEPS}
    """
    )

    # Prepare host
    # TODO: handle extra (non-transitioning) guests?
    print("Solvating host...")
    (
        solvated_host_system,
        solvated_host_coords,
        _,
        _,
        host_box,
        solvated_topology,
    ) = builders.build_protein_system(host_pdbfile)

    # sometimes water boxes are sad. Should be minimized first; this is a workaround
    host_box += np.eye(3) * 0.1
    print("host box", host_box)

    solvated_host_pdb = os.path.join(outdir, "solvated_host.pdb")
    writer = pdb_writer.PDBWriter([solvated_topology], solvated_host_pdb)
    writer.write_frame(solvated_host_coords)
    writer.close()
    solvated_host_mol = Chem.MolFromPDBFile(solvated_host_pdb, removeHs=False)
    if no_outfiles:
        os.remove(solvated_host_pdb)
    final_host_potentials = []
    host_potentials, host_masses = openmm_deserializer.deserialize_system(solvated_host_system, cutoff=1.2)
    host_nb_bp = None
    for bp in host_potentials:
        if isinstance(bp, potentials.Nonbonded):
            # (ytz): hack to ensure we only have one nonbonded term
            assert host_nb_bp is None
            host_nb_bp = bp
        else:
            final_host_potentials.append(bp)


    # Prepare water box
    print("Generating water box...")
    # TODO: water box probably doesn't need to be this big
    box_lengths = host_box[np.diag_indices(3)]
    water_box_width = min(box_lengths)
    (
        water_system,
        orig_water_coords,
        water_box,
        water_topology,
    ) = builders.build_water_system(water_box_width)

    # sometimes water boxes are sad. should be minimized first; this is a workaround
    water_box += np.eye(3) * 0.1
    print("water box", water_box)

    # it's okay if the water box here and the solvated protein box don't align -- they have PBCs
    water_pdb = os.path.join(outdir, "water_box.pdb")
    writer = pdb_writer.PDBWriter([water_topology], water_pdb)
    writer.write_frame(orig_water_coords)
    writer.close()
    water_mol = Chem.MolFromPDBFile(water_pdb, removeHs=False)
    if no_outfiles:
        os.remove(water_pdb)

    final_water_potentials = []
    water_potentials, water_masses = openmm_deserializer.deserialize_system(water_system, cutoff=1.2)
    water_nb_bp = None
    for bp in water_potentials:
        if isinstance(bp, potentials.Nonbonded):
            # (ytz): hack to ensure we only have one nonbonded term
            assert water_nb_bp is None
            water_nb_bp = bp
        else:
            final_water_potentials.append(bp)

    # Run the procedure
    print("Getting guests...")
    suppl = Chem.SDMolSupplier(guests_sdfile, removeHs=False)
    for guest_mol in suppl:
        start_time = time.time()
        guest_name = guest_mol.GetProp("_Name")
        guest_conformer = guest_mol.GetConformer(0)
        orig_guest_coords = np.array(guest_conformer.GetPositions(), dtype=np.float64)
        orig_guest_coords = orig_guest_coords / 10  # convert to md_units
        guest_ff_handlers = deserialize_handlers(
            open(
                os.path.join(
                    os.path.dirname(os.path.abspath(__file__)),
                    "..",
                    "ff/params/smirnoff_1_1_0_ccc.py",
                )
            ).read()
        )
        ff = Forcefield(guest_ff_handlers)
        guest_base_top = topology.BaseTopology(guest_mol, ff)

        # combine host & guest
        hgt = topology.HostGuestTopology(host_nb_bp, guest_base_top)
        # setup the parameter handlers for the ligand
        bonded_tuples = [
            [hgt.parameterize_harmonic_bond, ff.hb_handle],
            [hgt.parameterize_harmonic_angle, ff.ha_handle],
            [hgt.parameterize_proper_torsion, ff.pt_handle],
            [hgt.parameterize_improper_torsion, ff.it_handle]
        ]
        combined_bps = list(final_host_potentials)
        # instantiate the vjps while parameterizing (forward pass)
        for fn, handle in bonded_tuples:
            params, potential = fn(handle.params)
            combined_bps.append(potential.bind(params))
        nb_params, nb_potential = hgt.parameterize_nonbonded(ff.q_handle.params, ff.lj_handle.params)
        combined_bps.append(nb_potential.bind(nb_params))
        guest_masses = [a.GetMass() for a in guest_mol.GetAtoms()]
        combined_masses = np.concatenate([host_masses, guest_masses])

        run_leg(
            solvated_host_coords,
            orig_guest_coords,
            combined_bps,
            combined_masses,
            host_box,
            guest_name,
            "host",
            solvated_host_mol,
            guest_mol,
            outdir,
            fewer_outfiles,
            no_outfiles,
        )
        end_time = time.time()
        print(
            f"{guest_name} host leg time:", "%.2f" % (end_time - start_time), "seconds"
        )

        # combine water & guest
        wgt = topology.HostGuestTopology(water_nb_bp, guest_base_top)
        # setup the parameter handlers for the ligand
        bonded_tuples = [
            [wgt.parameterize_harmonic_bond, ff.hb_handle],
            [wgt.parameterize_harmonic_angle, ff.ha_handle],
            [wgt.parameterize_proper_torsion, ff.pt_handle],
            [wgt.parameterize_improper_torsion, ff.it_handle]
        ]
        combined_bps = list(final_water_potentials)
        # instantiate the vjps while parameterizing (forward pass)
        for fn, handle in bonded_tuples:
            params, potential = fn(handle.params)
            combined_bps.append(potential.bind(params))
        nb_params, nb_potential = wgt.parameterize_nonbonded(ff.q_handle.params, ff.lj_handle.params)
        combined_bps.append(nb_potential.bind(nb_params))
        guest_masses = [a.GetMass() for a in guest_mol.GetAtoms()]
        combined_masses = np.concatenate([water_masses, guest_masses])
        start_time = time.time()
        run_leg(
            orig_water_coords,
            orig_guest_coords,
            combined_bps,
            combined_masses,
            water_box,
            guest_name,
            "water",
            water_mol,
            guest_mol,
            outdir,
            fewer_outfiles,
            no_outfiles,
        )
        end_time = time.time()
        print(
            f"{guest_name} water leg time:", "%.2f" % (end_time - start_time), "seconds"
        )