def trace_jaxpr(fun, operand): op_flat, in_tree = pytree_to_flatjaxtuple(operand) fun_flat, out_tree = pytree_fun_to_flatjaxtuple_fun( lu.wrap_init(fun), (in_tree, )) jaxpr, pvout, consts = pe.trace_to_jaxpr(fun_flat, (lax._abstractify(op_flat), )) return op_flat, jaxpr, consts, pvout, out_tree
def _tscan(f, a, bs, fields=(0, )): """ Works as jax.lax.scan but has additional `fields` argument to select only necessary fields from `a`'s structure. Defaults to selecting only the first field. Other fields will be filled by None. """ # Note: code is copied and modified from lax.scan implementation in # [JAX](https://github.com/google/jax) to support the additional `fields` # arg. Original code has the following copyright: # # Copyright 2018 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License") # convert pytree to flat jaxtuple a, a_tree = pytree_to_flatjaxtuple(a) bs, b_tree = pytree_to_flatjaxtuple(bs) fields, _ = pytree_to_flatjaxtuple(fields) f, out_tree = pytree_fun_to_flatjaxtuple_fun(wrap_init(f), (a_tree, b_tree)) # convert arrays to abstract values a_aval, _ = lax._abstractify(a) bs_aval, _ = lax._abstractify(bs) # convert bs to b b_aval = core.AbstractTuple( [ShapedArray(b.shape[1:], b.dtype) for b in bs_aval]) # convert abstract values to partial values (?) then evaluate to get jaxpr a_pval = partial_eval.PartialVal((a_aval, core.unit)) b_pval = partial_eval.PartialVal((b_aval, core.unit)) jaxpr, pval_out, consts = partial_eval.trace_to_jaxpr(f, (a_pval, b_pval)) aval_out, _ = pval_out consts = core.pack(consts) out = tscan_p.bind(a, bs, fields, consts, aval_out=aval_out, jaxpr=jaxpr) return tree_unflatten(out_tree(), out)