예제 #1
0
def join(device=-1):
    """A function that indicates that the rank finished processing data.

    All ranks that did not call join() continue to process allreduce operations.
    This function blocks Python thread until all ranks join.

    Arguments:
        device: An id of the device to create temprorary zero tensors (default -1, CPU)

    Returns:
        Id of the rank that joined last.
    """
    if not _v2_api:
        raise NotImplementedError("Join Op is not supported for PyTorch < 1.0")
    return mpi_lib.horovod_torch_join(device)
예제 #2
0
def join(device=-1):
    """A function that indicates that the rank finished processing data.

    All ranks that did not call join() continue to process allreduce operations.
    This function blocks Python thread until all ranks join.

    Arguments:
        device: An id of the device to create temprorary zero tensors (default -1, CPU)

    Returns:
        Id of the rank that joined last.
    """
    try:
        return mpi_lib.horovod_torch_join(device)
    except RuntimeError as e:
        raise HorovodInternalError(e)
예제 #3
0
파일: mpi_ops.py 프로젝트: rongou/horovod
def join(device=-1) -> int:
    """A function that indicates that the rank finished processing data.

    All ranks that did not call join() continue to process allreduce operations.
    This function blocks Python thread until all ranks join.

    Arguments:
        device: An id of the device to create temprorary zero tensors (default -1, CPU)

    Returns:
        Id of the rank that joined last.
    """
    output = torch.tensor(-1, dtype=torch.int, device=torch.device("cpu"))
    try:
        handle = mpi_lib.horovod_torch_join(output, device)
    except RuntimeError as e:
        raise HorovodInternalError(e)

    _handle_map[handle] = (None, output)

    return synchronize(handle).item()