def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): assert call_primitive.multiple_results primals, tangents = unzip2((t.primal, t.tangent) for t in tracers) nonzero_tangents, tangent_tree_def = tree_flatten(tangents) nz_tangents = [type(t) is not Zero for t in tangents] if 'name' in params and not config.jax_experimental_name_stack: params = dict(params, name=wrap_name(params['name'], 'jvp')) f_jvp = jvp_subtrace(f, self.main) f_jvp, nz_tangents_out = nonzero_tangent_outputs(f_jvp) if isinstance(call_primitive, core.MapPrimitive): in_axes = params['in_axes'] tangent_in_axes = [ax for ax, nz in zip(in_axes, nz_tangents) if nz] out_axes_thunk = params['out_axes_thunk'] # The new thunk depends deterministically on the old thunk and the wrapped function. # Any caching already has to include the wrapped function as part of the key, so we # only use the previous thunk for equality checks. # NOTE: This assumes that the output tangents being zero is a deterministic # function of which input tangents were zero. @as_hashable_function(closure=(tuple(nz_tangents), out_axes_thunk)) def new_out_axes_thunk(): out_axes = out_axes_thunk() return (*out_axes, *(ax for ax, nz in zip(out_axes, nz_tangents_out()) if nz)) params = dict(params, in_axes=(*in_axes, *tangent_in_axes), out_axes_thunk=new_out_axes_thunk) f_jvp, out_tree_def = traceable(f_jvp, len(primals), tangent_tree_def) update_params = call_param_updaters.get(call_primitive) new_params = update_params(params, nz_tangents) if update_params else params result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params) primal_out, tangent_out = tree_unflatten(out_tree_def(), result) return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
def remat_impl(*args, call_jaxpr: Optional[core.Jaxpr] = None, jaxpr: Optional[core.Jaxpr] = None, prevent_cse: bool, differentiated: bool, policy, is_gpu_platform: bool = False, concrete: bool = False, name: str = "checkpoint"): # Support either "jaxpr" (for remat2) and "call_jaxpr" (for remat) # name is not passed for remat2, defaults to "checkpoint" # TODO: remove call_jaxpr once we drop the remat call primitive if jaxpr is None: jaxpr = call_jaxpr assert jaxpr is not None assert not jaxpr.constvars del concrete, policy # Unused. if differentiated and prevent_cse: if config.jax_remat_opt_barrier: translation_rule = _remat_translation_using_opt_barrier elif is_gpu_platform: translation_rule = _remat_translation_using_while else: translation_rule = _remat_translation_using_cond else: translation_rule = lambda *args, jaxpr: core.eval_jaxpr( jaxpr, (), *args) return jax.named_call(translation_rule, name=wrap_name(name, "remat"))(*args, jaxpr=jaxpr)
def _xla_call_translation_rule(ctx, avals_in, avals_out, *in_nodes, name, backend=None, call_jaxpr, donated_invars, inline=None, device=None): del device, donated_invars, inline # Ignored. c = ctx.builder check_backend_matches(backend, ctx.platform) subc = xc.XlaBuilder(f"jit_{name}") args = [parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)] sub_ctx = ctx.replace(builder=subc, name_stack=extend_name_stack(ctx.name_stack, wrap_name(name, 'jit'))) out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args) if len(out_nodes) == 1: subc = subc.Build(out_nodes[0]) return [xops.Call(c, subc, list(in_nodes))] else: subc = subc.Build(xops.Tuple(subc, out_nodes)) return xla_destructure(c, xops.Call(c, subc, list(in_nodes)))
def _remat_using_cond(ctx, in_nodes, name, call_jaxpr): """Lower remat to a Conditional which always returns true. This: 1. Circumvents common subexpression elimination. 2. In common case of `jax.grad(jax.remat(f))`, ensures the remat blocks occur after the primal blocks, because cotangent is an input to the Conditional.""" # Fake condition which always selects True branch. c = ctx.builder rng = xops.RngUniform(xops.Constant(c, np.array(0, dtype=np.float32)), xops.Constant(c, np.array(1, dtype=np.float32)), xc.Shape.array_shape(xc.PrimitiveType.F32, [])) pred = xops.Lt(rng, xops.Constant(c, np.array(2, dtype=np.float32))) true_op = xops.Tuple(c, in_nodes) remat_subc = xc.XlaBuilder("remat_call_subcomputation") input_op = parameter(remat_subc, 0, c.get_shape(true_op), replicated=[]) args = xla_destructure(remat_subc, input_op) sub_ctx = ctx.replace(builder=remat_subc, name_stack=extend_name_stack( ctx.name_stack, wrap_name(name, 'remat'))) out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args) out_node_shapes = [remat_subc.get_shape(o) for o in out_nodes] remat_subc = remat_subc.build(xops.Tuple(remat_subc, out_nodes)) false_op = true_op dummy_subc = xc.XlaBuilder("remat_call_dummy_subcomputation") parameter(dummy_subc, 0, c.get_shape(false_op), replicated=[]) out_nodes = [_zeros(dummy_subc, s) for s in out_node_shapes] dummy_subc = dummy_subc.build(xops.Tuple(dummy_subc, out_nodes)) return xla_destructure( c, xops.Conditional(pred, true_op, remat_subc, false_op, dummy_subc))
def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results if config.jax_experimental_name_stack: params = dict(params, name=params.get('name', f.__name__)) else: params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap')) vals, dims = unzip2((t.val, t.batch_dim) for t in tracers) if all(bdim is not_mapped for bdim in dims): return call_primitive.bind(f, *vals, **params) else: f_, dims_out = batch_subtrace(f, self.main, dims) ax_size, = { x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped } f_ = _update_annotation(f_, f.in_type, ax_size, self.axis_name, dims) vals_out = call_primitive.bind(f_, *vals, **params) src = source_info_util.current() return [ BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out()) ]
def _sharded_jit_translation_rule(ctx, avals_in, avals_out, *in_nodes, in_parts, out_parts_thunk, nparts, name, call_jaxpr, local_in_parts, local_out_parts_thunk, local_nparts): subc = xc.XlaBuilder(f"sharded_jit_{name}") # We assume any extra leading in_nodes are constants and replicate them. num_extra_nodes = len(in_nodes) - len(in_parts) assert num_extra_nodes >= 0 in_parts = (None,) * num_extra_nodes + in_parts args = [] for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)): # We use xla.set_sharding instead of xla.with_sharding because inlined calls # shouldn't have shardings set directly on the inputs or outputs. arg = xla.parameter(subc, i, ctx.builder.GetShape(n)) args.append(xla.set_sharding(subc, arg, sharding)) sub_ctx = ctx.replace( builder=subc, name_stack=new_name_stack(wrap_name(name, "sharded_jit"))) out_nodes = xla.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args) out_parts = out_parts_thunk() assert len(out_parts) == len(out_nodes) out_nodes = [xla.set_sharding(subc, out, sharding) for out, sharding in safe_zip(out_nodes, out_parts)] subc = subc.build(xops.Tuple(subc, out_nodes)) return xla.xla_destructure(ctx.builder, xops.Call(ctx.builder, subc, list(in_nodes)))
def _remat_using_while(ctx, in_nodes, name, call_jaxpr): """Lower remat to a single iteration while loop.""" c = ctx.builder # Dummy subc for getting subcomp shapes. dummy_inputs = xops.Tuple(c, in_nodes) dummy_subc = xc.XlaBuilder("remat_dummy_subcomputation") dummy_input_op = parameter(dummy_subc, 0, c.get_shape(dummy_inputs), replicated=[]) dummy_args = xla_destructure(dummy_subc, dummy_input_op) dummy_ctx = ctx.replace(builder=dummy_subc, name_stack=extend_name_stack( ctx.name_stack, wrap_name(name, 'remat'))) dummy_subcomp_outs = jaxpr_subcomp(dummy_ctx, call_jaxpr, (), *dummy_args) out_node_shapes = [dummy_subc.get_shape(o) for o in dummy_subcomp_outs] i_init = xops.Constant(c, np.array(0, dtype=np.int32)) zeros_like_outs = [_zeros(c, s) for s in out_node_shapes] inputs = xops.Tuple(c, [i_init] + list(in_nodes) + zeros_like_outs) cond_subc = xc.XlaBuilder("remat_cond_subcomputation") input_op = parameter(cond_subc, 0, c.get_shape(inputs), replicated=[]) i = xops.GetTupleElement(input_op, 0) rng = xops.RngUniform( xops.Constant(cond_subc, np.array(1, dtype=np.int32)), xops.Constant(cond_subc, np.array(2, dtype=np.int32)), xc.Shape.array_shape(xc.PrimitiveType.S32, [])) cond_subc = cond_subc.build(xops.Lt(i, rng)) body_subc = xc.XlaBuilder("remat_body_subcomputation") input_op = parameter(body_subc, 0, c.get_shape(inputs), replicated=[]) i, *args = xla_destructure(body_subc, input_op)[:len(in_nodes) + 1] i_next = xops.Add(i, xops.Constant(body_subc, np.array(1, dtype=np.int32))) body_ctx = ctx.replace(builder=body_subc, name_stack=extend_name_stack( ctx.name_stack, wrap_name(name, 'remat'))) subcomp_outs = jaxpr_subcomp(body_ctx, call_jaxpr, (), *args) out_nodes = [i_next] + args + list(subcomp_outs) body_subc = body_subc.build(xops.Tuple(body_subc, out_nodes)) outs = xops.While(cond_subc, body_subc, inputs) return xla_destructure(c, outs)[len(in_nodes) + 1:]
def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) new_params = dict(params, name=wrap_name(params['name'], 'transpose')) update_params = call_transpose_param_updaters.get(primitive) if update_params: new_params = update_params(new_params, map(is_undefined_primal, args), [type(x) is not Zero for x in ct]) out_flat = primitive.bind(fun, *all_args, **new_params) return tree_unflatten(out_tree(), out_flat)
def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results heads, tails = unzip2((t.head, t.tail) for t in tracers) nonzero_tails, in_tree_def = tree_flatten(tails) f_double, out_tree_def = screen_nones(doubling_subtrace(f, self.main), len(heads), in_tree_def) name = params.get('name', f.__name__) new_params = dict(params, name=wrap_name(name, 'doubledouble'), donated_invars=(False,) * (len(heads) + len(nonzero_tails))) result = call_primitive.bind(f_double, *heads, *nonzero_tails, **new_params) heads_out, tails_out = tree_unflatten(out_tree_def(), result) return [DoublingTracer(self, h, t) for h, t in zip(heads_out, tails_out)]
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): assert call_primitive.multiple_results params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap')) vals, dims = unzip2((t.val, t.batch_dim) for t in tracers) if all(bdim is not_mapped for bdim in dims): return call_primitive.bind(f, *vals, **params) else: f, dims_out = batch_subtrace(f, self.main, dims) vals_out = call_primitive.bind(f, *vals, **params) return [ BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out()) ]
def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) if not config.jax_experimental_name_stack: params = dict(params, name=wrap_name(params['name'], 'transpose')) update_params = call_transpose_param_updaters.get(primitive) if update_params: params = update_params(params, map(is_undefined_primal, args), [type(x) is not Zero for x in ct]) if config.jax_dynamic_shapes: in_type = [(core.raise_to_shaped(core.get_aval(x)), True) for x in all_args] fun = lu.annotate(fun, tuple(in_type)) out_flat = primitive.bind(fun, *all_args, **params) return tree_unflatten(out_tree(), out_flat)
def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False) fun, nz_arg_cts = nonzero_outputs(fun) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) # Preserve axis for primal arguments, skip tangents (represented as undefined primals). in_axes, out_axes = params['in_axes'], params['out_axes'] new_in_axes = (*[ axis for axis, x in zip(in_axes, args) if not is_undefined_primal(x) ], *[axis for axis, x in zip(out_axes, ct) if type(x) is not Zero]) # The interim strategy we use below (until avals-with-names) only works # when all outputs are mapped. assert all(out_axis is not None for out_axis in out_axes), out_axes # NOTE: This assumes that the output cotangents being zero is a deterministic # function of which input cotangents were zero. @as_hashable_function(closure=(in_axes, tuple(type(c) is Zero for c in ct))) def out_axes_thunk(): return tuple(axis or 0 for axis, nz in zip(in_axes, nz_arg_cts()) if nz) new_params = dict(params, name=wrap_name(params['name'], 'transpose'), in_axes=new_in_axes, out_axes_thunk=out_axes_thunk) del new_params['out_axes'] update_params = call_transpose_param_updaters.get(primitive) if update_params: new_params = update_params(new_params, map(is_undefined_primal, args), [type(x) is not Zero for x in ct]) out_flat = primitive.bind(fun, *all_args, **new_params) arg_cts = tree_unflatten(out_tree(), out_flat) # The freevars are being fanned out (not mapped). During transpose the # dual of fan-out is fan-in-sum. We apply it to the unmapped invars. assert len(in_axes) == len(arg_cts) def unmap_zero(zero, in_axis): return (zero if in_axis is None else Zero( core.unmapped_aval(params['axis_size'], params['axis_name'], in_axis, zero.aval))) arg_cts = (unmap_zero(arg_ct, in_axis) if type(arg_ct) is Zero else arg_ct if in_axis is not None else arg_ct.sum(0) for arg_ct, in_axis in zip(arg_cts, in_axes)) return tuple(arg_cts)
def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results params = dict(params, name=wrap_name(params.get('name', f.__name__), 'mask')) vals, shapes = unzip2((t.val, t.polymorphic_shape) for t in tracers) if not any(is_polymorphic(s) for s in shapes): return call_primitive.bind(f, *vals, **params) else: logical_env, padded_env = shape_envs env_keys, padded_env_vals = unzip2(sorted(padded_env.items())) logical_env_vals = tuple(logical_env[k] for k in env_keys) # Make padded_env hashable padded_env = (env_keys, padded_env_vals) f, shapes_out = mask_subtrace(f, self.main, shapes, padded_env) if 'donated_invars' in params: params = dict(params, donated_invars=((False,) * len(logical_env_vals) + params['donated_invars'])) vals_out = call_primitive.bind(f, *(logical_env_vals + vals), **params) return [MaskTracer(self, v, s) for v, s in zip(vals_out, shapes_out())]
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): assert call_primitive.multiple_results if config.jax_experimental_name_stack: params = dict(params, name=params.get('name', f.__name__)) else: params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap')) vals, dims = unzip2((t.val, t.batch_dim) for t in tracers) if all(bdim is not_mapped for bdim in dims): return call_primitive.bind(f, *vals, **params) else: f, dims_out = batch_subtrace(f, self.main, dims) vals_out = call_primitive.bind(f, *vals, **params) src = source_info_util.current() return [ BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out()) ]
def _sharded_jit_lowering(ctx, *in_nodes, in_parts, out_parts_thunk, nparts, name, call_jaxpr, local_in_parts, local_out_parts_thunk, local_nparts): # We assume any extra leading in_nodes are constants and replicate them. num_extra_nodes = len(in_nodes) - len(in_parts) assert num_extra_nodes >= 0 in_parts = (None, ) * num_extra_nodes + in_parts args = [] for ns, sharding in safe_zip( safe_map(mlir.wrap_singleton_ir_values, in_nodes), in_parts): if sharding is not None: args.append([ mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding)) for n in ns ]) else: args.append(ns) sub_ctx = ctx.module_context.replace( name_stack=extend_name_stack(wrap_name(name, "sharded_jit"))) fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}", core.ClosedJaxpr(call_jaxpr, ())) output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) call = std.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(fn.name.value), mlir.flatten_lowering_ir_args(args)) out_nodes = util.unflatten(call.results, safe_map(len, output_types)) out_parts = out_parts_thunk() outputs = [] for ns, sharding in safe_zip(out_nodes, out_parts): if sharding is not None: outputs.append([ mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding)) for n in ns ]) else: outputs.append(ns) return outputs
def _sharded_callable( fun: lu.WrappedFun, nparts: Optional[int], in_parts: Tuple[pxla.PartitionsOrReplicated, ...], out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]], local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]], local_out_parts_thunk: Callable[[], Optional[Tuple[ pxla.PartitionsOrReplicated, ...]]], local_nparts: Optional[int], name: str, *abstract_args): nrep = 1 if local_in_parts is None: local_in_parts = in_parts global_abstract_args = [ pxla.get_global_aval(arg, parts, lparts) for arg, parts, lparts in safe_zip( abstract_args, in_parts, local_in_parts) ] if logging.vlog_is_on(2): logging.vlog(2, "abstract_args: %s", abstract_args) logging.vlog(2, "global_abstract_args: %s", global_abstract_args) logging.vlog(2, "in_parts: %s", in_parts) logging.vlog(2, "local_in_parts: %s", local_in_parts) jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( fun, global_abstract_args) platform = xb.get_backend().platform nparts = pxla.reconcile_num_partitions(jaxpr, nparts) assert nparts is not None if nparts > xb.device_count(): raise ValueError( f"sharded_jit computation requires {nparts} devices, " f"but only {xb.device_count()} devices are available.") if xb.local_device_count() < nparts < xb.device_count(): raise NotImplementedError( f"sharded_jit across multiple hosts must use all available devices. " f"Got {nparts} out of {xb.device_count()} requested devices " f"(local device count: {xb.local_device_count()})") if local_nparts is None: if nparts > xb.local_device_count(): raise ValueError( "Specify 'local_nparts' when using cross-process sharded_jit " "and all inputs and outputs are replicated.") else: local_nparts = nparts if local_nparts > xb.local_device_count(): raise ValueError( f"sharded_jit computation requires {local_nparts} local devices, " f"but only {xb.local_device_count()} local devices are available.") if logging.vlog_is_on(2): logging.vlog(2, "nparts: %d local_nparts: %d", nparts, local_nparts) out_parts = out_parts_thunk() local_out_parts = local_out_parts_thunk() if local_out_parts is None: local_out_parts = out_parts if logging.vlog_is_on(2): logging.vlog(2, "out_parts: %s", out_parts) logging.vlog(2, "local_out_parts: %s", local_out_parts) local_out_avals = [ pxla.get_local_aval(out, parts, lparts) for out, parts, lparts in safe_zip( global_out_avals, out_parts, local_out_parts) ] log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logging.log(log_priority, "Compiling %s for %d devices with args %s.", fun.__name__, nparts, global_abstract_args) axis_env = xla.AxisEnv(nrep, (), ()) unordered_effects = [ eff for eff in jaxpr.effects if eff not in core.ordered_effects ] ordered_effects = [ eff for eff in jaxpr.effects if eff in core.ordered_effects ] module, _ = mlir.lower_jaxpr_to_module( f"spjit_{fun.__name__}", core.ClosedJaxpr(jaxpr, consts), unordered_effects, ordered_effects, platform=platform, axis_context=mlir.ReplicaAxisContext(axis_env), name_stack=new_name_stack(wrap_name(name, "sharded_jit")), donated_args=[False] * len(in_parts), arg_shardings=safe_map(xla.sharding_to_proto, in_parts), result_shardings=safe_map(xla.sharding_to_proto, out_parts)) built = xc._xla.mlir.mlir_module_to_xla_computation( mlir.module_to_string(module), use_tuple_args=False, return_tuple=True) if nparts <= xb.local_device_count(): devices = xb.local_devices()[:nparts] else: assert nparts == xb.device_count() devices = xb.devices() device_assignment = np.array([[d for d in devices]]) device_assignment = np.reshape(device_assignment, (-1, nparts)) # device_assignment = None # TODO(skye): replace with default device assignment? compiled = dispatch.backend_compile( xb.get_backend(), built, xb.get_compile_options(nrep, nparts, device_assignment)) input_specs = [ pxla.partitioned_sharding_spec(local_nparts, parts, aval) for parts, aval in zip(local_in_parts, abstract_args) ] input_indices = [ pxla.spec_to_indices(aval.shape, spec) if spec is not None else None for aval, spec in zip(abstract_args, input_specs) ] handle_args = partial(pxla.shard_args, compiled.local_devices(), input_indices) handle_outs = _avals_to_results_handler( nrep, local_nparts, # type: ignore local_out_parts, local_out_avals) return partial(_execute_spatially_partitioned, compiled, handle_args, handle_outs)
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, always_lower: bool, keep_unused: bool, *arg_specs): """Lower into XLA. Args: always_lower: If `True`, even trivial programs (not doing any computation such as lambda x: x) will be lowered into an XLA program. keep_unused: If `False` (the default), arguments that JAX determines to be unused by `fun` *may* be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If `True`, unused arguments will not be pruned. """ if device is not None and backend is not None: raise ValueError("can't specify both a device and a backend for jit, " "got device={} and backend={}".format( device, backend)) abstract_args, arg_devices = util.unzip2(arg_specs) if fun.in_type is not None: abstract_args, which_explicit = util.unzip2(fun.in_type) else: which_explicit = None with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " "for jit in {elapsed_time} sec"): jaxpr, out_avals, consts = pe.trace_to_jaxpr_final( fun, abstract_args, pe.debug_info_final(fun, "jit"), which_explicit) if any(isinstance(c, core.Tracer) for c in consts): raise UnexpectedTracerError("Encountered an unexpected tracer.") # TODO(mattjj): handle argument pruning w/ dynamic shapes if fun.in_type is None and not keep_unused: jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr) consts = [c for i, c in enumerate(consts) if i in kept_const_idx] abstract_args, arg_devices = util.unzip2( [a for i, a in enumerate(arg_specs) if i in kept_var_idx]) donated_invars = [ x for i, x in enumerate(donated_invars) if i in kept_var_idx ] del kept_const_idx else: kept_var_idx = set(range(len(abstract_args))) map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr))) jaxpr = apply_outfeed_rewriter(jaxpr) nreps = jaxpr_replicas(jaxpr) device = _xla_callable_device(nreps, backend, device, arg_devices) backend = xb.get_device_backend(device) if device else xb.get_backend( backend) if (config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr) and not _backend_supports_unbounded_dynamic_shapes(backend)): jaxpr, consts = pe.pad_jaxpr(jaxpr, consts) # Computations that only produce constants and/or only rearrange their inputs, # which are often produced from partial evaluation, don't need compilation, # and don't need to evaluate their arguments. if not jaxpr.eqns and not always_lower: return XlaComputation(name, None, True, None, None, jaxpr=jaxpr, consts=consts, device=device, in_avals=abstract_args, out_avals=out_avals, has_unordered_effects=False, ordered_effects=[], kept_var_idx=kept_var_idx, keepalive=None) if not _on_exit: log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG if len(abstract_args) > 10: msg = f"Compiling {fun.__name__} ({id(fun)}) for {len(abstract_args)} args." else: msg = f"Compiling {fun.__name__} ({id(fun)} for args {abstract_args}." logging.log(log_priority, msg) if nreps > 1: warnings.warn( f"The jitted function {name} includes a pmap. Using " "jit-of-pmap can lead to inefficient data movement, as the outer jit " "does not preserve sharded data representations and instead collects " "input and output arrays onto a single device. " "Consider removing the outer jit unless you know what you're doing. " "See https://github.com/google/jax/issues/2926.") if nreps > xb.device_count(backend): raise ValueError( f"compiling computation `{name}` that requires {nreps} replicas, but " f"only {xb.device_count(backend)} XLA devices are available.") if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)): raise NotImplementedError( "jit of multi-host pmap not implemented (and jit-of-pmap can cause " "extra data movement anyway, so maybe you don't want it after all)." ) # pass long arg lists as tuple for TPU tuple_args = len(abstract_args) > 100 axis_env = xla.AxisEnv(nreps, (), ()) name_stack = util.new_name_stack(util.wrap_name(name, 'jit')) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) module_name = f"jit_{fun.__name__}" unordered_effects = [ eff for eff in closed_jaxpr.effects if eff not in core.ordered_effects ] ordered_effects = [ eff for eff in closed_jaxpr.effects if eff in core.ordered_effects ] module, keepalive = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, unordered_effects, ordered_effects, backend.platform, mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars) return XlaComputation(name, module, False, donated_invars, which_explicit, nreps=nreps, device=device, backend=backend, tuple_args=tuple_args, in_avals=abstract_args, out_avals=out_avals, has_unordered_effects=bool(unordered_effects), ordered_effects=ordered_effects, kept_var_idx=kept_var_idx, keepalive=keepalive)
def _sharded_callable( fun: lu.WrappedFun, nparts: Optional[int], in_parts: Tuple[pxla.PartitionsOrReplicated, ...], out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]], local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]], local_out_parts_thunk: Callable[[], Optional[Tuple[ pxla.PartitionsOrReplicated, ...]]], local_nparts: Optional[int], name: str, *abstract_args): nrep = 1 if local_in_parts is None: local_in_parts = in_parts global_abstract_args = [ pxla.get_global_aval(arg, parts, lparts) for arg, parts, lparts in safe_zip( abstract_args, in_parts, local_in_parts) ] if logging.vlog_is_on(2): logging.vlog(2, "abstract_args: %s", abstract_args) logging.vlog(2, "global_abstract_args: %s", global_abstract_args) logging.vlog(2, "in_parts: %s", in_parts) logging.vlog(2, "local_in_parts: %s", local_in_parts) jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( fun, global_abstract_args) platform = xb.get_backend().platform if platform not in ["tpu", "gpu"]: # TODO(skye): fall back to regular jit? raise ValueError(f"sharded_jit not supported for {platform}") nparts = pxla.reconcile_num_partitions(jaxpr, nparts) assert nparts is not None if nparts > xb.device_count(): raise ValueError( f"sharded_jit computation requires {nparts} devices, " f"but only {xb.device_count()} devices are available.") if xb.local_device_count() < nparts < xb.device_count(): raise NotImplementedError( f"sharded_jit across multiple hosts must use all available devices. " f"Got {nparts} out of {xb.device_count()} requested devices " f"(local device count: {xb.local_device_count()})") if local_nparts is None: if nparts > xb.local_device_count(): raise ValueError( "Specify 'local_nparts' when using cross-process sharded_jit " "and all inputs and outputs are replicated.") else: local_nparts = nparts if local_nparts > xb.local_device_count(): raise ValueError( f"sharded_jit computation requires {local_nparts} local devices, " f"but only {xb.local_device_count()} local devices are available.") if logging.vlog_is_on(2): logging.vlog(2, "nparts: %d local_nparts: %d", nparts, local_nparts) out_parts = out_parts_thunk() local_out_parts = local_out_parts_thunk() if local_out_parts is None: local_out_parts = out_parts if logging.vlog_is_on(2): logging.vlog(2, "out_parts: %s", out_parts) logging.vlog(2, "local_out_parts: %s", local_out_parts) local_out_avals = [ pxla.get_local_aval(out, parts, lparts) for out, parts, lparts in safe_zip( global_out_avals, out_parts, local_out_parts) ] log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logging.log(log_priority, "Compiling %s for %d devices with args %s.", fun.__name__, nparts, global_abstract_args) c = xc.XlaBuilder("spjit_{}".format(fun.__name__)) xla_consts = _map(partial(xla.pyval_to_ir_constant, c), consts) xla_args = _xla_sharded_args(c, global_abstract_args, in_parts) axis_env = xla.AxisEnv(nrep, (), ()) ctx = xla.TranslationContext( c, platform, axis_env, extend_name_stack(wrap_name(name, "sharded_jit"))) out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args) out_tuple = xla.with_sharding(c, out_parts, xops.Tuple, c, out_nodes) built = c.Build(out_tuple) if nparts <= xb.local_device_count(): devices = xb.local_devices()[:nparts] else: assert nparts == xb.device_count() devices = xb.devices() device_assignment = np.array([[d.id for d in devices]]) device_assignment = np.reshape(device_assignment, (-1, nparts)) # device_assignment = None # TODO(skye): replace with default device assignment? compiled = dispatch.backend_compile( xb.get_backend(), built, xb.get_compile_options(nrep, nparts, device_assignment)) input_specs = [ pxla.partitioned_sharding_spec(local_nparts, parts, aval) for parts, aval in zip(local_in_parts, abstract_args) ] input_indices = [ pxla.spec_to_indices(aval.shape, spec) if spec is not None else None for aval, spec in zip(abstract_args, input_specs) ] handle_args = partial(pxla.shard_args, compiled.local_devices(), input_indices) handle_outs = _avals_to_results_handler( nrep, local_nparts, # type: ignore local_out_parts, local_out_avals) return partial(_execute_spatially_partitioned, compiled, handle_args, handle_outs)