def _get_rank_helper(group, backend): """ The Helper to do get_rank_id. Args: group (str): The communication group. backend (str): The backend, like "hccl". Raises: ValueError: If backend is invalid. Returns: Integer. The local rank id of the calling process. """ rank_id = None if _is_role_pserver() or _is_role_sched(): rank_id = 0 return rank_id if backend == Backend.HCCL: if group == HCCL_WORLD_COMM_GROUP: rank_id = hccl.get_rank_id() else: rank_id = hccl.get_rank_id(group) elif backend == Backend.NCCL: rank_id = mpi.get_rank_id(group) else: raise ValueError("Invalid backend: '{}'".format(backend)) return rank_id
def _get_rank_helper(group, backend): """ The Helper to do get_rank_id. Args: group (str): The communication group. backend (str): The backend, like "hccl". Raises: ValueError: If backend is invalid. Returns: Integer. The local rank id of the calling process. """ rank_id = None if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): rank_id = 0 return rank_id if backend == Backend.HCCL: if group == HCCL_WORLD_COMM_GROUP: rank_id = hccl.get_rank_id() else: rank_id = hccl.get_rank_id(group) elif backend == Backend.NCCL: if group == NCCL_WORLD_COMM_GROUP: rank_id = mpi.get_rank_id() else: raise RuntimeError( "Nccl doesn't support get_rank_id by user group now.") else: raise ValueError("Invalid backend: '{}'".format(backend)) return rank_id
def test_net_reduce_scatter(): x = np.arange(12).astype(np.float32) * 0.1 reducescatter = Net() rankid = mpi.get_rank_id() print("self rankid:", rankid) output = reducescatter(Tensor(x, mstype.float32)) print("output:\n", output) if rankid == 0: expect_result = np.arange(4).astype(np.float32) * 0.3 if rankid == 1: expect_result = np.arange(4, 8).astype(np.float32) * 0.3 if rankid == 2: expect_result = np.arange(8, 12).astype(np.float32) * 0.3 diff = abs(output.asnumpy() - expect_result) error = np.ones(shape=expect_result.shape) * 1.0e-6 assert np.all(diff < error) allgather = AllGatherNet() allgather_output = allgather(output) print("allgather result:\n", allgather_output) expect_allgather_result = np.arange(12).astype(np.float32) * 0.3 diff = abs(allgather_output.asnumpy() - expect_allgather_result) error = np.ones(shape=expect_allgather_result.shape) * 1.0e-6 assert np.all(diff < error)