Exemple #1
0
def _remat_using_while(ctx, avals_in, avals_out, *args, name, call_jaxpr):
    input_types = map(aval_to_ir_types, avals_in)
    output_types = map(aval_to_ir_types, avals_out)
    flat_output_types = util.flatten(output_types)
    int32_scalar_type = aval_to_ir_type(
        core.ShapedArray((), np.dtype(np.int32)))
    loop_carry_types = [(int32_scalar_type, )] + input_types + output_types
    flat_loop_carry_types = util.flatten(loop_carry_types)
    counter_init = ir_constants(np.array(0, np.int32))
    flat_args = flatten_lowering_ir_args((counter_init, ) + args + tuple(
        _dummy_like_aval(aval) for aval in avals_out))
    loop_carry_tuple_type = ir.TupleType.get_tuple(flat_loop_carry_types)
    init_carry = mhlo.TupleOp(loop_carry_tuple_type, flat_args)

    one = ir_constant(np.array(1, np.int32))
    while_op = mhlo.WhileOp([loop_carry_tuple_type], [init_carry.result])

    # Loop condition
    cond_block = while_op.regions[0].blocks.append(loop_carry_tuple_type)
    with ir.InsertionPoint(cond_block):
        bool_scalar_type = aval_to_ir_type(
            core.ShapedArray((), np.dtype(np.bool_)))
        two = ir_constant(np.array(2, np.int32))
        shape = ir_constant(np.array((), np.int64), canonicalize_types=False)
        rng = mhlo.RngUniformOp(one, two, shape).result
        i = mhlo.GetTupleElementOp(int32_scalar_type, cond_block.arguments[0],
                                   i32_attr(0))
        cmp = mhlo.CompareOp(bool_scalar_type, i, rng, ir.StringAttr.get("LT"),
                             ir.StringAttr.get("SIGNED")).result
        mhlo.ReturnOp([cmp])

    body_block = while_op.regions[1].blocks.append(loop_carry_tuple_type)
    with ir.InsertionPoint(body_block):
        flat_body_args = [
            mhlo.GetTupleElementOp(input_type, body_block.arguments[0],
                                   i32_attr(i)).result
            for i, input_type in enumerate(flat_loop_carry_types)
        ]
        body_args = util.unflatten(flat_body_args, map(len, loop_carry_types))
        ((i, ), ), y, _ = util.split_list(body_args, [1, len(avals_in)])
        body_ctx = ctx.replace(name_stack=xla.extend_name_stack(
            ctx.name_stack, xla.wrap_name(name, 'remat')))
        z = jaxpr_subcomp(body_ctx, call_jaxpr, (), *y)
        i_next = mhlo.AddOp(i, one).result
        new_carry = mhlo.TupleOp(loop_carry_tuple_type,
                                 [i_next, *util.flatten(y), *util.flatten(z)])
        mhlo.ReturnOp([new_carry.result])

    outputs = [
        mhlo.GetTupleElementOp(output_type, while_op.result,
                               i32_attr(1 + len(avals_in) + i)).result
        for i, output_type in enumerate(flat_output_types)
    ]
    return util.unflatten(outputs, map(len, output_types))
Exemple #2
0
def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringContext, avals_in,
                          avals_out, *args, **params):
    xla_computation = xla.primitive_subcomputation(ctx.platform, ctx.axis_env,
                                                   prim, *avals_in, **params)
    submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(
        xla_computation)
    submodule = ir.Module.parse(submodule_str)
    callee_name = None
    for op in submodule.body.operations:
        ctx.module.body.append(op)
        if op.name.value == "main":
            callee_name = ir.StringAttr(ctx.symbol_table.insert(op)).value
            op.attributes["sym_visibility"] = ir.StringAttr.get("private")
        else:
            ctx.symbol_table.insert(op)

    output_types = map(aval_to_ir_types, avals_out)
    flat_output_types = util.flatten(output_types)
    output_type = (ir.TupleType.get_tuple(flat_output_types)
                   if prim.multiple_results else flat_output_types[0])

    call = std.CallOp([output_type], ir.FlatSymbolRefAttr.get(callee_name),
                      flatten_lowering_ir_args(args)).result
    if not prim.multiple_results:
        return [call]
    flat_results = [
        mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result
        for i, typ in enumerate(flat_output_types)
    ]
    return util.unflatten(flat_results, map(len, output_types))
