Exemplo n.º 1
0
Arquivo: ad.py Projeto: jbampton/jax
 def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
   assert call_primitive.multiple_results
   primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
   nonzero_tangents, tangent_tree_def = tree_flatten(tangents)
   nz_tangents = [type(t) is not Zero for t in tangents]
   if 'name' in params and not config.jax_experimental_name_stack:
     params = dict(params, name=wrap_name(params['name'], 'jvp'))
   f_jvp = jvp_subtrace(f, self.main)
   f_jvp, nz_tangents_out = nonzero_tangent_outputs(f_jvp)
   if isinstance(call_primitive, core.MapPrimitive):
     in_axes = params['in_axes']
     tangent_in_axes = [ax for ax, nz in zip(in_axes, nz_tangents) if nz]
     out_axes_thunk = params['out_axes_thunk']
     # The new thunk depends deterministically on the old thunk and the wrapped function.
     # Any caching already has to include the wrapped function as part of the key, so we
     # only use the previous thunk for equality checks.
     # NOTE: This assumes that the output tangents being zero is a deterministic
     #       function of which input tangents were zero.
     @as_hashable_function(closure=(tuple(nz_tangents), out_axes_thunk))
     def new_out_axes_thunk():
       out_axes = out_axes_thunk()
       return (*out_axes, *(ax for ax, nz in zip(out_axes, nz_tangents_out()) if nz))
     params = dict(params,
                   in_axes=(*in_axes, *tangent_in_axes),
                   out_axes_thunk=new_out_axes_thunk)
   f_jvp, out_tree_def = traceable(f_jvp, len(primals), tangent_tree_def)
   update_params = call_param_updaters.get(call_primitive)
   new_params = update_params(params, nz_tangents) if update_params else params
   result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
   primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
   return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
Exemplo n.º 2
0
def remat_impl(*args,
               call_jaxpr: Optional[core.Jaxpr] = None,
               jaxpr: Optional[core.Jaxpr] = None,
               prevent_cse: bool,
               differentiated: bool,
               policy,
               is_gpu_platform: bool = False,
               concrete: bool = False,
               name: str = "checkpoint"):
    # Support either "jaxpr" (for remat2) and "call_jaxpr" (for remat)
    # name is not passed for remat2, defaults to "checkpoint"
    # TODO: remove call_jaxpr once we drop the remat call primitive
    if jaxpr is None:
        jaxpr = call_jaxpr
    assert jaxpr is not None
    assert not jaxpr.constvars

    del concrete, policy  # Unused.
    if differentiated and prevent_cse:
        if config.jax_remat_opt_barrier:
            translation_rule = _remat_translation_using_opt_barrier
        elif is_gpu_platform:
            translation_rule = _remat_translation_using_while
        else:
            translation_rule = _remat_translation_using_cond
    else:
        translation_rule = lambda *args, jaxpr: core.eval_jaxpr(
            jaxpr, (), *args)

    return jax.named_call(translation_rule,
                          name=wrap_name(name, "remat"))(*args, jaxpr=jaxpr)
Exemplo n.º 3
0
def _xla_call_translation_rule(ctx,
                               avals_in,
                               avals_out,
                               *in_nodes,
                               name,
                               backend=None,
                               call_jaxpr,
                               donated_invars,
                               inline=None,
                               device=None):
    del device, donated_invars, inline  # Ignored.
    c = ctx.builder
    check_backend_matches(backend, ctx.platform)
    subc = xc.XlaBuilder(f"jit_{name}")
    args = [parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)]
    sub_ctx = ctx.replace(builder=subc,
                          name_stack=extend_name_stack(ctx.name_stack,
                                                       wrap_name(name, 'jit')))
    out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)

    if len(out_nodes) == 1:
        subc = subc.Build(out_nodes[0])
        return [xops.Call(c, subc, list(in_nodes))]
    else:
        subc = subc.Build(xops.Tuple(subc, out_nodes))
        return xla_destructure(c, xops.Call(c, subc, list(in_nodes)))
