def _test_sequence_num_incremented(self, process_group, ranks): # verify initial sequence numbers. Use a distinct process group for # verification to keep counts as expected with respect to process_group. verify_pg = dist.new_group( ranks=ranks, backend="gloo", ) assert dist.get_world_size(process_group) == dist.get_world_size(verify_pg) initial_num = ( self._verify_sequence_number_across_pg( pg=process_group, verify_pg=verify_pg ) if not c10d._rank_not_in_group(process_group) else -1 ) # Verify sequence numbers are appropriately incremented for i in range(10): t = torch.ones(1, device=torch.cuda.current_device()) dist.all_reduce(t, group=process_group) if not c10d._rank_not_in_group(process_group): seq_num = self._verify_sequence_number_across_pg( pg=process_group, verify_pg=verify_pg, ) self.assertEqual(initial_num + i + 1, seq_num) if dist.get_world_size(process_group) > 2: # Test when certain ranks don't call collectives if dist.get_rank(process_group) not in [0, 2]: dist.all_reduce(t, group=process_group, async_op=True) # Now ranks 0 and 2 should be lagging by 1. if not c10d._rank_not_in_group(process_group): seq_num = process_group._get_sequence_number_for_group() rank = dist.get_rank(process_group) obj_list = [None for _ in range(dist.get_world_size(verify_pg))] dist.all_gather_object(obj_list, (rank, seq_num), group=verify_pg) rank_to_seq_num = {rank: num for (rank, num) in obj_list} self.assertEqual(len(set(rank_to_seq_num.values())), 2) self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[2]) expected_same = { rank_to_seq_num[i] for i in rank_to_seq_num.keys() if i not in [0, 2] } self.assertEqual(len(expected_same), 1) self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[1])
def _parse_and_validate_remote_device(pg, remote_device): worker_name = remote_device.worker_name() rank = remote_device.rank() device = remote_device.device() # Validate rank, skip validation if rank is not part of process group. if not distributed_c10d._rank_not_in_group(pg): if rank is not None and (rank < 0 or rank >= distributed_c10d.get_world_size(pg)): raise ValueError(f'Invalid rank: {rank}') if worker_name is not None: if not rpc._is_current_rpc_agent_set(): raise RuntimeError( f'RPC framework needs to be initialized for using worker names: {worker_name}' ) workers = rpc._get_current_rpc_agent().get_worker_infos() for worker in workers: if worker.name == worker_name: return worker.id, device raise ValueError(f'Invalid worker name: {worker_name}') return rank, device
def irecv(self, tensor, src=None, tag=0): # pylint: disable=protected-access # Original irecv doesn't support recv from any # but original recv does. They are essentially # the same except recv have a wait() call dist_c10d._check_single_tensor(tensor, "tensor") if dist_c10d._rank_not_in_group(self.group): return -1 if self.group == dist_c10d.GroupMember.WORLD: dist_c10d._check_default_pg() pg = dist_c10d._default_pg else: pg = self.group if src is None: work = pg.recv_anysource([tensor], tag) src_rank = work.source_rank() if self.group == dist_c10d.GroupMember.WORLD: return src_rank else: return dist_c10d._get_global_rank(pg, src_rank) else: if self.group == dist_c10d.GroupMember.WORLD: pg.recv([tensor], src, tag).wait() else: group_src_rank = dist_c10d._get_group_rank(pg, src) pg.recv([tensor], group_src_rank, tag).wait() return src
def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None): if len(args) != 2: raise ValueError( f'Expected two arguments for torch.{cmp_fun.__name__}') result = True st1 = args[0] st2 = args[1] if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)): raise TypeError( f'Both arguments to torch.{cmp_fun.__name__} need to be of type ShardedTensor' ) # Verify same PG if st1._process_group != st2._process_group: return False if distributed_c10d._rank_not_in_group( st1._process_group) or distributed_c10d._rank_not_in_group( st2._process_group): return distributed_c10d._rank_not_in_group( st1._process_group) == distributed_c10d._rank_not_in_group( st2._process_group) # Verify metadata if st1.metadata() != st2.metadata(): return _communicate_result(False, st1._process_group) # Verify number of local shards st1_local_shards = st1.local_shards() st2_local_shards = st2.local_shards() if len(st1_local_shards) != len(st2_local_shards): return _communicate_result(False, st1._process_group) # kwargs must be dict-like if kwargs is None: kwargs = {} # Verify each local shard for idx in range(len(st1_local_shards)): if st1_local_shards[idx].metadata != st2_local_shards[idx].metadata: return _communicate_result(False, st1._process_group) if not cmp_fun(st1_local_shards[idx].tensor, st2_local_shards[idx].tensor, **kwargs): return _communicate_result(False, st1._process_group) return _communicate_result(True, st1._process_group)
def irecv(self, tensor, src=None, tag=0): # pragma: no cover """ Returns: An object you can call .wait() on, .wait() will return the source rank. """ # pylint: disable=protected-access # Original irecv doesn't support recv from any # but original recv does. They are essentially # the same except recv have a wait() call dist_c10d._check_single_tensor(tensor, "tensor") if dist_c10d._rank_not_in_group(self.group): class Waiter: def wait(self): return -1 return Waiter() if self.group == dist_c10d.GroupMember.WORLD: dist_c10d._check_default_pg() pg = dist_c10d._default_pg else: pg = self.group if src is None: work = pg.recv_anysource([tensor], tag) if self.group == dist_c10d.GroupMember.WORLD: class Waiter: def wait(self): nonlocal work work.wait() return work.source_rank() return Waiter() else: class Waiter: def wait(self): nonlocal work, pg work.wait() src_rank = work.source_rank() return dist_c10d._get_global_rank(pg, src_rank) return Waiter() else: if self.group == dist_c10d.GroupMember.WORLD: work = pg.recv([tensor], src, tag) else: group_src_rank = dist_c10d._get_group_rank(pg, src) work = pg.recv([tensor], group_src_rank, tag) class Waiter: def wait(self): nonlocal src work.wait() return src return Waiter()
def equal(types, args=(), kwargs=None, process_group=None): if len(args) != 2: raise ValueError('Expected two arguments for torch.equal') result = True st1 = args[0] st2 = args[1] if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)): raise TypeError( 'Both arguments to torch.equal need to be of type ShardedTensor') # Verify same PG if st1._process_group != st2._process_group: return False if distributed_c10d._rank_not_in_group( st1._process_group) or distributed_c10d._rank_not_in_group( st2._process_group): return distributed_c10d._rank_not_in_group( st1._process_group) == distributed_c10d._rank_not_in_group( st2._process_group) # Verify metadata if st1.metadata() != st2.metadata(): return _communicate_result(False, st1._process_group) # Verify number of local shards st1_local_shards = st1.local_shards() st2_local_shards = st2.local_shards() if len(st1_local_shards) != len(st2_local_shards): return _communicate_result(False, st1._process_group) # Verify each local shard for idx in range(len(st1_local_shards)): if st1_local_shards[idx].metadata != st2_local_shards[idx].metadata: return _communicate_result(False, st1._process_group) if not torch.equal(st1_local_shards[idx].tensor, st2_local_shards[idx].tensor): return _communicate_result(False, st1._process_group) return _communicate_result(True, st1._process_group)
def _test_sequence_num_set_new_group(self, backend): store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( backend, world_size=self.world_size, rank=self.rank, store=store, ) subgroup = dist.new_group([0, 1]) if not c10d._rank_not_in_group(subgroup): subgroup_seq = subgroup._get_sequence_number_for_group() obj_list = [None for _ in range(dist.get_world_size(subgroup))] dist.all_gather_object(obj_list, subgroup_seq, group=subgroup) self.assertEqual(len(set(obj_list)), 1)
def _prepare_init(self, process_group=None): self._rpc_initialized = False self._sharded_tensor_id = None if rpc._is_current_rpc_agent_set(): # Validate PG and RPC ranks match. pg_rank = dist.get_rank() rpc_rank = rpc.get_worker_info().id if pg_rank != rpc_rank: raise ValueError( f'Default ProcessGroup and RPC ranks must be ' f'the same for ShardedTensor, found process group rank: ' f'{pg_rank} and RPC rank: {rpc_rank}') self._process_group = (process_group if process_group is not None else distributed_c10d._get_default_group()) if distributed_c10d._rank_not_in_group(self._process_group): raise ValueError( f'Global rank: {dist.get_rank()} not part of process group') self._local_shards: List[Shard] = [] self._remote_shards: Dict[int, List[rpc.RRef[Shard]]] = {} self._sharding_metadata: List[ShardMetadata] = []
def _parse_and_validate_remote_device(self, device): on, local_device = _parse_remote_device(device) # Validate rank, skip validation if rank is not part of process group. if not distributed_c10d._rank_not_in_group(self._process_group): if isinstance(on, int) and ( on < 0 or on >= dist.get_world_size(self._process_group)): raise ValueError(f'Invalid rank: {on}') if isinstance(on, str): if not rpc._is_current_rpc_agent_set(): raise RuntimeError( f'RPC framework needs to be initialized for using worker names: {on}' ) workers = rpc._get_current_rpc_agent().get_worker_infos() for worker in workers: if worker.name == on: return worker.id, local_device raise ValueError(f'Invalid worker name: {on}') return on, local_device
def __init__( self, sharding_spec: ShardingSpec, *size, dtype=None, layout=torch.strided, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format, process_group=None, ): self._rpc_initialized = False self._sharded_tensor_id = None if rpc._is_current_rpc_agent_set(): # Validate PG and RPC ranks match. pg_rank = dist.get_rank() rpc_rank = rpc.get_worker_info().id if pg_rank != rpc_rank: raise ValueError( f'Default ProcessGroup and RPC ranks must be ' f'the same for ShardedTensor, found process group rank: ' f'{pg_rank} and RPC rank: {rpc_rank}' ) if layout != torch.strided: raise ValueError('Only torch.strided layout is currently supported') if memory_format != torch.contiguous_format: raise ValueError('Only torch.contiguous_format memory_format is currently supported') self._sharding_spec = sharding_spec self._dims = list(size) self._process_group = ( process_group if process_group is not None else distributed_c10d._get_default_group() ) if distributed_c10d._rank_not_in_group(self._process_group): raise ValueError(f'Global rank: {dist.get_rank()} not part of process group') self._local_shards: List[Shard] = [] self._remote_shards: Dict[int, List[rpc.RRef[Shard]]] = {} self._sharding_metadata: List[ShardMetadata] = [] if isinstance(self._sharding_spec, ChunkShardingSpec): self._init_chunked( dtype, layout, requires_grad, pin_memory, memory_format, ) elif isinstance(self._sharding_spec, EnumerableShardingSpec): self._init_enumerable( dtype, layout, requires_grad, pin_memory, memory_format, ) else: raise ValueError(f'Unsupported sharding_spec: {self._sharding_spec}') with _sharded_tensor_lock: global _sharded_tensor_current_id, _sharded_tensor_map self._sharded_tensor_id = _sharded_tensor_current_id _sharded_tensor_map[self._sharded_tensor_id] = self _sharded_tensor_current_id += 1 # Initialize RPC if available. if rpc._is_current_rpc_agent_set(): self._init_rpc()