if not transpose_a: transpose = xops.TriangularSolveOptions_Transpose.NO_TRANSPOSE else: transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT if conjugate_a else xops.TriangularSolveOptions_Transpose.TRANSPOSE) return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose) triangular_solve_p = standard_primitive( triangular_solve_shape_rule, triangular_solve_dtype_rule, 'triangular_solve', translation_rule=_triangular_solve_translation_rule) ad.defjvp2(triangular_solve_p, triangular_solve_jvp_rule_a, lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws)) ad.primitive_transposes[triangular_solve_p] = triangular_solve_transpose_rule batching.primitive_batchers[ triangular_solve_p] = triangular_solve_batching_rule def _triangular_solve_cpu_translation_rule(c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal): shape = c.GetShape(a) dtype = shape.element_type().type if conjugate_a and not transpose_a: a = xops.Conj(a) conjugate_a = False if len(shape.dimensions()) == 2 and onp.dtype(dtype) in _cpu_lapack_types:
def _gamma_batching_rule(batched_args, batch_dims): k, a = batched_args bk, ba = batch_dims size = next(t.shape[i] for t, i in zip(batched_args, batch_dims) if i is not None) k = batching.bdim_at_front(k, bk, size) a = batching.bdim_at_front(a, ba, size) return random_gamma_p.bind(k, a), (0, ) random_gamma_p = core.Primitive('random_gamma') random_gamma_p.multiple_results = True random_gamma_p.def_impl(_gamma_impl) random_gamma_p.def_abstract_eval(lambda key, a: (abstract_arrays.raise_to_shaped(a), )) ad.defjvp2(random_gamma_p, None, lambda tangent, ans, key, a: (tangent * _gamma_grad(ans[0], a), )) xla.translations[random_gamma_p] = xla.lower_fun(_gamma_impl) batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule def gamma(key, a, shape=None, dtype=onp.float64): """Sample Gamma random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. a: a float or array of floats broadcast-compatible with ``shape`` representing the parameter of the distribution. shape: optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with ``a``. The default (None) produces a result shape equal to ``a.shape``. dtype: optional, a float dtype for the returned values (default float64 if
return lax.maybe_tracer_tuple_to_abstract_tuple(aval) def _standard_gamma_translate(c, key, alpha, jaxpr, aval, consts): xla_computation = xla.jaxpr_computation(jaxpr, consts, (), c.GetShape(key), c.GetShape(alpha)) return c.Call(xla_computation, (key, alpha)) # define primitive standard_gamma_p = Primitive('standard_gamma') standard_gamma_p.def_impl(partial(xla.apply_primitive, standard_gamma_p)) standard_gamma_p.def_abstract_eval(_standard_gamma_abstract_eval) xla.translations[standard_gamma_p] = _standard_gamma_translate ad.defjvp2( standard_gamma_p, None, lambda tangent, sample, key, alpha, **kwargs: tangent * _standard_gamma_grad(sample, alpha)) @partial(jit, static_argnums=(2, 3)) def standard_gamma(key, alpha, shape=(), dtype=np.float32): shape = shape or np.shape(alpha) alpha = lax.convert_element_type(alpha, dtype) if np.shape(alpha) != shape: alpha = np.broadcast_to(alpha, shape) jaxpr, out, consts = partial_eval.trace_unwrapped_to_jaxpr( _standard_gamma_impl, tuple(lax._abstractify(o) for o in (key, alpha))) aval, _ = out return standard_gamma_p.bind(key, alpha, jaxpr=jaxpr,
# XXX work around the issue: batching rule for 'reduce_window' not implemented # when using @custom_transforms decorator def _cumprod_impl(x): return np.cumprod(x, axis=-1) cumprod_p = core.Primitive('cumprod') cumprod_p.def_impl(_cumprod_impl) cumprod_p.def_abstract_eval( partial(partial_eval.abstract_eval_fun, _cumprod_impl)) xla.translations[cumprod_p] = partial(xla.lower_fun, _cumprod_impl) # XXX this implementation does not address the case x=0, hence the result in that case will be nan # Ref: https://stackoverflow.com/questions/40916955/how-to-compute-gradient-of-cumprod-safely ad.defjvp2(cumprod_p, lambda g, ans, x: np.cumsum(g / x, axis=-1) * ans) batching.defvectorized(cumprod_p) def cumprod(x): return cumprod_p.bind(x) def promote_shapes(*args, shape=()): # adapted from lax.lax_numpy if len(args) < 2 and not shape: return args else: shapes = [np.shape(arg) for arg in args] num_dims = len(lax.broadcast_shapes(shape, *shapes)) return [
def _standard_gamma_grad(sample, alpha): samples = np.reshape(sample, -1) alphas = np.reshape(alpha, -1) grads = vmap(_standard_gamma_grad_one)(samples, alphas) return grads.reshape(alpha.shape) @custom_transforms def _standard_gamma_p(key, alpha): return _standard_gamma_impl(key, alpha) ad.defjvp2( _standard_gamma_p.primitive, None, lambda tangent, sample, key, alpha, ** kwargs: tangent * _standard_gamma_grad(sample, alpha)) batching.defvectorized(_standard_gamma_p.primitive) @partial(jit, static_argnums=(2, 3)) def _standard_gamma(key, alpha, shape=(), dtype=np.float32): shape = shape or np.shape(alpha) alpha = lax.convert_element_type(alpha, dtype) if np.shape(alpha) != shape: alpha = np.broadcast_to(alpha, shape) return _standard_gamma_p(key, alpha) def standard_gamma(key, alpha, shape=(), dtype=np.float32): return _standard_gamma(key, alpha, shape, dtype)