Beispiel #1
0
def benchmark_dhfr():

    pdb_path = 'tests/data/5dfr_solv_equil.pdb'
    host_pdb = app.PDBFile(pdb_path)
    protein_ff = app.ForceField('amber99sbildn.xml', 'tip3p.xml')
    host_system = protein_ff.createSystem(
        host_pdb.topology,
        nonbondedMethod=app.NoCutoff,
        constraints=None,
        rigidWater=False
    )
    host_coords = host_pdb.positions
    box = host_pdb.topology.getPeriodicBoxVectors()
    box = np.asarray(box/box.unit)

    host_fns, host_masses = openmm_deserializer.deserialize_system(
        host_system,
        cutoff=1.0
    )

    host_conf = []
    for x,y,z in host_coords:
        host_conf.append([to_md_units(x),to_md_units(y),to_md_units(z)])
    host_conf = np.array(host_conf)

    seed = 1234
    dt = 1.5e-3

    intg = LangevinIntegrator(
        300,
        dt,
        1.0,
        np.array(host_masses),
        seed
    ).impl()

    bps = []

    for potential in host_fns:
        bps.append(potential.bound_impl(precision=np.float32)) # get the bound implementation

    x0 = host_conf
    v0 = np.zeros_like(host_conf)

    ctxt = custom_ops.Context(
        x0,
        v0,
        box,
        intg,
        bps
    )

    # initialize observables
    obs = []
    for bp in bps:
        du_dp_obs = custom_ops.AvgPartialUPartialParam(bp, 100)
        ctxt.add_observable(du_dp_obs)
        obs.append(du_dp_obs)

    lamb = 0.0

    start = time.time()
    # num_steps = 50000
    num_steps = 50000
    # num_steps = 10

    writer = PDBWriter([host_pdb.topology], "dhfr.pdb")

    for step in range(num_steps):
        ctxt.step(lamb)
        if step % 1000 == 0:

            delta = time.time()-start
            steps_per_second = step/delta
            seconds_per_day = 86400
            steps_per_day = steps_per_second*seconds_per_day
            ps_per_day = dt*steps_per_day
            ns_per_day = ps_per_day*1e-3

            print(step, "ns/day", ns_per_day)
            # coords = recenter(ctxt.get_x_t(), box)
            # writer.write_frame(coords*10)

    print("total time", time.time() - start)

    writer.close()


    # bond angle torsions nonbonded
    for potential, du_dp_obs in zip(host_fns, obs):
        dp = du_dp_obs.avg_du_dp()
        print(potential, dp.shape)
        print(dp)
Beispiel #2
0
    # note: these 5000 steps are "equilibration", before we attach a reporter /
    #   "observable" to the context and start running "production"
    for step in range(5000):
        if step % 500 == 0:
            writer.write_frame(ctxt.get_x_t() * 10)
        ctxt.step(final_lamb)

    # print("equilibrium energy", ctxt._get_u_t_minus_1())

    # TODO: what was the second argument -- reporting interval in steps?
    du_dl_obs = custom_ops.AvgPartialUPartialLambda(u_impls, 5)

    du_dps = []
    for ui in u_impls:
        du_dp_obs = custom_ops.AvgPartialUPartialParam(ui, 5)
        ctxt.add_observable(du_dp_obs)
        du_dps.append(du_dp_obs)

    ctxt.add_observable(du_dl_obs)

    for _ in range(20000):
        if step % 500 == 0:
            writer.write_frame(ctxt.get_x_t() * 10)
        ctxt.step(final_lamb)

    writer.close()

    # print("final energy", ctxt._get_u_t_minus_1())

    # print vector jacobian products back into the forcefield derivative
