예제 #1
0
  def test_grad_ob_aux_return(self):
    class Test(bp.Base):
      def __init__(self):
        super(Test, self).__init__()
        self.a = bm.TrainVar(bm.ones(10))
        self.b = bm.TrainVar(bm.random.randn(10))
        self.c = bm.TrainVar(bm.random.uniform(size=10))

      def __call__(self):
        return bm.sum(self.a + self.b + self.c), (bm.sin(100), bm.exp(0.1))

    bm.random.seed(0)
    t = Test()
    f_grad = bm.grad(t, grad_vars=[t.a, t.b], dyn_vars=t.vars(),
                     has_aux=True, return_value=True)
    grads, returns, aux = f_grad()
    for g in grads: assert (g == 1.).all()
    assert returns == bm.sum(t.a + t.b + t.c)
    assert aux[0] == bm.sin(100)
    assert aux[1] == bm.exp(0.1)

    t = Test()
    f_grad = bm.grad(t, grad_vars=t.a, dyn_vars=t.vars(),
                     has_aux=True, return_value=True)
    grads, returns, aux = f_grad()
    assert (grads == 1.).all()
    assert returns == bm.sum(t.a + t.b + t.c)
    assert aux[0] == bm.sin(100)
    assert aux[1] == bm.exp(0.1)
예제 #2
0
  def test_grad_ob_argnums_aux_return(self):
    class Test(bp.Base):
      def __init__(self):
        super(Test, self).__init__()
        self.a = bm.TrainVar(bm.ones(10))
        self.b = bm.TrainVar(bm.random.randn(10))
        self.c = bm.TrainVar(bm.random.uniform(size=10))

      def __call__(self, d):
        return bm.sum(self.a + self.b + self.c + 2 * d), (bm.sin(100), bm.exp(0.1))

    bm.random.seed(0)

    t = Test()
    f_grad = bm.grad(t, grad_vars=t.vars(), argnums=0, has_aux=True, return_value=True)
    d = bm.random.random(10)
    (var_grads, arg_grads), loss, aux = f_grad(d)
    for g in var_grads.values(): assert (g == 1.).all()
    assert (arg_grads == 2.).all()
    assert aux[0] == bm.sin(100)
    assert aux[1] == bm.exp(0.1)
    assert loss == t(d)[0]

    t = Test()
    f_grad = bm.grad(t, grad_vars=t.vars(), argnums=[0], has_aux=True, return_value=True)
    d = bm.random.random(10)
    (var_grads, arg_grads), loss, aux = f_grad(d)
    for g in var_grads.values(): assert (g == 1.).all()
    assert (arg_grads[0] == 2.).all()
    assert aux[0] == bm.sin(100)
    assert aux[1] == bm.exp(0.1)
    assert loss == t(d)[0]

    t = Test()
    f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=0, has_aux=True, return_value=True)
    d = bm.random.random(10)
    arg_grads, loss, aux = f_grad(d)
    assert (arg_grads == 2.).all()
    assert aux[0] == bm.sin(100)
    assert aux[1] == bm.exp(0.1)
    assert loss == t(d)[0]

    t = Test()
    f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=[0], has_aux=True, return_value=True)
    d = bm.random.random(10)
    arg_grads, loss, aux = f_grad(d)
    assert (arg_grads[0] == 2.).all()
    assert aux[0] == bm.sin(100)
    assert aux[1] == bm.exp(0.1)
    assert loss == t(d)[0]
예제 #3
0
  def test_jacfwd_and_aux_nested(self):
    def f(x):
      jac, aux = _jacfwd(lambda x: (x ** 3, [x ** 3]), has_aux=True)(x)
      return aux[0]

    f2 = lambda x: x ** 3

    self.assertEqual(_jacfwd(f)(4.), _jacfwd(f2)(4.))
    self.assertEqual(bm.jit(_jacfwd(f))(4.), _jacfwd(f2)(4.))
    self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(4.), _jacfwd(f2)(4.))

    self.assertEqual(_jacfwd(f)(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))
    self.assertEqual(bm.jit(_jacfwd(f))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))
    self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))

    def f(x):
      jac, aux = _jacfwd(lambda x: (x ** 3, [x ** 3]), has_aux=True)(x)
      return aux[0] * bm.sin(x)

    f2 = lambda x: x ** 3 * bm.sin(x)

    self.assertEqual(_jacfwd(f)(4.), _jacfwd(f2)(4.))
    self.assertEqual(bm.jit(_jacfwd(f))(4.), _jacfwd(f2)(4.))
    self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(4.), _jacfwd(f2)(4.))

    self.assertEqual(_jacfwd(f)(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))
    self.assertEqual(bm.jit(_jacfwd(f))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))
    self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))
예제 #4
0
  def test_grad_pure_func_aux2(self):
    def call(a, b, c):
      return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1))

    bm.random.seed(1)
    f_grad = bm.grad(call, argnums=[0, 1, 2], has_aux=True)
    grads, aux = f_grad(bm.ones(10), bm.random.randn(10), bm.random.uniform(size=10))
    for g in grads: assert (g == 1.).all()
    assert aux[0] == bm.sin(100)
    assert aux[1] == bm.exp(0.1)
예제 #5
0
  def test_grad_func_return_aux1(self):
    def call(a, b, c):
      return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1))

    bm.random.seed(1)
    a = bm.ones(10)
    b = bm.random.randn(10)
    c = bm.random.uniform(size=10)
    f_grad = bm.grad(call, return_value=True, has_aux=True)
    grads, returns, aux = f_grad(a, b, c)
    assert (grads == 1.).all()
    assert returns == bm.sum(a + b + c)
    assert aux[0] == bm.sin(100)
    assert aux[1] == bm.exp(0.1)
예제 #6
0
 def call(a, b, c):
   return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1))
예제 #7
0
 def f(x):
   jac, aux = _jacfwd(lambda x: (x ** 3, [x ** 3]), has_aux=True)(x)
   return aux[0] * bm.sin(x)
예제 #8
0
 def __call__(self, d):
   return bm.sum(self.a + self.b + self.c + 2 * d), (bm.sin(100), bm.exp(0.1))