Exemplo n.º 1
0
def _remat_using_cond(ctx, in_nodes, name, call_jaxpr):
    """Lower remat to a Conditional which always returns true. This:
    1. Circumvents common subexpression elimination.
    2. In common case of `jax.grad(jax.remat(f))`, ensures the remat blocks
       occur after the primal blocks, because cotangent is an input to the
       Conditional."""
    # Fake condition which always selects True branch.
    c = ctx.builder
    rng = xops.RngUniform(xops.Constant(c, np.array(0, dtype=np.float32)),
                          xops.Constant(c, np.array(1, dtype=np.float32)),
                          xc.Shape.array_shape(xc.PrimitiveType.F32, []))
    pred = xops.Lt(rng, xops.Constant(c, np.array(2, dtype=np.float32)))

    true_op = xops.Tuple(c, in_nodes)
    remat_subc = xc.XlaBuilder("remat_call_subcomputation")
    input_op = parameter(remat_subc, 0, c.get_shape(true_op), replicated=[])
    args = xla_destructure(remat_subc, input_op)
    sub_ctx = ctx.replace(builder=remat_subc,
                          name_stack=extend_name_stack(
                              ctx.name_stack, wrap_name(name, 'remat')))
    out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
    out_node_shapes = [remat_subc.get_shape(o) for o in out_nodes]
    remat_subc = remat_subc.build(xops.Tuple(remat_subc, out_nodes))

    false_op = true_op
    dummy_subc = xc.XlaBuilder("remat_call_dummy_subcomputation")
    parameter(dummy_subc, 0, c.get_shape(false_op), replicated=[])
    out_nodes = [_zeros(dummy_subc, s) for s in out_node_shapes]
    dummy_subc = dummy_subc.build(xops.Tuple(dummy_subc, out_nodes))

    return xla_destructure(
        c, xops.Conditional(pred, true_op, remat_subc, false_op, dummy_subc))
Exemplo n.º 2
0
def _xla_call_translation_rule(ctx,
                               avals_in,
                               avals_out,
                               *in_nodes,
                               name,
                               backend=None,
                               call_jaxpr,
                               donated_invars,
                               inline=None,
                               device=None):
    del device, donated_invars, inline  # Ignored.
    c = ctx.builder
    check_backend_matches(backend, ctx.platform)
    subc = xc.XlaBuilder(f"jit_{name}")
    args = [parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)]
    sub_ctx = ctx.replace(builder=subc,
                          name_stack=extend_name_stack(ctx.name_stack,
                                                       wrap_name(name, 'jit')))
    out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)

    if len(out_nodes) == 1:
        subc = subc.Build(out_nodes[0])
        return [xops.Call(c, subc, list(in_nodes))]
    else:
        subc = subc.Build(xops.Tuple(subc, out_nodes))
        return xla_destructure(c, xops.Call(c, subc, list(in_nodes)))
Exemplo n.º 3
0
def _sharded_jit_translation_rule(c, axis_env, in_nodes, name_stack,
                                  in_parts, out_parts_thunk, nparts, backend,
                                  name, call_jaxpr, local_in_parts,
                                  local_out_parts_thunk, local_nparts):
  subc = xc.XlaBuilder(f"sharded_jit_{name}")

  # We assume any extra leading in_nodes are constants and replicate them.
  num_extra_nodes = len(in_nodes) - len(in_parts)
  assert num_extra_nodes >= 0
  in_parts = (None,) * num_extra_nodes + in_parts

  args = []
  for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)):
    # We use xb.set_sharding instead of xb.with_sharding because inlined calls
    # shouldn't have shardings set directly on the inputs or outputs.
    arg = xb.parameter(subc, i, c.GetShape(n))
    args.append(xb.set_sharding(subc, arg, sharding))

  out_nodes = xla.jaxpr_subcomp(
      subc, call_jaxpr, backend, axis_env, (),
      extend_name_stack(name_stack, wrap_name(name, "sharded_jit")), *args)
  out_parts = out_parts_thunk()
  assert len(out_parts) == len(out_nodes)
  out_nodes = [xb.set_sharding(subc, out, sharding)
               for out, sharding in safe_zip(out_nodes, out_parts)]

  subc = subc.build(xops.Tuple(subc, out_nodes))
  return xops.Call(c, subc, list(in_nodes))
