Пример #1
0
  def test_defvjp_all_higher_order_revmode(self):
    foo_p = Primitive('foo')
    def foo(x): return 2. * foo_p.bind(x)

    ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (g * x ** 2,)))
    ans = api.grad(api.grad(foo))(3.)
    self.assertAllClose(ans, 2 * 2 * 3., check_dtypes=False)
Пример #2
0
  def test_defvjp_all_const(self):
    foo_p = Primitive('foo')
    def foo(x): return foo_p.bind(x)

    ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (12.,)))
    val_ans, grad_ans = api.value_and_grad(foo)(3.)
    self.assertAllClose(val_ans, 9., check_dtypes=False)
    self.assertAllClose(grad_ans, 12., check_dtypes=True)
Пример #3
0
  def test_defvjp_all(self):
    foo_p = Primitive('foo')
    def foo(x): return 2. * foo_p.bind(x)

    ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (4 * g * np.sin(x),)))
    val_ans, grad_ans = api.value_and_grad(foo)(3.)
    self.assertAllClose(val_ans, 2 * 3.**2, check_dtypes=False)
    self.assertAllClose(grad_ans, 4 * 2 * onp.sin(3.), check_dtypes=False)
Пример #4
0
  def test_defvjp_all_multiple_arguments(self):
    # also tests passing in symbolic zero tangents b/c we differentiate wrt only
    # the first argument in one case

    foo_p = Primitive('foo')
    def foo(x, y): return foo_p.bind(x, y)

    def vjpfun(x, y):
      out = x**2 + y**3
      vjp = lambda g: (g + x + y, g * x * 9.)
      return out, vjp

    ad.defvjp_all(foo_p, vjpfun)
    val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
    self.assertAllClose(val_ans, 3.**2 + 4.**3, check_dtypes=False)
    self.assertAllClose(grad_ans, 1. + 3. + 4., check_dtypes=False)

    ans = api.grad(foo, (0, 1))(3., 4.)
    self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False)
Пример #5
0
    def decorator(fenics_function: Callable) -> Callable:
        @functools.wraps(fenics_function)
        def jax_solve_eval(*args):
            return jax_solve_eval_p.bind(*args)

        jax_solve_eval_p = Primitive("jax_solve_eval")
        jax_solve_eval_p.def_impl(lambda *args: solve_eval(
            fenics_function, fenics_templates, *args)[0])

        jax_solve_eval_p.def_abstract_eval(
            lambda *args: jax.abstract_arrays.make_shaped_array(
                solve_eval(fenics_function, fenics_templates, *args)[0]))

        def jax_solve_eval_batch(vector_arg_values, batch_axes):
            assert len(
                set(batch_axes)) == 1  # assert that all batch axes are same
            assert (
                batch_axes[0] == 0
            )  # assert that batch axis is zero, need to rewrite for a general case?
            # compute function row-by-row
            res = np.asarray([
                jax_solve_eval(*(vector_arg_values[j][i]
                                 for j in range(len(batch_axes))))
                for i in range(vector_arg_values[0].shape[0])
            ])
            return res, batch_axes[0]

        jax.batching.primitive_batchers[
            jax_solve_eval_p] = jax_solve_eval_batch

        # @trace("djax_solve_eval")
        def djax_solve_eval(*args):
            return djax_solve_eval_p.bind(*args)

        djax_solve_eval_p = Primitive("djax_solve_eval")
        # djax_solve_eval_p.multiple_results = True
        djax_solve_eval_p.def_impl(lambda *args: vjp_solve_eval(
            fenics_function, fenics_templates, *args))

        defvjp_all(jax_solve_eval_p, djax_solve_eval)
        return jax_solve_eval
Пример #6
0
    def decorator(fenics_function: Callable) -> Callable:
        def jax_assemble_eval(*args):
            return jax_assemble_eval_p.bind(*args)

        jax_assemble_eval_p = Primitive("jax_assemble_eval")
        jax_assemble_eval_p.def_impl(
            lambda *args: assemble_eval(fenics_function, fenics_templates, *args)[0]
        )

        jax_assemble_eval_p.def_abstract_eval(
            lambda *args: jax.abstract_arrays.make_shaped_array(
                assemble_eval(fenics_function, fenics_templates, *args)[0]
            )
        )

        def jax_assemble_eval_batch(vector_arg_values, batch_axes):
            assert len(set(batch_axes)) == 1  # assert that all batch axes are same
            assert (
                batch_axes[0] == 0
            )  # assert that batch axis is zero, need to rewrite for a general case?
            res = list(map(jax_assemble_eval, *vector_arg_values))
            res = np.asarray(res)
            return res, batch_axes[0]

        batching.primitive_batchers[jax_assemble_eval_p] = jax_assemble_eval_batch

        # @trace("djax_assemble_eval")
        def djax_assemble_eval(*args):
            return djax_assemble_eval_p.bind(*args)

        djax_assemble_eval_p = Primitive("djax_assemble_eval")
        # djax_assemble_eval_p.multiple_results = True
        djax_assemble_eval_p.def_impl(
            lambda *args: vjp_assemble_eval(fenics_function, fenics_templates, *args)
        )

        defvjp_all(jax_assemble_eval_p, djax_assemble_eval)

        return jax_assemble_eval