Example #1
0
    def test_good_factor(self):
        # test a good mapping
        suppl = Chem.SDMolSupplier("tests/data/ligands_40.sdf", removeHs=False)
        all_mols = [x for x in suppl]
        mol_a = all_mols[1]
        mol_b = all_mols[4]

        ff = Forcefield.load_from_file("smirnoff_1_1_0_sc.py")

        core = np.array([
            [0, 0],
            [2, 2],
            [1, 1],
            [6, 6],
            [5, 5],
            [4, 4],
            [3, 3],
            [15, 16],
            [16, 17],
            [17, 18],
            [18, 19],
            [19, 20],
            [20, 21],
            [32, 30],
            [26, 25],
            [27, 26],
            [7, 7],
            [8, 8],
            [9, 9],
            [10, 10],
            [29, 11],
            [11, 12],
            [12, 13],
            [14, 15],
            [31, 29],
            [13, 14],
            [23, 24],
            [30, 28],
            [28, 27],
            [21, 22],
        ])

        st = topology.SingleTopology(mol_a, mol_b, core, ff)

        # test that the vjps work
        _ = jax.vjp(st.parameterize_harmonic_bond,
                    ff.hb_handle.params,
                    has_aux=True)
        _ = jax.vjp(st.parameterize_harmonic_angle,
                    ff.ha_handle.params,
                    has_aux=True)
        _ = jax.vjp(st.parameterize_periodic_torsion,
                    ff.pt_handle.params,
                    ff.it_handle.params,
                    has_aux=True)
        _ = jax.vjp(st.parameterize_nonbonded,
                    ff.q_handle.params,
                    ff.lj_handle.params,
                    has_aux=True)
Example #2
0
    def test_nonbonded_optimal_map(self):
        """Similar test as test_nonbonbed, ie. assert that coordinates and nonbonded parameters
        can be averaged in benzene -> phenol transformation. However, use the maximal mapping possible."""

        # map benzene H to phenol O, leaving a dangling phenol H
        core = np.array(
            [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6]],
            dtype=np.int32)

        st = topology.SingleTopology(self.mol_a, self.mol_b, core, self.ff)
        x_a = get_romol_conf(self.mol_a)
        x_b = get_romol_conf(self.mol_b)

        # test interpolation of coordinates.
        x_src, x_dst = st.interpolate_params(x_a, x_b)
        x_avg = np.mean([x_src, x_dst], axis=0)

        assert x_avg.shape == (st.get_num_atoms(), 3)

        np.testing.assert_array_equal((x_a[:7] + x_b[:7]) / 2,
                                      x_avg[:7])  # core parts
        np.testing.assert_array_equal(x_b[-1], x_avg[7])  # dangling H

        params, vjp_fn, pot_c = jax.vjp(st.parameterize_nonbonded,
                                        self.ff.q_handle.params,
                                        self.ff.lj_handle.params,
                                        has_aux=True)

        vjp_fn(np.random.rand(*params.shape))

        assert params.shape == (2 * st.get_num_atoms(), 3)  # qlj

        # test interpolation of parameters
        bt_a = topology.BaseTopology(self.mol_a, self.ff)
        qlj_a, pot_a = bt_a.parameterize_nonbonded(self.ff.q_handle.params,
                                                   self.ff.lj_handle.params)
        bt_b = topology.BaseTopology(self.mol_b, self.ff)
        qlj_b, pot_b = bt_b.parameterize_nonbonded(self.ff.q_handle.params,
                                                   self.ff.lj_handle.params)

        n_base_params = len(
            params
        ) // 2  # params is actually interpolated, so its 2x number of base params

        # qlj_c = np.mean([params[:n_base_params], params[n_base_params:]], axis=0)

        params_src = params[:n_base_params]
        params_dst = params[n_base_params:]

        # core testing
        np.testing.assert_array_equal(qlj_a[:7], params_src[:7])
        np.testing.assert_array_equal(qlj_b[:7], params_dst[:7])

        # r-group atoms in A are all part of the core. so no testing is needed.

        # test r-group in B
        np.testing.assert_array_equal(qlj_b[7], params_dst[8])
        np.testing.assert_array_equal(np.array([0, qlj_b[7][1], 0]),
                                      params_src[8])
