Example #1
0
def scan(x, op, *, comm=None, token=None):
    """Perform a scan operation.

    Arguments:
        x: Array or scalar input to send.
        op (mpi4py.MPI.Op): The reduction operator (e.g :obj:`mpi4py.MPI.SUM`).
        comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to
            a clone of :obj:`COMM_WORLD`).
        token (Token): XLA token to use to ensure correct execution order.
            If not given, a new token is generated.

    Returns:
        Tuple[DeviceArray, Token]:
            - Result of the scan operation.
            - A new, modified token, that depends on this operation.

    """
    if token is None:
        token = create_token(x)

    if comm is None:
        comm = get_default_comm()

    op = wrap_as_hashable(op)
    comm = wrap_as_hashable(comm)
    return tuple(mpi_scan_p.bind(x, token, op=op, comm=comm))
Example #2
0
def allreduce(x, op, *, comm=None, token=None):
    """Perform an allreduce operation.

    .. note::

       This primitive can be differentiated via :func:`jax.grad` and related functions
       if ``op`` is :obj:`mpi4py.MPI.SUM`.

    Arguments:
        x: Array or scalar input.
        op (mpi4py.MPI.Op): The reduction operator (e.g :obj:`mpi4py.MPI.SUM`).
        comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to
            a clone of :obj:`COMM_WORLD`).
        token (Token): XLA token to use to ensure correct execution order.
            If not given, a new token is generated.

    Returns:
        Tuple[DeviceArray, Token]:
            - Result of the allreduce operation.
            - A new, modified token, that depends on this operation.

    """
    if token is None:
        token = create_token(x)

    if comm is None:
        comm = get_default_comm()

    op = wrap_as_hashable(op)
    comm = wrap_as_hashable(comm)
    return tuple(
        mpi_allreduce_p.bind(x, token, op=op, comm=comm, transpose=False))
Example #3
0
def Sendrecv(
    sendbuf,
    recvbuf,
    source,
    dest,
    sendtag=0,
    recvtag=_MPI.ANY_TAG,
    comm=_MPI.COMM_WORLD,
    status=None,
    token=None,
):
    if token is None:
        token = create_token(sendbuf)

    return mpi_sendrecv_p.bind(
        sendbuf,
        recvbuf,
        token,
        source=source,
        dest=dest,
        sendtag=sendtag,
        recvtag=recvtag,
        comm=comm,
        status=status,
    )
Example #4
0
 def f(x):
     token = lax.create_token(x)
     y, token = lax.infeed(token,
                           shape=jax.ShapedArray((3, 4), jnp.float32))
     token = lax.outfeed(token, y + np.float32(1))
     return x - 1 if config.omnistaging_enabled else lax.tie_in(
         token, x - 1)
Example #5
0
def Sendrecv(
    sendbuf,
    recvbuf,
    source,
    dest,
    sendtag=0,
    recvtag=_MPI.ANY_TAG,
    comm=_MPI.COMM_WORLD,
    status=None,
    token=None,
):
    if token is None:
        token = create_token(sendbuf)

    comm = wrap_as_hashable(comm)

    if status is not None:
        status = wrap_as_hashable(status)

    return mpi_sendrecv_p.bind(
        sendbuf,
        recvbuf,
        token,
        source=source,
        dest=dest,
        sendtag=sendtag,
        recvtag=recvtag,
        comm=comm,
        status=status,
    )
Example #6
0
def reduce(x, op, root, comm=None, token=None):
    """Perform a reduce operation.

    Arguments:
        x: Array or scalar input to send.
        op (mpi4py.MPI.Op): The reduction operator (e.g :obj:`mpi4py.MPI.SUM`).
        root (int): Rank of the root MPI process.
        comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to
            a clone of :obj:`COMM_WORLD`).
        token (Token): XLA token to use to ensure correct execution order.
            If not given, a new token is generated.

    Returns:
        Tuple[DeviceArray, Token]:
            - Result of the reduce operation on root process, otherwise
              unmodified input.
            - A new, modified token, that depends on this operation.

    """
    if token is None:
        token = create_token(x)

    if comm is None:
        comm = get_default_comm()

    rank = comm.Get_rank()

    op = wrap_as_hashable(op)
    comm = wrap_as_hashable(comm)
    res, token = mpi_reduce_p.bind(x, token, op=op, root=root, comm=comm)

    if rank != root:
        return x, token

    return res, token