Exemple #3
0
def _optimization_barrier_lowering_rule(ctx, *args):
    barrier_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
    flat_barrier_types = util.flatten(barrier_types)

    flat_args = mlir.flatten_lowering_ir_args(args)
    barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
    return util.unflatten(barrier_op.results, _map(len, barrier_types))
Exemple #4
0
def _execute_replicated(name: str, compiled: XlaExecutable,
                        input_handler: Optional[Callable],
                        output_buffer_counts: Optional[Sequence[int]],
                        result_handlers, has_unordered_effects: bool,
                        ordered_effects: List[core.Effect], kept_var_idx,
                        *args):
    if has_unordered_effects or ordered_effects:
        # TODO(sharadmv): support jit-of-pmap with effects
        raise NotImplementedError(
            "Cannot execute replicated computation with effects.")
    if input_handler: raise NotImplementedError  # TODO(mattjj, dougalm)
    input_bufs = [
        flatten(
            device_put(x, device) for i, x in enumerate(args)
            if i in kept_var_idx) for device in compiled.local_devices()
    ]
    input_bufs_flip = list(unsafe_zip(*input_bufs))
    out_bufs_flat_rep = compiled.execute_sharded_on_local_devices(
        input_bufs_flip)
    out_bufs_flat = [bufs[0] for bufs in out_bufs_flat_rep]
    check_special(name, out_bufs_flat)
    if output_buffer_counts is None:
        return (result_handlers[0](*out_bufs_flat), )
    out_bufs = unflatten(out_bufs_flat, output_buffer_counts)
    return tuple(h(*bs) for h, bs in unsafe_zip(result_handlers, out_bufs))
Exemple #5
0
def lower_jaxpr_to_xla_module(
        fn_name: str,
        jaxpr: core.ClosedJaxpr,
        platform: str,
        axis_env: AxisEnv,
        name_stack: str,
        tuple_args: bool,
        donated_invars: Sequence[bool],
        replicated_args: Optional[Sequence[bool]],
        arg_partitions: Optional[Any],
        out_partitions: Optional[Any],
        partitions_are_protos: bool = False) -> xc.XlaComputation:
    """Lowers a closed jaxpr to a top-level XLA module."""
    c = xc.XlaBuilder(fn_name)
    xla_consts = _xla_consts(c, jaxpr.consts)
    xla_args, donated_invars = _xla_callable_args(
        c,
        jaxpr.in_avals,
        tuple_args,
        donated_invars=donated_invars,
        replicated=replicated_args,
        partitions=arg_partitions,
        partitions_proto=partitions_are_protos)
    ctx = TranslationContext(c, platform, axis_env, name_stack)
    out_nodes = jaxpr_subcomp(ctx, jaxpr.jaxpr, xla_consts, *xla_args)
    # Replace tokens with a dummy array value, because the runtime cannot
    # handle token arguments.
    out_aval_lens = [len(aval_to_xla_shapes(a)) for a in jaxpr.out_avals]
    out_nodes = util.flatten(
        [[_make_token_return_value(c)] if a is core.abstract_token else v
         for a, v in zip(jaxpr.out_avals,
                         util.unflatten(out_nodes, out_aval_lens))])

    # There is a non-zero cost to building an output tuple, particularly on TPU.
    # Avoid it if the output arity is 1.
    if out_partitions is None:
        output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple(
            c, out_nodes)
    else:
        build_out_tuple = partial(xops.Tuple, c, out_nodes)
        if partitions_are_protos:
            output = with_sharding_proto(c, out_partitions, build_out_tuple)
        else:
            output = with_sharding(c, out_partitions, build_out_tuple)

    if platform in ("gpu", "tpu"):
        donated_invars = set_up_aliases(c, xla_args, c.GetShape(output),
                                        donated_invars, tuple_args)
    if any(donated_invars):
        # TODO(tomhennigan): At call time we should mark these buffers as deleted.
        unused_donations = [
            str(c.GetShape(a)) for a, d in zip(xla_args, donated_invars) if d
        ]
        warnings.warn("Some donated buffers were not usable: {}".format(
            ", ".join(unused_donations)))
    return c.build(output)