Example #3
0
def _setup_hif2a_ligand_pair(ff="timemachine/ff/params/smirnoff_1_1_0_ccc.py"):
    """Manually constructed atom map

    TODO: replace this with a testsystem class similar to those used in openmmtools
    """
    path_to_ligand = str(root.joinpath("tests/data/ligands_40.sdf"))
    forcefield = Forcefield.load_from_file(ff)

    suppl = Chem.SDMolSupplier(path_to_ligand, removeHs=False)
    all_mols = [x for x in suppl]
    mol_a = all_mols[1]
    mol_b = all_mols[4]

    core = np.array([
        [0, 0],
        [2, 2],
        [1, 1],
        [6, 6],
        [5, 5],
        [4, 4],
        [3, 3],
        [15, 16],
        [16, 17],
        [17, 18],
        [18, 19],
        [19, 20],
        [20, 21],
        [32, 30],
        [26, 25],
        [27, 26],
        [7, 7],
        [8, 8],
        [9, 9],
        [10, 10],
        [29, 11],
        [11, 12],
        [12, 13],
        [14, 15],
        [31, 29],
        [13, 14],
        [23, 24],
        [30, 28],
        [28, 27],
        [21, 22],
    ])

    single_topology = topology.SingleTopology(mol_a, mol_b, core, forcefield)
    rfe = free_energy.RelativeFreeEnergy(single_topology)

    return rfe
Example #4
0
def generate_star(
    hub,
    spokes,
    forcefield,
    transformation_size_threshold: int = 2,
    core_strategy: str = "geometry",
    label_property: str = "IC50[uM](SPA)",
    core_kwargs: Optional[Dict[str, Any]] = None,
):
    if core_kwargs is None:
        core_kwargs = {}
    transformations = []
    error_transformations = []
    for spoke in spokes:

        # TODO: reduce overlap between get_core_by_smarts and get_core_by_mcs
        # TODO: replace big ol' list of get_core_by_*(mol_a, mol_b, **kwargs) functions with something... classy
        core = None
        try:
            core = core_strategies[core_strategy](hub, spoke, core_kwargs)
            single_topology = topology.SingleTopology(hub, spoke, core,
                                                      forcefield)
            rfe = RelativeFreeEnergy(single_topology,
                                     label=_compute_label(
                                         hub, spoke, label_property))
            transformations.append(rfe)
        except AtomMappingError:
            # note: some of transformations may fail the factorizability assertion here:
            # https://github.com/proteneer/timemachine/blob/2eb956f9f8ce62287cc531188d1d1481832c5e96/fe/topology.py#L381-L431
            error_transformations.append((hub, spoke, core))

    print(
        f"total # of edges that encountered atom mapping errors: {len(error_transformations)}"
    )

    # filter to keep just the edges with very small number of atoms changing
    easy_transformations = [
        rfe for rfe in transformations
        if rbfe_transformation_estimate(rfe) <= transformation_size_threshold
    ]

    return easy_transformations, error_transformations
Example #5
0
    def test_bad_factor(self):
        # test a bad mapping that results in a non-cancellable endpoint
        suppl = Chem.SDMolSupplier("tests/data/ligands_40.sdf", removeHs=False)
        all_mols = [x for x in suppl]
        mol_a = all_mols[0]
        mol_b = all_mols[1]

        ff = Forcefield.load_from_file("smirnoff_1_1_0_sc.py")

        core = np.array([
            [4, 1],
            [5, 2],
            [6, 3],
            [7, 4],
            [8, 5],
            [9, 6],
            [10, 7],
            [11, 8],
            [12, 9],
            [13, 10],
            [15, 11],
            [16, 12],
            [18, 14],
            [34, 31],
            [17, 13],
            [23, 23],
            [33, 30],
            [32, 28],
            [31, 27],
            [30, 26],
            [19, 15],
            [20, 16],
            [21, 17],
        ])

        with self.assertRaises(topology.AtomMappingError):
            topology.SingleTopology(mol_a, mol_b, core, ff)
