示例#1
0
    if a.shape[-1] != b.shape[common_dim]:
        msg = "Incompatible shapes for arguments to triangular_solve: {} and {}."
        raise TypeError(msg.format(a.shape, b.shape))
    return b.shape


def triangular_solve_transpose_rule(cotangent, a, b, left_side, lower,
                                    transpose_a, conjugate_a):
    assert a is not None and b is None
    cotangent_b = triangular_solve(a, cotangent, left_side, lower,
                                   not transpose_a, conjugate_a)
    return [None, cotangent_b]


triangular_solve_p = standard_primitive(triangular_solve_shape_rule,
                                        triangular_solve_dtype_rule,
                                        'triangular_solve')
ad.defjvp(triangular_solve_p, None,
          lambda g_b, a, b, **kwargs: triangular_solve(a, g_b, **kwargs))
ad.primitive_transposes[triangular_solve_p] = triangular_solve_transpose_rule


def qr_impl(operand, full_matrices):
    q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices)
    return core.pack((q, r))


def qr_translation_rule(c, operand, full_matrices):
    return c.QR(operand, full_matrices=full_matrices)

示例#2
0
文件: lax_linalg.py 项目: yotarok/jax
    if conjugate_a and not transpose_a:
        a = xops.Conj(a)
        conjugate_a = False
    if not transpose_a:
        transpose = xops.TriangularSolveOptions_Transpose.NO_TRANSPOSE
    else:
        transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT
                     if conjugate_a else
                     xops.TriangularSolveOptions_Transpose.TRANSPOSE)
    return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal,
                                transpose)


triangular_solve_p = standard_primitive(
    triangular_solve_shape_rule,
    triangular_solve_dtype_rule,
    'triangular_solve',
    translation_rule=_triangular_solve_translation_rule)
ad.defjvp2(triangular_solve_p, triangular_solve_jvp_rule_a,
           lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws))
ad.primitive_transposes[triangular_solve_p] = triangular_solve_transpose_rule
batching.primitive_batchers[
    triangular_solve_p] = triangular_solve_batching_rule


def _triangular_solve_cpu_translation_rule(c, a, b, left_side, lower,
                                           transpose_a, conjugate_a,
                                           unit_diagonal):
    shape = c.GetShape(a)
    dtype = shape.element_type().type
示例#3
0
def transform_shape_rule(T, v):
    return v.shape


def transform_dtype_rule(T, v):
    return v.dtype


def transform_translation_rule(c, T, v):
    v_dim = len(c.GetShape(v).dimensions()) - 1
    return c.DotGeneral(v, T, (((v_dim, ), (0, )), ((), ())))


transform_p = lax.standard_primitive(transform_shape_rule,
                                     transform_dtype_rule, 'transform',
                                     transform_translation_rule)


def transform_batching_rule(operands, batch_dims):
    T, v = operands
    T_dim, v_dim = batch_dims

    assert T_dim is None
    assert v_dim == 0

    return transform_p.bind(T, v), v_dim


batching.primitive_batchers[transform_p] = transform_batching_rule