Пример #1
0
from monai.config import IgniteInfo
from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
from monai.networks.utils import eval_mode, train_mode
from monai.transforms import Transform
from monai.utils import ForwardMode, ensure_tuple, min_version, optional_import
from monai.utils.enums import CommonKeys as Keys
from monai.utils.module import look_up_option

if TYPE_CHECKING:
    from ignite.engine import Engine, EventEnum
    from ignite.metrics import Metric
else:
    Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
    Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
    EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")

__all__ = ["Evaluator", "SupervisedEvaluator", "EnsembleEvaluator"]


class Evaluator(Workflow):
    """
    Base class for all kinds of evaluators, inherits from Workflow.

    Args:
        device: an object representing the device on which to run.
        val_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.
        epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
Пример #2
0
import logging
from typing import TYPE_CHECKING, Optional

from monai.utils import exact_version, optional_import

if TYPE_CHECKING:
    from ignite.engine import Engine, Events
    import nni as NNi
else:
    Engine, _ = optional_import("ignite.engine", "0.4.7", exact_version, "Engine")
    Events, _ = optional_import("ignite.engine", "0.4.7", exact_version, "Events")
    NNi, _ = optional_import("nni")


class NNIReporterHandler:
    """
    NNIReporter

    Args:

    """

    def __init__(
        self,
        metric_name: str,
        max_epochs: int,
        logger_name: Optional[str] = None,
    ) -> None:
        self.metric_name = metric_name
        self.logger_name = logger_name
        self.max_epochs = max_epochs
Пример #3
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union

import numpy as np
import torch

from monai.config import NdarrayTensor
from monai.transforms import rescale_array
from monai.utils import optional_import

PIL, _ = optional_import("PIL")
GifImage, _ = optional_import("PIL.GifImagePlugin", name="Image")

if TYPE_CHECKING:
    from tensorboard.compat.proto.summary_pb2 import Summary
    from torch.utils.tensorboard import SummaryWriter
else:
    Summary, _ = optional_import("tensorboard.compat.proto.summary_pb2", name="Summary")
    SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter")


__all__ = ["make_animated_gif_summary", "add_animated_gif", "add_animated_gif_no_channels", "plot_2d_or_3d_image"]


def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], scale_factor: float = 1.0) -> Summary:
    """Function to actually create the animated gif.
Пример #4
0
# limitations under the License.
import json
from typing import Callable, Dict, Hashable, List, Optional, Sequence, Union

import numpy as np
import torch

from monai.config import IndexSelection, KeysCollection
from monai.networks.layers import GaussianFilter
from monai.transforms import Resize, SpatialCrop
from monai.transforms.transform import MapTransform, Randomizable, Transform
from monai.transforms.utils import generate_spatial_bounding_box, is_positive
from monai.utils import InterpolateMode, deprecated_arg, ensure_tuple, ensure_tuple_rep, min_version, optional_import
from monai.utils.enums import PostFix

measure, _ = optional_import("skimage.measure", "0.14.2", min_version)
distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")

DEFAULT_POST_FIX = PostFix.meta()


# Transforms to support Training for Deepgrow models
class FindAllValidSlicesd(Transform):
    """
    Find/List all valid slices in the label.
    Label is assumed to be a 4D Volume with shape CDHW, where C=1.

    Args:
        label: key to the label source.
        sids: key to store slices indices having valid label map.
    """
Пример #5
0
from unittest import skipUnless

import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets import DenseNet121, Densenet169, DenseNet264, densenet201
from monai.utils import optional_import
from tests.utils import skip_if_quick, test_script_save

if TYPE_CHECKING:
    import torchvision

    has_torchvision = True
else:
    torchvision, has_torchvision = optional_import("torchvision")

device = "cuda" if torch.cuda.is_available() else "cpu"

