Esempio n. 1
0
    def test_overfit_host_acd(self):
        raw_potentials, coords, (params, param_groups), masses = serialize.deserialize_system('examples/host_acd.xml')

        potentials = []
        for p, args in raw_potentials:
            potentials.append(p(*args))

        num_atoms = coords.shape[0]

        dt = 0.001
        ca, cb, cc = langevin_coefficients(
            temperature=100.0,
            dt=dt,
            friction=75,
            masses=masses
        )

        # minimization coefficients
        m_dt, m_ca, m_cb, m_cc = dt, 0.5, cb, np.zeros_like(masses)

        friction = 1.0

        opt = custom_ops.LangevinOptimizer_f64(
            m_dt,
            m_ca,
            m_cb,
            m_cc
        )

        # test getting charges
        dp_idxs = np.argwhere(param_groups == 7).reshape(-1)

        ctxt = custom_ops.Context_f64(
            potentials,
            opt,
            params,
            coords, # x0
            np.zeros_like(coords), # v0
            # np.arange(len(params))
            dp_idxs
        )

        # minimize the system
        for i in range(10000):
            ctxt.step()
            if i % 100 == 0:
                print(i, ctxt.get_E())


        opt.set_dt(dt)
        opt.set_coeff_a(ca)
        opt.set_coeff_b(cb)
        opt.set_coeff_c(cc)

        # tdb reservoir sampler
        for i in range(10000):
            ctxt.step()
            if i % 100 == 0:
                print(i, ctxt.get_E())
Esempio n. 2
0
    def test_set_and_get(self):
        """
        This test the setters and getters in the context.
        """

        np.random.seed(4321)

        N = 8
        D = 3

        x0 = np.random.rand(N, D).astype(dtype=np.float64) * 2

        E = 2

        lambda_plane_idxs = np.random.randint(low=0,
                                              high=2,
                                              size=N,
                                              dtype=np.int32)
        lambda_offset_idxs = np.random.randint(low=0,
                                               high=2,
                                               size=N,
                                               dtype=np.int32)

        params, _, test_nrg = prepare_nb_system(
            x0,
            E,
            lambda_plane_idxs,
            lambda_offset_idxs,
            p_scale=3.0,
            cutoff=1.0,
        )

        masses = np.random.rand(N)
        v0 = np.random.rand(x0.shape[0], x0.shape[1])

        temperature = 300
        dt = 2e-3
        friction = 0.0
        ca, cbs, ccs = langevin_coefficients(temperature, dt, friction, masses)

        box = np.eye(3) * 3.0
        intg = custom_ops.LangevinIntegrator(dt, ca, cbs, ccs, 1234)

        bp = test_nrg.bind(params).bound_impl(precision=np.float64)
        bps = [bp]

        ctxt = custom_ops.Context(x0, v0, box, intg, bps)

        np.testing.assert_equal(ctxt.get_x_t(), x0)
        np.testing.assert_equal(ctxt.get_v_t(), v0)
        np.testing.assert_equal(ctxt.get_box(), box)

        new_x = np.random.rand(N, 3)
        ctxt.set_x_t(new_x)

        np.testing.assert_equal(ctxt.get_x_t(), new_x)
Esempio n. 3
0
    def __init__(self, temperature, dt, friction, masses, seed):

        self.dt = dt
        self.seed = seed

        ca, cb, cc = langevin_coefficients(temperature, dt, friction, masses)
        cb *= -1
        self.ca = ca
        self.cbs = cb
        self.ccs = cc
Esempio n. 4
0
    def __init__(self, steps, dt, temperature, friction, masses, lamb, seed):

        equilibrium_steps = 2000

        ca, cbs, ccs = langevin_coefficients(temperature, dt, friction, masses)

        complete_cas = np.ones(steps) * ca
        complete_dts = np.concatenate([
            np.linspace(0, dt, equilibrium_steps),
            np.ones(steps - equilibrium_steps) * dt
        ])

        self.dts = complete_dts
        self.cas = complete_cas
        self.cbs = -cbs
        self.ccs = ccs
        self.lambs = np.zeros(steps) + lamb
        self.seed = seed
