示例#1
0
文件: lax_linalg.py 项目: yotarok/jax
    if not transpose_a:
        transpose = xops.TriangularSolveOptions_Transpose.NO_TRANSPOSE
    else:
        transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT
                     if conjugate_a else
                     xops.TriangularSolveOptions_Transpose.TRANSPOSE)
    return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal,
                                transpose)


triangular_solve_p = standard_primitive(
    triangular_solve_shape_rule,
    triangular_solve_dtype_rule,
    'triangular_solve',
    translation_rule=_triangular_solve_translation_rule)
ad.defjvp2(triangular_solve_p, triangular_solve_jvp_rule_a,
           lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws))
ad.primitive_transposes[triangular_solve_p] = triangular_solve_transpose_rule
batching.primitive_batchers[
    triangular_solve_p] = triangular_solve_batching_rule


def _triangular_solve_cpu_translation_rule(c, a, b, left_side, lower,
                                           transpose_a, conjugate_a,
                                           unit_diagonal):
    shape = c.GetShape(a)
    dtype = shape.element_type().type

    if conjugate_a and not transpose_a:
        a = xops.Conj(a)
        conjugate_a = False
    if len(shape.dimensions()) == 2 and onp.dtype(dtype) in _cpu_lapack_types:
示例#2
0
def _gamma_batching_rule(batched_args, batch_dims):
    k, a = batched_args
    bk, ba = batch_dims
    size = next(t.shape[i] for t, i in zip(batched_args, batch_dims)
                if i is not None)
    k = batching.bdim_at_front(k, bk, size)
    a = batching.bdim_at_front(a, ba, size)
    return random_gamma_p.bind(k, a), (0, )


random_gamma_p = core.Primitive('random_gamma')
random_gamma_p.multiple_results = True
random_gamma_p.def_impl(_gamma_impl)
random_gamma_p.def_abstract_eval(lambda key, a:
                                 (abstract_arrays.raise_to_shaped(a), ))
ad.defjvp2(random_gamma_p, None, lambda tangent, ans, key, a:
           (tangent * _gamma_grad(ans[0], a), ))
xla.translations[random_gamma_p] = xla.lower_fun(_gamma_impl)
batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule


def gamma(key, a, shape=None, dtype=onp.float64):
    """Sample Gamma random values with given shape and float dtype.

  Args:
    key: a PRNGKey used as the random key.
    a: a float or array of floats broadcast-compatible with ``shape``
      representing the parameter of the distribution.
    shape: optional, a tuple of nonnegative integers specifying the result
      shape. Must be broadcast-compatible with ``a``. The default (None)
      produces a result shape equal to ``a.shape``.
    dtype: optional, a float dtype for the returned values (default float64 if
示例#3
0
    return lax.maybe_tracer_tuple_to_abstract_tuple(aval)


def _standard_gamma_translate(c, key, alpha, jaxpr, aval, consts):
    xla_computation = xla.jaxpr_computation(jaxpr, consts, (), c.GetShape(key),
                                            c.GetShape(alpha))
    return c.Call(xla_computation, (key, alpha))


# define primitive
standard_gamma_p = Primitive('standard_gamma')
standard_gamma_p.def_impl(partial(xla.apply_primitive, standard_gamma_p))
standard_gamma_p.def_abstract_eval(_standard_gamma_abstract_eval)
xla.translations[standard_gamma_p] = _standard_gamma_translate
ad.defjvp2(
    standard_gamma_p, None, lambda tangent, sample, key, alpha, **kwargs:
    tangent * _standard_gamma_grad(sample, alpha))


@partial(jit, static_argnums=(2, 3))
def standard_gamma(key, alpha, shape=(), dtype=np.float32):
    shape = shape or np.shape(alpha)
    alpha = lax.convert_element_type(alpha, dtype)
    if np.shape(alpha) != shape:
        alpha = np.broadcast_to(alpha, shape)
    jaxpr, out, consts = partial_eval.trace_unwrapped_to_jaxpr(
        _standard_gamma_impl, tuple(lax._abstractify(o) for o in (key, alpha)))
    aval, _ = out
    return standard_gamma_p.bind(key,
                                 alpha,
                                 jaxpr=jaxpr,
示例#4
0

# XXX work around the issue: batching rule for 'reduce_window' not implemented
# when using @custom_transforms decorator
def _cumprod_impl(x):
    return np.cumprod(x, axis=-1)


cumprod_p = core.Primitive('cumprod')
cumprod_p.def_impl(_cumprod_impl)
cumprod_p.def_abstract_eval(
    partial(partial_eval.abstract_eval_fun, _cumprod_impl))
xla.translations[cumprod_p] = partial(xla.lower_fun, _cumprod_impl)
# XXX this implementation does not address the case x=0, hence the result in that case will be nan
# Ref: https://stackoverflow.com/questions/40916955/how-to-compute-gradient-of-cumprod-safely
ad.defjvp2(cumprod_p, lambda g, ans, x: np.cumsum(g / x, axis=-1) * ans)
batching.defvectorized(cumprod_p)


def cumprod(x):
    return cumprod_p.bind(x)


def promote_shapes(*args, shape=()):
    # adapted from lax.lax_numpy
    if len(args) < 2 and not shape:
        return args
    else:
        shapes = [np.shape(arg) for arg in args]
        num_dims = len(lax.broadcast_shapes(shape, *shapes))
        return [
示例#5
0

def _standard_gamma_grad(sample, alpha):
    samples = np.reshape(sample, -1)
    alphas = np.reshape(alpha, -1)
    grads = vmap(_standard_gamma_grad_one)(samples, alphas)
    return grads.reshape(alpha.shape)


@custom_transforms
def _standard_gamma_p(key, alpha):
    return _standard_gamma_impl(key, alpha)


ad.defjvp2(
    _standard_gamma_p.primitive, None, lambda tangent, sample, key, alpha, **
    kwargs: tangent * _standard_gamma_grad(sample, alpha))
batching.defvectorized(_standard_gamma_p.primitive)


@partial(jit, static_argnums=(2, 3))
def _standard_gamma(key, alpha, shape=(), dtype=np.float32):
    shape = shape or np.shape(alpha)
    alpha = lax.convert_element_type(alpha, dtype)
    if np.shape(alpha) != shape:
        alpha = np.broadcast_to(alpha, shape)
    return _standard_gamma_p(key, alpha)


def standard_gamma(key, alpha, shape=(), dtype=np.float32):
    return _standard_gamma(key, alpha, shape, dtype)