def test_overfit_host_acd(self): raw_potentials, coords, (params, param_groups), masses = serialize.deserialize_system('examples/host_acd.xml') potentials = [] for p, args in raw_potentials: potentials.append(p(*args)) num_atoms = coords.shape[0] dt = 0.001 ca, cb, cc = langevin_coefficients( temperature=100.0, dt=dt, friction=75, masses=masses ) # minimization coefficients m_dt, m_ca, m_cb, m_cc = dt, 0.5, cb, np.zeros_like(masses) friction = 1.0 opt = custom_ops.LangevinOptimizer_f64( m_dt, m_ca, m_cb, m_cc ) # test getting charges dp_idxs = np.argwhere(param_groups == 7).reshape(-1) ctxt = custom_ops.Context_f64( potentials, opt, params, coords, # x0 np.zeros_like(coords), # v0 # np.arange(len(params)) dp_idxs ) # minimize the system for i in range(10000): ctxt.step() if i % 100 == 0: print(i, ctxt.get_E()) opt.set_dt(dt) opt.set_coeff_a(ca) opt.set_coeff_b(cb) opt.set_coeff_c(cc) # tdb reservoir sampler for i in range(10000): ctxt.step() if i % 100 == 0: print(i, ctxt.get_E())
def test_set_and_get(self): """ This test the setters and getters in the context. """ np.random.seed(4321) N = 8 D = 3 x0 = np.random.rand(N, D).astype(dtype=np.float64) * 2 E = 2 lambda_plane_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32) lambda_offset_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32) params, _, test_nrg = prepare_nb_system( x0, E, lambda_plane_idxs, lambda_offset_idxs, p_scale=3.0, cutoff=1.0, ) masses = np.random.rand(N) v0 = np.random.rand(x0.shape[0], x0.shape[1]) temperature = 300 dt = 2e-3 friction = 0.0 ca, cbs, ccs = langevin_coefficients(temperature, dt, friction, masses) box = np.eye(3) * 3.0 intg = custom_ops.LangevinIntegrator(dt, ca, cbs, ccs, 1234) bp = test_nrg.bind(params).bound_impl(precision=np.float64) bps = [bp] ctxt = custom_ops.Context(x0, v0, box, intg, bps) np.testing.assert_equal(ctxt.get_x_t(), x0) np.testing.assert_equal(ctxt.get_v_t(), v0) np.testing.assert_equal(ctxt.get_box(), box) new_x = np.random.rand(N, 3) ctxt.set_x_t(new_x) np.testing.assert_equal(ctxt.get_x_t(), new_x)
def __init__(self, temperature, dt, friction, masses, seed): self.dt = dt self.seed = seed ca, cb, cc = langevin_coefficients(temperature, dt, friction, masses) cb *= -1 self.ca = ca self.cbs = cb self.ccs = cc
def __init__(self, steps, dt, temperature, friction, masses, lamb, seed): equilibrium_steps = 2000 ca, cbs, ccs = langevin_coefficients(temperature, dt, friction, masses) complete_cas = np.ones(steps) * ca complete_dts = np.concatenate([ np.linspace(0, dt, equilibrium_steps), np.ones(steps - equilibrium_steps) * dt ]) self.dts = complete_dts self.cas = complete_cas self.cbs = -cbs self.ccs = ccs self.lambs = np.zeros(steps) + lamb self.seed = seed
def setup_system(): xs = np.linspace(0, 1.0, 5, endpoint=False) ys = np.linspace(0, 1.0, 5, endpoint=False) conf = np.transpose([np.tile(xs, len(ys)), np.repeat(ys, len(xs))]) conf += np.random.rand(*conf.shape)/20 D = conf.shape[-1] sigma = 0.1/1.122 eps = 1.0 # lj_params = np.ones_like(conf) # lj_params[:, 0] = sigma lj_params = np.array([sigma, eps]) masses = np.ones(conf.shape[0])*1 dt = 1.5e-3 ca, cb, cc = langevin_coefficients( temperature=300.0, dt=dt, friction=1.0, masses=masses ) cb = -np.expand_dims(cb, axis=-1) cc = np.expand_dims(cc, axis=-1) # minimization # ca = np.zeros_like(ca) # cc = np.zeros_like(cc) # print(ca, cb, cc) num_steps = 2000 volume = 1.0 # or area in our case # box_length = np.sqrt(volume) p_ext = -25.0 grad_fn = jax.grad(lennard_jones, argnums=(0,2)) grad_fn = jax.jit(grad_fn) nrg_fn = jax.jit(lennard_jones) def integrate_once_through( x_t, v_t, vol_xt, vol_vt, lj_params, xt_noise_buf, vol_noise_buf): p_ints = [] # coords = [] # volumes = [] print("initial coords", x_t) for step in range(num_steps): box_length = np.sqrt(volume) x_t = recenter(x_t, box_length) x_t = recenter_to_first_atom(x_t) force, p_int = grad_fn(x_t, lj_params, vol_xt) # p_ints.append(p_int) if step % 1000 == 0: e = nrg_fn(x_t, lj_params, vol_xt) print("step", step, "vol_xt", vol_xt, "u", e, "p_int", p_int) if step % 50 == 0: p_ints.append(p_int) if step % 100 == 0: e = nrg_fn(x_t, lj_params, vol_xt) # plt.xlim(0, box_length) # plt.ylim(0, box_length) plt.scatter(x_t[:, 0], x_t[:, 1]) plt.savefig('barostat_frames/'+str(step)) plt.clf() # vol_noise = vol_noise_buf[step] # vol_vt = 0.5*vol_vt - 0.01*(p_int - p_ext) + vol_noise # vol_xt = vol_xt + vol_vt*1.5e-3 noise = xt_noise_buf[step] v_t = ca*v_t + cb*force + cc*noise x_t = x_t + v_t*dt print("final coords", x_t) # volumes = jnp.array(volumes) # expected_volume = 1.15 # print("expected", expected_volume, "observed", jnp.mean(volumes)) p_ints = jnp.array(p_ints) expected_pressure = -100.0 computed_pressure = jnp.mean(p_ints) print("EP", expected_pressure, "CP", computed_pressure) loss = jnp.abs(expected_pressure - computed_pressure) return loss x0 = np.copy(conf) v0 = np.zeros_like(x0) vol_xt = volume vol_vt = np.zeros_like(vol_xt) xt_noise_buffer = np.random.randn(num_steps, *conf.shape) vol_noise_buffer = np.random.randn(num_steps) x_final = integrate_once_through( x0, v0, vol_xt, vol_vt, lj_params, xt_noise_buffer, vol_noise_buffer ) assert 0 for epoch in range(100): print(epoch, lj_params) xt_noise_buffer = np.random.randn(num_steps, *conf.shape) vol_noise_buffer = np.random.randn(num_steps) primals = ( x0, v0, vol_xt, vol_vt, lj_params, xt_noise_buffer, vol_noise_buffer ) tangents = ( np.zeros_like(x0), np.zeros_like(v0), np.zeros_like(vol_xt), np.zeros_like(vol_vt), # np.zeros_like(lj_params), np.array([1.0, 0.0]), np.zeros_like(xt_noise_buffer), np.zeros_like(vol_noise_buffer) ) x_primals_out, x_tangents_out = jax.jvp(integrate_once_through, primals, tangents) sig_grad = np.clip(x_tangents_out, -0.01, 0.01) print("loss", x_primals_out, "raw_grad", x_tangents_out, "clip grad", sig_grad)
def run_simulation(potentials, params, param_groups, conf, masses, dp_idxs, n_samples=200, n_steps=1000): potentials = forcefield.merge_potentials(potentials) dt = 0.0005 ca, cb, cc = langevin_coefficients(temperature=25.0, dt=dt, friction=50, masses=masses) m_dt, m_ca, m_cb, m_cc = dt, 0.5, cb, np.zeros_like(masses) opt = custom_ops.LangevinOptimizer_f32(m_dt, m_ca, m_cb.astype(np.float32), m_cc.astype(np.float32)) v0 = np.zeros_like(conf) dp_idxs = dp_idxs.astype(np.int32) ctxt = custom_ops.Context_f32( potentials, opt, params.astype(np.float32), conf.astype(np.float32), # x0 v0.astype(np.float32), # v0 dp_idxs) # Minimize the system and carry the gradient over # call system converged when the delta is .25 kcal) max_iter = 25000 window_size = 150 minimization_energies = [] for i in range(max_iter): ctxt.step() E = ctxt.get_E() minimization_energies.append(E) if len(minimization_energies) > window_size: window_std = np.std(minimization_energies[-window_size:]) if window_std < 1.046 / 2: break # if i % 1000 == 0: # print("minimization", i, E) if i == max_iter - 1: raise Exception("Energy minimization failed to converge in ", i, "steps") else: print("Minimization converged in", i, "steps to", E) # #modify integrator to do dynamics # opt.set_dt(dt) # opt.set_coeff_a(ca) # opt.set_coeff_b(cb) # opt.set_coeff_c(cc) # # dynamics via reservoir sampling # k = n_samples # number of samples we want to keep # R = [] # count = 0 # for count in range(n_steps): # # closure around R, and ctxt # def get_reservoir_item(step): # E = ctxt.get_E() # dE_dx = ctxt.get_dE_dx() # dx_dp = ctxt.get_dx_dp() # dE_dp = ctxt.get_dE_dp() # min_dx = np.amin(dx_dp) # max_dx = np.amax(dx_dp) # lhs = np.einsum('kl,mkl->m', dE_dx, dx_dp) # total_dE_dp = lhs + dE_dp # # print(step, total_dE_dp) # limits = 1e5 # # if min_dx < -limits or max_dx > limits: # # raise Exception("Derivatives blew up:", min_dx, max_dx) # return [E, dE_dx, dx_dp, dE_dp, step] # if count < k: # R.append(get_reservoir_item(count)) # else: # j = random.randint(0, count) # if j < k: # R[j] = get_reservoir_item(count) # np.set_printoptions(suppress=True) # if count % 5000 == 0: # print("count", count) # ctxt.step() R = [[ ctxt.get_E(), ctxt.get_dE_dx(), ctxt.get_dx_dp(), ctxt.get_dE_dp(), 0 ]] return R
def __init__(self, U_fn, O_fn, temperature): self.kT = BOLTZ * temperature # self.temperature = temperature self.U_fn = U_fn # (x, p) -> R^1 self.O_fn = O_fn # (R^1 -> R^N) xs = np.linspace(0, 1.0, 3, endpoint=True) conf = np.expand_dims(xs, axis=1) self.conf = conf D = conf.shape[-1] # sigma = 0.2/1.122 # sigma = 0.1 # eps = 1.0 # lj_params = np.array([sigma, eps]) masses = np.ones(conf.shape[0]) dt = 1.5e-3 ca, cb, cc = langevin_coefficients(temperature=300.0, dt=dt, friction=1.0, masses=masses) cb = -np.expand_dims(cb, axis=-1) cc = np.expand_dims(cc, axis=-1) # setup bounded particles cb[0] = 0.0 cb[-1] = 0.0 cc[0] = 0.0 cc[-1] = 0.0 num_steps = 50000 grad_fn = jax.grad(lennard_jones, argnums=(0, 1)) grad_fn = jax.jit(grad_fn) nrg_fn = jax.jit(lennard_jones) def integrate_once_through(x_t, v_t, lj_params): Os = [] dU_dps = [] O_dot_dU_dps = [] for step in range(num_steps): dU_dx, dU_dp = grad_fn(x_t, lj_params) # if step % 1000 == 0: # e = nrg_fn(x_t, lj_params) # print("step", step, "x_t", x_t) if step % 10 == 0 and step > 2000: obs = self.O_fn(x_t, lj_params) Os.append(obs) dU_dps.append(dU_dp) O_dot_dU_dps.append(obs * dU_dp) # if step % 10 == 0: # e = nrg_fn(x_t, lj_params, vol_xt) # # plt.xlim(0, box_length) # # plt.ylim(0, box_length) # # plt.scatter(xx_t[:, 0], x_t[:, 1]) # plt.scatter(x_t, np.zeros_like(x_t)) # plt.savefig('barostat_frames/'+str(step)) # plt.clf() noise = np.random.randn(*x_t.shape) v_t = ca * v_t + cb * dU_dx + cc * noise x_t = x_t + v_t * dt # print(observables) # plt.hist(observables) # plt.show() Os = np.asarray(Os) dU_dps = np.asarray(dU_dps) O_dot_dU_dps = np.asarray(O_dot_dU_dps) # print(Os.shape, dU_dps.shape, O_dot_dU_dps.shape) return np.mean(Os, axis=0), np.mean(dU_dps, axis=0), np.mean(O_dot_dU_dps, axis=0) self.integrator = integrate_once_through
def __init__(self, U_A_fn, U_B_fn, temperature): self.kT = BOLTZ * temperature self.U_A_fn = jax.jit(U_A_fn) # (x, p) -> R^1 self.U_B_fn = jax.jit(U_B_fn) # (x, p) -> R^1 xs = np.linspace(0, 1.0, 5, endpoint=False) ys = np.linspace(0, 1.0, 5, endpoint=False) conf = np.transpose([np.tile(xs, len(ys)), np.repeat(ys, len(xs))]) N = conf.shape[0] self.conf = conf D = conf.shape[-1] masses = np.ones(conf.shape[0]) dt = 1.5e-3 ca, cb, cc = langevin_coefficients(temperature=300.0, dt=dt, friction=1.0, masses=masses) cb = -np.expand_dims(cb, axis=-1) cc = np.expand_dims(cc, axis=-1) num_steps = 100000 # grad_fn = jax.grad(self.U_A_fn, argnums=(0,1)) # grad_fn = jax.jit(grad_fn) dU_A_dx_fn = jax.jit(jax.grad(self.U_A_fn, argnums=(0, ))) dU_B_dx_fn = jax.jit(jax.grad(self.U_B_fn, argnums=(0, ))) dU_A_dp_fn = jax.jit(jax.grad(self.U_A_fn, argnums=(1, ))) dU_B_dp_fn = jax.jit(jax.grad(self.U_B_fn, argnums=(1, ))) def integrate_once_through(x_0, v_0, lj_params): volume = 1.0 # simulate "target state", denominator of ratio x_t = np.copy(x_0) v_t = np.copy(v_0) dUB_dps = [] for step in range(num_steps): dU_dx = dU_B_dx_fn(x_t, lj_params, volume)[0] if step % 10 == 0 and step > 10000: dUB_dp = dU_B_dp_fn(x_t, lj_params, volume)[0] dUB_dps.append(dUB_dp) noise = np.random.randn(*x_t.shape) v_t = ca * v_t + cb * dU_dx + cc * noise x_t = x_t + v_t * dt x_t = np.copy(x_0) v_t = np.copy(v_0) dUA_dps = [] deltaUs = [] # simulate "reference state", numerator of ratio for step in range(num_steps): dU_dx = dU_A_dx_fn(x_t, lj_params, volume)[0] # x_t = recenter(x_t, np.sqrt(volume)) # if step % 1000 == 0: # e = nrg_fn(x_t, lj_params) # print("step", step, "x_t", x_t) if step % 10 == 0 and step > 10000: dUA_dp = dU_A_dp_fn(x_t, lj_params, volume)[0] dUA_dps.append(dUA_dp) U_A = self.U_A_fn(x_t, lj_params, volume) U_B = self.U_B_fn(x_t, lj_params, volume) delta_U = U_B - U_A deltaUs.append(-delta_U / self.kT) # if step % 10 == 0: # obs = self.O_fn(x_t, lj_params) # obs = self.O_fn(x_t, lj_params, volume) # Os.append(obs) # dU_dps.append(dU_dp) # # print(dU_dp) # O_dot_dU_dps.append(obs * dU_dp) # dO_dp = self.dO_dp_fn(x_t, lj_params, volume)[0] # dO_dps.append(dO_dp) # if step % 5000 == 0: # print("step", step) # xx_t = recenter(x_t, np.sqrt(volume)) # e = nrg_fn(xx_t, lj_params, volume) # plt.xlim(0, 1.0) # plt.ylim(0, 1.0) # plt.scatter(xx_t[:, 0], xx_t[:, 1]) # # plt.scatter(x_t, np.zeros_like(x_t)) # plt.savefig('barostat_frames/'+str(step)) # plt.clf() noise = np.random.randn(*x_t.shape) v_t = ca * v_t + cb * dU_dx + cc * noise x_t = x_t + v_t * dt # delta_G = -self.kT*np.log(np.mean(edus)) # us = np.asarray(us)/len(us) avg_dUA_dps = np.mean(dUA_dps, axis=0) avg_dUB_dps = np.mean(dUB_dps, axis=0) print(avg_dUA_dps) print(avg_dUB_dps) delta_G = -self.kT * (logsumexp(deltaUs) - np.log(len(deltaUs))) return delta_G, avg_dUB_dps - avg_dUA_dps # return np.mean(Os, axis=0), np.mean(dU_dps, axis=0), np.mean(O_dot_dU_dps, axis=0), np.mean(dO_dps, axis=0) self.integrator = integrate_once_through
def test_fwd_mode(self): """ This test ensures that we can reverse-mode differentiate observables that are dU_dlambdas of each state. We provide adjoints with respect to each computed dU/dLambda. """ np.random.seed(4321) N = 8 D = 3 x0 = np.random.rand(N, D).astype(dtype=np.float64) * 2 E = 2 lambda_plane_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32) lambda_offset_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32) params, ref_nrg_fn, test_nrg = prepare_nb_system( x0, E, lambda_plane_idxs, lambda_offset_idxs, p_scale=3.0, # cutoff=0.5, cutoff=1.0, ) masses = np.random.rand(N) v0 = np.random.rand(x0.shape[0], x0.shape[1]) num_steps = 5 temperature = 300 dt = 2e-3 friction = 0.0 ca, cbs, ccs = langevin_coefficients(temperature, dt, friction, masses) # not convenient to simulate identical trajectories otherwise assert (ccs == 0).all() lamb = np.random.rand() lambda_windows = np.array([lamb + 0.05, lamb, lamb - 0.05]) def integrate_once_through(x_t, v_t, box, params): dU_dx_fn = jax.grad(ref_nrg_fn, argnums=(0, )) dU_dp_fn = jax.grad(ref_nrg_fn, argnums=(1, )) dU_dl_fn = jax.grad(ref_nrg_fn, argnums=(3, )) all_du_dls = [] all_du_dps = [] all_xs = [] all_du_dxs = [] all_us = [] all_lambda_us = [] for step in range(num_steps): u = ref_nrg_fn(x_t, params, box, lamb) all_us.append(u) du_dl = dU_dl_fn(x_t, params, box, lamb)[0] all_du_dls.append(du_dl) du_dp = dU_dp_fn(x_t, params, box, lamb)[0] all_du_dps.append(du_dp) du_dx = dU_dx_fn(x_t, params, box, lamb)[0] all_du_dxs.append(du_dx) all_xs.append(x_t) lus = [] for lamb_u in lambda_windows: lus.append(ref_nrg_fn(x_t, params, box, lamb_u)) all_lambda_us.append(lus) noise = np.random.randn(*v_t.shape) v_mid = v_t + np.expand_dims(cbs, axis=-1) * du_dx v_t = ca * v_mid + np.expand_dims(ccs, axis=-1) * noise x_t += 0.5 * dt * (v_mid + v_t) # note that we do not calculate the du_dl of the last frame. return all_xs, all_du_dxs, all_du_dps, all_du_dls, all_us, all_lambda_us box = np.eye(3) * 3.0 # when we have multiple parameters, we need to set this up correctly ( ref_all_xs, ref_all_du_dxs, ref_all_du_dps, ref_all_du_dls, ref_all_us, ref_all_lambda_us, ) = integrate_once_through(x0, v0, box, params) intg = custom_ops.LangevinIntegrator(dt, ca, cbs, ccs, 1234) bp = test_nrg.bind(params).bound_impl(precision=np.float64) bps = [bp] ctxt = custom_ops.Context(x0, v0, box, intg, bps) for step in range(num_steps): print("comparing step", step) test_x_t = ctxt.get_x_t() np.testing.assert_allclose(test_x_t, ref_all_xs[step]) ctxt.step(lamb) test_du_dx_t = ctxt._get_du_dx_t_minus_1() # test_u_t = ctxt._get_u_t_minus_1() # np.testing.assert_allclose(test_u_t, ref_all_us[step]) np.testing.assert_allclose(test_du_dx_t, ref_all_du_dxs[step]) # test the multiple_steps method ctxt_2 = custom_ops.Context(x0, v0, box, intg, bps) lambda_schedule = np.ones(num_steps) * lamb du_dl_interval = 3 x_interval = 2 start_box = ctxt_2.get_box() test_du_dls, test_xs, test_boxes = ctxt_2.multiple_steps( lambda_schedule, du_dl_interval, x_interval) end_box = ctxt_2.get_box() np.testing.assert_allclose(test_du_dls, ref_all_du_dls[::du_dl_interval]) np.testing.assert_allclose(test_xs, ref_all_xs[::x_interval]) np.testing.assert_array_equal(start_box, end_box) for i in range(test_boxes.shape[0]): np.testing.assert_array_equal(start_box, test_boxes[i]) self.assertEqual(test_boxes.shape[0], test_xs.shape[0]) self.assertEqual(test_boxes.shape[1], D) self.assertEqual(test_boxes.shape[2], test_xs.shape[2]) # test the multiple_steps_U method ctxt_3 = custom_ops.Context(x0, v0, box, intg, bps) u_interval = 3 test_us, test_xs, test_boxes = ctxt_3.multiple_steps_U( lamb, num_steps, lambda_windows, u_interval, x_interval) np.testing.assert_array_almost_equal(ref_all_lambda_us[::u_interval], test_us) np.testing.assert_array_almost_equal(ref_all_xs[::x_interval], test_xs) test_us, test_xs, test_boxes = ctxt_3.multiple_steps_U( lamb, num_steps, np.array([], dtype=np.float64), u_interval, x_interval) assert test_us.shape == (2, 0)
def __init__(self, U_fn, O_fn, temperature): self.kT = BOLTZ*temperature # self.temperature = temperature self.U_fn = U_fn # (x, p) -> R^1 self.O_fn = O_fn # (R^1 -> R^N) self.dO_dp_fn = jax.jit(jax.grad(O_fn, argnums=(1,))) xs = np.linspace(0, 1.0, 5, endpoint=False) ys = np.linspace(0, 1.0, 5, endpoint=False) conf = np.transpose([np.tile(xs, len(ys)), np.repeat(ys, len(xs))]) N = conf.shape[0] self.conf = conf D = conf.shape[-1] masses = np.ones(conf.shape[0]) dt = 1.5e-3 ca, cb, cc = langevin_coefficients( temperature=300.0, dt=dt, friction=1.0, masses=masses ) cb = -np.expand_dims(cb, axis=-1) cc = np.expand_dims(cc, axis=-1) num_steps = 200000 grad_fn = jax.grad(lennard_jones, argnums=(0,1)) grad_fn = jax.jit(grad_fn) nrg_fn = jax.jit(lennard_jones) def integrate_once_through( x_t, v_t, lj_params): Os = [] dU_dps = [] O_dot_dU_dps = [] dO_dps = [] volume = 1.0 for step in range(num_steps): # lennard_jones(x_t, lj_params, volume) # assert 0 dU_dx, dU_dp = grad_fn(x_t, lj_params, volume) # x_t = recenter(x_t, np.sqrt(volume)) # if step % 1000 == 0: # e = nrg_fn(x_t, lj_params) # print("step", step, "x_t", x_t) if step % 10 == 0 and step > 10000: # if step % 10 == 0: # obs = self.O_fn(x_t, lj_params) obs = self.O_fn(x_t, lj_params, volume) Os.append(obs) dU_dps.append(dU_dp) # print(dU_dp) O_dot_dU_dps.append(obs * dU_dp) dO_dp = self.dO_dp_fn(x_t, lj_params, volume)[0] dO_dps.append(dO_dp) # if step % 5000 == 0: # print("step", step) # xx_t = recenter(x_t, np.sqrt(volume)) # e = nrg_fn(xx_t, lj_params, volume) # plt.xlim(0, 1.0) # plt.ylim(0, 1.0) # plt.scatter(xx_t[:, 0], xx_t[:, 1]) # # plt.scatter(x_t, np.zeros_like(x_t)) # plt.savefig('barostat_frames/'+str(step)) # plt.clf() noise = np.random.randn(*x_t.shape) v_t = ca*v_t + cb*dU_dx + cc*noise x_t = x_t + v_t*dt Os = np.asarray(Os) dU_dps = np.asarray(dU_dps) O_dot_dU_dps = np.asarray(O_dot_dU_dps) # print(Os.shape, dU_dps.shape, O_dot_dU_dps.shape) return np.mean(Os, axis=0), np.mean(dU_dps, axis=0), np.mean(O_dot_dU_dps, axis=0), np.mean(dO_dps, axis=0) self.integrator = integrate_once_through
def convergence(args): epoch, lamb, lamb_idx = args suppl = Chem.SDMolSupplier("tests/data/ligands_40.sdf", removeHs=False) ligands = [] for mol in suppl: ligands.append(mol) ligand_a = ligands[0] ligand_b = ligands[1] # print(ligand_a.GetNumAtoms()) # print(ligand_b.GetNumAtoms()) # ligand_a = Chem.AddHs(Chem.MolFromSmiles("CCCC1=NN(C2=C1N=C(NC2=O)C3=C(C=CC(=C3)S(=O)(=O)N4CCN(CC4)C)OCC)C")) # ligand_b = Chem.AddHs(Chem.MolFromSmiles("CCCC1=NN(C2=C1N=C(NC2=O)C3=C(C=CC(=C3)S(=O)(=O)N4CCN(CC4)C)OCC)C")) # ligand_a = Chem.AddHs(Chem.MolFromSmiles("c1ccccc1CC")) # ligand_b = Chem.AddHs(Chem.MolFromSmiles("c1ccccc1CC")) # AllChem.EmbedMolecule(ligand_a, randomSeed=2020) # AllChem.EmbedMolecule(ligand_b, randomSeed=2020) coords_a = get_conf(ligand_a, idx=0) coords_b = get_conf(ligand_b, idx=0) # coords_b = np.matmul(coords_b, special_ortho_group.rvs(3)) coords_a = recenter(coords_a) coords_b = recenter(coords_b) coords = np.concatenate([coords_a, coords_b]) a_idxs = get_heavy_atom_idxs(ligand_a) b_idxs = get_heavy_atom_idxs(ligand_b) a_full_idxs = np.arange(0, ligand_a.GetNumAtoms()) b_full_idxs = np.arange(0, ligand_b.GetNumAtoms()) b_idxs += ligand_a.GetNumAtoms() b_full_idxs += ligand_a.GetNumAtoms() nrg_fns = [] forcefield = 'ff/params/smirnoff_1_1_0_ccc.py' ff_raw = open(forcefield, "r").read() ff_handlers = deserialize_handlers(ff_raw) combined_mol = Chem.CombineMols(ligand_a, ligand_b) for handler in ff_handlers: if isinstance(handler, handlers.HarmonicBondHandler): bond_idxs, (bond_params, _) = handler.parameterize(combined_mol) nrg_fns.append( functools.partial(bonded.harmonic_bond, params=bond_params, box=None, bond_idxs=bond_idxs ) ) elif isinstance(handler, handlers.HarmonicAngleHandler): angle_idxs, (angle_params, _) = handler.parameterize(combined_mol) nrg_fns.append( functools.partial(bonded.harmonic_angle, params=angle_params, box=None, angle_idxs=angle_idxs ) ) # elif isinstance(handler, handlers.ImproperTorsionHandler): # torsion_idxs, (torsion_params, _) = handler.parameterize(combined_mol) # print(torsion_idxs) # assert 0 # nrg_fns.append( # functools.partial(bonded.periodic_torsion, # params=torsion_params, # box=None, # lamb=None, # torsion_idxs=torsion_idxs # ) # ) # elif isinstance(handler, handlers.ProperTorsionHandler): # torsion_idxs, (torsion_params, _) = handler.parameterize(combined_mol) # # print(torsion_idxs) # nrg_fns.append( # functools.partial(bonded.periodic_torsion, # params=torsion_params, # box=None, # lamb=None, # torsion_idxs=torsion_idxs # ) # ) masses_a = onp.array([a.GetMass() for a in ligand_a.GetAtoms()]) * 10000 masses_b = onp.array([a.GetMass() for a in ligand_b.GetAtoms()]) combined_masses = np.concatenate([masses_a, masses_b]) # com_restraint_fn = functools.partial(bonded.centroid_restraint, # params=None, # box=None, # lamb=None, # # masses=combined_masses, # try making this ones-like # masses=np.ones_like(combined_masses), # group_a_idxs=a_idxs, # group_b_idxs=b_idxs, # kb=50.0, # b0=0.0) pmi_restraint_fn = functools.partial(pmi_restraints_new, params=None, box=None, lamb=None, # masses=np.ones_like(combined_masses), masses=combined_masses, # a_idxs=a_full_idxs, # b_idxs=b_full_idxs, a_idxs=a_idxs, b_idxs=b_idxs, angle_force=100.0, com_force=100.0 ) prefactor = 2.7 # unitless shape_lamb = (4*np.pi)/(3*prefactor) # unitless kappa = np.pi/(np.power(shape_lamb, 2/3)) # unitless sigma = 0.15 # 1 angstrom std, 95% coverage by 2 angstroms alpha = kappa/(sigma*sigma) alphas = np.zeros(combined_mol.GetNumAtoms())+alpha weights = np.zeros(combined_mol.GetNumAtoms())+prefactor shape_restraint_fn = functools.partial( shape.harmonic_overlap, box=None, lamb=None, params=None, a_idxs=a_idxs, b_idxs=b_idxs, alphas=alphas, weights=weights, k=150.0 ) # shape_restraint_4d_fn = functools.partial( # shape.harmonic_4d_overlap, # box=None, # params=None, # a_idxs=a_idxs, # b_idxs=b_idxs, # alphas=alphas, # weights=weights, # k=200.0 # ) def restraint_fn(conf, lamb): return pmi_restraint_fn(conf) + lamb*shape_restraint_fn(conf) # return (1-lamb)*pmi_restraint_fn(conf) + lamb*shape_restraint_fn(conf) nrg_fns.append(restraint_fn) def nrg_fn(conf, lamb): s = [] for u in nrg_fns: s.append(u(conf, lamb=lamb)) return np.sum(s) grad_fn = jax.grad(nrg_fn, argnums=(0,1)) grad_fn = jax.jit(grad_fn) du_dx_fn = jax.grad(nrg_fn, argnums=(0)) du_dx_fn = jax.jit(du_dx_fn) x_t = coords v_t = np.zeros_like(x_t) w = Chem.SDWriter('frames_heavy_'+str(epoch)+'_'+str(lamb_idx)+'.sdf') dt = 1.5e-3 ca, cb, cc = langevin_coefficients(300.0, dt, 1.0, combined_masses) cb = -1*onp.expand_dims(cb, axis=-1) cc = onp.expand_dims(cc, axis=-1) du_dls = [] # re-seed since forking onp.random.seed(int.from_bytes(os.urandom(4), byteorder='little')) # for step in range(100000): for step in range(100000): # if step % 1000 == 0: # u = nrg_fn(x_t, lamb) # print("step", step, "nrg", onp.asarray(u), "avg_du_dl", onp.mean(du_dls)) # mol = make_conformer(combined_mol, x_t[:ligand_a.GetNumAtoms()], x_t[ligand_a.GetNumAtoms():]) # w.write(mol) # w.flush() if step % 5 == 0 and step > 10000: du_dx, du_dl = grad_fn(x_t, lamb) du_dls.append(du_dl) else: du_dx = du_dx_fn(x_t, lamb) v_t = ca*v_t + cb*du_dx + cc*onp.random.normal(size=x_t.shape) x_t = x_t + v_t*dt return np.mean(onp.mean(du_dls))