def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: iterator = iter(self._dataloader) if self._device is None: yield from iterator for item in iterator: yield move_data_to_device(item, self._device)
def _wrapping_function( self, process_idx: int, trainer: Optional["pl.Trainer"], function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue, ) -> None: self._strategy._worker_setup(process_idx) results = function(*args, **kwargs) if trainer is not None: results = self._collect_rank_zero_results(trainer, results) if self._strategy.local_rank == 0: return_queue.put(move_data_to_device(results, "cpu")) # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 self._strategy.barrier("end-process") # Ensure that the rank 0 process is the one exiting last # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 if self._strategy.local_rank == 0: time.sleep(2)
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: checkpoint: dict containing model and trainer state path: write-target path storage_options: not used in ``XLACheckpointIO.save_checkpoint`` Raises: TypeError: If ``storage_options`` arg is passed in """ if storage_options is not None: raise TypeError( "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`" " to define how you'd like to use `storage_options`.") fs = get_filesystem(path) fs.makedirs(os.path.dirname(path), exist_ok=True) checkpoint = move_data_to_device(checkpoint, torch.device("cpu")) # write the checkpoint dictionary to the provided path atomic_save(checkpoint, path)
def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """Runs the actual predict step together with all the necessary bookkeeping and the hooks tied to the predict step. Args: batch: the current batch to run the prediction on batch_idx: the index of the current batch dataloader_idx: the index of the dataloader producing the current batch """ # configure step_kwargs step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) # extract batch_indices and store them self.current_batch_indices = self._seen_batch_indices[batch_idx] if self._seen_batch_indices else [] self.trainer._call_callback_hooks("on_predict_batch_start", batch, batch_idx, dataloader_idx) self.trainer._call_lightning_module_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx) self.batch_progress.increment_started() predictions = self.trainer._call_strategy_hook("predict_step", *step_kwargs.values()) self.batch_progress.increment_processed() if predictions is None: self._warning_cache.warn("predict returned None if it was on purpose, ignore this warning...") self.trainer._call_callback_hooks("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx) self.trainer._call_lightning_module_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx) self.batch_progress.increment_completed() if self.should_store_predictions: self.predictions.append(move_data_to_device(predictions, torch.device("cpu")))
def on_save(self, checkpoint: dict) -> dict: """ Move XLA tensors to CPU before saving Recommended on XLA Guide: https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors """ return move_data_to_device(checkpoint, torch.device("cpu"))
def _move_batch_to_device( self, batch: Tuple[Tuple[torch.tensor, torch.tensor], torch.tensor], ) -> Tuple[Tuple[torch.tensor, torch.tensor], torch.tensor]: """Move a batch of data to the proper device.""" # TODO: does this actually speed anything up? try: # assume we have implicit data ((users, pos_items), neg_items) = batch users = users.to(self.device) pos_items = pos_items.to(self.device) neg_items = neg_items.to(self.device) return ((users, pos_items), neg_items) except (AttributeError, ValueError): try: # now assume we have explicit data users, pos_items, ratings = batch users = users.to(self.device) pos_items = pos_items.to(self.device) ratings = ratings.to(self.device) return users, pos_items, ratings except (AttributeError, ValueError): # we have an unexpected data format, fallback to PyTorch Lightning return move_data_to_device(batch, self.device)
def _wrapped_function(self, process_idx: int, function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue) -> None: self._worker_setup(process_idx) result = function(*args, **kwargs) if self.local_rank == 0: return_queue.put(move_data_to_device(result, "cpu"))
def test_v1_8_0_deprecated_torchtext_batch(): with pytest.deprecated_call( match= "is deprecated and Lightning will remove support for it in v1.8"): data_iterator, _ = get_dummy_torchtext_data_iterator(num_samples=3, batch_size=3) batch = next(iter(data_iterator)) _ = move_data_to_device(batch=batch, device=torch.device("cpu"))
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: """Moves the batch to the correct device. The input and the output is the same type. Args: batch: The batch of samples to move to the correct device device: The target device dataloader_idx: The index of the dataloader to which the batch belongs. """ return move_data_to_device(batch, device=device or self.root_device)
def batch_to_device(self, batch: Any, device: torch.device) -> Any: """Moves the batch to the correct device. The returned batch is of the same type as the input batch, just having all tensors on the correct device. Args: batch: The batch of samples to move to the correct device device: The target device """ model = self.lightning_module if model is not None: return model.transfer_batch_to_device(batch, device) return move_data_to_device(batch, device)
def _wrapped_function(self, process_idx: int, function: Callable, args: Any, kwargs: Any, return_queue: Optional[SimpleQueue]) -> None: self._worker_setup(process_idx) result = function(*args, **kwargs) if return_queue is not None and self.local_rank == 0: return_queue.put(move_data_to_device(result, "cpu")) self.barrier("end-process") # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 if self.local_rank == 0: time.sleep(2)
def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device): data_iterator, _ = _get_torchtext_data_iterator(include_lengths=include_lengths) data_iter = iter(data_iterator) batch = next(data_iter) batch_on_device = move_data_to_device(batch, device) if include_lengths: # tensor with data assert (batch_on_device.text[0].device == device) # tensor with length of data assert (batch_on_device.text[1].device == device) else: assert (batch_on_device.text.device == device)
def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device): data_iterator, _ = get_dummy_torchtext_data_iterator(num_samples=3, batch_size=3, include_lengths=include_lengths) data_iter = iter(data_iterator) batch = next(data_iter) with pytest.deprecated_call(match="The `torchtext.legacy.Batch` object is deprecated"): batch_on_device = move_data_to_device(batch, device) if include_lengths: # tensor with data assert batch_on_device.text[0].device == device # tensor with length of data assert batch_on_device.text[1].device == device else: assert batch_on_device.text.device == device
def _wrapped_function(self, process_idx: int, function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue) -> None: self._worker_setup(process_idx) result = function(*args, **kwargs) if self.local_rank == 0: return_queue.put(move_data_to_device(result, "cpu")) # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 self.barrier("end-process") # Ensure that the rank 0 process is the one exiting last # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 if self.local_rank == 0: time.sleep(2)
def test_wrongly_implemented_transferable_data_type(should_return): class TensorObject: def __init__(self, tensor: torch.Tensor, should_return: bool = True): self.tensor = tensor self.should_return = should_return def to(self, device): self.tensor.to(device) # simulate a user forgets to return self if self.should_return: return self tensor = torch.tensor(0.1) obj = TensorObject(tensor, should_return) assert obj == move_data_to_device(obj, torch.device("cpu"))
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: """Moves the batch to the correct device. The returned batch is of the same type as the input batch, just having all tensors on the correct device. Args: batch: The batch of samples to move to the correct device device: The target device dataloader_idx: The index of the dataloader to which the batch belongs. """ model = self.lightning_module device = device or self.root_device if model is not None: return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx) return move_data_to_device(batch, device)
def _wrapping_function( self, process_idx: int, trainer: Optional["pl.Trainer"], function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue, ) -> None: self._strategy._worker_setup(process_idx) results = function(*args, **kwargs) if trainer is not None: results = self._collect_rank_zero_results(trainer, results) if self._strategy.local_rank == 0: return_queue.put(move_data_to_device(results, "cpu"))
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: Optional[int] = None) -> Any: """Moves the batch to the correct device. The returned batch is of the same type as the input batch, just having all tensors on the correct device. Args: batch: The batch of samples to move to the correct device device: The target device dataloader_idx: The index of the dataloader to which the batch belongs. """ model = self.lightning_module if model is not None and not isinstance(self.training_type_plugin, DataParallelPlugin): # no need to transfer batch to device in DP mode return model._apply_batch_transfer_handler(batch, device, dataloader_idx) return move_data_to_device(batch, device)
def batch_to_device(self, batch: Any, device: torch.device): model = self.trainer.get_model() if model is not None: return model.transfer_batch_to_device(batch, device) return move_data_to_device(batch, device)