Esempio n. 5
0
def setup_system():

    xs = np.linspace(0, 1.0, 5, endpoint=False)
    ys = np.linspace(0, 1.0, 5, endpoint=False)

    conf = np.transpose([np.tile(xs, len(ys)), np.repeat(ys, len(xs))])
    conf += np.random.rand(*conf.shape)/20

    D = conf.shape[-1]

    sigma = 0.1/1.122
    eps = 1.0

    # lj_params = np.ones_like(conf)
    # lj_params[:, 0] = sigma
    lj_params = np.array([sigma, eps])
    masses = np.ones(conf.shape[0])*1

    dt = 1.5e-3

    ca, cb, cc = langevin_coefficients(
        temperature=300.0,
        dt=dt,
        friction=1.0,
        masses=masses
    )
    cb = -np.expand_dims(cb, axis=-1)
    cc = np.expand_dims(cc, axis=-1)
    
    # minimization
    # ca = np.zeros_like(ca)
    # cc = np.zeros_like(cc)

    # print(ca, cb, cc)
    num_steps = 2000
    volume = 1.0 # or area in our case
    # box_length = np.sqrt(volume)

    p_ext = -25.0

    grad_fn = jax.grad(lennard_jones, argnums=(0,2))
    grad_fn = jax.jit(grad_fn)
    nrg_fn = jax.jit(lennard_jones)

    def integrate_once_through(
        x_t,
        v_t,
        vol_xt,
        vol_vt,
        lj_params,
        xt_noise_buf,
        vol_noise_buf):

        p_ints = []

        # coords = []
        # volumes = []

        print("initial coords", x_t)

        for step in range(num_steps):

            box_length = np.sqrt(volume)

            x_t = recenter(x_t, box_length)
            x_t = recenter_to_first_atom(x_t)

            force, p_int = grad_fn(x_t, lj_params, vol_xt)
            # p_ints.append(p_int)

            if step % 1000 == 0:
                e = nrg_fn(x_t, lj_params, vol_xt)
                print("step", step, "vol_xt", vol_xt, "u", e, "p_int", p_int)

            if step % 50 == 0:
                p_ints.append(p_int)

            if step % 100 == 0:
                e = nrg_fn(x_t, lj_params, vol_xt)
                # plt.xlim(0, box_length)
                # plt.ylim(0, box_length)
                plt.scatter(x_t[:, 0], x_t[:, 1])
                plt.savefig('barostat_frames/'+str(step))
                plt.clf()

            # vol_noise = vol_noise_buf[step]
            # vol_vt = 0.5*vol_vt - 0.01*(p_int - p_ext) + vol_noise
            # vol_xt = vol_xt + vol_vt*1.5e-3

            noise = xt_noise_buf[step]
            v_t = ca*v_t + cb*force + cc*noise
            x_t = x_t + v_t*dt

        print("final coords", x_t)

        # volumes = jnp.array(volumes)
        # expected_volume = 1.15
        # print("expected", expected_volume, "observed", jnp.mean(volumes))

        p_ints = jnp.array(p_ints)
        expected_pressure = -100.0
        computed_pressure = jnp.mean(p_ints)

        print("EP", expected_pressure, "CP", computed_pressure)

        loss = jnp.abs(expected_pressure - computed_pressure)

        return loss

    x0 = np.copy(conf)
    v0 = np.zeros_like(x0)

    vol_xt = volume
    vol_vt = np.zeros_like(vol_xt)

    xt_noise_buffer = np.random.randn(num_steps, *conf.shape)
    vol_noise_buffer = np.random.randn(num_steps)

    x_final = integrate_once_through(
        x0,
        v0,
        vol_xt,
        vol_vt,
        lj_params,
        xt_noise_buffer,
        vol_noise_buffer
    )

    assert 0

    for epoch in range(100):

        print(epoch, lj_params)


        xt_noise_buffer = np.random.randn(num_steps, *conf.shape)
        vol_noise_buffer = np.random.randn(num_steps)


        primals = (
            x0,
            v0, 
            vol_xt,
            vol_vt,
            lj_params,
            xt_noise_buffer,
            vol_noise_buffer
        )



        tangents = (
            np.zeros_like(x0),
            np.zeros_like(v0),
            np.zeros_like(vol_xt),
            np.zeros_like(vol_vt),
            # np.zeros_like(lj_params),
            np.array([1.0, 0.0]),
            np.zeros_like(xt_noise_buffer),
            np.zeros_like(vol_noise_buffer)
        )

        x_primals_out, x_tangents_out = jax.jvp(integrate_once_through, primals, tangents)
        
        sig_grad = np.clip(x_tangents_out, -0.01, 0.01)

        print("loss", x_primals_out, "raw_grad", x_tangents_out, "clip grad", sig_grad)
