예제 #1
0
def _while_loop_translation_rule(c, axis_env, *args, **kwargs):
    backend = kwargs.pop('backend', None)
    cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts = split_dict(
        kwargs, ["cond_jaxpr", "body_jaxpr", "cond_nconsts", "body_nconsts"])
    cond_consts, body_consts, init_vals = split_list(
        args, [cond_nconsts, body_nconsts])
    batched = bool(cond_jaxpr.out_avals[0].shape)

    # Since jaxprs don't have tuples and have multiple return values, but we need
    # the HLO While loop to take a single tuple input and output a single boolean
    # (for the cond computation) or a single tuple output (for the body
    # computation), we build XLA computations that handle the tuple munging before
    # generating a Call into the computations formed from the jaxprs.

    init_carry = c.Tuple(*(cond_consts + body_consts + init_vals))

    cond_c = xb.make_computation_builder("cond_computation")
    cond_carry = cond_c.ParameterWithShape(c.GetShape(init_carry))
    cond_carry_elts = [
        cond_c.GetTupleElement(cond_carry, i) for i in range(len(args))
    ]
    x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
    pred, = xla.jaxpr_subcomp(cond_c, cond_jaxpr.jaxpr, backend, axis_env,
                              _map(cond_c.Constant, cond_jaxpr.literals), (),
                              *(x + z))
    if batched:
        scalar = xla_client.Shape.array_shape(onp.dtype(onp.bool_), ())
        or_ = xla.primitive_computation(lax.or_p, scalar, scalar)
        pred = cond_c.Reduce(pred, cond_c.Constant(onp.array(False)), or_,
                             list(range(cond_jaxpr.out_avals[0].ndim)))

    body_c = xb.make_computation_builder("body_computation")
    body_carry = body_c.ParameterWithShape(c.GetShape(init_carry))
    body_carry_elts = [
        body_c.GetTupleElement(body_carry, i) for i in range(len(args))
    ]
    x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
    new_z = xla.jaxpr_subcomp(body_c, body_jaxpr.jaxpr, backend, axis_env,
                              _map(body_c.Constant, body_jaxpr.literals), (),
                              *(y + z))
    if batched:
        body_pred, = xla.jaxpr_subcomp(
            body_c, cond_jaxpr.jaxpr, backend, axis_env,
            _map(body_c.Constant, cond_jaxpr.literals), (), *(x + z))
        new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z)
        assert _map(body_c.GetShape, new_z) == _map(body_c.GetShape,
                                                    z)  # no broadcast
    new_carry = body_c.Tuple(*itertools.chain(x, y, new_z))

    ans = c.While(cond_c.Build(pred), body_c.Build(new_carry), init_carry)
    ans_elts = [c.GetTupleElement(ans, i) for i in range(len(args))]
    _, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts])
    return c.Tuple(*z)
예제 #2
0
 def make_computation(name, jaxpr, op_shape):
   c = xb.make_computation_builder(name)
   op = c.ParameterWithShape(op_shape)
   ops = [c.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
   outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env,
                            _map(c.Constant, jaxpr.literals), (), *ops)
   return c.Build(c.Tuple(*outs))
예제 #3
0
def _sharded_jit_translation_rule(ctx, avals_in, avals_out, *in_nodes,
                                  in_parts, out_parts_thunk, nparts,
                                  name, call_jaxpr, local_in_parts,
                                  local_out_parts_thunk, local_nparts):
  subc = xc.XlaBuilder(f"sharded_jit_{name}")

  # We assume any extra leading in_nodes are constants and replicate them.
  num_extra_nodes = len(in_nodes) - len(in_parts)
  assert num_extra_nodes >= 0
  in_parts = (None,) * num_extra_nodes + in_parts

  args = []
  for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)):
    # We use xla.set_sharding instead of xla.with_sharding because inlined calls
    # shouldn't have shardings set directly on the inputs or outputs.
    arg = xla.parameter(subc, i, ctx.builder.GetShape(n))
    args.append(xla.set_sharding(subc, arg, sharding))

  sub_ctx = ctx.replace(
      builder=subc,
      name_stack=new_name_stack(wrap_name(name, "sharded_jit")))
  out_nodes = xla.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
  out_parts = out_parts_thunk()
  assert len(out_parts) == len(out_nodes)
  out_nodes = [xla.set_sharding(subc, out, sharding)
               for out, sharding in safe_zip(out_nodes, out_parts)]

  subc = subc.build(xops.Tuple(subc, out_nodes))
  return xla.xla_destructure(ctx.builder,
                             xops.Call(ctx.builder, subc, list(in_nodes)))
예제 #4
0
def remat_translation(ctx, avals_in, avals_out, *in_nodes,
                      jaxpr, prevent_cse, differentiated, policy):
  del policy  # Unused.
  if differentiated and prevent_cse:
    if ctx.platform == "gpu":
      return xla._remat_using_while(ctx, in_nodes, "checkpoint", jaxpr)
    else:
      return xla._remat_using_cond(ctx, in_nodes, "checkpoint", jaxpr)
  else:
    return xla.jaxpr_subcomp(ctx, jaxpr, (), *in_nodes)
예제 #5
0
def _named_call_translation_rule(comp_builder: 'xla.xb._JaxComputationBuilder',
                                 axis_env: xla.AxisEnv,
                                 in_nodes: 'Sequence[xla.xc._xla.XlaOp]',
                                 name_stack: str,
                                 backend: Optional[Any],
                                 name: str,
                                 call_jaxpr: core.Jaxpr):
  """Compile and add a custom name to the XLA metadata."""
  subcomp_builder = xla.xb.make_computation_builder(f'named_call_{name}')
  args = [xla.xb.parameter(subcomp_builder, i, comp_builder.GetShape(n))
          for i, n in enumerate(in_nodes)]
  out_nodes = xla.jaxpr_subcomp(subcomp_builder, call_jaxpr,
                                backend, axis_env, (),
                                jax.util.extend_name_stack(name_stack, name),
                                *args)
  subcomp = subcomp_builder.Build(xla.xops.Tuple(subcomp_builder, out_nodes))
  return xla.xops.Call(comp_builder, subcomp, list(in_nodes))
