def _xla_call_lower(ctx, *args, backend=None, name, call_jaxpr, donated_invars, inline=None, device=None): del device, donated_invars, inline # Ignored. return _call_lowering(f"jit_{name}", xla.wrap_name(name, "jit"), call_jaxpr, backend, ctx.module_context, ctx.avals_in, ctx.avals_out, *args)
def _remat_using_while(ctx, avals_in, avals_out, *args, name, call_jaxpr): input_types = map(aval_to_ir_types, avals_in) output_types = map(aval_to_ir_types, avals_out) flat_output_types = util.flatten(output_types) int32_scalar_type = aval_to_ir_type( core.ShapedArray((), np.dtype(np.int32))) loop_carry_types = [(int32_scalar_type, )] + input_types + output_types flat_loop_carry_types = util.flatten(loop_carry_types) counter_init = ir_constants(np.array(0, np.int32)) flat_args = flatten_lowering_ir_args((counter_init, ) + args + tuple( _dummy_like_aval(aval) for aval in avals_out)) loop_carry_tuple_type = ir.TupleType.get_tuple(flat_loop_carry_types) init_carry = mhlo.TupleOp(loop_carry_tuple_type, flat_args) one = ir_constant(np.array(1, np.int32)) while_op = mhlo.WhileOp([loop_carry_tuple_type], [init_carry.result]) # Loop condition cond_block = while_op.regions[0].blocks.append(loop_carry_tuple_type) with ir.InsertionPoint(cond_block): bool_scalar_type = aval_to_ir_type( core.ShapedArray((), np.dtype(np.bool_))) two = ir_constant(np.array(2, np.int32)) shape = ir_constant(np.array((), np.int64), canonicalize_types=False) rng = mhlo.RngUniformOp(one, two, shape).result i = mhlo.GetTupleElementOp(int32_scalar_type, cond_block.arguments[0], i32_attr(0)) cmp = mhlo.CompareOp(bool_scalar_type, i, rng, ir.StringAttr.get("LT"), ir.StringAttr.get("SIGNED")).result mhlo.ReturnOp([cmp]) body_block = while_op.regions[1].blocks.append(loop_carry_tuple_type) with ir.InsertionPoint(body_block): flat_body_args = [ mhlo.GetTupleElementOp(input_type, body_block.arguments[0], i32_attr(i)).result for i, input_type in enumerate(flat_loop_carry_types) ] body_args = util.unflatten(flat_body_args, map(len, loop_carry_types)) ((i, ), ), y, _ = util.split_list(body_args, [1, len(avals_in)]) body_ctx = ctx.replace(name_stack=xla.extend_name_stack( ctx.name_stack, xla.wrap_name(name, 'remat'))) z = jaxpr_subcomp(body_ctx, call_jaxpr, (), *y) i_next = mhlo.AddOp(i, one).result new_carry = mhlo.TupleOp(loop_carry_tuple_type, [i_next, *util.flatten(y), *util.flatten(z)]) mhlo.ReturnOp([new_carry.result]) outputs = [ mhlo.GetTupleElementOp(output_type, while_op.result, i32_attr(1 + len(avals_in) + i)).result for i, output_type in enumerate(flat_output_types) ] return util.unflatten(outputs, map(len, output_types))
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *arg_specs): if device is not None and backend is not None: raise ValueError("can't specify both a device and a backend for jit, " "got device={} and backend={}".format(device, backend)) abstract_args, arg_devices = util.unzip2(arg_specs) with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " "for jit in {elapsed_time} sec"): jaxpr, out_avals, consts = pe.trace_to_jaxpr_final( fun, abstract_args, pe.debug_info_final(fun, "jit")) if any(isinstance(c, core.Tracer) for c in consts): raise UnexpectedTracerError("Encountered an unexpected tracer.") jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr) consts = [c for i, c in enumerate(consts) if i in kept_const_idx] pruned_arg_specs = (a for i, a in enumerate(arg_specs) if i in kept_var_idx) abstract_args, arg_devices = util.unzip2(pruned_arg_specs) donated_invars = [ x for i, x in enumerate(donated_invars) if i in kept_var_idx ] map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr))) jaxpr = apply_outfeed_rewriter(jaxpr) nreps = jaxpr_replicas(jaxpr) device = _xla_callable_device(nreps, backend, device, arg_devices) backend = xb.get_device_backend(device) if device else xb.get_backend(backend) # Computations that only produce constants and/or only rearrange their inputs, # which are often produced from partial evaluation, don't need compilation, # and don't need to evaluate their arguments. if not jaxpr.eqns: return XlaComputation( name, None, True, None, jaxpr=jaxpr, consts=consts, device=device, in_avals=abstract_args, out_avals=out_avals, kept_var_idx=kept_var_idx) if not _on_exit: log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG if len(abstract_args) > 10: msg = f"Compiling {fun.__name__} ({id(fun)}) for {len(abstract_args)} args." else: msg = f"Compiling {fun.__name__} ({id(fun)} for args {abstract_args}." logging.log(log_priority, msg) if nreps > 1: warnings.warn( f"The jitted function {name} includes a pmap. Using " "jit-of-pmap can lead to inefficient data movement, as the outer jit " "does not preserve sharded data representations and instead collects " "input and output arrays onto a single device. " "Consider removing the outer jit unless you know what you're doing. " "See https://github.com/google/jax/issues/2926.") if nreps > xb.device_count(backend): raise ValueError( f"compiling computation `{name}` that requires {nreps} replicas, but " f"only {xb.device_count(backend)} XLA devices are available.") if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)): raise NotImplementedError( "jit of multi-host pmap not implemented (and jit-of-pmap can cause " "extra data movement anyway, so maybe you don't want it after all).") # pass long arg lists as tuple for TPU tuple_args = len(abstract_args) > 100 axis_env = xla.AxisEnv(nreps, (), ()) name_stack = xla.new_name_stack(xla.wrap_name(name, 'jit')) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) module: Union[str, xc.XlaComputation] module_name = f"jit_{fun.__name__}" if config.jax_enable_mlir: module = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, backend.platform, mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars) else: module = xla.lower_jaxpr_to_xla_module( module_name, closed_jaxpr, backend.platform, axis_env, name_stack, tuple_args, donated_invars, replicated_args=None, arg_partitions=None, out_partitions=None) return XlaComputation( name, module, False, donated_invars, nreps=nreps, device=device, backend=backend, tuple_args=tuple_args, in_avals=abstract_args, out_avals=out_avals, kept_var_idx=kept_var_idx)
def lower_jaxpr_to_fun( ctx: ModuleContext, name: str, jaxpr: core.ClosedJaxpr, *, public: bool = False, replace_units_with_dummy: bool = False, replace_tokens_with_dummy: bool = False, replicated_args: Optional[Sequence[bool]] = None, arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None, result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None, use_sharding_annotations: bool = True, input_output_aliases: Optional[Sequence[Optional[int]]] = None ) -> FuncOpType: """Lowers jaxpr and its callees to an IR function. Assumes that an MLIR context, location, and insertion point are set. Args: ctx: the lowering context. name: the function name. The name will be uniquified by the symbol table, so it is ok to use the same name multiple times. jaxpr: the jaxpr to lower. public: if true, the function's visibility is set to "public". replace_units_with_dummy: if true, unit arguments/return values are replaced with bool arrays of size [0]. replace_tokens_with_dummy: if true, token arguments/return values are replaced with bool arrays of size [0]. replicated_args: if present, annotates arguments as replicated. arg_shardings: sharding annotations for each argument (optional). result_shardings: sharding annotations for each argument (optional). use_sharding_annotations: if True, use mhlo.sharding annotations on parameters and return values to express sharding. If False, use mhlo.custom_call operators with sharding annotations. TODO(b/228598865): remove this option when mhlo.sharding annotations are propagated on non-entry functions during MHLO->HLO conversion. input_output_aliases: optional sequence that maps argument numbers to the corresponding output that should alias them. Returns the name of the function. """ def aval_to_types(aval): if replace_units_with_dummy and aval is core.abstract_unit: aval = core.ShapedArray((), np.dtype(np.bool_)) elif replace_tokens_with_dummy and aval is core.abstract_token: aval = core.ShapedArray((), np.dtype(np.bool_)) return aval_to_ir_types(aval) input_types = map(aval_to_types, jaxpr.in_avals) output_types = map(aval_to_types, jaxpr.out_avals) flat_input_types = util.flatten(input_types) flat_output_types = util.flatten(output_types) ftype = ir.FunctionType.get(flat_input_types, flat_output_types) func_op = FuncOp(name, ftype, ip=ctx.ip) func_op.attributes["sym_visibility"] = ir.StringAttr.get( "public" if public else "private") ctx.symbol_table.insert(func_op) ir_arg_shardings = None if arg_shardings is not None: ir_arg_shardings = util.flatten( [[sharding] * len(types) for sharding, types in zip(arg_shardings, input_types)]) ir_result_shardings = None if result_shardings is not None: ir_result_shardings = util.flatten( [[sharding] * len(types) for sharding, types in zip(result_shardings, output_types)]) if (replicated_args is not None or ir_arg_shardings is not None or input_output_aliases is not None): arg_attrs: List[Dict[str, ir.Attribute]] = [ {} for _ in range(len(flat_input_types)) ] if replicated_args is not None: replicated_ir_args = [ [replicated] * len(types) for replicated, types in zip(replicated_args, input_types) ] for attrs, replicated in zip(arg_attrs, util.flatten(replicated_ir_args)): if replicated: attrs[ "mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get( ) if use_sharding_annotations and ir_arg_shardings is not None: for attrs, sharding in zip(arg_attrs, ir_arg_shardings): if sharding is not None: attrs["mhlo.sharding"] = ir.StringAttr.get( sharding.SerializeToString()) if input_output_aliases is not None: output_ids = util.unflatten(list(range(len(flat_output_types))), map(len, output_types)) aliases: List[Optional[int]] = [] for types, alias in zip(input_types, input_output_aliases): if alias is None: aliases.extend([None] * len(types)) else: aliases.extend(output_ids[alias]) for attrs, alias in zip(arg_attrs, aliases): if alias is not None: attrs["tf.aliasing_output"] = i32_attr(alias) func_op.arg_attrs = ir.ArrayAttr.get( [ir.DictAttr.get(attrs) for attrs in arg_attrs]) if use_sharding_annotations and ir_result_shardings is not None: func_op.result_attrs = ir.ArrayAttr.get([ ir.DictAttr.get({} if sharding is None else { "mhlo.sharding": ir.StringAttr.get(sharding.SerializeToString()) }) for sharding in ir_result_shardings ]) entry_block = func_op.add_entry_block() with ir.InsertionPoint(entry_block): flat_args = entry_block.arguments if not use_sharding_annotations and ir_arg_shardings is not None: flat_args = map(wrap_with_sharding_op, flat_args, ir_arg_shardings) unflattened_args = util.unflatten(flat_args, map(len, input_types)) args: List[List[ir.Value]] = [] for aval, arg in zip(jaxpr.in_avals, unflattened_args): if replace_units_with_dummy and aval is core.abstract_unit: args.append([]) elif replace_tokens_with_dummy and aval is core.abstract_token: args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results) else: args.append(arg) callee_name_stack = xla.extend_name_stack(ctx.name_stack, xla.wrap_name(name, 'jit')) out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack), jaxpr.jaxpr, map(ir_constants, jaxpr.consts), *args) outs = [] for aval, out in zip(jaxpr.out_avals, out_vals): if replace_units_with_dummy and aval is core.abstract_unit: outs.append(ir_constants(np.zeros((), np.bool_))) elif replace_tokens_with_dummy and aval is core.abstract_token: outs.append(ir_constants(np.zeros((), np.bool_))) else: outs.append(out) flat_outputs = util.flatten(outs) if not use_sharding_annotations and ir_result_shardings is not None: flat_outputs = map(wrap_with_sharding_op, flat_outputs, ir_result_shardings) func_dialect.ReturnOp(flat_outputs) return func_op
def lower_jaxpr_to_fun(ctx: LoweringContext, name: str, jaxpr: core.ClosedJaxpr, *, public: bool = False, replace_units_with_dummy: bool = False, replace_tokens_with_dummy: bool = False) -> str: """Lowers jaxpr and its callees to an IR function. Assumes that an MLIR context, location, and insertion point are set. Args: ctx: the lowering context. name: the function name. The name will be uniquified by the symbol table, so it is ok to use the same name multiple times. jaxpr: the jaxpr to lower. public: if true, the function's visibility is set to "public". replace_units_with_dummy: if true, unit arguments/return values are replaced with bool arrays of size [0]. replace_tokens_with_dummy: if true, token arguments/return values are replaced with bool arrays of size [0]. Returns the name of the function. """ def aval_to_types(aval): if replace_units_with_dummy and aval is core.abstract_unit: aval = core.ShapedArray((), np.dtype(np.bool_)) elif replace_tokens_with_dummy and aval is core.abstract_token: aval = core.ShapedArray((), np.dtype(np.bool_)) return aval_to_ir_types(aval) input_types = map(aval_to_types, jaxpr.in_avals) output_types = map(aval_to_types, jaxpr.out_avals) flat_input_types = util.flatten(input_types) flat_output_types = util.flatten(output_types) ftype = ir.FunctionType.get(flat_input_types, flat_output_types) func_op = builtin.FuncOp(name, ftype, ip=ctx.ip) func_op.attributes["sym_visibility"] = ir.StringAttr.get( "public" if public else "private") symbol_name = ir.StringAttr(ctx.symbol_table.insert(func_op)).value entry_block = func_op.add_entry_block() with ir.InsertionPoint(entry_block): unflattened_args = util.unflatten(entry_block.arguments, map(len, input_types)) args: List[List[ir.Value]] = [] for aval, arg in zip(jaxpr.in_avals, unflattened_args): if replace_units_with_dummy and aval is core.abstract_unit: args.append([]) elif replace_tokens_with_dummy and aval is core.abstract_token: args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results) else: args.append(arg) callee_name_stack = xla.extend_name_stack(ctx.name_stack, xla.wrap_name(name, 'jit')) out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack), jaxpr.jaxpr, map(ir_constants, jaxpr.consts), *args) outs = [] for aval, out in zip(jaxpr.out_avals, out_vals): if replace_units_with_dummy and aval is core.abstract_unit: outs.append(ir_constants(np.zeros((), np.bool_))) elif replace_tokens_with_dummy and aval is core.abstract_token: outs.append(ir_constants(np.zeros((), np.bool_))) else: outs.append(out) std.ReturnOp(util.flatten(outs)) return symbol_name