Beispiel #1
0
def cond(pred, true_operand, true_fun, false_operand, false_fun):
  def trace_jaxpr(fun, operand):
    op_flat, in_tree = pytree_to_flatjaxtuple(operand)
    fun_flat, out_tree = pytree_fun_to_flatjaxtuple_fun(lu.wrap_init(fun), (in_tree,))
    jaxpr, pvout, consts = pe.trace_to_jaxpr(fun_flat, (_abstractify(op_flat),))
    return op_flat, jaxpr, consts, pvout, out_tree

  true_data = trace_jaxpr(true_fun, true_operand)
  true_op, true_jaxpr, true_consts, true_pval, true_tree = true_data
  false_data = trace_jaxpr(false_fun, false_operand)
  false_op, false_jaxpr, false_consts, false_pval, false_tree = false_data

  if true_tree() != false_tree():
    msg = "true_fun and false_fun outputs must have identical structure"
    raise TypeError(msg)

  try:
    joined_pval = pe.join_pvals(true_pval, false_pval)
  except TypeError:
    msg = "could not merge true_fun and false_fun output pvals: {} and {}."
    raise TypeError(msg.format(true_pval, false_pval))
  revis = _revise_cond_jaxpr(joined_pval, true_pval, true_jaxpr, true_consts)
  true_jaxpr, true_consts = revis
  revis = _revise_cond_jaxpr(joined_pval, false_pval, false_jaxpr, false_consts)
  false_jaxpr, false_consts = revis
  aval_out, _ = joined_pval

  out = cond_p.bind(pred, true_op, core.pack(true_consts), false_op,
                    core.pack(false_consts), aval_out=aval_out,
                    true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
  out = pe.merge_pvals(out, joined_pval)
  return tree_unflatten(true_tree(), out)
Beispiel #2
0
def check_trace_eval(f, pvals, vals, expected_out_pval):
    jaxpr, consts, out_pval, _ = api.trace_to_jaxpr(f, pvals)
    assert expected_out_pval == out_pval, (expected_out_pval, out_pval)
    output_traced = core.eval_jaxpr(jaxpr, consts, (), *vals)
    output_traced = pe.merge_pvals(output_traced, out_pval)
    output_eval = f(*vals)
    assert onp.allclose(output_traced, output_eval), \
        '\neval:         {}\ntrace + eval: {}'.format(output_eval, output_traced)