def test_grad_ob_aux_return(self): class Test(bp.Base): def __init__(self): super(Test, self).__init__() self.a = bm.TrainVar(bm.ones(10)) self.b = bm.TrainVar(bm.random.randn(10)) self.c = bm.TrainVar(bm.random.uniform(size=10)) def __call__(self): return bm.sum(self.a + self.b + self.c), (bm.sin(100), bm.exp(0.1)) bm.random.seed(0) t = Test() f_grad = bm.grad(t, grad_vars=[t.a, t.b], dyn_vars=t.vars(), has_aux=True, return_value=True) grads, returns, aux = f_grad() for g in grads: assert (g == 1.).all() assert returns == bm.sum(t.a + t.b + t.c) assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1) t = Test() f_grad = bm.grad(t, grad_vars=t.a, dyn_vars=t.vars(), has_aux=True, return_value=True) grads, returns, aux = f_grad() assert (grads == 1.).all() assert returns == bm.sum(t.a + t.b + t.c) assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1)
def test_grad_ob_argnums_aux_return(self): class Test(bp.Base): def __init__(self): super(Test, self).__init__() self.a = bm.TrainVar(bm.ones(10)) self.b = bm.TrainVar(bm.random.randn(10)) self.c = bm.TrainVar(bm.random.uniform(size=10)) def __call__(self, d): return bm.sum(self.a + self.b + self.c + 2 * d), (bm.sin(100), bm.exp(0.1)) bm.random.seed(0) t = Test() f_grad = bm.grad(t, grad_vars=t.vars(), argnums=0, has_aux=True, return_value=True) d = bm.random.random(10) (var_grads, arg_grads), loss, aux = f_grad(d) for g in var_grads.values(): assert (g == 1.).all() assert (arg_grads == 2.).all() assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1) assert loss == t(d)[0] t = Test() f_grad = bm.grad(t, grad_vars=t.vars(), argnums=[0], has_aux=True, return_value=True) d = bm.random.random(10) (var_grads, arg_grads), loss, aux = f_grad(d) for g in var_grads.values(): assert (g == 1.).all() assert (arg_grads[0] == 2.).all() assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1) assert loss == t(d)[0] t = Test() f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=0, has_aux=True, return_value=True) d = bm.random.random(10) arg_grads, loss, aux = f_grad(d) assert (arg_grads == 2.).all() assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1) assert loss == t(d)[0] t = Test() f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=[0], has_aux=True, return_value=True) d = bm.random.random(10) arg_grads, loss, aux = f_grad(d) assert (arg_grads[0] == 2.).all() assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1) assert loss == t(d)[0]
def test_jacfwd_and_aux_nested(self): def f(x): jac, aux = _jacfwd(lambda x: (x ** 3, [x ** 3]), has_aux=True)(x) return aux[0] f2 = lambda x: x ** 3 self.assertEqual(_jacfwd(f)(4.), _jacfwd(f2)(4.)) self.assertEqual(bm.jit(_jacfwd(f))(4.), _jacfwd(f2)(4.)) self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(4.), _jacfwd(f2)(4.)) self.assertEqual(_jacfwd(f)(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) self.assertEqual(bm.jit(_jacfwd(f))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) def f(x): jac, aux = _jacfwd(lambda x: (x ** 3, [x ** 3]), has_aux=True)(x) return aux[0] * bm.sin(x) f2 = lambda x: x ** 3 * bm.sin(x) self.assertEqual(_jacfwd(f)(4.), _jacfwd(f2)(4.)) self.assertEqual(bm.jit(_jacfwd(f))(4.), _jacfwd(f2)(4.)) self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(4.), _jacfwd(f2)(4.)) self.assertEqual(_jacfwd(f)(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) self.assertEqual(bm.jit(_jacfwd(f))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))
def test_grad_pure_func_aux2(self): def call(a, b, c): return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1)) bm.random.seed(1) f_grad = bm.grad(call, argnums=[0, 1, 2], has_aux=True) grads, aux = f_grad(bm.ones(10), bm.random.randn(10), bm.random.uniform(size=10)) for g in grads: assert (g == 1.).all() assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1)
def test_grad_func_return_aux1(self): def call(a, b, c): return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1)) bm.random.seed(1) a = bm.ones(10) b = bm.random.randn(10) c = bm.random.uniform(size=10) f_grad = bm.grad(call, return_value=True, has_aux=True) grads, returns, aux = f_grad(a, b, c) assert (grads == 1.).all() assert returns == bm.sum(a + b + c) assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1)
def call(a, b, c): return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1))
def f(x): jac, aux = _jacfwd(lambda x: (x ** 3, [x ** 3]), has_aux=True)(x) return aux[0] * bm.sin(x)
def __call__(self, d): return bm.sum(self.a + self.b + self.c + 2 * d), (bm.sin(100), bm.exp(0.1))