def run(args): lamb, intg, bound_potentials, masses, x0, box, gpu_idx, stage = args # print("running on", gpu_idx) os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_idx) u_impls = [] for bp in bound_potentials: u_impls.append(bp.bound_impl(precision=np.float32)) # important that we reseed here. intg.seed = np.random.randint(np.iinfo(np.int32).max) intg_impl = intg.impl() v0 = np.zeros_like(x0) ctxt = custom_ops.Context(x0, v0, box, intg_impl, u_impls) # secondary minimization needed only for stage 1 if stage == 1: for l in np.linspace(0.35, lamb, 500): ctxt.step(l) # equilibration for step in range(20000): # for step in range(1000): ctxt.step(lamb) # print(ctxt.get_x_t()) du_dl_obs = custom_ops.AvgPartialUPartialLambda(u_impls, 10) ctxt.add_observable(du_dl_obs) # add observable for <du/dl> for step in range(50000): # for step in range(5000): ctxt.step(lamb) print(lamb, du_dl_obs.avg_du_dl()) assert np.any(np.abs(ctxt.get_x_t()) > 100) == False assert np.any(np.isnan(ctxt.get_x_t())) == False assert np.any(np.isinf(ctxt.get_x_t())) == False return du_dl_obs.avg_du_dl()
def Simulate(self, request, context): if request.precision == 'single': precision = np.float32 elif request.precision == 'double': precision = np.float64 else: raise Exception("Unknown precision") simulation = pickle.loads(request.simulation) bps = [] pots = [] for potential in simulation.potentials: bps.append(potential.bound_impl()) # get the bound implementation intg = simulation.integrator.impl() ctxt = custom_ops.Context(simulation.x, simulation.v, simulation.box, intg, bps) lamb = request.lamb for step, minimize_lamb in enumerate( np.linspace(1.0, lamb, request.prep_steps)): ctxt.step(minimize_lamb) energies = [] frames = [] if request.observe_du_dl_freq > 0: du_dl_obs = custom_ops.AvgPartialUPartialLambda( bps, request.observe_du_dl_freq) ctxt.add_observable(du_dl_obs) if request.observe_du_dp_freq > 0: du_dps = [] # for name, bp in zip(names, bps): # if name == 'LennardJones' or name == 'Electrostatics': for bp in bps: du_dp_obs = custom_ops.AvgPartialUPartialParam( bp, request.observe_du_dp_freq) ctxt.add_observable(du_dp_obs) du_dps.append(du_dp_obs) # dynamics for step in range(request.prod_steps): if step % 100 == 0: u = ctxt._get_u_t_minus_1() energies.append(u) if request.n_frames > 0: interval = max(1, request.prod_steps // request.n_frames) if step % interval == 0: frames.append(ctxt.get_x_t()) ctxt.step(lamb) frames = np.array(frames) if request.observe_du_dl_freq > 0: avg_du_dls = du_dl_obs.avg_du_dl() else: avg_du_dls = None if request.observe_du_dp_freq > 0: avg_du_dps = [] for obs in du_dps: avg_du_dps.append(obs.avg_du_dp()) else: avg_du_dps = None return service_pb2.SimulateReply( avg_du_dls=pickle.dumps(avg_du_dls), avg_du_dps=pickle.dumps(avg_du_dps), energies=pickle.dumps(energies), frames=pickle.dumps(frames), )
writer.write_frame(ctxt.get_x_t() * 10) ctxt.step(lamb) # print("insertion energy", ctxt._get_u_t_minus_1()) # note: these 5000 steps are "equilibration", before we attach a reporter / # "observable" to the context and start running "production" for step in range(5000): if step % 500 == 0: writer.write_frame(ctxt.get_x_t() * 10) ctxt.step(final_lamb) # print("equilibrium energy", ctxt._get_u_t_minus_1()) # TODO: what was the second argument -- reporting interval in steps? du_dl_obs = custom_ops.AvgPartialUPartialLambda(u_impls, 5) du_dps = [] for ui in u_impls: du_dp_obs = custom_ops.AvgPartialUPartialParam(ui, 5) ctxt.add_observable(du_dp_obs) du_dps.append(du_dp_obs) ctxt.add_observable(du_dl_obs) for _ in range(20000): if step % 500 == 0: writer.write_frame(ctxt.get_x_t() * 10) ctxt.step(final_lamb) writer.close()
def test_fwd_mode(self): """ This test ensures that we can reverse-mode differentiate observables that are dU_dlambdas of each state. We provide adjoints with respect to each computed dU/dLambda. """ np.random.seed(4321) N = 8 B = 5 A = 0 T = 0 D = 3 x0 = np.random.rand(N, D).astype(dtype=np.float64) * 2 E = 2 lambda_plane_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32) lambda_offset_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32) params, ref_nrg_fn, test_nrg = prepare_nb_system( x0, E, lambda_plane_idxs, lambda_offset_idxs, p_scale=3.0, # cutoff=0.5, cutoff=1.5) masses = np.random.rand(N) v0 = np.random.rand(x0.shape[0], x0.shape[1]) N = len(masses) num_steps = 5 lambda_schedule = np.random.rand(num_steps) ca = np.random.rand() cbs = -np.random.rand(len(masses)) / 1 ccs = np.zeros_like(cbs) dt = 2e-3 lamb = np.random.rand() def loss_fn(du_dls): return jnp.sum(du_dls * du_dls) / du_dls.shape[0] def sum_loss_fn(du_dls): du_dls = np.sum(du_dls, axis=0) return jnp.sum(du_dls * du_dls) / du_dls.shape[0] def integrate_once_through(x_t, v_t, box, params): dU_dx_fn = jax.grad(ref_nrg_fn, argnums=(0, )) dU_dp_fn = jax.grad(ref_nrg_fn, argnums=(1, )) dU_dl_fn = jax.grad(ref_nrg_fn, argnums=(3, )) all_du_dls = [] all_du_dps = [] all_xs = [] all_du_dxs = [] all_us = [] for step in range(num_steps): u = ref_nrg_fn(x_t, params, box, lamb) all_us.append(u) du_dl = dU_dl_fn(x_t, params, box, lamb)[0] all_du_dls.append(du_dl) du_dp = dU_dp_fn(x_t, params, box, lamb)[0] all_du_dps.append(du_dp) du_dx = dU_dx_fn(x_t, params, box, lamb)[0] all_du_dxs.append(du_dx) v_t = ca * v_t + np.expand_dims(cbs, axis=-1) * du_dx x_t = x_t + v_t * dt all_xs.append(x_t) # note that we do not calculate the du_dl of the last frame. return all_xs, all_du_dxs, all_du_dps, all_du_dls, all_us box = np.eye(3) * 1.5 # when we have multiple parameters, we need to set this up correctly ref_all_xs, ref_all_du_dxs, ref_all_du_dps, ref_all_du_dls, ref_all_us = integrate_once_through( x0, v0, box, params) intg = custom_ops.LangevinIntegrator(dt, ca, cbs, ccs, 1234) bp = test_nrg.bind(params).bound_impl(precision=np.float64) bps = [bp] ctxt = custom_ops.Context(x0, v0, box, intg, bps) test_obs = custom_ops.AvgPartialUPartialParam(bp, 1) test_obs_f2 = custom_ops.AvgPartialUPartialParam(bp, 2) test_obs_du_dl = custom_ops.AvgPartialUPartialLambda(bps, 1) test_obs_f2_du_dl = custom_ops.AvgPartialUPartialLambda(bps, 2) test_obs_f3_du_dl = custom_ops.FullPartialUPartialLambda(bps, 2) obs = [ test_obs, test_obs_f2, test_obs_du_dl, test_obs_f2_du_dl, test_obs_f3_du_dl ] for o in obs: ctxt.add_observable(o) for step in range(num_steps): print("comparing step", step) ctxt.step(lamb) test_x_t = ctxt.get_x_t() test_v_t = ctxt.get_v_t() test_du_dx_t = ctxt._get_du_dx_t_minus_1() # test_u_t = ctxt._get_u_t_minus_1() # np.testing.assert_allclose(test_u_t, ref_all_us[step]) np.testing.assert_allclose(test_du_dx_t, ref_all_du_dxs[step]) np.testing.assert_allclose(test_x_t, ref_all_xs[step]) ref_avg_du_dls = np.mean(ref_all_du_dls, axis=0) ref_avg_du_dls_f2 = np.mean(ref_all_du_dls[::2], axis=0) np.testing.assert_allclose(test_obs_du_dl.avg_du_dl(), ref_avg_du_dls) np.testing.assert_allclose(test_obs_f2_du_dl.avg_du_dl(), ref_avg_du_dls_f2) full_du_dls = test_obs_f3_du_dl.full_du_dl() assert len(full_du_dls) == np.ceil(num_steps / 2) np.testing.assert_allclose(np.mean(full_du_dls), ref_avg_du_dls_f2) ref_avg_du_dps = np.mean(ref_all_du_dps, axis=0) ref_avg_du_dps_f2 = np.mean(ref_all_du_dps[::2], axis=0) # the fixed point accumulator makes it hard to converge some of these # if the derivative is super small - in which case they probably don't matter # anyways np.testing.assert_allclose(test_obs.avg_du_dp()[:, 0], ref_avg_du_dps[:, 0], 1.5e-6) np.testing.assert_allclose(test_obs.avg_du_dp()[:, 1], ref_avg_du_dps[:, 1], 1.5e-6) np.testing.assert_allclose(test_obs.avg_du_dp()[:, 2], ref_avg_du_dps[:, 2], 5e-5)