Example #7
0
def _dummy_remat_result(aval: core.AbstractValue):
    """A result that will be discarded"""
    if aval is core.abstract_token:
        return lax.create_token()
    else:
        return lax.broadcast(np.array(0, dtype=aval.dtype),
                             aval.shape)  # type: ignore
Example #8
0
def Recv(
    x,
    source=_MPI.ANY_SOURCE,
    tag=_MPI.ANY_TAG,
    comm=_MPI.COMM_WORLD,
    status=None,
    token=None,
):
    """
    Recv(x, source=_MPI.ANY_SOURCE, tag=_MPI.ANY_TAG, comm=_MPI.COMM_WORLD, status=None, token=None)

    Receives the input`x` from the target rank `source` using the communicator `comm` 
    which defaults to the  world comunicator, with the `tag`.
    An optional token can be passed, which is used to force jax to execute
    MPI operations in the correct order.
    This is particularly important if you are performing different Send/Recv
    operations, which might otherwise deadlock.

    Argumemnts:
        x: Array or scalar input with the desired shape and dtype.
        source: rank of the source MPI process.
        tag: Tag of this message.
        comm: The communicator (defaults to MPI.COMM_WORLD)
        status: 
        token: token to force a sequential order in the operations (default=None)

    Returns:
        res: the received array or scalar
        new_token: a new, modified token, that depends on this operation. 
    """
    if token is None:
        token = create_token(x)

    out = mpi_recv_p.bind(x, token, source=source, tag=tag, comm=comm, status=status)
    return out
Example #9
0
 def f(x):
   token = lax.create_token(x)
   (y,), token = lax.infeed(
       token, shape=(jax.ShapedArray((3, 4), jnp.float32),))
   (z,), _ = lax.infeed(
       token, shape=(jax.ShapedArray((3, 1, 1), jnp.float32),))
   return x + y + z
Example #10
0
def Send(x, dest, tag=0, comm=_MPI.COMM_WORLD, token=None):
    """
    Send(x, dest, tag=0, comm=_MPI.COMM_WORLD, token=None)

    Sends the input`x` to the target rank `dest` using the communicator `comm` 
    which defaults to the  world comunicator, with the `tag`.
    An optional token can be passed, which is used to force jax to execute
    MPI operations in the correct order.
    This is particularly important if you are performing different Send/Recv
    operations, which might otherwise deadlock.

    Argumemnts:
        x: Array or scalar input.
        dest: rank of the target MPI process.
        tag: Tag of this message.
        comm: The communicator (defaults to MPI.COMM_WORLD)
        token: token to force a sequential order in the operations (default=None)

    Returns:
        new_token: a new, modified token, that depends on this operation. 
    """
    if token is None:
        token = create_token(x)

    out = mpi_send_p.bind(x, token, dest=dest, tag=tag, comm=comm)
    return out
Example #11
0
def sendrecv(
    sendbuf,
    recvbuf,
    source,
    dest,
    sendtag=0,
    recvtag=_MPI.ANY_TAG,
    comm=None,
    status=None,
    token=None,
):
    """Perform a sendrecv operation.

    .. warning::

        Unlike mpi4py's sendrecv, this returns a *new* array with the received data.

    Arguments:
        sendbuf: Array or scalar input to send.
        recvbuf: Array or scalar input with the correct shape and dtype. This can
           contain arbitrary data and will not be overwritten.
        source (int): Rank of the source MPI process.
        dest (int): Rank of the destination MPI process.
        sendtag (int): Tag of this message for sending.
        recvtag (int): Tag of this message for receiving.
        comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to
            a clone of :obj:`COMM_WORLD`).
        status (mpi4py.MPI.Status): Status object, can be used for introspection.
        token (Token): XLA token to use to ensure correct execution order.
            If not given, a new token is generated.

    Returns:
        Tuple[DeviceArray, Token]:
            - Received data.
            - A new, modified token, that depends on this operation.

    """
    if token is None:
        token = create_token(sendbuf)

    if comm is None:
        comm = get_default_comm()

    comm = wrap_as_hashable(comm)

    if status is not None:
        status = wrap_as_hashable(status)

    return mpi_sendrecv_p.bind(
        sendbuf,
        recvbuf,
        token,
        source=source,
        dest=dest,
        sendtag=sendtag,
        recvtag=recvtag,
        comm=comm,
        status=status,
    )
