示例#1
0
文件: dispatch.py 项目: jbampton/jax
 def from_xla_computation(
     name: str,
     xla_computation,
     nreps: int,
     device: Optional[Device],
     backend,
     tuple_args: bool,
     in_avals,
     out_avals,
     kept_var_idx) -> 'XlaCompiledComputation':
   sticky_device = device
   result_handlers = map(partial(aval_to_result_handler, sticky_device),
                         out_avals)
   options = xb.get_compile_options(
       num_replicas=nreps,
       num_partitions=1,
       device_assignment=(sticky_device,) if sticky_device else None)
   options.parameter_is_tupled_arguments = tuple_args
   with log_elapsed_time(f"Finished XLA compilation of {name} "
                         "in {elapsed_time} sec"):
     compiled = compile_or_get_cached(backend, xla_computation, options)
   buffer_counts = (None if len(out_avals) == 1 else
                    [aval_to_num_buffers(aval) for aval in out_avals])
   execute = _execute_compiled if nreps == 1 else _execute_replicated
   unsafe_call = partial(execute, name, compiled, buffer_counts,
                         result_handlers, kept_var_idx)
   return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call)
示例#2
0
 def test_set_device_assignment_with_partition(self):
   compile_options = xb.get_compile_options(
       num_replicas=2, num_partitions=2, device_assignment=[[0, 1], [2, 3]])
   expected_device_assignment = ("Computations: 2 Replicas: 2\nComputation 0: "
                                 "0 2 \nComputation 1: 1 3 \n")
   self.assertEqual(compile_options.device_assignment.__repr__(),
                    expected_device_assignment)
示例#3
0
 def test_set_device_assignment_no_partition(self):
   compile_options = xb.get_compile_options(
       num_replicas=4, num_partitions=1, device_assignment=[0, 1, 2, 3])
   expected_device_assignment = ("Computations: 1 Replicas: 4\nComputation 0: "
                                 "0 1 2 3 \n")
   self.assertEqual(compile_options.device_assignment.__repr__(),
                    expected_device_assignment)
示例#4
0
文件: dispatch.py 项目: John1Tang/jax
 def from_xla_computation(
     name: str,
     xla_computation: Optional[ir.Module],
     explicit_args: Optional[Sequence[bool]],
     nreps: int,
     device: Optional[Device],
     backend: Backend,
     tuple_args: bool,
     in_avals: Sequence[core.AbstractValue],
     out_avals: Sequence[core.AbstractValue],
     kept_var_idx: Set[int]) -> XlaCompiledComputation:
   sticky_device = device
   input_handler = _input_handler(explicit_args, in_avals)
   result_handlers = map(partial(aval_to_result_handler, sticky_device),
                         out_avals)
   options = xb.get_compile_options(
       num_replicas=nreps, num_partitions=1,
       device_assignment=(sticky_device,) if sticky_device else None)
   options.parameter_is_tupled_arguments = tuple_args
   with log_elapsed_time(f"Finished XLA compilation of {name} "
                         "in {elapsed_time} sec"):
     compiled = compile_or_get_cached(backend, xla_computation, options)
   buffer_counts = (None if len(out_avals) == 1 and not config.jax_dynamic_shapes
                    else [aval_to_num_buffers(aval) for aval in out_avals])
   execute = _execute_compiled if nreps == 1 else _execute_replicated
   unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts,
                         result_handlers, kept_var_idx)
   return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call)
示例#5
0
文件: dispatch.py 项目: romanngg/jax
 def from_xla_computation(
         name: str, xla_computation: Optional[ir.Module],
         in_type: Optional[pe.InputType], out_type: Optional[pe.OutputType],
         nreps: int, device: Optional[Device], backend: Backend,
         tuple_args: bool, in_avals: Sequence[core.AbstractValue],
         out_avals: Sequence[core.AbstractValue],
         has_unordered_effects: bool, ordered_effects: List[core.Effect],
         kept_var_idx: Set[int],
         keepalive: Optional[Any]) -> XlaCompiledComputation:
     sticky_device = device
     input_handler = _input_handler(backend, in_type, out_type)
     result_handler = _result_handler(backend, sticky_device, out_type)
     options = xb.get_compile_options(
         num_replicas=nreps,
         num_partitions=1,
         device_assignment=(sticky_device, ) if sticky_device else None)
     options.parameter_is_tupled_arguments = tuple_args
     with log_elapsed_time(f"Finished XLA compilation of {name} "
                           "in {elapsed_time} sec"):
         compiled = compile_or_get_cached(backend, xla_computation, options)
     buffer_counts = [aval_to_num_buffers(aval) for aval in out_avals]
     if ordered_effects or has_unordered_effects:
         num_output_tokens = len(ordered_effects) + has_unordered_effects
         buffer_counts = ([1] * num_output_tokens) + buffer_counts
     execute = _execute_compiled if nreps == 1 else _execute_replicated
     unsafe_call = partial(
         execute,
         name,
         compiled,
         input_handler,
         buffer_counts,  # type: ignore  # noqa: F811
         result_handler,
         has_unordered_effects,
         ordered_effects,
         kept_var_idx)
     return XlaCompiledComputation(compiled, in_avals, kept_var_idx,
                                   unsafe_call, keepalive)
