Exemplo n.º 1
0
Arquivo: api_test.py Projeto: yyht/jax
        def foo(x):
            @api.custom_transforms
            def bar(y):
                return x * y

            api.defvjp(bar, lambda g, ans, y: x * y)
            return bar(x)
Exemplo n.º 2
0
Arquivo: api_test.py Projeto: yyht/jax
    def test_defvjp_higher_order(self):
        @api.custom_transforms
        def foo(x):
            return np.sin(2. * x)

        api.defvjp(foo, lambda g, _, x: g * np.cos(x))
        ans = api.grad(api.grad(foo))(2.)
        expected = api.grad(api.grad(np.sin))(2.)
        self.assertAllClose(ans, expected, check_dtypes=False)
Exemplo n.º 3
0
  def test_defvjp_use_ans(self):
    @api.custom_transforms
    def foo(x, y):
      return np.sin(x * y)

    api.defvjp(foo, None, lambda g, ans, x, y: g * x * y + np.cos(ans))
    val_ans, grad_ans = api.value_and_grad(foo, 1)(3., 4.)
    self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False)
    self.assertAllClose(grad_ans, 3. * 4. + onp.cos(onp.sin(3. * 4)),
                        check_dtypes=False)
Exemplo n.º 4
0
Arquivo: api_test.py Projeto: yyht/jax
    def test_defvjp(self):
        @api.custom_transforms
        def foo(x, y):
            return np.sin(x * y)

        api.defvjp(foo, None, lambda g, _, x, y: g * x * y)
        val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
        self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False)
        self.assertAllClose(grad_ans, 0., check_dtypes=False)

        ans_0, ans_1 = api.grad(foo, (0, 1))(3., 4.)
        self.assertAllClose(ans_0, 0., check_dtypes=False)
        self.assertAllClose(ans_1, 3. * 4., check_dtypes=False)