Exemple #1
0
def _xla_call_lower(ctx, *args,
                    backend=None, name, call_jaxpr, donated_invars, inline=None,
                    device=None):
  del device, donated_invars, inline  # Ignored.
  return _call_lowering(f"jit_{name}", xla.wrap_name(name, "jit"), call_jaxpr,
                        backend, ctx.module_context, ctx.avals_in, ctx.avals_out,
                        *args)
Exemple #2
0
def _remat_using_while(ctx, avals_in, avals_out, *args, name, call_jaxpr):
    input_types = map(aval_to_ir_types, avals_in)
    output_types = map(aval_to_ir_types, avals_out)
    flat_output_types = util.flatten(output_types)
    int32_scalar_type = aval_to_ir_type(
        core.ShapedArray((), np.dtype(np.int32)))
    loop_carry_types = [(int32_scalar_type, )] + input_types + output_types
    flat_loop_carry_types = util.flatten(loop_carry_types)
    counter_init = ir_constants(np.array(0, np.int32))
    flat_args = flatten_lowering_ir_args((counter_init, ) + args + tuple(
        _dummy_like_aval(aval) for aval in avals_out))
    loop_carry_tuple_type = ir.TupleType.get_tuple(flat_loop_carry_types)
    init_carry = mhlo.TupleOp(loop_carry_tuple_type, flat_args)

    one = ir_constant(np.array(1, np.int32))
    while_op = mhlo.WhileOp([loop_carry_tuple_type], [init_carry.result])

    # Loop condition
    cond_block = while_op.regions[0].blocks.append(loop_carry_tuple_type)
    with ir.InsertionPoint(cond_block):
        bool_scalar_type = aval_to_ir_type(
            core.ShapedArray((), np.dtype(np.bool_)))
        two = ir_constant(np.array(2, np.int32))
        shape = ir_constant(np.array((), np.int64), canonicalize_types=False)
        rng = mhlo.RngUniformOp(one, two, shape).result
        i = mhlo.GetTupleElementOp(int32_scalar_type, cond_block.arguments[0],
                                   i32_attr(0))
        cmp = mhlo.CompareOp(bool_scalar_type, i, rng, ir.StringAttr.get("LT"),
                             ir.StringAttr.get("SIGNED")).result
        mhlo.ReturnOp([cmp])

    body_block = while_op.regions[1].blocks.append(loop_carry_tuple_type)
    with ir.InsertionPoint(body_block):
        flat_body_args = [
            mhlo.GetTupleElementOp(input_type, body_block.arguments[0],
                                   i32_attr(i)).result
            for i, input_type in enumerate(flat_loop_carry_types)
        ]
        body_args = util.unflatten(flat_body_args, map(len, loop_carry_types))
        ((i, ), ), y, _ = util.split_list(body_args, [1, len(avals_in)])
        body_ctx = ctx.replace(name_stack=xla.extend_name_stack(
            ctx.name_stack, xla.wrap_name(name, 'remat')))
        z = jaxpr_subcomp(body_ctx, call_jaxpr, (), *y)
        i_next = mhlo.AddOp(i, one).result
        new_carry = mhlo.TupleOp(loop_carry_tuple_type,
                                 [i_next, *util.flatten(y), *util.flatten(z)])
        mhlo.ReturnOp([new_carry.result])

    outputs = [
        mhlo.GetTupleElementOp(output_type, while_op.result,
                               i32_attr(1 + len(avals_in) + i)).result
        for i, output_type in enumerate(flat_output_types)
    ]
    return util.unflatten(outputs, map(len, output_types))