Example #6
0
def test_relative_free_energy():
    # test that we can properly build a single topology host guest system and
    # that we can run a few steps in a stable way. This tests runs both the complex
    # and the solvent stages.

    suppl = Chem.SDMolSupplier("tests/data/ligands_40.sdf", removeHs=False)
    all_mols = [x for x in suppl]
    mol_a = all_mols[1]
    mol_b = all_mols[4]

    core = np.array([
        [0, 0],
        [2, 2],
        [1, 1],
        [6, 6],
        [5, 5],
        [4, 4],
        [3, 3],
        [15, 16],
        [16, 17],
        [17, 18],
        [18, 19],
        [19, 20],
        [20, 21],
        [32, 30],
        [26, 25],
        [27, 26],
        [7, 7],
        [8, 8],
        [9, 9],
        [10, 10],
        [29, 11],
        [11, 12],
        [12, 13],
        [14, 15],
        [31, 29],
        [13, 14],
        [23, 24],
        [30, 28],
        [28, 27],
        [21, 22],
    ])

    complex_system, complex_coords, _, _, complex_box, _ = builders.build_protein_system(
        "tests/data/hif2a_nowater_min.pdb")

    # build the water system.
    solvent_system, solvent_coords, solvent_box, _ = builders.build_water_system(
        4.0)

    ff = Forcefield.load_from_file("smirnoff_1_1_0_ccc.py")

    ff_params = ff.get_ordered_params()

    seed = 2021

    lambda_schedule = np.linspace(0, 1.0, 4)
    equil_steps = 1000
    prod_steps = 1000

    single_topology = topology.SingleTopology(mol_a, mol_b, core, ff)
    rfe = free_energy.RelativeFreeEnergy(single_topology)

    def vacuum_model(ff_params):

        unbound_potentials, sys_params, masses, coords = rfe.prepare_vacuum_edge(
            ff_params)

        x0 = coords
        v0 = np.zeros_like(coords)
        client = CUDAPoolClient(1)
        box = np.eye(3, dtype=np.float64) * 100

        harmonic_bond_potential = unbound_potentials[0]
        group_idxs = get_group_indices(get_bond_list(harmonic_bond_potential))

        x0 = coords
        v0 = np.zeros_like(coords)
        client = CUDAPoolClient(1)
        temperature = 300.0
        pressure = 1.0

        integrator = LangevinIntegrator(temperature, 1.5e-3, 1.0, masses, seed)

        barostat = MonteCarloBarostat(x0.shape[0], pressure, temperature,
                                      group_idxs, 25, seed)
        model = estimator.FreeEnergyModel(unbound_potentials, client, box, x0,
                                          v0, integrator, lambda_schedule,
                                          equil_steps, prod_steps, barostat)

        return estimator.deltaG(model, sys_params)[0]

    dG = vacuum_model(ff_params)
    assert np.abs(dG) < 1000.0

    def binding_model(ff_params):

        dGs = []

        for host_system, host_coords, host_box in [
            (complex_system, complex_coords, complex_box),
            (solvent_system, solvent_coords, solvent_box),
        ]:

            # minimize the host to avoid clashes
            host_coords = minimizer.minimize_host_4d([mol_a], host_system,
                                                     host_coords, ff, host_box)

            unbound_potentials, sys_params, masses, coords = rfe.prepare_host_edge(
                ff_params, host_system, host_coords)

            x0 = coords
            v0 = np.zeros_like(coords)
            client = CUDAPoolClient(1)

            harmonic_bond_potential = unbound_potentials[0]
            group_idxs = get_group_indices(
                get_bond_list(harmonic_bond_potential))

            temperature = 300.0
            pressure = 1.0

            integrator = LangevinIntegrator(temperature, 1.5e-3, 1.0, masses,
                                            seed)

            barostat = MonteCarloBarostat(x0.shape[0], pressure, temperature,
                                          group_idxs, 25, seed)

            model = estimator.FreeEnergyModel(
                unbound_potentials,
                client,
                host_box,
                x0,
                v0,
                integrator,
                lambda_schedule,
                equil_steps,
                prod_steps,
                barostat,
            )

            dG, _ = estimator.deltaG(model, sys_params)
            dGs.append(dG)

        return dGs[0] - dGs[1]

    dG = binding_model(ff_params)
    assert np.abs(dG) < 1000.0
