예제 #1
  def finalize(f):
    """Decorate f with a custom gradient."""

    if JAX_MODE:

      # https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html

      # For JAX, we prefer to specify a custom JVP, as JAX can use a function
      # transform to transpose a JVP (must be linear in the tangents) to a VJP.
      if jvp_fn is not None:
        f_jvp = custom_jvp(f, nondiff_argnums=nondiff_argnums)
        return f_jvp

        from jax import custom_vjp  # pylint: disable=g-import-not-at-top
        f_vjp = custom_vjp(f, nondiff_argnums=nondiff_argnums)
        f_vjp.defvjp(vjp_fwd, vjp_bwd)
        return f_vjp

      # TF custom gradients support only custom VJPs.
      def f_wrapped(*args, **kwargs):
        val, aux = vjp_fwd(*args, **kwargs)
        def vjp_bwd_wrapped(*g):
          result = vjp_bwd(aux, tf.nest.pack_sequence_as(val, g))
          for i in nondiff_argnums:
            result = tuple(result[:i]) + (None,) + tuple(result[i:])
          return result
        return val, vjp_bwd_wrapped

      return f_wrapped
예제 #2
    def custom_gradient(self, f: Callable, gradient: Callable) -> Callable:
        jax_fun = jax.custom_jvp(f)

        def jax_grad(primals, tangents):
            grad = gradient(*tangents)
            return jax_fun(primals), grad

        return jax_fun
예제 #3
파일: shims.py 프로젝트: NeilGirdhar/tjax
 def __init__(self,
              fun: Callable[Concatenate[U, P], R_co],
              static_argnums: Tuple[int, ...] = ()):
         fun: the function to decorate.
         static_argnums: The indices of the static arguments.
     static_argnums = tuple(sorted(static_argnums))
     self.jvp = jax.custom_jvp(fun, nondiff_argnums=static_argnums)
예제 #4
  def finalize(f):
    """Decorate f with a custom gradient."""

    if JAX_MODE:

      # https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html

      # For JAX, we prefer to specify a custom JVP, as JAX can use a function
      # transform to transpose a JVP (must be linear in the tangents) to a VJP.
      if jvp_fn is not None:
        f_jvp = custom_jvp(f, nondiff_argnums=nondiff_argnums)
        return f_jvp

        from jax import custom_vjp  # pylint: disable=g-import-not-at-top
        f_vjp = custom_vjp(f, nondiff_argnums=nondiff_argnums)
        f_vjp.defvjp(vjp_fwd, vjp_bwd)
        return f_vjp

      # TF custom gradients support only custom VJPs.
      def none_wrapper(*args, **kwargs):  # custom_gradient can't handle None.
        closure = {i: a for i, a in enumerate(args)
                   if i in nondiff_argnums or a is None}
        trimmed_args = [a for i, a in enumerate(args) if i not in closure]

        def f_wrapped(*args, **kwargs):
          reconstruct_args = []
          args_structure = tf.nest.map_structure(lambda _: 0, args)
          for i in range(len(args) + len(closure)):
            if i in closure:
              args = args[1:]
          val, aux = vjp_fwd(*reconstruct_args, **kwargs)

          def vjp_bwd_wrapped(*g):
            result = tf.nest.flatten(
                vjp_bwd(aux, tf.nest.pack_sequence_as(val, g)))
            for i in nondiff_argnums:
              result = tuple(result[:i]) + (None,) + tuple(result[i:])
            result = [a for i, a in enumerate(result) if i not in closure]
            return tf.nest.pack_sequence_as(args_structure, result)

          return val, vjp_bwd_wrapped

        return f_wrapped(*trimmed_args, **kwargs)

      return none_wrapper
예제 #5
    return partial_det[..., -1], x

def _det(a):
    sign, logdet = slogdet(a)
    return sign * jnp.exp(logdet)

def _det_jvp(primals, tangents, grad_type='fast'):
    x, = primals
    g, = tangents
    y, z = _cofactor_solve(x, g, which=grad_type)
    return y, jnp.trace(z, axis1=-1, axis2=-2)

_det_fast = custom_jvp(lambda a: _det(a))
_det_fast.defjvp(partial(_det_jvp, grad_type='fast'))

_det_safe = custom_jvp(lambda a: _det(a))
_det_safe.defjvp(partial(_det_jvp, grad_type='safe'))

def det(a, grad_type='fast'):
    if grad_type == 'fast':
        return _det_fast(a)
    elif grad_type == 'safe':
        return _det_safe(a)
        raise ValueError(
            "Unrecognized grad type for Det: {}".format(grad_type))