def __getstate__(self): pg_state = ShardedTensor.ProcessGroupState( distributed_c10d.get_rank(self._process_group), distributed_c10d.get_rank(), distributed_c10d.get_world_size(self._process_group), distributed_c10d.get_world_size(), ) return self._local_shards, self._metadata, pg_state, self._sharding_spec, self._init_rrefs
def allreduce(data: torch.Tensor): """ return the chunk """ c10d.all_reduce(data.view(-1)) rank = c10d.get_rank() world_size = c10d.get_world_size() chunk_sz = int(data.numel() / world_size) return data.view(-1).narrow(0, rank * chunk_sz, chunk_sz)
def reduce_scatter(data: torch.Tensor): data = data.view(-1) rank = c10d.get_rank() world_size = c10d.get_world_size() chunk_sz = int(data.numel() / world_size) offset = rank * chunk_sz chunk = data.narrow(0, offset, chunk_sz) c10d._reduce_scatter_base(chunk, data) return chunk
def __setstate__(self, state): self._sharded_tensor_id = None if not distributed_c10d.is_initialized(): raise RuntimeError( 'Need to initialize default process group using ' '"init_process_group" before loading ShardedTensor') self._local_shards, self._metadata, pg_state, self._sharding_spec, self._init_rrefs = state # Setup process group global _CURRENT_PROCESS_GROUP if _CURRENT_PROCESS_GROUP is None: self._process_group = distributed_c10d._get_default_group() else: self._process_group = _CURRENT_PROCESS_GROUP # Validate process group. local_rank = distributed_c10d.get_rank(self._process_group) if pg_state.local_rank != local_rank: raise RuntimeError( f'Local rank at save time was {pg_state.local_rank}, but at ' f'load time was {local_rank}') global_rank = distributed_c10d.get_rank() if pg_state.global_rank != global_rank: raise RuntimeError( f'Global rank at save time was {pg_state.global_rank}, but at ' f'load time was {global_rank}') local_world_size = distributed_c10d.get_world_size(self._process_group) if pg_state.local_world_size != local_world_size: raise RuntimeError( f'Local world size at save time was {pg_state.local_world_size}, ' f'but at load time was {local_world_size}') global_world_size = distributed_c10d.get_world_size() if pg_state.global_world_size != global_world_size: raise RuntimeError( f'Global world size at save time was {pg_state.global_world_size}, ' f'but at load time was {global_world_size}') self._post_init()
def fidelity_check_hierarchy_allgather(local_rank, inter_shard_group, intra_shard_group): """""" torch.manual_seed(10) n = 10 dtype = torch.half partition_nelem = 10e6 # 10M elements world_size = c10d.get_world_size() rank = c10d.get_rank() nelem = int(partition_nelem * world_size) # all ranks suppose to have same test_tensors = [] for _ in range(n): test_tensors.append( torch.rand(nelem, dtype=dtype, device=f'cuda:{local_rank}')) # get inputs and outputs outputs_for_vanila, inputs_for_vanila = _get_outputs_inputs( test_tensors, rank, world_size) c10d._all_gather_base_coalesced(outputs_for_vanila, inputs_for_vanila) for i, j in zip(outputs_for_vanila, test_tensors): print(f"vanila all close {torch.allclose(i, j)}") # torch.cuda.synchronize() comm_stream = torch.cuda.Stream() outputs_for_hierarchy, inputs_for_hierarchy = _get_outputs_inputs( test_tensors, rank, world_size) # with torch.cuda.stream(comm_stream): # with torch.cuda.stream(torch.cuda.default_stream()): with torch.cuda.stream(torch.cuda.current_stream()): handle = _hierarchy_allgather(outputs_for_hierarchy, inputs_for_hierarchy, intra_shard_group, inter_shard_group, world_size, local_rank) handle.wait() for i, j in zip(outputs_for_hierarchy, test_tensors): print(f"hierarchy all close {torch.allclose(i, j)}")
def main(): """""" arg_parser = argparse.ArgumentParser() arg_parser.add_argument('--data-type', default='bfloat16', type=str) arg_parser.add_argument('--local_rank', type=int) args = arg_parser.parse_args() c10d.init_process_group(backend='nccl') if args.data_type == 'bfloat16': data_type = torch.bfloat16 elif args.data_type in ('float16', 'fp16'): data_type = torch.half elif args.data_type in ('float32', 'fp32'): data_type = torch.float32 else: raise RuntimeError(f"unsupported data type {args.data_type}") rank = c10d.get_rank() before_sync_data = data[rank] for_reduce_scatter = torch.tensor(before_sync_data, device=f'cuda:{rank}', dtype=data_type) for_allreduce = torch.tensor(before_sync_data, device=f'cuda:{rank}', dtype=data_type) rc_part = reduce_scatter(for_reduce_scatter) ar_part = allreduce(for_allreduce) errors = [] if rank == 7: for i, j in zip( rc_part.to(torch.float32).cpu().numpy().tolist(), ar_part.to(torch.float32).cpu().numpy().tolist()): # print(i, j, i - j) errors.append(i - j) print(f'error range [{np.min(errors)}, {np.max(errors)}]') diff_norm = torch.norm(rc_part - ar_part) print(f'norm diff {diff_norm}')
def print_at_rank0(msg): if c10d.get_rank() == 0: print(msg)