Esempio n. 6
0
def run_simulation(potentials,
                   params,
                   param_groups,
                   conf,
                   masses,
                   dp_idxs,
                   n_samples=200,
                   n_steps=1000):

    potentials = forcefield.merge_potentials(potentials)

    dt = 0.0005
    ca, cb, cc = langevin_coefficients(temperature=25.0,
                                       dt=dt,
                                       friction=50,
                                       masses=masses)

    m_dt, m_ca, m_cb, m_cc = dt, 0.5, cb, np.zeros_like(masses)

    opt = custom_ops.LangevinOptimizer_f32(m_dt, m_ca, m_cb.astype(np.float32),
                                           m_cc.astype(np.float32))

    v0 = np.zeros_like(conf)
    dp_idxs = dp_idxs.astype(np.int32)

    ctxt = custom_ops.Context_f32(
        potentials,
        opt,
        params.astype(np.float32),
        conf.astype(np.float32),  # x0
        v0.astype(np.float32),  # v0
        dp_idxs)

    # Minimize the system and carry the gradient over
    # call system converged when the delta is .25 kcal)
    max_iter = 25000
    window_size = 150
    minimization_energies = []
    for i in range(max_iter):
        ctxt.step()
        E = ctxt.get_E()
        minimization_energies.append(E)
        if len(minimization_energies) > window_size:
            window_std = np.std(minimization_energies[-window_size:])
            if window_std < 1.046 / 2:
                break
        # if i % 1000 == 0:
        # print("minimization", i, E)

    if i == max_iter - 1:
        raise Exception("Energy minimization failed to converge in ", i,
                        "steps")
    else:
        print("Minimization converged in", i, "steps to", E)

    # #modify integrator to do dynamics
    # opt.set_dt(dt)
    # opt.set_coeff_a(ca)
    # opt.set_coeff_b(cb)
    # opt.set_coeff_c(cc)

    # # dynamics via reservoir sampling
    # k = n_samples # number of samples we want to keep
    # R = []
    # count = 0

    # for count in range(n_steps):

    #     # closure around R, and ctxt
    #     def get_reservoir_item(step):
    #         E = ctxt.get_E()
    #         dE_dx = ctxt.get_dE_dx()
    #         dx_dp = ctxt.get_dx_dp()
    #         dE_dp = ctxt.get_dE_dp()
    #         min_dx = np.amin(dx_dp)
    #         max_dx = np.amax(dx_dp)
    #         lhs = np.einsum('kl,mkl->m', dE_dx, dx_dp)
    #         total_dE_dp = lhs + dE_dp

    #         # print(step, total_dE_dp)

    #         limits = 1e5
    #         # if min_dx < -limits or max_dx > limits:
    #             # raise Exception("Derivatives blew up:", min_dx, max_dx)
    #         return [E, dE_dx, dx_dp, dE_dp, step]

    #     if count < k:
    #         R.append(get_reservoir_item(count))
    #     else:
    #         j = random.randint(0, count)
    #         if j < k:
    #             R[j] = get_reservoir_item(count)
    #             np.set_printoptions(suppress=True)

    #     if count % 5000 == 0:
    #         print("count", count)

    #     ctxt.step()

    R = [[
        ctxt.get_E(),
        ctxt.get_dE_dx(),
        ctxt.get_dx_dp(),
        ctxt.get_dE_dp(), 0
    ]]

    return R
Esempio n. 7
0
    def __init__(self, U_fn, O_fn, temperature):

        self.kT = BOLTZ * temperature
        # self.temperature = temperature
        self.U_fn = U_fn  # (x, p) -> R^1
        self.O_fn = O_fn  # (R^1 -> R^N)

        xs = np.linspace(0, 1.0, 3, endpoint=True)
        conf = np.expand_dims(xs, axis=1)
        self.conf = conf
        D = conf.shape[-1]

        # sigma = 0.2/1.122
        # sigma = 0.1
        # eps = 1.0

        # lj_params = np.array([sigma, eps])

        masses = np.ones(conf.shape[0])
        dt = 1.5e-3

        ca, cb, cc = langevin_coefficients(temperature=300.0, dt=dt, friction=1.0, masses=masses)
        cb = -np.expand_dims(cb, axis=-1)
        cc = np.expand_dims(cc, axis=-1)

        # setup bounded particles
        cb[0] = 0.0
        cb[-1] = 0.0
        cc[0] = 0.0
        cc[-1] = 0.0

        num_steps = 50000

        grad_fn = jax.grad(lennard_jones, argnums=(0, 1))
        grad_fn = jax.jit(grad_fn)
        nrg_fn = jax.jit(lennard_jones)

        def integrate_once_through(x_t, v_t, lj_params):

            Os = []
            dU_dps = []
            O_dot_dU_dps = []

            for step in range(num_steps):
                dU_dx, dU_dp = grad_fn(x_t, lj_params)

                # if step % 1000 == 0:
                #     e = nrg_fn(x_t, lj_params)
                #     print("step", step, "x_t", x_t)

                if step % 10 == 0 and step > 2000:
                    obs = self.O_fn(x_t, lj_params)
                    Os.append(obs)
                    dU_dps.append(dU_dp)
                    O_dot_dU_dps.append(obs * dU_dp)

                # if step % 10 == 0:
                #     e = nrg_fn(x_t, lj_params, vol_xt)
                #     # plt.xlim(0, box_length)
                #     # plt.ylim(0, box_length)
                #     # plt.scatter(xx_t[:, 0], x_t[:, 1])
                #     plt.scatter(x_t, np.zeros_like(x_t))
                #     plt.savefig('barostat_frames/'+str(step))
                #     plt.clf()

                noise = np.random.randn(*x_t.shape)
                v_t = ca * v_t + cb * dU_dx + cc * noise
                x_t = x_t + v_t * dt

            # print(observables)
            # plt.hist(observables)
            # plt.show()
            Os = np.asarray(Os)
            dU_dps = np.asarray(dU_dps)
            O_dot_dU_dps = np.asarray(O_dot_dU_dps)

            # print(Os.shape, dU_dps.shape, O_dot_dU_dps.shape)

            return np.mean(Os, axis=0), np.mean(dU_dps, axis=0), np.mean(O_dot_dU_dps, axis=0)

        self.integrator = integrate_once_through
