Exemplo n.º 1
0
    def update_replica_device_attributes(self, inputs: Any) -> None:
        """Updates the device information of LightningModule by reading the device from the inputs. In
        :class:`~torch.nn.data_parallel.DataParallel` changes to the state during the `forward` pass are lost when
        the replicas get discarded. The only way to know the current device is from the inputs passed into the
        model.

        Args:
            inputs: A collection of inputs (typically a tuple). If the inputs don't contain tensors,
                a warning is shown that accessing ``self.device`` will not return the correct device.
        """
        replica_device = None

        def find_tensor_with_device(tensor: Tensor) -> Tensor:
            nonlocal replica_device
            if replica_device is None and tensor.device != torch.device("cpu"):
                replica_device = tensor.device
            return tensor

        apply_to_collection(inputs,
                            dtype=Tensor,
                            function=find_tensor_with_device)

        if replica_device is not None:
            # by calling .to() we force the update to the self.device property
            self.module.to(device=replica_device)
        else:
            rank_zero_warn(
                "Could not determine on which device the inputs are."
                " When using DataParallel (strategy='dp'), be aware that in case you are using self.device"
                " in your code, it will reference only the root device.")
    def _resolve_sampler(self,
                         dataloader: DataLoader,
                         shuffle: bool,
                         mode: Optional[RunningStage] = None) -> Sampler:
        if self._requires_distributed_sampler(dataloader):
            sampler = self._get_distributed_sampler(
                dataloader,
                shuffle,
                mode=mode,
                overfit_batches=self.trainer.overfit_batches,
                **self.trainer.distributed_sampler_kwargs,
            )

            # update docs too once this is resolved
            trainer_fn = self.trainer.state.fn
            if isinstance(sampler, DistributedSampler) and trainer_fn in (
                    TrainerFn.VALIDATING, TrainerFn.TESTING):
                rank_zero_warn(
                    f"Using `DistributedSampler` with the dataloaders. During `trainer.{trainer_fn.value}()`,"
                    " it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated"
                    " exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates"
                    " some samples to make sure all devices have same batch size in case of uneven inputs.",
                    category=PossibleUserWarning,
                )

            return sampler

        return dataloader.sampler
    def _resolve_overfit_batches(
            dataloader: Collection[DataLoader]) -> Collection[DataLoader]:
        all_have_sequential_sampler = True

        def resolve_has_no_sequential_sampler(dataloader: DataLoader):
            nonlocal all_have_sequential_sampler
            all_have_sequential_sampler = all_have_sequential_sampler & isinstance(
                dataloader.sampler, SequentialSampler)

        apply_to_collection(dataloader, DataLoader,
                            resolve_has_no_sequential_sampler)

        if not all_have_sequential_sampler:
            rank_zero_warn(
                "You requested to overfit but enabled training dataloader shuffling."
                " We are turning off the training dataloader shuffling for you."
            )

            def replace_sampler(dataloader: DataLoader) -> DataLoader:
                return _update_dataloader(dataloader,
                                          SequentialSampler(
                                              dataloader.dataset),
                                          mode=RunningStage.TRAINING)

            dataloader = apply_to_collection(dataloader, DataLoader,
                                             replace_sampler)

        return dataloader
    def save_checkpoint(self,
                        checkpoint: Dict[str, Any],
                        path: _PATH,
                        storage_options: Optional[Any] = None) -> None:
        """Save model/training states as a checkpoint file through state-dump and file-write.

        Args:
            checkpoint: dict containing model and trainer state
            path: write-target path
            storage_options: not used in ``TorchCheckpointIO.save_checkpoint``

        Raises:
            TypeError:
                If ``storage_options`` arg is passed in
        """
        if storage_options is not None:
            raise TypeError(
                "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
                f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`"
                " to define how you'd like to use `storage_options`.")
        fs = get_filesystem(path)
        fs.makedirs(os.path.dirname(path), exist_ok=True)
        try:
            # write the checkpoint dictionary on the file
            atomic_save(checkpoint, path)
        except AttributeError as err:
            # todo (sean): is this try catch necessary still?
            # https://github.com/Lightning-AI/lightning/pull/431
            key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY
            checkpoint.pop(key, None)
            rank_zero_warn(
                f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}"
            )
            atomic_save(checkpoint, path)
 def _check_eval_shuffling(dataloader, mode):
     if _is_dataloader_shuffled(dataloader):
         rank_zero_warn(
             f"Your `{mode.dataloader_prefix}_dataloader`'s sampler has shuffling enabled,"
             " it is strongly recommended that you turn shuffling off for val/test/predict dataloaders.",
             category=PossibleUserWarning,
         )
Exemplo n.º 6
0
    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        r"""
        .. deprecated:: v1.6
            `TrainerCallbackHookMixin.on_load_checkpoint` was deprecated in v1.6 and will be removed in v1.8.

        Called when loading a model checkpoint.
        """
        # Todo: the `callback_states` are dropped with TPUSpawn as they
        # can't be saved using `xm.save`
        # https://github.com/pytorch/xla/issues/2773
        rank_zero_deprecation(
            "`TrainerCallbackHookMixin.on_load_checkpoint` was deprecated in v1.6 and will be removed in v1.8."
        )
        callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks")

        if callback_states is None:
            return

        is_legacy_ckpt = Version(checkpoint["pytorch-lightning_version"]) < Version("1.5.0dev")
        current_callbacks_keys = {cb._legacy_state_key if is_legacy_ckpt else cb.state_key for cb in self.callbacks}
        difference = callback_states.keys() - current_callbacks_keys
        if difference:
            rank_zero_warn(
                "Be aware that when using `ckpt_path`,"
                " callbacks used to create the checkpoint need to be provided during `Trainer` instantiation."
                f" Please add the following callbacks: {list(difference)}.",
            )

        for callback in self.callbacks:
            state = callback_states.get(callback.state_key, callback_states.get(callback._legacy_state_key))
            if state:
                state = deepcopy(state)
                callback.on_load_checkpoint(self, self.lightning_module, state)
Exemplo n.º 7
0
    def get_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]:
        r"""
        Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics.
        Implement this to override the items displayed in the progress bar.

        Here is an example of how to override the defaults:

        .. code-block:: python

            def get_metrics(self, trainer, model):
                # don't show the version number
                items = super().get_metrics(trainer, model)
                items.pop("v_num", None)
                return items

        Return:
            Dictionary with the items to be displayed in the progress bar.
        """
        standard_metrics = get_standard_metrics(trainer, pl_module)
        pbar_metrics = trainer.progress_bar_metrics
        duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
        if duplicates:
            rank_zero_warn(
                f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and"
                f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. "
                " If this is undesired, change the name or override `get_metrics()` in the progress bar callback.",
            )

        return {**standard_metrics, **pbar_metrics}
Exemplo n.º 8
0
def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None:
    valid_keys = {"optimizer", "lr_scheduler", "frequency", "monitor"}
    extra_keys = optim_conf.keys() - valid_keys
    if extra_keys:
        rank_zero_warn(
            f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}",
            category=RuntimeWarning)
Exemplo n.º 9
0
def _get_short_description(component: object) -> Optional[str]:
    parse = import_docstring_parse("LightningCLI(run=True)")
    try:
        docstring = parse(component.__doc__)
        return docstring.short_description
    except ValueError:
        rank_zero_warn(f"Failed parsing docstring for {component}")
Exemplo n.º 10
0
 def _instantiate_trainer(self, config: Dict[str, Any],
                          callbacks: List[Callback]) -> Trainer:
     key = "callbacks"
     if key in config:
         if config[key] is None:
             config[key] = []
         elif not isinstance(config[key], list):
             config[key] = [config[key]]
         config[key].extend(callbacks)
         if key in self.trainer_defaults:
             value = self.trainer_defaults[key]
             config[key] += value if isinstance(value, list) else [value]
         if self.save_config_callback and not config.get(
                 "fast_dev_run", False):
             config_callback = self.save_config_callback(
                 self._parser(self.subcommand),
                 self.config.get(str(self.subcommand), self.config),
                 self.save_config_filename,
                 overwrite=self.save_config_overwrite,
                 multifile=self.save_config_multifile,
             )
             config[key].append(config_callback)
     else:
         rank_zero_warn(
             f"The `{self.trainer_class.__qualname__}` class does not expose the `{key}` argument so they will"
             " not be included.")
     return self.trainer_class(**config)
Exemplo n.º 11
0
def _configure_schedulers_manual_opt(
        schedulers: list) -> List[LRSchedulerConfig]:
    """Convert each scheduler into `LRSchedulerConfig` structure with relevant information, when using manual
    optimization."""
    lr_scheduler_configs = []
    for scheduler in schedulers:
        if isinstance(scheduler, dict):
            invalid_keys = {
                "interval", "frequency", "reduce_on_plateau", "monitor",
                "strict"
            }
            keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys]

            if keys_to_warn:
                rank_zero_warn(
                    f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored."
                    " You need to call `lr_scheduler.step()` manually in manual optimization.",
                    category=RuntimeWarning,
                )

            config = LRSchedulerConfig(
                **{
                    key: scheduler[key]
                    for key in scheduler if key not in invalid_keys
                })
        else:
            config = LRSchedulerConfig(scheduler)
        lr_scheduler_configs.append(config)
    return lr_scheduler_configs
Exemplo n.º 12
0
def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> Dict[str, Any]:
    """Load hparams from a file.

        Args:
            config_yaml: Path to config yaml file
            use_omegaconf: If omegaconf is available and ``use_omegaconf=True``,
                the hparams will be converted to ``DictConfig`` if possible.

    >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
    >>> path_yaml = './testing-hparams.yaml'
    >>> save_hparams_to_yaml(path_yaml, hparams)
    >>> hparams_new = load_hparams_from_yaml(path_yaml)
    >>> vars(hparams) == hparams_new
    True
    >>> os.remove(path_yaml)
    """
    fs = get_filesystem(config_yaml)
    if not fs.exists(config_yaml):
        rank_zero_warn(f"Missing Tags: {config_yaml}.", category=RuntimeWarning)
        return {}

    with fs.open(config_yaml, "r") as fp:
        hparams = yaml.full_load(fp)

    if _OMEGACONF_AVAILABLE:
        if use_omegaconf:
            try:
                return OmegaConf.create(hparams)
            except (UnsupportedValueType, ValidationError):
                pass
    return hparams
Exemplo n.º 13
0
    def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List:
        """This function is used to exclude any parameter which already exists in this optimizer.

        Args:
            optimizer: Optimizer used for parameter exclusion
            params: Iterable of parameters used to check against the provided optimizer

        Returns:
            List of parameters not contained in this optimizer param groups
        """
        out_params = []
        removed_params = []
        for param in params:
            if not any(
                    torch.equal(p, param) for group in optimizer.param_groups
                    for p in group["params"]):
                out_params.append(param)
            else:
                removed_params.append(param)

        if removed_params:
            rank_zero_warn(
                "The provided params to be frozen already exist within another group of this optimizer."
                " Those parameters will be skipped.\n"
                "HINT: Did you init your optimizer in `configure_optimizer` as such:\n"
                f" {type(optimizer)}(filter(lambda p: p.requires_grad, self.parameters()), ...) ",
            )
        return out_params
    def reset(self) -> None:
        """Resets the internal state of the loop for a new run."""
        if self.restarting:
            self.batch_progress.reset_on_restart()
            self.scheduler_progress.reset_on_restart()
            self.batch_loop.optimizer_loop.optim_progress.reset_on_restart()

            trainer = self.trainer
            if not trainer.state._fault_tolerant_mode.is_enabled and trainer.num_training_batches != float(
                    "inf"):
                expected_steps = math.ceil(trainer.num_training_batches /
                                           trainer.accumulate_grad_batches)
                if self.global_step % expected_steps != 0:
                    rank_zero_warn(
                        "You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable"
                        " results if further training is done. Consider using an end-of-epoch checkpoint or enabling"
                        " fault-tolerant training:"
                        " https://pytorch-lightning.readthedocs.io/en/stable/advanced/fault_tolerant_training.html"
                    )
        else:
            self.batch_progress.reset_on_run()
            self.scheduler_progress.reset_on_run()
            self.batch_loop.optimizer_loop.optim_progress.reset_on_run()
            # when the epoch starts, the total val batch progress should be reset as it's supposed to count the batches
            # seen per epoch, this is useful for tracking when validation is run multiple times per epoch
            self.val_loop.epoch_loop.batch_progress.total.reset()

        self._outputs = []
Exemplo n.º 15
0
def test_v1_8_0_rank_zero_imports():

    import warnings

    from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_info
    from pytorch_lightning.utilities.warnings import LightningDeprecationWarning, rank_zero_deprecation, rank_zero_warn

    with pytest.deprecated_call(
        match="pytorch_lightning.utilities.distributed.rank_zero_debug has been deprecated in v1.6"
        " and will be removed in v1.8."
    ):
        rank_zero_debug("foo")
    with pytest.deprecated_call(
        match="pytorch_lightning.utilities.distributed.rank_zero_info has been deprecated in v1.6"
        " and will be removed in v1.8."
    ):
        rank_zero_info("foo")
    with pytest.deprecated_call(
        match="pytorch_lightning.utilities.warnings.rank_zero_warn has been deprecated in v1.6"
        " and will be removed in v1.8."
    ):
        rank_zero_warn("foo")
    with pytest.deprecated_call(
        match="pytorch_lightning.utilities.warnings.rank_zero_deprecation has been deprecated in v1.6"
        " and will be removed in v1.8."
    ):
        rank_zero_deprecation("foo")
    with pytest.deprecated_call(
        match="pytorch_lightning.utilities.warnings.LightningDeprecationWarning has been deprecated in v1.6"
        " and will be removed in v1.8."
    ):
        warnings.warn("foo", LightningDeprecationWarning, stacklevel=5)
Exemplo n.º 16
0
 def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
     """Gets parallel devices for the Accelerator."""
     if isinstance(devices, int):
         return [torch.device("cpu")] * devices
     rank_zero_warn(
         f"The flag `devices` must be an int with `accelerator='cpu'`, got `devices={devices!r}` instead."
     )
     return []
Exemplo n.º 17
0
 def main_address(self) -> str:
     if "MASTER_ADDR" not in os.environ:
         rank_zero_warn(
             "MASTER_ADDR environment variable is not defined. Set as localhost"
         )
         os.environ["MASTER_ADDR"] = "127.0.0.1"
     log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
     return os.environ["MASTER_ADDR"]
Exemplo n.º 18
0
def init_meta_context() -> Generator:
    rank_zero_warn(
        "Be aware this feature is highly experimental and there are a number of weird edge cases "
        "where it can internal assert and/or crash. A more stable version is to be expected from PyTorch 1.11."
    )
    _set_meta_device()
    yield
    _unset_meta_device()
Exemplo n.º 19
0
def _get_short_description(component: object) -> Optional[str]:
    if component.__doc__ is None:
        return None
    try:
        docstring = docstring_parser.parse(component.__doc__)
        return docstring.short_description
    except (ValueError, docstring_parser.ParseError) as ex:
        rank_zero_warn(f"Failed parsing docstring for {component}: {ex}")
Exemplo n.º 20
0
    def __init__(
        self,
        name: Optional[str] = None,
        save_dir: Optional[str] = None,
        offline: Optional[bool] = False,
        id: Optional[str] = None,
        anonymous: Optional[bool] = None,
        version: Optional[str] = None,
        project: Optional[str] = None,
        log_model: Union[str, bool] = False,
        experiment=None,
        prefix: Optional[str] = "",
        **kwargs,
    ):
        if wandb is None:
            raise ModuleNotFoundError(
                "You want to use `wandb` logger which is not installed yet,"
                " install it with `pip install wandb`."  # pragma: no-cover
            )

        if offline and log_model:
            raise MisconfigurationException(
                f"Providing log_model={log_model} and offline={offline} is an invalid configuration"
                " since model checkpoints cannot be uploaded in offline mode.\n"
                "Hint: Set `offline=False` to log your model.")

        if log_model and not _WANDB_GREATER_EQUAL_0_10_22:
            rank_zero_warn(
                f"Providing log_model={log_model} requires wandb version >= 0.10.22"
                " for logging associated model metadata.\n"
                "Hint: Upgrade with `pip install --upgrade wandb`.")

        super().__init__()
        self._offline = offline
        self._log_model = log_model
        self._prefix = prefix
        self._experiment = experiment
        self._logged_model_time = {}
        self._checkpoint_callback = None
        # set wandb init arguments
        anonymous_lut = {True: "allow", False: None}
        self._wandb_init = dict(
            name=name,
            project=project,
            id=version or id,
            dir=save_dir,
            resume="allow",
            anonymous=anonymous_lut.get(anonymous, anonymous),
        )
        self._wandb_init.update(**kwargs)
        # extract parameters
        self._save_dir = self._wandb_init.get("dir")
        self._name = self._wandb_init.get("name")
        self._id = self._wandb_init.get("id")
        # start wandb run (to create an attach_id for distributed modes)
        if _WANDB_GREATER_EQUAL_0_12_10:
            wandb.require("service")
            _ = self.experiment
Exemplo n.º 21
0
    def main_port(self) -> int:
        if "MASTER_PORT" not in os.environ:
            rank_zero_warn(
                "MASTER_PORT environment variable is not defined. Set as 12910"
            )
            os.environ["MASTER_PORT"] = "12910"
        log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

        return int(os.environ["MASTER_PORT"])
Exemplo n.º 22
0
    def on_train_start(self, trainer: "pl.Trainer", *args: Any,
                       **kwargs: Any) -> None:
        """Called before training, determines unique names for all lr schedulers in the case of multiple of the
        same type or in the case of multiple parameter groups.

        Raises:
            MisconfigurationException:
                If ``Trainer`` has no ``logger``.
        """
        if not trainer.loggers:
            raise MisconfigurationException(
                "Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger."
            )

        if self.log_momentum:

            def _check_no_key(key: str) -> bool:
                if trainer.lr_scheduler_configs:
                    return any(key not in config.scheduler.optimizer.defaults
                               for config in trainer.lr_scheduler_configs)

                return any(key not in optimizer.defaults
                           for optimizer in trainer.optimizers)

            if _check_no_key("momentum") and _check_no_key("betas"):
                rank_zero_warn(
                    "You have set log_momentum=True, but some optimizers do not"
                    " have momentum. This will log a value 0 for the momentum.",
                    category=RuntimeWarning,
                )

        # Find names for schedulers
        names: List[List[str]] = []
        (
            sched_hparam_keys,
            optimizers_with_scheduler,
            optimizers_with_scheduler_types,
        ) = self._find_names_from_schedulers(trainer.lr_scheduler_configs)
        names.extend(sched_hparam_keys)

        # Find names for leftover optimizers
        optimizer_hparam_keys, _ = self._find_names_from_optimizers(
            trainer.optimizers,
            seen_optimizers=optimizers_with_scheduler,
            seen_optimizer_types=optimizers_with_scheduler_types,
        )
        names.extend(optimizer_hparam_keys)

        # Initialize for storing values
        names_flatten = list(itertools.chain.from_iterable(names))
        self.lrs = {name: [] for name in names_flatten}
        self.last_momentum_values = {
            name + "-momentum": None
            for name in names_flatten
        }
Exemplo n.º 23
0
        def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]:
            if not self._update_called:
                rank_zero_warn(
                    f"The ``compute`` method of metric {self.__class__.__name__}"
                    " was called before the ``update`` method which may lead to errors,"
                    " as metric states have not yet been updated.", )

            # return cached value
            if self._computed is not None:
                return self._computed
            self._computed = compute(*args, **kwargs)
            return self._computed
Exemplo n.º 24
0
    def log_hyperparams(self, params: Union[Dict[str, Any],
                                            Namespace]) -> None:
        params = _convert_params(params)
        params = _flatten_dict(params)
        for k, v in params.items():
            if len(str(v)) > 250:
                rank_zero_warn(
                    f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}",
                    category=RuntimeWarning)
                continue

            self.experiment.log_param(self.run_id, k, v)
Exemplo n.º 25
0
    def log_graph(self, model: "pl.LightningModule", input_array=None):
        if self._log_graph:
            if input_array is None:
                input_array = model.example_input_array

            if input_array is not None:
                self.experiment.add_graph(
                    model, model._apply_batch_transfer_handler(input_array))
            else:
                rank_zero_warn(
                    "Could not log computational graph since neither the"
                    " `model.example_input_array` attribute is set nor"
                    " `input_array` was given", )
Exemplo n.º 26
0
    def __init__(self, log_dir: str) -> None:
        self.hparams = {}
        self.metrics = []

        self.log_dir = log_dir
        if os.path.exists(self.log_dir) and os.listdir(self.log_dir):
            rank_zero_warn(
                f"Experiment logs directory {self.log_dir} exists and is not empty."
                " Previous log files in this directory will be deleted when the new ones are saved!"
            )
        os.makedirs(self.log_dir, exist_ok=True)

        self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE)
Exemplo n.º 27
0
def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]:
    training_step_fx = getattr(trainer.lightning_module, "training_step")
    if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
        rank_zero_warn(
            "Found `dataloader_iter` argument in the `training_step`. Note that the support for "
            "this signature is experimental and the behavior is subject to change."
        )
        return DataLoaderIterDataFetcher
    elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
        if not isinstance(trainer.accelerator, CUDAAccelerator):
            raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")
        return InterBatchParallelDataFetcher
    return DataFetcher
Exemplo n.º 28
0
def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None:
    """Removes all unpicklable entries from hparams."""

    hparams_dict = hparams
    if isinstance(hparams, Namespace):
        hparams_dict = hparams.__dict__

    del_attrs = [k for k, v in hparams_dict.items() if not is_picklable(v)]

    for k in del_attrs:
        rank_zero_warn(
            f"attribute '{k}' removed from hparams because it cannot be pickled"
        )
        del hparams_dict[k]
Exemplo n.º 29
0
    def log_graph(self, model: "pl.LightningModule", input_array=None):
        if self._log_graph:
            if input_array is None:
                input_array = model.example_input_array

            if input_array is not None:
                input_array = model._apply_batch_transfer_handler(input_array)
                model._running_torchscript = True
                self.experiment.add_graph(model, input_array)
                model._running_torchscript = False
            else:
                rank_zero_warn(
                    "Could not log computational graph since the"
                    " `model.example_input_array` attribute is not set"
                    " or `input_array` was not given", )
Exemplo n.º 30
0
    def inner_fn(self, *args, **kwargs):
        pre_layer_count = len(list(self.model.parameters()))
        module = fn(self, *args, **kwargs)
        self.model.on_post_move_to_device()
        post_layer_count = len(list(self.model.parameters()))

        if not pre_layer_count == post_layer_count:
            rank_zero_warn(
                "The model layers do not match after moving to the target device."
                " If your model employs weight sharing on TPU,"
                " please tie your weights using the `on_post_move_to_device` model hook.\n"
                f"Layer count: [Before: {pre_layer_count} After: {post_layer_count}]"
            )

        return module