Exemple #6
0
def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
                   avals_out, *args):
    xla.check_backend_matches(backend, ctx.platform)
    output_types = map(aval_to_ir_types, avals_out)
    flat_output_types = util.flatten(output_types)
    sub_ctx = ctx.replace(
        name_stack=xla.extend_name_stack(ctx.name_stack, stack_name))
    symbol_name = lower_jaxpr_to_fun(sub_ctx, fn_name,
                                     core.ClosedJaxpr(call_jaxpr, ()))
    call = std.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(symbol_name),
                      flatten_lowering_ir_args(args))
    return util.unflatten(call.results, map(len, output_types))
Exemple #7
0
def _execute_compiled(name: str, compiled: XlaExecutable,
                      output_buffer_counts: Optional[Sequence[int]],
                      result_handlers, kept_var_idx, *args):
  device, = compiled.local_devices()
  input_bufs = util.flatten(
      device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx)
  out_bufs = compiled.execute(input_bufs)
  check_special(name, out_bufs)
  if output_buffer_counts is None:
    return (result_handlers[0](*out_bufs),)
  return tuple(
      handler(*bs) for handler, bs in
      unsafe_zip(result_handlers, util.unflatten(out_bufs, output_buffer_counts)))
Exemple #8
0
def _execute_compiled(name: str, compiled: XlaExecutable,
                      input_handler: Optional[Callable],
                      output_buffer_counts: Optional[Sequence[int]],
                      result_handlers, kept_var_idx, *args):
  device, = compiled.local_devices()
  args = input_handler(args) if input_handler else args
  input_bufs_flat = flatten(device_put(x, device) for i, x in enumerate(args)
                            if i in kept_var_idx)
  out_bufs_flat = compiled.execute(input_bufs_flat)
  check_special(name, out_bufs_flat)
  if output_buffer_counts is None:
    return (result_handlers[0](*out_bufs_flat),)
  out_bufs = unflatten(out_bufs_flat, output_buffer_counts)
  return tuple(h(*bs) for h, bs in unsafe_zip(result_handlers, out_bufs))
Exemple #9
0
def _execute_compiled(name: str, compiled: XlaExecutable,
                      input_handler: Optional[Callable],
                      output_buffer_counts: Sequence[int],
                      result_handler: Callable, has_unordered_effects: bool,
                      ordered_effects: List[core.Effect], kept_var_idx, *args):
    device, = compiled.local_devices()
    args, env = input_handler(args) if input_handler else (args, None)
    in_flat = flatten(
        device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx)
    if has_unordered_effects or ordered_effects:
        in_flat, token_handler = _add_tokens(has_unordered_effects,
                                             ordered_effects, device, in_flat)
    out_flat = compiled.execute(in_flat)
    check_special(name, out_flat)
    out_bufs = unflatten(out_flat, output_buffer_counts)
    if ordered_effects or has_unordered_effects:
        out_bufs = token_handler(out_bufs)
    return result_handler(env, out_bufs)
Exemple #10
0
def _execute_replicated(name: str, compiled: XlaExecutable,
                        output_buffer_counts: Optional[Sequence[int]],
                        result_handlers, kept_var_idx, *args):
  input_bufs = [
      util.flatten(
        device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx)
      for device in compiled.local_devices()
  ]
  out_bufs = [
      buf[0] for buf in compiled.execute_sharded_on_local_devices(
          list(zip(*input_bufs)))
  ]
  check_special(name, out_bufs)
  if output_buffer_counts is None:
    return (result_handlers[0](*out_bufs),)
  return tuple(
      handler(*bs) for handler, bs in
      unsafe_zip(result_handlers, util.unflatten(out_bufs, output_buffer_counts)))