Esempio n. 8
0
    def __init__(self, U_A_fn, U_B_fn, temperature):

        self.kT = BOLTZ * temperature
        self.U_A_fn = jax.jit(U_A_fn)  # (x, p) -> R^1
        self.U_B_fn = jax.jit(U_B_fn)  # (x, p) -> R^1

        xs = np.linspace(0, 1.0, 5, endpoint=False)
        ys = np.linspace(0, 1.0, 5, endpoint=False)
        conf = np.transpose([np.tile(xs, len(ys)), np.repeat(ys, len(xs))])

        N = conf.shape[0]
        self.conf = conf
        D = conf.shape[-1]

        masses = np.ones(conf.shape[0])
        dt = 1.5e-3

        ca, cb, cc = langevin_coefficients(temperature=300.0,
                                           dt=dt,
                                           friction=1.0,
                                           masses=masses)
        cb = -np.expand_dims(cb, axis=-1)
        cc = np.expand_dims(cc, axis=-1)

        num_steps = 100000

        # grad_fn = jax.grad(self.U_A_fn, argnums=(0,1))
        # grad_fn = jax.jit(grad_fn)
        dU_A_dx_fn = jax.jit(jax.grad(self.U_A_fn, argnums=(0, )))
        dU_B_dx_fn = jax.jit(jax.grad(self.U_B_fn, argnums=(0, )))

        dU_A_dp_fn = jax.jit(jax.grad(self.U_A_fn, argnums=(1, )))
        dU_B_dp_fn = jax.jit(jax.grad(self.U_B_fn, argnums=(1, )))

        def integrate_once_through(x_0, v_0, lj_params):

            volume = 1.0

            # simulate "target state", denominator of ratio
            x_t = np.copy(x_0)
            v_t = np.copy(v_0)

            dUB_dps = []

            for step in range(num_steps):
                dU_dx = dU_B_dx_fn(x_t, lj_params, volume)[0]

                if step % 10 == 0 and step > 10000:

                    dUB_dp = dU_B_dp_fn(x_t, lj_params, volume)[0]
                    dUB_dps.append(dUB_dp)

                noise = np.random.randn(*x_t.shape)
                v_t = ca * v_t + cb * dU_dx + cc * noise
                x_t = x_t + v_t * dt

            x_t = np.copy(x_0)
            v_t = np.copy(v_0)

            dUA_dps = []
            deltaUs = []

            # simulate "reference state", numerator of ratio
            for step in range(num_steps):
                dU_dx = dU_A_dx_fn(x_t, lj_params, volume)[0]
                # x_t = recenter(x_t, np.sqrt(volume))
                # if step % 1000 == 0:
                #     e = nrg_fn(x_t, lj_params)
                #     print("step", step, "x_t", x_t)

                if step % 10 == 0 and step > 10000:

                    dUA_dp = dU_A_dp_fn(x_t, lj_params, volume)[0]
                    dUA_dps.append(dUA_dp)

                    U_A = self.U_A_fn(x_t, lj_params, volume)
                    U_B = self.U_B_fn(x_t, lj_params, volume)

                    delta_U = U_B - U_A

                    deltaUs.append(-delta_U / self.kT)

                # if step % 10 == 0:
                # obs = self.O_fn(x_t, lj_params)
                # obs = self.O_fn(x_t, lj_params, volume)
                # Os.append(obs)
                # dU_dps.append(dU_dp)
                # # print(dU_dp)
                # O_dot_dU_dps.append(obs * dU_dp)
                # dO_dp = self.dO_dp_fn(x_t, lj_params, volume)[0]
                # dO_dps.append(dO_dp)

                # if step % 5000 == 0:
                #     print("step", step)
                #     xx_t = recenter(x_t, np.sqrt(volume))
                #     e = nrg_fn(xx_t, lj_params, volume)
                #     plt.xlim(0, 1.0)
                #     plt.ylim(0, 1.0)
                #     plt.scatter(xx_t[:, 0], xx_t[:, 1])
                #     # plt.scatter(x_t, np.zeros_like(x_t))
                #     plt.savefig('barostat_frames/'+str(step))
                #     plt.clf()

                noise = np.random.randn(*x_t.shape)
                v_t = ca * v_t + cb * dU_dx + cc * noise
                x_t = x_t + v_t * dt

            # delta_G = -self.kT*np.log(np.mean(edus))
            # us = np.asarray(us)/len(us)
            avg_dUA_dps = np.mean(dUA_dps, axis=0)
            avg_dUB_dps = np.mean(dUB_dps, axis=0)

            print(avg_dUA_dps)
            print(avg_dUB_dps)
            delta_G = -self.kT * (logsumexp(deltaUs) - np.log(len(deltaUs)))

            return delta_G, avg_dUB_dps - avg_dUA_dps

            # return np.mean(Os, axis=0), np.mean(dU_dps, axis=0), np.mean(O_dot_dU_dps, axis=0), np.mean(dO_dps, axis=0)

        self.integrator = integrate_once_through
