Ejemplo n.º 1
0
    def test_defvjp(self):
        @api.custom_transforms
        def foo(x, y):
            return np.sin(x * y)

        defvjp(foo.primitive, None, lambda g, x, y: g * x * y)
        val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
        self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False)
        self.assertAllClose(grad_ans, 0., check_dtypes=False)

        ans_0, ans_1 = api.grad(foo, (0, 1))(3., 4.)
        self.assertAllClose(ans_0, 0., check_dtypes=False)
        self.assertAllClose(ans_1, 3. * 4., check_dtypes=False)
Ejemplo n.º 2
0
def mybar_impl(w):
    A, _ = pymbar.BAR(w[0], w[1])
    return A


def mybar_vjp(g, w):
    return g * tmbar.dG_dw(w)


def mybar(x):
    return mybar_p.bind(x)


mybar_p = core.Primitive('mybar')
mybar_p.def_impl(mybar_impl)
ad.defvjp(mybar_p, mybar_vjp)


def BAR_leg(insertion_du_dls, deletion_du_dls, lambda_schedule):
    insertion_W = math_utils.trapz(insertion_du_dls, lambda_schedule)
    deletion_W = math_utils.trapz(deletion_du_dls, lambda_schedule)

    return mybar(jnp.stack([insertion_W, deletion_W]))


def BAR_loss(
        complex_insertion_du_dls,  # [C, N]
        complex_deletion_du_dls,  # [C, N]
        solvent_insertion_du_dls,  # [C, N]
        solvent_deletion_du_dls,  # [C, N]
        lambda_schedule,