Exemple #11
0
    def fallback(ctx: LoweringRuleContext, *args, **params):
        module_ctx = ctx.module_context
        xla_computation = xla.primitive_subcomputation(module_ctx.platform,
                                                       module_ctx.axis_env,
                                                       prim, *ctx.avals_in,
                                                       **params)
        submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(
            xla_computation)
        submodule = ir.Module.parse(submodule_str)
        callee_name = None
        for op in submodule.body.operations:
            op = typing.cast(FuncOpType, op)
            module_ctx.module.body.append(op)
            if op.name.value == "main":
                op.attributes["sym_name"] = ir.StringAttr.get(
                    f"xla_fallback_{prim.name}")
                callee_name = ir.StringAttr(
                    module_ctx.symbol_table.insert(op)).value
                op.attributes["sym_visibility"] = ir.StringAttr.get("private")
            else:
                module_ctx.symbol_table.insert(op)

        output_types = map(aval_to_ir_types, ctx.avals_out)
        flat_output_types = util.flatten(output_types)
        output_type = (ir.TupleType.get_tuple(flat_output_types)
                       if prim.multiple_results else flat_output_types[0])

        call = func_dialect.CallOp([output_type],
                                   ir.FlatSymbolRefAttr.get(callee_name),
                                   flatten_lowering_ir_args(args)).result
        if not prim.multiple_results:
            return [call]
        if jax._src.lib.mlir_api_version < 6:
            flat_results = [
                mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result
                for i, typ in enumerate(flat_output_types)
            ]
        else:
            flat_results = [
                mhlo.GetTupleElementOp(call, i32_attr(i)).result
                for i in range(len(flat_output_types))
            ]

        return util.unflatten(flat_results, map(len, output_types))
Exemple #12
0
def _emit_lowering_rule_as_fun(lowering_rule,
                               ctx: LoweringRuleContext) -> builtin.FuncOp:
  """Emits the contents of a lowering rule as a private function."""
  input_types = map(aval_to_ir_types, ctx.avals_in)
  output_types = map(aval_to_ir_types, ctx.avals_out)
  flat_input_types = util.flatten(input_types)
  flat_output_types = util.flatten(output_types)
  ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
  assert ctx.primitive is not None
  func_op = builtin.FuncOp(ctx.primitive.name, ftype, ip=ctx.module_context.ip)
  func_op.attributes["sym_visibility"] = ir.StringAttr.get("private")
  ctx.module_context.symbol_table.insert(func_op)
  entry_block = func_op.add_entry_block()
  with ir.InsertionPoint(entry_block):
    unflattened_args = util.unflatten(entry_block.arguments,
                                      map(len, input_types))
    outs = lowering_rule(ctx, *_unwrap_singleton_ir_values(unflattened_args))
    std.ReturnOp(util.flatten(map(wrap_singleton_ir_values, outs)))
  return func_op
Exemple #13
0
def _sharded_jit_lowering(ctx, *in_nodes, in_parts, out_parts_thunk, nparts,
                          name, call_jaxpr, local_in_parts,
                          local_out_parts_thunk, local_nparts):
    # We assume any extra leading in_nodes are constants and replicate them.
    num_extra_nodes = len(in_nodes) - len(in_parts)
    assert num_extra_nodes >= 0
    in_parts = (None, ) * num_extra_nodes + in_parts

    args = []
    for ns, sharding in safe_zip(
            safe_map(mlir.wrap_singleton_ir_values, in_nodes), in_parts):
        if sharding is not None:
            args.append([
                mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
                for n in ns
            ])
        else:
            args.append(ns)

    sub_ctx = ctx.module_context.replace(
        name_stack=extend_name_stack(wrap_name(name, "sharded_jit")))
    fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}",
                                 core.ClosedJaxpr(call_jaxpr, ()))

    output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out)
    flat_output_types = util.flatten(output_types)
    call = std.CallOp(flat_output_types,
                      ir.FlatSymbolRefAttr.get(fn.name.value),
                      mlir.flatten_lowering_ir_args(args))
    out_nodes = util.unflatten(call.results, safe_map(len, output_types))

    out_parts = out_parts_thunk()
    outputs = []
    for ns, sharding in safe_zip(out_nodes, out_parts):
        if sharding is not None:
            outputs.append([
                mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
                for n in ns
            ])
        else:
            outputs.append(ns)
    return outputs
