예제 #1
0
 def barrier(self, *args, **kwargs) -> None:
     if not distributed_available():
         return
     if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
         torch.distributed.barrier(device_ids=self.determine_ddp_device_ids())
     else:
         torch.distributed.barrier()
예제 #2
0
        def create_loader_iters(dataloader: DataLoader,
                                state_dict: Dict) -> Iterator:
            """Function used to reload the iterator state before once the workers are created."""

            dataloader_to_iter_on = dataloader
            if isinstance(dataloader, CycleIterator):
                dataloader = dataloader_to_iter_on.loader

            # dataset states are collected across all ranks
            rank = torch.distributed.get_rank() if distributed_available(
            ) else 0
            state_dict = state_dict[rank]

            _reload_dataloader_state_dict(dataloader, state_dict)

            # We finally spawned the workers if any.
            it = iter(dataloader_to_iter_on)

            # restore caching state
            state = MergedIteratorState.from_state_dict(state_dict)

            if isinstance(dataloader_to_iter_on, CycleIterator):
                it._loader_iter.state = state
            else:
                it.state = state
            return it
예제 #3
0
    def __iter__(self):
        r"""
        Return the iterable by nesting different generators, each of which does a different
        filtering based on the process id when in distributed training and on the worker id
        if using also parallel loading in the dataloader.

        1) utils.batch_filter simply ensures that at least `world_size` elements are read at a time
        2) utils.filter_generator on distributed training to keep one element every `world_size`
        3) utils.filter_generator on parallel workers processing to keep one element every `num_workers`
        """
        reader = iter(self.adapter)

        # add distributed training logic
        if distributed_available():

            world_size = torch.distributed.get_world_size()
            rank = torch.distributed.get_rank()

            reader = utils.batch_filter(reader, size=world_size)
            reader = utils.filter_generator(reader,
                                            step=world_size,
                                            offset=rank)

        # add parallel processing with workers logic
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            reader = utils.filter_generator(reader,
                                            step=worker_info.num_workers,
                                            offset=worker_info.id)

        # pre-process data and return
        for line in reader:
            if self.do_preprocessing:
                line = self.adapter.preprocess_line(line)
            yield line
 def barrier(self, name: Optional[str] = None) -> None:
     if not distributed_available():
         return
     if torch.distributed.get_backend() == "nccl":
         torch.distributed.barrier(device_ids=self._determine_device_ids())
     else:
         torch.distributed.barrier()
예제 #5
0
    def training_epoch_end(self, outputs):
        ids = torch.cat([o['ids'] for o in outputs], dim=0)

        # in distributed mode collect ids from every process (gpu)
        if distributed_available():
            gather_ids = [
                torch.zeros_like(ids)
                for _ in range(torch.distributed.get_world_size())
            ]
            torch.distributed.all_gather(gather_ids, ids)
            ids = torch.cat(gather_ids, dim=0)

        if has_len(self.trainer.datamodule.train_dataset):
            received = torch.zeros(len(
                self.trainer.datamodule.train_dataset)).to(dtype=bool)
        else:
            received = torch.zeros(
                len(list(
                    self.trainer.datamodule.train_dataset))).to(dtype=bool)
        received[ids] = True

        if self.check_ids:
            # assert no duplicate element received
            assert len(set(ids.tolist())) == len(
                ids.tolist()), (f"Received {len(ids.tolist())} ids but only"
                                f" {len(set(ids.tolist()))} are unique: {ids}")
            # assert all elements received
            assert all(received), (
                f"({self.trainer.max_steps}) Received not all {len(received)} ids: {received}"
            )
예제 #6
0
 def broadcast(self, obj: object, src: int = 0) -> object:
     if not distributed_available():
         return obj
     obj = [obj]
     if self.global_rank != src:
         obj = [None]
     torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
     return obj[0]
예제 #7
0
    def _get_cache(result_metric: _ResultMetric,
                   on_step: bool) -> Optional[Tensor]:
        cache = None
        if on_step and result_metric.meta.on_step:
            cache = result_metric._forward_cache
        elif not on_step and result_metric.meta.on_epoch:
            if result_metric._computed is None:
                should = result_metric.meta.sync.should
                if not result_metric.meta.sync.should and distributed_available(
                ):
                    # ensure sync happens for FT since during a failure, the metrics are synced and saved to the
                    # checkpoint, so during restart, metrics on rank 0 are from the accumulated ones from the previous
                    # run, and on other ranks, they are 0. So we need to make sure they are synced in further training
                    # to ensure correct calculation.
                    if _fault_tolerant_training():
                        result_metric.meta.sync.should = True
                    else:
                        warning_cache.warn(
                            f"It is recommended to use `self.log({result_metric.meta.name!r}, ..., sync_dist=True)`"
                            " when logging on epoch level in distributed setting to accumulate the metric across"
                            " devices.",
                            category=PossibleUserWarning,
                        )
                result_metric.compute()
                result_metric.meta.sync.should = should

            cache = result_metric._computed

        if cache is not None:
            if not isinstance(cache, Tensor):
                raise ValueError(
                    f"The `.compute()` return of the metric logged as {result_metric.meta.name!r} must be a tensor."
                    f" Found {cache}")
            if not result_metric.meta.enable_graph:
                return cache.detach()

        return cache
예제 #8
0
 def barrier(self, *args, **kwargs):
     if distributed_available():
         self.join()
예제 #9
0
 def barrier(self, *args: Any, **kwargs: Any) -> None:
     if distributed_available():
         self.join()
예제 #10
0
 def broadcast(self, obj: object, src: int = 0) -> object:
     if not distributed_available():
         return obj
     return self.dist.broadcast(obj)