def _dummy_like_aval(aval: core.AbstractValue) -> Sequence[ir.Value]: if isinstance(aval, core.ShapedArray): return [full_like_aval(0, aval)] elif isinstance(aval, core.AbstractToken): return mhlo.CreateTokenOp(aval_to_ir_type(aval)).results elif isinstance(aval, core.AbstractUnit): return () else: raise TypeError(f"Unsupported abstract value {aval}")
def lower_jaxpr_to_fun( ctx: ModuleContext, name: str, jaxpr: core.ClosedJaxpr, *, public: bool = False, replace_units_with_dummy: bool = False, replace_tokens_with_dummy: bool = False, replicated_args: Optional[Sequence[bool]] = None, arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None, result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None, use_sharding_annotations: bool = True, input_output_aliases: Optional[Sequence[Optional[int]]] = None ) -> FuncOpType: """Lowers jaxpr and its callees to an IR function. Assumes that an MLIR context, location, and insertion point are set. Args: ctx: the lowering context. name: the function name. The name will be uniquified by the symbol table, so it is ok to use the same name multiple times. jaxpr: the jaxpr to lower. public: if true, the function's visibility is set to "public". replace_units_with_dummy: if true, unit arguments/return values are replaced with bool arrays of size [0]. replace_tokens_with_dummy: if true, token arguments/return values are replaced with bool arrays of size [0]. replicated_args: if present, annotates arguments as replicated. arg_shardings: sharding annotations for each argument (optional). result_shardings: sharding annotations for each argument (optional). use_sharding_annotations: if True, use mhlo.sharding annotations on parameters and return values to express sharding. If False, use mhlo.custom_call operators with sharding annotations. TODO(b/228598865): remove this option when mhlo.sharding annotations are propagated on non-entry functions during MHLO->HLO conversion. input_output_aliases: optional sequence that maps argument numbers to the corresponding output that should alias them. Returns the name of the function. """ def aval_to_types(aval): if replace_units_with_dummy and aval is core.abstract_unit: aval = core.ShapedArray((), np.dtype(np.bool_)) elif replace_tokens_with_dummy and aval is core.abstract_token: aval = core.ShapedArray((), np.dtype(np.bool_)) return aval_to_ir_types(aval) input_types = map(aval_to_types, jaxpr.in_avals) output_types = map(aval_to_types, jaxpr.out_avals) flat_input_types = util.flatten(input_types) flat_output_types = util.flatten(output_types) ftype = ir.FunctionType.get(flat_input_types, flat_output_types) func_op = FuncOp(name, ftype, ip=ctx.ip) func_op.attributes["sym_visibility"] = ir.StringAttr.get( "public" if public else "private") ctx.symbol_table.insert(func_op) ir_arg_shardings = None if arg_shardings is not None: ir_arg_shardings = util.flatten( [[sharding] * len(types) for sharding, types in zip(arg_shardings, input_types)]) ir_result_shardings = None if result_shardings is not None: ir_result_shardings = util.flatten( [[sharding] * len(types) for sharding, types in zip(result_shardings, output_types)]) if (replicated_args is not None or ir_arg_shardings is not None or input_output_aliases is not None): arg_attrs: List[Dict[str, ir.Attribute]] = [ {} for _ in range(len(flat_input_types)) ] if replicated_args is not None: replicated_ir_args = [ [replicated] * len(types) for replicated, types in zip(replicated_args, input_types) ] for attrs, replicated in zip(arg_attrs, util.flatten(replicated_ir_args)): if replicated: attrs[ "mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get( ) if use_sharding_annotations and ir_arg_shardings is not None: for attrs, sharding in zip(arg_attrs, ir_arg_shardings): if sharding is not None: attrs["mhlo.sharding"] = ir.StringAttr.get( sharding.SerializeToString()) if input_output_aliases is not None: output_ids = util.unflatten(list(range(len(flat_output_types))), map(len, output_types)) aliases: List[Optional[int]] = [] for types, alias in zip(input_types, input_output_aliases): if alias is None: aliases.extend([None] * len(types)) else: aliases.extend(output_ids[alias]) for attrs, alias in zip(arg_attrs, aliases): if alias is not None: attrs["tf.aliasing_output"] = i32_attr(alias) func_op.arg_attrs = ir.ArrayAttr.get( [ir.DictAttr.get(attrs) for attrs in arg_attrs]) if use_sharding_annotations and ir_result_shardings is not None: func_op.result_attrs = ir.ArrayAttr.get([ ir.DictAttr.get({} if sharding is None else { "mhlo.sharding": ir.StringAttr.get(sharding.SerializeToString()) }) for sharding in ir_result_shardings ]) entry_block = func_op.add_entry_block() with ir.InsertionPoint(entry_block): flat_args = entry_block.arguments if not use_sharding_annotations and ir_arg_shardings is not None: flat_args = map(wrap_with_sharding_op, flat_args, ir_arg_shardings) unflattened_args = util.unflatten(flat_args, map(len, input_types)) args: List[List[ir.Value]] = [] for aval, arg in zip(jaxpr.in_avals, unflattened_args): if replace_units_with_dummy and aval is core.abstract_unit: args.append([]) elif replace_tokens_with_dummy and aval is core.abstract_token: args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results) else: args.append(arg) callee_name_stack = xla.extend_name_stack(ctx.name_stack, xla.wrap_name(name, 'jit')) out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack), jaxpr.jaxpr, map(ir_constants, jaxpr.consts), *args) outs = [] for aval, out in zip(jaxpr.out_avals, out_vals): if replace_units_with_dummy and aval is core.abstract_unit: outs.append(ir_constants(np.zeros((), np.bool_))) elif replace_tokens_with_dummy and aval is core.abstract_token: outs.append(ir_constants(np.zeros((), np.bool_))) else: outs.append(out) flat_outputs = util.flatten(outs) if not use_sharding_annotations and ir_result_shardings is not None: flat_outputs = map(wrap_with_sharding_op, flat_outputs, ir_result_shardings) func_dialect.ReturnOp(flat_outputs) return func_op
for ptype, dtype in dtypes.python_scalar_dtypes.items(): register_constant_handler(ptype, partial(_python_scalar_handler, dtype)) def _device_array_constant_handler(val, canonicalize_types): return _ndarray_constant_handler(val.device_buffer.to_py(), canonicalize_types) for t in device_array.device_array_types: register_constant_handler(t, _device_array_constant_handler) register_constant_handler( core.Token, lambda _, __: [mhlo.CreateTokenOp(mhlo.TokenType.get()).result]) # Source locations def _source_info_to_location( primitive: core.Primitive, params: Dict, source_info: source_info_util.SourceInfo, name_stack: Union[str, source_info_util.NameStack] = "") -> ir.Location: eqn_str = str(name_stack) + core.str_eqn_compact(primitive.name, params) frame = source_info_util.user_frame(source_info) if frame is None: loc = ir.Location.unknown() else:
def lower_jaxpr_to_fun(ctx: LoweringContext, name: str, jaxpr: core.ClosedJaxpr, *, public: bool = False, replace_units_with_dummy: bool = False, replace_tokens_with_dummy: bool = False) -> str: """Lowers jaxpr and its callees to an IR function. Assumes that an MLIR context, location, and insertion point are set. Args: ctx: the lowering context. name: the function name. The name will be uniquified by the symbol table, so it is ok to use the same name multiple times. jaxpr: the jaxpr to lower. public: if true, the function's visibility is set to "public". replace_units_with_dummy: if true, unit arguments/return values are replaced with bool arrays of size [0]. replace_tokens_with_dummy: if true, token arguments/return values are replaced with bool arrays of size [0]. Returns the name of the function. """ def aval_to_types(aval): if replace_units_with_dummy and aval is core.abstract_unit: aval = core.ShapedArray((), np.dtype(np.bool_)) elif replace_tokens_with_dummy and aval is core.abstract_token: aval = core.ShapedArray((), np.dtype(np.bool_)) return aval_to_ir_types(aval) input_types = map(aval_to_types, jaxpr.in_avals) output_types = map(aval_to_types, jaxpr.out_avals) flat_input_types = util.flatten(input_types) flat_output_types = util.flatten(output_types) ftype = ir.FunctionType.get(flat_input_types, flat_output_types) func_op = builtin.FuncOp(name, ftype, ip=ctx.ip) func_op.attributes["sym_visibility"] = ir.StringAttr.get( "public" if public else "private") symbol_name = ir.StringAttr(ctx.symbol_table.insert(func_op)).value entry_block = func_op.add_entry_block() with ir.InsertionPoint(entry_block): unflattened_args = util.unflatten(entry_block.arguments, map(len, input_types)) args: List[List[ir.Value]] = [] for aval, arg in zip(jaxpr.in_avals, unflattened_args): if replace_units_with_dummy and aval is core.abstract_unit: args.append([]) elif replace_tokens_with_dummy and aval is core.abstract_token: args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results) else: args.append(arg) callee_name_stack = xla.extend_name_stack(ctx.name_stack, xla.wrap_name(name, 'jit')) out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack), jaxpr.jaxpr, map(ir_constants, jaxpr.consts), *args) outs = [] for aval, out in zip(jaxpr.out_avals, out_vals): if replace_units_with_dummy and aval is core.abstract_unit: outs.append(ir_constants(np.zeros((), np.bool_))) elif replace_tokens_with_dummy and aval is core.abstract_token: outs.append(ir_constants(np.zeros((), np.bool_))) else: outs.append(out) std.ReturnOp(util.flatten(outs)) return symbol_name
register_constant_handler(_scalar_type, _ndarray_constant_handler) def _python_scalar_handler(dtype, val, canonicalize_dtypes): return _numpy_array_constant(np.array(val, dtype), canonicalize_dtypes) for ptype, dtype in dtypes.python_scalar_dtypes.items(): register_constant_handler(ptype, partial(_python_scalar_handler, dtype)) def _device_array_constant_handler(val, canonicalize_types): return _ndarray_constant_handler(val.device_buffer.to_py(), canonicalize_types) for t in device_array.device_array_types: register_constant_handler(t, _device_array_constant_handler) register_constant_handler( core.Token, lambda _, __: [mhlo.CreateTokenOp(mhlo.TokenType.get())]) # Source locations def _source_info_to_location( primitive: core.Primitive, params: Dict, source_info: source_info_util.SourceInfo, name_stack: str = "") -> ir.Location: eqn_str = name_stack + core.str_eqn_compact(primitive.name, params) frame = source_info_util.user_frame(source_info) if frame is None: loc = ir.Location.unknown() else: loc = ir.Location.file(xla._get_canonical_source_file(frame), frame.line_num, 1) loc = ir.Location.name(eqn_str, childLoc=loc)