Exemple #14
0
def _cond_lowering(ctx, index, *args, branches, linear):
    del linear  # Unused.
    joined_effects = core.join_effects(*(branch.effects
                                         for branch in branches))
    ordered_effects = [
        eff for eff in joined_effects if eff in core.ordered_effects
    ]
    num_tokens = len(ordered_effects)
    tokens_in = ctx.tokens_in.subset(ordered_effects)
    output_token_types = [mlir.token_type() for _ in ordered_effects]
    output_types = [
        *output_token_types, *_map(mlir.aval_to_ir_types, ctx.avals_out)
    ]
    flat_output_types = util.flatten(output_types)

    # mhlo.CaseOp takes a single argument 'index' and the corresponding blocks
    # have no arguments; the computation within the block uses implicit
    # captures.
    case_op = mhlo.CaseOp(flat_output_types,
                          index=index,
                          num_branches=len(branches))
    name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond')
    for i, jaxpr in enumerate(branches):
        branch = case_op.regions[i].blocks.append()
        with ir.InsertionPoint(branch):
            sub_ctx = ctx.module_context.replace(
                name_stack=xla.extend_name_stack(name_stack,
                                                 f'branch_{i}_fun'))
            out_vals, tokens_out = mlir.jaxpr_subcomp(
                sub_ctx, jaxpr.jaxpr, tokens_in,
                _map(mlir.ir_constants, jaxpr.consts),
                *_map(mlir.wrap_singleton_ir_values, args))
            out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
            out_vals = [*out_tokens, *out_vals]
            mhlo.ReturnOp(util.flatten(out_vals))

    tokens_and_outputs = util.unflatten(case_op.results,
                                        _map(len, output_types))
    tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])
    ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens)))
    return outputs
Exemple #15
0
def _execute_compiled(name: str, compiled: XlaExecutable,
                      input_handler: Optional[Callable],
                      output_buffer_counts: Optional[Sequence[int]],
                      result_handlers, has_unordered_effects: bool,
                      ordered_effects: List[core.Effect], kept_var_idx, *args):
    device, = compiled.local_devices()
    args = input_handler(args) if input_handler else args
    input_bufs_flat = flatten(
        device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx)
    if has_unordered_effects or ordered_effects:
        input_bufs_flat, token_handler = _add_tokens(has_unordered_effects,
                                                     ordered_effects, device,
                                                     input_bufs_flat)
    out_bufs_flat = compiled.execute(input_bufs_flat)
    check_special(name, out_bufs_flat)
    if output_buffer_counts is None:
        return (result_handlers[0](*out_bufs_flat), )
    out_bufs = unflatten(out_bufs_flat, output_buffer_counts)
    if ordered_effects or has_unordered_effects:
        out_bufs = token_handler(out_bufs)
    return tuple(h(*bs) for h, bs in unsafe_zip(result_handlers, out_bufs))
Exemple #16
0
    def cached_lowering(ctx, *args, **params):
        assert ctx.primitive is not None
        key = (ctx.primitive, tuple(ctx.avals_in), tuple(ctx.avals_out),
               tuple(params.items()))
        try:
            func = ctx.module_context.cached_primitive_lowerings.get(key)
        except TypeError:
            # If the parameters aren't hashable, give up on caching.
            # TODO(phawkins): switch to requiring hashability, when XLA fallback
            # computations have been ported to MHLO.
            return f(ctx, *args, **params)
        if func is None:
            func = _emit_lowering_rule_as_fun(partial(f, **params), ctx)
            ctx.module_context.cached_primitive_lowerings[key] = func

        output_types = map(aval_to_ir_types, ctx.avals_out)
        flat_output_types = util.flatten(output_types)
        call = func_dialect.CallOp(flat_output_types,
                                   ir.FlatSymbolRefAttr.get(func.name.value),
                                   flatten_lowering_ir_args(args))
        return util.unflatten(call.results, map(len, output_types))
Exemple #17
0
def _partitionmap(func: Callable, vars: Sequence, nodes: Sequence):
    return map(
        func, vars,
        util.unflatten(nodes, [len(aval_to_xla_shapes(v.aval)) for v in vars]))
