예제 #1
0
파일: checkify.py 프로젝트: xueeinstein/jax
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
                           cond_jaxpr, body_nconsts, body_jaxpr):
    c_consts, b_consts, carry = split_list(in_flat,
                                           [cond_nconsts, body_nconsts])

    # Check if the first cond application will error.
    cond_jaxpr_, msgs_cond = checkify_jaxpr(cond_jaxpr, error, enabled_errors)
    cond_err, cond_code, cond_payload, _ = core.jaxpr_as_fun(cond_jaxpr_)(
        error.err, error.code, error.payload, *c_consts, *carry)
    del cond_jaxpr_

    checked_body_jaxpr_, msgs_body = checkify_while_body_jaxpr(
        cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts)
    to_move = [False] * 3 + [True] * body_nconsts + [False] * len(carry)
    checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)

    compat_cond_jaxpr_ = ignore_errors_jaxpr(cond_jaxpr, error)
    to_move = [False] * 3 + [True] * cond_nconsts + [False] * len(carry)
    compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move)
    new_in_flat = [
        *c_consts, *b_consts, cond_err, cond_code, cond_payload, *carry
    ]

    err, code, payload, *out = lax.while_p.bind(*new_in_flat,
                                                cond_nconsts=cond_nconsts,
                                                cond_jaxpr=compat_cond_jaxpr,
                                                body_nconsts=body_nconsts,
                                                body_jaxpr=checked_body_jaxpr)
    new_msgs = {**error.msgs, **msgs_body, **msgs_cond}
    return out, Error(err, code, new_msgs, payload)
예제 #2
0
파일: checkify.py 프로젝트: xueeinstein/jax
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
                     num_consts, num_carry, linear, unroll):
    consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
    checked_jaxpr_, msgs_ = checkify_jaxpr(jaxpr, error, enabled_errors)
    tomove = [False] * 3 + [True] * len(consts) + [False
                                                   ] * (len(carry) + len(xs))
    checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove)
    new_linear = (False, False, False, *linear)
    new_in_flat = [*consts, error.err, error.code, error.payload, *carry, *xs]
    err, code, payload, *outs = lax.scan_p.bind(*new_in_flat,
                                                reverse=reverse,
                                                length=length,
                                                jaxpr=checked_jaxpr,
                                                num_consts=len(consts),
                                                num_carry=len(carry) + 3,
                                                linear=new_linear,
                                                unroll=unroll)
    new_msgs = {**error.msgs, **msgs_}
    return outs, Error(err, code, new_msgs, payload)