def _execute_replicated(name: str, compiled: XlaExecutable, output_buffer_counts: Optional[Sequence[int]], result_handlers, kept_var_idx, *args): input_bufs = [ util.flatten( device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx) for device in compiled.local_devices() ] out_bufs = [ buf[0] for buf in compiled.execute_sharded_on_local_devices( list(zip(*input_bufs))) ] check_special(name, out_bufs) if output_buffer_counts is None: return (result_handlers[0](*out_bufs), ) return tuple( handler(*bs) for handler, bs in unsafe_zip( result_handlers, util.unflatten(out_bufs, output_buffer_counts)))
def fallback(ctx: LoweringRuleContext, *args, **params): module_ctx = ctx.module_context xla_computation = xla.primitive_subcomputation(module_ctx.platform, module_ctx.axis_env, prim, *ctx.avals_in, **params) submodule_str = xc._xla.mlir.xla_computation_to_mlir_module( xla_computation) submodule = ir.Module.parse(submodule_str) callee_name = None for op in submodule.body.operations: op = typing.cast(FuncOpType, op) module_ctx.module.body.append(op) if op.name.value == "main": op.attributes["sym_name"] = ir.StringAttr.get( f"xla_fallback_{prim.name}") callee_name = ir.StringAttr( module_ctx.symbol_table.insert(op)).value op.attributes["sym_visibility"] = ir.StringAttr.get("private") else: module_ctx.symbol_table.insert(op) output_types = map(aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) output_type = (ir.TupleType.get_tuple(flat_output_types) if prim.multiple_results else flat_output_types[0]) call = func_dialect.CallOp([output_type], ir.FlatSymbolRefAttr.get(callee_name), flatten_lowering_ir_args(args)).result if not prim.multiple_results: return [call] if jax._src.lib.mlir_api_version < 6: flat_results = [ mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result for i, typ in enumerate(flat_output_types) ] else: flat_results = [ mhlo.GetTupleElementOp(call, i32_attr(i)).result for i in range(len(flat_output_types)) ] return util.unflatten(flat_results, map(len, output_types))
def _generic_reduce_window_lower(ctx, *args, jaxpr, consts, window_dimensions, window_strides, padding, base_dilation, window_dilation): operands, init_values = util.split_list(args, [len(args) // 2]) _, init_value_avals = util.split_list(ctx.avals_in, [len(operands)]) scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals] rw = mhlo.ReduceWindowOp( map(mlir.aval_to_ir_type, ctx.avals_out), operands, init_values, mlir.dense_int_elements(window_dimensions), mlir.dense_int_elements(window_strides), mlir.dense_int_elements(base_dilation), mlir.dense_int_elements(window_dilation), ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64))) reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types)) with ir.InsertionPoint(reducer): out_nodes = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, consts, *([a] for a in reducer.arguments)) mhlo.ReturnOp(util.flatten(out_nodes)) return rw.results
def _sharded_jit_lowering(ctx, *in_nodes, in_parts, out_parts_thunk, nparts, name, call_jaxpr, local_in_parts, local_out_parts_thunk, local_nparts): # We assume any extra leading in_nodes are constants and replicate them. num_extra_nodes = len(in_nodes) - len(in_parts) assert num_extra_nodes >= 0 in_parts = (None, ) * num_extra_nodes + in_parts args = [] for ns, sharding in safe_zip( safe_map(mlir.wrap_singleton_ir_values, in_nodes), in_parts): if sharding is not None: args.append([ mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding)) for n in ns ]) else: args.append(ns) sub_ctx = ctx.module_context.replace( name_stack=extend_name_stack(wrap_name(name, "sharded_jit"))) fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}", core.ClosedJaxpr(call_jaxpr, ())) output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) call = std.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(fn.name.value), mlir.flatten_lowering_ir_args(args)) out_nodes = util.unflatten(call.results, safe_map(len, output_types)) out_parts = out_parts_thunk() outputs = [] for ns, sharding in safe_zip(out_nodes, out_parts): if sharding is not None: outputs.append([ mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding)) for n in ns ]) else: outputs.append(ns) return outputs
def _execute_compiled(name: str, compiled: XlaExecutable, input_handler: Optional[Callable], output_buffer_counts: Optional[Sequence[int]], result_handlers, effects: List[core.Effect], kept_var_idx, *args): device, = compiled.local_devices() args = input_handler(args) if input_handler else args input_bufs_flat = flatten( device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx) if effects: input_bufs_flat, token_handler = _add_tokens(effects, device, input_bufs_flat) out_bufs_flat = compiled.execute(input_bufs_flat) check_special(name, out_bufs_flat) if output_buffer_counts is None: return (result_handlers[0](*out_bufs_flat), ) out_bufs = unflatten(out_bufs_flat, output_buffer_counts) if effects: out_bufs = token_handler(out_bufs) return tuple(h(*bs) for h, bs in unsafe_zip(result_handlers, out_bufs))
def cached_lowering(ctx, *args, **params): assert ctx.primitive is not None key = (ctx.primitive, tuple(ctx.avals_in), tuple(ctx.avals_out), tuple(params.items())) try: func = ctx.module_context.cached_primitive_lowerings.get(key) except TypeError: # If the parameters aren't hashable, give up on caching. # TODO(phawkins): switch to requiring hashability, when XLA fallback # computations have been ported to MHLO. return f(ctx, *args, **params) if func is None: func = _emit_lowering_rule_as_fun(partial(f, **params), ctx) ctx.module_context.cached_primitive_lowerings[key] = func output_types = map(aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) call = func_dialect.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(func.name.value), flatten_lowering_ir_args(args)) return util.unflatten(call.results, map(len, output_types))
def _execute_replicated(name: str, compiled: XlaExecutable, input_handler: Optional[Callable], output_buffer_counts: Sequence[int], result_handler: Callable, has_unordered_effects: bool, ordered_effects: List[core.Effect], kept_var_idx, *args): if has_unordered_effects or ordered_effects: # TODO(sharadmv): support jit-of-pmap with effects raise NotImplementedError( "Cannot execute replicated computation with effects.") if input_handler: raise NotImplementedError # TODO(mattjj, dougalm) input_bufs = [ flatten( device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx) for device in compiled.local_devices() ] input_bufs_flip = list(unsafe_zip(*input_bufs)) out_bufs_flat_rep = compiled.execute_sharded_on_local_devices( input_bufs_flip) out_flat = [bufs[0] for bufs in out_bufs_flat_rep] check_special(name, out_flat) out_bufs = unflatten(out_flat, output_buffer_counts) return result_handler(None, out_bufs)
def _execute_replicated(name: str, compiled: XlaExecutable, input_handler: Optional[Callable], output_buffer_counts: Optional[Sequence[int]], result_handlers, effects: List[core.Effect], kept_var_idx, *args): if effects: raise NotImplementedError( 'Cannot execute replicated computation with effects.') if input_handler: raise NotImplementedError # TODO(mattjj, dougalm) input_bufs = [ flatten( device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx) for device in compiled.local_devices() ] input_bufs_flip = list(unsafe_zip(*input_bufs)) out_bufs_flat_rep = compiled.execute_sharded_on_local_devices( input_bufs_flip) out_bufs_flat = [bufs[0] for bufs in out_bufs_flat_rep] check_special(name, out_bufs_flat) if output_buffer_counts is None: return (result_handlers[0](*out_bufs_flat), ) out_bufs = unflatten(out_bufs_flat, output_buffer_counts) return tuple(h(*bs) for h, bs in unsafe_zip(result_handlers, out_bufs))
def lower_jaxpr_to_xla_module( fn_name: str, jaxpr: core.ClosedJaxpr, platform: str, axis_env: AxisEnv, name_stack: Union[source_info_util.NameStack, str], tuple_args: bool, donated_invars: Sequence[bool], replicated_args: Optional[Sequence[bool]], arg_partitions: Optional[Any], out_partitions: Optional[Any], partitions_are_protos: bool = False) -> xc.XlaComputation: """Lowers a closed jaxpr to a top-level XLA module.""" c = xc.XlaBuilder(fn_name) xla_consts = _xla_consts(c, jaxpr.consts) xla_args, donated_invars = _xla_callable_args( c, jaxpr.in_avals, tuple_args, donated_invars=donated_invars, replicated=replicated_args, partitions=arg_partitions, partitions_proto=partitions_are_protos) ctx = TranslationContext(c, platform, axis_env, name_stack) out_nodes = jaxpr_subcomp(ctx, jaxpr.jaxpr, xla_consts, *xla_args) # Replace tokens with a dummy array value, because the runtime cannot # handle token arguments. out_aval_lens = [len(aval_to_xla_shapes(a)) for a in jaxpr.out_avals] out_nodes = util.flatten( [[_make_token_return_value(c)] if a is core.abstract_token else v for a, v in zip(jaxpr.out_avals, util.unflatten(out_nodes, out_aval_lens))]) # There is a non-zero cost to building an output tuple, particularly on TPU. # Avoid it if the output arity is 1. if out_partitions is None: output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple( c, out_nodes) else: build_out_tuple = partial(xops.Tuple, c, out_nodes) if partitions_are_protos: output = with_sharding_proto(c, out_partitions, build_out_tuple) else: output = with_sharding(c, out_partitions, build_out_tuple) platforms_with_donation = ("gpu", "tpu") if platform in platforms_with_donation: donated_invars = set_up_aliases(c, xla_args, c.GetShape(output), donated_invars, tuple_args) if any(donated_invars): # TODO(tomhennigan): At call time we should mark these buffers as deleted. unused_donations = [ str(c.GetShape(a)) for a, d in zip(xla_args, donated_invars) if d ] msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation." if platform not in platforms_with_donation: msg = f"Donation is not implemented for {platform}.\n{msg}" warnings.warn( f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}" ) return c.build(output)
def lower_jaxpr_to_fun( ctx: ModuleContext, name: str, jaxpr: core.ClosedJaxpr, *, public: bool = False, replace_units_with_dummy: bool = False, replace_tokens_with_dummy: bool = False, replicated_args: Optional[Sequence[bool]] = None, arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None, result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None, use_sharding_annotations: bool = True, input_output_aliases: Optional[Sequence[Optional[int]]] = None ) -> FuncOpType: """Lowers jaxpr and its callees to an IR function. Assumes that an MLIR context, location, and insertion point are set. Args: ctx: the lowering context. name: the function name. The name will be uniquified by the symbol table, so it is ok to use the same name multiple times. jaxpr: the jaxpr to lower. public: if true, the function's visibility is set to "public". replace_units_with_dummy: if true, unit arguments/return values are replaced with bool arrays of size [0]. replace_tokens_with_dummy: if true, token arguments/return values are replaced with bool arrays of size [0]. replicated_args: if present, annotates arguments as replicated. arg_shardings: sharding annotations for each argument (optional). result_shardings: sharding annotations for each argument (optional). use_sharding_annotations: if True, use mhlo.sharding annotations on parameters and return values to express sharding. If False, use mhlo.custom_call operators with sharding annotations. TODO(b/228598865): remove this option when mhlo.sharding annotations are propagated on non-entry functions during MHLO->HLO conversion. input_output_aliases: optional sequence that maps argument numbers to the corresponding output that should alias them. Returns the name of the function. """ def aval_to_types(aval): if replace_units_with_dummy and aval is core.abstract_unit: aval = core.ShapedArray((), np.dtype(np.bool_)) elif replace_tokens_with_dummy and aval is core.abstract_token: aval = core.ShapedArray((), np.dtype(np.bool_)) return aval_to_ir_types(aval) input_types = map(aval_to_types, jaxpr.in_avals) output_types = map(aval_to_types, jaxpr.out_avals) flat_input_types = util.flatten(input_types) flat_output_types = util.flatten(output_types) ftype = ir.FunctionType.get(flat_input_types, flat_output_types) func_op = FuncOp(name, ftype, ip=ctx.ip) func_op.attributes["sym_visibility"] = ir.StringAttr.get( "public" if public else "private") ctx.symbol_table.insert(func_op) ir_arg_shardings = None if arg_shardings is not None: ir_arg_shardings = util.flatten( [[sharding] * len(types) for sharding, types in zip(arg_shardings, input_types)]) ir_result_shardings = None if result_shardings is not None: ir_result_shardings = util.flatten( [[sharding] * len(types) for sharding, types in zip(result_shardings, output_types)]) if (replicated_args is not None or ir_arg_shardings is not None or input_output_aliases is not None): arg_attrs: List[Dict[str, ir.Attribute]] = [ {} for _ in range(len(flat_input_types)) ] if replicated_args is not None: replicated_ir_args = [ [replicated] * len(types) for replicated, types in zip(replicated_args, input_types) ] for attrs, replicated in zip(arg_attrs, util.flatten(replicated_ir_args)): if replicated: attrs[ "mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get( ) if use_sharding_annotations and ir_arg_shardings is not None: for attrs, sharding in zip(arg_attrs, ir_arg_shardings): if sharding is not None: attrs["mhlo.sharding"] = ir.StringAttr.get( sharding.SerializeToString()) if input_output_aliases is not None: output_ids = util.unflatten(list(range(len(flat_output_types))), map(len, output_types)) aliases: List[Optional[int]] = [] for types, alias in zip(input_types, input_output_aliases): if alias is None: aliases.extend([None] * len(types)) else: aliases.extend(output_ids[alias]) for attrs, alias in zip(arg_attrs, aliases): if alias is not None: attrs["tf.aliasing_output"] = i32_attr(alias) func_op.arg_attrs = ir.ArrayAttr.get( [ir.DictAttr.get(attrs) for attrs in arg_attrs]) if use_sharding_annotations and ir_result_shardings is not None: func_op.result_attrs = ir.ArrayAttr.get([ ir.DictAttr.get({} if sharding is None else { "mhlo.sharding": ir.StringAttr.get(sharding.SerializeToString()) }) for sharding in ir_result_shardings ]) entry_block = func_op.add_entry_block() with ir.InsertionPoint(entry_block): flat_args = entry_block.arguments if not use_sharding_annotations and ir_arg_shardings is not None: flat_args = map(wrap_with_sharding_op, flat_args, ir_arg_shardings) unflattened_args = util.unflatten(flat_args, map(len, input_types)) args: List[List[ir.Value]] = [] for aval, arg in zip(jaxpr.in_avals, unflattened_args): if replace_units_with_dummy and aval is core.abstract_unit: args.append([]) elif replace_tokens_with_dummy and aval is core.abstract_token: args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results) else: args.append(arg) callee_name_stack = xla.extend_name_stack(ctx.name_stack, xla.wrap_name(name, 'jit')) out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack), jaxpr.jaxpr, map(ir_constants, jaxpr.consts), *args) outs = [] for aval, out in zip(jaxpr.out_avals, out_vals): if replace_units_with_dummy and aval is core.abstract_unit: outs.append(ir_constants(np.zeros((), np.bool_))) elif replace_tokens_with_dummy and aval is core.abstract_token: outs.append(ir_constants(np.zeros((), np.bool_))) else: outs.append(out) flat_outputs = util.flatten(outs) if not use_sharding_annotations and ir_result_shardings is not None: flat_outputs = map(wrap_with_sharding_op, flat_outputs, ir_result_shardings) func_dialect.ReturnOp(flat_outputs) return func_op
def flatten_lowering_ir_args( xs: Sequence[Union[ir.Value, Sequence[ir.Value]]] ) -> Sequence[Sequence[ir.Value]]: return util.flatten(map(wrap_singleton_ir_values, xs))
def lower_jaxpr_to_fun(ctx: LoweringContext, name: str, jaxpr: core.ClosedJaxpr, *, public: bool = False, replace_units_with_dummy: bool = False, replace_tokens_with_dummy: bool = False) -> str: """Lowers jaxpr and its callees to an IR function. Assumes that an MLIR context, location, and insertion point are set. Args: ctx: the lowering context. name: the function name. The name will be uniquified by the symbol table, so it is ok to use the same name multiple times. jaxpr: the jaxpr to lower. public: if true, the function's visibility is set to "public". replace_units_with_dummy: if true, unit arguments/return values are replaced with bool arrays of size [0]. replace_tokens_with_dummy: if true, token arguments/return values are replaced with bool arrays of size [0]. Returns the name of the function. """ def aval_to_types(aval): if replace_units_with_dummy and aval is core.abstract_unit: aval = core.ShapedArray((), np.dtype(np.bool_)) elif replace_tokens_with_dummy and aval is core.abstract_token: aval = core.ShapedArray((), np.dtype(np.bool_)) return aval_to_ir_types(aval) input_types = map(aval_to_types, jaxpr.in_avals) output_types = map(aval_to_types, jaxpr.out_avals) flat_input_types = util.flatten(input_types) flat_output_types = util.flatten(output_types) ftype = ir.FunctionType.get(flat_input_types, flat_output_types) func_op = builtin.FuncOp(name, ftype, ip=ctx.ip) func_op.attributes["sym_visibility"] = ir.StringAttr.get( "public" if public else "private") symbol_name = ir.StringAttr(ctx.symbol_table.insert(func_op)).value entry_block = func_op.add_entry_block() with ir.InsertionPoint(entry_block): unflattened_args = util.unflatten(entry_block.arguments, map(len, input_types)) args: List[List[ir.Value]] = [] for aval, arg in zip(jaxpr.in_avals, unflattened_args): if replace_units_with_dummy and aval is core.abstract_unit: args.append([]) elif replace_tokens_with_dummy and aval is core.abstract_token: args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results) else: args.append(arg) callee_name_stack = xla.extend_name_stack(ctx.name_stack, xla.wrap_name(name, 'jit')) out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack), jaxpr.jaxpr, map(ir_constants, jaxpr.consts), *args) outs = [] for aval, out in zip(jaxpr.out_avals, out_vals): if replace_units_with_dummy and aval is core.abstract_unit: outs.append(ir_constants(np.zeros((), np.bool_))) elif replace_tokens_with_dummy and aval is core.abstract_token: outs.append(ir_constants(np.zeros((), np.bool_))) else: outs.append(out) std.ReturnOp(util.flatten(outs)) return symbol_name