Exemplo n.º 4
0
def _remat_using_cond(ctx, in_nodes, name, call_jaxpr):
    """Lower remat to a Conditional which always returns true. This:
    1. Circumvents common subexpression elimination.
    2. In common case of `jax.grad(jax.remat(f))`, ensures the remat blocks
       occur after the primal blocks, because cotangent is an input to the
       Conditional."""
    # Fake condition which always selects True branch.
    c = ctx.builder
    rng = xops.RngUniform(xops.Constant(c, np.array(0, dtype=np.float32)),
                          xops.Constant(c, np.array(1, dtype=np.float32)),
                          xc.Shape.array_shape(xc.PrimitiveType.F32, []))
    pred = xops.Lt(rng, xops.Constant(c, np.array(2, dtype=np.float32)))

    true_op = xops.Tuple(c, in_nodes)
    remat_subc = xc.XlaBuilder("remat_call_subcomputation")
    input_op = parameter(remat_subc, 0, c.get_shape(true_op), replicated=[])
    args = xla_destructure(remat_subc, input_op)
    sub_ctx = ctx.replace(builder=remat_subc,
                          name_stack=extend_name_stack(
                              ctx.name_stack, wrap_name(name, 'remat')))
    out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
    out_node_shapes = [remat_subc.get_shape(o) for o in out_nodes]
    remat_subc = remat_subc.build(xops.Tuple(remat_subc, out_nodes))

    false_op = true_op
    dummy_subc = xc.XlaBuilder("remat_call_dummy_subcomputation")
    parameter(dummy_subc, 0, c.get_shape(false_op), replicated=[])
    out_nodes = [_zeros(dummy_subc, s) for s in out_node_shapes]
    dummy_subc = dummy_subc.build(xops.Tuple(dummy_subc, out_nodes))

    return xla_destructure(
        c, xops.Conditional(pred, true_op, remat_subc, false_op, dummy_subc))
Exemplo n.º 5
0
 def process_call(self, call_primitive, f, tracers, params):
     assert call_primitive.multiple_results
     if config.jax_experimental_name_stack:
         params = dict(params, name=params.get('name', f.__name__))
     else:
         params = dict(params,
                       name=wrap_name(params.get('name', f.__name__),
                                      'vmap'))
     vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
     if all(bdim is not_mapped for bdim in dims):
         return call_primitive.bind(f, *vals, **params)
     else:
         f_, dims_out = batch_subtrace(f, self.main, dims)
         ax_size, = {
             x.shape[d]
             for x, d in zip(vals, dims) if d is not not_mapped
         }
         f_ = _update_annotation(f_, f.in_type, ax_size, self.axis_name,
                                 dims)
         vals_out = call_primitive.bind(f_, *vals, **params)
         src = source_info_util.current()
         return [
             BatchTracer(self, v, d, src)
             for v, d in zip(vals_out, dims_out())
         ]
Exemplo n.º 6
0
def _sharded_jit_translation_rule(ctx, avals_in, avals_out, *in_nodes,
                                  in_parts, out_parts_thunk, nparts,
                                  name, call_jaxpr, local_in_parts,
                                  local_out_parts_thunk, local_nparts):
  subc = xc.XlaBuilder(f"sharded_jit_{name}")

  # We assume any extra leading in_nodes are constants and replicate them.
  num_extra_nodes = len(in_nodes) - len(in_parts)
  assert num_extra_nodes >= 0
  in_parts = (None,) * num_extra_nodes + in_parts

  args = []
  for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)):
    # We use xla.set_sharding instead of xla.with_sharding because inlined calls
    # shouldn't have shardings set directly on the inputs or outputs.
    arg = xla.parameter(subc, i, ctx.builder.GetShape(n))
    args.append(xla.set_sharding(subc, arg, sharding))

  sub_ctx = ctx.replace(
      builder=subc,
      name_stack=new_name_stack(wrap_name(name, "sharded_jit")))
  out_nodes = xla.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
  out_parts = out_parts_thunk()
  assert len(out_parts) == len(out_nodes)
  out_nodes = [xla.set_sharding(subc, out, sharding)
               for out, sharding in safe_zip(out_nodes, out_parts)]

  subc = subc.build(xops.Tuple(subc, out_nodes))
  return xla.xla_destructure(ctx.builder,
                             xops.Call(ctx.builder, subc, list(in_nodes)))