Exemple #18
0
def lower_jaxpr_to_fun(
    ctx: ModuleContext,
    name: str,
    jaxpr: core.ClosedJaxpr,
    *,
    public: bool = False,
    replace_units_with_dummy: bool = False,
    replace_tokens_with_dummy: bool = False,
    replicated_args: Optional[Sequence[bool]] = None,
    arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
    result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
    use_sharding_annotations: bool = True,
    input_output_aliases: Optional[Sequence[Optional[int]]] = None
) -> FuncOpType:
    """Lowers jaxpr and its callees to an IR function.

  Assumes that an MLIR context, location, and insertion point are set.

  Args:
    ctx: the lowering context.
    name: the function name. The name will be uniquified by the symbol table,
      so it is ok to use the same name multiple times.
    jaxpr: the jaxpr to lower.
    public: if true, the function's visibility is set to "public".
    replace_units_with_dummy: if true, unit arguments/return values are
      replaced with bool arrays of size [0].
    replace_tokens_with_dummy: if true, token arguments/return values are
      replaced with bool arrays of size [0].
    replicated_args: if present, annotates arguments as replicated.
    arg_shardings: sharding annotations for each argument (optional).
    result_shardings: sharding annotations for each argument (optional).
    use_sharding_annotations: if True, use mhlo.sharding annotations on
      parameters and return values to express sharding. If False, use
      mhlo.custom_call operators with sharding annotations.
      TODO(b/228598865): remove this option when mhlo.sharding annotations are
      propagated on non-entry functions during MHLO->HLO conversion.
    input_output_aliases: optional sequence that maps argument numbers to the
      corresponding output that should alias them.
  Returns the name of the function.
  """
    def aval_to_types(aval):
        if replace_units_with_dummy and aval is core.abstract_unit:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        elif replace_tokens_with_dummy and aval is core.abstract_token:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        return aval_to_ir_types(aval)

    input_types = map(aval_to_types, jaxpr.in_avals)
    output_types = map(aval_to_types, jaxpr.out_avals)
    flat_input_types = util.flatten(input_types)
    flat_output_types = util.flatten(output_types)
    ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
    func_op = FuncOp(name, ftype, ip=ctx.ip)
    func_op.attributes["sym_visibility"] = ir.StringAttr.get(
        "public" if public else "private")
    ctx.symbol_table.insert(func_op)
    ir_arg_shardings = None
    if arg_shardings is not None:
        ir_arg_shardings = util.flatten(
            [[sharding] * len(types)
             for sharding, types in zip(arg_shardings, input_types)])
    ir_result_shardings = None
    if result_shardings is not None:
        ir_result_shardings = util.flatten(
            [[sharding] * len(types)
             for sharding, types in zip(result_shardings, output_types)])

    if (replicated_args is not None or ir_arg_shardings is not None
            or input_output_aliases is not None):
        arg_attrs: List[Dict[str, ir.Attribute]] = [
            {} for _ in range(len(flat_input_types))
        ]

        if replicated_args is not None:
            replicated_ir_args = [
                [replicated] * len(types)
                for replicated, types in zip(replicated_args, input_types)
            ]
            for attrs, replicated in zip(arg_attrs,
                                         util.flatten(replicated_ir_args)):
                if replicated:
                    attrs[
                        "mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get(
                        )

        if use_sharding_annotations and ir_arg_shardings is not None:
            for attrs, sharding in zip(arg_attrs, ir_arg_shardings):
                if sharding is not None:
                    attrs["mhlo.sharding"] = ir.StringAttr.get(
                        sharding.SerializeToString())

        if input_output_aliases is not None:
            output_ids = util.unflatten(list(range(len(flat_output_types))),
                                        map(len, output_types))
            aliases: List[Optional[int]] = []
            for types, alias in zip(input_types, input_output_aliases):
                if alias is None:
                    aliases.extend([None] * len(types))
                else:
                    aliases.extend(output_ids[alias])

            for attrs, alias in zip(arg_attrs, aliases):
                if alias is not None:
                    attrs["tf.aliasing_output"] = i32_attr(alias)

        func_op.arg_attrs = ir.ArrayAttr.get(
            [ir.DictAttr.get(attrs) for attrs in arg_attrs])

    if use_sharding_annotations and ir_result_shardings is not None:
        func_op.result_attrs = ir.ArrayAttr.get([
            ir.DictAttr.get({} if sharding is None else {
                "mhlo.sharding":
                ir.StringAttr.get(sharding.SerializeToString())
            }) for sharding in ir_result_shardings
        ])

    entry_block = func_op.add_entry_block()
    with ir.InsertionPoint(entry_block):
        flat_args = entry_block.arguments
        if not use_sharding_annotations and ir_arg_shardings is not None:
            flat_args = map(wrap_with_sharding_op, flat_args, ir_arg_shardings)

        unflattened_args = util.unflatten(flat_args, map(len, input_types))
        args: List[List[ir.Value]] = []
        for aval, arg in zip(jaxpr.in_avals, unflattened_args):
            if replace_units_with_dummy and aval is core.abstract_unit:
                args.append([])
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results)
            else:
                args.append(arg)
        callee_name_stack = xla.extend_name_stack(ctx.name_stack,
                                                  xla.wrap_name(name, 'jit'))
        out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
                                 jaxpr.jaxpr, map(ir_constants,
                                                  jaxpr.consts), *args)
        outs = []
        for aval, out in zip(jaxpr.out_avals, out_vals):
            if replace_units_with_dummy and aval is core.abstract_unit:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            else:
                outs.append(out)
        flat_outputs = util.flatten(outs)
        if not use_sharding_annotations and ir_result_shardings is not None:
            flat_outputs = map(wrap_with_sharding_op, flat_outputs,
                               ir_result_shardings)

        func_dialect.ReturnOp(flat_outputs)

    return func_op
