def _while_callback_rule(trace, *tracers, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): cond_const_tracers, body_const_tracers, init_tracers = split_list( tracers, [cond_nconsts, body_nconsts]) init_avals = safe_map(lambda x: x.aval, init_tracers) cond_const_vals, body_const_vals, init_vals = tree_map( lambda x: x.val, (cond_const_tracers, body_const_tracers, init_tracers)) body_fun = jaxpr_as_fun(body_jaxpr) cond_fun = jaxpr_as_fun(cond_jaxpr) def cond(*carry): return cond_fun(*it.chain(cond_const_vals, carry)) def body(*carry): return body_fun(*it.chain(body_const_vals, carry)) new_cond = callback_transform(cond, trace.callback, strip_calls=trace.strip_calls) # type: ignore new_body = callback_transform(body, trace.callback, strip_calls=trace.strip_calls) # type: ignore in_tree = tree_structure(init_avals) new_cond_jaxpr, new_cond_consts, _ = lcf._initial_style_jaxpr(new_cond, in_tree, tuple(init_avals)) new_body_jaxpr, new_body_consts, _ = lcf._initial_style_jaxpr(new_body, in_tree, tuple(init_avals)) out = lcf.while_p.bind( *it.chain(new_cond_consts, new_body_consts, init_vals), cond_nconsts=len(new_cond_consts), body_nconsts=len(new_body_consts), cond_jaxpr=new_cond_jaxpr, body_jaxpr=new_body_jaxpr) return safe_map(trace.pure, out)
def start_tracing_body(self): """Called upon starting the tracing of the loop body.""" # TODO: This is the first part of partial_eval.trace_to_subjaxpr. Share. self.trace = self.scope.start_subtrace() # The entire state is carried. self.carried_state_names = sorted(self.scope._mutable_state.keys()) for key in self.carried_state_names: init_val = self.scope._mutable_state[key] flat_init_vals, init_tree = tree_util.tree_flatten(init_val) flat_init_avals = safe_map(_BodyTracer.abstractify, flat_init_vals) flat_init_pvals = safe_map(pe.PartialVal.unknown, flat_init_avals) flat_init_vars = safe_map(self.trace.new_arg, flat_init_pvals) self.carried_state_vars[key] = flat_init_vars # Set the scope._mutable_state to new tracing variables. self.scope._mutable_state[key] = init_tree.unflatten( flat_init_vars) self.scope._mutable_state_aval[key] = init_tree.unflatten( flat_init_avals) # Make a copy of the initial state by unflattening the flat_init_vals self.carried_state_initial[key] = init_tree.unflatten( flat_init_vals) index_var_aval = _BodyTracer.abstractify(0) index_var_pval = pe.PartialVal.unknown(index_var_aval) self._index_var = self.trace.new_arg(index_var_pval)
def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True): # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share. instantiate = [instantiate] * len(out_tracers) out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers)) out_tracers = safe_map(partial(pe.instantiate_const_at, trace), instantiate, out_tracers) jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers) assert not env # TODO: this is from partial_eval.trace_to_jaxpr. Share. closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) return closed_jaxpr, consts
def __setattr__(self, key, value): """Update scope data to be functionalized. Called for *all* attribute setting. """ if key in [ "_active_ranges", "_mutable_state", "_mutable_state_aval", "_count_subtraces" ]: object.__setattr__(self, key, value) else: if self._active_ranges: if key not in self._mutable_state: raise ValueError( "New mutable state '{}' cannot be created inside a loop." .format(key)) assert key in self._mutable_state_aval old_aval = self._mutable_state_aval[key] flat_values, flat_tree = tree_util.tree_flatten(value) new_aval = flat_tree.unflatten( safe_map(_BodyTracer.abstractify, flat_values)) if old_aval != new_aval: msg = ( f"Mutable state '{key}' is updated with new abstract value " f"{new_aval}, which is different from previous one {old_aval}" ) raise TypeError(msg) self._mutable_state[key] = value
def _custom_derivative_call_jaxpr_callback_rule(primitive, trace, *tracers, fun_jaxpr, num_consts, **params): main = trace.main vals = [t.val for t in tracers] new_closed_jaxpr = callback_jaxpr(fun_jaxpr, trace.callback, strip_calls=trace.strip_calls) if primitive == cd.custom_jvp_call_jaxpr_p: thunk_name = 'jvp_jaxpr_thunk' elif primitive == cd.custom_vjp_call_jaxpr_p: thunk_name = 'fwd_jaxpr_thunk' params['bwd'] = callback_subtrace(params['bwd'], main) else: raise NotImplementedError(primitive) thunk = params.pop(thunk_name) @pe._memoize def new_thunk(): thunk_jaxpr = core.ClosedJaxpr(*thunk()) closed_jaxpr = callback_jaxpr(thunk_jaxpr, trace.callback, trace.strip_calls) return closed_jaxpr.jaxpr, closed_jaxpr.literals params[thunk_name] = new_thunk new_fun_jaxpr, new_consts = new_closed_jaxpr.jaxpr, new_closed_jaxpr.literals closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(new_fun_jaxpr), ()) new_num_consts = len(new_consts) + num_consts out = primitive.bind(*it.chain(new_consts, vals), fun_jaxpr=closed_fun_jaxpr, num_consts=new_num_consts, **params) return safe_map(trace.pure, out)
def _scan_callback_rule(trace, *tracers, reverse, length, num_consts, num_carry, jaxpr, linear, unroll): const_tracers, carry_tracers, xs_tracers = split_list(tracers, [num_consts, num_carry]) carry_avals, xs_avals = tree_map(lambda x: x.aval, (carry_tracers, xs_tracers)) const_vals, carry_vals, xs_vals = tree_map(lambda x: x.val, (const_tracers, carry_tracers, xs_tracers)) x_tracers = [t[0] for t in xs_tracers] x_avals = [t.aval for t in x_tracers] body_fun = jaxpr_as_fun(jaxpr) def new_body(*vals): out = body_fun(*vals) out_carry, y = split_list(out, [num_carry]) return out_carry, y new_body = callback_transform(new_body, trace.callback, strip_calls=trace.strip_calls) # type: ignore in_tree = tree_structure(carry_avals + xs_avals) new_jaxpr, new_consts, _ = lcf._initial_style_jaxpr( new_body, in_tree, tuple(carry_avals + x_avals)) vals = tuple(it.chain(new_consts, carry_vals, xs_vals)) out_vals = lax.scan_p.bind(*vals, reverse=reverse, length=length, num_consts=len(new_consts), num_carry=num_carry, jaxpr=new_jaxpr, linear=linear, unroll=unroll) return safe_map(trace.pure, out_vals)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees): vals_in = [t.val for t in tracers] fun = callback_subtrace(fun, self.main) fwd = callback_subtrace(fwd, self.main) bwd = callback_subtrace(bwd, self.main) out = primitive.bind(fun, fwd, bwd, *vals_in, out_trees=out_trees) return safe_map(self.pure, out)
def end_tracing_body(self): """Called when we are done tracing one iteration of the body.""" # We will turn the body of the loop into a function that takes some values # for the scope state (carried_state_names) and returns the values for the # same state fields after one execution of the body. For some of the ranges, # e.g., scope.range, the function will also take the index_var as last parameter. in_tracers = tuple(itertools.chain(*[self.carried_state_vars[ms] for ms in self.carried_state_names])) if self.loop_builder.can_use_index_var(): in_tracers += (self._index_var,) # Make the jaxpr for the body of the loop # TODO: See which mutable state was changed in the one iteration. # For now, we assume all state changes. body_out_tracers = [] for key in self.carried_state_names: new_val = self.scope._mutable_state[key] flat_new_values, flat_new_tree = tree_util.tree_flatten(new_val) body_out_tracers.extend(flat_new_values) assert key in self.scope._mutable_state_aval old_aval = self.scope._mutable_state_aval[key] new_aval = flat_new_tree.unflatten(safe_map(_BodyTracer.abstractify, flat_new_values)) if old_aval != new_aval: msg = (f"Mutable state '{key}' had at the end of the loop body new abstract value " f"{new_aval}, which is different from initial one {old_aval}") raise TypeError(msg) try: # If the body actually uses the index variable, and is not allowed to # (e.g., cond_range and while_range), then in_tracers will not contain # the tracer for the index_var, and trace_to_jaxpr_finalize will throw # an assertion error. body_closed_jaxpr, body_const_vals = _BodyTracer.trace_to_jaxpr_finalize( in_tracers=in_tracers, out_tracers=body_out_tracers, trace=self.trace) except UnexpectedTracerError as e: if "Tracer not among input tracers" in str(e): raise ValueError("Body of cond_range or while_range should not use the " "index variable returned by iterator.") from e raise # End the subtrace for the loop body, before we trace the condition self.scope.end_subtrace() carried_init_val = tuple([self.carried_state_initial[ms] for ms in self.carried_state_names]) carried_init_vals, carried_tree = tree_util.tree_flatten(carried_init_val) assert len(carried_init_vals) == len(body_out_tracers) carried_out_vals = self.loop_builder.build_output_vals( self.scope, self.carried_state_names, carried_tree, carried_init_vals, body_closed_jaxpr, body_const_vals) carried_mutable_state_unflattened = tree_util.tree_unflatten(carried_tree, carried_out_vals) # Update the mutable state with the values of the changed vars, after the loop. for ms, mv in zip(self.carried_state_names, carried_mutable_state_unflattened): self.scope._mutable_state[ms] = mv
def build_output_vals(self, scope, carried_state_names, carried_tree, init_vals, body_closed_jaxpr, body_const_vals): # Trace the conditional function. cond_func takes 0 arguments, but # for lax.while we need a conditional function that takes the # carried_state_names. _initial_style_jaxpr will start its own trace and # will create tracers for all the carried state. We must put these values # in the scope._mutable_state before we trace the conditional # function. def cond_func_wrapped(*args): assert len(args) == len(carried_state_names) for ms, init_ms in zip(carried_state_names, args): scope._mutable_state[ms] = init_ms res = self.cond_func() # Conditional function is not allowed to modify the scope state for ms, init_ms in zip(carried_state_names, args): if not (scope._mutable_state[ms] is init_ms): raise ValueError( f"Conditional function modifies scope.{ms} field.") return res init_avals = safe_map(_BodyTracer.abstractify, init_vals) cond_jaxpr, cond_consts, cond_tree = ( lax_control_flow._initial_style_jaxpr(cond_func_wrapped, carried_tree, tuple(init_avals))) # TODO: share these checks with lax_control_flow.while if not tree_util.treedef_is_leaf(cond_tree): raise TypeError( f"cond_fun must return a boolean scalar, but got pytree {cond_tree}." ) if not safe_map(core.typecompat, cond_jaxpr.out_avals, [core.ShapedArray((), np.bool_)]): raise TypeError( f"cond_fun must return a boolean scalar, but got output type(s) " f"{cond_jaxpr.out_avals}.") return lax_control_flow.while_p.bind(*cond_consts, *body_const_vals, *init_vals, cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr, body_nconsts=len(body_const_vals), body_jaxpr=body_closed_jaxpr)
def _sharded_jit_lowering(ctx, *in_nodes, in_parts, out_parts_thunk, nparts, name, call_jaxpr, local_in_parts, local_out_parts_thunk, local_nparts): # We assume any extra leading in_nodes are constants and replicate them. num_extra_nodes = len(in_nodes) - len(in_parts) assert num_extra_nodes >= 0 in_parts = (None, ) * num_extra_nodes + in_parts args = [] for ns, sharding in safe_zip( safe_map(mlir.wrap_singleton_ir_values, in_nodes), in_parts): if sharding is not None: args.append([ mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding)) for n in ns ]) else: args.append(ns) sub_ctx = ctx.module_context.replace( name_stack=extend_name_stack(wrap_name(name, "sharded_jit"))) fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}", core.ClosedJaxpr(call_jaxpr, ())) output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) call = std.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(fn.name.value), mlir.flatten_lowering_ir_args(args)) out_nodes = util.unflatten(call.results, safe_map(len, output_types)) out_parts = out_parts_thunk() outputs = [] for ns, sharding in safe_zip(out_nodes, out_parts): if sharding is not None: outputs.append([ mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding)) for n in ns ]) else: outputs.append(ns) return outputs
def build_output_vals(self, scope, carried_state_names, carried_tree, init_vals, body_closed_jaxpr, body_const_vals): # Simulate a pass-through false branch in_vals, in_tree = tree_util.tree_flatten( (body_const_vals, tree_util.tree_unflatten(carried_tree, init_vals))) in_avals = safe_map(_BodyTracer.abstractify, in_vals) pass_through_closed_jaxpr, pass_through_const_vals, _ = ( lax_control_flow._initial_style_jaxpr( lambda *args: args[1], in_tree, tuple(in_avals))) assert len(pass_through_const_vals) == 0 args = [*body_const_vals, *init_vals] return lax_control_flow.cond_p.bind( self.index, *args, branches=(pass_through_closed_jaxpr, body_closed_jaxpr), linear=(False,) * len(args))
def _sharded_callable( fun: lu.WrappedFun, nparts: Optional[int], in_parts: Tuple[pxla.PartitionsOrReplicated, ...], out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]], local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]], local_out_parts_thunk: Callable[[], Optional[Tuple[ pxla.PartitionsOrReplicated, ...]]], local_nparts: Optional[int], name: str, *abstract_args): nrep = 1 if local_in_parts is None: local_in_parts = in_parts global_abstract_args = [ pxla.get_global_aval(arg, parts, lparts) for arg, parts, lparts in safe_zip( abstract_args, in_parts, local_in_parts) ] if logging.vlog_is_on(2): logging.vlog(2, "abstract_args: %s", abstract_args) logging.vlog(2, "global_abstract_args: %s", global_abstract_args) logging.vlog(2, "in_parts: %s", in_parts) logging.vlog(2, "local_in_parts: %s", local_in_parts) jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( fun, global_abstract_args) platform = xb.get_backend().platform nparts = pxla.reconcile_num_partitions(jaxpr, nparts) assert nparts is not None if nparts > xb.device_count(): raise ValueError( f"sharded_jit computation requires {nparts} devices, " f"but only {xb.device_count()} devices are available.") if xb.local_device_count() < nparts < xb.device_count(): raise NotImplementedError( f"sharded_jit across multiple hosts must use all available devices. " f"Got {nparts} out of {xb.device_count()} requested devices " f"(local device count: {xb.local_device_count()})") if local_nparts is None: if nparts > xb.local_device_count(): raise ValueError( "Specify 'local_nparts' when using cross-process sharded_jit " "and all inputs and outputs are replicated.") else: local_nparts = nparts if local_nparts > xb.local_device_count(): raise ValueError( f"sharded_jit computation requires {local_nparts} local devices, " f"but only {xb.local_device_count()} local devices are available.") if logging.vlog_is_on(2): logging.vlog(2, "nparts: %d local_nparts: %d", nparts, local_nparts) out_parts = out_parts_thunk() local_out_parts = local_out_parts_thunk() if local_out_parts is None: local_out_parts = out_parts if logging.vlog_is_on(2): logging.vlog(2, "out_parts: %s", out_parts) logging.vlog(2, "local_out_parts: %s", local_out_parts) local_out_avals = [ pxla.get_local_aval(out, parts, lparts) for out, parts, lparts in safe_zip( global_out_avals, out_parts, local_out_parts) ] log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logging.log(log_priority, "Compiling %s for %d devices with args %s.", fun.__name__, nparts, global_abstract_args) axis_env = xla.AxisEnv(nrep, (), ()) unordered_effects = [ eff for eff in jaxpr.effects if eff not in core.ordered_effects ] ordered_effects = [ eff for eff in jaxpr.effects if eff in core.ordered_effects ] module, _ = mlir.lower_jaxpr_to_module( f"spjit_{fun.__name__}", core.ClosedJaxpr(jaxpr, consts), unordered_effects, ordered_effects, platform=platform, axis_context=mlir.ReplicaAxisContext(axis_env), name_stack=new_name_stack(wrap_name(name, "sharded_jit")), donated_args=[False] * len(in_parts), arg_shardings=safe_map(xla.sharding_to_proto, in_parts), result_shardings=safe_map(xla.sharding_to_proto, out_parts)) built = xc._xla.mlir.mlir_module_to_xla_computation( mlir.module_to_string(module), use_tuple_args=False, return_tuple=True) if nparts <= xb.local_device_count(): devices = xb.local_devices()[:nparts] else: assert nparts == xb.device_count() devices = xb.devices() device_assignment = np.array([[d for d in devices]]) device_assignment = np.reshape(device_assignment, (-1, nparts)) # device_assignment = None # TODO(skye): replace with default device assignment? compiled = dispatch.backend_compile( xb.get_backend(), built, xb.get_compile_options(nrep, nparts, device_assignment)) input_specs = [ pxla.partitioned_sharding_spec(local_nparts, parts, aval) for parts, aval in zip(local_in_parts, abstract_args) ] input_indices = [ pxla.spec_to_indices(aval.shape, spec) if spec is not None else None for aval, spec in zip(abstract_args, input_specs) ] handle_args = partial(pxla.shard_args, compiled.local_devices(), input_indices) handle_outs = _avals_to_results_handler( nrep, local_nparts, # type: ignore local_out_parts, local_out_avals) return partial(_execute_spatially_partitioned, compiled, handle_args, handle_outs)
def process_custom_jvp_call(self, primitive, fun, jvp, tracers): vals_in = [t.val for t in tracers] fun = callback_subtrace(fun, self.main) jvp = callback_subtrace(jvp, self.main) out = primitive.bind(fun, jvp, *vals_in) return safe_map(self.pure, out)