Example #7
0
def do_relative_docking(host_pdbfile, mol_a, mol_b, core, num_switches,
                        transition_steps):
    """Runs non-equilibrium switching jobs:
    1. Solvates a protein, minimizes w.r.t guest_A, equilibrates & spins off switching jobs
       (deleting guest_A while inserting guest_B) every 1000th step, calculates work.
    2. Does the same thing in solvent instead of protein
    Does num_switches switching jobs per leg.

    Parameters
    ----------

    host_pdbfile (str): path to host pdb file
    mol_a (rdkit mol): the starting ligand to swap from
    mol_b (rdkit mol): the ending ligand to swap to
    core (np.array[[int, int], [int, int], ...]): the common core atoms between mol_a and mol_b
    num_switches (int): number of switching trajectories to run per compound pair per leg
    transition_stpes (int): length of each switching trajectory

    Returns
    -------

    {str: float}: map of leg label to work values of switching mol_a to mol_b in that leg,
                  {'protein': [work values], 'solvent': [work_values]}

    Output
    ------

    stdout noting the step number, lambda value, and energy at various steps
    stdout noting the work of transition, if applicable
    stdout noting how long it took to run

    Note
    ----
    The work will not be calculated if any norm of force per atom exceeds 20000 kJ/(mol*nm)
       [MAX_NORM_FORCE defined in docking/report.py]
    The simulations won't run if the atom maps are not factorizable
    """

    # Prepare host
    # TODO: handle extra (non-transitioning) guests?
    print("Solvating host...")
    (
        solvated_host_system,
        solvated_host_coords,
        _,
        _,
        host_box,
        solvated_topology,
    ) = builders.build_protein_system(host_pdbfile)

    # Prepare water box
    print("Generating water box...")
    # TODO: water box probably doesn't need to be this big
    box_lengths = host_box[np.diag_indices(3)]
    water_box_width = min(box_lengths)
    (
        water_system,
        water_coords,
        water_box,
        water_topology,
    ) = builders.build_water_system(water_box_width)

    # it's okay if the water box here and the solvated protein box don't align -- they have PBCs

    # Run the procedure
    start_time = time.time()
    guest_name_a = mol_a.GetProp("_Name")
    guest_name_b = mol_b.GetProp("_Name")
    combined_name = guest_name_a + "-->" + guest_name_b

    guest_conformer_a = mol_a.GetConformer(0)
    orig_guest_coords_a = np.array(guest_conformer_a.GetPositions(),
                                   dtype=np.float64)
    orig_guest_coords_a = orig_guest_coords_a / 10  # convert to md_units

    ff = Forcefield.load_from_file("smirnoff_1_1_0_ccc.py")

    all_works = {}
    for system, coords, box, label in zip(
        [solvated_host_system, water_system],
        [solvated_host_coords, water_coords],
        [host_box, water_box],
        ["protein", "solvent"],
    ):
        # minimize w.r.t. both mol_a and mol_b?
        min_coords = minimizer.minimize_host_4d([mol_a], system, coords, ff,
                                                box)

        try:
            single_topology = topology.SingleTopology(mol_a, mol_b, core, ff)
            rfe = free_energy.RelativeFreeEnergy(single_topology)
            ups, sys_params, combined_masses, combined_coords = rfe.prepare_host_edge(
                ff.get_ordered_params(), system, min_coords)
        except topology.AtomMappingError as e:
            print(f"NON-FACTORIZABLE PAIR: {combined_name}")
            print(e)
            return {}

        combined_bps = []
        for up, sp in zip(ups, sys_params):
            combined_bps.append(up.bind(sp))
        all_works[label] = run_leg(
            combined_coords,
            combined_bps,
            combined_masses,
            box,
            combined_name,
            label,
            num_switches,
            transition_steps,
        )
        end_time = time.time()
        print(
            f"{combined_name} {label} leg time:",
            "%.2f" % (end_time - start_time),
            "seconds",
        )
    return all_works
