Example #1
0
def _execute_replicated(name: str, compiled: XlaExecutable,
                        output_buffer_counts: Optional[Sequence[int]],
                        result_handlers, kept_var_idx, *args):
    input_bufs = [
        util.flatten(
            device_put(x, device) for i, x in enumerate(args)
            if i in kept_var_idx) for device in compiled.local_devices()
    ]
    out_bufs = [
        buf[0] for buf in compiled.execute_sharded_on_local_devices(
            list(zip(*input_bufs)))
    ]
    check_special(name, out_bufs)
    if output_buffer_counts is None:
        return (result_handlers[0](*out_bufs), )
    return tuple(
        handler(*bs) for handler, bs in unsafe_zip(
            result_handlers, util.unflatten(out_bufs, output_buffer_counts)))
Example #2
0
    def fallback(ctx: LoweringRuleContext, *args, **params):
        module_ctx = ctx.module_context
        xla_computation = xla.primitive_subcomputation(module_ctx.platform,
                                                       module_ctx.axis_env,
                                                       prim, *ctx.avals_in,
                                                       **params)
        submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(
            xla_computation)
        submodule = ir.Module.parse(submodule_str)
        callee_name = None
        for op in submodule.body.operations:
            op = typing.cast(FuncOpType, op)
            module_ctx.module.body.append(op)
            if op.name.value == "main":
                op.attributes["sym_name"] = ir.StringAttr.get(
                    f"xla_fallback_{prim.name}")
                callee_name = ir.StringAttr(
                    module_ctx.symbol_table.insert(op)).value
                op.attributes["sym_visibility"] = ir.StringAttr.get("private")
            else:
                module_ctx.symbol_table.insert(op)

        output_types = map(aval_to_ir_types, ctx.avals_out)
        flat_output_types = util.flatten(output_types)
        output_type = (ir.TupleType.get_tuple(flat_output_types)
                       if prim.multiple_results else flat_output_types[0])

        call = func_dialect.CallOp([output_type],
                                   ir.FlatSymbolRefAttr.get(callee_name),
                                   flatten_lowering_ir_args(args)).result
        if not prim.multiple_results:
            return [call]
        if jax._src.lib.mlir_api_version < 6:
            flat_results = [
                mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result
                for i, typ in enumerate(flat_output_types)
            ]
        else:
            flat_results = [
                mhlo.GetTupleElementOp(call, i32_attr(i)).result
                for i in range(len(flat_output_types))
            ]

        return util.unflatten(flat_results, map(len, output_types))
Example #3
0
def _generic_reduce_window_lower(ctx, *args, jaxpr, consts, window_dimensions,
                                 window_strides, padding, base_dilation,
                                 window_dilation):
    operands, init_values = util.split_list(args, [len(args) // 2])
    _, init_value_avals = util.split_list(ctx.avals_in, [len(operands)])
    scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
    rw = mhlo.ReduceWindowOp(
        map(mlir.aval_to_ir_type, ctx.avals_out), operands, init_values,
        mlir.dense_int_elements(window_dimensions),
        mlir.dense_int_elements(window_strides),
        mlir.dense_int_elements(base_dilation),
        mlir.dense_int_elements(window_dilation),
        ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
    reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
    with ir.InsertionPoint(reducer):
        out_nodes = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, consts,
                                       *([a] for a in reducer.arguments))
        mhlo.ReturnOp(util.flatten(out_nodes))
    return rw.results
Example #4
0
def _sharded_jit_lowering(ctx, *in_nodes, in_parts, out_parts_thunk, nparts,
                          name, call_jaxpr, local_in_parts,
                          local_out_parts_thunk, local_nparts):
    # We assume any extra leading in_nodes are constants and replicate them.
    num_extra_nodes = len(in_nodes) - len(in_parts)
    assert num_extra_nodes >= 0
    in_parts = (None, ) * num_extra_nodes + in_parts

    args = []
    for ns, sharding in safe_zip(
            safe_map(mlir.wrap_singleton_ir_values, in_nodes), in_parts):
        if sharding is not None:
            args.append([
                mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
                for n in ns
            ])
        else:
            args.append(ns)

    sub_ctx = ctx.module_context.replace(
        name_stack=extend_name_stack(wrap_name(name, "sharded_jit")))
    fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}",
                                 core.ClosedJaxpr(call_jaxpr, ()))

    output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out)
    flat_output_types = util.flatten(output_types)
    call = std.CallOp(flat_output_types,
                      ir.FlatSymbolRefAttr.get(fn.name.value),
                      mlir.flatten_lowering_ir_args(args))
    out_nodes = util.unflatten(call.results, safe_map(len, output_types))

    out_parts = out_parts_thunk()
    outputs = []
    for ns, sharding in safe_zip(out_nodes, out_parts):
        if sharding is not None:
            outputs.append([
                mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
                for n in ns
            ])
        else:
            outputs.append(ns)
    return outputs