Exemplo n.º 4
0
def _pjit_translation_rule(c, axis_env, in_nodes, name_stack, backend, name,
                           jaxpr, in_axis_resources, out_axis_resources,
                           resource_env, donated_invars, positional_semantics):
    mesh = resource_env.physical_mesh
    subc = xc.XlaBuilder(f"pjit_{name}")

    args = []
    for i, (n,
            axis_resources) in enumerate(safe_zip(in_nodes,
                                                  in_axis_resources)):
        # N.B. inlined calls shouldn't have shardings set directly on the inputs or
        # outputs (set_sharding_proto adds an identity operation).
        arg = xb.parameter(subc, i, c.GetShape(n))
        args.append(
            xb.set_sharding_proto(
                subc, arg, get_sharding_proto(c, n, axis_resources, mesh)))

    # TODO: Think about how to avoid duplicating constants with the outer jaxpr
    out_nodes = xla.jaxpr_subcomp(
        subc, jaxpr.jaxpr, backend, axis_env,
        xla._xla_consts(subc, jaxpr.consts),
        extend_name_stack(name_stack, wrap_name(name, "pjit")), *args)
    out_nodes = [
        xb.set_sharding_proto(
            subc, out, get_sharding_proto(subc, out, axis_resources, mesh))
        for out, axis_resources in safe_zip(out_nodes, out_axis_resources)
    ]

    subc = subc.build(xops.Tuple(subc, out_nodes))
    return xops.Call(c, subc, list(in_nodes))
Exemplo n.º 5
0
def _sharded_jit_translation_rule(ctx, avals_in, avals_out, *in_nodes,
                                  in_parts, out_parts_thunk, nparts,
                                  name, call_jaxpr, local_in_parts,
                                  local_out_parts_thunk, local_nparts):
  subc = xc.XlaBuilder(f"sharded_jit_{name}")

  # We assume any extra leading in_nodes are constants and replicate them.
  num_extra_nodes = len(in_nodes) - len(in_parts)
  assert num_extra_nodes >= 0
  in_parts = (None,) * num_extra_nodes + in_parts

  args = []
  for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)):
    # We use xla.set_sharding instead of xla.with_sharding because inlined calls
    # shouldn't have shardings set directly on the inputs or outputs.
    arg = xla.parameter(subc, i, ctx.builder.GetShape(n))
    args.append(xla.set_sharding(subc, arg, sharding))

  sub_ctx = ctx.replace(
      builder=subc,
      name_stack=new_name_stack(wrap_name(name, "sharded_jit")))
  out_nodes = xla.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
  out_parts = out_parts_thunk()
  assert len(out_parts) == len(out_nodes)
  out_nodes = [xla.set_sharding(subc, out, sharding)
               for out, sharding in safe_zip(out_nodes, out_parts)]

  subc = subc.build(xops.Tuple(subc, out_nodes))
  return xla.xla_destructure(ctx.builder,
                             xops.Call(ctx.builder, subc, list(in_nodes)))
Exemplo n.º 6
0
Arquivo: xla.py Projeto: romanngg/jax
def primitive_subcomputation(platform: str, axis_env: 'AxisEnv',
                             prim: core.Primitive,
                             avals_in: Sequence[core.AbstractValue],
                             avals_out: Sequence[core.AbstractValue],
                             **params):
    c = xc.XlaBuilder(f"primitive_computation_{prim.name}")
    counts = it.count()
    xla_args = [
        parameter(c, next(counts), xla_shape) for a in avals_in
        for xla_shape in aval_to_xla_shapes(a)
    ]
    if (platform is not None
            and prim in _backend_specific_translations[platform]):
        rule = _backend_specific_translations[platform][prim]
    elif prim in _translations:
        rule = _translations[prim]

    ctx = TranslationContext(builder=c,
                             platform=platform,
                             axis_env=axis_env,
                             name_stack=new_name_stack())
    ans = rule(ctx, avals_in, avals_out, *xla_args, **params)

    if prim.multiple_results:
        return c.build(xops.Tuple(c, ans))
    else:
        x, = ans
        return c.build(x)
