Exemple #1
0
def compile_or_get_cached(backend, computation, compile_options):
  # Avoid import cycle between jax and jax.experimental
  from jax.experimental.compilation_cache import compilation_cache as cc

  if isinstance(computation, ir.Module):
    sym_name = computation.operation.attributes['sym_name']
    module_name = ir.StringAttr(sym_name).value
    computation = mlir.module_to_string(computation)
  else:
    module_name = computation.name()

  # Persistent compilation cache only implemented on TPU.
  # TODO(skye): add warning when initializing cache on unsupported default platform
  if cc.is_initialized() and backend.platform == 'tpu':
    cached_executable = cc.get_executable(computation, compile_options, backend)
    if cached_executable is not None:
      logging.info('Persistent compilation cache hit for %s.', module_name)
      return cached_executable
    else:
      compiled = backend_compile(backend, computation, compile_options)
      cc.put_executable(module_name, computation, compile_options, compiled,
                        backend)
      return compiled

  if FLAGS.jax_dump_ir_to:
    ir_str = (computation if isinstance(computation, str)
              else computation.as_hlo_text())
    _dump_ir_to_file(module_name, ir_str)
  return backend_compile(backend, computation, compile_options)
Exemple #2
0
def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringContext, avals_in,
                          avals_out, *args, **params):
    xla_computation = xla.primitive_subcomputation(ctx.platform, ctx.axis_env,
                                                   prim, *avals_in, **params)
    submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(
        xla_computation)
    submodule = ir.Module.parse(submodule_str)
    callee_name = None
    for op in submodule.body.operations:
        ctx.module.body.append(op)
        if op.name.value == "main":
            callee_name = ir.StringAttr(ctx.symbol_table.insert(op)).value
            op.attributes["sym_visibility"] = ir.StringAttr.get("private")
        else:
            ctx.symbol_table.insert(op)

    output_types = map(aval_to_ir_types, avals_out)
    flat_output_types = util.flatten(output_types)
    output_type = (ir.TupleType.get_tuple(flat_output_types)
                   if prim.multiple_results else flat_output_types[0])

    call = std.CallOp([output_type], ir.FlatSymbolRefAttr.get(callee_name),
                      flatten_lowering_ir_args(args)).result
    if not prim.multiple_results:
        return [call]
    flat_results = [
        mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result
        for i, typ in enumerate(flat_output_types)
    ]
    return util.unflatten(flat_results, map(len, output_types))
Exemple #3
0
    def fallback(ctx: LoweringRuleContext, *args, **params):
        module_ctx = ctx.module_context
        xla_computation = xla.primitive_subcomputation(module_ctx.platform,
                                                       module_ctx.axis_env,
                                                       prim, *ctx.avals_in,
                                                       **params)
        submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(
            xla_computation)
        submodule = ir.Module.parse(submodule_str)
        callee_name = None
        for op in submodule.body.operations:
            op = typing.cast(FuncOpType, op)
            module_ctx.module.body.append(op)
            if op.name.value == "main":
                op.attributes["sym_name"] = ir.StringAttr.get(
                    f"xla_fallback_{prim.name}")
                callee_name = ir.StringAttr(
                    module_ctx.symbol_table.insert(op)).value
                op.attributes["sym_visibility"] = ir.StringAttr.get("private")
            else:
                module_ctx.symbol_table.insert(op)

        output_types = map(aval_to_ir_types, ctx.avals_out)
        flat_output_types = util.flatten(output_types)
        output_type = (ir.TupleType.get_tuple(flat_output_types)
                       if prim.multiple_results else flat_output_types[0])

        call = func_dialect.CallOp([output_type],
                                   ir.FlatSymbolRefAttr.get(callee_name),
                                   flatten_lowering_ir_args(args)).result
        if not prim.multiple_results:
            return [call]
        if jax._src.lib.mlir_api_version < 6:
            flat_results = [
                mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result
                for i, typ in enumerate(flat_output_types)
            ]
        else:
            flat_results = [
                mhlo.GetTupleElementOp(call, i32_attr(i)).result
                for i in range(len(flat_output_types))
            ]

        return util.unflatten(flat_results, map(len, output_types))