Esempio n. 9
0
    def test_fwd_mode(self):
        """
        This test ensures that we can reverse-mode differentiate
        observables that are dU_dlambdas of each state. We provide
        adjoints with respect to each computed dU/dLambda.
        """

        np.random.seed(4321)

        N = 8
        D = 3

        x0 = np.random.rand(N, D).astype(dtype=np.float64) * 2

        E = 2

        lambda_plane_idxs = np.random.randint(low=0,
                                              high=2,
                                              size=N,
                                              dtype=np.int32)
        lambda_offset_idxs = np.random.randint(low=0,
                                               high=2,
                                               size=N,
                                               dtype=np.int32)

        params, ref_nrg_fn, test_nrg = prepare_nb_system(
            x0,
            E,
            lambda_plane_idxs,
            lambda_offset_idxs,
            p_scale=3.0,
            # cutoff=0.5,
            cutoff=1.0,
        )

        masses = np.random.rand(N)

        v0 = np.random.rand(x0.shape[0], x0.shape[1])

        num_steps = 5
        temperature = 300
        dt = 2e-3
        friction = 0.0
        ca, cbs, ccs = langevin_coefficients(temperature, dt, friction, masses)

        # not convenient to simulate identical trajectories otherwise
        assert (ccs == 0).all()

        lamb = np.random.rand()
        lambda_windows = np.array([lamb + 0.05, lamb, lamb - 0.05])

        def integrate_once_through(x_t, v_t, box, params):

            dU_dx_fn = jax.grad(ref_nrg_fn, argnums=(0, ))
            dU_dp_fn = jax.grad(ref_nrg_fn, argnums=(1, ))
            dU_dl_fn = jax.grad(ref_nrg_fn, argnums=(3, ))

            all_du_dls = []
            all_du_dps = []
            all_xs = []
            all_du_dxs = []
            all_us = []
            all_lambda_us = []
            for step in range(num_steps):
                u = ref_nrg_fn(x_t, params, box, lamb)
                all_us.append(u)
                du_dl = dU_dl_fn(x_t, params, box, lamb)[0]
                all_du_dls.append(du_dl)
                du_dp = dU_dp_fn(x_t, params, box, lamb)[0]
                all_du_dps.append(du_dp)
                du_dx = dU_dx_fn(x_t, params, box, lamb)[0]
                all_du_dxs.append(du_dx)
                all_xs.append(x_t)

                lus = []
                for lamb_u in lambda_windows:
                    lus.append(ref_nrg_fn(x_t, params, box, lamb_u))

                all_lambda_us.append(lus)
                noise = np.random.randn(*v_t.shape)

                v_mid = v_t + np.expand_dims(cbs, axis=-1) * du_dx

                v_t = ca * v_mid + np.expand_dims(ccs, axis=-1) * noise
                x_t += 0.5 * dt * (v_mid + v_t)

                # note that we do not calculate the du_dl of the last frame.
            return all_xs, all_du_dxs, all_du_dps, all_du_dls, all_us, all_lambda_us

        box = np.eye(3) * 3.0

        # when we have multiple parameters, we need to set this up correctly
        (
            ref_all_xs,
            ref_all_du_dxs,
            ref_all_du_dps,
            ref_all_du_dls,
            ref_all_us,
            ref_all_lambda_us,
        ) = integrate_once_through(x0, v0, box, params)

        intg = custom_ops.LangevinIntegrator(dt, ca, cbs, ccs, 1234)

        bp = test_nrg.bind(params).bound_impl(precision=np.float64)
        bps = [bp]

        ctxt = custom_ops.Context(x0, v0, box, intg, bps)

        for step in range(num_steps):
            print("comparing step", step)
            test_x_t = ctxt.get_x_t()
            np.testing.assert_allclose(test_x_t, ref_all_xs[step])
            ctxt.step(lamb)
            test_du_dx_t = ctxt._get_du_dx_t_minus_1()
            # test_u_t = ctxt._get_u_t_minus_1()
            # np.testing.assert_allclose(test_u_t, ref_all_us[step])
            np.testing.assert_allclose(test_du_dx_t, ref_all_du_dxs[step])

        # test the multiple_steps method
        ctxt_2 = custom_ops.Context(x0, v0, box, intg, bps)

        lambda_schedule = np.ones(num_steps) * lamb

        du_dl_interval = 3
        x_interval = 2
        start_box = ctxt_2.get_box()
        test_du_dls, test_xs, test_boxes = ctxt_2.multiple_steps(
            lambda_schedule, du_dl_interval, x_interval)
        end_box = ctxt_2.get_box()

        np.testing.assert_allclose(test_du_dls,
                                   ref_all_du_dls[::du_dl_interval])

        np.testing.assert_allclose(test_xs, ref_all_xs[::x_interval])
        np.testing.assert_array_equal(start_box, end_box)
        for i in range(test_boxes.shape[0]):
            np.testing.assert_array_equal(start_box, test_boxes[i])
        self.assertEqual(test_boxes.shape[0], test_xs.shape[0])
        self.assertEqual(test_boxes.shape[1], D)
        self.assertEqual(test_boxes.shape[2], test_xs.shape[2])

        # test the multiple_steps_U method
        ctxt_3 = custom_ops.Context(x0, v0, box, intg, bps)

        u_interval = 3

        test_us, test_xs, test_boxes = ctxt_3.multiple_steps_U(
            lamb, num_steps, lambda_windows, u_interval, x_interval)

        np.testing.assert_array_almost_equal(ref_all_lambda_us[::u_interval],
                                             test_us)

        np.testing.assert_array_almost_equal(ref_all_xs[::x_interval], test_xs)

        test_us, test_xs, test_boxes = ctxt_3.multiple_steps_U(
            lamb, num_steps, np.array([], dtype=np.float64), u_interval,
            x_interval)

        assert test_us.shape == (2, 0)
