def scan(x, op, *, comm=None, token=None): """Perform a scan operation. Arguments: x: Array or scalar input to send. op (mpi4py.MPI.Op): The reduction operator (e.g :obj:`mpi4py.MPI.SUM`). comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to a clone of :obj:`COMM_WORLD`). token (Token): XLA token to use to ensure correct execution order. If not given, a new token is generated. Returns: Tuple[DeviceArray, Token]: - Result of the scan operation. - A new, modified token, that depends on this operation. """ if token is None: token = create_token(x) if comm is None: comm = get_default_comm() op = wrap_as_hashable(op) comm = wrap_as_hashable(comm) return tuple(mpi_scan_p.bind(x, token, op=op, comm=comm))
def allreduce(x, op, *, comm=None, token=None): """Perform an allreduce operation. .. note:: This primitive can be differentiated via :func:`jax.grad` and related functions if ``op`` is :obj:`mpi4py.MPI.SUM`. Arguments: x: Array or scalar input. op (mpi4py.MPI.Op): The reduction operator (e.g :obj:`mpi4py.MPI.SUM`). comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to a clone of :obj:`COMM_WORLD`). token (Token): XLA token to use to ensure correct execution order. If not given, a new token is generated. Returns: Tuple[DeviceArray, Token]: - Result of the allreduce operation. - A new, modified token, that depends on this operation. """ if token is None: token = create_token(x) if comm is None: comm = get_default_comm() op = wrap_as_hashable(op) comm = wrap_as_hashable(comm) return tuple( mpi_allreduce_p.bind(x, token, op=op, comm=comm, transpose=False))
def Sendrecv( sendbuf, recvbuf, source, dest, sendtag=0, recvtag=_MPI.ANY_TAG, comm=_MPI.COMM_WORLD, status=None, token=None, ): if token is None: token = create_token(sendbuf) return mpi_sendrecv_p.bind( sendbuf, recvbuf, token, source=source, dest=dest, sendtag=sendtag, recvtag=recvtag, comm=comm, status=status, )
def f(x): token = lax.create_token(x) y, token = lax.infeed(token, shape=jax.ShapedArray((3, 4), jnp.float32)) token = lax.outfeed(token, y + np.float32(1)) return x - 1 if config.omnistaging_enabled else lax.tie_in( token, x - 1)
def Sendrecv( sendbuf, recvbuf, source, dest, sendtag=0, recvtag=_MPI.ANY_TAG, comm=_MPI.COMM_WORLD, status=None, token=None, ): if token is None: token = create_token(sendbuf) comm = wrap_as_hashable(comm) if status is not None: status = wrap_as_hashable(status) return mpi_sendrecv_p.bind( sendbuf, recvbuf, token, source=source, dest=dest, sendtag=sendtag, recvtag=recvtag, comm=comm, status=status, )
def reduce(x, op, root, comm=None, token=None): """Perform a reduce operation. Arguments: x: Array or scalar input to send. op (mpi4py.MPI.Op): The reduction operator (e.g :obj:`mpi4py.MPI.SUM`). root (int): Rank of the root MPI process. comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to a clone of :obj:`COMM_WORLD`). token (Token): XLA token to use to ensure correct execution order. If not given, a new token is generated. Returns: Tuple[DeviceArray, Token]: - Result of the reduce operation on root process, otherwise unmodified input. - A new, modified token, that depends on this operation. """ if token is None: token = create_token(x) if comm is None: comm = get_default_comm() rank = comm.Get_rank() op = wrap_as_hashable(op) comm = wrap_as_hashable(comm) res, token = mpi_reduce_p.bind(x, token, op=op, root=root, comm=comm) if rank != root: return x, token return res, token
def _dummy_remat_result(aval: core.AbstractValue): """A result that will be discarded""" if aval is core.abstract_token: return lax.create_token() else: return lax.broadcast(np.array(0, dtype=aval.dtype), aval.shape) # type: ignore
def Recv( x, source=_MPI.ANY_SOURCE, tag=_MPI.ANY_TAG, comm=_MPI.COMM_WORLD, status=None, token=None, ): """ Recv(x, source=_MPI.ANY_SOURCE, tag=_MPI.ANY_TAG, comm=_MPI.COMM_WORLD, status=None, token=None) Receives the input`x` from the target rank `source` using the communicator `comm` which defaults to the world comunicator, with the `tag`. An optional token can be passed, which is used to force jax to execute MPI operations in the correct order. This is particularly important if you are performing different Send/Recv operations, which might otherwise deadlock. Argumemnts: x: Array or scalar input with the desired shape and dtype. source: rank of the source MPI process. tag: Tag of this message. comm: The communicator (defaults to MPI.COMM_WORLD) status: token: token to force a sequential order in the operations (default=None) Returns: res: the received array or scalar new_token: a new, modified token, that depends on this operation. """ if token is None: token = create_token(x) out = mpi_recv_p.bind(x, token, source=source, tag=tag, comm=comm, status=status) return out
def f(x): token = lax.create_token(x) (y,), token = lax.infeed( token, shape=(jax.ShapedArray((3, 4), jnp.float32),)) (z,), _ = lax.infeed( token, shape=(jax.ShapedArray((3, 1, 1), jnp.float32),)) return x + y + z
def Send(x, dest, tag=0, comm=_MPI.COMM_WORLD, token=None): """ Send(x, dest, tag=0, comm=_MPI.COMM_WORLD, token=None) Sends the input`x` to the target rank `dest` using the communicator `comm` which defaults to the world comunicator, with the `tag`. An optional token can be passed, which is used to force jax to execute MPI operations in the correct order. This is particularly important if you are performing different Send/Recv operations, which might otherwise deadlock. Argumemnts: x: Array or scalar input. dest: rank of the target MPI process. tag: Tag of this message. comm: The communicator (defaults to MPI.COMM_WORLD) token: token to force a sequential order in the operations (default=None) Returns: new_token: a new, modified token, that depends on this operation. """ if token is None: token = create_token(x) out = mpi_send_p.bind(x, token, dest=dest, tag=tag, comm=comm) return out
def sendrecv( sendbuf, recvbuf, source, dest, sendtag=0, recvtag=_MPI.ANY_TAG, comm=None, status=None, token=None, ): """Perform a sendrecv operation. .. warning:: Unlike mpi4py's sendrecv, this returns a *new* array with the received data. Arguments: sendbuf: Array or scalar input to send. recvbuf: Array or scalar input with the correct shape and dtype. This can contain arbitrary data and will not be overwritten. source (int): Rank of the source MPI process. dest (int): Rank of the destination MPI process. sendtag (int): Tag of this message for sending. recvtag (int): Tag of this message for receiving. comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to a clone of :obj:`COMM_WORLD`). status (mpi4py.MPI.Status): Status object, can be used for introspection. token (Token): XLA token to use to ensure correct execution order. If not given, a new token is generated. Returns: Tuple[DeviceArray, Token]: - Received data. - A new, modified token, that depends on this operation. """ if token is None: token = create_token(sendbuf) if comm is None: comm = get_default_comm() comm = wrap_as_hashable(comm) if status is not None: status = wrap_as_hashable(status) return mpi_sendrecv_p.bind( sendbuf, recvbuf, token, source=source, dest=dest, sendtag=sendtag, recvtag=recvtag, comm=comm, status=status, )
def device_train_loop(optimizer, dropout_rngs, metrics, step, epoch): # Create symbolic token for threading infeed data. token = lax.create_token(step) # Run on-device loop. optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, dropout_rngs, metrics, token, step, epoch)) return optimizer, dropout_rngs, metrics, step
def device_train_loop(optimizer, state, metrics, step, loop): token = lax.create_token(step) optimizer, state, metrics, _, step, _ = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, state, metrics, token, step, loop)) state = sync_batchnorm_stats(state) metrics = allreduce_metrics(metrics) return optimizer, state, metrics, step
def host_loop_eval_step(model, state, metrics): token = lax.create_token(metrics['samples']) batch, token = lax.infeed( token, shape=(jax.ShapedArray(eval_input_shape, model_dtype), jax.ShapedArray((device_eval_batch_size, ), jnp.int32))) metrics = eval_step(model, state, batch, metrics, image_format, space_to_depth) return metrics
def device_train_loop(optimizer, dropout_rng, total_loss, lm_loss, sentence_loss, step, epoch, num_steps_per_epoch): """Device training loop.""" token = lax.create_token(step) (optimizer, total_loss, lm_loss, sentence_loss, dropout_rng, _, step, epoch, num_steps_per_epoch) = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, total_loss, lm_loss, sentence_loss, dropout_rng, token, step, epoch, num_steps_per_epoch)) return optimizer, total_loss, lm_loss, sentence_loss, dropout_rng, step
def f_for_jit(x): token = lax.create_token(x) (y,), token = lax.infeed( token, shape=(jax.ShapedArray(x.shape, np.float32),)) (z,), token = lax.infeed( token, shape=(jax.ShapedArray(x.shape, np.float32),)) (w,), token = lax.infeed( token, shape=(jax.ShapedArray(x.shape, np.float32),)) return x + y + z + w
def scatter( x, root, *, comm=None, token=None, ): """Perform a scatter operation. .. warning:: Unlike mpi4py's scatter, this returns a *new* array with the received data. .. warning:: The expected shape of the first input varies between ranks. On the root process, it is ``(nproc, *input_shape)``. On all other processes, it is ``input_shape``. Arguments: x: Array or scalar input with the correct shape and dtype. On the root process, this contains the data to send, and its first axis must have size ``nproc``. On non-root processes, this may contain arbitrary data and will not be overwritten. root (int): Rank of the root MPI process. comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to a clone of :obj:`COMM_WORLD`). token (Token): XLA token to use to ensure correct execution order. If not given, a new token is generated. Returns: Tuple[DeviceArray, Token]: - Received data. - A new, modified token, that depends on this operation. """ if token is None: token = create_token(x) if comm is None: comm = get_default_comm() rank = comm.Get_rank() if rank == root: size = comm.Get_size() if x.shape[0] != size: raise ValueError("Scatter input must have shape (nproc, ...)") comm = wrap_as_hashable(comm) return tuple(mpi_scatter_p.bind( x, token, root=root, comm=comm, ))
def host_loop_train_step(optimizer, state, metrics): token = lax.create_token(optimizer.state[0].step) batch, token = lax.infeed(token, shape=(jax.ShapedArray( train_input_shape, model_dtype), jax.ShapedArray((device_batch_size, ), jnp.int32))) optimizer, state, metrics = train_step(optimizer, state, batch, metrics, learning_rate_fn, image_format, space_to_depth) return optimizer, state, metrics
def gather( x, root, *, comm=None, token=None, ): """Perform a gather operation. .. warning:: ``x`` must have the same shape and dtype on all processes. .. warning:: The shape of the returned data varies between ranks. On the root process, it is ``(nproc, *input_shape)``. On all other processes the output is identical to the input. Arguments: x: Array or scalar input to send. root (int): Rank of the root MPI process. comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to a clone of :obj:`COMM_WORLD`). token (Token): XLA token to use to ensure correct execution order. If not given, a new token is generated. Returns: Tuple[DeviceArray, Token]: - Received data on root process, otherwise unmodified input. - A new, modified token, that depends on this operation. """ if token is None: token = create_token(x) if comm is None: comm = get_default_comm() rank = comm.Get_rank() comm = wrap_as_hashable(comm) res, token = mpi_gather_p.bind( x, token, root=root, comm=comm, ) if rank != root: return (x, token) return (res, token)
def f_for_pjit(x): token = lax.create_token(x) # A replicated infeed (y, ), token = lax.infeed(token, shape=(jax.ShapedArray( x.shape, np.float32), ), partitions=(None, )) # An infeed sharded on first axis (z, ), token = lax.infeed(token, shape=(jax.ShapedArray( x.shape, np.float32), ), partitions=(P(nr_devices, 1), )) # An infeed sharded on second axis (w, ), token = lax.infeed(token, shape=(jax.ShapedArray( x.shape, np.float32), ), partitions=(P(1, nr_devices), )) return x + y + z + w
def barrier(*, comm=None, token=None): """Perform a barrier operation. Arguments: comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to a clone of :obj:`COMM_WORLD`). token (Token): XLA token to use to ensure correct execution order. If not given, a new token is generated. Returns: Token: - A new, modified token, that depends on this operation. """ if token is None: token = create_token() if comm is None: comm = get_default_comm() comm = wrap_as_hashable(comm) return mpi_barrier_p.bind(token, comm=comm)
def alltoall( x, *, comm=None, token=None, ): """Perform an alltoall operation. Arguments: x: Array input to send. First axis must have size ``nproc``. comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to a clone of :obj:`COMM_WORLD`). token (Token): XLA token to use to ensure correct execution order. If not given, a new token is generated. Returns: Tuple[DeviceArray, Token]: - Received data. - A new, modified token, that depends on this operation. """ if token is None: token = create_token(x) if comm is None: comm = get_default_comm() size = comm.Get_size() if x.shape[0] != size: raise ValueError("Alltoall input must have shape (nproc, ...)") comm = wrap_as_hashable(comm) return tuple(mpi_alltoall_p.bind( x, token, comm=comm, ))
def allgather( x, *, comm=None, token=None, ): """Perform an allgather operation. .. warning:: ``x`` must have the same shape and dtype on all processes. Arguments: x: Array or scalar input to send. comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to a clone of :obj:`COMM_WORLD`). token (Token): XLA token to use to ensure correct execution order. If not given, a new token is generated. Returns: Tuple[DeviceArray, Token]: - Received data. - A new, modified token, that depends on this operation. """ if token is None: token = create_token(x) if comm is None: comm = get_default_comm() comm = wrap_as_hashable(comm) return tuple(mpi_allgather_p.bind( x, token, comm=comm, ))
def bcast(x, root, comm=None, token=None): """Perform a bcast (broadcast) operation. .. warning:: Unlike mpi4py's bcast, this returns a *new* array with the received data. Arguments: x: Array or scalar input. Data is only read on root process. On non-root processes, this is used to determine the shape and dtype of the result. root (int): The process to use as source. comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to a clone of :obj:`COMM_WORLD`). token (Token): XLA token to use to ensure correct execution order. If not given, a new token is generated. Returns: Tuple[DeviceArray, Token]: - Received data. - A new, modified token, that depends on this operation. """ if token is None: token = create_token(x) if comm is None: comm = get_default_comm() rank = comm.Get_rank() comm = wrap_as_hashable(comm) res, token = mpi_bcast_p.bind(x, token, root=root, comm=comm) if rank == root: return x, token return res, token
def Allreduce(x, op, comm=_MPI.COMM_WORLD, token=None): """ Allreduce(x, op, comm=_MPI.COMM_WORLD, token=None) Performs the Allreduce operation `op` on the input `x` using the communicator `comm` which defaults to the world comunicator. An optional token can be passed, which is used to force jax to execute MPI operations in the correct order. Argumemnts: x: Array or scalar input. op: The reduction operation `MPI.Op` (e.g: MPI.SUM) comm: The communicator (defaults to MPI.COMM_WORLD) token: token to force a sequential order in the operations (default=None) Returns: res: result of the allreduce operation new_token: a new, modified token, that depends on this operation. This result can be ignored if result forces a data dependency. """ if token is None: token = create_token(x) return mpi_allreduce_p.bind(x, token, op=op, comm=comm)
def send(x, dest, tag=0, comm=None, token=None): """Perform a send operation. Arguments: x: Array or scalar input to send. dest (int): Rank of the destination MPI process. tag (int): Tag of this message. comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to a clone of :obj:`COMM_WORLD`). token (Token): XLA token to use to ensure correct execution order. If not given, a new token is generated. Returns: Token: A new, modified token, that depends on this operation. """ if token is None: token = create_token(x) if comm is None: comm = get_default_comm() comm = wrap_as_hashable(comm) return mpi_send_p.bind(x, token, dest=dest, tag=tag, comm=comm)
def f(x): token = lax.create_token(x) (y, z), token = lax.infeed(token, infeed_shapes, partitions=infeed_parts) return x @ y.T + z
def device_train_loop(optimizer, state, metrics, step, epoch): token = lax.create_token(step) optimizer, state, metrics, _, step, _ = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, state, metrics, token, step, epoch)) return optimizer, state, metrics, step
def f(x): token = lax.create_token(x) y, token = lax.infeed( token, shape=jax.ShapedArray((3, 4), jnp.float32)) token = lax.outfeed(token, y + np.float32(1)) return x - 1
def f(x): token = lax.create_token(x) token = lax.outfeed(token, x, partitions=(None,)) token = lax.outfeed(token, x, partitions=(P(nr_devices, 1),)) token = lax.outfeed(token, x, partitions=(P(1, nr_devices),)) return x