Esempio n. 1
0
File: api_test.py Progetto: yyht/jax
    def test_defvjp_all(self):
        @api.custom_transforms
        def foo(x):
            return np.sin(x)

        api.defvjp_all(foo, lambda x: (np.sin(x), lambda g: (g * x, )))
        val_ans, grad_ans = api.value_and_grad(foo)(3.)
        self.assertAllClose(val_ans, onp.sin(3.), check_dtypes=False)
        self.assertAllClose(grad_ans, 3., check_dtypes=False)
Esempio n. 2
0
  def test_custom_vjp_zeros(self):
    @api.custom_transforms
    def f(x, y):
      return 2 * x, 3 * y

    def f_vjp(x, y):
      return (2 * x, 3 * y), lambda ts: (4 * ts[0], 5 * ts[1])

    api.defvjp_all(f, f_vjp, )
    api.grad(lambda x, y: f(x, y)[0])(1., 2.)  # doesn't crash