Esempio n. 10
0
    def __init__(self, U_fn, O_fn, temperature):


        self.kT = BOLTZ*temperature
        # self.temperature = temperature
        self.U_fn = U_fn # (x, p) -> R^1
        self.O_fn = O_fn # (R^1 -> R^N)
        self.dO_dp_fn = jax.jit(jax.grad(O_fn, argnums=(1,)))

        xs = np.linspace(0, 1.0, 5, endpoint=False)
        ys = np.linspace(0, 1.0, 5, endpoint=False)
        conf = np.transpose([np.tile(xs, len(ys)), np.repeat(ys, len(xs))])

        N = conf.shape[0]
        self.conf = conf
        D = conf.shape[-1]

        masses = np.ones(conf.shape[0])
        dt = 1.5e-3

        ca, cb, cc = langevin_coefficients(
            temperature=300.0,
            dt=dt,
            friction=1.0,
            masses=masses
        )
        cb = -np.expand_dims(cb, axis=-1)
        cc = np.expand_dims(cc, axis=-1)

        num_steps = 200000

        grad_fn = jax.grad(lennard_jones, argnums=(0,1))
        grad_fn = jax.jit(grad_fn)
        nrg_fn = jax.jit(lennard_jones)

        def integrate_once_through(
            x_t,
            v_t,
            lj_params):

            Os = []
            dU_dps = []
            O_dot_dU_dps = []
            dO_dps = []

            volume = 1.0

            for step in range(num_steps):
                # lennard_jones(x_t, lj_params, volume)
                # assert 0
                dU_dx, dU_dp = grad_fn(x_t, lj_params, volume)

                # x_t = recenter(x_t, np.sqrt(volume))

                # if step % 1000 == 0:
                #     e = nrg_fn(x_t, lj_params)
                #     print("step", step, "x_t", x_t)

                if step % 10 == 0 and step > 10000:
                # if step % 10 == 0:
                    # obs = self.O_fn(x_t, lj_params)
                    obs = self.O_fn(x_t, lj_params, volume)
                    Os.append(obs)
                    dU_dps.append(dU_dp)
                    # print(dU_dp)
                    O_dot_dU_dps.append(obs * dU_dp)
                    dO_dp = self.dO_dp_fn(x_t, lj_params, volume)[0]
                    dO_dps.append(dO_dp)

                # if step % 5000 == 0:
                #     print("step", step)
                #     xx_t = recenter(x_t, np.sqrt(volume))
                #     e = nrg_fn(xx_t, lj_params, volume)
                #     plt.xlim(0, 1.0)
                #     plt.ylim(0, 1.0)
                #     plt.scatter(xx_t[:, 0], xx_t[:, 1])
                #     # plt.scatter(x_t, np.zeros_like(x_t))
                #     plt.savefig('barostat_frames/'+str(step))
                #     plt.clf()

                noise = np.random.randn(*x_t.shape)
                v_t = ca*v_t + cb*dU_dx + cc*noise
                x_t = x_t + v_t*dt

            Os = np.asarray(Os)
            dU_dps = np.asarray(dU_dps)
            O_dot_dU_dps = np.asarray(O_dot_dU_dps)

            # print(Os.shape, dU_dps.shape, O_dot_dU_dps.shape)

            return np.mean(Os, axis=0), np.mean(dU_dps, axis=0), np.mean(O_dot_dU_dps, axis=0), np.mean(dO_dps, axis=0)

        self.integrator = integrate_once_through