Exemplo n.º 7
0
def _remat_using_while(ctx, in_nodes, name, call_jaxpr):
    """Lower remat to a single iteration while loop."""
    c = ctx.builder
    # Dummy subc for getting subcomp shapes.
    dummy_inputs = xops.Tuple(c, in_nodes)
    dummy_subc = xc.XlaBuilder("remat_dummy_subcomputation")
    dummy_input_op = parameter(dummy_subc,
                               0,
                               c.get_shape(dummy_inputs),
                               replicated=[])
    dummy_args = xla_destructure(dummy_subc, dummy_input_op)
    dummy_ctx = ctx.replace(builder=dummy_subc,
                            name_stack=extend_name_stack(
                                ctx.name_stack, wrap_name(name, 'remat')))
    dummy_subcomp_outs = jaxpr_subcomp(dummy_ctx, call_jaxpr, (), *dummy_args)
    out_node_shapes = [dummy_subc.get_shape(o) for o in dummy_subcomp_outs]

    i_init = xops.Constant(c, np.array(0, dtype=np.int32))
    zeros_like_outs = [_zeros(c, s) for s in out_node_shapes]
    inputs = xops.Tuple(c, [i_init] + list(in_nodes) + zeros_like_outs)

    cond_subc = xc.XlaBuilder("remat_cond_subcomputation")
    input_op = parameter(cond_subc, 0, c.get_shape(inputs), replicated=[])
    i = xops.GetTupleElement(input_op, 0)
    rng = xops.RngUniform(
        xops.Constant(cond_subc, np.array(1, dtype=np.int32)),
        xops.Constant(cond_subc, np.array(2, dtype=np.int32)),
        xc.Shape.array_shape(xc.PrimitiveType.S32, []))
    cond_subc = cond_subc.build(xops.Lt(i, rng))

    body_subc = xc.XlaBuilder("remat_body_subcomputation")
    input_op = parameter(body_subc, 0, c.get_shape(inputs), replicated=[])
    i, *args = xla_destructure(body_subc, input_op)[:len(in_nodes) + 1]
    i_next = xops.Add(i, xops.Constant(body_subc, np.array(1, dtype=np.int32)))
    body_ctx = ctx.replace(builder=body_subc,
                           name_stack=extend_name_stack(
                               ctx.name_stack, wrap_name(name, 'remat')))
    subcomp_outs = jaxpr_subcomp(body_ctx, call_jaxpr, (), *args)
    out_nodes = [i_next] + args + list(subcomp_outs)
    body_subc = body_subc.build(xops.Tuple(body_subc, out_nodes))
    outs = xops.While(cond_subc, body_subc, inputs)
    return xla_destructure(c, outs)[len(in_nodes) + 1:]
Exemplo n.º 8
0
def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
    all_args, in_tree_def = tree_flatten(((), args, ct))  # empty consts
    fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr,
                              reduce_axes)
    fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
    new_params = dict(params, name=wrap_name(params['name'], 'transpose'))
    update_params = call_transpose_param_updaters.get(primitive)
    if update_params:
        new_params = update_params(new_params, map(is_undefined_primal, args),
                                   [type(x) is not Zero for x in ct])
    out_flat = primitive.bind(fun, *all_args, **new_params)
    return tree_unflatten(out_tree(), out_flat)
Exemplo n.º 9
0
 def process_call(self, call_primitive, f, tracers, params):
   assert call_primitive.multiple_results
   heads, tails = unzip2((t.head, t.tail) for t in tracers)
   nonzero_tails, in_tree_def = tree_flatten(tails)
   f_double, out_tree_def = screen_nones(doubling_subtrace(f, self.main),
                                         len(heads), in_tree_def)
   name = params.get('name', f.__name__)
   new_params = dict(params, name=wrap_name(name, 'doubledouble'),
                     donated_invars=(False,) * (len(heads) + len(nonzero_tails)))
   result = call_primitive.bind(f_double, *heads, *nonzero_tails, **new_params)
   heads_out, tails_out = tree_unflatten(out_tree_def(), result)
   return [DoublingTracer(self, h, t) for h, t in zip(heads_out, tails_out)]
Exemplo n.º 10
0
 def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
     assert call_primitive.multiple_results
     params = dict(params,
                   name=wrap_name(params.get('name', f.__name__), 'vmap'))
     vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
     if all(bdim is not_mapped for bdim in dims):
         return call_primitive.bind(f, *vals, **params)
     else:
         f, dims_out = batch_subtrace(f, self.main, dims)
         vals_out = call_primitive.bind(f, *vals, **params)
         return [
             BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())
         ]
