def __init__( self, optimization: Optional[Literal["trace", "script", "onnx", "tensorrt"]] = None, server: Literal["fastapi", "ml_server", "torchserve", "sagemaker"] = "fastapi", host: str = "127.0.0.1", port: int = 8080, timeout: int = 10, exit_on_failure: bool = True, ): super().__init__() fastapi_installed = _RequirementAvailable("fastapi") if not fastapi_installed: raise ModuleNotFoundError(fastapi_installed.message) uvicorn_installed = _RequirementAvailable("uvicorn") if not uvicorn_installed: raise ModuleNotFoundError(uvicorn_installed.message) # TODO: Add support for the other options if optimization is not None: raise NotImplementedError(f"The optimization {optimization} is currently not supported.") # TODO: Add support for testing with those server services if server != "fastapi": raise NotImplementedError("Only the fastapi server is currently supported.") self.optimization = optimization self.host = host self.port = port self.server = server self.timeout = timeout self.exit_on_failure = exit_on_failure self.resp: Optional[requests.Response] = None
get_default_process_group_backend_for_device, log, ) from pytorch_lightning.utilities.enums import AMPType, PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _RequirementAvailable from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_info from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, LRSchedulerTypeUnion, STEP_OUTPUT from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache warning_cache = WarningCache() _DEEPSPEED_AVAILABLE: bool = _RequirementAvailable("deepspeed") if _DEEPSPEED_AVAILABLE: import deepspeed def remove_module_hooks(model: torch.nn.Module) -> None: # todo (tchaton) awaiting this feature to move upstream to DeepSpeed for module in model.modules(): module._backward_hooks = OrderedDict() module._is_full_backward_hook = None module._forward_hooks = OrderedDict() module._forward_pre_hooks = OrderedDict() module._state_dict_hooks = OrderedDict() module._load_state_dict_pre_hooks = OrderedDict()
from functools import partial, update_wrapper from types import MethodType from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import torch from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _RequirementAvailable from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn _JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable( "jsonargparse[signatures]>=4.12.0") if _JSONARGPARSE_SIGNATURES_AVAILABLE: import docstring_parser from jsonargparse import ( ActionConfigFile, ArgumentParser, class_from_function, Namespace, register_unresolvable_import_paths, set_config_read_mode, ) register_unresolvable_import_paths( torch ) # Required until fix https://github.com/pytorch/pytorch/issues/74483
from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _RequirementAvailable from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn try: import wandb from wandb.sdk.lib import RunDisabled from wandb.wandb_run import Run except ModuleNotFoundError: # needed for test mocks, these tests shall be updated wandb, Run, RunDisabled = None, None, None # type: ignore _WANDB_AVAILABLE = _RequirementAvailable("wandb") _WANDB_GREATER_EQUAL_0_10_22 = _RequirementAvailable("wandb>=0.10.22") _WANDB_GREATER_EQUAL_0_12_10 = _RequirementAvailable("wandb>=0.12.10") class WandbLogger(Logger): r""" Log using `Weights and Biases <https://docs.wandb.ai/integrations/lightning>`_. **Installation and set-up** Install with pip: .. code-block:: bash pip install wandb
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, _set_scheduler_opt_idx from pytorch_lightning.loggers.logger import DummyLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _RequirementAvailable from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr from pytorch_lightning.utilities.rank_zero import rank_zero_warn from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed if importlib.util.find_spec("ipywidgets") is not None: from tqdm.auto import tqdm else: from tqdm import tqdm _MATPLOTLIB_AVAILABLE = _RequirementAvailable("matplotlib") if _MATPLOTLIB_AVAILABLE and TYPE_CHECKING: import matplotlib.pyplot as plt log = logging.getLogger(__name__) def _determine_lr_attr_name(trainer: "pl.Trainer", model: "pl.LightningModule") -> str: if isinstance(trainer.auto_lr_find, str): if not lightning_hasattr(model, trainer.auto_lr_find): raise MisconfigurationException( f"`auto_lr_find` was set to {trainer.auto_lr_find}, however" " could not find this as a field in `model` or `model.hparams`." ) return trainer.auto_lr_find
from functools import reduce from typing import Any, Callable, Dict, Generator, Mapping, Optional, Sequence, Set, Union from weakref import ReferenceType import torch from torch import Tensor from pytorch_lightning import __version__ from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.imports import _RequirementAvailable from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params from pytorch_lightning.utilities.model_summary import ModelSummary from pytorch_lightning.utilities.rank_zero import rank_zero_only _NEPTUNE_AVAILABLE = _RequirementAvailable("neptune") _NEPTUNE_GREATER_EQUAL_0_9 = _RequirementAvailable("neptune>=0.9.0") if _NEPTUNE_AVAILABLE and _NEPTUNE_GREATER_EQUAL_0_9: try: from neptune import new as neptune from neptune.new.exceptions import NeptuneLegacyProjectException, NeptuneOfflineModeFetchException from neptune.new.run import Run from neptune.new.types import File as NeptuneFile except ModuleNotFoundError: import neptune from neptune.exceptions import NeptuneLegacyProjectException, NeptuneOfflineModeFetchException from neptune.run import Run from neptune.types import File as NeptuneFile else:
def test_requirement_avaliable(): assert _RequirementAvailable(f"torch>={torch.__version__}") assert not _RequirementAvailable(f"torch<{torch.__version__}") assert "Requirement '-' not met" in str(_RequirementAvailable("-"))
from typing import Any, Callable, Optional, TYPE_CHECKING, Union from torch import Tensor from torch.nn import Module from torch.optim import LBFGS, Optimizer import pytorch_lightning as pl from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _RequirementAvailable from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.warnings import WarningCache _DEEPSPEED_GREATER_EQUAL_0_6 = _RequirementAvailable("deepspeed>=0.6.0") if TYPE_CHECKING: if pl.strategies.deepspeed._DEEPSPEED_AVAILABLE: import deepspeed warning_cache = WarningCache() class DeepSpeedPrecisionPlugin(PrecisionPlugin): """Precision plugin for DeepSpeed integration. Args: precision: Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16). amp_type: The mixed precision backend to use ("native" or "apex"). amp_level: The optimization level to use (O1, O2, etc...). By default it will be set to "O2" if ``amp_type`` is set to "apex".