Example #8
0
    def test_bonded(self):
        # other bonded terms use an identical protocol, so we assume they're correct if the harmonic bond tests pass.
        # leaving benzene H unmapped, and phenol OH unmapped
        core = np.array(
            [
                [0, 0],
                [1, 1],
                [2, 2],
                [3, 3],
                [4, 4],
                [5, 5],
            ],
            dtype=np.int32,
        )

        st = topology.SingleTopology(self.mol_a, self.mol_b, core, self.ff)

        combined_params, vjp_fn, combined_potential = jax.vjp(
            st.parameterize_harmonic_bond,
            self.ff.hb_handle.params,
            has_aux=True)

        # test that vjp_fn works
        vjp_fn(np.random.rand(*combined_params.shape))

        # we expect 15 bonds in total, of which 6 are duplicated.
        assert len(combined_potential.get_idxs() == 15)

        src_idxs = set([tuple(x) for x in combined_potential.get_idxs()[:6]])
        dst_idxs = set([tuple(x) for x in combined_potential.get_idxs()[6:12]])

        np.testing.assert_equal(src_idxs, dst_idxs)

        cc = self.ff.hb_handle.lookup_smirks("[#6X3:1]:[#6X3:2]")
        cH = self.ff.hb_handle.lookup_smirks("[#6X3:1]-[#1:2]")
        cO = self.ff.hb_handle.lookup_smirks("[#6X3:1]-[#8X2H1:2]")
        OH = self.ff.hb_handle.lookup_smirks("[#8:1]-[#1:2]")

        params_src = combined_params[:6]
        params_dst = combined_params[6:12]
        params_uni = combined_params[12:]

        np.testing.assert_array_equal(params_src, [cc, cc, cc, cc, cc, cc])
        np.testing.assert_array_equal(params_dst, [cc, cc, cc, cc, cc, cc])
        np.testing.assert_array_equal(params_uni, [cH, cO, OH])

        # map H to O
        core = np.array(
            [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6]],
            dtype=np.int32)

        st = topology.SingleTopology(self.mol_a, self.mol_b, core, self.ff)

        combined_params, vjp_fn, combined_potential = jax.vjp(
            st.parameterize_harmonic_bond,
            self.ff.hb_handle.params,
            has_aux=True)

        assert len(combined_potential.get_idxs() == 15)

        src_idxs = set([tuple(x) for x in combined_potential.get_idxs()[:7]])
        dst_idxs = set([tuple(x) for x in combined_potential.get_idxs()[7:14]])

        params_src = combined_params[:7]
        params_dst = combined_params[7:14]
        params_uni = combined_params[14:]

        np.testing.assert_array_equal(params_src, [cc, cc, cc, cc, cc, cc, cH])
        np.testing.assert_array_equal(params_dst, [cc, cc, cc, cc, cc, cc, cO])
        np.testing.assert_array_equal(params_uni, [OH])

        # test that vjp_fn works
        vjp_fn(np.random.rand(*combined_params.shape))
Example #9
0
    def test_nonbonded(self):

        # leaving benzene H unmapped, and phenol OH unmapped
        core = np.array(
            [
                [0, 0],
                [1, 1],
                [2, 2],
                [3, 3],
                [4, 4],
                [5, 5],
            ],
            dtype=np.int32,
        )

        st = topology.SingleTopology(self.mol_a, self.mol_b, core, self.ff)
        x_a = get_romol_conf(self.mol_a)
        x_b = get_romol_conf(self.mol_b)

        # test interpolation of coordinates.
        x_src, x_dst = st.interpolate_params(x_a, x_b)
        x_avg = np.mean([x_src, x_dst], axis=0)

        assert x_avg.shape == (st.get_num_atoms(), 3)

        np.testing.assert_array_equal((x_a[:6] + x_b[:6]) / 2, x_avg[:6])  # C
        np.testing.assert_array_equal(x_a[6], x_avg[6])  # H
        np.testing.assert_array_equal(x_b[6:], x_avg[7:])  # OH

        # NOTE: unused result
        st.parameterize_nonbonded(self.ff.q_handle.params,
                                  self.ff.lj_handle.params)

        params, vjp_fn, pot_c = jax.vjp(st.parameterize_nonbonded,
                                        self.ff.q_handle.params,
                                        self.ff.lj_handle.params,
                                        has_aux=True)

        vjp_fn(np.random.rand(*params.shape))

        assert params.shape == (2 * st.get_num_atoms(), 3)  # qlj

        # test interpolation of parameters
        bt_a = topology.BaseTopology(self.mol_a, self.ff)
        qlj_a, pot_a = bt_a.parameterize_nonbonded(self.ff.q_handle.params,
                                                   self.ff.lj_handle.params)
        bt_b = topology.BaseTopology(self.mol_b, self.ff)
        qlj_b, pot_b = bt_b.parameterize_nonbonded(self.ff.q_handle.params,
                                                   self.ff.lj_handle.params)

        n_base_params = len(
            params
        ) // 2  # params is actually interpolated, so its 2x number of base params

        # qlj_c = np.mean([params[:n_base_params], params[n_base_params:]], axis=0)

        params_src = params[:n_base_params]
        params_dst = params[n_base_params:]

        # core testing
        np.testing.assert_array_equal(qlj_a[:6], params_src[:6])
        np.testing.assert_array_equal(qlj_b[:6], params_dst[:6])

        # test r-group in A
        np.testing.assert_array_equal(qlj_a[6], params_src[6])
        np.testing.assert_array_equal(np.array([0, qlj_a[6][1], 0]),
                                      params_dst[6])

        # test r-group in B
        np.testing.assert_array_equal(qlj_b[6:], params_dst[7:])
        np.testing.assert_array_equal(
            np.array([[0, qlj_b[6][1], 0], [0, qlj_b[7][1], 0]]),
            params_src[7:])
