Example #1
0
    def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
        iterator = iter(self._dataloader)
        if self._device is None:
            yield from iterator

        for item in iterator:
            yield move_data_to_device(item, self._device)
Example #2
0
    def _wrapping_function(
        self,
        process_idx: int,
        trainer: Optional["pl.Trainer"],
        function: Callable,
        args: Any,
        kwargs: Any,
        return_queue: SimpleQueue,
    ) -> None:
        self._strategy._worker_setup(process_idx)
        results = function(*args, **kwargs)

        if trainer is not None:
            results = self._collect_rank_zero_results(trainer, results)

        if self._strategy.local_rank == 0:
            return_queue.put(move_data_to_device(results, "cpu"))

        # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
        self._strategy.barrier("end-process")

        # Ensure that the rank 0 process is the one exiting last
        # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
        if self._strategy.local_rank == 0:
            time.sleep(2)
    def save_checkpoint(self,
                        checkpoint: Dict[str, Any],
                        path: _PATH,
                        storage_options: Optional[Any] = None) -> None:
        """Save model/training states as a checkpoint file through state-dump and file-write.

        Args:
            checkpoint: dict containing model and trainer state
            path: write-target path
            storage_options: not used in ``XLACheckpointIO.save_checkpoint``

        Raises:
            TypeError:
                If ``storage_options`` arg is passed in
        """
        if storage_options is not None:
            raise TypeError(
                "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
                f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`"
                " to define how you'd like to use `storage_options`.")
        fs = get_filesystem(path)
        fs.makedirs(os.path.dirname(path), exist_ok=True)

        checkpoint = move_data_to_device(checkpoint, torch.device("cpu"))
        # write the checkpoint dictionary to the provided path
        atomic_save(checkpoint, path)
    def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
        """Runs the actual predict step together with all the necessary bookkeeping and the hooks tied to the
        predict step.

        Args:
            batch: the current batch to run the prediction on
            batch_idx: the index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch
        """
        # configure step_kwargs
        step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)

        # extract batch_indices and store them
        self.current_batch_indices = self._seen_batch_indices[batch_idx] if self._seen_batch_indices else []

        self.trainer._call_callback_hooks("on_predict_batch_start", batch, batch_idx, dataloader_idx)
        self.trainer._call_lightning_module_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx)

        self.batch_progress.increment_started()

        predictions = self.trainer._call_strategy_hook("predict_step", *step_kwargs.values())

        self.batch_progress.increment_processed()

        if predictions is None:
            self._warning_cache.warn("predict returned None if it was on purpose, ignore this warning...")

        self.trainer._call_callback_hooks("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx)
        self.trainer._call_lightning_module_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx)

        self.batch_progress.increment_completed()

        if self.should_store_predictions:
            self.predictions.append(move_data_to_device(predictions, torch.device("cpu")))
Example #5
0
 def on_save(self, checkpoint: dict) -> dict:
     """
     Move XLA tensors to CPU before saving
     Recommended on XLA Guide:
     https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
     """
     return move_data_to_device(checkpoint, torch.device("cpu"))
Example #6
0
    def _move_batch_to_device(
        self,
        batch: Tuple[Tuple[torch.tensor, torch.tensor], torch.tensor],
    ) -> Tuple[Tuple[torch.tensor, torch.tensor], torch.tensor]:
        """Move a batch of data to the proper device."""
        # TODO: does this actually speed anything up?
        try:
            # assume we have implicit data
            ((users, pos_items), neg_items) = batch

            users = users.to(self.device)
            pos_items = pos_items.to(self.device)
            neg_items = neg_items.to(self.device)

            return ((users, pos_items), neg_items)
        except (AttributeError, ValueError):
            try:
                # now assume we have explicit data
                users, pos_items, ratings = batch

                users = users.to(self.device)
                pos_items = pos_items.to(self.device)
                ratings = ratings.to(self.device)

                return users, pos_items, ratings
            except (AttributeError, ValueError):
                # we have an unexpected data format, fallback to PyTorch Lightning
                return move_data_to_device(batch, self.device)
Example #7
0
 def _wrapped_function(self, process_idx: int, function: Callable,
                       args: Any, kwargs: Any,
                       return_queue: SimpleQueue) -> None:
     self._worker_setup(process_idx)
     result = function(*args, **kwargs)
     if self.local_rank == 0:
         return_queue.put(move_data_to_device(result, "cpu"))
Example #8
0
def test_v1_8_0_deprecated_torchtext_batch():

    with pytest.deprecated_call(
            match=
            "is deprecated and Lightning will remove support for it in v1.8"):
        data_iterator, _ = get_dummy_torchtext_data_iterator(num_samples=3,
                                                             batch_size=3)
        batch = next(iter(data_iterator))
        _ = move_data_to_device(batch=batch, device=torch.device("cpu"))
Example #9
0
    def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
        """Moves the batch to the correct device.

        The input and the output is the same type.

        Args:
            batch: The batch of samples to move to the correct device
            device: The target device
            dataloader_idx: The index of the dataloader to which the batch belongs.
        """
        return move_data_to_device(batch, device=device or self.root_device)
Example #10
0
    def batch_to_device(self, batch: Any, device: torch.device) -> Any:
        """Moves the batch to the correct device.
        The returned batch is of the same type as the input batch, just having all tensors on the correct device.

        Args:
            batch: The batch of samples to move to the correct device
            device: The target device
        """
        model = self.lightning_module
        if model is not None:
            return model.transfer_batch_to_device(batch, device)
        return move_data_to_device(batch, device)
Example #11
0
    def _wrapped_function(self, process_idx: int, function: Callable,
                          args: Any, kwargs: Any,
                          return_queue: Optional[SimpleQueue]) -> None:
        self._worker_setup(process_idx)
        result = function(*args, **kwargs)
        if return_queue is not None and self.local_rank == 0:
            return_queue.put(move_data_to_device(result, "cpu"))

        self.barrier("end-process")
        # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
        if self.local_rank == 0:
            time.sleep(2)
Example #12
0
def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device):
    data_iterator, _ = _get_torchtext_data_iterator(include_lengths=include_lengths)
    data_iter = iter(data_iterator)
    batch = next(data_iter)
    batch_on_device = move_data_to_device(batch, device)

    if include_lengths:
        # tensor with data
        assert (batch_on_device.text[0].device == device)
        # tensor with length of data
        assert (batch_on_device.text[1].device == device)
    else:
        assert (batch_on_device.text.device == device)
def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device):
    data_iterator, _ = get_dummy_torchtext_data_iterator(num_samples=3, batch_size=3, include_lengths=include_lengths)
    data_iter = iter(data_iterator)
    batch = next(data_iter)

    with pytest.deprecated_call(match="The `torchtext.legacy.Batch` object is deprecated"):
        batch_on_device = move_data_to_device(batch, device)

    if include_lengths:
        # tensor with data
        assert batch_on_device.text[0].device == device
        # tensor with length of data
        assert batch_on_device.text[1].device == device
    else:
        assert batch_on_device.text.device == device
Example #14
0
    def _wrapped_function(self, process_idx: int, function: Callable,
                          args: Any, kwargs: Any,
                          return_queue: SimpleQueue) -> None:
        self._worker_setup(process_idx)
        result = function(*args, **kwargs)
        if self.local_rank == 0:
            return_queue.put(move_data_to_device(result, "cpu"))

        # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
        self.barrier("end-process")

        # Ensure that the rank 0 process is the one exiting last
        # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
        if self.local_rank == 0:
            time.sleep(2)
Example #15
0
def test_wrongly_implemented_transferable_data_type(should_return):
    class TensorObject:
        def __init__(self, tensor: torch.Tensor, should_return: bool = True):
            self.tensor = tensor
            self.should_return = should_return

        def to(self, device):
            self.tensor.to(device)
            # simulate a user forgets to return self
            if self.should_return:
                return self

    tensor = torch.tensor(0.1)
    obj = TensorObject(tensor, should_return)
    assert obj == move_data_to_device(obj, torch.device("cpu"))
Example #16
0
    def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
        """Moves the batch to the correct device.

        The returned batch is of the same type as the input batch, just
        having all tensors on the correct device.

        Args:
            batch: The batch of samples to move to the correct device
            device: The target device
            dataloader_idx: The index of the dataloader to which the batch belongs.
        """
        model = self.lightning_module
        device = device or self.root_device
        if model is not None:
            return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx)
        return move_data_to_device(batch, device)
Example #17
0
    def _wrapping_function(
        self,
        process_idx: int,
        trainer: Optional["pl.Trainer"],
        function: Callable,
        args: Any,
        kwargs: Any,
        return_queue: SimpleQueue,
    ) -> None:
        self._strategy._worker_setup(process_idx)
        results = function(*args, **kwargs)

        if trainer is not None:
            results = self._collect_rank_zero_results(trainer, results)

        if self._strategy.local_rank == 0:
            return_queue.put(move_data_to_device(results, "cpu"))
Example #18
0
    def batch_to_device(self,
                        batch: Any,
                        device: Optional[torch.device] = None,
                        dataloader_idx: Optional[int] = None) -> Any:
        """Moves the batch to the correct device.
        The returned batch is of the same type as the input batch, just having all tensors on the correct device.

        Args:
            batch: The batch of samples to move to the correct device
            device: The target device
            dataloader_idx: The index of the dataloader to which the batch belongs.
        """
        model = self.lightning_module
        if model is not None and not isinstance(self.training_type_plugin,
                                                DataParallelPlugin):
            # no need to transfer batch to device in DP mode
            return model._apply_batch_transfer_handler(batch, device,
                                                       dataloader_idx)

        return move_data_to_device(batch, device)
Example #19
0
 def batch_to_device(self, batch: Any, device: torch.device):
     model = self.trainer.get_model()
     if model is not None:
         return model.transfer_batch_to_device(batch, device)
     return move_data_to_device(batch, device)