Beispiel #1
0
                                                  tmp,
                                                  left_side=True,
                                                  transpose_a=False,
                                                  lower=True)),
                             precision=lax.Precision.HIGHEST)
    return L, L_dot


def cholesky_batching_rule(batched_args, batch_dims):
    x, = batched_args
    bd, = batch_dims
    x = batching.moveaxis(x, bd, 0)
    return cholesky(x), 0


cholesky_p = standard_unop(_float | _complex, 'cholesky')
ad.primitive_jvps[cholesky_p] = cholesky_jvp_rule
batching.primitive_batchers[cholesky_p] = cholesky_batching_rule


def _nan_like(c, operand):
    shape = c.GetShape(operand)
    dtype = shape.element_type()
    if np.issubdtype(dtype, onp.complexfloating):
        nan = xb.constant(c, onp.array(onp.nan * (1. + 1j), dtype=dtype))
    else:
        nan = xb.constant(c, onp.array(onp.nan, dtype=dtype))
    return xops.Broadcast(nan, shape.dimensions())


def _cholesky_cpu_gpu_translation_rule(potrf_impl, c, operand):
Beispiel #2
0
                           sigma_dot,
                           left_side=False,
                           transpose_a=True,
                           lower=True)
    L_dot = lax.dot(
        L,
        phi(
            triangular_solve(L,
                             tmp,
                             left_side=True,
                             transpose_a=False,
                             lower=True)))
    return L, L_dot


cholesky_p = standard_unop(_float, 'cholesky')
ad.primitive_jvps[cholesky_p] = cholesky_jvp_rule

triangular_solve_dtype_rule = partial(binop_dtype_rule, _input_dtype,
                                      (_float | _complex, _float | _complex),
                                      'triangular_solve')


def triangular_solve_shape_rule(a, b, left_side=False, **unused_kwargs):
    if a.ndim < 2:
        msg = "triangular_solve requires a.ndim to be at least 2, got {}."
        raise TypeError(msg.format(a.ndim))
    if a.shape[-1] != a.shape[-2]:
        msg = (
            "triangular_solve requires the last two dimensions of a to be equal "
            "in size, got a.shape of {}.")