Example #1
0
File: mlir.py Project: rsepassi/jax
def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringContext, avals_in,
                          avals_out, *args, **params):
    xla_computation = xla.primitive_subcomputation(ctx.platform, ctx.axis_env,
                                                   prim, *avals_in, **params)
    submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(
        xla_computation)
    submodule = ir.Module.parse(submodule_str)
    callee_name = None
    for op in submodule.body.operations:
        ctx.module.body.append(op)
        if op.name.value == "main":
            callee_name = ir.StringAttr(ctx.symbol_table.insert(op)).value
            op.attributes["sym_visibility"] = ir.StringAttr.get("private")
        else:
            ctx.symbol_table.insert(op)

    output_types = map(aval_to_ir_types, avals_out)
    flat_output_types = util.flatten(output_types)
    output_type = (ir.TupleType.get_tuple(flat_output_types)
                   if prim.multiple_results else flat_output_types[0])

    call = std.CallOp([output_type], ir.FlatSymbolRefAttr.get(callee_name),
                      flatten_lowering_ir_args(args)).result
    if not prim.multiple_results:
        return [call]
    flat_results = [
        mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result
        for i, typ in enumerate(flat_output_types)
    ]
    return util.unflatten(flat_results, map(len, output_types))
Example #2
0
File: mlir.py Project: rsepassi/jax
def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
                   avals_out, *args):
    xla.check_backend_matches(backend, ctx.platform)
    output_types = map(aval_to_ir_types, avals_out)
    flat_output_types = util.flatten(output_types)
    sub_ctx = ctx.replace(
        name_stack=xla.extend_name_stack(ctx.name_stack, stack_name))
    symbol_name = lower_jaxpr_to_fun(sub_ctx, fn_name,
                                     core.ClosedJaxpr(call_jaxpr, ()))
    call = std.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(symbol_name),
                      flatten_lowering_ir_args(args))
    return util.unflatten(call.results, map(len, output_types))
Example #3
0
def _sharded_jit_lowering(ctx, *in_nodes, in_parts, out_parts_thunk, nparts,
                          name, call_jaxpr, local_in_parts,
                          local_out_parts_thunk, local_nparts):
    # We assume any extra leading in_nodes are constants and replicate them.
    num_extra_nodes = len(in_nodes) - len(in_parts)
    assert num_extra_nodes >= 0
    in_parts = (None, ) * num_extra_nodes + in_parts

    args = []
    for ns, sharding in safe_zip(
            safe_map(mlir.wrap_singleton_ir_values, in_nodes), in_parts):
        if sharding is not None:
            args.append([
                mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
                for n in ns
            ])
        else:
            args.append(ns)

    sub_ctx = ctx.module_context.replace(
        name_stack=extend_name_stack(wrap_name(name, "sharded_jit")))
    fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}",
                                 core.ClosedJaxpr(call_jaxpr, ()))

    output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out)
    flat_output_types = util.flatten(output_types)
    call = std.CallOp(flat_output_types,
                      ir.FlatSymbolRefAttr.get(fn.name.value),
                      mlir.flatten_lowering_ir_args(args))
    out_nodes = util.unflatten(call.results, safe_map(len, output_types))

    out_parts = out_parts_thunk()
    outputs = []
    for ns, sharding in safe_zip(out_nodes, out_parts):
        if sharding is not None:
            outputs.append([
                mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
                for n in ns
            ])
        else:
            outputs.append(ns)
    return outputs
Example #4
0
    def cached_lowering(ctx, *args, **params):
        assert ctx.primitive is not None
        key = (ctx.primitive, tuple(ctx.avals_in), tuple(ctx.avals_out),
               tuple(params.items()))
        try:
            func = ctx.module_context.cached_primitive_lowerings.get(key)
        except TypeError:
            # If the parameters aren't hashable, give up on caching.
            # TODO(phawkins): switch to requiring hashability, when XLA fallback
            # computations have been ported to MHLO.
            return f(ctx, *args, **params)
        if func is None:
            func = _emit_lowering_rule_as_fun(partial(f, **params), ctx)
            ctx.module_context.cached_primitive_lowerings[key] = func

        output_types = map(aval_to_ir_types, ctx.avals_out)
        flat_output_types = util.flatten(output_types)
        call = std.CallOp(flat_output_types,
                          ir.FlatSymbolRefAttr.get(func.name.value),
                          flatten_lowering_ir_args(args))
        return util.unflatten(call.results, map(len, output_types))