示例#1
0
  def mat_power(mat_m, p):
    """Computes mat_m^p, for p == 1, 2, 4 or 8.

    Args:
      mat_m: a square matrix
      p: a positive integer

    Returns:
      mat_m^p
    """
    # We unrolled the loop for performance reasons.
    exponent = jnp.round(jnp.log2(p))
    return lax.switch(
        jnp.asarray(exponent, jnp.int32), [
            _unrolled_mat_pow_1,
            _unrolled_mat_pow_2,
            _unrolled_mat_pow_4,
            _unrolled_mat_pow_8,
        ], (mat_m))
示例#2
0
    def body_f(iterp: _CGSteihaugState) -> _CGSteihaugState:

        # do an iteration
        Bd = hessvp(iterp.d)
        dBd = _dot(iterp.d, Bd)

        # after 1
        r_squared = _dot(iterp.r, iterp.r)
        alpha = r_squared / dBd
        z_next = iterp.z + alpha * iterp.d

        # after 2
        r_next = iterp.r + alpha * Bd
        r_next_squared = _dot(r_next, r_next)

        # include a junk switch to catch the case where none should be executed
        index = jnp.argmax(
            jnp.array([
                False, dBd <= 0,
                jnpla.norm(z_next, ord=norm) >= trust_radius,
                jnp.sqrt(r_next_squared) < tolerance
            ]))
        result = lax.switch(index, [noop, step1, step2, step3],
                            (iterp, z_next))

        # update the state for the next iteration
        beta_next = r_next_squared / r_squared
        d_next = -r_next + beta_next * iterp.d

        state = _CGSteihaugState(z=z_next,
                                 r=r_next,
                                 d=d_next,
                                 step=result.step,
                                 hits_boundary=result.hits_boundary,
                                 converged=result.converged)
        return state
示例#3
0
 def test_check_jaxpr_cond_correct(self):
     jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(
         1.).jaxpr
     core.check_jaxpr(jaxpr)
示例#4
0
 def f():
     branches = [lambda _: (), err, lambda _: ()]
     return lax.switch(1, branches, ())