def replace(eqn): invars = [ in2out[i] if (isinstance(i, jc.Var) and i in in2out) else i for i in eqn.invars ] return jc.JaxprEqn(invars, eqn.outvars, eqn.primitive, eqn.params, eqn.source_info)
def lazy_eval_jaxpr(jaxpr, consts, *args): def read(v): if type(v) in {jc.Literal, Literal_}: return v.val else: return env[v] def write(v, val): env[v] = val env = {} write(jc.unitvar, jc.unit) map(write, jaxpr.constvars, consts) map(write, jaxpr.invars, args) for eqn in jaxpr.eqns: call_jaxpr, params = jc.extract_call_jaxpr(eqn.primitive, eqn.params) if call_jaxpr: raise NotImplementedError map(write, eqn.outvars, map(LazyArray, eqn.outvars)) for eqn in jaxpr.eqns: invals = map(read, eqn.invars) outvals = map(read, eqn.outvars) new_eqn = jc.JaxprEqn(invals, outvals, eqn.primitive, eqn.params, eqn.source_info) map(lambda arr: arr.set_eqn(new_eqn), outvals) return map(read, jaxpr.outvars)
def _scan_partial_eval(trace, *tracers, **kwargs): jaxpr = kwargs.pop('jaxpr') length = kwargs.pop('length') forward = kwargs.pop('forward') assert not kwargs in_pvs, _ = unzip2([t.pval for t in tracers]) sc_consts, sc_init, sc_xs = map(pe.unknown, in_pvs) sc_carry = sc_init for i in range(1000): second_components = (sc_consts, sc_carry, sc_xs) jaxpr_1, jaxpr_2, sc_out = pe.partial_eval_jaxpr(jaxpr, second_components, instantiate=(sc_carry, False)) sc_carry_out, sc_ys = sc_out if sc_carry_out == sc_carry: break else: sc_carry = _binary_lattice_join(sc_carry, sc_carry_out) else: raise FixedPointError consts_tracer, init_tracer, xs_tracer = tracers lifted_init_tracer = _lift_tracer(trace, init_tracer, sc_carry) lifted_tracers = consts_tracer, lifted_init_tracer, xs_tracer in_pvs, in_consts = unzip2([t.pval for t in lifted_tracers]) carry_aval, y_aval = jaxpr.out_aval ys_aval = _promote_aval_rank(length, y_aval) out_aval = core.AbstractTuple((carry_aval, ys_aval)) out_pv = _put_known_pvs(sc_out, out_aval) out_carry, (ys, residuals) = scan_p.bind(*in_consts, forward=forward, length=length, jaxpr=jaxpr_1) out_const = core.pack((out_carry, ys)) residuals_tracer = trace.new_instantiated_const(core.pack(residuals)) d, c, a = lifted_tracers new_tracers = (d, c, (a, residuals_tracer)) eqn = core.JaxprEqn(new_tracers, None, scan_p, (), True, False, dict(forward=forward, length=length, jaxpr=jaxpr_2)) return pe.JaxprTracer(trace, pe.PartialVal((out_pv, out_const)), eqn)
def _pack_eqn(invars, outvar): return core.JaxprEqn(invars, [outvar], core.pack_p, (), False, {})
def _unpack_eqn(invar, outvars): return core.JaxprEqn([invar], outvars, core.identity_p, (), True, {})
def _add_any_eqn(tot, a, b): return core.JaxprEqn([a, b], [tot], ad_util.add_jaxvals_p, (), False, False, {})