Exemple #3
0
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
                       donated_invars, *arg_specs):
  if device is not None and backend is not None:
    raise ValueError("can't specify both a device and a backend for jit, "
                     "got device={} and backend={}".format(device, backend))

  abstract_args, arg_devices = util.unzip2(arg_specs)
  with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
                        "for jit in {elapsed_time} sec"):
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
        fun, abstract_args, pe.debug_info_final(fun, "jit"))
  if any(isinstance(c, core.Tracer) for c in consts):
    raise UnexpectedTracerError("Encountered an unexpected tracer.")
  jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)
  consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
  pruned_arg_specs = (a for i, a in enumerate(arg_specs) if i in kept_var_idx)
  abstract_args, arg_devices = util.unzip2(pruned_arg_specs)
  donated_invars = [
      x for i, x in enumerate(donated_invars) if i in kept_var_idx
  ]
  map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr)))
  jaxpr = apply_outfeed_rewriter(jaxpr)

  nreps = jaxpr_replicas(jaxpr)
  device = _xla_callable_device(nreps, backend, device, arg_devices)
  backend = xb.get_device_backend(device) if device else xb.get_backend(backend)

  # Computations that only produce constants and/or only rearrange their inputs,
  # which are often produced from partial evaluation, don't need compilation,
  # and don't need to evaluate their arguments.
  if not jaxpr.eqns:
    return XlaComputation(
        name, None, True, None, jaxpr=jaxpr, consts=consts, device=device,
        in_avals=abstract_args, out_avals=out_avals, kept_var_idx=kept_var_idx)

  if not _on_exit:
    log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
    if len(abstract_args) > 10:
      msg = f"Compiling {fun.__name__} ({id(fun)}) for {len(abstract_args)} args."
    else:
      msg = f"Compiling {fun.__name__} ({id(fun)} for args {abstract_args}."
    logging.log(log_priority, msg)

  if nreps > 1:
    warnings.warn(
        f"The jitted function {name} includes a pmap. Using "
         "jit-of-pmap can lead to inefficient data movement, as the outer jit "
         "does not preserve sharded data representations and instead collects "
         "input and output arrays onto a single device. "
         "Consider removing the outer jit unless you know what you're doing. "
         "See https://github.com/google/jax/issues/2926.")

  if nreps > xb.device_count(backend):
    raise ValueError(
        f"compiling computation `{name}` that requires {nreps} replicas, but "
        f"only {xb.device_count(backend)} XLA devices are available.")

  if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
    raise NotImplementedError(
        "jit of multi-host pmap not implemented (and jit-of-pmap can cause "
        "extra data movement anyway, so maybe you don't want it after all).")

  # pass long arg lists as tuple for TPU
  tuple_args = len(abstract_args) > 100
  axis_env = xla.AxisEnv(nreps, (), ())
  name_stack = xla.new_name_stack(xla.wrap_name(name, 'jit'))
  closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
  module: Union[str, xc.XlaComputation]
  module_name = f"jit_{fun.__name__}"
  if config.jax_enable_mlir:
    module = mlir.lower_jaxpr_to_module(
        module_name, closed_jaxpr, backend.platform,
        mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
  else:
    module = xla.lower_jaxpr_to_xla_module(
        module_name, closed_jaxpr, backend.platform, axis_env,
        name_stack, tuple_args, donated_invars, replicated_args=None,
        arg_partitions=None, out_partitions=None)
  return XlaComputation(
      name, module, False, donated_invars, nreps=nreps, device=device,
      backend=backend, tuple_args=tuple_args, in_avals=abstract_args,
      out_avals=out_avals, kept_var_idx=kept_var_idx)
Exemple #4
0
def lower_jaxpr_to_fun(
    ctx: ModuleContext,
    name: str,
    jaxpr: core.ClosedJaxpr,
    *,
    public: bool = False,
    replace_units_with_dummy: bool = False,
    replace_tokens_with_dummy: bool = False,
    replicated_args: Optional[Sequence[bool]] = None,
    arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
    result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
    use_sharding_annotations: bool = True,
    input_output_aliases: Optional[Sequence[Optional[int]]] = None
) -> FuncOpType:
    """Lowers jaxpr and its callees to an IR function.

  Assumes that an MLIR context, location, and insertion point are set.

  Args:
    ctx: the lowering context.
    name: the function name. The name will be uniquified by the symbol table,
      so it is ok to use the same name multiple times.
    jaxpr: the jaxpr to lower.
    public: if true, the function's visibility is set to "public".
    replace_units_with_dummy: if true, unit arguments/return values are
      replaced with bool arrays of size [0].
    replace_tokens_with_dummy: if true, token arguments/return values are
      replaced with bool arrays of size [0].
    replicated_args: if present, annotates arguments as replicated.
    arg_shardings: sharding annotations for each argument (optional).
    result_shardings: sharding annotations for each argument (optional).
    use_sharding_annotations: if True, use mhlo.sharding annotations on
      parameters and return values to express sharding. If False, use
      mhlo.custom_call operators with sharding annotations.
      TODO(b/228598865): remove this option when mhlo.sharding annotations are
      propagated on non-entry functions during MHLO->HLO conversion.
    input_output_aliases: optional sequence that maps argument numbers to the
      corresponding output that should alias them.
  Returns the name of the function.
  """
    def aval_to_types(aval):
        if replace_units_with_dummy and aval is core.abstract_unit:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        elif replace_tokens_with_dummy and aval is core.abstract_token:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        return aval_to_ir_types(aval)

    input_types = map(aval_to_types, jaxpr.in_avals)
    output_types = map(aval_to_types, jaxpr.out_avals)
    flat_input_types = util.flatten(input_types)
    flat_output_types = util.flatten(output_types)
    ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
    func_op = FuncOp(name, ftype, ip=ctx.ip)
    func_op.attributes["sym_visibility"] = ir.StringAttr.get(
        "public" if public else "private")
    ctx.symbol_table.insert(func_op)
    ir_arg_shardings = None
    if arg_shardings is not None:
        ir_arg_shardings = util.flatten(
            [[sharding] * len(types)
             for sharding, types in zip(arg_shardings, input_types)])
    ir_result_shardings = None
    if result_shardings is not None:
        ir_result_shardings = util.flatten(
            [[sharding] * len(types)
             for sharding, types in zip(result_shardings, output_types)])

    if (replicated_args is not None or ir_arg_shardings is not None
            or input_output_aliases is not None):
        arg_attrs: List[Dict[str, ir.Attribute]] = [
            {} for _ in range(len(flat_input_types))
        ]

        if replicated_args is not None:
            replicated_ir_args = [
                [replicated] * len(types)
                for replicated, types in zip(replicated_args, input_types)
            ]
            for attrs, replicated in zip(arg_attrs,
                                         util.flatten(replicated_ir_args)):
                if replicated:
                    attrs[
                        "mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get(
                        )

        if use_sharding_annotations and ir_arg_shardings is not None:
            for attrs, sharding in zip(arg_attrs, ir_arg_shardings):
                if sharding is not None:
                    attrs["mhlo.sharding"] = ir.StringAttr.get(
                        sharding.SerializeToString())

        if input_output_aliases is not None:
            output_ids = util.unflatten(list(range(len(flat_output_types))),
                                        map(len, output_types))
            aliases: List[Optional[int]] = []
            for types, alias in zip(input_types, input_output_aliases):
                if alias is None:
                    aliases.extend([None] * len(types))
                else:
                    aliases.extend(output_ids[alias])

            for attrs, alias in zip(arg_attrs, aliases):
                if alias is not None:
                    attrs["tf.aliasing_output"] = i32_attr(alias)

        func_op.arg_attrs = ir.ArrayAttr.get(
            [ir.DictAttr.get(attrs) for attrs in arg_attrs])

    if use_sharding_annotations and ir_result_shardings is not None:
        func_op.result_attrs = ir.ArrayAttr.get([
            ir.DictAttr.get({} if sharding is None else {
                "mhlo.sharding":
                ir.StringAttr.get(sharding.SerializeToString())
            }) for sharding in ir_result_shardings
        ])

    entry_block = func_op.add_entry_block()
    with ir.InsertionPoint(entry_block):
        flat_args = entry_block.arguments
        if not use_sharding_annotations and ir_arg_shardings is not None:
            flat_args = map(wrap_with_sharding_op, flat_args, ir_arg_shardings)

        unflattened_args = util.unflatten(flat_args, map(len, input_types))
        args: List[List[ir.Value]] = []
        for aval, arg in zip(jaxpr.in_avals, unflattened_args):
            if replace_units_with_dummy and aval is core.abstract_unit:
                args.append([])
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results)
            else:
                args.append(arg)
        callee_name_stack = xla.extend_name_stack(ctx.name_stack,
                                                  xla.wrap_name(name, 'jit'))
        out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
                                 jaxpr.jaxpr, map(ir_constants,
                                                  jaxpr.consts), *args)
        outs = []
        for aval, out in zip(jaxpr.out_avals, out_vals):
            if replace_units_with_dummy and aval is core.abstract_unit:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            else:
                outs.append(out)
        flat_outputs = util.flatten(outs)
        if not use_sharding_annotations and ir_result_shardings is not None:
            flat_outputs = map(wrap_with_sharding_op, flat_outputs,
                               ir_result_shardings)

        func_dialect.ReturnOp(flat_outputs)

    return func_op
