Exemplo n.º 1
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional

import torch

from ponai.inferers import SimpleInferer
from ponai.utils import exact_version, optional_import
from ponai.engines.utils import CommonKeys as Keys
from ponai.engines.utils import default_prepare_batch
from ponai.engines.workflow import Workflow

Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine")
Metric, _ = optional_import("ignite.metrics", "0.3.0", exact_version, "Metric")


class Trainer(Workflow):
    """
    Base class for all kinds of trainers, inherits from Workflow.

    """
    def run(self) -> None:
        """
        Execute training based on Ignite Engine.
        If call this function multiple times, it will continuously run from the previous state.

        """
        if self._is_done(self.state):
Exemplo n.º 2
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import warnings
from typing import Optional, Callable

import torch
import numpy as np

from ponai.config import IndexSelection
from ponai.utils import ensure_tuple, ensure_tuple_size, fall_back_tuple, optional_import, min_version

measure, _ = optional_import("skimage.measure", "0.14.2", min_version)


def rand_choice(prob: float = 0.5) -> bool:
    """
    Returns True if a randomly chosen number is less than or equal to `prob`, by default this is a 50/50 chance.
    """
    return bool(random.random() <= prob)


def img_bounds(img):
    """
    Returns the minimum and maximum indices of non-zero lines in axis 0 of `img`, followed by that for axis 1.
    """
    ax0 = np.any(img, axis=0)
    ax1 = np.any(img, axis=1)
Exemplo n.º 3
0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from typing import Callable

from ponai.utils import exact_version, optional_import

Events, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Events")
Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine")


class MetricLogger:
    def __init__(self,
                 loss_transform: Callable = lambda x: x,
                 metric_transform: Callable = lambda x: x):
        self.loss_transform = loss_transform
        self.metric_transform = metric_transform
        self.loss: list = []
        self.metrics: defaultdict = defaultdict(list)

    def attach(self, engine: Engine):
        return engine.add_event_handler(Events.ITERATION_COMPLETED, self)
Exemplo n.º 4
0
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable
import torch

from ponai.utils import exact_version, optional_import
from ponai.engines.utils import get_devices_spec

create_supervised_trainer, _ = optional_import("ignite.engine", "0.3.0",
                                               exact_version,
                                               "create_supervised_trainer")
create_supervised_evaluator, _ = optional_import(
    "ignite.engine", "0.3.0", exact_version, "create_supervised_evaluator")
_prepare_batch, _ = optional_import("ignite.engine", "0.3.0", exact_version,
                                    "_prepare_batch")


def _default_transform(_x, _y, _y_pred, loss):
    return loss.item()


def _default_eval_transform(x, y, y_pred):
    return y_pred, y

#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence, Union

import numpy as np
import torch

from ponai.transforms import rescale_array
from ponai.utils import optional_import

PIL, _ = optional_import("PIL")
GifImage, _ = optional_import("PIL.GifImagePlugin", name="Image")
summary_pb2, _ = optional_import("tensorboard.compat.proto.summary_pb2")
SummaryWriter, _ = optional_import("torch.utils.tensorboard",
                                   name="SummaryWriter")


def _image3_animated_gif(tag: str,
                         image: Union[np.ndarray, torch.Tensor],
                         scale_factor: float = 1.0):
    """Function to actually create the animated gif.

    Args:
        tag: Data identifier
        image: 3D image tensors expected to be in `HWD` format
        scale_factor: amount to multiply values by. if the image data is between 0 and 1, using 255 for this value will
Exemplo n.º 6
0
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Sequence, Union, List

import torch

from ponai.metrics import compute_roc_auc
from ponai.utils import exact_version, optional_import, Average

Metric, _ = optional_import("ignite.metrics", "0.3.0", exact_version, "Metric")


class ROCAUC(Metric):
    """
    Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC).
    accumulating predictions and the ground-truth during an epoch and applying `compute_roc_auc`.

    Args:
        to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
        softmax: whether to add softmax function to `y_pred` before computation. Defaults to False.
        average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
            Type of averaging performed if not binary classification. Defaults to ``"macro"``.

            - ``"macro"``: calculate metrics for each label, and find their unweighted mean.
                This does not take label imbalance into account.
Exemplo n.º 7
0
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Optional

import torch

from ponai.utils import exact_version, optional_import

Events, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Events")
Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine")
Checkpoint, _ = optional_import("ignite.handlers", "0.3.0", exact_version, "Checkpoint")


