Esempio n. 1
0
def _get_rank_helper(group, backend):
    """
    The Helper to do get_rank_id.

    Args:
        group (str): The communication group.
        backend (str): The backend, like "hccl".

    Raises:
        ValueError: If backend is invalid.

    Returns:
        Integer. The local rank id of the calling process.
    """
    rank_id = None
    if _is_role_pserver() or _is_role_sched():
        rank_id = 0
        return rank_id
    if backend == Backend.HCCL:
        if group == HCCL_WORLD_COMM_GROUP:
            rank_id = hccl.get_rank_id()
        else:
            rank_id = hccl.get_rank_id(group)
    elif backend == Backend.NCCL:
        rank_id = mpi.get_rank_id(group)
    else:
        raise ValueError("Invalid backend: '{}'".format(backend))
    return rank_id
Esempio n. 2
0
def _get_rank_helper(group, backend):
    """
    The Helper to do get_rank_id.

    Args:
        group (str): The communication group.
        backend (str): The backend, like "hccl".

    Raises:
        ValueError: If backend is invalid.

    Returns:
        Integer. The local rank id of the calling process.
    """
    rank_id = None
    if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
        rank_id = 0
        return rank_id
    if backend == Backend.HCCL:
        if group == HCCL_WORLD_COMM_GROUP:
            rank_id = hccl.get_rank_id()
        else:
            rank_id = hccl.get_rank_id(group)
    elif backend == Backend.NCCL:
        if group == NCCL_WORLD_COMM_GROUP:
            rank_id = mpi.get_rank_id()
        else:
            raise RuntimeError(
                "Nccl doesn't support get_rank_id by user group now.")
    else:
        raise ValueError("Invalid backend: '{}'".format(backend))
    return rank_id
Esempio n. 3
0
def test_net_reduce_scatter():
    x = np.arange(12).astype(np.float32) * 0.1

    reducescatter = Net()
    rankid = mpi.get_rank_id()
    print("self rankid:", rankid)
    output = reducescatter(Tensor(x, mstype.float32))
    print("output:\n", output)
    if rankid == 0:
        expect_result = np.arange(4).astype(np.float32) * 0.3
    if rankid == 1:
        expect_result = np.arange(4, 8).astype(np.float32) * 0.3
    if rankid == 2:
        expect_result = np.arange(8, 12).astype(np.float32) * 0.3
    diff = abs(output.asnumpy() - expect_result)
    error = np.ones(shape=expect_result.shape) * 1.0e-6
    assert np.all(diff < error)

    allgather = AllGatherNet()
    allgather_output = allgather(output)
    print("allgather result:\n", allgather_output)
    expect_allgather_result = np.arange(12).astype(np.float32) * 0.3
    diff = abs(allgather_output.asnumpy() - expect_allgather_result)
    error = np.ones(shape=expect_allgather_result.shape) * 1.0e-6
    assert np.all(diff < error)