Exemplo n.º 1
0
 def replace(eqn):
     invars = [
         in2out[i] if (isinstance(i, jc.Var) and i in in2out) else i
         for i in eqn.invars
     ]
     return jc.JaxprEqn(invars, eqn.outvars, eqn.primitive, eqn.params,
                        eqn.source_info)
Exemplo n.º 2
0
def lazy_eval_jaxpr(jaxpr, consts, *args):
    def read(v):
        if type(v) in {jc.Literal, Literal_}:
            return v.val
        else:
            return env[v]

    def write(v, val):
        env[v] = val

    env = {}
    write(jc.unitvar, jc.unit)
    map(write, jaxpr.constvars, consts)
    map(write, jaxpr.invars, args)
    for eqn in jaxpr.eqns:
        call_jaxpr, params = jc.extract_call_jaxpr(eqn.primitive, eqn.params)
        if call_jaxpr:
            raise NotImplementedError
        map(write, eqn.outvars, map(LazyArray, eqn.outvars))
    for eqn in jaxpr.eqns:
        invals = map(read, eqn.invars)
        outvals = map(read, eqn.outvars)
        new_eqn = jc.JaxprEqn(invals, outvals, eqn.primitive, eqn.params,
                              eqn.source_info)
        map(lambda arr: arr.set_eqn(new_eqn), outvals)
    return map(read, jaxpr.outvars)
Exemplo n.º 3
0
def _scan_partial_eval(trace, *tracers, **kwargs):
    jaxpr = kwargs.pop('jaxpr')
    length = kwargs.pop('length')
    forward = kwargs.pop('forward')
    assert not kwargs
    in_pvs, _ = unzip2([t.pval for t in tracers])
    sc_consts, sc_init, sc_xs = map(pe.unknown, in_pvs)

    sc_carry = sc_init
    for i in range(1000):
        second_components = (sc_consts, sc_carry, sc_xs)
        jaxpr_1, jaxpr_2, sc_out = pe.partial_eval_jaxpr(jaxpr,
                                                         second_components,
                                                         instantiate=(sc_carry,
                                                                      False))
        sc_carry_out, sc_ys = sc_out
        if sc_carry_out == sc_carry:
            break
        else:
            sc_carry = _binary_lattice_join(sc_carry, sc_carry_out)
    else:
        raise FixedPointError

    consts_tracer, init_tracer, xs_tracer = tracers
    lifted_init_tracer = _lift_tracer(trace, init_tracer, sc_carry)
    lifted_tracers = consts_tracer, lifted_init_tracer, xs_tracer
    in_pvs, in_consts = unzip2([t.pval for t in lifted_tracers])

    carry_aval, y_aval = jaxpr.out_aval
    ys_aval = _promote_aval_rank(length, y_aval)
    out_aval = core.AbstractTuple((carry_aval, ys_aval))
    out_pv = _put_known_pvs(sc_out, out_aval)

    out_carry, (ys, residuals) = scan_p.bind(*in_consts,
                                             forward=forward,
                                             length=length,
                                             jaxpr=jaxpr_1)
    out_const = core.pack((out_carry, ys))
    residuals_tracer = trace.new_instantiated_const(core.pack(residuals))
    d, c, a = lifted_tracers
    new_tracers = (d, c, (a, residuals_tracer))
    eqn = core.JaxprEqn(new_tracers, None, scan_p, (), True, False,
                        dict(forward=forward, length=length, jaxpr=jaxpr_2))
    return pe.JaxprTracer(trace, pe.PartialVal((out_pv, out_const)), eqn)
Exemplo n.º 4
0
def _pack_eqn(invars, outvar):
    return core.JaxprEqn(invars, [outvar], core.pack_p, (), False, {})
Exemplo n.º 5
0
def _unpack_eqn(invar, outvars):
    return core.JaxprEqn([invar], outvars, core.identity_p, (), True, {})
Exemplo n.º 6
0
def _add_any_eqn(tot, a, b):
  return core.JaxprEqn([a, b], [tot], ad_util.add_jaxvals_p, (), False, False, {})