Ejemplo n.º 1
0
Archivo: mlir.py Proyecto: rsepassi/jax
def _remat_using_while(ctx, avals_in, avals_out, *args, name, call_jaxpr):
    input_types = map(aval_to_ir_types, avals_in)
    output_types = map(aval_to_ir_types, avals_out)
    flat_output_types = util.flatten(output_types)
    int32_scalar_type = aval_to_ir_type(
        core.ShapedArray((), np.dtype(np.int32)))
    loop_carry_types = [(int32_scalar_type, )] + input_types + output_types
    flat_loop_carry_types = util.flatten(loop_carry_types)
    counter_init = ir_constants(np.array(0, np.int32))
    flat_args = flatten_lowering_ir_args((counter_init, ) + args + tuple(
        _dummy_like_aval(aval) for aval in avals_out))
    loop_carry_tuple_type = ir.TupleType.get_tuple(flat_loop_carry_types)
    init_carry = mhlo.TupleOp(loop_carry_tuple_type, flat_args)

    one = ir_constant(np.array(1, np.int32))
    while_op = mhlo.WhileOp([loop_carry_tuple_type], [init_carry.result])

    # Loop condition
    cond_block = while_op.regions[0].blocks.append(loop_carry_tuple_type)
    with ir.InsertionPoint(cond_block):
        bool_scalar_type = aval_to_ir_type(
            core.ShapedArray((), np.dtype(np.bool_)))
        two = ir_constant(np.array(2, np.int32))
        shape = ir_constant(np.array((), np.int64), canonicalize_types=False)
        rng = mhlo.RngUniformOp(one, two, shape).result
        i = mhlo.GetTupleElementOp(int32_scalar_type, cond_block.arguments[0],
                                   i32_attr(0))
        cmp = mhlo.CompareOp(bool_scalar_type, i, rng, ir.StringAttr.get("LT"),
                             ir.StringAttr.get("SIGNED")).result
        mhlo.ReturnOp([cmp])

    body_block = while_op.regions[1].blocks.append(loop_carry_tuple_type)
    with ir.InsertionPoint(body_block):
        flat_body_args = [
            mhlo.GetTupleElementOp(input_type, body_block.arguments[0],
                                   i32_attr(i)).result
            for i, input_type in enumerate(flat_loop_carry_types)
        ]
        body_args = util.unflatten(flat_body_args, map(len, loop_carry_types))
        ((i, ), ), y, _ = util.split_list(body_args, [1, len(avals_in)])
        body_ctx = ctx.replace(name_stack=xla.extend_name_stack(
            ctx.name_stack, xla.wrap_name(name, 'remat')))
        z = jaxpr_subcomp(body_ctx, call_jaxpr, (), *y)
        i_next = mhlo.AddOp(i, one).result
        new_carry = mhlo.TupleOp(loop_carry_tuple_type,
                                 [i_next, *util.flatten(y), *util.flatten(z)])
        mhlo.ReturnOp([new_carry.result])

    outputs = [
        mhlo.GetTupleElementOp(output_type, while_op.result,
                               i32_attr(1 + len(avals_in) + i)).result
        for i, output_type in enumerate(flat_output_types)
    ]
    return util.unflatten(outputs, map(len, output_types))
Ejemplo n.º 2
0
Archivo: mlir.py Proyecto: rsepassi/jax
def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringContext, avals_in,
                          avals_out, *args, **params):
    xla_computation = xla.primitive_subcomputation(ctx.platform, ctx.axis_env,
                                                   prim, *avals_in, **params)
    submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(
        xla_computation)
    submodule = ir.Module.parse(submodule_str)
    callee_name = None
    for op in submodule.body.operations:
        ctx.module.body.append(op)
        if op.name.value == "main":
            callee_name = ir.StringAttr(ctx.symbol_table.insert(op)).value
            op.attributes["sym_visibility"] = ir.StringAttr.get("private")
        else:
            ctx.symbol_table.insert(op)

    output_types = map(aval_to_ir_types, avals_out)
    flat_output_types = util.flatten(output_types)
    output_type = (ir.TupleType.get_tuple(flat_output_types)
                   if prim.multiple_results else flat_output_types[0])

    call = std.CallOp([output_type], ir.FlatSymbolRefAttr.get(callee_name),
                      flatten_lowering_ir_args(args)).result
    if not prim.multiple_results:
        return [call]
    flat_results = [
        mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result
        for i, typ in enumerate(flat_output_types)
    ]
    return util.unflatten(flat_results, map(len, output_types))
Ejemplo n.º 3
0
  def code_gen(ctx: mlir.ModuleContext, args_op: Sequence[ir.Value]
              ) -> Sequence[ir.Value]:
    captured_ops = tuple(mlir.ir_constant(np.asarray(inp),
                                          canonicalize_types=False)
                         for inp in captured_inputs)
    submodule = mlir.xla_computation_to_mhlo_module(xla_comp)
    symtab = ir.SymbolTable(submodule.operation)
    callee_result_types = symtab["main"].type.results
    fn = mlir.merge_mhlo_modules(ctx.module, f"call_tf_{function_flat_tf.name}",
                                 submodule)
    call = func_dialect.CallOp(callee_result_types,
                               ir.FlatSymbolRefAttr.get(fn),
                               tuple(args_op) + captured_ops)
    if result_shape.is_tuple():
      flat_results = [mhlo.GetTupleElementOp(call, mlir.i32_attr(i)).result
                      for i in range(len(result_shapes))]
    else:
      flat_results = call.results

    outputs = []
    for op, res_aval, res_shape in zip(flat_results, result_avals,
                                       result_shapes):
      if res_aval.dtype != res_shape.numpy_dtype():
        op = mhlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result
      outputs.append(op)
    return outputs
Ejemplo n.º 4
0
    def fallback(ctx: LoweringRuleContext, *args, **params):
        module_ctx = ctx.module_context
        xla_computation = xla.primitive_subcomputation(module_ctx.platform,
                                                       module_ctx.axis_env,
                                                       prim, *ctx.avals_in,
                                                       **params)
        submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(
            xla_computation)
        submodule = ir.Module.parse(submodule_str)
        callee_name = None
        for op in submodule.body.operations:
            op = typing.cast(FuncOpType, op)
            module_ctx.module.body.append(op)
            if op.name.value == "main":
                op.attributes["sym_name"] = ir.StringAttr.get(
                    f"xla_fallback_{prim.name}")
                callee_name = ir.StringAttr(
                    module_ctx.symbol_table.insert(op)).value
                op.attributes["sym_visibility"] = ir.StringAttr.get("private")
            else:
                module_ctx.symbol_table.insert(op)

        output_types = map(aval_to_ir_types, ctx.avals_out)
        flat_output_types = util.flatten(output_types)
        output_type = (ir.TupleType.get_tuple(flat_output_types)
                       if prim.multiple_results else flat_output_types[0])

        call = func_dialect.CallOp([output_type],
                                   ir.FlatSymbolRefAttr.get(callee_name),
                                   flatten_lowering_ir_args(args)).result
        if not prim.multiple_results:
            return [call]
        if jax._src.lib.mlir_api_version < 6:
            flat_results = [
                mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result
                for i, typ in enumerate(flat_output_types)
            ]
        else:
            flat_results = [
                mhlo.GetTupleElementOp(call, i32_attr(i)).result
                for i in range(len(flat_output_types))
            ]

        return util.unflatten(flat_results, map(len, output_types))