Exemplo n.º 7
0
def lower_jaxpr_to_xla_module(
        fn_name: str,
        jaxpr: core.ClosedJaxpr,
        platform: str,
        axis_env: AxisEnv,
        name_stack: str,
        tuple_args: bool,
        donated_invars: Sequence[bool],
        replicated_args: Optional[Sequence[bool]],
        arg_partitions: Optional[Any],
        out_partitions: Optional[Any],
        partitions_are_protos: bool = False) -> xc.XlaComputation:
    """Lowers a closed jaxpr to a top-level XLA module."""
    c = xc.XlaBuilder(fn_name)
    xla_consts = _xla_consts(c, jaxpr.consts)
    xla_args, donated_invars = _xla_callable_args(
        c,
        jaxpr.in_avals,
        tuple_args,
        donated_invars=donated_invars,
        replicated=replicated_args,
        partitions=arg_partitions,
        partitions_proto=partitions_are_protos)
    ctx = TranslationContext(c, platform, axis_env, name_stack)
    out_nodes = jaxpr_subcomp(ctx, jaxpr.jaxpr, xla_consts, *xla_args)
    # Replace tokens with a dummy array value, because the runtime cannot
    # handle token arguments.
    out_aval_lens = [len(aval_to_xla_shapes(a)) for a in jaxpr.out_avals]
    out_nodes = util.flatten(
        [[_make_token_return_value(c)] if a is core.abstract_token else v
         for a, v in zip(jaxpr.out_avals,
                         util.unflatten(out_nodes, out_aval_lens))])

    # There is a non-zero cost to building an output tuple, particularly on TPU.
    # Avoid it if the output arity is 1.
    if out_partitions is None:
        output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple(
            c, out_nodes)
    else:
        build_out_tuple = partial(xops.Tuple, c, out_nodes)
        if partitions_are_protos:
            output = with_sharding_proto(c, out_partitions, build_out_tuple)
        else:
            output = with_sharding(c, out_partitions, build_out_tuple)

    if platform in ("gpu", "tpu"):
        donated_invars = set_up_aliases(c, xla_args, c.GetShape(output),
                                        donated_invars, tuple_args)
    if any(donated_invars):
        # TODO(tomhennigan): At call time we should mark these buffers as deleted.
        unused_donations = [
            str(c.GetShape(a)) for a, d in zip(xla_args, donated_invars) if d
        ]
        warnings.warn("Some donated buffers were not usable: {}".format(
            ", ".join(unused_donations)))
    return c.build(output)
Exemplo n.º 8
0
def _remat_using_while(ctx, in_nodes, name, call_jaxpr):
    """Lower remat to a single iteration while loop."""
    c = ctx.builder
    # Dummy subc for getting subcomp shapes.
    dummy_inputs = xops.Tuple(c, in_nodes)
    dummy_subc = xc.XlaBuilder("remat_dummy_subcomputation")
    dummy_input_op = parameter(dummy_subc,
                               0,
                               c.get_shape(dummy_inputs),
                               replicated=[])
    dummy_args = xla_destructure(dummy_subc, dummy_input_op)
    dummy_ctx = ctx.replace(builder=dummy_subc,
                            name_stack=extend_name_stack(
                                ctx.name_stack, wrap_name(name, 'remat')))
    dummy_subcomp_outs = jaxpr_subcomp(dummy_ctx, call_jaxpr, (), *dummy_args)
    out_node_shapes = [dummy_subc.get_shape(o) for o in dummy_subcomp_outs]

    i_init = xops.Constant(c, np.array(0, dtype=np.int32))
    zeros_like_outs = [_zeros(c, s) for s in out_node_shapes]
    inputs = xops.Tuple(c, [i_init] + list(in_nodes) + zeros_like_outs)

    cond_subc = xc.XlaBuilder("remat_cond_subcomputation")
    input_op = parameter(cond_subc, 0, c.get_shape(inputs), replicated=[])
    i = xops.GetTupleElement(input_op, 0)
    rng = xops.RngUniform(
        xops.Constant(cond_subc, np.array(1, dtype=np.int32)),
        xops.Constant(cond_subc, np.array(2, dtype=np.int32)),
        xc.Shape.array_shape(xc.PrimitiveType.S32, []))
    cond_subc = cond_subc.build(xops.Lt(i, rng))

    body_subc = xc.XlaBuilder("remat_body_subcomputation")
    input_op = parameter(body_subc, 0, c.get_shape(inputs), replicated=[])
    i, *args = xla_destructure(body_subc, input_op)[:len(in_nodes) + 1]
    i_next = xops.Add(i, xops.Constant(body_subc, np.array(1, dtype=np.int32)))
    body_ctx = ctx.replace(builder=body_subc,
                           name_stack=extend_name_stack(
                               ctx.name_stack, wrap_name(name, 'remat')))
    subcomp_outs = jaxpr_subcomp(body_ctx, call_jaxpr, (), *args)
    out_nodes = [i_next] + args + list(subcomp_outs)
    body_subc = body_subc.build(xops.Tuple(body_subc, out_nodes))
    outs = xops.While(cond_subc, body_subc, inputs)
    return xla_destructure(c, outs)[len(in_nodes) + 1:]
