def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params): from jax.interpreters.partial_eval import ( trace_to_subjaxpr_dynamic, DynamicJaxprTracer, source_info_util, convert_constvars_jaxpr, new_jaxpr_eqn) assert primitive is xmap_p in_avals = [t.aval for t in tracers] global_axis_sizes = params['global_axis_sizes'] mapped_in_avals = [_delete_aval_axes(a, a_in_axes) for a, a_in_axes in zip(in_avals, params['in_axes'])] with core.extend_axis_env_nd(global_axis_sizes.items()): jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic( f, self.main, mapped_in_avals) out_axes = params['out_axes_thunk']() axis_resource_count = _get_axis_resource_count(params['axis_resources'], params['resource_env']) local_axis_sizes = {axis: axis_resource_count[axis].to_local(global_size) for axis, global_size in global_axis_sizes.items()} out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes) for a, a_out_axes in zip(mapped_out_avals, out_axes)] _check_out_avals_vs_out_axes(out_avals, out_axes, params['global_axis_sizes']) source_info = source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) constvars = map(self.getvar, map(self.instantiate_const, consts)) outvars = map(self.makevar, out_tracers) new_in_axes = (AxisNamePos(user_repr='{}'),) * len(consts) + params['in_axes'] new_donated_invars = (False,) * len(consts) + params['donated_invars'] new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, donated_invars=new_donated_invars, call_jaxpr=convert_constvars_jaxpr(jaxpr)) del new_params['out_axes_thunk'] eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive, new_params, source_info) self.frame.eqns.append(eqn) return out_tracers
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params): from jax.interpreters.partial_eval import ( trace_to_subjaxpr_dynamic, DynamicJaxprTracer, source_info_util, convert_constvars_jaxpr, call_param_updaters, new_jaxpr_eqn) assert primitive is xmap_p in_avals = [t.aval for t in tracers] axis_sizes = params['axis_sizes'] mapped_in_avals = [_delete_aval_axes(a, a_in_axes) for a, a_in_axes in zip(in_avals, params['in_axes'])] with core.extend_axis_env_nd(params['axis_sizes'].items()): jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic( f, self.main, mapped_in_avals) out_axes = params['out_axes_thunk']() out_avals = [_insert_aval_axes(a, a_out_axes, axis_sizes) for a, a_out_axes in zip(mapped_out_avals, out_axes)] source_info = source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) constvars = map(self.getvar, map(self.instantiate_const, consts)) outvars = map(self.makevar, out_tracers) new_in_axes = (None,) * len(consts) + params['in_axes'] new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, call_jaxpr=convert_constvars_jaxpr(jaxpr)) del new_params['out_axes_thunk'] update_params = call_param_updaters.get(primitive) if update_params: new_params = update_params(new_params, [True] * len(tracers)) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive, new_params, source_info) self.frame.eqns.append(eqn) return out_tracers
def remat_dce(used_outputs: List[bool], eqn: core.JaxprEqn ) -> Tuple[List[bool], Optional[core.JaxprEqn]]: new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) new_params = dict(eqn.params, jaxpr=new_jaxpr) if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects: return used_inputs, None else: new_eqn = pe.new_jaxpr_eqn( [v for v, used in zip(eqn.invars, used_inputs) if used], [v for v, used in zip(eqn.outvars, used_outputs) if used], eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info) return used_inputs, new_eqn
def _broadcast_staging_rule(trace, tracers, params): x, d = tracers d_const = trace.get_const(d) if d_const is not None: raise NotImplementedError # TODO else: aval = x.aval dtype = aval._eltTy._dtype if isinstance(aval, AbsArray) else aval.dtype out_aval = AbsArray((d, *x.shape), BaseType(dtype)) out_tracer = pe.DynamicJaxprTracer(trace, out_aval, None) eqn = pe.new_jaxpr_eqn([trace.getvar(x), trace.getvar(d)], [trace.makevar(out_tracer)], broadcast_p, {}, None) trace.frame.eqns.append(eqn) return out_tracer
def _iota_staging_rule(trace, tracers, params): tracer, = tracers n = trace.get_const(tracer) if n is not None: if type(n) is not int: raise NotImplementedError # TODO batched version? out_aval = core.ShapedArray((n,), np.dtype('int32')) out_tracer = pe.DynamicJaxprTracer(trace, out_aval, None) outvar = trace.makevar(out_tracer) eqn = pe.new_jaxpr_eqn([], [outvar], iota_p, dict(size=n), None) else: aval = tracer.aval if not isinstance(aval, AbsArray): raise TypeError if aval.shape: indices = tuple(range(len(aval.shape))) out_aval = AbsArray((*aval.shape, DimIndexingExpr(tracer, indices)), BaseType(np.dtype('int32'))) else: out_aval = AbsArray((tracer,), BaseType(np.dtype('int32'))) out_tracer = pe.DynamicJaxprTracer(trace, out_aval, None) outvar = trace.makevar(out_tracer) invar = trace.getvar(tracer) eqn = pe.new_jaxpr_eqn([invar], [outvar], iota_p, {}, None) trace.frame.eqns.append(eqn) return out_tracer
def _scan_partial_eval(trace, *tracers, **kwargs): forward, length, num_consts, num_carry, jaxpr, linear = split_dict( kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"]) num_xs = len(jaxpr.in_avals) - num_carry - num_consts num_ys = len(jaxpr.out_avals) - num_carry unknowns = original_unknowns = [t.pval[0] is not None for t in tracers] const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry]) carry_uk = init_uk for _ in range(1000): unknowns = const_uk + carry_uk + xs_uk jaxpr_1, jaxpr_2, out_uk = pe.partial_eval_jaxpr( jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys) carry_uk_out, ys_uk = out_uk[:num_carry], out_uk[num_carry:] if carry_uk_out == carry_uk: break else: carry_uk = carry_uk_out else: raise FixedPointError in_consts = [core.unit if uk else t.pval[1] for uk, t in zip(unknowns, tracers)] new_tracers = [trace.instantiate_const(t) if uk else trace.new_instantiated_literal(core.unit) for uk, t in zip(unknowns, tracers)] carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) ys_avals = _map(partial(_promote_aval_rank, length), y_avals) out_avals = carry_avals + ys_avals out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uk)] linear_1 = [lin or uk for uk, lin in zip(unknowns, linear)] out_flat = scan_p.bind( *in_consts, forward=forward, length=length, jaxpr=jaxpr_1, num_consts=num_consts, num_carry=num_carry, linear=linear_1) out_carry, ys, residuals = split_list(out_flat, [num_carry, num_ys]) out_consts = out_carry + ys residual_tracers = _map(trace.new_instantiated_const, residuals) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None) for pv, const in zip(out_pvs, out_consts)] linear_2 = ([lin or not uk for uk, lin in zip(unknowns, linear)] + [False] * len(residual_tracers)) eqn = pe.new_jaxpr_eqn(new_tracers + residual_tracers, out_tracers, scan_p, (), dict(forward=forward, length=length, jaxpr=jaxpr_2, num_consts=num_consts, num_carry=num_carry, linear=linear_2)) for t in out_tracers: t.recipe = eqn return out_tracers
def partial_eval_jaxpr(jaxpr, in_unknowns): env: Dict[Var, bool] = {} res = [] def read(v): if type(v) is core.Literal: raise NotImplementedError # TODO else: return env[v] def write(unk, v): env[v] = unk def new_res(v): res.append(v) return v eqns1, eqns2 = [], [] map(write, in_unknowns, jaxpr.in_binders) for eqn in jaxpr.eqns: unks = map(read, eqn.invars) if any(unks): invars = [v if unk else new_res(v) for unk, v in zip(unks, eqn.invars)] eqns2.append(pe.new_jaxpr_eqn(invars, eqn.outvars, eqn.primitive, eqn.params, None)) map(partial(write, True), eqn.outvars) else: eqns1.append(eqn) map(partial(write, False), eqn.outvars) out_unknowns = map(read, jaxpr.outs) out_dim_unknowns = map(read, jaxpr.out_dims) # when linearizing, all known invars1, invars2 = partition_list(in_unknowns, jaxpr.in_binders) outvars1, outvars2 = partition_list(out_unknowns, jaxpr.outs) out_dims1, out_dims2 = partition_list(out_dim_unknowns, jaxpr.out_dims) outvars1 = outvars1 + res invars2 = res + invars2 # TODO forward the correct residuals here (all dimvars used in types) in_dimvars2 = out_dims1 + jaxpr.in_dim_binders jaxpr1 = DJaxpr(jaxpr.in_dim_binders, invars1, out_dims1, outvars1, eqns1) jaxpr2 = DJaxpr(in_dimvars2, invars2, out_dims2, outvars2, eqns2) return jaxpr1, jaxpr2, out_unknowns, len(res)
def _nonzero_staging_rule(trace, tracers, params): aval = tracers[0].aval if isinstance(aval, AbsArray) and not isinstance(aval._eltTy, BaseType): raise NotImplementedError bound = aval.shape[-1] bound = bound if isinstance(bound, int) else bound._bound out_dim_aval = AbsArray(aval.shape[:-1], BoundedIntTy(bound)) out_dim_tracer = pe.DynamicJaxprTracer(trace, out_dim_aval, None) if len(aval.shape) == 1: out_val_aval = AbsArray((out_dim_tracer,), BaseType(np.dtype('int32'))) else: indices = tuple(range(len(aval.shape[:-1]))) expr = DimIndexingExpr(out_dim_tracer, indices) out_val_aval = AbsArray((*aval.shape[:-1], expr), BaseType(np.dtype('int32'))) out_val_tracer = pe.DynamicJaxprTracer(trace, out_val_aval, None) invars = map(trace.getvar, tracers) outvars = map(trace.makevar, [out_dim_tracer, out_val_tracer]) eqn = pe.new_jaxpr_eqn(invars, outvars, nonzero_p, {}, None) trace.frame.eqns.append(eqn) return out_val_tracer