def test_avg_potential_param_sizes_is_zero(self): np.random.seed(814) N = 8 D = 3 x0 = np.random.rand(N, D).astype(dtype=np.float64) * 2 masses = np.random.rand(N) v0 = np.random.rand(x0.shape[0], x0.shape[1]) num_steps = 3 ca = np.random.rand() cbs = -np.random.rand(len(masses)) / 1 ccs = np.zeros_like(cbs) dt = 2e-3 lamb = np.random.rand() box = np.eye(3) * 1.5 intg = custom_ops.LangevinIntegrator(dt, ca, cbs, ccs, 814) # Construct a 'bad' centroid restraint potential = potentials.CentroidRestraint( np.random.randint(N, size=5, dtype=np.int32), np.random.randint(N, size=5, dtype=np.int32), 10.0, 0.0) # Bind to empty params bp = potential.bind(np.zeros(0)).bound_impl(precision=np.float64) ctxt = custom_ops.Context(x0, v0, box, intg, [bp]) for _ in range(num_steps): ctxt.step(lamb)
def test_set_and_get(self): """ This test the setters and getters in the context. """ np.random.seed(4321) N = 8 D = 3 x0 = np.random.rand(N, D).astype(dtype=np.float64) * 2 E = 2 lambda_plane_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32) lambda_offset_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32) params, _, test_nrg = prepare_nb_system( x0, E, lambda_plane_idxs, lambda_offset_idxs, p_scale=3.0, cutoff=1.0, ) masses = np.random.rand(N) v0 = np.random.rand(x0.shape[0], x0.shape[1]) temperature = 300 dt = 2e-3 friction = 0.0 ca, cbs, ccs = langevin_coefficients(temperature, dt, friction, masses) box = np.eye(3) * 3.0 intg = custom_ops.LangevinIntegrator(dt, ca, cbs, ccs, 1234) bp = test_nrg.bind(params).bound_impl(precision=np.float64) bps = [bp] ctxt = custom_ops.Context(x0, v0, box, intg, bps) np.testing.assert_equal(ctxt.get_x_t(), x0) np.testing.assert_equal(ctxt.get_v_t(), v0) np.testing.assert_equal(ctxt.get_box(), box) new_x = np.random.rand(N, 3) ctxt.set_x_t(new_x) np.testing.assert_equal(ctxt.get_x_t(), new_x)
def impl(self): return custom_ops.LangevinIntegrator(self.dt, self.ca, self.cbs, self.ccs, self.seed)
def test_fwd_mode(self): """ This test ensures that we can reverse-mode differentiate observables that are dU_dlambdas of each state. We provide adjoints with respect to each computed dU/dLambda. """ np.random.seed(4321) N = 8 D = 3 x0 = np.random.rand(N, D).astype(dtype=np.float64) * 2 E = 2 lambda_plane_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32) lambda_offset_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32) params, ref_nrg_fn, test_nrg = prepare_nb_system( x0, E, lambda_plane_idxs, lambda_offset_idxs, p_scale=3.0, # cutoff=0.5, cutoff=1.0, ) masses = np.random.rand(N) v0 = np.random.rand(x0.shape[0], x0.shape[1]) num_steps = 5 temperature = 300 dt = 2e-3 friction = 0.0 ca, cbs, ccs = langevin_coefficients(temperature, dt, friction, masses) # not convenient to simulate identical trajectories otherwise assert (ccs == 0).all() lamb = np.random.rand() lambda_windows = np.array([lamb + 0.05, lamb, lamb - 0.05]) def integrate_once_through(x_t, v_t, box, params): dU_dx_fn = jax.grad(ref_nrg_fn, argnums=(0, )) dU_dp_fn = jax.grad(ref_nrg_fn, argnums=(1, )) dU_dl_fn = jax.grad(ref_nrg_fn, argnums=(3, )) all_du_dls = [] all_du_dps = [] all_xs = [] all_du_dxs = [] all_us = [] all_lambda_us = [] for step in range(num_steps): u = ref_nrg_fn(x_t, params, box, lamb) all_us.append(u) du_dl = dU_dl_fn(x_t, params, box, lamb)[0] all_du_dls.append(du_dl) du_dp = dU_dp_fn(x_t, params, box, lamb)[0] all_du_dps.append(du_dp) du_dx = dU_dx_fn(x_t, params, box, lamb)[0] all_du_dxs.append(du_dx) all_xs.append(x_t) lus = [] for lamb_u in lambda_windows: lus.append(ref_nrg_fn(x_t, params, box, lamb_u)) all_lambda_us.append(lus) noise = np.random.randn(*v_t.shape) v_mid = v_t + np.expand_dims(cbs, axis=-1) * du_dx v_t = ca * v_mid + np.expand_dims(ccs, axis=-1) * noise x_t += 0.5 * dt * (v_mid + v_t) # note that we do not calculate the du_dl of the last frame. return all_xs, all_du_dxs, all_du_dps, all_du_dls, all_us, all_lambda_us box = np.eye(3) * 3.0 # when we have multiple parameters, we need to set this up correctly ( ref_all_xs, ref_all_du_dxs, ref_all_du_dps, ref_all_du_dls, ref_all_us, ref_all_lambda_us, ) = integrate_once_through(x0, v0, box, params) intg = custom_ops.LangevinIntegrator(dt, ca, cbs, ccs, 1234) bp = test_nrg.bind(params).bound_impl(precision=np.float64) bps = [bp] ctxt = custom_ops.Context(x0, v0, box, intg, bps) for step in range(num_steps): print("comparing step", step) test_x_t = ctxt.get_x_t() np.testing.assert_allclose(test_x_t, ref_all_xs[step]) ctxt.step(lamb) test_du_dx_t = ctxt._get_du_dx_t_minus_1() # test_u_t = ctxt._get_u_t_minus_1() # np.testing.assert_allclose(test_u_t, ref_all_us[step]) np.testing.assert_allclose(test_du_dx_t, ref_all_du_dxs[step]) # test the multiple_steps method ctxt_2 = custom_ops.Context(x0, v0, box, intg, bps) lambda_schedule = np.ones(num_steps) * lamb du_dl_interval = 3 x_interval = 2 start_box = ctxt_2.get_box() test_du_dls, test_xs, test_boxes = ctxt_2.multiple_steps( lambda_schedule, du_dl_interval, x_interval) end_box = ctxt_2.get_box() np.testing.assert_allclose(test_du_dls, ref_all_du_dls[::du_dl_interval]) np.testing.assert_allclose(test_xs, ref_all_xs[::x_interval]) np.testing.assert_array_equal(start_box, end_box) for i in range(test_boxes.shape[0]): np.testing.assert_array_equal(start_box, test_boxes[i]) self.assertEqual(test_boxes.shape[0], test_xs.shape[0]) self.assertEqual(test_boxes.shape[1], D) self.assertEqual(test_boxes.shape[2], test_xs.shape[2]) # test the multiple_steps_U method ctxt_3 = custom_ops.Context(x0, v0, box, intg, bps) u_interval = 3 test_us, test_xs, test_boxes = ctxt_3.multiple_steps_U( lamb, num_steps, lambda_windows, u_interval, x_interval) np.testing.assert_array_almost_equal(ref_all_lambda_us[::u_interval], test_us) np.testing.assert_array_almost_equal(ref_all_xs[::x_interval], test_xs) test_us, test_xs, test_boxes = ctxt_3.multiple_steps_U( lamb, num_steps, np.array([], dtype=np.float64), u_interval, x_interval) assert test_us.shape == (2, 0)
def test_fwd_mode(self): """ This test ensures that we can reverse-mode differentiate observables that are dU_dlambdas of each state. We provide adjoints with respect to each computed dU/dLambda. """ np.random.seed(4321) N = 8 B = 5 A = 0 T = 0 D = 3 x0 = np.random.rand(N, D).astype(dtype=np.float64) * 2 E = 2 lambda_plane_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32) lambda_offset_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32) params, ref_nrg_fn, test_nrg = prepare_nb_system( x0, E, lambda_plane_idxs, lambda_offset_idxs, p_scale=3.0, # cutoff=0.5, cutoff=1.5) masses = np.random.rand(N) v0 = np.random.rand(x0.shape[0], x0.shape[1]) N = len(masses) num_steps = 5 lambda_schedule = np.random.rand(num_steps) ca = np.random.rand() cbs = -np.random.rand(len(masses)) / 1 ccs = np.zeros_like(cbs) dt = 2e-3 lamb = np.random.rand() def loss_fn(du_dls): return jnp.sum(du_dls * du_dls) / du_dls.shape[0] def sum_loss_fn(du_dls): du_dls = np.sum(du_dls, axis=0) return jnp.sum(du_dls * du_dls) / du_dls.shape[0] def integrate_once_through(x_t, v_t, box, params): dU_dx_fn = jax.grad(ref_nrg_fn, argnums=(0, )) dU_dp_fn = jax.grad(ref_nrg_fn, argnums=(1, )) dU_dl_fn = jax.grad(ref_nrg_fn, argnums=(3, )) all_du_dls = [] all_du_dps = [] all_xs = [] all_du_dxs = [] all_us = [] for step in range(num_steps): u = ref_nrg_fn(x_t, params, box, lamb) all_us.append(u) du_dl = dU_dl_fn(x_t, params, box, lamb)[0] all_du_dls.append(du_dl) du_dp = dU_dp_fn(x_t, params, box, lamb)[0] all_du_dps.append(du_dp) du_dx = dU_dx_fn(x_t, params, box, lamb)[0] all_du_dxs.append(du_dx) v_t = ca * v_t + np.expand_dims(cbs, axis=-1) * du_dx x_t = x_t + v_t * dt all_xs.append(x_t) # note that we do not calculate the du_dl of the last frame. return all_xs, all_du_dxs, all_du_dps, all_du_dls, all_us box = np.eye(3) * 1.5 # when we have multiple parameters, we need to set this up correctly ref_all_xs, ref_all_du_dxs, ref_all_du_dps, ref_all_du_dls, ref_all_us = integrate_once_through( x0, v0, box, params) intg = custom_ops.LangevinIntegrator(dt, ca, cbs, ccs, 1234) bp = test_nrg.bind(params).bound_impl(precision=np.float64) bps = [bp] ctxt = custom_ops.Context(x0, v0, box, intg, bps) test_obs = custom_ops.AvgPartialUPartialParam(bp, 1) test_obs_f2 = custom_ops.AvgPartialUPartialParam(bp, 2) test_obs_du_dl = custom_ops.AvgPartialUPartialLambda(bps, 1) test_obs_f2_du_dl = custom_ops.AvgPartialUPartialLambda(bps, 2) test_obs_f3_du_dl = custom_ops.FullPartialUPartialLambda(bps, 2) obs = [ test_obs, test_obs_f2, test_obs_du_dl, test_obs_f2_du_dl, test_obs_f3_du_dl ] for o in obs: ctxt.add_observable(o) for step in range(num_steps): print("comparing step", step) ctxt.step(lamb) test_x_t = ctxt.get_x_t() test_v_t = ctxt.get_v_t() test_du_dx_t = ctxt._get_du_dx_t_minus_1() # test_u_t = ctxt._get_u_t_minus_1() # np.testing.assert_allclose(test_u_t, ref_all_us[step]) np.testing.assert_allclose(test_du_dx_t, ref_all_du_dxs[step]) np.testing.assert_allclose(test_x_t, ref_all_xs[step]) ref_avg_du_dls = np.mean(ref_all_du_dls, axis=0) ref_avg_du_dls_f2 = np.mean(ref_all_du_dls[::2], axis=0) np.testing.assert_allclose(test_obs_du_dl.avg_du_dl(), ref_avg_du_dls) np.testing.assert_allclose(test_obs_f2_du_dl.avg_du_dl(), ref_avg_du_dls_f2) full_du_dls = test_obs_f3_du_dl.full_du_dl() assert len(full_du_dls) == np.ceil(num_steps / 2) np.testing.assert_allclose(np.mean(full_du_dls), ref_avg_du_dls_f2) ref_avg_du_dps = np.mean(ref_all_du_dps, axis=0) ref_avg_du_dps_f2 = np.mean(ref_all_du_dps[::2], axis=0) # the fixed point accumulator makes it hard to converge some of these # if the derivative is super small - in which case they probably don't matter # anyways np.testing.assert_allclose(test_obs.avg_du_dp()[:, 0], ref_avg_du_dps[:, 0], 1.5e-6) np.testing.assert_allclose(test_obs.avg_du_dp()[:, 1], ref_avg_du_dps[:, 1], 1.5e-6) np.testing.assert_allclose(test_obs.avg_du_dp()[:, 2], ref_avg_du_dps[:, 2], 5e-5)