def remat_partial_eval(trace, *tracers, jaxpr, **params): assert not jaxpr.constvars policy = params['policy'] or (lambda *_, **__: False) # unzip into jaxpr_known and jaxpr_unknown in_unknowns = [not t.is_known() for t in tracers] jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \ pe._partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy) jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars)) _, used_outs_unknown = partition_list(out_inst, out_unknowns) jaxpr_unknown, in_used_unknown = pe.dce_jaxpr(jaxpr_unknown, used_outs_unknown) # compute known outputs and residuals (hoisted out of remat primitive) _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known()) _, in_consts = partition_list(in_used_known, in_consts_) out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts) out_consts_ = iter(out_consts) # form known outputs and collect residual tracers out_known_tracers = [ pe.JaxprTracer(trace, pe.PartialVal.known(next(out_consts_)), None) for uk in out_unknowns if not uk] residuals = list(out_consts_) # set up unknown outputs with a recipe to call remat res_tracers = map(trace.new_instantiated_const, residuals) in_jaxpr_tracers = [*res_tracers, *map(trace.instantiate_const, tracers)] _, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers) out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None) for x in jaxpr_unknown.outvars] new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True) recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p, new_params, source_info_util.current()) for t in out_jaxpr_tracers: t.recipe = recipe # zip together known and unknown outputs return pe._zip_knowns(out_known_tracers, out_jaxpr_tracers, out_unknowns)
def _axis_index_bind(*, axis_name): dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env frame = dynamic_axis_env[axis_name] trace = frame.pmap_trace out_aval = ShapedArray((), np.int32) out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p, dict(axis_name=axis_name), source_info_util.current()) out_tracer.recipe = eqn return out_tracer
def _scan_partial_eval(trace, *tracers, **kwargs): forward, length, num_consts, num_carry, jaxpr, linear = split_dict( kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"]) num_xs = len(jaxpr.in_avals) - num_carry - num_consts num_ys = len(jaxpr.out_avals) - num_carry unknowns = original_unknowns = [t.pval[0] is not None for t in tracers] const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry]) carry_uk = init_uk for _ in range(1000): unknowns = const_uk + carry_uk + xs_uk jaxpr_1, jaxpr_2, out_uk = pe.partial_eval_jaxpr( jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys) carry_uk_out, ys_uk = out_uk[:num_carry], out_uk[num_carry:] if carry_uk_out == carry_uk: break else: carry_uk = carry_uk_out else: raise FixedPointError in_consts = [core.unit if uk else t.pval[1] for uk, t in zip(unknowns, tracers)] new_tracers = [trace.instantiate_const(t) if uk else trace.new_instantiated_literal(core.unit) for uk, t in zip(unknowns, tracers)] carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) ys_avals = _map(partial(_promote_aval_rank, length), y_avals) out_avals = carry_avals + ys_avals out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uk)] linear_1 = [lin or uk for uk, lin in zip(unknowns, linear)] out_flat = scan_p.bind( *in_consts, forward=forward, length=length, jaxpr=jaxpr_1, num_consts=num_consts, num_carry=num_carry, linear=linear_1) out_carry, ys, residuals = split_list(out_flat, [num_carry, num_ys]) out_consts = out_carry + ys residual_tracers = _map(trace.new_instantiated_const, residuals) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None) for pv, const in zip(out_pvs, out_consts)] linear_2 = ([lin or not uk for uk, lin in zip(unknowns, linear)] + [False] * len(residual_tracers)) eqn = pe.new_jaxpr_eqn(new_tracers + residual_tracers, out_tracers, scan_p, (), dict(forward=forward, length=length, jaxpr=jaxpr_2, num_consts=num_consts, num_carry=num_carry, linear=linear_2)) for t in out_tracers: t.recipe = eqn return out_tracers
def remat_partial_eval(trace, *tracers, jaxpr, **params): assert not jaxpr.constvars policy = params['policy'] or nothing_saveable in_unknowns = [not t.is_known() for t in tracers] jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res = \ pe.partial_eval_jaxpr_custom( jaxpr, in_unknowns, [True] * len(in_unknowns), False, False, policy) # DCE jaxpr_staged, keeping only instantiated outputs which are unknown _, out_inst_unknown = partition_list(out_inst, out_unknowns) jaxpr_unknown, in_used_staged = pe.dce_jaxpr(jaxpr_staged, out_inst_unknown) used_res, in_used_staged = split_list(in_used_staged, [num_res]) # DCE jaxpr_known, keeping all known outputs but discarding dce'd res out_used_known = [True ] * (len(out_unknowns) - sum(out_unknowns)) + used_res jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, out_used_known) num_res = sum(used_res) # compute known outputs and residuals (hoisted out of remat primitive) _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known()) _, in_consts = partition_list(in_used_known, in_consts_) out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts) out_knowns, residuals = split_list(out_consts, [len(out_consts) - num_res]) # set up unknown outputs with a recipe to call remat res_tracers = map(trace.new_instantiated_const, residuals) _, tracers_staged = partition_list(in_used_staged, tracers) in_jaxpr_tracers = res_tracers + map(trace.instantiate_const, tracers_staged) out_jaxpr_tracers = [ pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None) for x in jaxpr_unknown.outvars ] new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True) recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p, new_params, jaxpr_unknown.effects, source_info_util.current()) for t in out_jaxpr_tracers: t.recipe = recipe # zip together known and unknown outputs return merge_lists(out_unknowns, out_knowns, out_jaxpr_tracers)
def _scan_partial_eval(trace, *tracers, **kwargs): jaxpr = kwargs.pop('jaxpr') length = kwargs.pop('length') forward = kwargs.pop('forward') assert not kwargs in_pvs, _ = unzip2([t.pval for t in tracers]) sc_consts, sc_init, sc_xs = map(pe.unknown, in_pvs) sc_carry = sc_init for i in range(1000): second_components = (sc_consts, sc_carry, sc_xs) jaxpr_1, jaxpr_2, sc_out = pe.partial_eval_jaxpr(jaxpr, second_components, instantiate=(sc_carry, False)) sc_carry_out, sc_ys = sc_out if sc_carry_out == sc_carry: break else: sc_carry = _binary_lattice_join(sc_carry, sc_carry_out) else: raise FixedPointError consts_tracer, init_tracer, xs_tracer = tracers lifted_init_tracer = _lift_tracer(trace, init_tracer, sc_carry) lifted_tracers = consts_tracer, lifted_init_tracer, xs_tracer in_pvs, in_consts = unzip2([t.pval for t in lifted_tracers]) carry_aval, y_aval = jaxpr.out_aval ys_aval = _promote_aval_rank(length, y_aval) out_aval = core.AbstractTuple((carry_aval, ys_aval)) out_pv = _put_known_pvs(sc_out, out_aval) out_carry, (ys, residuals) = scan_p.bind(*in_consts, forward=forward, length=length, jaxpr=jaxpr_1) out_const = core.pack((out_carry, ys)) residuals_tracer = trace.new_instantiated_const(core.pack(residuals)) d, c, a = lifted_tracers new_tracers = (d, c, (a, residuals_tracer)) eqn = core.JaxprEqn(new_tracers, None, scan_p, (), True, False, dict(forward=forward, length=length, jaxpr=jaxpr_2)) return pe.JaxprTracer(trace, pe.PartialVal((out_pv, out_const)), eqn)
def _dynamic_xla_call_pe(trace, *tracers, jaxpr, num_consts): in_dim_tracers, tracers = split_list(tracers, [len(jaxpr.in_dim_binders)]) if any(not t.pval.is_known() for t in in_dim_tracers): raise NotImplementedError in_unknowns = [not t.pval.is_known() for t in tracers] jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns) known_tracers, unknown_tracers = partition_list(in_unknowns, tracers) known_vals = [t.pval.get_known() for t in known_tracers] in_dim_vals = [t.pval.get_known() for t in in_dim_tracers] outs1_res = dynamic_xla_call_p.bind(*in_dim_vals, *known_vals, jaxpr=jaxpr1, num_consts=num_consts) outs1, res = split_list(outs1_res, [len(jaxpr1.outs) - num_res]) in_dim_tracers = map(trace.new_instantiated_const, in_dim_tracers) res_tracers = map(trace.new_instantiated_const, res) outs2 = [pe.JaxprTracer(trace, pe.PartialVal.unknown(v.aval), None) for v in jaxpr2.outs] eqn = pe.new_eqn_recipe(in_dim_tracers + res_tracers + unknown_tracers, outs2, dynamic_xla_call_p, dict(jaxpr=jaxpr2, num_consts=0), None) for t in outs2: t.recipe = eqn outs1, outs2 = iter(outs1), iter(outs2) return [next(outs2) if uk else next(outs1) for uk in out_unknowns]
def _cond_partial_eval(trace, *tracers, branches, linear): in_unknowns = [t.pval[0] is not None for t in tracers] index_uk, *ops_uk = in_unknowns if index_uk: # When the branch index is unknown, we stage out the whole cond. # TODO(mattjj): remove this path when old remat is removed params = dict(branches=branches, linear=linear) return trace.default_process_primitive(cond_p, tracers, params) branches_out_uks = [] for branch_jaxpr in branches: _, _, out_uks, _ = pe.partial_eval_jaxpr_nounits(branch_jaxpr, ops_uk, instantiate=False) branches_out_uks.append(out_uks) out_uks = [any(uks) for uks in zip(*branches_out_uks)] branches_known, branches_unknown, branch_res_avals = [], [], [] for branch_jaxpr in branches: branch_jaxpr_known, branch_jaxpr_unknown, _, res_avals = \ pe.partial_eval_jaxpr_nounits(branch_jaxpr, ops_uk, instantiate=out_uks) branches_known.append(branch_jaxpr_known) branches_unknown.append(branch_jaxpr_unknown) branch_res_avals.append(res_avals) all_res_avals, res_avals_per_branch = _merge_branch_residuals( branch_res_avals) num_res = len(all_res_avals) num_known_outs = len(out_uks) - sum(out_uks) branches_known = _join_cond_outputs(branches_known, all_res_avals, res_avals_per_branch, num_known_outs) branches_unknown = _join_cond_pe_staged_jaxpr_inputs( branches_unknown, all_res_avals, res_avals_per_branch) assert all( all(_map(core.typematch, j.out_avals, branches_known[0].out_avals)) for j in branches_known[1:]) in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()] linear_known = [l for l, uk in zip(linear, ops_uk) if not uk] out_consts_res = cond_p.bind(*in_consts, branches=branches_known, linear=tuple(linear_known)) out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res]) index_tracer = trace.instantiate_const(tracers[0]) ops_tracers = [ trace.instantiate_const(t) for uk, t in zip(in_unknowns[1:], tracers[1:]) if uk ] res_tracers = _map(trace.new_instantiated_const, res) out_tracers = [ pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) for aval in branches_unknown[0].out_avals ] linear_unknown = ([False] * num_res + [l for l, uk in zip(linear, in_unknowns[1:]) if uk]) params = dict(branches=branches_unknown, linear=tuple(linear_unknown)) name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] source = source_info_util.current().replace(name_stack=name_stack) eqn = pe.new_eqn_recipe([index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params, core.no_effects, source) for t in out_tracers: t.recipe = eqn return util.merge_lists(out_uks, out_consts, out_tracers)