Exemplo n.º 11
0
def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
  all_args, in_tree_def = tree_flatten(((), args, ct))  # empty consts
  fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr,
                            reduce_axes, False)
  fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
  if not config.jax_experimental_name_stack:
    params = dict(params, name=wrap_name(params['name'], 'transpose'))
  update_params = call_transpose_param_updaters.get(primitive)
  if update_params:
    params = update_params(params, map(is_undefined_primal, args),
                           [type(x) is not Zero for x in ct])
  if config.jax_dynamic_shapes:
    in_type = [(core.raise_to_shaped(core.get_aval(x)), True) for x in all_args]
    fun = lu.annotate(fun, tuple(in_type))
  out_flat = primitive.bind(fun, *all_args, **params)
  return tree_unflatten(out_tree(), out_flat)
Exemplo n.º 12
0
Arquivo: ad.py Projeto: John1Tang/jax
def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
    all_args, in_tree_def = tree_flatten(((), args, ct))  # empty consts
    fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr,
                              reduce_axes, False)
    fun, nz_arg_cts = nonzero_outputs(fun)
    fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
    # Preserve axis for primal arguments, skip tangents (represented as undefined primals).
    in_axes, out_axes = params['in_axes'], params['out_axes']
    new_in_axes = (*[
        axis for axis, x in zip(in_axes, args) if not is_undefined_primal(x)
    ], *[axis for axis, x in zip(out_axes, ct) if type(x) is not Zero])
    # The interim strategy we use below (until avals-with-names) only works
    # when all outputs are mapped.
    assert all(out_axis is not None for out_axis in out_axes), out_axes
    # NOTE: This assumes that the output cotangents being zero is a deterministic
    #       function of which input cotangents were zero.
    @as_hashable_function(closure=(in_axes, tuple(type(c) is Zero
                                                  for c in ct)))
    def out_axes_thunk():
        return tuple(axis or 0 for axis, nz in zip(in_axes, nz_arg_cts())
                     if nz)

    new_params = dict(params,
                      name=wrap_name(params['name'], 'transpose'),
                      in_axes=new_in_axes,
                      out_axes_thunk=out_axes_thunk)
    del new_params['out_axes']
    update_params = call_transpose_param_updaters.get(primitive)
    if update_params:
        new_params = update_params(new_params, map(is_undefined_primal, args),
                                   [type(x) is not Zero for x in ct])
    out_flat = primitive.bind(fun, *all_args, **new_params)
    arg_cts = tree_unflatten(out_tree(), out_flat)

    # The freevars are being fanned out (not mapped). During transpose the
    # dual of fan-out is fan-in-sum. We apply it to the unmapped invars.
    assert len(in_axes) == len(arg_cts)

    def unmap_zero(zero, in_axis):
        return (zero if in_axis is None else Zero(
            core.unmapped_aval(params['axis_size'], params['axis_name'],
                               in_axis, zero.aval)))

    arg_cts = (unmap_zero(arg_ct, in_axis) if type(arg_ct) is Zero else
               arg_ct if in_axis is not None else arg_ct.sum(0)
               for arg_ct, in_axis in zip(arg_cts, in_axes))
    return tuple(arg_cts)
Exemplo n.º 13
0
 def process_call(self, call_primitive, f, tracers, params):
   assert call_primitive.multiple_results
   params = dict(params, name=wrap_name(params.get('name', f.__name__), 'mask'))
   vals, shapes = unzip2((t.val, t.polymorphic_shape) for t in tracers)
   if not any(is_polymorphic(s) for s in shapes):
     return call_primitive.bind(f, *vals, **params)
   else:
     logical_env, padded_env = shape_envs
     env_keys, padded_env_vals = unzip2(sorted(padded_env.items()))
     logical_env_vals = tuple(logical_env[k] for k in env_keys)
     # Make padded_env hashable
     padded_env = (env_keys, padded_env_vals)
     f, shapes_out = mask_subtrace(f, self.main, shapes, padded_env)
     if 'donated_invars' in params:
       params = dict(params, donated_invars=((False,) * len(logical_env_vals) +
                                             params['donated_invars']))
     vals_out = call_primitive.bind(f, *(logical_env_vals + vals), **params)
     return [MaskTracer(self, v, s) for v, s in zip(vals_out, shapes_out())]
