Beispiel #1
0
def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
                   avals_out, *args):
    xla.check_backend_matches(backend, ctx.platform)
    output_types = map(aval_to_ir_types, avals_out)
    flat_output_types = util.flatten(output_types)
    sub_ctx = ctx.replace(
        name_stack=xla.extend_name_stack(ctx.name_stack, stack_name))
    symbol_name = lower_jaxpr_to_fun(sub_ctx, fn_name,
                                     core.ClosedJaxpr(call_jaxpr,
                                                      ())).name.value
    call = func_dialect.CallOp(flat_output_types,
                               ir.FlatSymbolRefAttr.get(symbol_name),
                               flatten_lowering_ir_args(args))
    return util.unflatten(call.results, map(len, output_types))
Beispiel #2
0
    def fallback(ctx: LoweringRuleContext, *args, **params):
        module_ctx = ctx.module_context
        xla_computation = xla.primitive_subcomputation(module_ctx.platform,
                                                       module_ctx.axis_env,
                                                       prim, *ctx.avals_in,
                                                       **params)
        submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(
            xla_computation)
        submodule = ir.Module.parse(submodule_str)
        callee_name = None
        for op in submodule.body.operations:
            op = typing.cast(FuncOpType, op)
            module_ctx.module.body.append(op)
            if op.name.value == "main":
                op.attributes["sym_name"] = ir.StringAttr.get(
                    f"xla_fallback_{prim.name}")
                callee_name = ir.StringAttr(
                    module_ctx.symbol_table.insert(op)).value
                op.attributes["sym_visibility"] = ir.StringAttr.get("private")
            else:
                module_ctx.symbol_table.insert(op)

        output_types = map(aval_to_ir_types, ctx.avals_out)
        flat_output_types = util.flatten(output_types)
        output_type = (ir.TupleType.get_tuple(flat_output_types)
                       if prim.multiple_results else flat_output_types[0])

        call = func_dialect.CallOp([output_type],
                                   ir.FlatSymbolRefAttr.get(callee_name),
                                   flatten_lowering_ir_args(args)).result
        if not prim.multiple_results:
            return [call]
        if jax._src.lib.mlir_api_version < 6:
            flat_results = [
                mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result
                for i, typ in enumerate(flat_output_types)
            ]
        else:
            flat_results = [
                mhlo.GetTupleElementOp(call, i32_attr(i)).result
                for i in range(len(flat_output_types))
            ]

        return util.unflatten(flat_results, map(len, output_types))
Beispiel #3
0
def _sharded_jit_lowering(ctx, *in_nodes, in_parts, out_parts_thunk, nparts,
                          name, call_jaxpr, local_in_parts,
                          local_out_parts_thunk, local_nparts):
    # We assume any extra leading in_nodes are constants and replicate them.
    num_extra_nodes = len(in_nodes) - len(in_parts)
    assert num_extra_nodes >= 0
    in_parts = (None, ) * num_extra_nodes + in_parts

    args = []
    for ns, sharding in safe_zip(
            safe_map(mlir.wrap_singleton_ir_values, in_nodes), in_parts):
        if sharding is not None:
            args.append([
                mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
                for n in ns
            ])
        else:
            args.append(ns)

    sub_ctx = ctx.module_context.replace(
        name_stack=new_name_stack(wrap_name(name, "sharded_jit")))
    fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}",
                                 core.ClosedJaxpr(call_jaxpr, ()))

    output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out)
    flat_output_types = util.flatten(output_types)
    call = func_dialect.CallOp(flat_output_types,
                               ir.FlatSymbolRefAttr.get(fn.name.value),
                               mlir.flatten_lowering_ir_args(args))
    out_nodes = util.unflatten(call.results, safe_map(len, output_types))

    out_parts = out_parts_thunk()
    outputs = []
    for ns, sharding in safe_zip(out_nodes, out_parts):
        if sharding is not None:
            outputs.append([
                mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
                for n in ns
            ])
        else:
            outputs.append(ns)
    return outputs
Beispiel #4
0
    def cached_lowering(ctx, *args, **params):
        assert ctx.primitive is not None
        key = (ctx.primitive, tuple(ctx.avals_in), tuple(ctx.avals_out),
               tuple(params.items()))
        try:
            func = ctx.module_context.cached_primitive_lowerings.get(key)
        except TypeError:
            # If the parameters aren't hashable, give up on caching.
            # TODO(phawkins): switch to requiring hashability, when XLA fallback
            # computations have been ported to MHLO.
            return f(ctx, *args, **params)
        if func is None:
            func = _emit_lowering_rule_as_fun(partial(f, **params), ctx)
            ctx.module_context.cached_primitive_lowerings[key] = func

        output_types = map(aval_to_ir_types, ctx.avals_out)
        flat_output_types = util.flatten(output_types)
        call = func_dialect.CallOp(flat_output_types,
                                   ir.FlatSymbolRefAttr.get(func.name.value),
                                   flatten_lowering_ir_args(args))
        return util.unflatten(call.results, map(len, output_types))