TEST_CASE_1 = [  # 4-channel 3D, batch 2
    {
        "pretrained": False,
        "spatial_dims": 3,
        "in_channels": 2,
        "out_channels": 3,
        "norm": ("instance", {
            "eps": 1e-5
        })
    },
    (2, 2, 32, 64, 48),
    (2, 3),
Пример #6
0
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from ignite.engine import Events
from parameterized import parameterized

from monai.engines import SupervisedEvaluator
from monai.handlers import StatsHandler, from_engine
from monai.handlers.nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler
from monai.utils import optional_import

_, has_nvtx = optional_import(
    "torch._C._nvtx",
    descriptor="NVTX is not installed. Are you sure you have a CUDA build?")

TENSOR_0 = torch.tensor([[[[1.0], [2.0]], [[3.0], [4.0]]]])

TENSOR_1 = torch.tensor([[[[0.0], [-2.0]], [[-3.0], [4.0]]]])

TENSOR_1_EXPECTED = torch.tensor([[[1.0], [0.5]], [[0.25], [5.0]]])

TEST_CASE_0 = [[{"image": TENSOR_0}], TENSOR_0[0] + 1.0]
TEST_CASE_1 = [[{"image": TENSOR_1}], TENSOR_1_EXPECTED]


class TestHandlerDecollateBatch(unittest.TestCase):
    @parameterized.expand([TEST_CASE_0, TEST_CASE_1])
    @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX!")
Пример #7
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.data.utils import list_data_collate
from monai.inferers import SlidingWindowInferer, sliding_window_inference
from monai.utils import optional_import
from tests.utils import TEST_TORCH_AND_META_TENSORS, skip_if_no_cuda

_, has_tqdm = optional_import("tqdm")

TEST_CASES = [
    [(2, 3, 16), (4, ), 3, 0.25, "constant",
     torch.device("cpu:0")],  # 1D small roi
    [(2, 3, 16, 15, 7, 9), 4, 3, 0.25, "constant",
     torch.device("cpu:0")],  # 4D small roi
    [(1, 3, 16, 15, 7), (4, -1, 7), 3, 0.25, "constant",
     torch.device("cpu:0")],  # 3D small roi
    [(2, 3, 16, 15, 7), (4, -1, 7), 3, 0.25, "constant",
     torch.device("cpu:0")],  # 3D small roi
    [(3, 3, 16, 15, 7), (4, -1, 7), 3, 0.25, "constant",
     torch.device("cpu:0")],  # 3D small roi
    [(2, 3, 16, 15, 7), (4, -1, 7), 3, 0.25, "constant",
     torch.device("cpu:0")],  # 3D small roi
    [(1, 3, 16, 15, 7), (4, 10, 7), 3, 0.25, "constant",
Пример #8
0
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, TYPE_CHECKING

if TYPE_CHECKING:
    import ignite.engine

import logging
import warnings

import torch

from monai.utils import exact_version, is_scalar, optional_import

Events, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Events")

DEFAULT_KEY_VAL_FORMAT = "{}: {:.4f} "
DEFAULT_TAG = "Loss"


class StatsHandler(object):
    """
    StatsHandler defines a set of Ignite Event-handlers for all the log printing logics.
    It's can be used for any Ignite Engine(trainer, validator and evaluator).
    And it can support logging for epoch level and iteration level with pre-defined loggers.

    Default behaviors:
        - When EPOCH_COMPLETED, logs ``engine.state.metrics`` using ``self.logger``.
        - When ITERATION_COMPLETED, logs
          ``self.output_transform(engine.state.output)`` using ``self.logger``.
Пример #9
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Tuple, Union

import numpy as np
import torch

from monai.transforms.croppad.array import SpatialCrop
from monai.transforms.utils import generate_spatial_bounding_box
from monai.utils import MetricReduction, optional_import

binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion")
distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt")
distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")

__all__ = ["ignore_background", "do_metric_reduction", "get_mask_edges", "get_surface_distance"]


def ignore_background(
    y_pred: Union[np.ndarray, torch.Tensor],
    y: Union[np.ndarray, torch.Tensor],
):
    """
    This function is used to remove background (the first channel) for `y_pred` and `y`.
    Args:
        y_pred: predictions. As for classification tasks,
            `y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,
Пример #10
0
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union

import torch

from monai.handlers.utils import evenly_divisible_all_gather
from monai.metrics import do_metric_reduction
from monai.utils import MetricReduction, exact_version, optional_import

idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed")
Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric")
reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.4",
                                        exact_version, "reinit__is_reduced")
if TYPE_CHECKING:
    from ignite.engine import Engine
else:
    Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version,
                                "Engine")


class IterationMetric(
        Metric):  # type: ignore[valid-type, misc] # due to optional_import
    """
    Class for metrics that should be computed on every iteration and compute final results when epoch completed.
    Similar to the `EpochMetric` in ignite:
Пример #11
0
    se_resnet50,
    se_resnet101,
    se_resnet152,
    se_resnext50_32x4d,
    se_resnext101_32x4d,
    senet154,
)
from monai.utils import optional_import
from tests.utils import test_pretrained_networks, test_script_save

if TYPE_CHECKING:
    import pretrainedmodels

    has_cadene_pretrain = True
else:
    pretrainedmodels, has_cadene_pretrain = optional_import("pretrainedmodels")

device = "cuda" if torch.cuda.is_available() else "cpu"

NET_ARGS = {"spatial_dims": 3, "in_channels": 2, "num_classes": 2}
TEST_CASE_1 = [senet154, NET_ARGS]
TEST_CASE_2 = [se_resnet50, NET_ARGS]
TEST_CASE_3 = [se_resnet101, NET_ARGS]
TEST_CASE_4 = [se_resnet152, NET_ARGS]
TEST_CASE_5 = [se_resnext50_32x4d, NET_ARGS]
TEST_CASE_6 = [se_resnext101_32x4d, NET_ARGS]

TEST_CASE_PRETRAINED_1 = [
    se_resnet50, {
        "spatial_dims": 2,
        "in_channels": 3,
Пример #12
0
from torch.utils.data import DataLoader

from monai_ex.engines.utils import CustomKeys as Keys
from monai_ex.engines.utils import default_prepare_batch_ex
from monai.engines import Evaluator, SupervisedEvaluator, EnsembleEvaluator
from monai.engines.utils import IterationEvents, default_metric_cmp_fn
from monai.inferers import Inferer, SimpleInferer
from monai.transforms import Transform
from monai.utils import ForwardMode, ensure_tuple, exact_version, optional_import
from monai.visualize.class_activation_maps import ModelWithHooks

if TYPE_CHECKING:
    from ignite.engine import Engine, EventEnum
    from ignite.metrics import Metric
else:
    Engine, _ = optional_import("ignite.engine", "0.4.7", exact_version, "Engine")
    Metric, _ = optional_import("ignite.metrics", "0.4.7", exact_version, "Metric")
    EventEnum, _ = optional_import("ignite.engine", "0.4.7", exact_version, "EventEnum")


class SiameseEvaluator(Evaluator):
    """
    Siamese evaluation method with image and label(optional), inherits from evaluator and Workflow.

    Args:
        device: an object representing the device on which to run.
        val_data_loader: Ignite engine use data_loader to run, must be torch.DataLoader.
        network: use the network to run model forward.
        epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
Пример #13
0
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union

import torch
from torch.utils.data import Dataset as _TorchDataset

from monai.data.utils import pickle_hashing
from monai.transforms import Compose, Randomizable, Transform, apply_transform
from monai.utils import MAX_SEED, get_seed, min_version, optional_import

if TYPE_CHECKING:
    from tqdm import tqdm

    has_tqdm = True
else:
    tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm")

lmdb, _ = optional_import("lmdb")


class Dataset(_TorchDataset):
    """
    A generic dataset with a length property and an optional callable data transform
    when fetching a data sample.
    For example, typical input data can be a list of dictionaries::

        [{                            {                            {
             'img': 'image1.nii.gz',      'img': 'image2.nii.gz',      'img': 'image3.nii.gz',
             'seg': 'label1.nii.gz',      'seg': 'label2.nii.gz',      'seg': 'label3.nii.gz',
             'extra': 123                 'extra': 456                 'extra': 789
         },                           },                           }]
Пример #14
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence

import torch

from monai.config import IgniteInfo
from monai.metrics import CumulativeIterationMetric
from monai.utils import min_version, optional_import

idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION,
                           min_version, "distributed")
Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION,
                            min_version, "Metric")
reinit__is_reduced, _ = optional_import("ignite.metrics.metric",
                                        IgniteInfo.OPT_IMPORT_VERSION,
                                        min_version, "reinit__is_reduced")
if TYPE_CHECKING:
    from ignite.engine import Engine
else:
    Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION,
                                min_version, "Engine")


class IgniteMetric(
        Metric):  # type: ignore[valid-type, misc] # due to optional_import
    """
Пример #15
0
    Rotated,
    Spacingd,
    SpatialCropd,
    SpatialPadd,
    Zoomd,
    allow_missing_keys_mode,
)
from monai.utils import first, get_seed, optional_import, set_determinism
from monai.utils.enums import InverseKeys
from tests.utils import make_nifti_image, make_rand_affine

if TYPE_CHECKING:

    has_nib = True
else:
    _, has_nib = optional_import("nibabel")

KEYS = ["image", "label"]

TESTS: List[Tuple] = []

# For pad, start with odd/even images and add odd/even amounts
for name in ("1D even", "1D odd"):
    for val in (3, 4):
        for t in (
                partial(SpatialPadd, spatial_size=val, method="symmetric"),
                partial(SpatialPadd, spatial_size=val, method="end"),
                partial(BorderPadd, spatial_border=[val, val + 1]),
                partial(DivisiblePadd, k=val),
                partial(ResizeWithPadOrCropd, spatial_size=20 + val),
                partial(CenterSpatialCropd, roi_size=10 + val),
Пример #16
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import warnings
from typing import Callable, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch

from monai.config import IndexSelection
from monai.utils import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, min_version, optional_import

measure, _ = optional_import("skimage.measure", "0.14.2", min_version)


def rand_choice(prob: float = 0.5) -> bool:
    """
    Returns True if a randomly chosen number is less than or equal to `prob`, by default this is a 50/50 chance.
    """
    return bool(random.random() <= prob)


def img_bounds(img: np.ndarray) -> np.ndarray:
    """
    Returns the minimum and maximum indices of non-zero lines in axis 0 of `img`, followed by that for axis 1.
    """
    ax0 = np.any(img, axis=0)
    ax1 = np.any(img, axis=1)
Пример #17
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from typing import TYPE_CHECKING, Dict, Optional

import numpy as np

from monai.config import DtypeLike, IgniteInfo
from monai.utils import deprecated, min_version, optional_import

Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION,
                            min_version, "Events")
if TYPE_CHECKING:
    from ignite.engine import Engine
else:
    Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION,
                                min_version, "Engine")


@deprecated(
    since="0.8",
    msg_suffix=
    "use `monai.handler.ProbMapProducer` (with `monai.data.wsi_dataset.SlidingPatchWSIDataset`) instead.",
)
class ProbMapProducer:
    """
    Event handler triggered on completing every iteration to save the probability map
Пример #18
0
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Sequence

import torch

from monai.metrics import HausdorffDistanceMetric
from monai.utils import MetricReduction, exact_version, optional_import

NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2",
                                        exact_version, "NotComputableError")
Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric")
reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2",
                                        exact_version, "reinit__is_reduced")
sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2",
                                     exact_version, "sync_all_reduce")


class HausdorffDistance(
        Metric):  # type: ignore[valid-type, misc] # due to optional_import
    """
    Computes Hausdorff distance from full size Tensor and collects average over batch, class-channels, iterations.
    """
    def __init__(
        self,
        include_background: bool = False,
Пример #19
0
import os
import unittest
from unittest import skipUnless

import numpy as np
from parameterized import parameterized

from monai.apps.pathology.metrics import LesionFROC
from monai.utils import optional_import

_, has_cucim = optional_import("cucim")
_, has_skimage = optional_import("skimage.measure")
_, has_sp = optional_import("scipy.ndimage")
PILImage, has_pil = optional_import("PIL.Image")


def save_as_tif(filename, array):
    array = array[::-1, ...]  # Upside-down
    img = PILImage.fromarray(array)
    if not filename.endswith(".tif"):
        filename += ".tif"
    img.save(os.path.join("tests", "testing_data", filename))


def around(val, interval=3):
    return slice(val - interval, val + interval)


# mask and prediction image size
HEIGHT = 101
WIDTH = 800
Пример #20
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Callable, Dict, Optional, Sequence, Tuple

import torch
import torch.nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.optim.optimizer import Optimizer

from monai.engines.utils import get_devices_spec
from monai.utils import exact_version, optional_import

create_supervised_trainer, _ = optional_import("ignite.engine", "0.4.2",
                                               exact_version,
                                               "create_supervised_trainer")
create_supervised_evaluator, _ = optional_import(
    "ignite.engine", "0.4.2", exact_version, "create_supervised_evaluator")
_prepare_batch, _ = optional_import("ignite.engine", "0.4.2", exact_version,
                                    "_prepare_batch")
if TYPE_CHECKING:
    from ignite.engine import Engine
    from ignite.metrics import Metric
else:
    Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version,
                                "Engine")
    Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version,
                                "Metric")

Пример #21
0
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from monai.networks.utils import eval_mode
from monai.optimizers.lr_scheduler import ExponentialLR, LinearLR
from monai.utils import StateCacher, copy_to_device, optional_import

if TYPE_CHECKING:
    import matplotlib.pyplot as plt

    has_matplotlib = True
    import tqdm

    has_tqdm = True
else:
    plt, has_matplotlib = optional_import("matplotlib.pyplot")
    tqdm, has_tqdm = optional_import("tqdm")

__all__ = ["LearningRateFinder"]


class DataLoaderIter:
    def __init__(self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable) -> None:
        if not isinstance(data_loader, DataLoader):
            raise ValueError(
                f"Loader has unsupported type: {type(data_loader)}. Expected type was `torch.utils.data.DataLoader`"
            )
        self.data_loader = data_loader
        self._iterator = iter(data_loader)
        self.image_extractor = image_extractor
        self.label_extractor = label_extractor
Пример #22
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Callable, Dict, Iterable, Optional, Sequence, Union

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from monai.engines.utils import default_prepare_batch
from monai.transforms import apply_transform
from monai.utils import ensure_tuple, exact_version, optional_import

IgniteEngine, _ = optional_import("ignite.engine", "0.4.4", exact_version,
                                  "Engine")
State, _ = optional_import("ignite.engine", "0.4.4", exact_version, "State")
Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
if TYPE_CHECKING:
    from ignite.engine import Engine
    from ignite.metrics import Metric
else:
    Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version,
                                "Engine")
    Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version,
                                "Metric")


class Workflow(IgniteEngine
               ):  # type: ignore[valid-type, misc] # due to optional_import
    """
Пример #23
0
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import logging
import os
import shutil
import tarfile
import zipfile
from typing import Optional
from urllib.error import ContentTooShortError, HTTPError, URLError
from urllib.request import Request, urlopen, urlretrieve

from monai.utils import optional_import, progress_bar

gdown, has_gdown = optional_import("gdown", "3.6")


def check_md5(filepath: str, md5_value: Optional[str] = None) -> bool:
    """
    check MD5 signature of specified file.

    Args:
        filepath: path of source file to verify MD5.
        md5_value: expected MD5 value of the file.

    """
    if md5_value is not None:
        md5 = hashlib.md5()
        try:
            with open(filepath, "rb") as f:
Пример #24
0
import torch.nn.functional as F
from torch import nn
from torch.autograd import Function

from monai.networks.layers.convutils import gaussian_1d, same_padding
from monai.networks.layers.factories import Conv
from monai.utils import (
    PT_BEFORE_1_7,
    ChannelMatching,
    InvalidPyTorchVersionError,
    SkipMode,
    ensure_tuple_rep,
    optional_import,
)

_C, _ = optional_import("monai._C")
if not PT_BEFORE_1_7:
    fft, _ = optional_import("torch.fft")

__all__ = [
    "SkipConnection",
    "Flatten",
    "GaussianFilter",
    "LLTM",
    "Reshape",
    "separable_filtering",
    "SavitzkyGolayFilter",
    "HilbertTransform",
    "ChannelPad",
]
Пример #25
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Sequence, Union, cast

import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Function

from monai.networks.layers.convutils import gaussian_1d, same_padding
from monai.utils import ensure_tuple_rep, optional_import

_C, _ = optional_import("monai._C")

__all__ = ["SkipConnection", "Flatten", "GaussianFilter", "LLTM", "Reshape"]


class SkipConnection(nn.Module):
    """
    Concats the forward pass input with the result from the given submodule.
    """
    def __init__(self, submodule, cat_dim: int = 1) -> None:
        super().__init__()
        self.submodule = submodule
        self.cat_dim = cat_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.cat([x, self.submodule(x)], self.cat_dim)
Пример #26
0
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Union

from monai.handlers.utils import string_list_all_gather, write_metrics_reports
from monai.utils import ImageMetaKey as Key
from monai.utils import ensure_tuple, exact_version, issequenceiterable, optional_import

Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed")
if TYPE_CHECKING:
    from ignite.engine import Engine
else:
    Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version,
                                "Engine")


class MetricsSaver:
    """
    ignite handler to save metrics values and details into expected files.

    Args:
        save_dir: directory to save the metrics and metric details.
        metrics: expected final metrics to save into files, can be: None, "*" or list of strings.
Пример #27
0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Dict, Optional

from monai.utils import exact_version, optional_import

Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events")
Checkpoint, _ = optional_import("ignite.handlers", "0.4.2", exact_version,
                                "Checkpoint")
BaseSaveHandler, _ = optional_import("ignite.handlers.checkpoint", "0.4.2",
                                     exact_version, "BaseSaveHandler")

if TYPE_CHECKING:
    from ignite.engine import Engine
    from ignite.handlers import DiskSaver
else:
    Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version,
                                "Engine")
    DiskSaver, _ = optional_import("ignite.handlers", "0.4.2", exact_version,
                                   "DiskSaver")

Пример #28
0
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Dict, Optional

import torch

from monai.utils import exact_version, optional_import

Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events")
Checkpoint, _ = optional_import("ignite.handlers", "0.4.2", exact_version,
                                "Checkpoint")
if TYPE_CHECKING:
    from ignite.engine import Engine
else:
    Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version,
                                "Engine")


class CheckpointLoader:
    """
    CheckpointLoader acts as an Ignite handler to load checkpoint data from file.
    It can load variables for network, optimizer, lr_scheduler, etc.
    If saving checkpoint after `torch.nn.DataParallel`, need to save `model.module` instead
    as PyTorch recommended and then use this loader to load the model.
Пример #29
0
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import warnings
from typing import TYPE_CHECKING, Any, Callable, Optional

import torch

from monai.utils import exact_version, is_scalar, optional_import

Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
if TYPE_CHECKING:
    from ignite.engine import Engine
else:
    Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version,
                                "Engine")

DEFAULT_KEY_VAL_FORMAT = "{}: {:.4f} "
DEFAULT_TAG = "Loss"


class StatsHandler:
    """
    StatsHandler defines a set of Ignite Event-handlers for all the log printing logics.
    It's can be used for any Ignite Engine(trainer, validator and evaluator).
    And it can support logging for epoch level and iteration level with pre-defined loggers.
Пример #30
0
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Optional, Union

import torch

from monai.handlers.utils import evenly_divisible_all_gather
from monai.metrics import compute_roc_auc
from monai.utils import Average, exact_version, optional_import

idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed")
EpochMetric, _ = optional_import("ignite.metrics", "0.4.2", exact_version,
                                 "EpochMetric")


class ROCAUC(EpochMetric
             ):  # type: ignore[valid-type, misc]  # due to optional_import
    """
    Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC).
    accumulating predictions and the ground-truth during an epoch and applying `compute_roc_auc`.

    Args:
        to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
        softmax: whether to add softmax function to `y_pred` before computation. Defaults to False.
        other_act: callable function to replace `softmax` as activation layer if needed, Defaults to ``None``.
            for example: `other_act = lambda x: torch.log_softmax(x)`.