def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None: """Overrides the model's :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers` method if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'.""" parser = self._parser(subcommand) def get_automatic( class_type: Union[Type, Tuple[Type, ...]], register: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] ) -> List[str]: automatic = [] for key, (base_class, link_to) in register.items(): if not isinstance(base_class, tuple): base_class = (base_class,) if link_to == "AUTOMATIC" and any(issubclass(c, class_type) for c in base_class): automatic.append(key) return automatic optimizers = get_automatic(Optimizer, parser._optimizers) lr_schedulers = get_automatic(LRSchedulerTypeTuple, parser._lr_schedulers) if len(optimizers) == 0: return if len(optimizers) > 1 or len(lr_schedulers) > 1: raise MisconfigurationException( f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model` expects at most one optimizer " f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers+lr_schedulers}. In this case the user " "is expected to link the argument groups and implement `configure_optimizers`, see " "https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html" "#optimizers-and-learning-rate-schedulers" ) optimizer_class = parser._optimizers[optimizers[0]][0] optimizer_init = self._get(self.config_init, optimizers[0]) if not isinstance(optimizer_class, tuple): optimizer_init = _global_add_class_path(optimizer_class, optimizer_init) if not optimizer_init: # optimizers were registered automatically but not passed by the user return lr_scheduler_init = None if lr_schedulers: lr_scheduler_class = parser._lr_schedulers[lr_schedulers[0]][0] lr_scheduler_init = self._get(self.config_init, lr_schedulers[0]) if not isinstance(lr_scheduler_class, tuple): lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init) if is_overridden("configure_optimizers", self.model): _warn( f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by " f"`{self.__class__.__name__}.configure_optimizers`." ) optimizer = instantiate_class(self.model.parameters(), optimizer_init) lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) if lr_scheduler_init else None fn = partial(self.configure_optimizers, optimizer=optimizer, lr_scheduler=lr_scheduler) update_wrapper(fn, self.configure_optimizers) # necessary for `is_overridden` # override the existing method self.model.configure_optimizers = MethodType(fn, self.model)
Needs to be run outside of `pytest` as it captures all the warnings. """ import os from contextlib import redirect_stderr from io import StringIO from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.warnings import WarningCache standalone = os.getenv("PL_RUN_STANDALONE_TESTS", "0") == "1" if standalone: stderr = StringIO() # recording with redirect_stderr(stderr): _warn("test1") _warn("test2", category=DeprecationWarning) rank_zero_warn("test3") rank_zero_warn("test4", category=DeprecationWarning) rank_zero_deprecation("test5") cache = WarningCache() cache.warn("test6") cache.deprecation("test7") output = stderr.getvalue() assert "test_warnings.py:31: UserWarning: test1" in output assert "test_warnings.py:32: DeprecationWarning: test2" in output