Beispiel #3
0
    def test_fwd_mode(self):
        """
        This test ensures that we can reverse-mode differentiate
        observables that are dU_dlambdas of each state. We provide
        adjoints with respect to each computed dU/dLambda.
        """

        np.random.seed(4321)

        N = 8
        B = 5
        A = 0
        T = 0
        D = 3

        x0 = np.random.rand(N, D).astype(dtype=np.float64) * 2

        E = 2

        lambda_plane_idxs = np.random.randint(low=0,
                                              high=2,
                                              size=N,
                                              dtype=np.int32)
        lambda_offset_idxs = np.random.randint(low=0,
                                               high=2,
                                               size=N,
                                               dtype=np.int32)

        params, ref_nrg_fn, test_nrg = prepare_nb_system(
            x0,
            E,
            lambda_plane_idxs,
            lambda_offset_idxs,
            p_scale=3.0,
            # cutoff=0.5,
            cutoff=1.5)

        masses = np.random.rand(N)

        v0 = np.random.rand(x0.shape[0], x0.shape[1])
        N = len(masses)

        num_steps = 5
        lambda_schedule = np.random.rand(num_steps)
        ca = np.random.rand()
        cbs = -np.random.rand(len(masses)) / 1
        ccs = np.zeros_like(cbs)

        dt = 2e-3
        lamb = np.random.rand()

        def loss_fn(du_dls):
            return jnp.sum(du_dls * du_dls) / du_dls.shape[0]

        def sum_loss_fn(du_dls):
            du_dls = np.sum(du_dls, axis=0)
            return jnp.sum(du_dls * du_dls) / du_dls.shape[0]

        def integrate_once_through(x_t, v_t, box, params):

            dU_dx_fn = jax.grad(ref_nrg_fn, argnums=(0, ))
            dU_dp_fn = jax.grad(ref_nrg_fn, argnums=(1, ))
            dU_dl_fn = jax.grad(ref_nrg_fn, argnums=(3, ))

            all_du_dls = []
            all_du_dps = []
            all_xs = []
            all_du_dxs = []
            all_us = []
            for step in range(num_steps):
                u = ref_nrg_fn(x_t, params, box, lamb)
                all_us.append(u)
                du_dl = dU_dl_fn(x_t, params, box, lamb)[0]
                all_du_dls.append(du_dl)
                du_dp = dU_dp_fn(x_t, params, box, lamb)[0]
                all_du_dps.append(du_dp)
                du_dx = dU_dx_fn(x_t, params, box, lamb)[0]
                all_du_dxs.append(du_dx)
                v_t = ca * v_t + np.expand_dims(cbs, axis=-1) * du_dx
                x_t = x_t + v_t * dt
                all_xs.append(x_t)
                # note that we do not calculate the du_dl of the last frame.

            return all_xs, all_du_dxs, all_du_dps, all_du_dls, all_us

        box = np.eye(3) * 1.5

        # when we have multiple parameters, we need to set this up correctly
        ref_all_xs, ref_all_du_dxs, ref_all_du_dps, ref_all_du_dls, ref_all_us = integrate_once_through(
            x0, v0, box, params)

        intg = custom_ops.LangevinIntegrator(dt, ca, cbs, ccs, 1234)

        bp = test_nrg.bind(params).bound_impl(precision=np.float64)
        bps = [bp]

        ctxt = custom_ops.Context(x0, v0, box, intg, bps)

        test_obs = custom_ops.AvgPartialUPartialParam(bp, 1)
        test_obs_f2 = custom_ops.AvgPartialUPartialParam(bp, 2)

        test_obs_du_dl = custom_ops.AvgPartialUPartialLambda(bps, 1)
        test_obs_f2_du_dl = custom_ops.AvgPartialUPartialLambda(bps, 2)
        test_obs_f3_du_dl = custom_ops.FullPartialUPartialLambda(bps, 2)

        obs = [
            test_obs, test_obs_f2, test_obs_du_dl, test_obs_f2_du_dl,
            test_obs_f3_du_dl
        ]

        for o in obs:
            ctxt.add_observable(o)

        for step in range(num_steps):
            print("comparing step", step)
            ctxt.step(lamb)
            test_x_t = ctxt.get_x_t()
            test_v_t = ctxt.get_v_t()
            test_du_dx_t = ctxt._get_du_dx_t_minus_1()
            # test_u_t = ctxt._get_u_t_minus_1()
            # np.testing.assert_allclose(test_u_t, ref_all_us[step])
            np.testing.assert_allclose(test_du_dx_t, ref_all_du_dxs[step])
            np.testing.assert_allclose(test_x_t, ref_all_xs[step])

        ref_avg_du_dls = np.mean(ref_all_du_dls, axis=0)
        ref_avg_du_dls_f2 = np.mean(ref_all_du_dls[::2], axis=0)

        np.testing.assert_allclose(test_obs_du_dl.avg_du_dl(), ref_avg_du_dls)
        np.testing.assert_allclose(test_obs_f2_du_dl.avg_du_dl(),
                                   ref_avg_du_dls_f2)

        full_du_dls = test_obs_f3_du_dl.full_du_dl()
        assert len(full_du_dls) == np.ceil(num_steps / 2)
        np.testing.assert_allclose(np.mean(full_du_dls), ref_avg_du_dls_f2)

        ref_avg_du_dps = np.mean(ref_all_du_dps, axis=0)
        ref_avg_du_dps_f2 = np.mean(ref_all_du_dps[::2], axis=0)

        # the fixed point accumulator makes it hard to converge some of these
        # if the derivative is super small - in which case they probably don't matter
        # anyways
        np.testing.assert_allclose(test_obs.avg_du_dp()[:, 0],
                                   ref_avg_du_dps[:, 0], 1.5e-6)
        np.testing.assert_allclose(test_obs.avg_du_dp()[:, 1],
                                   ref_avg_du_dps[:, 1], 1.5e-6)
        np.testing.assert_allclose(test_obs.avg_du_dp()[:, 2],
                                   ref_avg_du_dps[:, 2], 5e-5)
