def get_ddp_sampler(dataset: Dataset, epoch: int): """ This function will create a DistributedSampler if DDP is initialized, and will just return None if DDP is not initialized. """ if is_initialized(): sampler = DistributedSampler(dataset) sampler.set_epoch(epoch) else: sampler = None return sampler
def __setstate__(self, state): self._sharded_tensor_id = None if not distributed_c10d.is_initialized(): raise RuntimeError( 'Need to initialize default process group using ' '"init_process_group" before loading ShardedTensor') self._local_shards, self._metadata, pg_state, self._sharding_spec, self._init_rrefs = state # Setup process group global _CURRENT_PROCESS_GROUP if _CURRENT_PROCESS_GROUP is None: self._process_group = distributed_c10d._get_default_group() else: self._process_group = _CURRENT_PROCESS_GROUP # Validate process group. local_rank = distributed_c10d.get_rank(self._process_group) if pg_state.local_rank != local_rank: raise RuntimeError( f'Local rank at save time was {pg_state.local_rank}, but at ' f'load time was {local_rank}') global_rank = distributed_c10d.get_rank() if pg_state.global_rank != global_rank: raise RuntimeError( f'Global rank at save time was {pg_state.global_rank}, but at ' f'load time was {global_rank}') local_world_size = distributed_c10d.get_world_size(self._process_group) if pg_state.local_world_size != local_world_size: raise RuntimeError( f'Local world size at save time was {pg_state.local_world_size}, ' f'but at load time was {local_world_size}') global_world_size = distributed_c10d.get_world_size() if pg_state.global_world_size != global_world_size: raise RuntimeError( f'Global world size at save time was {pg_state.global_world_size}, ' f'but at load time was {global_world_size}') self._post_init()