Exemplo n.º 14
0
 def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
     assert call_primitive.multiple_results
     if config.jax_experimental_name_stack:
         params = dict(params, name=params.get('name', f.__name__))
     else:
         params = dict(params,
                       name=wrap_name(params.get('name', f.__name__),
                                      'vmap'))
     vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
     if all(bdim is not_mapped for bdim in dims):
         return call_primitive.bind(f, *vals, **params)
     else:
         f, dims_out = batch_subtrace(f, self.main, dims)
         vals_out = call_primitive.bind(f, *vals, **params)
         src = source_info_util.current()
         return [
             BatchTracer(self, v, d, src)
             for v, d in zip(vals_out, dims_out())
         ]
Exemplo n.º 15
0
def _sharded_jit_lowering(ctx, *in_nodes, in_parts, out_parts_thunk, nparts,
                          name, call_jaxpr, local_in_parts,
                          local_out_parts_thunk, local_nparts):
    # We assume any extra leading in_nodes are constants and replicate them.
    num_extra_nodes = len(in_nodes) - len(in_parts)
    assert num_extra_nodes >= 0
    in_parts = (None, ) * num_extra_nodes + in_parts

    args = []
    for ns, sharding in safe_zip(
            safe_map(mlir.wrap_singleton_ir_values, in_nodes), in_parts):
        if sharding is not None:
            args.append([
                mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
                for n in ns
            ])
        else:
            args.append(ns)

    sub_ctx = ctx.module_context.replace(
        name_stack=extend_name_stack(wrap_name(name, "sharded_jit")))
    fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}",
                                 core.ClosedJaxpr(call_jaxpr, ()))

    output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out)
    flat_output_types = util.flatten(output_types)
    call = std.CallOp(flat_output_types,
                      ir.FlatSymbolRefAttr.get(fn.name.value),
                      mlir.flatten_lowering_ir_args(args))
    out_nodes = util.unflatten(call.results, safe_map(len, output_types))

    out_parts = out_parts_thunk()
    outputs = []
    for ns, sharding in safe_zip(out_nodes, out_parts):
        if sharding is not None:
            outputs.append([
                mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
                for n in ns
            ])
        else:
            outputs.append(ns)
    return outputs
