def cond(pred, true_operand, true_fun, false_operand, false_fun): def trace_jaxpr(fun, operand): op_flat, in_tree = pytree_to_flatjaxtuple(operand) fun_flat, out_tree = pytree_fun_to_flatjaxtuple_fun(lu.wrap_init(fun), (in_tree,)) jaxpr, pvout, consts = pe.trace_to_jaxpr(fun_flat, (_abstractify(op_flat),)) return op_flat, jaxpr, consts, pvout, out_tree true_data = trace_jaxpr(true_fun, true_operand) true_op, true_jaxpr, true_consts, true_pval, true_tree = true_data false_data = trace_jaxpr(false_fun, false_operand) false_op, false_jaxpr, false_consts, false_pval, false_tree = false_data if true_tree() != false_tree(): msg = "true_fun and false_fun outputs must have identical structure" raise TypeError(msg) try: joined_pval = pe.join_pvals(true_pval, false_pval) except TypeError: msg = "could not merge true_fun and false_fun output pvals: {} and {}." raise TypeError(msg.format(true_pval, false_pval)) revis = _revise_cond_jaxpr(joined_pval, true_pval, true_jaxpr, true_consts) true_jaxpr, true_consts = revis revis = _revise_cond_jaxpr(joined_pval, false_pval, false_jaxpr, false_consts) false_jaxpr, false_consts = revis aval_out, _ = joined_pval out = cond_p.bind(pred, true_op, core.pack(true_consts), false_op, core.pack(false_consts), aval_out=aval_out, true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr) out = pe.merge_pvals(out, joined_pval) return tree_unflatten(true_tree(), out)
def check_trace_eval(f, pvals, vals, expected_out_pval): jaxpr, consts, out_pval, _ = api.trace_to_jaxpr(f, pvals) assert expected_out_pval == out_pval, (expected_out_pval, out_pval) output_traced = core.eval_jaxpr(jaxpr, consts, (), *vals) output_traced = pe.merge_pvals(output_traced, out_pval) output_eval = f(*vals) assert onp.allclose(output_traced, output_eval), \ '\neval: {}\ntrace + eval: {}'.format(output_eval, output_traced)