Example #1
0
    def __getstate__(self):
        pg_state = ShardedTensor.ProcessGroupState(
            distributed_c10d.get_rank(self._process_group),
            distributed_c10d.get_rank(),
            distributed_c10d.get_world_size(self._process_group),
            distributed_c10d.get_world_size(),
        )

        return self._local_shards, self._metadata, pg_state, self._sharding_spec, self._init_rrefs
Example #2
0
def allreduce(data: torch.Tensor):
    """ return the chunk
    """

    c10d.all_reduce(data.view(-1))
    rank = c10d.get_rank()
    world_size = c10d.get_world_size()
    chunk_sz = int(data.numel() / world_size)
    return data.view(-1).narrow(0, rank * chunk_sz, chunk_sz)
Example #3
0
def reduce_scatter(data: torch.Tensor):
    data = data.view(-1)
    rank = c10d.get_rank()
    world_size = c10d.get_world_size()
    chunk_sz = int(data.numel() / world_size)
    offset = rank * chunk_sz
    chunk = data.narrow(0, offset, chunk_sz)
    c10d._reduce_scatter_base(chunk, data)
    return chunk
Example #4
0
    def __setstate__(self, state):
        self._sharded_tensor_id = None
        if not distributed_c10d.is_initialized():
            raise RuntimeError(
                'Need to initialize default process group using '
                '"init_process_group" before loading ShardedTensor')

        self._local_shards, self._metadata, pg_state, self._sharding_spec, self._init_rrefs = state

        # Setup process group
        global _CURRENT_PROCESS_GROUP
        if _CURRENT_PROCESS_GROUP is None:
            self._process_group = distributed_c10d._get_default_group()
        else:
            self._process_group = _CURRENT_PROCESS_GROUP

        # Validate process group.
        local_rank = distributed_c10d.get_rank(self._process_group)
        if pg_state.local_rank != local_rank:
            raise RuntimeError(
                f'Local rank at save time was {pg_state.local_rank}, but at '
                f'load time was {local_rank}')

        global_rank = distributed_c10d.get_rank()
        if pg_state.global_rank != global_rank:
            raise RuntimeError(
                f'Global rank at save time was {pg_state.global_rank}, but at '
                f'load time was {global_rank}')

        local_world_size = distributed_c10d.get_world_size(self._process_group)
        if pg_state.local_world_size != local_world_size:
            raise RuntimeError(
                f'Local world size at save time was {pg_state.local_world_size}, '
                f'but at load time was {local_world_size}')

        global_world_size = distributed_c10d.get_world_size()
        if pg_state.global_world_size != global_world_size:
            raise RuntimeError(
                f'Global world size at save time was {pg_state.global_world_size}, '
                f'but at load time was {global_world_size}')

        self._post_init()
def fidelity_check_hierarchy_allgather(local_rank, inter_shard_group,
                                       intra_shard_group):
    """"""
    torch.manual_seed(10)
    n = 10
    dtype = torch.half
    partition_nelem = 10e6  # 10M elements
    world_size = c10d.get_world_size()
    rank = c10d.get_rank()

    nelem = int(partition_nelem * world_size)

    # all ranks suppose to have same
    test_tensors = []
    for _ in range(n):
        test_tensors.append(
            torch.rand(nelem, dtype=dtype, device=f'cuda:{local_rank}'))

    # get inputs and outputs
    outputs_for_vanila, inputs_for_vanila = _get_outputs_inputs(
        test_tensors, rank, world_size)
    c10d._all_gather_base_coalesced(outputs_for_vanila, inputs_for_vanila)
    for i, j in zip(outputs_for_vanila, test_tensors):
        print(f"vanila all close {torch.allclose(i, j)}")

    # torch.cuda.synchronize()

    comm_stream = torch.cuda.Stream()
    outputs_for_hierarchy, inputs_for_hierarchy = _get_outputs_inputs(
        test_tensors, rank, world_size)
    # with torch.cuda.stream(comm_stream):
    # with torch.cuda.stream(torch.cuda.default_stream()):
    with torch.cuda.stream(torch.cuda.current_stream()):
        handle = _hierarchy_allgather(outputs_for_hierarchy,
                                      inputs_for_hierarchy, intra_shard_group,
                                      inter_shard_group, world_size,
                                      local_rank)
        handle.wait()

    for i, j in zip(outputs_for_hierarchy, test_tensors):
        print(f"hierarchy all close {torch.allclose(i, j)}")
Example #6
0
def main():
    """"""
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument('--data-type', default='bfloat16', type=str)
    arg_parser.add_argument('--local_rank', type=int)
    args = arg_parser.parse_args()

    c10d.init_process_group(backend='nccl')
    if args.data_type == 'bfloat16':
        data_type = torch.bfloat16
    elif args.data_type in ('float16', 'fp16'):
        data_type = torch.half
    elif args.data_type in ('float32', 'fp32'):
        data_type = torch.float32
    else:
        raise RuntimeError(f"unsupported data type {args.data_type}")

    rank = c10d.get_rank()
    before_sync_data = data[rank]
    for_reduce_scatter = torch.tensor(before_sync_data,
                                      device=f'cuda:{rank}',
                                      dtype=data_type)
    for_allreduce = torch.tensor(before_sync_data,
                                 device=f'cuda:{rank}',
                                 dtype=data_type)

    rc_part = reduce_scatter(for_reduce_scatter)
    ar_part = allreduce(for_allreduce)

    errors = []
    if rank == 7:
        for i, j in zip(
                rc_part.to(torch.float32).cpu().numpy().tolist(),
                ar_part.to(torch.float32).cpu().numpy().tolist()):
            # print(i, j, i - j)
            errors.append(i - j)
        print(f'error range [{np.min(errors)}, {np.max(errors)}]')
        diff_norm = torch.norm(rc_part - ar_part)
        print(f'norm diff {diff_norm}')
def print_at_rank0(msg):
    if c10d.get_rank() == 0:
        print(msg)