def _emit_lowering_rule_as_fun(lowering_rule, ctx: LoweringRuleContext) -> builtin.FuncOp: """Emits the contents of a lowering rule as a private function.""" input_types = map(aval_to_ir_types, ctx.avals_in) output_types = map(aval_to_ir_types, ctx.avals_out) flat_input_types = util.flatten(input_types) flat_output_types = util.flatten(output_types) ftype = ir.FunctionType.get(flat_input_types, flat_output_types) assert ctx.primitive is not None func_op = builtin.FuncOp(ctx.primitive.name, ftype, ip=ctx.module_context.ip) func_op.attributes["sym_visibility"] = ir.StringAttr.get("private") ctx.module_context.symbol_table.insert(func_op) entry_block = func_op.add_entry_block() with ir.InsertionPoint(entry_block): unflattened_args = util.unflatten(entry_block.arguments, map(len, input_types)) outs = lowering_rule(ctx, *_unwrap_singleton_ir_values(unflattened_args)) std.ReturnOp(util.flatten(map(wrap_singleton_ir_values, outs))) return func_op
def lower_jaxpr_to_fun(ctx: LoweringContext, name: str, jaxpr: core.ClosedJaxpr, *, public: bool = False, replace_units_with_dummy: bool = False, replace_tokens_with_dummy: bool = False) -> str: """Lowers jaxpr and its callees to an IR function. Assumes that an MLIR context, location, and insertion point are set. Args: ctx: the lowering context. name: the function name. The name will be uniquified by the symbol table, so it is ok to use the same name multiple times. jaxpr: the jaxpr to lower. public: if true, the function's visibility is set to "public". replace_units_with_dummy: if true, unit arguments/return values are replaced with bool arrays of size [0]. replace_tokens_with_dummy: if true, token arguments/return values are replaced with bool arrays of size [0]. Returns the name of the function. """ def aval_to_types(aval): if replace_units_with_dummy and aval is core.abstract_unit: aval = core.ShapedArray((), np.dtype(np.bool_)) elif replace_tokens_with_dummy and aval is core.abstract_token: aval = core.ShapedArray((), np.dtype(np.bool_)) return aval_to_ir_types(aval) input_types = map(aval_to_types, jaxpr.in_avals) output_types = map(aval_to_types, jaxpr.out_avals) flat_input_types = util.flatten(input_types) flat_output_types = util.flatten(output_types) ftype = ir.FunctionType.get(flat_input_types, flat_output_types) func_op = builtin.FuncOp(name, ftype, ip=ctx.ip) func_op.attributes["sym_visibility"] = ir.StringAttr.get( "public" if public else "private") symbol_name = ir.StringAttr(ctx.symbol_table.insert(func_op)).value entry_block = func_op.add_entry_block() with ir.InsertionPoint(entry_block): unflattened_args = util.unflatten(entry_block.arguments, map(len, input_types)) args: List[List[ir.Value]] = [] for aval, arg in zip(jaxpr.in_avals, unflattened_args): if replace_units_with_dummy and aval is core.abstract_unit: args.append([]) elif replace_tokens_with_dummy and aval is core.abstract_token: args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results) else: args.append(arg) callee_name_stack = xla.extend_name_stack(ctx.name_stack, xla.wrap_name(name, 'jit')) out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack), jaxpr.jaxpr, map(ir_constants, jaxpr.consts), *args) outs = [] for aval, out in zip(jaxpr.out_avals, out_vals): if replace_units_with_dummy and aval is core.abstract_unit: outs.append(ir_constants(np.zeros((), np.bool_))) elif replace_tokens_with_dummy and aval is core.abstract_token: outs.append(ir_constants(np.zeros((), np.bool_))) else: outs.append(out) std.ReturnOp(util.flatten(outs)) return symbol_name
def lower_jaxpr_to_fun( ctx: ModuleContext, name: str, jaxpr: core.ClosedJaxpr, *, public: bool = False, replace_units_with_dummy: bool = False, replace_tokens_with_dummy: bool = False, replicated_args: Optional[Sequence[bool]] = None, arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None, result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None, input_output_aliases: Optional[Sequence[Optional[int]]] = None ) -> builtin.FuncOp: """Lowers jaxpr and its callees to an IR function. Assumes that an MLIR context, location, and insertion point are set. Args: ctx: the lowering context. name: the function name. The name will be uniquified by the symbol table, so it is ok to use the same name multiple times. jaxpr: the jaxpr to lower. public: if true, the function's visibility is set to "public". replace_units_with_dummy: if true, unit arguments/return values are replaced with bool arrays of size [0]. replace_tokens_with_dummy: if true, token arguments/return values are replaced with bool arrays of size [0]. replicated_args: if present, annotates arguments as replicated. arg_shardings: sharding annotations for each argument (optional). result_shardings: sharding annotations for each argument (optional). input_output_aliases: optional sequence that maps argument numbers to the corresponding output that should alias them. Returns the name of the function. """ def aval_to_types(aval): if replace_units_with_dummy and aval is core.abstract_unit: aval = core.ShapedArray((), np.dtype(np.bool_)) elif replace_tokens_with_dummy and aval is core.abstract_token: aval = core.ShapedArray((), np.dtype(np.bool_)) return aval_to_ir_types(aval) input_types = map(aval_to_types, jaxpr.in_avals) output_types = map(aval_to_types, jaxpr.out_avals) flat_input_types = util.flatten(input_types) flat_output_types = util.flatten(output_types) ftype = ir.FunctionType.get(flat_input_types, flat_output_types) func_op = builtin.FuncOp(name, ftype, ip=ctx.ip) func_op.attributes["sym_visibility"] = ir.StringAttr.get( "public" if public else "private") ctx.symbol_table.insert(func_op) if (replicated_args is not None or arg_shardings is not None or input_output_aliases is not None): arg_attrs: List[Dict[str, ir.Attribute]] = [ {} for _ in range(len(flat_input_types)) ] if replicated_args is not None: replicated_ir_args = [ [replicated] * len(types) for replicated, types in zip(replicated_args, input_types) ] for attrs, replicated in zip(arg_attrs, util.flatten(replicated_ir_args)): if replicated: attrs[ "mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get( ) if arg_shardings is not None: ir_arg_shardings = [ [sharding] * len(types) for sharding, types in zip(arg_shardings, input_types) ] for attrs, sharding in zip(arg_attrs, util.flatten(ir_arg_shardings)): if sharding is not None: attrs["mhlo.sharding"] = ir.StringAttr.get( sharding.SerializeToString()) if input_output_aliases is not None: output_ids = util.unflatten(list(range(len(flat_output_types))), map(len, output_types)) aliases: List[Optional[int]] = [] for types, alias in zip(input_types, input_output_aliases): if alias is None: aliases.extend([None] * len(types)) else: aliases.extend(output_ids[alias]) for attrs, alias in zip(arg_attrs, aliases): if alias is not None: attrs["tf.aliasing_output"] = i32_attr(alias) func_op.arg_attrs = ir.ArrayAttr.get( [ir.DictAttr.get(attrs) for attrs in arg_attrs]) if result_shardings is not None: ir_result_shardings = util.flatten( [[sharding] * len(types) for sharding, types in zip(result_shardings, output_types)]) func_op.result_attrs = ir.ArrayAttr.get([ ir.DictAttr.get({} if sharding is None else { "mhlo.sharding": ir.StringAttr.get(sharding.SerializeToString()) }) for sharding in ir_result_shardings ]) entry_block = func_op.add_entry_block() with ir.InsertionPoint(entry_block): unflattened_args = util.unflatten(entry_block.arguments, map(len, input_types)) args: List[List[ir.Value]] = [] for aval, arg in zip(jaxpr.in_avals, unflattened_args): if replace_units_with_dummy and aval is core.abstract_unit: args.append([]) elif replace_tokens_with_dummy and aval is core.abstract_token: args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results) else: args.append(arg) callee_name_stack = xla.extend_name_stack(ctx.name_stack, xla.wrap_name(name, 'jit')) out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack), jaxpr.jaxpr, map(ir_constants, jaxpr.consts), *args) outs = [] for aval, out in zip(jaxpr.out_avals, out_vals): if replace_units_with_dummy and aval is core.abstract_unit: outs.append(ir_constants(np.zeros((), np.bool_))) elif replace_tokens_with_dummy and aval is core.abstract_token: outs.append(ir_constants(np.zeros((), np.bool_))) else: outs.append(out) std.ReturnOp(util.flatten(outs)) return func_op