Exemple #19
0
def lower_jaxpr_to_fun(ctx: LoweringContext,
                       name: str,
                       jaxpr: core.ClosedJaxpr,
                       *,
                       public: bool = False,
                       replace_units_with_dummy: bool = False,
                       replace_tokens_with_dummy: bool = False) -> str:
    """Lowers jaxpr and its callees to an IR function.

  Assumes that an MLIR context, location, and insertion point are set.

  Args:
    ctx: the lowering context.
    name: the function name. The name will be uniquified by the symbol table,
      so it is ok to use the same name multiple times.
    jaxpr: the jaxpr to lower.
    public: if true, the function's visibility is set to "public".
    replace_units_with_dummy: if true, unit arguments/return values are
      replaced with bool arrays of size [0].
    replace_tokens_with_dummy: if true, token arguments/return values are
      replaced with bool arrays of size [0].
  Returns the name of the function.
  """
    def aval_to_types(aval):
        if replace_units_with_dummy and aval is core.abstract_unit:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        elif replace_tokens_with_dummy and aval is core.abstract_token:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        return aval_to_ir_types(aval)

    input_types = map(aval_to_types, jaxpr.in_avals)
    output_types = map(aval_to_types, jaxpr.out_avals)
    flat_input_types = util.flatten(input_types)
    flat_output_types = util.flatten(output_types)
    ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
    func_op = builtin.FuncOp(name, ftype, ip=ctx.ip)
    func_op.attributes["sym_visibility"] = ir.StringAttr.get(
        "public" if public else "private")
    symbol_name = ir.StringAttr(ctx.symbol_table.insert(func_op)).value
    entry_block = func_op.add_entry_block()
    with ir.InsertionPoint(entry_block):
        unflattened_args = util.unflatten(entry_block.arguments,
                                          map(len, input_types))
        args: List[List[ir.Value]] = []
        for aval, arg in zip(jaxpr.in_avals, unflattened_args):
            if replace_units_with_dummy and aval is core.abstract_unit:
                args.append([])
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results)
            else:
                args.append(arg)
        callee_name_stack = xla.extend_name_stack(ctx.name_stack,
                                                  xla.wrap_name(name, 'jit'))
        out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
                                 jaxpr.jaxpr, map(ir_constants,
                                                  jaxpr.consts), *args)
        outs = []
        for aval, out in zip(jaxpr.out_avals, out_vals):
            if replace_units_with_dummy and aval is core.abstract_unit:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            else:
                outs.append(out)
        std.ReturnOp(util.flatten(outs))

    return symbol_name