Example #10
0
def estimate_dG(
    transformation: RelativeTransformation,
    num_lambdas: int,
    num_steps_per_lambda: int,
    num_equil_steps: int,
):
    # build the protein system.
    complex_system, complex_coords, _, _, complex_box = builders.build_protein_system(
        path_to_protein)

    # build the water system.
    solvent_system, solvent_coords, solvent_box, _ = builders.build_water_system(
        4.0)

    stage_dGs = []

    ff = transformation.ff
    mol_a, mol_b = transformation.mol_a, transformation.mol_b
    core = transformation.core

    # TODO: measure performance of complex and solvent separately

    lambda_schedule = construct_lambda_schedule(num_lambdas)

    for stage, host_system, host_coords, host_box in [
        ("complex", complex_system, complex_coords, complex_box),
        ("solvent", solvent_system, solvent_coords, solvent_box),
    ]:

        print("Minimizing the host structure to remove clashes.")
        minimized_host_coords = minimizer.minimize_host_4d(
            mol_a, host_system, host_coords, ff, host_box)

        single_topology = topology.SingleTopology(mol_a, mol_b, core, ff)
        rfe = free_energy.RelativeFreeEnergy(single_topology)

        # solvent leg
        host_args = []
        for lambda_idx, lamb in enumerate(lambda_schedule):
            gpu_idx = lambda_idx % num_gpus
            host_args.append(
                (gpu_idx, lamb, host_system, minimized_host_coords, host_box,
                 num_equil_steps, num_steps_per_lambda))

        # one GPU job per lambda window
        print("submitting tasks to client!")
        do_work = partial(wrap_method, fxn=rfe.host_edge)
        futures = []
        for lambda_idx, lamb in enumerate(lambda_schedule):
            arg = (lamb, host_system, minimized_host_coords, host_box,
                   num_equil_steps, num_steps_per_lambda)
            futures.append(client.submit(do_work, arg))

        results = []
        for fut in futures:
            results.append(fut.result())

        def _mean_du_dlambda(result):
            """summarize result of rfe.host_edge into mean du/dl

            TODO: refactor where this analysis step occurs
            """
            bonded_du_dl, nonbonded_du_dl, _ = result
            return np.mean(bonded_du_dl + nonbonded_du_dl)

        dG_host = np.trapz([_mean_du_dlambda(x) for x in results],
                           lambda_schedule)
        stage_dGs.append(dG_host)

    pred = stage_dGs[0] - stage_dGs[1]
    return pred
Example #11
0
    # relative free energy, compare two different core approaches

    core_full = np.stack(
        [np.arange(mol_a.GetNumAtoms()),
         np.arange(mol_b.GetNumAtoms())],
        axis=1)

    core_part = np.stack([
        np.arange(mol_a.GetNumAtoms() - 1),
        np.arange(mol_b.GetNumAtoms() - 1)
    ],
                         axis=1)

    for core_idx, core in enumerate([core_full, core_part]):
        single_topology = topology.SingleTopology(mol_a, mol_b, core, ff)

        rfe = free_energy.RelativeFreeEnergy(single_topology)

        vacuum_lambda_schedule = np.linspace(0.0, 1.0,
                                             cmd_args.num_vacuum_windows)
        solvent_lambda_schedule = np.linspace(0.0, 1.0,
                                              cmd_args.num_solvent_windows)

        # vacuum leg
        vacuum_args = []
        for lambda_idx, lamb in enumerate(vacuum_lambda_schedule):
            gpu_idx = lambda_idx % cmd_args.num_gpus
            vacuum_args.append((gpu_idx, lamb, cmd_args.num_equil_steps,
                                cmd_args.num_prod_steps))