Exemplo n.º 9
0
def _comparator_builder(op_type, is_max_k):
    c = xc.XlaBuilder('top_k_{}_comparator'.format('gt' if is_max_k else 'lt'))
    p0 = xla.parameter(c, 0, xc.Shape.scalar_shape(op_type))
    p1 = xla.parameter(c, 1, xc.Shape.scalar_shape(op_type))
    xla.parameter(c, 2, xc.Shape.scalar_shape(np.dtype(np.int32)))
    xla.parameter(c, 3, xc.Shape.scalar_shape(np.dtype(np.int32)))
    if is_max_k:
        cmp_result = xc.ops.Gt(p0, p1)
    else:
        cmp_result = xc.ops.Lt(p0, p1)
    return c.build(cmp_result)
Exemplo n.º 10
0
Arquivo: xla.py Projeto: rsepassi/jax
def _named_call_translation_rule(ctx, avals_in, avals_out, *in_nodes,
                                 name="core_call", backend=None, call_jaxpr):
  check_backend_matches(backend, ctx.platform)
  c = ctx.builder
  subc = xc.XlaBuilder(name)
  args = [parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)]
  sub_ctx = ctx.replace(builder=subc,
                        name_stack=extend_name_stack(ctx.name_stack, name))
  out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
  subc = subc.Build(xops.Tuple(subc, out_nodes))
  return xla_destructure(c, xops.Call(c, subc, list(in_nodes)))
Exemplo n.º 11
0
Arquivo: xla.py Projeto: John1Tang/jax
def primitive_subcomputation(platform: str, axis_env: 'AxisEnv',
                             prim: core.Primitive,
                             *avals: core.AbstractValue, **params):
  c = xc.XlaBuilder(f"primitive_computation_{prim.name}")
  f = lower_fun(prim.bind, multiple_results=prim.multiple_results,
                new_style=True)
  xla_args, _ = _xla_callable_args(c, avals, tuple_args=False,
                                   filter_tokens=False)
  ctx = TranslationContext(builder=c, platform=platform, axis_env=axis_env,
                           name_stack=new_name_stack())
  ans = f(ctx.replace(builder=c), avals, None, *xla_args, **params)
  if prim.multiple_results:
    ans = xops.Tuple(c, ans)
  else:
    ans, = ans
  return c.build(ans)