Exemplo n.º 16
0
def _sharded_callable(
        fun: lu.WrappedFun, nparts: Optional[int],
        in_parts: Tuple[pxla.PartitionsOrReplicated, ...],
        out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]],
        local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]],
        local_out_parts_thunk: Callable[[], Optional[Tuple[
            pxla.PartitionsOrReplicated,
            ...]]], local_nparts: Optional[int], name: str, *abstract_args):
    nrep = 1

    if local_in_parts is None:
        local_in_parts = in_parts

    global_abstract_args = [
        pxla.get_global_aval(arg, parts,
                             lparts) for arg, parts, lparts in safe_zip(
                                 abstract_args, in_parts, local_in_parts)
    ]

    if logging.vlog_is_on(2):
        logging.vlog(2, "abstract_args: %s", abstract_args)
        logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
        logging.vlog(2, "in_parts: %s", in_parts)
        logging.vlog(2, "local_in_parts: %s", local_in_parts)

    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
        fun, global_abstract_args)

    platform = xb.get_backend().platform

    nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
    assert nparts is not None
    if nparts > xb.device_count():
        raise ValueError(
            f"sharded_jit computation requires {nparts} devices, "
            f"but only {xb.device_count()} devices are available.")
    if xb.local_device_count() < nparts < xb.device_count():
        raise NotImplementedError(
            f"sharded_jit across multiple hosts must use all available devices. "
            f"Got {nparts} out of {xb.device_count()} requested devices "
            f"(local device count: {xb.local_device_count()})")

    if local_nparts is None:
        if nparts > xb.local_device_count():
            raise ValueError(
                "Specify 'local_nparts' when using cross-process sharded_jit "
                "and all inputs and outputs are replicated.")
        else:
            local_nparts = nparts
    if local_nparts > xb.local_device_count():
        raise ValueError(
            f"sharded_jit computation requires {local_nparts} local devices, "
            f"but only {xb.local_device_count()} local devices are available.")

    if logging.vlog_is_on(2):
        logging.vlog(2, "nparts: %d  local_nparts: %d", nparts, local_nparts)

    out_parts = out_parts_thunk()

    local_out_parts = local_out_parts_thunk()
    if local_out_parts is None:
        local_out_parts = out_parts

    if logging.vlog_is_on(2):
        logging.vlog(2, "out_parts: %s", out_parts)
        logging.vlog(2, "local_out_parts: %s", local_out_parts)

    local_out_avals = [
        pxla.get_local_aval(out, parts,
                            lparts) for out, parts, lparts in safe_zip(
                                global_out_avals, out_parts, local_out_parts)
    ]

    log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
    logging.log(log_priority, "Compiling %s for %d devices with args %s.",
                fun.__name__, nparts, global_abstract_args)

    axis_env = xla.AxisEnv(nrep, (), ())
    unordered_effects = [
        eff for eff in jaxpr.effects if eff not in core.ordered_effects
    ]
    ordered_effects = [
        eff for eff in jaxpr.effects if eff in core.ordered_effects
    ]
    module, _ = mlir.lower_jaxpr_to_module(
        f"spjit_{fun.__name__}",
        core.ClosedJaxpr(jaxpr, consts),
        unordered_effects,
        ordered_effects,
        platform=platform,
        axis_context=mlir.ReplicaAxisContext(axis_env),
        name_stack=new_name_stack(wrap_name(name, "sharded_jit")),
        donated_args=[False] * len(in_parts),
        arg_shardings=safe_map(xla.sharding_to_proto, in_parts),
        result_shardings=safe_map(xla.sharding_to_proto, out_parts))
    built = xc._xla.mlir.mlir_module_to_xla_computation(
        mlir.module_to_string(module), use_tuple_args=False, return_tuple=True)

    if nparts <= xb.local_device_count():
        devices = xb.local_devices()[:nparts]
    else:
        assert nparts == xb.device_count()
        devices = xb.devices()
    device_assignment = np.array([[d for d in devices]])
    device_assignment = np.reshape(device_assignment, (-1, nparts))
    # device_assignment = None  # TODO(skye): replace with default device assignment?

    compiled = dispatch.backend_compile(
        xb.get_backend(), built,
        xb.get_compile_options(nrep, nparts, device_assignment))

    input_specs = [
        pxla.partitioned_sharding_spec(local_nparts, parts, aval)
        for parts, aval in zip(local_in_parts, abstract_args)
    ]
    input_indices = [
        pxla.spec_to_indices(aval.shape, spec) if spec is not None else None
        for aval, spec in zip(abstract_args, input_specs)
    ]

    handle_args = partial(pxla.shard_args, compiled.local_devices(),
                          input_indices)
    handle_outs = _avals_to_results_handler(
        nrep,
        local_nparts,  # type: ignore
        local_out_parts,
        local_out_avals)
    return partial(_execute_spatially_partitioned, compiled, handle_args,
                   handle_outs)