Example #12
0
    def predict(self, ff_params: list, mol_a: Chem.Mol, mol_b: Chem.Mol,
                core: np.ndarray):
        """
        Predict the ddG of morphing mol_a into mol_b. This function is differentiable w.r.t. ff_params.

        Parameters
        ----------

        ff_params: list of np.ndarray
            This should match the ordered params returned by the forcefield

        mol_a: Chem.Mol
            Starting molecule corresponding to lambda = 0

        mol_b: Chem.Mol
            Starting molecule corresponding to lambda = 1

        core: np.ndarray
            N x 2 list of ints corresponding to the atom mapping of the core.

        Returns
        -------
        float
            delta delta G in kJ/mol
        aux
            list of TI results
        """

        stage_dGs = []
        stage_results = []

        for stage, host_system, host_coords, host_box, lambda_schedule in [
            ("complex", self.complex_system, self.complex_coords,
             self.complex_box, self.complex_schedule),
            ("solvent", self.solvent_system, self.solvent_coords,
             self.solvent_box, self.solvent_schedule),
        ]:
            single_topology = topology.SingleTopology(mol_a, mol_b, core,
                                                      self.ff)
            rfe = free_energy.RelativeFreeEnergy(single_topology)
            edge_hash = self._edge_hash(stage, mol_a, mol_b, core)
            if self.pre_equilibrate and edge_hash in self._equil_cache:
                cached_state = self._equil_cache[edge_hash]
                x0 = cached_state.coords
                host_box = cached_state.box
                num_host_coords = len(host_coords)
                unbound_potentials, sys_params, masses, _ = rfe.prepare_host_edge(
                    ff_params, host_system, host_coords)
                mol_a_size = mol_a.GetNumAtoms()
                # Use Dual Topology to pre equilibrate, so have to get the mean of the two sets of mol,
                # normally done within prepare_host_edge, but the whole system has moved by this stage
                x0 = np.concatenate([
                    x0[:num_host_coords],
                    np.mean(
                        single_topology.interpolate_params(
                            x0[num_host_coords:num_host_coords + mol_a_size],
                            x0[num_host_coords + mol_a_size:]),
                        axis=0,
                    ),
                ])
            else:
                if self.pre_equilibrate:
                    print(
                        "Edge not correctly pre-equilibrated, ensure equilibrate_edges was called"
                    )
                print(
                    f"Minimizing the {stage} host structure to remove clashes."
                )
                # (ytz): this isn't strictly symmetric, and we should modify minimize later on remove
                # the hysteresis by jointly minimizing against a and b at the same time. We may also want
                # to remove the randomness completely from the minimization.
                min_host_coords = minimizer.minimize_host_4d([mol_a, mol_b],
                                                             host_system,
                                                             host_coords,
                                                             self.ff, host_box)

                unbound_potentials, sys_params, masses, coords = rfe.prepare_host_edge(
                    ff_params, host_system, min_host_coords)

                x0 = coords
            v0 = np.zeros_like(x0)

            time_step = 1.5e-3

            harmonic_bond_potential = unbound_potentials[0]
            bond_list = get_bond_list(harmonic_bond_potential)
            if self.hmr:
                masses = apply_hmr(masses, bond_list)
                time_step = 2.5e-3
            group_idxs = get_group_indices(bond_list)

            seed = 0

            temperature = 300.0
            pressure = 1.0

            integrator = LangevinIntegrator(temperature, time_step, 1.0,
                                            masses, seed)

            barostat = MonteCarloBarostat(x0.shape[0], pressure, temperature,
                                          group_idxs, self.barostat_interval,
                                          seed)

            model = estimator.FreeEnergyModel(
                unbound_potentials,
                self.client,
                host_box,
                x0,
                v0,
                integrator,
                lambda_schedule,
                self.equil_steps,
                self.prod_steps,
                barostat,
            )

            dG, results = estimator.deltaG(model, sys_params)

            stage_dGs.append(dG)
            stage_results.append((stage, results))

        pred = stage_dGs[0] - stage_dGs[1]

        return pred, stage_results