def pack(a, b): a_dims = ir.RankedTensorType(a.type).shape b_dims = ir.RankedTensorType(b.type).shape a = mhlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a) b = mhlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b) a = mhlo.ConvertOp(ir.RankedTensorType.get(a_dims, double_word_type), a) b = mhlo.ConvertOp(ir.RankedTensorType.get(b_dims, double_word_type), b) a = mhlo.ShiftLeftOp(a, _broadcast(const(double_word_dtype, nbits), a_dims)) return mhlo.OrOp(a, b)
def fst(t): dims = ir.RankedTensorType(t.type).shape st = mhlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits)) return mhlo.BitcastConvertOp( ir.RankedTensorType.get(dims, etype), mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), st)).result
def code_gen(ctx: mlir.ModuleContext, args_op: Sequence[ir.Value] ) -> Sequence[ir.Value]: captured_ops = tuple(mlir.ir_constant(np.asarray(inp), canonicalize_types=False) for inp in captured_inputs) submodule = mlir.xla_computation_to_mhlo_module(xla_comp) symtab = ir.SymbolTable(submodule.operation) callee_result_types = symtab["main"].type.results fn = mlir.merge_mhlo_modules(ctx.module, f"call_tf_{function_flat_tf.name}", submodule) call = func_dialect.CallOp(callee_result_types, ir.FlatSymbolRefAttr.get(fn), tuple(args_op) + captured_ops) if result_shape.is_tuple(): flat_results = [mhlo.GetTupleElementOp(call, mlir.i32_attr(i)).result for i in range(len(result_shapes))] else: flat_results = call.results outputs = [] for op, res_aval, res_shape in zip(flat_results, result_avals, result_shapes): if res_aval.dtype != res_shape.numpy_dtype(): op = mhlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result outputs.append(op) return outputs
def convert_mhlo(x, aval_in, aval_out): """Variant of convert that has XLA HLO semantics. In particular, treat casts to boolean as x != 0, rather than truncating integer values (b/209440332).""" if aval_out.dtype == np.dtype(np.bool_): if dtypes.issubdtype(aval_in.dtype, np.inexact): compare_type = "FLOAT" elif dtypes.issubdtype(aval_in.dtype, np.signedinteger): compare_type = "SIGNED" else: compare_type = "UNSIGNED" return compare_mhlo(x, full_like_aval(0, aval_in), "NE", compare_type).result return mhlo.ConvertOp(aval_to_ir_type(aval_out), x).result
def snd(t): dims = ir.RankedTensorType(t.type).shape return mhlo.BitcastConvertOp( ir.RankedTensorType.get(dims, etype), mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), t)).result