Example #5
0
def _execute_compiled(name: str, compiled: XlaExecutable,
                      input_handler: Optional[Callable],
                      output_buffer_counts: Optional[Sequence[int]],
                      result_handlers, effects: List[core.Effect],
                      kept_var_idx, *args):
    device, = compiled.local_devices()
    args = input_handler(args) if input_handler else args
    input_bufs_flat = flatten(
        device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx)
    if effects:
        input_bufs_flat, token_handler = _add_tokens(effects, device,
                                                     input_bufs_flat)
    out_bufs_flat = compiled.execute(input_bufs_flat)
    check_special(name, out_bufs_flat)
    if output_buffer_counts is None:
        return (result_handlers[0](*out_bufs_flat), )
    out_bufs = unflatten(out_bufs_flat, output_buffer_counts)
    if effects:
        out_bufs = token_handler(out_bufs)
    return tuple(h(*bs) for h, bs in unsafe_zip(result_handlers, out_bufs))
Example #6
0
    def cached_lowering(ctx, *args, **params):
        assert ctx.primitive is not None
        key = (ctx.primitive, tuple(ctx.avals_in), tuple(ctx.avals_out),
               tuple(params.items()))
        try:
            func = ctx.module_context.cached_primitive_lowerings.get(key)
        except TypeError:
            # If the parameters aren't hashable, give up on caching.
            # TODO(phawkins): switch to requiring hashability, when XLA fallback
            # computations have been ported to MHLO.
            return f(ctx, *args, **params)
        if func is None:
            func = _emit_lowering_rule_as_fun(partial(f, **params), ctx)
            ctx.module_context.cached_primitive_lowerings[key] = func

        output_types = map(aval_to_ir_types, ctx.avals_out)
        flat_output_types = util.flatten(output_types)
        call = func_dialect.CallOp(flat_output_types,
                                   ir.FlatSymbolRefAttr.get(func.name.value),
                                   flatten_lowering_ir_args(args))
        return util.unflatten(call.results, map(len, output_types))
Example #7
0
def _execute_replicated(name: str, compiled: XlaExecutable,
                        input_handler: Optional[Callable],
                        output_buffer_counts: Sequence[int],
                        result_handler: Callable, has_unordered_effects: bool,
                        ordered_effects: List[core.Effect], kept_var_idx,
                        *args):
    if has_unordered_effects or ordered_effects:
        # TODO(sharadmv): support jit-of-pmap with effects
        raise NotImplementedError(
            "Cannot execute replicated computation with effects.")
    if input_handler: raise NotImplementedError  # TODO(mattjj, dougalm)
    input_bufs = [
        flatten(
            device_put(x, device) for i, x in enumerate(args)
            if i in kept_var_idx) for device in compiled.local_devices()
    ]
    input_bufs_flip = list(unsafe_zip(*input_bufs))
    out_bufs_flat_rep = compiled.execute_sharded_on_local_devices(
        input_bufs_flip)
    out_flat = [bufs[0] for bufs in out_bufs_flat_rep]
    check_special(name, out_flat)
    out_bufs = unflatten(out_flat, output_buffer_counts)
    return result_handler(None, out_bufs)
