def mat_power(mat_m, p): """Computes mat_m^p, for p == 1, 2, 4 or 8. Args: mat_m: a square matrix p: a positive integer Returns: mat_m^p """ # We unrolled the loop for performance reasons. exponent = jnp.round(jnp.log2(p)) return lax.switch( jnp.asarray(exponent, jnp.int32), [ _unrolled_mat_pow_1, _unrolled_mat_pow_2, _unrolled_mat_pow_4, _unrolled_mat_pow_8, ], (mat_m))
def body_f(iterp: _CGSteihaugState) -> _CGSteihaugState: # do an iteration Bd = hessvp(iterp.d) dBd = _dot(iterp.d, Bd) # after 1 r_squared = _dot(iterp.r, iterp.r) alpha = r_squared / dBd z_next = iterp.z + alpha * iterp.d # after 2 r_next = iterp.r + alpha * Bd r_next_squared = _dot(r_next, r_next) # include a junk switch to catch the case where none should be executed index = jnp.argmax( jnp.array([ False, dBd <= 0, jnpla.norm(z_next, ord=norm) >= trust_radius, jnp.sqrt(r_next_squared) < tolerance ])) result = lax.switch(index, [noop, step1, step2, step3], (iterp, z_next)) # update the state for the next iteration beta_next = r_next_squared / r_squared d_next = -r_next + beta_next * iterp.d state = _CGSteihaugState(z=z_next, r=r_next, d=d_next, step=result.step, hits_boundary=result.hits_boundary, converged=result.converged) return state
def test_check_jaxpr_cond_correct(self): jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))( 1.).jaxpr core.check_jaxpr(jaxpr)
def f(): branches = [lambda _: (), err, lambda _: ()] return lax.switch(1, branches, ())