Exemple #4
0
def compile_or_get_cached(backend, computation, compile_options):
    # Avoid import cycle between jax and jax.experimental
    from jax.experimental.compilation_cache import compilation_cache as cc

    if isinstance(computation, ir.Module):
        sym_name = computation.operation.attributes['sym_name']
        module_name = ir.StringAttr(sym_name).value
        # Convert ir.Module to str representation (the default), unless the
        # back-end expliclity flags the ability to handle a module directly
        # (avoiding the overhead of back and forth conversions)
        if getattr(backend, "needs_str_ir", True):
            computation = mlir.module_to_string(computation)
    else:
        module_name = computation.name()

    # Persistent compilation cache only implemented on TPU.
    # TODO(skye): add warning when initializing cache on unsupported default platform
    if cc.is_initialized() and backend.platform == 'tpu':
        cached_executable = cc.get_executable(computation, compile_options,
                                              backend)
        if cached_executable is not None:
            logging.info('Persistent compilation cache hit for %s.',
                         module_name)
            return cached_executable
        else:
            compiled = backend_compile(backend, computation, compile_options)
            cc.put_executable(module_name, computation, compile_options,
                              compiled, backend)
            return compiled

    if FLAGS.jax_dump_ir_to:
        if isinstance(computation, xc.XlaComputation):
            ir_str = computation.as_hlo_text()
        elif isinstance(computation, ir.Module):
            ir_str = mlir.module_to_string(computation)
        else:
            assert isinstance(computation, str)
            ir_str = computation
        _dump_ir_to_file(module_name, ir_str)
    return backend_compile(backend, computation, compile_options)
Exemple #5
0
def lower_jaxpr_to_fun(ctx: LoweringContext,
                       name: str,
                       jaxpr: core.ClosedJaxpr,
                       *,
                       public: bool = False,
                       replace_units_with_dummy: bool = False,
                       replace_tokens_with_dummy: bool = False) -> str:
    """Lowers jaxpr and its callees to an IR function.

  Assumes that an MLIR context, location, and insertion point are set.

  Args:
    ctx: the lowering context.
    name: the function name. The name will be uniquified by the symbol table,
      so it is ok to use the same name multiple times.
    jaxpr: the jaxpr to lower.
    public: if true, the function's visibility is set to "public".
    replace_units_with_dummy: if true, unit arguments/return values are
      replaced with bool arrays of size [0].
    replace_tokens_with_dummy: if true, token arguments/return values are
      replaced with bool arrays of size [0].
  Returns the name of the function.
  """
    def aval_to_types(aval):
        if replace_units_with_dummy and aval is core.abstract_unit:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        elif replace_tokens_with_dummy and aval is core.abstract_token:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        return aval_to_ir_types(aval)

    input_types = map(aval_to_types, jaxpr.in_avals)
    output_types = map(aval_to_types, jaxpr.out_avals)
    flat_input_types = util.flatten(input_types)
    flat_output_types = util.flatten(output_types)
    ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
    func_op = builtin.FuncOp(name, ftype, ip=ctx.ip)
    func_op.attributes["sym_visibility"] = ir.StringAttr.get(
        "public" if public else "private")
    symbol_name = ir.StringAttr(ctx.symbol_table.insert(func_op)).value
    entry_block = func_op.add_entry_block()
    with ir.InsertionPoint(entry_block):
        unflattened_args = util.unflatten(entry_block.arguments,
                                          map(len, input_types))
        args: List[List[ir.Value]] = []
        for aval, arg in zip(jaxpr.in_avals, unflattened_args):
            if replace_units_with_dummy and aval is core.abstract_unit:
                args.append([])
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results)
            else:
                args.append(arg)
        callee_name_stack = xla.extend_name_stack(ctx.name_stack,
                                                  xla.wrap_name(name, 'jit'))
        out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
                                 jaxpr.jaxpr, map(ir_constants,
                                                  jaxpr.consts), *args)
        outs = []
        for aval, out in zip(jaxpr.out_avals, out_vals):
            if replace_units_with_dummy and aval is core.abstract_unit:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            else:
                outs.append(out)
        std.ReturnOp(util.flatten(outs))

    return symbol_name