def test_defvjp(self): @api.custom_transforms def foo(x, y): return np.sin(x * y) defvjp(foo.primitive, None, lambda g, x, y: g * x * y) val_ans, grad_ans = api.value_and_grad(foo)(3., 4.) self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False) self.assertAllClose(grad_ans, 0., check_dtypes=False) ans_0, ans_1 = api.grad(foo, (0, 1))(3., 4.) self.assertAllClose(ans_0, 0., check_dtypes=False) self.assertAllClose(ans_1, 3. * 4., check_dtypes=False)
def mybar_impl(w): A, _ = pymbar.BAR(w[0], w[1]) return A def mybar_vjp(g, w): return g * tmbar.dG_dw(w) def mybar(x): return mybar_p.bind(x) mybar_p = core.Primitive('mybar') mybar_p.def_impl(mybar_impl) ad.defvjp(mybar_p, mybar_vjp) def BAR_leg(insertion_du_dls, deletion_du_dls, lambda_schedule): insertion_W = math_utils.trapz(insertion_du_dls, lambda_schedule) deletion_W = math_utils.trapz(deletion_du_dls, lambda_schedule) return mybar(jnp.stack([insertion_W, deletion_W])) def BAR_loss( complex_insertion_du_dls, # [C, N] complex_deletion_du_dls, # [C, N] solvent_insertion_du_dls, # [C, N] solvent_deletion_du_dls, # [C, N] lambda_schedule,