Example #12
0
 def device_train_loop(optimizer, dropout_rngs, metrics, step, epoch):
     # Create symbolic token for threading infeed data.
     token = lax.create_token(step)
     # Run on-device loop.
     optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop(
         device_train_loop_cond, device_train_loop_body,
         (optimizer, dropout_rngs, metrics, token, step, epoch))
     return optimizer, dropout_rngs, metrics, step
Example #13
0
 def device_train_loop(optimizer, state, metrics, step, loop):
     token = lax.create_token(step)
     optimizer, state, metrics, _, step, _ = lax.while_loop(
         device_train_loop_cond, device_train_loop_body,
         (optimizer, state, metrics, token, step, loop))
     state = sync_batchnorm_stats(state)
     metrics = allreduce_metrics(metrics)
     return optimizer, state, metrics, step
Example #14
0
 def host_loop_eval_step(model, state, metrics):
     token = lax.create_token(metrics['samples'])
     batch, token = lax.infeed(
         token,
         shape=(jax.ShapedArray(eval_input_shape, model_dtype),
                jax.ShapedArray((device_eval_batch_size, ), jnp.int32)))
     metrics = eval_step(model, state, batch, metrics, image_format,
                         space_to_depth)
     return metrics
Example #15
0
 def device_train_loop(optimizer, dropout_rng, total_loss, lm_loss,
                       sentence_loss, step, epoch, num_steps_per_epoch):
   """Device training loop."""
   token = lax.create_token(step)
   (optimizer, total_loss, lm_loss, sentence_loss, dropout_rng,
    _, step, epoch, num_steps_per_epoch) = lax.while_loop(
        device_train_loop_cond, device_train_loop_body,
        (optimizer, total_loss, lm_loss, sentence_loss, dropout_rng, token,
         step, epoch, num_steps_per_epoch))
   return optimizer, total_loss, lm_loss, sentence_loss, dropout_rng, step
Example #16
0
    def f_for_jit(x):
      token = lax.create_token(x)
      (y,), token = lax.infeed(
          token, shape=(jax.ShapedArray(x.shape, np.float32),))
      (z,), token = lax.infeed(
          token, shape=(jax.ShapedArray(x.shape, np.float32),))
      (w,), token = lax.infeed(
          token, shape=(jax.ShapedArray(x.shape, np.float32),))

      return x + y + z + w
Example #17
0
def scatter(
    x,
    root,
    *,
    comm=None,
    token=None,
):
    """Perform a scatter operation.

    .. warning::

        Unlike mpi4py's scatter, this returns a *new* array with the received data.

    .. warning::

        The expected shape of the first input varies between ranks. On the root process,
        it is ``(nproc, *input_shape)``. On all other processes, it is ``input_shape``.

    Arguments:
        x: Array or scalar input with the correct shape and dtype. On the root process,
           this contains the data to send, and its first axis must have size ``nproc``.
           On non-root processes, this may contain arbitrary data and will not be
           overwritten.
        root (int): Rank of the root MPI process.
        comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to
            a clone of :obj:`COMM_WORLD`).
        token (Token): XLA token to use to ensure correct execution order.
            If not given, a new token is generated.

    Returns:
        Tuple[DeviceArray, Token]:
            - Received data.
            - A new, modified token, that depends on this operation.

    """
    if token is None:
        token = create_token(x)

    if comm is None:
        comm = get_default_comm()

    rank = comm.Get_rank()
    if rank == root:
        size = comm.Get_size()
        if x.shape[0] != size:
            raise ValueError("Scatter input must have shape (nproc, ...)")

    comm = wrap_as_hashable(comm)

    return tuple(mpi_scatter_p.bind(
        x,
        token,
        root=root,
        comm=comm,
    ))