Example #8
0
def _execute_replicated(name: str, compiled: XlaExecutable,
                        input_handler: Optional[Callable],
                        output_buffer_counts: Optional[Sequence[int]],
                        result_handlers, effects: List[core.Effect],
                        kept_var_idx, *args):
    if effects:
        raise NotImplementedError(
            'Cannot execute replicated computation with effects.')
    if input_handler: raise NotImplementedError  # TODO(mattjj, dougalm)
    input_bufs = [
        flatten(
            device_put(x, device) for i, x in enumerate(args)
            if i in kept_var_idx) for device in compiled.local_devices()
    ]
    input_bufs_flip = list(unsafe_zip(*input_bufs))
    out_bufs_flat_rep = compiled.execute_sharded_on_local_devices(
        input_bufs_flip)
    out_bufs_flat = [bufs[0] for bufs in out_bufs_flat_rep]
    check_special(name, out_bufs_flat)
    if output_buffer_counts is None:
        return (result_handlers[0](*out_bufs_flat), )
    out_bufs = unflatten(out_bufs_flat, output_buffer_counts)
    return tuple(h(*bs) for h, bs in unsafe_zip(result_handlers, out_bufs))
Example #9
0
def lower_jaxpr_to_xla_module(
        fn_name: str,
        jaxpr: core.ClosedJaxpr,
        platform: str,
        axis_env: AxisEnv,
        name_stack: Union[source_info_util.NameStack, str],
        tuple_args: bool,
        donated_invars: Sequence[bool],
        replicated_args: Optional[Sequence[bool]],
        arg_partitions: Optional[Any],
        out_partitions: Optional[Any],
        partitions_are_protos: bool = False) -> xc.XlaComputation:
    """Lowers a closed jaxpr to a top-level XLA module."""
    c = xc.XlaBuilder(fn_name)
    xla_consts = _xla_consts(c, jaxpr.consts)
    xla_args, donated_invars = _xla_callable_args(
        c,
        jaxpr.in_avals,
        tuple_args,
        donated_invars=donated_invars,
        replicated=replicated_args,
        partitions=arg_partitions,
        partitions_proto=partitions_are_protos)
    ctx = TranslationContext(c, platform, axis_env, name_stack)
    out_nodes = jaxpr_subcomp(ctx, jaxpr.jaxpr, xla_consts, *xla_args)
    # Replace tokens with a dummy array value, because the runtime cannot
    # handle token arguments.
    out_aval_lens = [len(aval_to_xla_shapes(a)) for a in jaxpr.out_avals]
    out_nodes = util.flatten(
        [[_make_token_return_value(c)] if a is core.abstract_token else v
         for a, v in zip(jaxpr.out_avals,
                         util.unflatten(out_nodes, out_aval_lens))])

    # There is a non-zero cost to building an output tuple, particularly on TPU.
    # Avoid it if the output arity is 1.
    if out_partitions is None:
        output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple(
            c, out_nodes)
    else:
        build_out_tuple = partial(xops.Tuple, c, out_nodes)
        if partitions_are_protos:
            output = with_sharding_proto(c, out_partitions, build_out_tuple)
        else:
            output = with_sharding(c, out_partitions, build_out_tuple)

    platforms_with_donation = ("gpu", "tpu")
    if platform in platforms_with_donation:
        donated_invars = set_up_aliases(c, xla_args, c.GetShape(output),
                                        donated_invars, tuple_args)
    if any(donated_invars):
        # TODO(tomhennigan): At call time we should mark these buffers as deleted.
        unused_donations = [
            str(c.GetShape(a)) for a, d in zip(xla_args, donated_invars) if d
        ]
        msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
        if platform not in platforms_with_donation:
            msg = f"Donation is not implemented for {platform}.\n{msg}"
        warnings.warn(
            f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}"
        )
    return c.build(output)
