def _remat_using_while(ctx, avals_in, avals_out, *args, name, call_jaxpr): input_types = map(aval_to_ir_types, avals_in) output_types = map(aval_to_ir_types, avals_out) flat_output_types = util.flatten(output_types) int32_scalar_type = aval_to_ir_type( core.ShapedArray((), np.dtype(np.int32))) loop_carry_types = [(int32_scalar_type, )] + input_types + output_types flat_loop_carry_types = util.flatten(loop_carry_types) counter_init = ir_constants(np.array(0, np.int32)) flat_args = flatten_lowering_ir_args((counter_init, ) + args + tuple( _dummy_like_aval(aval) for aval in avals_out)) loop_carry_tuple_type = ir.TupleType.get_tuple(flat_loop_carry_types) init_carry = mhlo.TupleOp(loop_carry_tuple_type, flat_args) one = ir_constant(np.array(1, np.int32)) while_op = mhlo.WhileOp([loop_carry_tuple_type], [init_carry.result]) # Loop condition cond_block = while_op.regions[0].blocks.append(loop_carry_tuple_type) with ir.InsertionPoint(cond_block): bool_scalar_type = aval_to_ir_type( core.ShapedArray((), np.dtype(np.bool_))) two = ir_constant(np.array(2, np.int32)) shape = ir_constant(np.array((), np.int64), canonicalize_types=False) rng = mhlo.RngUniformOp(one, two, shape).result i = mhlo.GetTupleElementOp(int32_scalar_type, cond_block.arguments[0], i32_attr(0)) cmp = mhlo.CompareOp(bool_scalar_type, i, rng, ir.StringAttr.get("LT"), ir.StringAttr.get("SIGNED")).result mhlo.ReturnOp([cmp]) body_block = while_op.regions[1].blocks.append(loop_carry_tuple_type) with ir.InsertionPoint(body_block): flat_body_args = [ mhlo.GetTupleElementOp(input_type, body_block.arguments[0], i32_attr(i)).result for i, input_type in enumerate(flat_loop_carry_types) ] body_args = util.unflatten(flat_body_args, map(len, loop_carry_types)) ((i, ), ), y, _ = util.split_list(body_args, [1, len(avals_in)]) body_ctx = ctx.replace(name_stack=xla.extend_name_stack( ctx.name_stack, xla.wrap_name(name, 'remat'))) z = jaxpr_subcomp(body_ctx, call_jaxpr, (), *y) i_next = mhlo.AddOp(i, one).result new_carry = mhlo.TupleOp(loop_carry_tuple_type, [i_next, *util.flatten(y), *util.flatten(z)]) mhlo.ReturnOp([new_carry.result]) outputs = [ mhlo.GetTupleElementOp(output_type, while_op.result, i32_attr(1 + len(avals_in) + i)).result for i, output_type in enumerate(flat_output_types) ] return util.unflatten(outputs, map(len, output_types))
def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringContext, avals_in, avals_out, *args, **params): xla_computation = xla.primitive_subcomputation(ctx.platform, ctx.axis_env, prim, *avals_in, **params) submodule_str = xc._xla.mlir.xla_computation_to_mlir_module( xla_computation) submodule = ir.Module.parse(submodule_str) callee_name = None for op in submodule.body.operations: ctx.module.body.append(op) if op.name.value == "main": callee_name = ir.StringAttr(ctx.symbol_table.insert(op)).value op.attributes["sym_visibility"] = ir.StringAttr.get("private") else: ctx.symbol_table.insert(op) output_types = map(aval_to_ir_types, avals_out) flat_output_types = util.flatten(output_types) output_type = (ir.TupleType.get_tuple(flat_output_types) if prim.multiple_results else flat_output_types[0]) call = std.CallOp([output_type], ir.FlatSymbolRefAttr.get(callee_name), flatten_lowering_ir_args(args)).result if not prim.multiple_results: return [call] flat_results = [ mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result for i, typ in enumerate(flat_output_types) ] return util.unflatten(flat_results, map(len, output_types))
def code_gen(ctx: mlir.ModuleContext, args_op: Sequence[ir.Value] ) -> Sequence[ir.Value]: captured_ops = tuple(mlir.ir_constant(np.asarray(inp), canonicalize_types=False) for inp in captured_inputs) submodule = mlir.xla_computation_to_mhlo_module(xla_comp) symtab = ir.SymbolTable(submodule.operation) callee_result_types = symtab["main"].type.results fn = mlir.merge_mhlo_modules(ctx.module, f"call_tf_{function_flat_tf.name}", submodule) call = func_dialect.CallOp(callee_result_types, ir.FlatSymbolRefAttr.get(fn), tuple(args_op) + captured_ops) if result_shape.is_tuple(): flat_results = [mhlo.GetTupleElementOp(call, mlir.i32_attr(i)).result for i in range(len(result_shapes))] else: flat_results = call.results outputs = [] for op, res_aval, res_shape in zip(flat_results, result_avals, result_shapes): if res_aval.dtype != res_shape.numpy_dtype(): op = mhlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result outputs.append(op) return outputs
def fallback(ctx: LoweringRuleContext, *args, **params): module_ctx = ctx.module_context xla_computation = xla.primitive_subcomputation(module_ctx.platform, module_ctx.axis_env, prim, *ctx.avals_in, **params) submodule_str = xc._xla.mlir.xla_computation_to_mlir_module( xla_computation) submodule = ir.Module.parse(submodule_str) callee_name = None for op in submodule.body.operations: op = typing.cast(FuncOpType, op) module_ctx.module.body.append(op) if op.name.value == "main": op.attributes["sym_name"] = ir.StringAttr.get( f"xla_fallback_{prim.name}") callee_name = ir.StringAttr( module_ctx.symbol_table.insert(op)).value op.attributes["sym_visibility"] = ir.StringAttr.get("private") else: module_ctx.symbol_table.insert(op) output_types = map(aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) output_type = (ir.TupleType.get_tuple(flat_output_types) if prim.multiple_results else flat_output_types[0]) call = func_dialect.CallOp([output_type], ir.FlatSymbolRefAttr.get(callee_name), flatten_lowering_ir_args(args)).result if not prim.multiple_results: return [call] if jax._src.lib.mlir_api_version < 6: flat_results = [ mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result for i, typ in enumerate(flat_output_types) ] else: flat_results = [ mhlo.GetTupleElementOp(call, i32_attr(i)).result for i in range(len(flat_output_types)) ] return util.unflatten(flat_results, map(len, output_types))