def ForwardMode(self, request, context): assert self.state is None if request.precision == 'single': precision = np.float32 elif request.precision == 'double': precision = np.float64 else: raise Exception("Unknown precision") system = pickle.loads(request.system) gradients = [] force_names = [] for grad_name, grad_args in system.gradients: force_names.append(grad_name) op_fn = getattr(ops, grad_name) grad = op_fn(*grad_args, precision=precision) gradients.append(grad) integrator = system.integrator stepper = custom_ops.AlchemicalStepper_f64(gradients, integrator.lambs) ctxt = custom_ops.ReversibleContext_f64(stepper, system.x0, system.v0, integrator.cas, integrator.cbs, integrator.ccs, integrator.dts, integrator.seed) start = time.time() ctxt.forward_mode() full_du_dls = stepper.get_du_dl() # [FxT] energies = stepper.get_energies() keep_idxs = [] if request.n_frames > 0: xs = ctxt.get_all_coords() interval = max(1, xs.shape[0] // request.n_frames) for frame_idx in range(xs.shape[0]): if frame_idx % interval == 0: keep_idxs.append(frame_idx) frames = np.zeros((0, *system.x0.shape), dtype=system.x0.dtype) reply = service_pb2.ForwardReply(du_dls=pickle.dumps(full_du_dls), energies=pickle.dumps(energies), frames=pickle.dumps(frames)) # store and set state for backwards mode use. if request.inference is False: self.state = (ctxt, gradients, force_names, stepper, system) return reply
def test_reverse_mode_lambda(self): """ This test ensures that we can reverse-mode differentiate observables that are dU_dlambdas of each state. We provide adjoints with respect to each computed dU/dLambda. """ np.random.seed(4321) N = 5 B = 5 A = 0 T = 0 D = 3 x0 = np.random.rand(N, D).astype(dtype=np.float64) * 2 precision = np.float64 (bond_params, ref_bond), test_bond = prepare_bonded_system(x0, B, A, T, precision) (restr_params, ref_restr), test_restr = prepare_restraints(x0, B, precision) E = 2 lambda_plane_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32) lambda_offset_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32) (charge_params, lj_params), ref_nb_fn, test_nb_ctor = prepare_nonbonded_system( x0, E, lambda_plane_idxs, lambda_offset_idxs, p_scale=10.0, cutoff=1000.0, precision=precision) test_nb = test_nb_ctor() masses = np.random.rand(N) v0 = np.random.rand(x0.shape[0], x0.shape[1]) N = len(masses) num_steps = 5 lambda_schedule = np.random.rand(num_steps) cas = np.random.rand(num_steps) cbs = np.random.rand(len(masses)) / 10 ccs = np.zeros_like(cbs) step_sizes = np.random.rand(num_steps) def loss_fn(du_dls): return jnp.sum(du_dls * du_dls) / du_dls.shape[0] def sum_loss_fn(du_dls): du_dls = np.sum(du_dls, axis=0) return jnp.sum(du_dls * du_dls) / du_dls.shape[0] def integrate_once_through(x_t, v_t, bond_params, restr_params, charge_params, lj_params): ref_bond_impl = functools.partial(ref_bond, params=bond_params) ref_restr_impl = functools.partial(ref_restr, params=restr_params) ref_nb_impl = functools.partial(ref_nb_fn, charge_params=charge_params, lj_params=lj_params) def ref_total_nrg_fn(*args): nrgs = [] for fn in [ref_bond_impl, ref_restr_impl, ref_nb_impl]: nrgs.append(fn(*args)) return jnp.sum(nrgs) dU_dx_fn = jax.grad(ref_total_nrg_fn, argnums=(0, )) dU_dl_fn = jax.grad(ref_total_nrg_fn, argnums=(1, )) all_du_dls = [] for step in range(num_steps): lamb = lambda_schedule[step] du_dl = dU_dl_fn(x_t, lamb)[0] all_du_dls.append(du_dl) dt = step_sizes[step] cb_tmp = np.expand_dims(cbs, axis=-1) v_t = cas[step] * v_t + cb_tmp * dU_dx_fn(x_t, lamb)[0] x_t = x_t + v_t * dt # note that we do not calculate the du_dl of the last frame. all_du_dls = jnp.stack(all_du_dls) return loss_fn(all_du_dls) # when we have multiple parameters, we need to set this up correctly ref_loss = integrate_once_through(x0, v0, bond_params, restr_params, charge_params, lj_params) grad_fn = jax.grad(integrate_once_through, argnums=(2, 3)) ref_dl_dp_bond, ref_dl_dp_restr = grad_fn(x0, v0, bond_params, restr_params, charge_params, lj_params) stepper = custom_ops.AlchemicalStepper_f64( [test_bond, test_restr, test_nb], lambda_schedule) seed = 1234 ctxt = custom_ops.ReversibleContext_f64(stepper, x0, v0, cas, cbs, ccs, step_sizes, seed) # run 5 steps forward ctxt.forward_mode() test_du_dls = stepper.get_du_dl() test_loss = sum_loss_fn(test_du_dls) loss_grad_fn = jax.grad(sum_loss_fn, argnums=(0, )) du_dl_adjoint = loss_grad_fn(test_du_dls)[0] # limit of precision is due to the settings in fixed_point.hpp # np.testing.assert_almost_equal(test_loss, ref_loss, decimal=7) np.testing.assert_allclose(test_loss, ref_loss, rtol=1e-6) stepper.set_du_dl_adjoint(du_dl_adjoint) ctxt.set_x_t_adjoint(np.zeros_like(x0)) ctxt.backward_mode() test_dl_dp = test_bond.get_du_dp_tangents() np.testing.assert_allclose(test_dl_dp, ref_dl_dp_bond, rtol=1e-6) test_dl_dp = test_restr.get_du_dp_tangents() np.testing.assert_allclose(test_dl_dp, ref_dl_dp_restr, rtol=1e-6)
def ForwardMode(self, request, context): if request.precision == 'single': precision = np.float32 elif request.precision == 'double': precision = np.float64 else: raise Exception("Unknown precision") system = pickle.loads(request.system) gradients = [] force_names = [] for grad_name, grad_args in system.gradients: force_names.append(grad_name) op_fn = getattr(ops, grad_name) grad = op_fn(*grad_args, precision=precision) gradients.append(grad) integrator = system.integrator stepper = custom_ops.AlchemicalStepper_f64(gradients, integrator.lambs) ctxt = custom_ops.ReversibleContext_f64(stepper, system.x0, system.v0, integrator.cas, integrator.cbs, integrator.ccs, integrator.dts, integrator.seed) start = time.time() # ensure only one GPU can be running at given time. total_size = 0 with self.mutex: ctxt.forward_mode() full_du_dls = stepper.get_du_dl() # [FxT] stripped_du_dls = [] energies = stepper.get_energies() for force_du_dls in full_du_dls: # zero out if np.all(force_du_dls) == 0: stripped_du_dls.append(None) else: stripped_du_dls.append(force_du_dls) total_size += len(force_du_dls) keep_idxs = [] if request.n_frames > 0: xs = ctxt.get_all_coords() interval = max(1, xs.shape[0] // request.n_frames) for frame_idx in range(xs.shape[0]): if frame_idx % interval == 0: keep_idxs.append(frame_idx) frames = xs[keep_idxs] else: frames = np.zeros((0, *system.x0.shape), dtype=system.x0.dtype) # store and set state for backwards mode use. if request.inference is False: self.states[request.key] = (ctxt, gradients, force_names, stepper, system) return service_pb2.ForwardReply( du_dls=pickle.dumps(stripped_du_dls), # tbd strip zeros energies=pickle.dumps(energies), frames=pickle.dumps(frames), )
lowering_steps = 10000 new_lambda_schedule = np.concatenate([ np.linspace(1.0, 0.0, lowering_steps), np.zeros(n_steps - lowering_steps) ]) stepper = custom_ops.AlchemicalStepper_f64(gradients, new_lambda_schedule # integrator.lambs ) v0 = np.zeros_like(x0) ctxt = custom_ops.ReversibleContext_f64(stepper, x0, v0, integrator.cas, integrator.cbs, integrator.ccs, integrator.dts, integrator.seed) combined_pdb_str = StringIO(Chem.MolToPDBBlock(combined_pdb)) out_file = "pose_dock.pdb" pdb_writer = PDBWriter(combined_pdb_str, out_file) # pdb_writer.write_header() # frames = ctxt.get_all_coords() # for frame_idx, x in enumerate(frames): # # if frame_idx % 100 == 0: # pdb_writer.write(x*10) # break # pdb_writer.close()