Example #10
0
def lower_jaxpr_to_fun(
    ctx: ModuleContext,
    name: str,
    jaxpr: core.ClosedJaxpr,
    *,
    public: bool = False,
    replace_units_with_dummy: bool = False,
    replace_tokens_with_dummy: bool = False,
    replicated_args: Optional[Sequence[bool]] = None,
    arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
    result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
    use_sharding_annotations: bool = True,
    input_output_aliases: Optional[Sequence[Optional[int]]] = None
) -> FuncOpType:
    """Lowers jaxpr and its callees to an IR function.

  Assumes that an MLIR context, location, and insertion point are set.

  Args:
    ctx: the lowering context.
    name: the function name. The name will be uniquified by the symbol table,
      so it is ok to use the same name multiple times.
    jaxpr: the jaxpr to lower.
    public: if true, the function's visibility is set to "public".
    replace_units_with_dummy: if true, unit arguments/return values are
      replaced with bool arrays of size [0].
    replace_tokens_with_dummy: if true, token arguments/return values are
      replaced with bool arrays of size [0].
    replicated_args: if present, annotates arguments as replicated.
    arg_shardings: sharding annotations for each argument (optional).
    result_shardings: sharding annotations for each argument (optional).
    use_sharding_annotations: if True, use mhlo.sharding annotations on
      parameters and return values to express sharding. If False, use
      mhlo.custom_call operators with sharding annotations.
      TODO(b/228598865): remove this option when mhlo.sharding annotations are
      propagated on non-entry functions during MHLO->HLO conversion.
    input_output_aliases: optional sequence that maps argument numbers to the
      corresponding output that should alias them.
  Returns the name of the function.
  """
    def aval_to_types(aval):
        if replace_units_with_dummy and aval is core.abstract_unit:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        elif replace_tokens_with_dummy and aval is core.abstract_token:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        return aval_to_ir_types(aval)

    input_types = map(aval_to_types, jaxpr.in_avals)
    output_types = map(aval_to_types, jaxpr.out_avals)
    flat_input_types = util.flatten(input_types)
    flat_output_types = util.flatten(output_types)
    ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
    func_op = FuncOp(name, ftype, ip=ctx.ip)
    func_op.attributes["sym_visibility"] = ir.StringAttr.get(
        "public" if public else "private")
    ctx.symbol_table.insert(func_op)
    ir_arg_shardings = None
    if arg_shardings is not None:
        ir_arg_shardings = util.flatten(
            [[sharding] * len(types)
             for sharding, types in zip(arg_shardings, input_types)])
    ir_result_shardings = None
    if result_shardings is not None:
        ir_result_shardings = util.flatten(
            [[sharding] * len(types)
             for sharding, types in zip(result_shardings, output_types)])

    if (replicated_args is not None or ir_arg_shardings is not None
            or input_output_aliases is not None):
        arg_attrs: List[Dict[str, ir.Attribute]] = [
            {} for _ in range(len(flat_input_types))
        ]

        if replicated_args is not None:
            replicated_ir_args = [
                [replicated] * len(types)
                for replicated, types in zip(replicated_args, input_types)
            ]
            for attrs, replicated in zip(arg_attrs,
                                         util.flatten(replicated_ir_args)):
                if replicated:
                    attrs[
                        "mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get(
                        )

        if use_sharding_annotations and ir_arg_shardings is not None:
            for attrs, sharding in zip(arg_attrs, ir_arg_shardings):
                if sharding is not None:
                    attrs["mhlo.sharding"] = ir.StringAttr.get(
                        sharding.SerializeToString())

        if input_output_aliases is not None:
            output_ids = util.unflatten(list(range(len(flat_output_types))),
                                        map(len, output_types))
            aliases: List[Optional[int]] = []
            for types, alias in zip(input_types, input_output_aliases):
                if alias is None:
                    aliases.extend([None] * len(types))
                else:
                    aliases.extend(output_ids[alias])

            for attrs, alias in zip(arg_attrs, aliases):
                if alias is not None:
                    attrs["tf.aliasing_output"] = i32_attr(alias)

        func_op.arg_attrs = ir.ArrayAttr.get(
            [ir.DictAttr.get(attrs) for attrs in arg_attrs])

    if use_sharding_annotations and ir_result_shardings is not None:
        func_op.result_attrs = ir.ArrayAttr.get([
            ir.DictAttr.get({} if sharding is None else {
                "mhlo.sharding":
                ir.StringAttr.get(sharding.SerializeToString())
            }) for sharding in ir_result_shardings
        ])

    entry_block = func_op.add_entry_block()
    with ir.InsertionPoint(entry_block):
        flat_args = entry_block.arguments
        if not use_sharding_annotations and ir_arg_shardings is not None:
            flat_args = map(wrap_with_sharding_op, flat_args, ir_arg_shardings)

        unflattened_args = util.unflatten(flat_args, map(len, input_types))
        args: List[List[ir.Value]] = []
        for aval, arg in zip(jaxpr.in_avals, unflattened_args):
            if replace_units_with_dummy and aval is core.abstract_unit:
                args.append([])
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results)
            else:
                args.append(arg)
        callee_name_stack = xla.extend_name_stack(ctx.name_stack,
                                                  xla.wrap_name(name, 'jit'))
        out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
                                 jaxpr.jaxpr, map(ir_constants,
                                                  jaxpr.consts), *args)
        outs = []
        for aval, out in zip(jaxpr.out_avals, out_vals):
            if replace_units_with_dummy and aval is core.abstract_unit:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            else:
                outs.append(out)
        flat_outputs = util.flatten(outs)
        if not use_sharding_annotations and ir_result_shardings is not None:
            flat_outputs = map(wrap_with_sharding_op, flat_outputs,
                               ir_result_shardings)

        func_dialect.ReturnOp(flat_outputs)

    return func_op