Exemplo n.º 12
0
def _sharded_callable(
        fun: lu.WrappedFun, nparts: Optional[int],
        in_parts: Tuple[pxla.PartitionsOrReplicated, ...],
        out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]],
        local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]],
        local_out_parts_thunk: Callable[[], Optional[Tuple[
            pxla.PartitionsOrReplicated,
            ...]]], local_nparts: Optional[int], name: str, *abstract_args):
    nrep = 1

    if local_in_parts is None:
        local_in_parts = in_parts

    global_abstract_args = [
        pxla.get_global_aval(arg, parts,
                             lparts) for arg, parts, lparts in safe_zip(
                                 abstract_args, in_parts, local_in_parts)
    ]

    if logging.vlog_is_on(2):
        logging.vlog(2, "abstract_args: %s", abstract_args)
        logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
        logging.vlog(2, "in_parts: %s", in_parts)
        logging.vlog(2, "local_in_parts: %s", local_in_parts)

    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
        fun, global_abstract_args)

    platform = xb.get_backend().platform
    if platform not in ["tpu", "gpu"]:
        # TODO(skye): fall back to regular jit?
        raise ValueError(f"sharded_jit not supported for {platform}")

    nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
    assert nparts is not None
    if nparts > xb.device_count():
        raise ValueError(
            f"sharded_jit computation requires {nparts} devices, "
            f"but only {xb.device_count()} devices are available.")
    if xb.local_device_count() < nparts < xb.device_count():
        raise NotImplementedError(
            f"sharded_jit across multiple hosts must use all available devices. "
            f"Got {nparts} out of {xb.device_count()} requested devices "
            f"(local device count: {xb.local_device_count()})")

    if local_nparts is None:
        if nparts > xb.local_device_count():
            raise ValueError(
                "Specify 'local_nparts' when using cross-process sharded_jit "
                "and all inputs and outputs are replicated.")
        else:
            local_nparts = nparts
    if local_nparts > xb.local_device_count():
        raise ValueError(
            f"sharded_jit computation requires {local_nparts} local devices, "
            f"but only {xb.local_device_count()} local devices are available.")

    if logging.vlog_is_on(2):
        logging.vlog(2, "nparts: %d  local_nparts: %d", nparts, local_nparts)

    out_parts = out_parts_thunk()

    local_out_parts = local_out_parts_thunk()
    if local_out_parts is None:
        local_out_parts = out_parts

    if logging.vlog_is_on(2):
        logging.vlog(2, "out_parts: %s", out_parts)
        logging.vlog(2, "local_out_parts: %s", local_out_parts)

    local_out_avals = [
        pxla.get_local_aval(out, parts,
                            lparts) for out, parts, lparts in safe_zip(
                                global_out_avals, out_parts, local_out_parts)
    ]

    log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
    logging.log(log_priority, "Compiling %s for %d devices with args %s.",
                fun.__name__, nparts, global_abstract_args)

    c = xc.XlaBuilder("spjit_{}".format(fun.__name__))
    xla_consts = _map(partial(xla.pyval_to_ir_constant, c), consts)
    xla_args = _xla_sharded_args(c, global_abstract_args, in_parts)
    axis_env = xla.AxisEnv(nrep, (), ())
    ctx = xla.TranslationContext(
        c, platform, axis_env,
        extend_name_stack(wrap_name(name, "sharded_jit")))
    out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
    out_tuple = xla.with_sharding(c, out_parts, xops.Tuple, c, out_nodes)
    built = c.Build(out_tuple)

    if nparts <= xb.local_device_count():
        devices = xb.local_devices()[:nparts]
    else:
        assert nparts == xb.device_count()
        devices = xb.devices()
    device_assignment = np.array([[d.id for d in devices]])
    device_assignment = np.reshape(device_assignment, (-1, nparts))
    # device_assignment = None  # TODO(skye): replace with default device assignment?

    compiled = dispatch.backend_compile(
        xb.get_backend(), built,
        xb.get_compile_options(nrep, nparts, device_assignment))

    input_specs = [
        pxla.partitioned_sharding_spec(local_nparts, parts, aval)
        for parts, aval in zip(local_in_parts, abstract_args)
    ]
    input_indices = [
        pxla.spec_to_indices(aval.shape, spec) if spec is not None else None
        for aval, spec in zip(abstract_args, input_specs)
    ]

    handle_args = partial(pxla.shard_args, compiled.local_devices(),
                          input_indices)
    handle_outs = _avals_to_results_handler(
        nrep,
        local_nparts,  # type: ignore
        local_out_parts,
        local_out_avals)
    return partial(_execute_spatially_partitioned, compiled, handle_args,
                   handle_outs)
Exemplo n.º 13
0
 def test_parameter_replication(self):
     c = xc.XlaBuilder("test")
     _ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()),
                       "", False)
     built_c = c.Build()
     assert "parameter_replication={false}" in built_c.as_hlo_text()
Exemplo n.º 14
0
 def test_parameter_replication_default(self):
     c = xc.XlaBuilder("test")
     _ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()))
     built_c = c.Build()
     assert "replication" not in built_c.as_hlo_text()