Beispiel #4
0
    def Simulate(self, request, context):

        if request.precision == 'single':
            precision = np.float32
        elif request.precision == 'double':
            precision = np.float64
        else:
            raise Exception("Unknown precision")

        simulation = pickle.loads(request.simulation)

        bps = []
        pots = []

        for potential in simulation.potentials:
            bps.append(potential.bound_impl())  # get the bound implementation

        intg = simulation.integrator.impl()

        ctxt = custom_ops.Context(simulation.x, simulation.v, simulation.box,
                                  intg, bps)

        lamb = request.lamb

        for step, minimize_lamb in enumerate(
                np.linspace(1.0, lamb, request.prep_steps)):
            ctxt.step(minimize_lamb)

        energies = []
        frames = []

        if request.observe_du_dl_freq > 0:
            du_dl_obs = custom_ops.AvgPartialUPartialLambda(
                bps, request.observe_du_dl_freq)
            ctxt.add_observable(du_dl_obs)

        if request.observe_du_dp_freq > 0:
            du_dps = []
            # for name, bp in zip(names, bps):
            # if name == 'LennardJones' or name == 'Electrostatics':
            for bp in bps:
                du_dp_obs = custom_ops.AvgPartialUPartialParam(
                    bp, request.observe_du_dp_freq)
                ctxt.add_observable(du_dp_obs)
                du_dps.append(du_dp_obs)

        # dynamics
        for step in range(request.prod_steps):
            if step % 100 == 0:
                u = ctxt._get_u_t_minus_1()
                energies.append(u)

            if request.n_frames > 0:
                interval = max(1, request.prod_steps // request.n_frames)
                if step % interval == 0:
                    frames.append(ctxt.get_x_t())

            ctxt.step(lamb)

        frames = np.array(frames)

        if request.observe_du_dl_freq > 0:
            avg_du_dls = du_dl_obs.avg_du_dl()
        else:
            avg_du_dls = None

        if request.observe_du_dp_freq > 0:
            avg_du_dps = []
            for obs in du_dps:
                avg_du_dps.append(obs.avg_du_dp())
        else:
            avg_du_dps = None

        return service_pb2.SimulateReply(
            avg_du_dls=pickle.dumps(avg_du_dls),
            avg_du_dps=pickle.dumps(avg_du_dps),
            energies=pickle.dumps(energies),
            frames=pickle.dumps(frames),
        )
Beispiel #5
0
    def _simulate(lamb, box, x0, v0, final_potentials, integrator, equil_steps, prod_steps):
        all_impls = []
        bonded_impls = []
        nonbonded_impls = []

        # set up observables for du_dps here as well.

        du_dp_obs = []

        for bps in final_potentials:
            obs_list = []

            for bp in bps:
                impl = bp.bound_impl(np.float32)

                if isinstance(bp, potentials.InterpolatedPotential) or isinstance(bp, potentials.LambdaPotential):
                    bp = bp.get_u_fn()

                if isinstance(bp, potentials.Nonbonded):
                    nonbonded_impls.append(impl)
                else:
                    bonded_impls.append(impl)

                all_impls.append(impl)
                obs_list.append(custom_ops.AvgPartialUPartialParam(impl, 5))

            du_dp_obs.append(obs_list)

        intg_impl = integrator.impl()
        # context components: positions, velocities, box, integrator, energy fxns
        ctxt = custom_ops.Context(
            x0,
            v0,
            box,
            intg_impl,
            all_impls
        )

        # equilibration
        for step in range(equil_steps):
            ctxt.step(lamb)

        bonded_du_dl_obs = custom_ops.FullPartialUPartialLambda(bonded_impls, 5)
        nonbonded_du_dl_obs = custom_ops.FullPartialUPartialLambda(nonbonded_impls, 5)

        # add observable
        ctxt.add_observable(bonded_du_dl_obs)
        ctxt.add_observable(nonbonded_du_dl_obs)

        for obs_list in du_dp_obs:
            for obs in obs_list:
                ctxt.add_observable(obs)

        for _ in range(prod_steps):
            ctxt.step(lamb)

        bonded_full_du_dls = bonded_du_dl_obs.full_du_dl()
        nonbonded_full_du_dls = nonbonded_du_dl_obs.full_du_dl()

        bonded_mean, bonded_std = np.mean(bonded_full_du_dls), np.std(bonded_full_du_dls)
        nonbonded_mean, nonbonded_std = np.mean(nonbonded_full_du_dls), np.std(nonbonded_full_du_dls)

        # keep the structure of grads the same as that of final_potentials so we can properly
        # form their vjps.
        grads = []
        for obs_list in du_dp_obs:
            grad_list = []
            for obs in obs_list:
                grad_list.append(obs.avg_du_dp())
            grads.append(grad_list)

        return (bonded_mean, bonded_std), (nonbonded_mean, nonbonded_std), grads