예제 #6
0
파일: named_call.py 프로젝트: ykumards/flax
def _named_call_translation_rule(c,
                                 axis_env,
                                 in_nodes,
                                 name_stack,
                                 *,
                                 name='core_call',
                                 backend,
                                 call_jaxpr):
    subc = xla.xb.make_computation_builder(name)
    args = [
        xla.xb.parameter(subc, i, c.GetShape(n))
        for i, n in enumerate(in_nodes)
    ]
    out_nodes = xla.jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
                                  jax.util.extend_name_stack(name_stack, name),
                                  *args)
    subc = subc.Build(xla.xops.Tuple(subc, out_nodes))
    return xla.xops.Call(c, subc, list(in_nodes))
예제 #7
0
def _sharded_callable(
        fun: lu.WrappedFun, nparts: Optional[int],
        in_parts: Tuple[pxla.PartitionsOrReplicated, ...],
        out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]],
        local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]],
        local_out_parts_thunk: Callable[[], Optional[Tuple[
            pxla.PartitionsOrReplicated,
            ...]]], local_nparts: Optional[int], name: str, *abstract_args):
    nrep = 1

    if local_in_parts is None:
        local_in_parts = in_parts

    global_abstract_args = [
        pxla.get_global_aval(arg, parts,
                             lparts) for arg, parts, lparts in safe_zip(
                                 abstract_args, in_parts, local_in_parts)
    ]

    if logging.vlog_is_on(2):
        logging.vlog(2, "abstract_args: %s", abstract_args)
        logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
        logging.vlog(2, "in_parts: %s", in_parts)
        logging.vlog(2, "local_in_parts: %s", local_in_parts)

    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
        fun, global_abstract_args)

    platform = xb.get_backend().platform
    if platform not in ["tpu", "gpu"]:
        # TODO(skye): fall back to regular jit?
        raise ValueError(f"sharded_jit not supported for {platform}")

    nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
    assert nparts is not None
    if nparts > xb.device_count():
        raise ValueError(
            f"sharded_jit computation requires {nparts} devices, "
            f"but only {xb.device_count()} devices are available.")
    if xb.local_device_count() < nparts < xb.device_count():
        raise NotImplementedError(
            f"sharded_jit across multiple hosts must use all available devices. "
            f"Got {nparts} out of {xb.device_count()} requested devices "
            f"(local device count: {xb.local_device_count()})")

    if local_nparts is None:
        if nparts > xb.local_device_count():
            raise ValueError(
                "Specify 'local_nparts' when using cross-process sharded_jit "
                "and all inputs and outputs are replicated.")
        else:
            local_nparts = nparts
    if local_nparts > xb.local_device_count():
        raise ValueError(
            f"sharded_jit computation requires {local_nparts} local devices, "
            f"but only {xb.local_device_count()} local devices are available.")

    if logging.vlog_is_on(2):
        logging.vlog(2, "nparts: %d  local_nparts: %d", nparts, local_nparts)

    out_parts = out_parts_thunk()

    local_out_parts = local_out_parts_thunk()
    if local_out_parts is None:
        local_out_parts = out_parts

    if logging.vlog_is_on(2):
        logging.vlog(2, "out_parts: %s", out_parts)
        logging.vlog(2, "local_out_parts: %s", local_out_parts)

    local_out_avals = [
        pxla.get_local_aval(out, parts,
                            lparts) for out, parts, lparts in safe_zip(
                                global_out_avals, out_parts, local_out_parts)
    ]

    log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
    logging.log(log_priority, "Compiling %s for %d devices with args %s.",
                fun.__name__, nparts, global_abstract_args)

    c = xc.XlaBuilder("spjit_{}".format(fun.__name__))
    xla_consts = _map(partial(xla.pyval_to_ir_constant, c), consts)
    xla_args = _xla_sharded_args(c, global_abstract_args, in_parts)
    axis_env = xla.AxisEnv(nrep, (), ())
    ctx = xla.TranslationContext(
        c, platform, axis_env,
        extend_name_stack(wrap_name(name, "sharded_jit")))
    out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
    out_tuple = xla.with_sharding(c, out_parts, xops.Tuple, c, out_nodes)
    built = c.Build(out_tuple)

    if nparts <= xb.local_device_count():
        devices = xb.local_devices()[:nparts]
    else:
        assert nparts == xb.device_count()
        devices = xb.devices()
    device_assignment = np.array([[d.id for d in devices]])
    device_assignment = np.reshape(device_assignment, (-1, nparts))
    # device_assignment = None  # TODO(skye): replace with default device assignment?

    compiled = dispatch.backend_compile(
        xb.get_backend(), built,
        xb.get_compile_options(nrep, nparts, device_assignment))

    input_specs = [
        pxla.partitioned_sharding_spec(local_nparts, parts, aval)
        for parts, aval in zip(local_in_parts, abstract_args)
    ]
    input_indices = [
        pxla.spec_to_indices(aval.shape, spec) if spec is not None else None
        for aval, spec in zip(abstract_args, input_specs)
    ]

    handle_args = partial(pxla.shard_args, compiled.local_devices(),
                          input_indices)
    handle_outs = _avals_to_results_handler(
        nrep,
        local_nparts,  # type: ignore
        local_out_parts,
        local_out_avals)
    return partial(_execute_spatially_partitioned, compiled, handle_args,
                   handle_outs)