示例#6
0
def _sharded_callable(
        fun: lu.WrappedFun, nparts: Optional[int],
        in_parts: Tuple[pxla.PartitionsOrReplicated, ...],
        out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]],
        local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]],
        local_out_parts_thunk: Callable[[], Optional[Tuple[
            pxla.PartitionsOrReplicated,
            ...]]], local_nparts: Optional[int], name: str, *abstract_args):
    nrep = 1

    if local_in_parts is None:
        local_in_parts = in_parts

    global_abstract_args = [
        pxla.get_global_aval(arg, parts,
                             lparts) for arg, parts, lparts in safe_zip(
                                 abstract_args, in_parts, local_in_parts)
    ]

    if logging.vlog_is_on(2):
        logging.vlog(2, "abstract_args: %s", abstract_args)
        logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
        logging.vlog(2, "in_parts: %s", in_parts)
        logging.vlog(2, "local_in_parts: %s", local_in_parts)

    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
        fun, global_abstract_args)

    platform = xb.get_backend().platform

    nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
    assert nparts is not None
    if nparts > xb.device_count():
        raise ValueError(
            f"sharded_jit computation requires {nparts} devices, "
            f"but only {xb.device_count()} devices are available.")
    if xb.local_device_count() < nparts < xb.device_count():
        raise NotImplementedError(
            f"sharded_jit across multiple hosts must use all available devices. "
            f"Got {nparts} out of {xb.device_count()} requested devices "
            f"(local device count: {xb.local_device_count()})")

    if local_nparts is None:
        if nparts > xb.local_device_count():
            raise ValueError(
                "Specify 'local_nparts' when using cross-process sharded_jit "
                "and all inputs and outputs are replicated.")
        else:
            local_nparts = nparts
    if local_nparts > xb.local_device_count():
        raise ValueError(
            f"sharded_jit computation requires {local_nparts} local devices, "
            f"but only {xb.local_device_count()} local devices are available.")

    if logging.vlog_is_on(2):
        logging.vlog(2, "nparts: %d  local_nparts: %d", nparts, local_nparts)

    out_parts = out_parts_thunk()

    local_out_parts = local_out_parts_thunk()
    if local_out_parts is None:
        local_out_parts = out_parts

    if logging.vlog_is_on(2):
        logging.vlog(2, "out_parts: %s", out_parts)
        logging.vlog(2, "local_out_parts: %s", local_out_parts)

    local_out_avals = [
        pxla.get_local_aval(out, parts,
                            lparts) for out, parts, lparts in safe_zip(
                                global_out_avals, out_parts, local_out_parts)
    ]

    log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
    logging.log(log_priority, "Compiling %s for %d devices with args %s.",
                fun.__name__, nparts, global_abstract_args)

    axis_env = xla.AxisEnv(nrep, (), ())
    unordered_effects = [
        eff for eff in jaxpr.effects if eff not in core.ordered_effects
    ]
    ordered_effects = [
        eff for eff in jaxpr.effects if eff in core.ordered_effects
    ]
    module, _ = mlir.lower_jaxpr_to_module(
        f"spjit_{fun.__name__}",
        core.ClosedJaxpr(jaxpr, consts),
        unordered_effects,
        ordered_effects,
        platform=platform,
        axis_context=mlir.ReplicaAxisContext(axis_env),
        name_stack=new_name_stack(wrap_name(name, "sharded_jit")),
        donated_args=[False] * len(in_parts),
        arg_shardings=safe_map(xla.sharding_to_proto, in_parts),
        result_shardings=safe_map(xla.sharding_to_proto, out_parts))
    built = xc._xla.mlir.mlir_module_to_xla_computation(
        mlir.module_to_string(module), use_tuple_args=False, return_tuple=True)

    if nparts <= xb.local_device_count():
        devices = xb.local_devices()[:nparts]
    else:
        assert nparts == xb.device_count()
        devices = xb.devices()
    device_assignment = np.array([[d for d in devices]])
    device_assignment = np.reshape(device_assignment, (-1, nparts))
    # device_assignment = None  # TODO(skye): replace with default device assignment?

    compiled = dispatch.backend_compile(
        xb.get_backend(), built,
        xb.get_compile_options(nrep, nparts, device_assignment))

    input_specs = [
        pxla.partitioned_sharding_spec(local_nparts, parts, aval)
        for parts, aval in zip(local_in_parts, abstract_args)
    ]
    input_indices = [
        pxla.spec_to_indices(aval.shape, spec) if spec is not None else None
        for aval, spec in zip(abstract_args, input_specs)
    ]

    handle_args = partial(pxla.shard_args, compiled.local_devices(),
                          input_indices)
    handle_outs = _avals_to_results_handler(
        nrep,
        local_nparts,  # type: ignore
        local_out_parts,
        local_out_avals)
    return partial(_execute_spatially_partitioned, compiled, handle_args,
                   handle_outs)
