예제 #1
0
    def __setstate__(self, state):
        self._sharded_tensor_id = None
        if not distributed_c10d.is_initialized():
            raise RuntimeError(
                'Need to initialize default process group using '
                '"init_process_group" before loading ShardedTensor')

        self._local_shards, self._metadata, pg_state, self._sharding_spec, self._init_rrefs = state

        # Setup process group
        if _CURRENT_PROCESS_GROUP is None:
            self._process_group = distributed_c10d._get_default_group()
        else:
            self._process_group = _CURRENT_PROCESS_GROUP

        # Validate process group.
        local_rank = distributed_c10d.get_rank(self._process_group)
        if pg_state.local_rank != local_rank:
            raise RuntimeError(
                f'Local rank at save time was {pg_state.local_rank}, but at '
                f'load time was {local_rank}')

        global_rank = distributed_c10d.get_rank()
        if pg_state.global_rank != global_rank:
            raise RuntimeError(
                f'Global rank at save time was {pg_state.global_rank}, but at '
                f'load time was {global_rank}')

        local_world_size = distributed_c10d.get_world_size(self._process_group)
        if pg_state.local_world_size != local_world_size:
            raise RuntimeError(
                f'Local world size at save time was {pg_state.local_world_size}, '
                f'but at load time was {local_world_size}')

        global_world_size = distributed_c10d.get_world_size()
        if pg_state.global_world_size != global_world_size:
            raise RuntimeError(
                f'Global world size at save time was {pg_state.global_world_size}, '
                f'but at load time was {global_world_size}')

        self._post_init()
def fidelity_check_hierarchy_allgather(local_rank, inter_shard_group,
                                       intra_shard_group):
    """"""
    torch.manual_seed(10)
    n = 10
    dtype = torch.half
    partition_nelem = 10e6  # 10M elements
    world_size = c10d.get_world_size()
    rank = c10d.get_rank()

    nelem = int(partition_nelem * world_size)

    # all ranks suppose to have same
    test_tensors = []
    for _ in range(n):
        test_tensors.append(
            torch.rand(nelem, dtype=dtype, device=f'cuda:{local_rank}'))

    # get inputs and outputs
    outputs_for_vanila, inputs_for_vanila = _get_outputs_inputs(
        test_tensors, rank, world_size)
    c10d._all_gather_base_coalesced(outputs_for_vanila, inputs_for_vanila)
    for i, j in zip(outputs_for_vanila, test_tensors):
        print(f"vanila all close {torch.allclose(i, j)}")

    # torch.cuda.synchronize()

    comm_stream = torch.cuda.Stream()
    outputs_for_hierarchy, inputs_for_hierarchy = _get_outputs_inputs(
        test_tensors, rank, world_size)
    # with torch.cuda.stream(comm_stream):
    # with torch.cuda.stream(torch.cuda.default_stream()):
    with torch.cuda.stream(torch.cuda.current_stream()):
        handle = _hierarchy_allgather(outputs_for_hierarchy,
                                      inputs_for_hierarchy, intra_shard_group,
                                      inter_shard_group, world_size,
                                      local_rank)
        handle.wait()

    for i, j in zip(outputs_for_hierarchy, test_tensors):
        print(f"hierarchy all close {torch.allclose(i, j)}")
예제 #3
0
def _parse_and_validate_remote_device(pg, remote_device):

    worker_name = remote_device.worker_name()
    rank = remote_device.rank()
    device = remote_device.device()

    # Validate rank, skip validation if rank is not part of process group.
    if not distributed_c10d._rank_not_in_group(pg):
        if rank is not None and (rank < 0 or rank >= distributed_c10d.get_world_size(pg)):
            raise ValueError(f'Invalid rank: {rank}')

    if worker_name is not None:
        if not rpc._is_current_rpc_agent_set():
            raise RuntimeError(f'RPC framework needs to be initialized for using worker names: {worker_name}')

        workers = rpc._get_current_rpc_agent().get_worker_infos()
        for worker in workers:
            if worker.name == worker_name:
                return worker.id, device

        raise ValueError(f'Invalid worker name: {worker_name}')

    return rank, device