def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params): from jax.interpreters.partial_eval import ( trace_to_subjaxpr_dynamic, DynamicJaxprTracer, source_info_util, convert_constvars_jaxpr, call_param_updaters, new_jaxpr_eqn) assert primitive is xmap_p in_avals = [t.aval for t in tracers] axis_sizes = params['axis_sizes'] mapped_in_avals = [_delete_aval_axes(a, a_in_axes) for a, a_in_axes in zip(in_avals, params['in_axes'])] with core.extend_axis_env_nd(params['axis_sizes'].items()): jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic( f, self.main, mapped_in_avals) out_axes = params['out_axes_thunk']() out_avals = [_insert_aval_axes(a, a_out_axes, axis_sizes) for a, a_out_axes in zip(mapped_out_avals, out_axes)] source_info = source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) constvars = map(self.getvar, map(self.instantiate_const, consts)) outvars = map(self.makevar, out_tracers) new_in_axes = (None,) * len(consts) + params['in_axes'] new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, call_jaxpr=convert_constvars_jaxpr(jaxpr)) del new_params['out_axes_thunk'] update_params = call_param_updaters.get(primitive) if update_params: new_params = update_params(new_params, [True] * len(tracers)) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive, new_params, source_info) self.frame.eqns.append(eqn) return out_tracers
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params): from jax.interpreters.partial_eval import ( trace_to_subjaxpr_dynamic, DynamicJaxprTracer, source_info_util, convert_constvars_jaxpr, new_jaxpr_eqn) assert primitive is xmap_p in_avals = [t.aval for t in tracers] global_axis_sizes = params['global_axis_sizes'] mapped_in_avals = [_delete_aval_axes(a, a_in_axes) for a, a_in_axes in zip(in_avals, params['in_axes'])] with core.extend_axis_env_nd(global_axis_sizes.items()): jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic( f, self.main, mapped_in_avals) out_axes = params['out_axes_thunk']() axis_resource_count = _get_axis_resource_count(params['axis_resources'], params['resource_env']) local_axis_sizes = {axis: axis_resource_count[axis].to_local(global_size) for axis, global_size in global_axis_sizes.items()} out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes) for a, a_out_axes in zip(mapped_out_avals, out_axes)] _check_out_avals_vs_out_axes(out_avals, out_axes, params['global_axis_sizes']) source_info = source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) constvars = map(self.getvar, map(self.instantiate_const, consts)) outvars = map(self.makevar, out_tracers) new_in_axes = (AxisNamePos(user_repr='{}'),) * len(consts) + params['in_axes'] new_donated_invars = (False,) * len(consts) + params['donated_invars'] new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, donated_invars=new_donated_invars, call_jaxpr=convert_constvars_jaxpr(jaxpr)) del new_params['out_axes_thunk'] eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive, new_params, source_info) self.frame.eqns.append(eqn) return out_tracers