def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts]) # Check if the first cond application will error. cond_jaxpr_, msgs_cond = checkify_jaxpr(cond_jaxpr, error, enabled_errors) cond_err, cond_code, cond_payload, _ = core.jaxpr_as_fun(cond_jaxpr_)( error.err, error.code, error.payload, *c_consts, *carry) del cond_jaxpr_ checked_body_jaxpr_, msgs_body = checkify_while_body_jaxpr( cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts) to_move = [False] * 3 + [True] * body_nconsts + [False] * len(carry) checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move) compat_cond_jaxpr_ = ignore_errors_jaxpr(cond_jaxpr, error) to_move = [False] * 3 + [True] * cond_nconsts + [False] * len(carry) compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move) new_in_flat = [ *c_consts, *b_consts, cond_err, cond_code, cond_payload, *carry ] err, code, payload, *out = lax.while_p.bind(*new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr, body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr) new_msgs = {**error.msgs, **msgs_body, **msgs_cond} return out, Error(err, code, new_msgs, payload)
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, num_consts, num_carry, linear, unroll): consts, carry, xs = split_list(in_flat, [num_consts, num_carry]) checked_jaxpr_, msgs_ = checkify_jaxpr(jaxpr, error, enabled_errors) tomove = [False] * 3 + [True] * len(consts) + [False ] * (len(carry) + len(xs)) checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove) new_linear = (False, False, False, *linear) new_in_flat = [*consts, error.err, error.code, error.payload, *carry, *xs] err, code, payload, *outs = lax.scan_p.bind(*new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr, num_consts=len(consts), num_carry=len(carry) + 3, linear=new_linear, unroll=unroll) new_msgs = {**error.msgs, **msgs_} return outs, Error(err, code, new_msgs, payload)