Exemplo n.º 17
0
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
                       donated_invars, always_lower: bool, keep_unused: bool,
                       *arg_specs):
    """Lower into XLA.

  Args:
    always_lower: If `True`, even trivial programs (not doing any computation
      such as lambda x: x) will be lowered into an XLA program.
    keep_unused: If `False` (the default), arguments that JAX determines to be
      unused by `fun` *may* be dropped from resulting compiled XLA executables.
      Such arguments will not be transferred to the device nor provided to the
      underlying executable. If `True`, unused arguments will not be pruned.
  """
    if device is not None and backend is not None:
        raise ValueError("can't specify both a device and a backend for jit, "
                         "got device={} and backend={}".format(
                             device, backend))
    abstract_args, arg_devices = util.unzip2(arg_specs)
    if fun.in_type is not None:
        abstract_args, which_explicit = util.unzip2(fun.in_type)
    else:
        which_explicit = None
    with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
                          "for jit in {elapsed_time} sec"):
        jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
            fun, abstract_args, pe.debug_info_final(fun, "jit"),
            which_explicit)
    if any(isinstance(c, core.Tracer) for c in consts):
        raise UnexpectedTracerError("Encountered an unexpected tracer.")
    # TODO(mattjj): handle argument pruning w/ dynamic shapes
    if fun.in_type is None and not keep_unused:
        jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)
        consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
        abstract_args, arg_devices = util.unzip2(
            [a for i, a in enumerate(arg_specs) if i in kept_var_idx])
        donated_invars = [
            x for i, x in enumerate(donated_invars) if i in kept_var_idx
        ]
        del kept_const_idx
    else:
        kept_var_idx = set(range(len(abstract_args)))
    map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr)))
    jaxpr = apply_outfeed_rewriter(jaxpr)

    nreps = jaxpr_replicas(jaxpr)
    device = _xla_callable_device(nreps, backend, device, arg_devices)
    backend = xb.get_device_backend(device) if device else xb.get_backend(
        backend)

    if (config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr)
            and not _backend_supports_unbounded_dynamic_shapes(backend)):
        jaxpr, consts = pe.pad_jaxpr(jaxpr, consts)

    # Computations that only produce constants and/or only rearrange their inputs,
    # which are often produced from partial evaluation, don't need compilation,
    # and don't need to evaluate their arguments.
    if not jaxpr.eqns and not always_lower:
        return XlaComputation(name,
                              None,
                              True,
                              None,
                              None,
                              jaxpr=jaxpr,
                              consts=consts,
                              device=device,
                              in_avals=abstract_args,
                              out_avals=out_avals,
                              has_unordered_effects=False,
                              ordered_effects=[],
                              kept_var_idx=kept_var_idx,
                              keepalive=None)

    if not _on_exit:
        log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
        if len(abstract_args) > 10:
            msg = f"Compiling {fun.__name__} ({id(fun)}) for {len(abstract_args)} args."
        else:
            msg = f"Compiling {fun.__name__} ({id(fun)} for args {abstract_args}."
        logging.log(log_priority, msg)

    if nreps > 1:
        warnings.warn(
            f"The jitted function {name} includes a pmap. Using "
            "jit-of-pmap can lead to inefficient data movement, as the outer jit "
            "does not preserve sharded data representations and instead collects "
            "input and output arrays onto a single device. "
            "Consider removing the outer jit unless you know what you're doing. "
            "See https://github.com/google/jax/issues/2926.")

    if nreps > xb.device_count(backend):
        raise ValueError(
            f"compiling computation `{name}` that requires {nreps} replicas, but "
            f"only {xb.device_count(backend)} XLA devices are available.")

    if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
        raise NotImplementedError(
            "jit of multi-host pmap not implemented (and jit-of-pmap can cause "
            "extra data movement anyway, so maybe you don't want it after all)."
        )

    # pass long arg lists as tuple for TPU
    tuple_args = len(abstract_args) > 100
    axis_env = xla.AxisEnv(nreps, (), ())
    name_stack = util.new_name_stack(util.wrap_name(name, 'jit'))
    closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
    module_name = f"jit_{fun.__name__}"
    unordered_effects = [
        eff for eff in closed_jaxpr.effects if eff not in core.ordered_effects
    ]
    ordered_effects = [
        eff for eff in closed_jaxpr.effects if eff in core.ordered_effects
    ]
    module, keepalive = mlir.lower_jaxpr_to_module(
        module_name, closed_jaxpr,
        unordered_effects, ordered_effects, backend.platform,
        mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
    return XlaComputation(name,
                          module,
                          False,
                          donated_invars,
                          which_explicit,
                          nreps=nreps,
                          device=device,
                          backend=backend,
                          tuple_args=tuple_args,
                          in_avals=abstract_args,
                          out_avals=out_avals,
                          has_unordered_effects=bool(unordered_effects),
                          ordered_effects=ordered_effects,
                          kept_var_idx=kept_var_idx,
                          keepalive=keepalive)
Exemplo n.º 18
0
def _sharded_callable(
        fun: lu.WrappedFun, nparts: Optional[int],
        in_parts: Tuple[pxla.PartitionsOrReplicated, ...],
        out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]],
        local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]],
        local_out_parts_thunk: Callable[[], Optional[Tuple[
            pxla.PartitionsOrReplicated,
            ...]]], local_nparts: Optional[int], name: str, *abstract_args):
    nrep = 1

    if local_in_parts is None:
        local_in_parts = in_parts

    global_abstract_args = [
        pxla.get_global_aval(arg, parts,
                             lparts) for arg, parts, lparts in safe_zip(
                                 abstract_args, in_parts, local_in_parts)
    ]

    if logging.vlog_is_on(2):
        logging.vlog(2, "abstract_args: %s", abstract_args)
        logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
        logging.vlog(2, "in_parts: %s", in_parts)
        logging.vlog(2, "local_in_parts: %s", local_in_parts)

    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
        fun, global_abstract_args)

    platform = xb.get_backend().platform
    if platform not in ["tpu", "gpu"]:
        # TODO(skye): fall back to regular jit?
        raise ValueError(f"sharded_jit not supported for {platform}")

    nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
    assert nparts is not None
    if nparts > xb.device_count():
        raise ValueError(
            f"sharded_jit computation requires {nparts} devices, "
            f"but only {xb.device_count()} devices are available.")
    if xb.local_device_count() < nparts < xb.device_count():
        raise NotImplementedError(
            f"sharded_jit across multiple hosts must use all available devices. "
            f"Got {nparts} out of {xb.device_count()} requested devices "
            f"(local device count: {xb.local_device_count()})")

    if local_nparts is None:
        if nparts > xb.local_device_count():
            raise ValueError(
                "Specify 'local_nparts' when using cross-process sharded_jit "
                "and all inputs and outputs are replicated.")
        else:
            local_nparts = nparts
    if local_nparts > xb.local_device_count():
        raise ValueError(
            f"sharded_jit computation requires {local_nparts} local devices, "
            f"but only {xb.local_device_count()} local devices are available.")

    if logging.vlog_is_on(2):
        logging.vlog(2, "nparts: %d  local_nparts: %d", nparts, local_nparts)

    out_parts = out_parts_thunk()

    local_out_parts = local_out_parts_thunk()
    if local_out_parts is None:
        local_out_parts = out_parts

    if logging.vlog_is_on(2):
        logging.vlog(2, "out_parts: %s", out_parts)
        logging.vlog(2, "local_out_parts: %s", local_out_parts)

    local_out_avals = [
        pxla.get_local_aval(out, parts,
                            lparts) for out, parts, lparts in safe_zip(
                                global_out_avals, out_parts, local_out_parts)
    ]

    log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
    logging.log(log_priority, "Compiling %s for %d devices with args %s.",
                fun.__name__, nparts, global_abstract_args)

    c = xc.XlaBuilder("spjit_{}".format(fun.__name__))
    xla_consts = _map(partial(xla.pyval_to_ir_constant, c), consts)
    xla_args = _xla_sharded_args(c, global_abstract_args, in_parts)
    axis_env = xla.AxisEnv(nrep, (), ())
    ctx = xla.TranslationContext(
        c, platform, axis_env,
        extend_name_stack(wrap_name(name, "sharded_jit")))
    out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
    out_tuple = xla.with_sharding(c, out_parts, xops.Tuple, c, out_nodes)
    built = c.Build(out_tuple)

    if nparts <= xb.local_device_count():
        devices = xb.local_devices()[:nparts]
    else:
        assert nparts == xb.device_count()
        devices = xb.devices()
    device_assignment = np.array([[d.id for d in devices]])
    device_assignment = np.reshape(device_assignment, (-1, nparts))
    # device_assignment = None  # TODO(skye): replace with default device assignment?

    compiled = dispatch.backend_compile(
        xb.get_backend(), built,
        xb.get_compile_options(nrep, nparts, device_assignment))

    input_specs = [
        pxla.partitioned_sharding_spec(local_nparts, parts, aval)
        for parts, aval in zip(local_in_parts, abstract_args)
    ]
    input_indices = [
        pxla.spec_to_indices(aval.shape, spec) if spec is not None else None
        for aval, spec in zip(abstract_args, input_specs)
    ]

    handle_args = partial(pxla.shard_args, compiled.local_devices(),
                          input_indices)
    handle_outs = _avals_to_results_handler(
        nrep,
        local_nparts,  # type: ignore
        local_out_parts,
        local_out_avals)
    return partial(_execute_spatially_partitioned, compiled, handle_args,
                   handle_outs)