class CheckpointLoader:
    """
    CheckpointLoader acts as an Ignite handler to load checkpoint data from file.
    It can load variables for network, optimizer, lr_scheduler, etc.
    If saving checkpoint after `torch.nn.DataParallel`, need to save `model.module` instead
    as PyTorch recommended and then use this loader to load the model.

    Args:
        load_path: the file path of checkpoint, it should be a PyTorch `pth` file.
        load_dict (dict): target objects that load checkpoint to. examples::
Exemplo n.º 8
0
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Sequence, Union

import torch

from ponai.metrics import DiceMetric
from ponai.utils import exact_version, optional_import, MetricReduction

NotComputableError, _ = optional_import("ignite.exceptions", "0.3.0",
                                        exact_version, "NotComputableError")
Metric, _ = optional_import("ignite.metrics", "0.3.0", exact_version, "Metric")
reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.3.0",
                                        exact_version, "reinit__is_reduced")
sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.3.0",
                                     exact_version, "sync_all_reduce")


class MeanDice(Metric):
    """
    Computes Dice score metric from full size Tensor and collects average over batch, class-channels, iterations.
    """
    def __init__(
        self,
        include_background: bool = True,
        to_onehot_y: bool = False,
Exemplo n.º 9
0
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import numpy as np

from ponai.transforms import Resize
from ponai.utils import ensure_tuple_rep, min_version, optional_import, InterpolateMode

Image, _ = optional_import("PIL", name="Image")


def write_png(
    data,
    file_name: str,
    output_spatial_shape=None,
    mode: Union[InterpolateMode, str] = InterpolateMode.BICUBIC,
    scale=None,
):
    """
    Write numpy data into png files to disk.
    Spatially it supports HW for 2D.(H,W) or (H,W,3) or (H,W,4).
    If `scale` is None, expect the input data in `np.uint8` or `np.uint16` type.
    It's based on the Image module in PIL library:
    https://pillow.readthedocs.io/en/stable/reference/Image.html
Exemplo n.º 10
0
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union

import os
import warnings
import math
from itertools import starmap, product
import torch
from torch.utils.data._utils.collate import default_collate
import numpy as np
from ponai.utils import ensure_tuple_size, ensure_tuple_rep, optional_import, NumpyPadMode, BlendMode
from ponai.networks.layers.simplelayers import GaussianFilter

nib, _ = optional_import("nibabel")


def get_random_patch(dims,
                     patch_size,
                     rand_state: Optional[np.random.RandomState] = None):
    """
    Returns a tuple of slices to define a random patch in an array of shape `dims` with size `patch_size` or the as
    close to it as possible within the given dimension. It is expected that `patch_size` is a valid patch for a source
    of shape `dims` as returned by `get_valid_patch_size`.

    Args:
        dims (tuple of int): shape of source array
        patch_size (tuple of int): shape of patch size to generate
        rand_state (np.random.RandomState): a random state object to generate random numbers from
Exemplo n.º 11
0
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from urllib.request import urlretrieve
from urllib.error import URLError
import hashlib
import tarfile
import zipfile
from ponai.utils import progress_bar, optional_import

gdown, has_gdown = optional_import("gdown", "3.6")


def check_md5(filepath: str, md5_value: str = None):
    """
    check MD5 signature of specified file.

    Args:
        filepath: path of source file to verify MD5.
        md5_value: expected MD5 value of the file.

    """
    if md5_value is not None:
        md5 = hashlib.md5()
        with open(filepath, "rb") as f:
            for chunk in iter(lambda: f.read(1024 * 1024), b""):
Exemplo n.º 12
0
"""
A collection of "vanilla" transforms for IO functions
https://github.com/Project-ponai/ponai/wiki/ponai_Design
"""

from typing import Optional

import numpy as np

from torch.utils.data._utils.collate import np_str_obj_array_pattern

from ponai.data.utils import correct_nifti_header_if_necessary
from ponai.transforms.compose import Transform
from ponai.utils import optional_import, ensure_tuple

nib, _ = optional_import("nibabel")
Image, _ = optional_import("PIL.Image")


class LoadNifti(Transform):
    """
    Load Nifti format file or files from provided path. If loading a list of
    files, stack them together and add a new dimension as first dimension, and
    use the meta data of the first image to represent the stacked result. Note
    that the affine transform of all the images should be same if ``image_only=False``.
    """
    def __init__(self,
                 as_closest_canonical: bool = False,
                 image_only: bool = False,
                 dtype: Optional[np.dtype] = np.float32):
        """