def _while_loop_translation_rule(c, axis_env, *args, **kwargs): backend = kwargs.pop('backend', None) cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts = split_dict( kwargs, ["cond_jaxpr", "body_jaxpr", "cond_nconsts", "body_nconsts"]) cond_consts, body_consts, init_vals = split_list( args, [cond_nconsts, body_nconsts]) batched = bool(cond_jaxpr.out_avals[0].shape) # Since jaxprs don't have tuples and have multiple return values, but we need # the HLO While loop to take a single tuple input and output a single boolean # (for the cond computation) or a single tuple output (for the body # computation), we build XLA computations that handle the tuple munging before # generating a Call into the computations formed from the jaxprs. init_carry = c.Tuple(*(cond_consts + body_consts + init_vals)) cond_c = xb.make_computation_builder("cond_computation") cond_carry = cond_c.ParameterWithShape(c.GetShape(init_carry)) cond_carry_elts = [ cond_c.GetTupleElement(cond_carry, i) for i in range(len(args)) ] x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts]) pred, = xla.jaxpr_subcomp(cond_c, cond_jaxpr.jaxpr, backend, axis_env, _map(cond_c.Constant, cond_jaxpr.literals), (), *(x + z)) if batched: scalar = xla_client.Shape.array_shape(onp.dtype(onp.bool_), ()) or_ = xla.primitive_computation(lax.or_p, scalar, scalar) pred = cond_c.Reduce(pred, cond_c.Constant(onp.array(False)), or_, list(range(cond_jaxpr.out_avals[0].ndim))) body_c = xb.make_computation_builder("body_computation") body_carry = body_c.ParameterWithShape(c.GetShape(init_carry)) body_carry_elts = [ body_c.GetTupleElement(body_carry, i) for i in range(len(args)) ] x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts]) new_z = xla.jaxpr_subcomp(body_c, body_jaxpr.jaxpr, backend, axis_env, _map(body_c.Constant, body_jaxpr.literals), (), *(y + z)) if batched: body_pred, = xla.jaxpr_subcomp( body_c, cond_jaxpr.jaxpr, backend, axis_env, _map(body_c.Constant, cond_jaxpr.literals), (), *(x + z)) new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z) assert _map(body_c.GetShape, new_z) == _map(body_c.GetShape, z) # no broadcast new_carry = body_c.Tuple(*itertools.chain(x, y, new_z)) ans = c.While(cond_c.Build(pred), body_c.Build(new_carry), init_carry) ans_elts = [c.GetTupleElement(ans, i) for i in range(len(args))] _, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts]) return c.Tuple(*z)
def make_computation(name, jaxpr, op_shape): c = xb.make_computation_builder(name) op = c.ParameterWithShape(op_shape) ops = [c.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))] outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env, _map(c.Constant, jaxpr.literals), (), *ops) return c.Build(c.Tuple(*outs))
def _sharded_jit_translation_rule(ctx, avals_in, avals_out, *in_nodes, in_parts, out_parts_thunk, nparts, name, call_jaxpr, local_in_parts, local_out_parts_thunk, local_nparts): subc = xc.XlaBuilder(f"sharded_jit_{name}") # We assume any extra leading in_nodes are constants and replicate them. num_extra_nodes = len(in_nodes) - len(in_parts) assert num_extra_nodes >= 0 in_parts = (None,) * num_extra_nodes + in_parts args = [] for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)): # We use xla.set_sharding instead of xla.with_sharding because inlined calls # shouldn't have shardings set directly on the inputs or outputs. arg = xla.parameter(subc, i, ctx.builder.GetShape(n)) args.append(xla.set_sharding(subc, arg, sharding)) sub_ctx = ctx.replace( builder=subc, name_stack=new_name_stack(wrap_name(name, "sharded_jit"))) out_nodes = xla.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args) out_parts = out_parts_thunk() assert len(out_parts) == len(out_nodes) out_nodes = [xla.set_sharding(subc, out, sharding) for out, sharding in safe_zip(out_nodes, out_parts)] subc = subc.build(xops.Tuple(subc, out_nodes)) return xla.xla_destructure(ctx.builder, xops.Call(ctx.builder, subc, list(in_nodes)))
def remat_translation(ctx, avals_in, avals_out, *in_nodes, jaxpr, prevent_cse, differentiated, policy): del policy # Unused. if differentiated and prevent_cse: if ctx.platform == "gpu": return xla._remat_using_while(ctx, in_nodes, "checkpoint", jaxpr) else: return xla._remat_using_cond(ctx, in_nodes, "checkpoint", jaxpr) else: return xla.jaxpr_subcomp(ctx, jaxpr, (), *in_nodes)
def _named_call_translation_rule(comp_builder: 'xla.xb._JaxComputationBuilder', axis_env: xla.AxisEnv, in_nodes: 'Sequence[xla.xc._xla.XlaOp]', name_stack: str, backend: Optional[Any], name: str, call_jaxpr: core.Jaxpr): """Compile and add a custom name to the XLA metadata.""" subcomp_builder = xla.xb.make_computation_builder(f'named_call_{name}') args = [xla.xb.parameter(subcomp_builder, i, comp_builder.GetShape(n)) for i, n in enumerate(in_nodes)] out_nodes = xla.jaxpr_subcomp(subcomp_builder, call_jaxpr, backend, axis_env, (), jax.util.extend_name_stack(name_stack, name), *args) subcomp = subcomp_builder.Build(xla.xops.Tuple(subcomp_builder, out_nodes)) return xla.xops.Call(comp_builder, subcomp, list(in_nodes))
def _named_call_translation_rule(c, axis_env, in_nodes, name_stack, *, name='core_call', backend, call_jaxpr): subc = xla.xb.make_computation_builder(name) args = [ xla.xb.parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes) ] out_nodes = xla.jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (), jax.util.extend_name_stack(name_stack, name), *args) subc = subc.Build(xla.xops.Tuple(subc, out_nodes)) return xla.xops.Call(c, subc, list(in_nodes))
def _sharded_callable( fun: lu.WrappedFun, nparts: Optional[int], in_parts: Tuple[pxla.PartitionsOrReplicated, ...], out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]], local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]], local_out_parts_thunk: Callable[[], Optional[Tuple[ pxla.PartitionsOrReplicated, ...]]], local_nparts: Optional[int], name: str, *abstract_args): nrep = 1 if local_in_parts is None: local_in_parts = in_parts global_abstract_args = [ pxla.get_global_aval(arg, parts, lparts) for arg, parts, lparts in safe_zip( abstract_args, in_parts, local_in_parts) ] if logging.vlog_is_on(2): logging.vlog(2, "abstract_args: %s", abstract_args) logging.vlog(2, "global_abstract_args: %s", global_abstract_args) logging.vlog(2, "in_parts: %s", in_parts) logging.vlog(2, "local_in_parts: %s", local_in_parts) jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( fun, global_abstract_args) platform = xb.get_backend().platform if platform not in ["tpu", "gpu"]: # TODO(skye): fall back to regular jit? raise ValueError(f"sharded_jit not supported for {platform}") nparts = pxla.reconcile_num_partitions(jaxpr, nparts) assert nparts is not None if nparts > xb.device_count(): raise ValueError( f"sharded_jit computation requires {nparts} devices, " f"but only {xb.device_count()} devices are available.") if xb.local_device_count() < nparts < xb.device_count(): raise NotImplementedError( f"sharded_jit across multiple hosts must use all available devices. " f"Got {nparts} out of {xb.device_count()} requested devices " f"(local device count: {xb.local_device_count()})") if local_nparts is None: if nparts > xb.local_device_count(): raise ValueError( "Specify 'local_nparts' when using cross-process sharded_jit " "and all inputs and outputs are replicated.") else: local_nparts = nparts if local_nparts > xb.local_device_count(): raise ValueError( f"sharded_jit computation requires {local_nparts} local devices, " f"but only {xb.local_device_count()} local devices are available.") if logging.vlog_is_on(2): logging.vlog(2, "nparts: %d local_nparts: %d", nparts, local_nparts) out_parts = out_parts_thunk() local_out_parts = local_out_parts_thunk() if local_out_parts is None: local_out_parts = out_parts if logging.vlog_is_on(2): logging.vlog(2, "out_parts: %s", out_parts) logging.vlog(2, "local_out_parts: %s", local_out_parts) local_out_avals = [ pxla.get_local_aval(out, parts, lparts) for out, parts, lparts in safe_zip( global_out_avals, out_parts, local_out_parts) ] log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logging.log(log_priority, "Compiling %s for %d devices with args %s.", fun.__name__, nparts, global_abstract_args) c = xc.XlaBuilder("spjit_{}".format(fun.__name__)) xla_consts = _map(partial(xla.pyval_to_ir_constant, c), consts) xla_args = _xla_sharded_args(c, global_abstract_args, in_parts) axis_env = xla.AxisEnv(nrep, (), ()) ctx = xla.TranslationContext( c, platform, axis_env, extend_name_stack(wrap_name(name, "sharded_jit"))) out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args) out_tuple = xla.with_sharding(c, out_parts, xops.Tuple, c, out_nodes) built = c.Build(out_tuple) if nparts <= xb.local_device_count(): devices = xb.local_devices()[:nparts] else: assert nparts == xb.device_count() devices = xb.devices() device_assignment = np.array([[d.id for d in devices]]) device_assignment = np.reshape(device_assignment, (-1, nparts)) # device_assignment = None # TODO(skye): replace with default device assignment? compiled = dispatch.backend_compile( xb.get_backend(), built, xb.get_compile_options(nrep, nparts, device_assignment)) input_specs = [ pxla.partitioned_sharding_spec(local_nparts, parts, aval) for parts, aval in zip(local_in_parts, abstract_args) ] input_indices = [ pxla.spec_to_indices(aval.shape, spec) if spec is not None else None for aval, spec in zip(abstract_args, input_specs) ] handle_args = partial(pxla.shard_args, compiled.local_devices(), input_indices) handle_outs = _avals_to_results_handler( nrep, local_nparts, # type: ignore local_out_parts, local_out_avals) return partial(_execute_spatially_partitioned, compiled, handle_args, handle_outs)