Example #18
0
 def host_loop_train_step(optimizer, state, metrics):
     token = lax.create_token(optimizer.state[0].step)
     batch, token = lax.infeed(token,
                               shape=(jax.ShapedArray(
                                   train_input_shape, model_dtype),
                                      jax.ShapedArray((device_batch_size, ),
                                                      jnp.int32)))
     optimizer, state, metrics = train_step(optimizer, state, batch,
                                            metrics, learning_rate_fn,
                                            image_format, space_to_depth)
     return optimizer, state, metrics
Example #19
0
def gather(
    x,
    root,
    *,
    comm=None,
    token=None,
):
    """Perform a gather operation.

    .. warning::

       ``x`` must have the same shape and dtype on all processes.

    .. warning::

        The shape of the returned data varies between ranks. On the root process,
        it is ``(nproc, *input_shape)``. On all other processes the output is
        identical to the input.

    Arguments:
        x: Array or scalar input to send.
        root (int): Rank of the root MPI process.
        comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to
            a clone of :obj:`COMM_WORLD`).
        token (Token): XLA token to use to ensure correct execution order.
            If not given, a new token is generated.

    Returns:
        Tuple[DeviceArray, Token]:
            - Received data on root process, otherwise unmodified input.
            - A new, modified token, that depends on this operation.
    """
    if token is None:
        token = create_token(x)

    if comm is None:
        comm = get_default_comm()

    rank = comm.Get_rank()
    comm = wrap_as_hashable(comm)

    res, token = mpi_gather_p.bind(
        x,
        token,
        root=root,
        comm=comm,
    )

    if rank != root:
        return (x, token)

    return (res, token)
Example #20
0
 def f_for_pjit(x):
     token = lax.create_token(x)
     # A replicated infeed
     (y, ), token = lax.infeed(token,
                               shape=(jax.ShapedArray(
                                   x.shape, np.float32), ),
                               partitions=(None, ))
     # An infeed sharded on first axis
     (z, ), token = lax.infeed(token,
                               shape=(jax.ShapedArray(
                                   x.shape, np.float32), ),
                               partitions=(P(nr_devices, 1), ))
     # An infeed sharded on second axis
     (w, ), token = lax.infeed(token,
                               shape=(jax.ShapedArray(
                                   x.shape, np.float32), ),
                               partitions=(P(1, nr_devices), ))
     return x + y + z + w
Example #21
0
def barrier(*, comm=None, token=None):
    """Perform a barrier operation.

    Arguments:
        comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to
            a clone of :obj:`COMM_WORLD`).
        token (Token): XLA token to use to ensure correct execution order.
            If not given, a new token is generated.

    Returns:
        Token:
            - A new, modified token, that depends on this operation.

    """
    if token is None:
        token = create_token()

    if comm is None:
        comm = get_default_comm()

    comm = wrap_as_hashable(comm)
    return mpi_barrier_p.bind(token, comm=comm)
Example #22
0
def alltoall(
    x,
    *,
    comm=None,
    token=None,
):
    """Perform an alltoall operation.

    Arguments:
        x: Array input to send. First axis must have size ``nproc``.
        comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to
            a clone of :obj:`COMM_WORLD`).
        token (Token): XLA token to use to ensure correct execution order.
            If not given, a new token is generated.

    Returns:
        Tuple[DeviceArray, Token]:
            - Received data.
            - A new, modified token, that depends on this operation.

    """
    if token is None:
        token = create_token(x)

    if comm is None:
        comm = get_default_comm()

    size = comm.Get_size()
    if x.shape[0] != size:
        raise ValueError("Alltoall input must have shape (nproc, ...)")

    comm = wrap_as_hashable(comm)

    return tuple(mpi_alltoall_p.bind(
        x,
        token,
        comm=comm,
    ))
