def barrier(self, *args, **kwargs) -> None: if not distributed_available(): return if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": torch.distributed.barrier(device_ids=self.determine_ddp_device_ids()) else: torch.distributed.barrier()
def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator: """Function used to reload the iterator state before once the workers are created.""" dataloader_to_iter_on = dataloader if isinstance(dataloader, CycleIterator): dataloader = dataloader_to_iter_on.loader # dataset states are collected across all ranks rank = torch.distributed.get_rank() if distributed_available( ) else 0 state_dict = state_dict[rank] _reload_dataloader_state_dict(dataloader, state_dict) # We finally spawned the workers if any. it = iter(dataloader_to_iter_on) # restore caching state state = MergedIteratorState.from_state_dict(state_dict) if isinstance(dataloader_to_iter_on, CycleIterator): it._loader_iter.state = state else: it.state = state return it
def __iter__(self): r""" Return the iterable by nesting different generators, each of which does a different filtering based on the process id when in distributed training and on the worker id if using also parallel loading in the dataloader. 1) utils.batch_filter simply ensures that at least `world_size` elements are read at a time 2) utils.filter_generator on distributed training to keep one element every `world_size` 3) utils.filter_generator on parallel workers processing to keep one element every `num_workers` """ reader = iter(self.adapter) # add distributed training logic if distributed_available(): world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() reader = utils.batch_filter(reader, size=world_size) reader = utils.filter_generator(reader, step=world_size, offset=rank) # add parallel processing with workers logic worker_info = torch.utils.data.get_worker_info() if worker_info is not None: reader = utils.filter_generator(reader, step=worker_info.num_workers, offset=worker_info.id) # pre-process data and return for line in reader: if self.do_preprocessing: line = self.adapter.preprocess_line(line) yield line
def barrier(self, name: Optional[str] = None) -> None: if not distributed_available(): return if torch.distributed.get_backend() == "nccl": torch.distributed.barrier(device_ids=self._determine_device_ids()) else: torch.distributed.barrier()
def training_epoch_end(self, outputs): ids = torch.cat([o['ids'] for o in outputs], dim=0) # in distributed mode collect ids from every process (gpu) if distributed_available(): gather_ids = [ torch.zeros_like(ids) for _ in range(torch.distributed.get_world_size()) ] torch.distributed.all_gather(gather_ids, ids) ids = torch.cat(gather_ids, dim=0) if has_len(self.trainer.datamodule.train_dataset): received = torch.zeros(len( self.trainer.datamodule.train_dataset)).to(dtype=bool) else: received = torch.zeros( len(list( self.trainer.datamodule.train_dataset))).to(dtype=bool) received[ids] = True if self.check_ids: # assert no duplicate element received assert len(set(ids.tolist())) == len( ids.tolist()), (f"Received {len(ids.tolist())} ids but only" f" {len(set(ids.tolist()))} are unique: {ids}") # assert all elements received assert all(received), ( f"({self.trainer.max_steps}) Received not all {len(received)} ids: {received}" )
def broadcast(self, obj: object, src: int = 0) -> object: if not distributed_available(): return obj obj = [obj] if self.global_rank != src: obj = [None] torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) return obj[0]
def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]: cache = None if on_step and result_metric.meta.on_step: cache = result_metric._forward_cache elif not on_step and result_metric.meta.on_epoch: if result_metric._computed is None: should = result_metric.meta.sync.should if not result_metric.meta.sync.should and distributed_available( ): # ensure sync happens for FT since during a failure, the metrics are synced and saved to the # checkpoint, so during restart, metrics on rank 0 are from the accumulated ones from the previous # run, and on other ranks, they are 0. So we need to make sure they are synced in further training # to ensure correct calculation. if _fault_tolerant_training(): result_metric.meta.sync.should = True else: warning_cache.warn( f"It is recommended to use `self.log({result_metric.meta.name!r}, ..., sync_dist=True)`" " when logging on epoch level in distributed setting to accumulate the metric across" " devices.", category=PossibleUserWarning, ) result_metric.compute() result_metric.meta.sync.should = should cache = result_metric._computed if cache is not None: if not isinstance(cache, Tensor): raise ValueError( f"The `.compute()` return of the metric logged as {result_metric.meta.name!r} must be a tensor." f" Found {cache}") if not result_metric.meta.enable_graph: return cache.detach() return cache
def barrier(self, *args, **kwargs): if distributed_available(): self.join()
def barrier(self, *args: Any, **kwargs: Any) -> None: if distributed_available(): self.join()
def broadcast(self, obj: object, src: int = 0) -> object: if not distributed_available(): return obj return self.dist.broadcast(obj)