Example #11
0
def flatten_lowering_ir_args(
    xs: Sequence[Union[ir.Value, Sequence[ir.Value]]]
) -> Sequence[Sequence[ir.Value]]:
    return util.flatten(map(wrap_singleton_ir_values, xs))
Example #12
0
File: mlir.py Project: rsepassi/jax
def lower_jaxpr_to_fun(ctx: LoweringContext,
                       name: str,
                       jaxpr: core.ClosedJaxpr,
                       *,
                       public: bool = False,
                       replace_units_with_dummy: bool = False,
                       replace_tokens_with_dummy: bool = False) -> str:
    """Lowers jaxpr and its callees to an IR function.

  Assumes that an MLIR context, location, and insertion point are set.

  Args:
    ctx: the lowering context.
    name: the function name. The name will be uniquified by the symbol table,
      so it is ok to use the same name multiple times.
    jaxpr: the jaxpr to lower.
    public: if true, the function's visibility is set to "public".
    replace_units_with_dummy: if true, unit arguments/return values are
      replaced with bool arrays of size [0].
    replace_tokens_with_dummy: if true, token arguments/return values are
      replaced with bool arrays of size [0].
  Returns the name of the function.
  """
    def aval_to_types(aval):
        if replace_units_with_dummy and aval is core.abstract_unit:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        elif replace_tokens_with_dummy and aval is core.abstract_token:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        return aval_to_ir_types(aval)

    input_types = map(aval_to_types, jaxpr.in_avals)
    output_types = map(aval_to_types, jaxpr.out_avals)
    flat_input_types = util.flatten(input_types)
    flat_output_types = util.flatten(output_types)
    ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
    func_op = builtin.FuncOp(name, ftype, ip=ctx.ip)
    func_op.attributes["sym_visibility"] = ir.StringAttr.get(
        "public" if public else "private")
    symbol_name = ir.StringAttr(ctx.symbol_table.insert(func_op)).value
    entry_block = func_op.add_entry_block()
    with ir.InsertionPoint(entry_block):
        unflattened_args = util.unflatten(entry_block.arguments,
                                          map(len, input_types))
        args: List[List[ir.Value]] = []
        for aval, arg in zip(jaxpr.in_avals, unflattened_args):
            if replace_units_with_dummy and aval is core.abstract_unit:
                args.append([])
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results)
            else:
                args.append(arg)
        callee_name_stack = xla.extend_name_stack(ctx.name_stack,
                                                  xla.wrap_name(name, 'jit'))
        out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
                                 jaxpr.jaxpr, map(ir_constants,
                                                  jaxpr.consts), *args)
        outs = []
        for aval, out in zip(jaxpr.out_avals, out_vals):
            if replace_units_with_dummy and aval is core.abstract_unit:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            else:
                outs.append(out)
        std.ReturnOp(util.flatten(outs))

    return symbol_name