Example #23
0
def allgather(
    x,
    *,
    comm=None,
    token=None,
):
    """Perform an allgather operation.

    .. warning::

       ``x`` must have the same shape and dtype on all processes.

    Arguments:
        x: Array or scalar input to send.
        comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to
            a clone of :obj:`COMM_WORLD`).
        token (Token): XLA token to use to ensure correct execution order.
            If not given, a new token is generated.

    Returns:
        Tuple[DeviceArray, Token]:
            - Received data.
            - A new, modified token, that depends on this operation.

    """
    if token is None:
        token = create_token(x)

    if comm is None:
        comm = get_default_comm()

    comm = wrap_as_hashable(comm)

    return tuple(mpi_allgather_p.bind(
        x,
        token,
        comm=comm,
    ))
Example #24
0
def bcast(x, root, comm=None, token=None):
    """Perform a bcast (broadcast) operation.

    .. warning::

        Unlike mpi4py's bcast, this returns a *new* array with the received data.

    Arguments:
        x: Array or scalar input. Data is only read on root process. On non-root
           processes, this is used to determine the shape and dtype of the result.
        root (int): The process to use as source.
        comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to
            a clone of :obj:`COMM_WORLD`).
        token (Token): XLA token to use to ensure correct execution order.
            If not given, a new token is generated.

    Returns:
        Tuple[DeviceArray, Token]:
            - Received data.
            - A new, modified token, that depends on this operation.

    """
    if token is None:
        token = create_token(x)

    if comm is None:
        comm = get_default_comm()

    rank = comm.Get_rank()

    comm = wrap_as_hashable(comm)
    res, token = mpi_bcast_p.bind(x, token, root=root, comm=comm)

    if rank == root:
        return x, token

    return res, token
Example #25
0
def Allreduce(x, op, comm=_MPI.COMM_WORLD, token=None):
    """
    Allreduce(x, op, comm=_MPI.COMM_WORLD, token=None)

    Performs the Allreduce operation `op` on the input `x` using the
    communicator `comm` which defaults to the world comunicator.
    An optional token can be passed, which is used to force jax to execute
    MPI operations in the correct order.

    Argumemnts:
        x: Array or scalar input.
        op: The reduction operation `MPI.Op` (e.g: MPI.SUM)
        comm: The communicator (defaults to MPI.COMM_WORLD)
        token: token to force a sequential order in the operations (default=None)

    Returns:
        res: result of the allreduce operation
        new_token: a new, modified token, that depends on this operation. 
            This result can be ignored if result forces a data dependency. 
    """
    if token is None:
        token = create_token(x)

    return mpi_allreduce_p.bind(x, token, op=op, comm=comm)
Example #26
0
def send(x, dest, tag=0, comm=None, token=None):
    """Perform a send operation.

    Arguments:
        x: Array or scalar input to send.
        dest (int): Rank of the destination MPI process.
        tag (int): Tag of this message.
        comm (mpi4py.MPI.Comm): The MPI communicator to use (defaults to
            a clone of :obj:`COMM_WORLD`).
        token (Token): XLA token to use to ensure correct execution order.
            If not given, a new token is generated.

    Returns:
        Token: A new, modified token, that depends on this operation.

    """
    if token is None:
        token = create_token(x)

    if comm is None:
        comm = get_default_comm()

    comm = wrap_as_hashable(comm)
    return mpi_send_p.bind(x, token, dest=dest, tag=tag, comm=comm)
Example #27
0
 def f(x):
     token = lax.create_token(x)
     (y, z), token = lax.infeed(token,
                                infeed_shapes,
                                partitions=infeed_parts)
     return x @ y.T + z
Example #28
0
 def device_train_loop(optimizer, state, metrics, step, epoch):
     token = lax.create_token(step)
     optimizer, state, metrics, _, step, _ = lax.while_loop(
         device_train_loop_cond, device_train_loop_body,
         (optimizer, state, metrics, token, step, epoch))
     return optimizer, state, metrics, step
Example #29
0
 def f(x):
   token = lax.create_token(x)
   y, token = lax.infeed(
       token, shape=jax.ShapedArray((3, 4), jnp.float32))
   token = lax.outfeed(token, y + np.float32(1))
   return x - 1
Example #30
0
 def f(x):
   token = lax.create_token(x)
   token = lax.outfeed(token, x, partitions=(None,))
   token = lax.outfeed(token, x, partitions=(P(nr_devices, 1),))
   token = lax.outfeed(token, x, partitions=(P(1, nr_devices),))
   return x