示例#7
0
def _sharded_callable(
    fun: lu.WrappedFun, nparts: Optional[int],
    in_parts: Tuple[pxla.PartitionsOrReplicated, ...],
    out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]],
    local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]],
    local_out_parts_thunk: Callable[[], Optional[Tuple[pxla.PartitionsOrReplicated, ...]]],
    local_nparts: Optional[int], name: str, *abstract_args):
  nrep = 1

  if local_in_parts is None:
    local_in_parts = in_parts

  global_abstract_args = [pxla.get_global_aval(arg, parts, lparts)
                          for arg, parts, lparts
                          in safe_zip(abstract_args, in_parts, local_in_parts)]

  if logging.vlog_is_on(2):
    logging.vlog(2, "abstract_args: %s", abstract_args)
    logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
    logging.vlog(2, "in_parts: %s", in_parts)
    logging.vlog(2, "local_in_parts: %s", local_in_parts)

  jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(fun, global_abstract_args)

  if xb.get_backend().platform not in ["tpu", "gpu"]:
    # TODO(skye): fall back to regular jit?
    raise ValueError("sharded_jit not supported for " +
                     xb.get_backend().platform)

  nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
  assert nparts is not None
  if nparts > xb.device_count():
    raise ValueError(
        f"sharded_jit computation requires {nparts} devices, "
        f"but only {xb.device_count()} devices are available.")
  if xb.local_device_count() < nparts < xb.device_count():
    raise NotImplementedError(
        f"sharded_jit across multiple hosts must use all available devices. "
        f"Got {nparts} out of {xb.device_count()} requested devices "
        f"(local device count: {xb.local_device_count()})")

  if local_nparts is None:
    if nparts > xb.local_device_count():
      raise ValueError(
        "Specify 'local_nparts' when using cross-process sharded_jit "
        "and all inputs and outputs are replicated.")
    else:
      local_nparts = nparts
  if local_nparts > xb.local_device_count():
    raise ValueError(
        f"sharded_jit computation requires {local_nparts} local devices, "
        f"but only {xb.local_device_count()} local devices are available.")

  if logging.vlog_is_on(2):
    logging.vlog(2, "nparts: %d  local_nparts: %d", nparts, local_nparts)

  out_parts = out_parts_thunk()

  local_out_parts = local_out_parts_thunk()
  if local_out_parts is None:
    local_out_parts = out_parts

  if logging.vlog_is_on(2):
    logging.vlog(2, "out_parts: %s", out_parts)
    logging.vlog(2, "local_out_parts: %s", local_out_parts)

  local_out_avals = [pxla.get_local_aval(out, parts, lparts)
                     for out, parts, lparts
                     in safe_zip(global_out_avals, out_parts, local_out_parts)]

  log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
  logging.log(log_priority,
              "Compiling %s for %d devices with args %s.",
              fun.__name__, nparts, global_abstract_args)

  c = xb.make_computation_builder("spjit_{}".format(fun.__name__))
  xla_consts = _map(partial(xb.constant, c), consts)
  xla_args = _xla_sharded_args(c, global_abstract_args, in_parts)
  axis_env = xla.AxisEnv(nrep, (), ())
  out_nodes = xla.jaxpr_subcomp(
      c, jaxpr, None, axis_env, xla_consts,
      extend_name_stack(wrap_name(name, "sharded_jit")), *xla_args)
  out_tuple = xb.with_sharding(c, out_parts, xops.Tuple, c, out_nodes)
  built = c.Build(out_tuple)

  if nparts <= xb.local_device_count():
    devices = xb.local_devices()[:nparts]
  else:
    assert nparts == xb.device_count()
    devices = xb.devices()
  device_assignment = np.array([[d.id for d in devices]])
  device_assignment = np.reshape(device_assignment, (-1, nparts))
  # device_assignment = None  # TODO(skye): replace with default device assignment?

  compiled = xla.backend_compile(
      xb.get_backend(), built,
      xb.get_compile_options(nrep, nparts, device_assignment))

  input_specs = [
      pxla.partitioned_sharding_spec(local_nparts, parts, aval)
      for parts, aval in zip(local_in_parts, abstract_args)]
  input_indices = [pxla.spec_to_indices(aval.shape, spec)
                   if spec is not None else None
                   for aval, spec in zip(abstract_args, input_specs)]

  handle_args = partial(pxla.shard_args, compiled.local_devices(),
                        input_indices)
  handle_outs = _avals_to_results_handler(nrep, local_nparts,  # type: ignore
                                          local_out_parts, local_out_avals)
  return partial(_execute_spatially_partitioned, compiled, handle_args,
                 handle_outs)