Exemplo n.º 1
0
 def __init__(self, pytorch_version_tuple):
     self.min_version = pytorch_version_tuple
     if has_pkg_res:
         self.version_too_old = ver(torch.__version__) < ver(".".join(
             map(str, self.min_version)))
     else:
         self.version_too_old = get_torch_version_tuple() < self.min_version
Exemplo n.º 2
0
 def __init__(self, pytorch_version_tuple):
     self.max_version = pytorch_version_tuple
     if has_pkg_res:
         self.version_too_new = ver(torch.__version__) >= ver(".".join(
             map(str, self.max_version)))
     else:
         self.version_too_new = get_torch_version_tuple(
         ) >= self.max_version
Exemplo n.º 3
0
def load_module(module_name: str,
                defines: Optional[dict] = None,
                verbose_build: bool = False,
                build_timeout: int = 300):
    """
    Handles the loading of c++ extension modules.

    Args:
        module_name: Name of the module to load.
            Must match the name of the relevant source directory in the `_extensions` directory.
        defines: Dictionary containing names and values of compilation defines.
        verbose_build: Set to true to enable build logging.
        build_timeout: Time in seconds before the build will throw an exception to prevent hanging.
    """

    # Ensuring named module exists in _extensions directory.
    module_dir = path.join(dir_path, module_name)
    if not path.exists(module_dir):
        raise ValueError(f"No extension module named {module_name}")

    platform_str = f"_{platform.system()}_{platform.python_version()}_"
    platform_str += "".join(f"{v}" for v in get_torch_version_tuple()[:2])
    # Adding configuration to module name.
    if defines is not None:
        module_name = "_".join([module_name] +
                               [f"{v}" for v in defines.values()])

    # Gathering source files.
    source = glob(path.join(module_dir, "**", "*.cpp"), recursive=True)
    if torch.cuda.is_available():
        source += glob(path.join(module_dir, "**", "*.cu"), recursive=True)
        platform_str += f"_{torch.version.cuda}"

    # Constructing compilation argument list.
    define_args = [] if not defines else [
        f"-D {key}={defines[key]}" for key in defines
    ]

    # Ninja may be blocked by something out of our control.
    # This will error if the build takes longer than expected.
    with timeout(
            build_timeout,
            "Build appears to be blocked. Is there a stopped process building the same extension?"
    ):
        load, _ = optional_import(
            "torch.utils.cpp_extension",
            name="load")  # main trigger some JIT config in pytorch
        # This will either run the build or return the existing .so object.
        name = module_name + platform_str.replace(".", "_")
        module = load(
            name=name,
            sources=source,
            extra_cflags=define_args,
            extra_cuda_cflags=define_args,
            verbose=verbose_build,
        )

    return module
Exemplo n.º 4
0
    def __call__(self, engine: Engine) -> None:
        """
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        """
        ws = idist.get_world_size()
        if self.save_rank >= ws:
            raise ValueError(
                "target rank is greater than the distributed group size.")

        _images = self._filenames
        if ws > 1:
            _filenames = self.deli.join(_images)
            if get_torch_version_tuple() > (1, 6, 0):
                # all gather across all processes
                _filenames = self.deli.join(idist.all_gather(_filenames))
            else:
                raise RuntimeError(
                    "MetricsSaver can not save metric details in distributed mode with PyTorch < 1.7.0."
                )
            _images = _filenames.split(self.deli)

        # only save metrics to file in specified rank
        if idist.get_rank() == self.save_rank:
            _metrics = {}
            if self.metrics is not None and len(engine.state.metrics) > 0:
                _metrics = {
                    k: v
                    for k, v in engine.state.metrics.items()
                    if k in self.metrics or "*" in self.metrics
                }
            _metric_details = {}
            if self.metric_details is not None and len(
                    engine.state.metric_details) > 0:
                for k, v in engine.state.metric_details.items():
                    if k in self.metric_details or "*" in self.metric_details:
                        _metric_details[k] = v

            write_metrics_reports(
                save_dir=self.save_dir,
                images=_images,
                metrics=_metrics,
                metric_details=_metric_details,
                summary_ops=self.summary_ops,
                deli=self.deli,
                output_type=self.output_type,
            )
Exemplo n.º 5
0
import random
import string
import unittest
from copy import deepcopy
from typing import Optional, Union

import torch
from parameterized import parameterized

from monai.data.meta_tensor import MetaTensor
from monai.transforms import FromMetaTensord, ToMetaTensord
from monai.utils.enums import PostFix
from monai.utils.module import get_torch_version_tuple
from tests.utils import TEST_DEVICES, assert_allclose

PT_VER_MAJ, PT_VER_MIN = get_torch_version_tuple()

DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64],
          [torch.int32]]
TESTS = []
for _device in TEST_DEVICES:
    for _dtype in DTYPES:
        TESTS.append((*_device, *_dtype))


def rand_string(min_len=5, max_len=10):
    str_size = random.randint(min_len, max_len)
    chars = string.ascii_letters + string.punctuation
    return "".join(random.choice(chars) for _ in range(str_size))