示例#1
0
    def __init__(
        self,
        compute_on_step: bool = True,
        dist_sync_on_step: bool = False,
        process_group: Optional[Any] = None,
        dist_sync_fn: Callable = None,
    ):
        super().__init__()

        # see (https://github.com/pytorch/pytorch/blob/3e6bb5233f9ca2c5aa55d9cda22a7ee85439aa6e/
        # torch/nn/modules/module.py#L227)
        torch._C._log_api_usage_once(
            f"torchmetrics.metric.{self.__class__.__name__}")

        self._LIGHTNING_GREATER_EQUAL_1_3 = _compare_version(
            "pytorch_lightning", operator.ge, "1.3.0")

        self.dist_sync_on_step = dist_sync_on_step
        self.compute_on_step = compute_on_step
        self.process_group = process_group
        self.dist_sync_fn = dist_sync_fn
        self._to_sync = True

        self._update_signature = inspect.signature(self.update)
        self.update = self._wrap_update(self.update)
        self.compute = self._wrap_compute(self.compute)
        self._computed = None
        self._forward_cache = None
        self._update_called = False

        # initialize state
        self._defaults = {}
        self._persistent = {}
        self._reductions = {}
示例#2
0
文件: metric.py 项目: hlin09/metrics
    def __init__(
        self,
        compute_on_step: bool = True,
        dist_sync_on_step: bool = False,
        process_group: Optional[Any] = None,
        dist_sync_fn: Callable = None,
    ):
        super().__init__()
        self._LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", operator.ge, "1.3.0")

        self.dist_sync_on_step = dist_sync_on_step
        self.compute_on_step = compute_on_step
        self.process_group = process_group
        self.dist_sync_fn = dist_sync_fn
        self._to_sync = True

        self._update_signature = inspect.signature(self.update)
        self.update = self._wrap_update(self.update)
        self.compute = self._wrap_compute(self.compute)
        self._computed = None
        self._forward_cache = None

        # initialize state
        self._defaults = {}
        self._persistent = {}
        self._reductions = {}
示例#3
0
import operator
import random

import numpy
import torch

from torchmetrics.utilities.imports import _TORCH_LOWER_1_4, _TORCH_LOWER_1_5, _TORCH_LOWER_1_6, _compare_version

_MARK_TORCH_MIN_1_4 = dict(condition=_TORCH_LOWER_1_4,
                           reason="required PT >= 1.4")
_MARK_TORCH_MIN_1_5 = dict(condition=_TORCH_LOWER_1_5,
                           reason="required PT >= 1.5")
_MARK_TORCH_MIN_1_6 = dict(condition=_TORCH_LOWER_1_6,
                           reason="required PT >= 1.6")

_LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning",
                                                operator.ge, "1.3.0")


def seed_all(seed):
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
示例#4
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import operator
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Dict, Optional, Union

from torchmetrics.utilities.imports import _compare_version

import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
from pytorch_lightning.utilities.imports import _package_available

_RICH_AVAILABLE: bool = _package_available("rich") and _compare_version(
    "rich", operator.ge, "10.2.2")

Task, Style = None, None
if _RICH_AVAILABLE:
    from rich import get_console, reconfigure
    from rich.console import RenderableType
    from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID, TextColumn
    from rich.progress_bar import ProgressBar
    from rich.style import Style
    from rich.text import Text

    class CustomBarColumn(BarColumn):
        """Overrides ``BarColumn`` to provide support for dataloaders that do not define a size (infinite size)
        such as ``IterableDataset``."""
        def render(self, task: "Task") -> ProgressBar:
            """Gets a progress bar widget for a task."""