Esempio n. 11
0
def convergence(args):
    epoch, lamb, lamb_idx = args

    suppl = Chem.SDMolSupplier("tests/data/ligands_40.sdf", removeHs=False)

    ligands = []
    for mol in suppl:
        ligands.append(mol)

    ligand_a = ligands[0]
    ligand_b = ligands[1]

    # print(ligand_a.GetNumAtoms())
    # print(ligand_b.GetNumAtoms())

    # ligand_a = Chem.AddHs(Chem.MolFromSmiles("CCCC1=NN(C2=C1N=C(NC2=O)C3=C(C=CC(=C3)S(=O)(=O)N4CCN(CC4)C)OCC)C"))
    # ligand_b = Chem.AddHs(Chem.MolFromSmiles("CCCC1=NN(C2=C1N=C(NC2=O)C3=C(C=CC(=C3)S(=O)(=O)N4CCN(CC4)C)OCC)C"))
    # ligand_a = Chem.AddHs(Chem.MolFromSmiles("c1ccccc1CC"))
    # ligand_b = Chem.AddHs(Chem.MolFromSmiles("c1ccccc1CC"))
    # AllChem.EmbedMolecule(ligand_a, randomSeed=2020)
    # AllChem.EmbedMolecule(ligand_b, randomSeed=2020)

    coords_a = get_conf(ligand_a, idx=0)
    coords_b = get_conf(ligand_b, idx=0)
    # coords_b = np.matmul(coords_b, special_ortho_group.rvs(3))

    coords_a = recenter(coords_a)
    coords_b = recenter(coords_b)

    coords = np.concatenate([coords_a, coords_b])

    a_idxs = get_heavy_atom_idxs(ligand_a)
    b_idxs = get_heavy_atom_idxs(ligand_b)

    a_full_idxs = np.arange(0, ligand_a.GetNumAtoms())
    b_full_idxs = np.arange(0, ligand_b.GetNumAtoms())

    b_idxs += ligand_a.GetNumAtoms()
    b_full_idxs += ligand_a.GetNumAtoms()

    nrg_fns = []

    forcefield = 'ff/params/smirnoff_1_1_0_ccc.py'
    ff_raw = open(forcefield, "r").read()
    ff_handlers = deserialize_handlers(ff_raw)

    combined_mol = Chem.CombineMols(ligand_a, ligand_b)

    for handler in ff_handlers:
        if isinstance(handler, handlers.HarmonicBondHandler):
            bond_idxs, (bond_params, _) = handler.parameterize(combined_mol)
            nrg_fns.append(
                functools.partial(bonded.harmonic_bond,
                    params=bond_params,
                    box=None,
                    bond_idxs=bond_idxs
                )
            )
        elif isinstance(handler, handlers.HarmonicAngleHandler):
            angle_idxs, (angle_params, _) = handler.parameterize(combined_mol)
            nrg_fns.append(
                functools.partial(bonded.harmonic_angle,
                    params=angle_params,
                    box=None,
                    angle_idxs=angle_idxs
                )
            )
        # elif isinstance(handler, handlers.ImproperTorsionHandler):
        #     torsion_idxs, (torsion_params, _) = handler.parameterize(combined_mol)
        #     print(torsion_idxs)
        #     assert 0
        #     nrg_fns.append(
        #         functools.partial(bonded.periodic_torsion,
        #             params=torsion_params,
        #             box=None,
        #             lamb=None,
        #             torsion_idxs=torsion_idxs
        #         )
        #     )
        # elif isinstance(handler, handlers.ProperTorsionHandler):
        #     torsion_idxs, (torsion_params, _) = handler.parameterize(combined_mol)
        #     # print(torsion_idxs)
        #     nrg_fns.append(
        #         functools.partial(bonded.periodic_torsion,
        #             params=torsion_params,
        #             box=None,
        #             lamb=None,
        #             torsion_idxs=torsion_idxs
        #         )
        #     )

    masses_a = onp.array([a.GetMass() for a in ligand_a.GetAtoms()]) * 10000
    masses_b = onp.array([a.GetMass() for a in ligand_b.GetAtoms()])

    combined_masses = np.concatenate([masses_a, masses_b])

    # com_restraint_fn = functools.partial(bonded.centroid_restraint,
    #     params=None,
    #     box=None,
    #     lamb=None,
    #     # masses=combined_masses, # try making this ones-like
    #     masses=np.ones_like(combined_masses),
    #     group_a_idxs=a_idxs,
    #     group_b_idxs=b_idxs,
    #     kb=50.0,
    #     b0=0.0)

    pmi_restraint_fn = functools.partial(pmi_restraints_new,
        params=None,
        box=None,
        lamb=None,
        # masses=np.ones_like(combined_masses),
        masses=combined_masses,
        # a_idxs=a_full_idxs,
        # b_idxs=b_full_idxs,
        a_idxs=a_idxs,
        b_idxs=b_idxs,
        angle_force=100.0,
        com_force=100.0
    )

    prefactor = 2.7 # unitless
    shape_lamb = (4*np.pi)/(3*prefactor) # unitless
    kappa = np.pi/(np.power(shape_lamb, 2/3)) # unitless
    sigma = 0.15 # 1 angstrom std, 95% coverage by 2 angstroms
    alpha = kappa/(sigma*sigma)

    alphas = np.zeros(combined_mol.GetNumAtoms())+alpha
    weights = np.zeros(combined_mol.GetNumAtoms())+prefactor

    shape_restraint_fn = functools.partial(
        shape.harmonic_overlap,
        box=None,
        lamb=None,
        params=None,
        a_idxs=a_idxs,
        b_idxs=b_idxs,
        alphas=alphas,
        weights=weights,
        k=150.0
    )

    # shape_restraint_4d_fn = functools.partial(
    #     shape.harmonic_4d_overlap,
    #     box=None,
    #     params=None,
    #     a_idxs=a_idxs,
    #     b_idxs=b_idxs,
    #     alphas=alphas,
    #     weights=weights,
    #     k=200.0
    # )

    def restraint_fn(conf, lamb):

        return pmi_restraint_fn(conf) + lamb*shape_restraint_fn(conf)
        # return (1-lamb)*pmi_restraint_fn(conf) + lamb*shape_restraint_fn(conf)


    nrg_fns.append(restraint_fn)

    def nrg_fn(conf, lamb):
        s = []
        for u in nrg_fns:
            s.append(u(conf, lamb=lamb))
        return np.sum(s)
 
    grad_fn = jax.grad(nrg_fn, argnums=(0,1))
    grad_fn = jax.jit(grad_fn)

    du_dx_fn = jax.grad(nrg_fn, argnums=(0))
    du_dx_fn = jax.jit(du_dx_fn)

    x_t = coords
    v_t = np.zeros_like(x_t)

    w = Chem.SDWriter('frames_heavy_'+str(epoch)+'_'+str(lamb_idx)+'.sdf')

    dt = 1.5e-3
    ca, cb, cc = langevin_coefficients(300.0, dt, 1.0, combined_masses)
    cb = -1*onp.expand_dims(cb, axis=-1)
    cc = onp.expand_dims(cc, axis=-1)

    du_dls = []

    # re-seed since forking 
    onp.random.seed(int.from_bytes(os.urandom(4), byteorder='little'))


    # for step in range(100000):
    for step in range(100000):

        # if step % 1000 == 0:
        #     u = nrg_fn(x_t, lamb)
        #     print("step", step, "nrg", onp.asarray(u), "avg_du_dl",  onp.mean(du_dls))
        #     mol = make_conformer(combined_mol, x_t[:ligand_a.GetNumAtoms()], x_t[ligand_a.GetNumAtoms():])
        #     w.write(mol)
        #     w.flush()

        if step % 5 == 0 and step > 10000:
            du_dx, du_dl = grad_fn(x_t, lamb)
            du_dls.append(du_dl)
        else:
            du_dx = du_dx_fn(x_t, lamb)

        v_t = ca*v_t + cb*du_dx + cc*onp.random.normal(size=x_t.shape)
        x_t = x_t + v_t*dt

    return np.mean(onp.mean(du_dls))