from monai.config import IgniteInfo from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.networks.utils import eval_mode, train_mode from monai.transforms import Transform from monai.utils import ForwardMode, ensure_tuple, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from monai.utils.module import look_up_option if TYPE_CHECKING: from ignite.engine import Engine, EventEnum from ignite.metrics import Metric else: Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric") EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum") __all__ = ["Evaluator", "SupervisedEvaluator", "EnsembleEvaluator"] class Evaluator(Workflow): """ Base class for all kinds of evaluators, inherits from Workflow. Args: device: an object representing the device on which to run. val_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader. epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
import logging from typing import TYPE_CHECKING, Optional from monai.utils import exact_version, optional_import if TYPE_CHECKING: from ignite.engine import Engine, Events import nni as NNi else: Engine, _ = optional_import("ignite.engine", "0.4.7", exact_version, "Engine") Events, _ = optional_import("ignite.engine", "0.4.7", exact_version, "Events") NNi, _ = optional_import("nni") class NNIReporterHandler: """ NNIReporter Args: """ def __init__( self, metric_name: str, max_epochs: int, logger_name: Optional[str] = None, ) -> None: self.metric_name = metric_name self.logger_name = logger_name self.max_epochs = max_epochs
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union import numpy as np import torch from monai.config import NdarrayTensor from monai.transforms import rescale_array from monai.utils import optional_import PIL, _ = optional_import("PIL") GifImage, _ = optional_import("PIL.GifImagePlugin", name="Image") if TYPE_CHECKING: from tensorboard.compat.proto.summary_pb2 import Summary from torch.utils.tensorboard import SummaryWriter else: Summary, _ = optional_import("tensorboard.compat.proto.summary_pb2", name="Summary") SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter") __all__ = ["make_animated_gif_summary", "add_animated_gif", "add_animated_gif_no_channels", "plot_2d_or_3d_image"] def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], scale_factor: float = 1.0) -> Summary: """Function to actually create the animated gif.
# limitations under the License. import json from typing import Callable, Dict, Hashable, List, Optional, Sequence, Union import numpy as np import torch from monai.config import IndexSelection, KeysCollection from monai.networks.layers import GaussianFilter from monai.transforms import Resize, SpatialCrop from monai.transforms.transform import MapTransform, Randomizable, Transform from monai.transforms.utils import generate_spatial_bounding_box, is_positive from monai.utils import InterpolateMode, deprecated_arg, ensure_tuple, ensure_tuple_rep, min_version, optional_import from monai.utils.enums import PostFix measure, _ = optional_import("skimage.measure", "0.14.2", min_version) distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") DEFAULT_POST_FIX = PostFix.meta() # Transforms to support Training for Deepgrow models class FindAllValidSlicesd(Transform): """ Find/List all valid slices in the label. Label is assumed to be a 4D Volume with shape CDHW, where C=1. Args: label: key to the label source. sids: key to store slices indices having valid label map. """
from unittest import skipUnless import torch from parameterized import parameterized from monai.networks import eval_mode from monai.networks.nets import DenseNet121, Densenet169, DenseNet264, densenet201 from monai.utils import optional_import from tests.utils import skip_if_quick, test_script_save if TYPE_CHECKING: import torchvision has_torchvision = True else: torchvision, has_torchvision = optional_import("torchvision") device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASE_1 = [ # 4-channel 3D, batch 2 { "pretrained": False, "spatial_dims": 3, "in_channels": 2, "out_channels": 3, "norm": ("instance", { "eps": 1e-5 }) }, (2, 2, 32, 64, 48), (2, 3),
# See the License for the specific language governing permissions and # limitations under the License. import unittest import torch from ignite.engine import Events from parameterized import parameterized from monai.engines import SupervisedEvaluator from monai.handlers import StatsHandler, from_engine from monai.handlers.nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler from monai.utils import optional_import _, has_nvtx = optional_import( "torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?") TENSOR_0 = torch.tensor([[[[1.0], [2.0]], [[3.0], [4.0]]]]) TENSOR_1 = torch.tensor([[[[0.0], [-2.0]], [[-3.0], [4.0]]]]) TENSOR_1_EXPECTED = torch.tensor([[[1.0], [0.5]], [[0.25], [5.0]]]) TEST_CASE_0 = [[{"image": TENSOR_0}], TENSOR_0[0] + 1.0] TEST_CASE_1 = [[{"image": TENSOR_1}], TENSOR_1_EXPECTED] class TestHandlerDecollateBatch(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) @unittest.skipUnless(has_nvtx, "CUDA is required for NVTX!")
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import numpy as np import torch from parameterized import parameterized from monai.data.utils import list_data_collate from monai.inferers import SlidingWindowInferer, sliding_window_inference from monai.utils import optional_import from tests.utils import TEST_TORCH_AND_META_TENSORS, skip_if_no_cuda _, has_tqdm = optional_import("tqdm") TEST_CASES = [ [(2, 3, 16), (4, ), 3, 0.25, "constant", torch.device("cpu:0")], # 1D small roi [(2, 3, 16, 15, 7, 9), 4, 3, 0.25, "constant", torch.device("cpu:0")], # 4D small roi [(1, 3, 16, 15, 7), (4, -1, 7), 3, 0.25, "constant", torch.device("cpu:0")], # 3D small roi [(2, 3, 16, 15, 7), (4, -1, 7), 3, 0.25, "constant", torch.device("cpu:0")], # 3D small roi [(3, 3, 16, 15, 7), (4, -1, 7), 3, 0.25, "constant", torch.device("cpu:0")], # 3D small roi [(2, 3, 16, 15, 7), (4, -1, 7), 3, 0.25, "constant", torch.device("cpu:0")], # 3D small roi [(1, 3, 16, 15, 7), (4, 10, 7), 3, 0.25, "constant",
# See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Optional, TYPE_CHECKING if TYPE_CHECKING: import ignite.engine import logging import warnings import torch from monai.utils import exact_version, is_scalar, optional_import Events, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Events") DEFAULT_KEY_VAL_FORMAT = "{}: {:.4f} " DEFAULT_TAG = "Loss" class StatsHandler(object): """ StatsHandler defines a set of Ignite Event-handlers for all the log printing logics. It's can be used for any Ignite Engine(trainer, validator and evaluator). And it can support logging for epoch level and iteration level with pre-defined loggers. Default behaviors: - When EPOCH_COMPLETED, logs ``engine.state.metrics`` using ``self.logger``. - When ITERATION_COMPLETED, logs ``self.output_transform(engine.state.output)`` using ``self.logger``.
# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from typing import Tuple, Union import numpy as np import torch from monai.transforms.croppad.array import SpatialCrop from monai.transforms.utils import generate_spatial_bounding_box from monai.utils import MetricReduction, optional_import binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion") distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt") distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") __all__ = ["ignore_background", "do_metric_reduction", "get_mask_edges", "get_surface_distance"] def ignore_background( y_pred: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor], ): """ This function is used to remove background (the first channel) for `y_pred` and `y`. Args: y_pred: predictions. As for classification tasks, `y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,
# http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union import torch from monai.handlers.utils import evenly_divisible_all_gather from monai.metrics import do_metric_reduction from monai.utils import MetricReduction, exact_version, optional_import idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric") reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.4", exact_version, "reinit__is_reduced") if TYPE_CHECKING: from ignite.engine import Engine else: Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") class IterationMetric( Metric): # type: ignore[valid-type, misc] # due to optional_import """ Class for metrics that should be computed on every iteration and compute final results when epoch completed. Similar to the `EpochMetric` in ignite:
se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d, senet154, ) from monai.utils import optional_import from tests.utils import test_pretrained_networks, test_script_save if TYPE_CHECKING: import pretrainedmodels has_cadene_pretrain = True else: pretrainedmodels, has_cadene_pretrain = optional_import("pretrainedmodels") device = "cuda" if torch.cuda.is_available() else "cpu" NET_ARGS = {"spatial_dims": 3, "in_channels": 2, "num_classes": 2} TEST_CASE_1 = [senet154, NET_ARGS] TEST_CASE_2 = [se_resnet50, NET_ARGS] TEST_CASE_3 = [se_resnet101, NET_ARGS] TEST_CASE_4 = [se_resnet152, NET_ARGS] TEST_CASE_5 = [se_resnext50_32x4d, NET_ARGS] TEST_CASE_6 = [se_resnext101_32x4d, NET_ARGS] TEST_CASE_PRETRAINED_1 = [ se_resnet50, { "spatial_dims": 2, "in_channels": 3,
from torch.utils.data import DataLoader from monai_ex.engines.utils import CustomKeys as Keys from monai_ex.engines.utils import default_prepare_batch_ex from monai.engines import Evaluator, SupervisedEvaluator, EnsembleEvaluator from monai.engines.utils import IterationEvents, default_metric_cmp_fn from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform from monai.utils import ForwardMode, ensure_tuple, exact_version, optional_import from monai.visualize.class_activation_maps import ModelWithHooks if TYPE_CHECKING: from ignite.engine import Engine, EventEnum from ignite.metrics import Metric else: Engine, _ = optional_import("ignite.engine", "0.4.7", exact_version, "Engine") Metric, _ = optional_import("ignite.metrics", "0.4.7", exact_version, "Metric") EventEnum, _ = optional_import("ignite.engine", "0.4.7", exact_version, "EventEnum") class SiameseEvaluator(Evaluator): """ Siamese evaluation method with image and label(optional), inherits from evaluator and Workflow. Args: device: an object representing the device on which to run. val_data_loader: Ignite engine use data_loader to run, must be torch.DataLoader. network: use the network to run model forward. epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect.
from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union import torch from torch.utils.data import Dataset as _TorchDataset from monai.data.utils import pickle_hashing from monai.transforms import Compose, Randomizable, Transform, apply_transform from monai.utils import MAX_SEED, get_seed, min_version, optional_import if TYPE_CHECKING: from tqdm import tqdm has_tqdm = True else: tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm") lmdb, _ = optional_import("lmdb") class Dataset(_TorchDataset): """ A generic dataset with a length property and an optional callable data transform when fetching a data sample. For example, typical input data can be a list of dictionaries:: [{ { { 'img': 'image1.nii.gz', 'img': 'image2.nii.gz', 'img': 'image3.nii.gz', 'seg': 'label1.nii.gz', 'seg': 'label2.nii.gz', 'seg': 'label3.nii.gz', 'extra': 123 'extra': 456 'extra': 789 }, }, }]
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence import torch from monai.config import IgniteInfo from monai.metrics import CumulativeIterationMetric from monai.utils import min_version, optional_import idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric") reinit__is_reduced, _ = optional_import("ignite.metrics.metric", IgniteInfo.OPT_IMPORT_VERSION, min_version, "reinit__is_reduced") if TYPE_CHECKING: from ignite.engine import Engine else: Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") class IgniteMetric( Metric): # type: ignore[valid-type, misc] # due to optional_import """
Rotated, Spacingd, SpatialCropd, SpatialPadd, Zoomd, allow_missing_keys_mode, ) from monai.utils import first, get_seed, optional_import, set_determinism from monai.utils.enums import InverseKeys from tests.utils import make_nifti_image, make_rand_affine if TYPE_CHECKING: has_nib = True else: _, has_nib = optional_import("nibabel") KEYS = ["image", "label"] TESTS: List[Tuple] = [] # For pad, start with odd/even images and add odd/even amounts for name in ("1D even", "1D odd"): for val in (3, 4): for t in ( partial(SpatialPadd, spatial_size=val, method="symmetric"), partial(SpatialPadd, spatial_size=val, method="end"), partial(BorderPadd, spatial_border=[val, val + 1]), partial(DivisiblePadd, k=val), partial(ResizeWithPadOrCropd, spatial_size=20 + val), partial(CenterSpatialCropd, roi_size=10 + val),
# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import random import warnings from typing import Callable, List, Optional, Sequence, Tuple, Union import numpy as np import torch from monai.config import IndexSelection from monai.utils import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, min_version, optional_import measure, _ = optional_import("skimage.measure", "0.14.2", min_version) def rand_choice(prob: float = 0.5) -> bool: """ Returns True if a randomly chosen number is less than or equal to `prob`, by default this is a 50/50 chance. """ return bool(random.random() <= prob) def img_bounds(img: np.ndarray) -> np.ndarray: """ Returns the minimum and maximum indices of non-zero lines in axis 0 of `img`, followed by that for axis 1. """ ax0 = np.any(img, axis=0) ax1 = np.any(img, axis=1)
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os from typing import TYPE_CHECKING, Dict, Optional import numpy as np from monai.config import DtypeLike, IgniteInfo from monai.utils import deprecated, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine else: Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") @deprecated( since="0.8", msg_suffix= "use `monai.handler.ProbMapProducer` (with `monai.data.wsi_dataset.SlidingPatchWSIDataset`) instead.", ) class ProbMapProducer: """ Event handler triggered on completing every iteration to save the probability map
# You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Optional, Sequence import torch from monai.metrics import HausdorffDistanceMetric from monai.utils import MetricReduction, exact_version, optional_import NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError") Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "sync_all_reduce") class HausdorffDistance( Metric): # type: ignore[valid-type, misc] # due to optional_import """ Computes Hausdorff distance from full size Tensor and collects average over batch, class-channels, iterations. """ def __init__( self, include_background: bool = False,
import os import unittest from unittest import skipUnless import numpy as np from parameterized import parameterized from monai.apps.pathology.metrics import LesionFROC from monai.utils import optional_import _, has_cucim = optional_import("cucim") _, has_skimage = optional_import("skimage.measure") _, has_sp = optional_import("scipy.ndimage") PILImage, has_pil = optional_import("PIL.Image") def save_as_tif(filename, array): array = array[::-1, ...] # Upside-down img = PILImage.fromarray(array) if not filename.endswith(".tif"): filename += ".tif" img.save(os.path.join("tests", "testing_data", filename)) def around(val, interval=3): return slice(val - interval, val + interval) # mask and prediction image size HEIGHT = 101 WIDTH = 800
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import TYPE_CHECKING, Callable, Dict, Optional, Sequence, Tuple import torch import torch.nn from torch.nn.parallel import DataParallel, DistributedDataParallel from torch.optim.optimizer import Optimizer from monai.engines.utils import get_devices_spec from monai.utils import exact_version, optional_import create_supervised_trainer, _ = optional_import("ignite.engine", "0.4.2", exact_version, "create_supervised_trainer") create_supervised_evaluator, _ = optional_import( "ignite.engine", "0.4.2", exact_version, "create_supervised_evaluator") _prepare_batch, _ = optional_import("ignite.engine", "0.4.2", exact_version, "_prepare_batch") if TYPE_CHECKING: from ignite.engine import Engine from ignite.metrics import Metric else: Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric")
from torch.optim import Optimizer from torch.utils.data import DataLoader from monai.networks.utils import eval_mode from monai.optimizers.lr_scheduler import ExponentialLR, LinearLR from monai.utils import StateCacher, copy_to_device, optional_import if TYPE_CHECKING: import matplotlib.pyplot as plt has_matplotlib = True import tqdm has_tqdm = True else: plt, has_matplotlib = optional_import("matplotlib.pyplot") tqdm, has_tqdm = optional_import("tqdm") __all__ = ["LearningRateFinder"] class DataLoaderIter: def __init__(self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable) -> None: if not isinstance(data_loader, DataLoader): raise ValueError( f"Loader has unsupported type: {type(data_loader)}. Expected type was `torch.utils.data.DataLoader`" ) self.data_loader = data_loader self._iterator = iter(data_loader) self.image_extractor = image_extractor self.label_extractor = label_extractor
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import TYPE_CHECKING, Callable, Dict, Iterable, Optional, Sequence, Union import torch import torch.distributed as dist from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from monai.engines.utils import default_prepare_batch from monai.transforms import apply_transform from monai.utils import ensure_tuple, exact_version, optional_import IgniteEngine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") State, _ = optional_import("ignite.engine", "0.4.4", exact_version, "State") Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine from ignite.metrics import Metric else: Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric") class Workflow(IgniteEngine ): # type: ignore[valid-type, misc] # due to optional_import """
# See the License for the specific language governing permissions and # limitations under the License. import hashlib import logging import os import shutil import tarfile import zipfile from typing import Optional from urllib.error import ContentTooShortError, HTTPError, URLError from urllib.request import Request, urlopen, urlretrieve from monai.utils import optional_import, progress_bar gdown, has_gdown = optional_import("gdown", "3.6") def check_md5(filepath: str, md5_value: Optional[str] = None) -> bool: """ check MD5 signature of specified file. Args: filepath: path of source file to verify MD5. md5_value: expected MD5 value of the file. """ if md5_value is not None: md5 = hashlib.md5() try: with open(filepath, "rb") as f:
import torch.nn.functional as F from torch import nn from torch.autograd import Function from monai.networks.layers.convutils import gaussian_1d, same_padding from monai.networks.layers.factories import Conv from monai.utils import ( PT_BEFORE_1_7, ChannelMatching, InvalidPyTorchVersionError, SkipMode, ensure_tuple_rep, optional_import, ) _C, _ = optional_import("monai._C") if not PT_BEFORE_1_7: fft, _ = optional_import("torch.fft") __all__ = [ "SkipConnection", "Flatten", "GaussianFilter", "LLTM", "Reshape", "separable_filtering", "SavitzkyGolayFilter", "HilbertTransform", "ChannelPad", ]
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Sequence, Union, cast import torch import torch.nn.functional as F from torch import nn from torch.autograd import Function from monai.networks.layers.convutils import gaussian_1d, same_padding from monai.utils import ensure_tuple_rep, optional_import _C, _ = optional_import("monai._C") __all__ = ["SkipConnection", "Flatten", "GaussianFilter", "LLTM", "Reshape"] class SkipConnection(nn.Module): """ Concats the forward pass input with the result from the given submodule. """ def __init__(self, submodule, cat_dim: int = 1) -> None: super().__init__() self.submodule = submodule self.cat_dim = cat_dim def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.cat([x, self.submodule(x)], self.cat_dim)
# you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Union from monai.handlers.utils import string_list_all_gather, write_metrics_reports from monai.utils import ImageMetaKey as Key from monai.utils import ensure_tuple, exact_version, issequenceiterable, optional_import Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") if TYPE_CHECKING: from ignite.engine import Engine else: Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") class MetricsSaver: """ ignite handler to save metrics values and details into expected files. Args: save_dir: directory to save the metrics and metric details. metrics: expected final metrics to save into files, can be: None, "*" or list of strings.
# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from typing import TYPE_CHECKING, Dict, Optional from monai.utils import exact_version, optional_import Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") Checkpoint, _ = optional_import("ignite.handlers", "0.4.2", exact_version, "Checkpoint") BaseSaveHandler, _ = optional_import("ignite.handlers.checkpoint", "0.4.2", exact_version, "BaseSaveHandler") if TYPE_CHECKING: from ignite.engine import Engine from ignite.handlers import DiskSaver else: Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") DiskSaver, _ = optional_import("ignite.handlers", "0.4.2", exact_version, "DiskSaver")
# You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from typing import TYPE_CHECKING, Dict, Optional import torch from monai.utils import exact_version, optional_import Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") Checkpoint, _ = optional_import("ignite.handlers", "0.4.2", exact_version, "Checkpoint") if TYPE_CHECKING: from ignite.engine import Engine else: Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") class CheckpointLoader: """ CheckpointLoader acts as an Ignite handler to load checkpoint data from file. It can load variables for network, optimizer, lr_scheduler, etc. If saving checkpoint after `torch.nn.DataParallel`, need to save `model.module` instead as PyTorch recommended and then use this loader to load the model.
# http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import warnings from typing import TYPE_CHECKING, Any, Callable, Optional import torch from monai.utils import exact_version, is_scalar, optional_import Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine else: Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") DEFAULT_KEY_VAL_FORMAT = "{}: {:.4f} " DEFAULT_TAG = "Loss" class StatsHandler: """ StatsHandler defines a set of Ignite Event-handlers for all the log printing logics. It's can be used for any Ignite Engine(trainer, validator and evaluator). And it can support logging for epoch level and iteration level with pre-defined loggers.
# http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Optional, Union import torch from monai.handlers.utils import evenly_divisible_all_gather from monai.metrics import compute_roc_auc from monai.utils import Average, exact_version, optional_import idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") EpochMetric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "EpochMetric") class ROCAUC(EpochMetric ): # type: ignore[valid-type, misc] # due to optional_import """ Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). accumulating predictions and the ground-truth during an epoch and applying `compute_roc_auc`. Args: to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. softmax: whether to add softmax function to `y_pred` before computation. Defaults to False. other_act: callable function to replace `softmax` as activation layer if needed, Defaults to ``None``. for example: `other_act = lambda x: torch.log_softmax(x)`.