def _remat_using_cond(ctx, in_nodes, name, call_jaxpr): """Lower remat to a Conditional which always returns true. This: 1. Circumvents common subexpression elimination. 2. In common case of `jax.grad(jax.remat(f))`, ensures the remat blocks occur after the primal blocks, because cotangent is an input to the Conditional.""" # Fake condition which always selects True branch. c = ctx.builder rng = xops.RngUniform(xops.Constant(c, np.array(0, dtype=np.float32)), xops.Constant(c, np.array(1, dtype=np.float32)), xc.Shape.array_shape(xc.PrimitiveType.F32, [])) pred = xops.Lt(rng, xops.Constant(c, np.array(2, dtype=np.float32))) true_op = xops.Tuple(c, in_nodes) remat_subc = xc.XlaBuilder("remat_call_subcomputation") input_op = parameter(remat_subc, 0, c.get_shape(true_op), replicated=[]) args = xla_destructure(remat_subc, input_op) sub_ctx = ctx.replace(builder=remat_subc, name_stack=extend_name_stack( ctx.name_stack, wrap_name(name, 'remat'))) out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args) out_node_shapes = [remat_subc.get_shape(o) for o in out_nodes] remat_subc = remat_subc.build(xops.Tuple(remat_subc, out_nodes)) false_op = true_op dummy_subc = xc.XlaBuilder("remat_call_dummy_subcomputation") parameter(dummy_subc, 0, c.get_shape(false_op), replicated=[]) out_nodes = [_zeros(dummy_subc, s) for s in out_node_shapes] dummy_subc = dummy_subc.build(xops.Tuple(dummy_subc, out_nodes)) return xla_destructure( c, xops.Conditional(pred, true_op, remat_subc, false_op, dummy_subc))
def _xla_call_translation_rule(ctx, avals_in, avals_out, *in_nodes, name, backend=None, call_jaxpr, donated_invars, inline=None, device=None): del device, donated_invars, inline # Ignored. c = ctx.builder check_backend_matches(backend, ctx.platform) subc = xc.XlaBuilder(f"jit_{name}") args = [parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)] sub_ctx = ctx.replace(builder=subc, name_stack=extend_name_stack(ctx.name_stack, wrap_name(name, 'jit'))) out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args) if len(out_nodes) == 1: subc = subc.Build(out_nodes[0]) return [xops.Call(c, subc, list(in_nodes))] else: subc = subc.Build(xops.Tuple(subc, out_nodes)) return xla_destructure(c, xops.Call(c, subc, list(in_nodes)))
def _sharded_jit_translation_rule(c, axis_env, in_nodes, name_stack, in_parts, out_parts_thunk, nparts, backend, name, call_jaxpr, local_in_parts, local_out_parts_thunk, local_nparts): subc = xc.XlaBuilder(f"sharded_jit_{name}") # We assume any extra leading in_nodes are constants and replicate them. num_extra_nodes = len(in_nodes) - len(in_parts) assert num_extra_nodes >= 0 in_parts = (None,) * num_extra_nodes + in_parts args = [] for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)): # We use xb.set_sharding instead of xb.with_sharding because inlined calls # shouldn't have shardings set directly on the inputs or outputs. arg = xb.parameter(subc, i, c.GetShape(n)) args.append(xb.set_sharding(subc, arg, sharding)) out_nodes = xla.jaxpr_subcomp( subc, call_jaxpr, backend, axis_env, (), extend_name_stack(name_stack, wrap_name(name, "sharded_jit")), *args) out_parts = out_parts_thunk() assert len(out_parts) == len(out_nodes) out_nodes = [xb.set_sharding(subc, out, sharding) for out, sharding in safe_zip(out_nodes, out_parts)] subc = subc.build(xops.Tuple(subc, out_nodes)) return xops.Call(c, subc, list(in_nodes))
def _pjit_translation_rule(c, axis_env, in_nodes, name_stack, backend, name, jaxpr, in_axis_resources, out_axis_resources, resource_env, donated_invars, positional_semantics): mesh = resource_env.physical_mesh subc = xc.XlaBuilder(f"pjit_{name}") args = [] for i, (n, axis_resources) in enumerate(safe_zip(in_nodes, in_axis_resources)): # N.B. inlined calls shouldn't have shardings set directly on the inputs or # outputs (set_sharding_proto adds an identity operation). arg = xb.parameter(subc, i, c.GetShape(n)) args.append( xb.set_sharding_proto( subc, arg, get_sharding_proto(c, n, axis_resources, mesh))) # TODO: Think about how to avoid duplicating constants with the outer jaxpr out_nodes = xla.jaxpr_subcomp( subc, jaxpr.jaxpr, backend, axis_env, xla._xla_consts(subc, jaxpr.consts), extend_name_stack(name_stack, wrap_name(name, "pjit")), *args) out_nodes = [ xb.set_sharding_proto( subc, out, get_sharding_proto(subc, out, axis_resources, mesh)) for out, axis_resources in safe_zip(out_nodes, out_axis_resources) ] subc = subc.build(xops.Tuple(subc, out_nodes)) return xops.Call(c, subc, list(in_nodes))
def _sharded_jit_translation_rule(ctx, avals_in, avals_out, *in_nodes, in_parts, out_parts_thunk, nparts, name, call_jaxpr, local_in_parts, local_out_parts_thunk, local_nparts): subc = xc.XlaBuilder(f"sharded_jit_{name}") # We assume any extra leading in_nodes are constants and replicate them. num_extra_nodes = len(in_nodes) - len(in_parts) assert num_extra_nodes >= 0 in_parts = (None,) * num_extra_nodes + in_parts args = [] for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)): # We use xla.set_sharding instead of xla.with_sharding because inlined calls # shouldn't have shardings set directly on the inputs or outputs. arg = xla.parameter(subc, i, ctx.builder.GetShape(n)) args.append(xla.set_sharding(subc, arg, sharding)) sub_ctx = ctx.replace( builder=subc, name_stack=new_name_stack(wrap_name(name, "sharded_jit"))) out_nodes = xla.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args) out_parts = out_parts_thunk() assert len(out_parts) == len(out_nodes) out_nodes = [xla.set_sharding(subc, out, sharding) for out, sharding in safe_zip(out_nodes, out_parts)] subc = subc.build(xops.Tuple(subc, out_nodes)) return xla.xla_destructure(ctx.builder, xops.Call(ctx.builder, subc, list(in_nodes)))
def primitive_subcomputation(platform: str, axis_env: 'AxisEnv', prim: core.Primitive, avals_in: Sequence[core.AbstractValue], avals_out: Sequence[core.AbstractValue], **params): c = xc.XlaBuilder(f"primitive_computation_{prim.name}") counts = it.count() xla_args = [ parameter(c, next(counts), xla_shape) for a in avals_in for xla_shape in aval_to_xla_shapes(a) ] if (platform is not None and prim in _backend_specific_translations[platform]): rule = _backend_specific_translations[platform][prim] elif prim in _translations: rule = _translations[prim] ctx = TranslationContext(builder=c, platform=platform, axis_env=axis_env, name_stack=new_name_stack()) ans = rule(ctx, avals_in, avals_out, *xla_args, **params) if prim.multiple_results: return c.build(xops.Tuple(c, ans)) else: x, = ans return c.build(x)
def lower_jaxpr_to_xla_module( fn_name: str, jaxpr: core.ClosedJaxpr, platform: str, axis_env: AxisEnv, name_stack: str, tuple_args: bool, donated_invars: Sequence[bool], replicated_args: Optional[Sequence[bool]], arg_partitions: Optional[Any], out_partitions: Optional[Any], partitions_are_protos: bool = False) -> xc.XlaComputation: """Lowers a closed jaxpr to a top-level XLA module.""" c = xc.XlaBuilder(fn_name) xla_consts = _xla_consts(c, jaxpr.consts) xla_args, donated_invars = _xla_callable_args( c, jaxpr.in_avals, tuple_args, donated_invars=donated_invars, replicated=replicated_args, partitions=arg_partitions, partitions_proto=partitions_are_protos) ctx = TranslationContext(c, platform, axis_env, name_stack) out_nodes = jaxpr_subcomp(ctx, jaxpr.jaxpr, xla_consts, *xla_args) # Replace tokens with a dummy array value, because the runtime cannot # handle token arguments. out_aval_lens = [len(aval_to_xla_shapes(a)) for a in jaxpr.out_avals] out_nodes = util.flatten( [[_make_token_return_value(c)] if a is core.abstract_token else v for a, v in zip(jaxpr.out_avals, util.unflatten(out_nodes, out_aval_lens))]) # There is a non-zero cost to building an output tuple, particularly on TPU. # Avoid it if the output arity is 1. if out_partitions is None: output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple( c, out_nodes) else: build_out_tuple = partial(xops.Tuple, c, out_nodes) if partitions_are_protos: output = with_sharding_proto(c, out_partitions, build_out_tuple) else: output = with_sharding(c, out_partitions, build_out_tuple) if platform in ("gpu", "tpu"): donated_invars = set_up_aliases(c, xla_args, c.GetShape(output), donated_invars, tuple_args) if any(donated_invars): # TODO(tomhennigan): At call time we should mark these buffers as deleted. unused_donations = [ str(c.GetShape(a)) for a, d in zip(xla_args, donated_invars) if d ] warnings.warn("Some donated buffers were not usable: {}".format( ", ".join(unused_donations))) return c.build(output)
def _remat_using_while(ctx, in_nodes, name, call_jaxpr): """Lower remat to a single iteration while loop.""" c = ctx.builder # Dummy subc for getting subcomp shapes. dummy_inputs = xops.Tuple(c, in_nodes) dummy_subc = xc.XlaBuilder("remat_dummy_subcomputation") dummy_input_op = parameter(dummy_subc, 0, c.get_shape(dummy_inputs), replicated=[]) dummy_args = xla_destructure(dummy_subc, dummy_input_op) dummy_ctx = ctx.replace(builder=dummy_subc, name_stack=extend_name_stack( ctx.name_stack, wrap_name(name, 'remat'))) dummy_subcomp_outs = jaxpr_subcomp(dummy_ctx, call_jaxpr, (), *dummy_args) out_node_shapes = [dummy_subc.get_shape(o) for o in dummy_subcomp_outs] i_init = xops.Constant(c, np.array(0, dtype=np.int32)) zeros_like_outs = [_zeros(c, s) for s in out_node_shapes] inputs = xops.Tuple(c, [i_init] + list(in_nodes) + zeros_like_outs) cond_subc = xc.XlaBuilder("remat_cond_subcomputation") input_op = parameter(cond_subc, 0, c.get_shape(inputs), replicated=[]) i = xops.GetTupleElement(input_op, 0) rng = xops.RngUniform( xops.Constant(cond_subc, np.array(1, dtype=np.int32)), xops.Constant(cond_subc, np.array(2, dtype=np.int32)), xc.Shape.array_shape(xc.PrimitiveType.S32, [])) cond_subc = cond_subc.build(xops.Lt(i, rng)) body_subc = xc.XlaBuilder("remat_body_subcomputation") input_op = parameter(body_subc, 0, c.get_shape(inputs), replicated=[]) i, *args = xla_destructure(body_subc, input_op)[:len(in_nodes) + 1] i_next = xops.Add(i, xops.Constant(body_subc, np.array(1, dtype=np.int32))) body_ctx = ctx.replace(builder=body_subc, name_stack=extend_name_stack( ctx.name_stack, wrap_name(name, 'remat'))) subcomp_outs = jaxpr_subcomp(body_ctx, call_jaxpr, (), *args) out_nodes = [i_next] + args + list(subcomp_outs) body_subc = body_subc.build(xops.Tuple(body_subc, out_nodes)) outs = xops.While(cond_subc, body_subc, inputs) return xla_destructure(c, outs)[len(in_nodes) + 1:]
def _comparator_builder(op_type, is_max_k): c = xc.XlaBuilder('top_k_{}_comparator'.format('gt' if is_max_k else 'lt')) p0 = xla.parameter(c, 0, xc.Shape.scalar_shape(op_type)) p1 = xla.parameter(c, 1, xc.Shape.scalar_shape(op_type)) xla.parameter(c, 2, xc.Shape.scalar_shape(np.dtype(np.int32))) xla.parameter(c, 3, xc.Shape.scalar_shape(np.dtype(np.int32))) if is_max_k: cmp_result = xc.ops.Gt(p0, p1) else: cmp_result = xc.ops.Lt(p0, p1) return c.build(cmp_result)
def _named_call_translation_rule(ctx, avals_in, avals_out, *in_nodes, name="core_call", backend=None, call_jaxpr): check_backend_matches(backend, ctx.platform) c = ctx.builder subc = xc.XlaBuilder(name) args = [parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)] sub_ctx = ctx.replace(builder=subc, name_stack=extend_name_stack(ctx.name_stack, name)) out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args) subc = subc.Build(xops.Tuple(subc, out_nodes)) return xla_destructure(c, xops.Call(c, subc, list(in_nodes)))
def primitive_subcomputation(platform: str, axis_env: 'AxisEnv', prim: core.Primitive, *avals: core.AbstractValue, **params): c = xc.XlaBuilder(f"primitive_computation_{prim.name}") f = lower_fun(prim.bind, multiple_results=prim.multiple_results, new_style=True) xla_args, _ = _xla_callable_args(c, avals, tuple_args=False, filter_tokens=False) ctx = TranslationContext(builder=c, platform=platform, axis_env=axis_env, name_stack=new_name_stack()) ans = f(ctx.replace(builder=c), avals, None, *xla_args, **params) if prim.multiple_results: ans = xops.Tuple(c, ans) else: ans, = ans return c.build(ans)
def _sharded_callable( fun: lu.WrappedFun, nparts: Optional[int], in_parts: Tuple[pxla.PartitionsOrReplicated, ...], out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]], local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]], local_out_parts_thunk: Callable[[], Optional[Tuple[ pxla.PartitionsOrReplicated, ...]]], local_nparts: Optional[int], name: str, *abstract_args): nrep = 1 if local_in_parts is None: local_in_parts = in_parts global_abstract_args = [ pxla.get_global_aval(arg, parts, lparts) for arg, parts, lparts in safe_zip( abstract_args, in_parts, local_in_parts) ] if logging.vlog_is_on(2): logging.vlog(2, "abstract_args: %s", abstract_args) logging.vlog(2, "global_abstract_args: %s", global_abstract_args) logging.vlog(2, "in_parts: %s", in_parts) logging.vlog(2, "local_in_parts: %s", local_in_parts) jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( fun, global_abstract_args) platform = xb.get_backend().platform if platform not in ["tpu", "gpu"]: # TODO(skye): fall back to regular jit? raise ValueError(f"sharded_jit not supported for {platform}") nparts = pxla.reconcile_num_partitions(jaxpr, nparts) assert nparts is not None if nparts > xb.device_count(): raise ValueError( f"sharded_jit computation requires {nparts} devices, " f"but only {xb.device_count()} devices are available.") if xb.local_device_count() < nparts < xb.device_count(): raise NotImplementedError( f"sharded_jit across multiple hosts must use all available devices. " f"Got {nparts} out of {xb.device_count()} requested devices " f"(local device count: {xb.local_device_count()})") if local_nparts is None: if nparts > xb.local_device_count(): raise ValueError( "Specify 'local_nparts' when using cross-process sharded_jit " "and all inputs and outputs are replicated.") else: local_nparts = nparts if local_nparts > xb.local_device_count(): raise ValueError( f"sharded_jit computation requires {local_nparts} local devices, " f"but only {xb.local_device_count()} local devices are available.") if logging.vlog_is_on(2): logging.vlog(2, "nparts: %d local_nparts: %d", nparts, local_nparts) out_parts = out_parts_thunk() local_out_parts = local_out_parts_thunk() if local_out_parts is None: local_out_parts = out_parts if logging.vlog_is_on(2): logging.vlog(2, "out_parts: %s", out_parts) logging.vlog(2, "local_out_parts: %s", local_out_parts) local_out_avals = [ pxla.get_local_aval(out, parts, lparts) for out, parts, lparts in safe_zip( global_out_avals, out_parts, local_out_parts) ] log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logging.log(log_priority, "Compiling %s for %d devices with args %s.", fun.__name__, nparts, global_abstract_args) c = xc.XlaBuilder("spjit_{}".format(fun.__name__)) xla_consts = _map(partial(xla.pyval_to_ir_constant, c), consts) xla_args = _xla_sharded_args(c, global_abstract_args, in_parts) axis_env = xla.AxisEnv(nrep, (), ()) ctx = xla.TranslationContext( c, platform, axis_env, extend_name_stack(wrap_name(name, "sharded_jit"))) out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args) out_tuple = xla.with_sharding(c, out_parts, xops.Tuple, c, out_nodes) built = c.Build(out_tuple) if nparts <= xb.local_device_count(): devices = xb.local_devices()[:nparts] else: assert nparts == xb.device_count() devices = xb.devices() device_assignment = np.array([[d.id for d in devices]]) device_assignment = np.reshape(device_assignment, (-1, nparts)) # device_assignment = None # TODO(skye): replace with default device assignment? compiled = dispatch.backend_compile( xb.get_backend(), built, xb.get_compile_options(nrep, nparts, device_assignment)) input_specs = [ pxla.partitioned_sharding_spec(local_nparts, parts, aval) for parts, aval in zip(local_in_parts, abstract_args) ] input_indices = [ pxla.spec_to_indices(aval.shape, spec) if spec is not None else None for aval, spec in zip(abstract_args, input_specs) ] handle_args = partial(pxla.shard_args, compiled.local_devices(), input_indices) handle_outs = _avals_to_results_handler( nrep, local_nparts, # type: ignore local_out_parts, local_out_avals) return partial(_execute_spatially_partitioned, compiled, handle_args, handle_outs)
def test_parameter_replication(self): c = xc.XlaBuilder("test") _ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()), "", False) built_c = c.Build() assert "parameter_replication={false}" in built_c.as_hlo_text()
def test_parameter_replication_default(self): c = xc.XlaBuilder("test") _ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ())) built_c = c.Build() assert "replication" not in built_c.as_hlo_text()