Ejemplo n.º 1
0
 def pack(a, b):
   a_dims = ir.RankedTensorType(a.type).shape
   b_dims = ir.RankedTensorType(b.type).shape
   a = mhlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a)
   b = mhlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b)
   a = mhlo.ConvertOp(ir.RankedTensorType.get(a_dims, double_word_type), a)
   b = mhlo.ConvertOp(ir.RankedTensorType.get(b_dims, double_word_type), b)
   a = mhlo.ShiftLeftOp(a,
                        _broadcast(const(double_word_dtype, nbits), a_dims))
   return mhlo.OrOp(a, b)
Ejemplo n.º 2
0
 def fst(t):
     dims = ir.RankedTensorType(t.type).shape
     st = mhlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits))
     return mhlo.BitcastConvertOp(
         ir.RankedTensorType.get(dims, etype),
         mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type),
                        st)).result
Ejemplo n.º 3
0
  def code_gen(ctx: mlir.ModuleContext, args_op: Sequence[ir.Value]
              ) -> Sequence[ir.Value]:
    captured_ops = tuple(mlir.ir_constant(np.asarray(inp),
                                          canonicalize_types=False)
                         for inp in captured_inputs)
    submodule = mlir.xla_computation_to_mhlo_module(xla_comp)
    symtab = ir.SymbolTable(submodule.operation)
    callee_result_types = symtab["main"].type.results
    fn = mlir.merge_mhlo_modules(ctx.module, f"call_tf_{function_flat_tf.name}",
                                 submodule)
    call = func_dialect.CallOp(callee_result_types,
                               ir.FlatSymbolRefAttr.get(fn),
                               tuple(args_op) + captured_ops)
    if result_shape.is_tuple():
      flat_results = [mhlo.GetTupleElementOp(call, mlir.i32_attr(i)).result
                      for i in range(len(result_shapes))]
    else:
      flat_results = call.results

    outputs = []
    for op, res_aval, res_shape in zip(flat_results, result_avals,
                                       result_shapes):
      if res_aval.dtype != res_shape.numpy_dtype():
        op = mhlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result
      outputs.append(op)
    return outputs
Ejemplo n.º 4
0
def convert_mhlo(x, aval_in, aval_out):
    """Variant of convert that has XLA HLO semantics.

  In particular, treat casts to boolean as x != 0, rather than truncating
  integer values (b/209440332)."""
    if aval_out.dtype == np.dtype(np.bool_):
        if dtypes.issubdtype(aval_in.dtype, np.inexact):
            compare_type = "FLOAT"
        elif dtypes.issubdtype(aval_in.dtype, np.signedinteger):
            compare_type = "SIGNED"
        else:
            compare_type = "UNSIGNED"
        return compare_mhlo(x, full_like_aval(0, aval_in), "NE",
                            compare_type).result
    return mhlo.ConvertOp(aval_to_ir_type(aval_out), x).result
Ejemplo n.º 5
0
 def snd(t):
     dims = ir.RankedTensorType(t.type).shape
     return mhlo.BitcastConvertOp(
         ir.RankedTensorType.get(dims, etype),
         mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type),
                        t)).result