Exemple #5
0
def lower_jaxpr_to_fun(ctx: LoweringContext,
                       name: str,
                       jaxpr: core.ClosedJaxpr,
                       *,
                       public: bool = False,
                       replace_units_with_dummy: bool = False,
                       replace_tokens_with_dummy: bool = False) -> str:
    """Lowers jaxpr and its callees to an IR function.

  Assumes that an MLIR context, location, and insertion point are set.

  Args:
    ctx: the lowering context.
    name: the function name. The name will be uniquified by the symbol table,
      so it is ok to use the same name multiple times.
    jaxpr: the jaxpr to lower.
    public: if true, the function's visibility is set to "public".
    replace_units_with_dummy: if true, unit arguments/return values are
      replaced with bool arrays of size [0].
    replace_tokens_with_dummy: if true, token arguments/return values are
      replaced with bool arrays of size [0].
  Returns the name of the function.
  """
    def aval_to_types(aval):
        if replace_units_with_dummy and aval is core.abstract_unit:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        elif replace_tokens_with_dummy and aval is core.abstract_token:
            aval = core.ShapedArray((), np.dtype(np.bool_))
        return aval_to_ir_types(aval)

    input_types = map(aval_to_types, jaxpr.in_avals)
    output_types = map(aval_to_types, jaxpr.out_avals)
    flat_input_types = util.flatten(input_types)
    flat_output_types = util.flatten(output_types)
    ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
    func_op = builtin.FuncOp(name, ftype, ip=ctx.ip)
    func_op.attributes["sym_visibility"] = ir.StringAttr.get(
        "public" if public else "private")
    symbol_name = ir.StringAttr(ctx.symbol_table.insert(func_op)).value
    entry_block = func_op.add_entry_block()
    with ir.InsertionPoint(entry_block):
        unflattened_args = util.unflatten(entry_block.arguments,
                                          map(len, input_types))
        args: List[List[ir.Value]] = []
        for aval, arg in zip(jaxpr.in_avals, unflattened_args):
            if replace_units_with_dummy and aval is core.abstract_unit:
                args.append([])
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results)
            else:
                args.append(arg)
        callee_name_stack = xla.extend_name_stack(ctx.name_stack,
                                                  xla.wrap_name(name, 'jit'))
        out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
                                 jaxpr.jaxpr, map(ir_constants,
                                                  jaxpr.consts), *args)
        outs = []
        for aval, out in zip(jaxpr.out_avals, out_vals):
            if replace_units_with_dummy and aval is core.abstract_unit:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            elif replace_tokens_with_dummy and aval is core.abstract_token:
                outs.append(ir_constants(np.zeros((), np.bool_)))
            else:
                outs.append(out)
        std.ReturnOp(util.flatten(outs))

    return symbol_name