Ejemplo n.º 1
0
    def __call__(self, filename):
        """
        Args:
            filename (str, list, tuple, file): path file or file-like object or a list of files.
        """
        filename = ensure_tuple(filename)
        img_array = list()
        compatible_meta = None
        for name in filename:
            img = Image.open(name)
            data = np.asarray(img)
            if self.dtype:
                data = data.astype(self.dtype)
            img_array.append(data)
            meta = dict()
            meta["filename_or_obj"] = name
            meta["spatial_shape"] = data.shape[:2]
            meta["format"] = img.format
            meta["mode"] = img.mode
            meta["width"] = img.width
            meta["height"] = img.height
            meta["info"] = img.info

            if self.image_only:
                continue

            if not compatible_meta:
                compatible_meta = meta
            else:
                assert np.allclose(
                    meta["spatial_shape"], compatible_meta["spatial_shape"]
                ), "all the images in the list should have same spatial shape."

        img_array = np.stack(img_array,
                             axis=0) if len(img_array) > 1 else img_array[0]
        return img_array if self.image_only else (img_array, compatible_meta)
Ejemplo n.º 2
0
    def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None):
        """
        Args:
            img: target data content that save into file. The image should be channel-first, shape: `[C,H,W,[D]]`.
            meta_data: key-value pairs of metadata corresponding to the data.
        """
        subject = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index)
        patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None
        filename = self.folder_layout.filename(subject=f"{subject}", idx=patch_index)
        if meta_data and len(ensure_tuple(meta_data.get("spatial_shape", ()))) == len(img.shape):
            self.data_kwargs["channel_dim"] = None

        err = []
        for writer_cls in self.writers:
            try:
                writer_obj = writer_cls(**self.init_kwargs)
                writer_obj.set_data_array(data_array=img, **self.data_kwargs)
                writer_obj.set_metadata(meta_dict=meta_data, **self.meta_kwargs)
                writer_obj.write(filename, **self.write_kwargs)
                self.writer_obj = writer_obj
            except Exception as e:
                err.append(traceback.format_exc())
                logging.getLogger(self.__class__.__name__).debug(e, exc_info=True)
                logging.getLogger(self.__class__.__name__).info(
                    f"{writer_cls.__class__.__name__}: unable to write {filename}.\n"
                )
            else:
                self._data_index += 1
                return img
        msg = "\n".join([f"{e}" for e in err])
        raise RuntimeError(
            f"{self.__class__.__name__} cannot find a suitable writer for {filename}.\n"
            "    Please install the writer libraries, see also the installation instructions:\n"
            "    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
            f"   The current registered writers for {self.output_ext}: {self.writers}.\n{msg}"
        )
Ejemplo n.º 3
0
    def __init__(self, keys: KeysCollection, times: int, names: KeysCollection):
        """
        Args:
            keys: keys of the corresponding items to be transformed.
                See also: :py:class:`monai.transforms.compose.MapTransform`
            times: expected copy times, for example, if keys is "img", times is 3,
                it will add 3 copies of "img" data to the dictionary.
            names: the names coresponding to the newly copied data,
                the length should match `len(keys) x times`. for example, if keys is ["img", "seg"]
                and times is 2, names can be: ["img_1", "seg_1", "img_2", "seg_2"].

        Raises:
            ValueError: times must be greater than 0.
            ValueError: length of names does not match `len(keys) x times`.

        """
        super().__init__(keys)
        if times < 1:
            raise ValueError("times must be greater than 0.")
        self.times = times
        names = ensure_tuple(names)
        if len(names) != (len(self.keys) * times):
            raise ValueError("length of names does not match `len(keys) x times`.")
        self.names = names
Ejemplo n.º 4
0
    def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray],
             **kwargs):
        """
        Read whole slide image objects from given file or list of files.

        Args:
            data: file name or a list of file names to read.
            kwargs: additional args that overrides `self.kwargs` for existing keys.
                For more details look at https://github.com/rapidsai/cucim/blob/main/cpp/include/cucim/cuimage.h

        Returns:
            whole slide image object or list of such objects

        """
        wsi_list: List = []

        filenames: Sequence[PathLike] = ensure_tuple(data)
        kwargs_ = self.kwargs.copy()
        kwargs_.update(kwargs)
        for filename in filenames:
            wsi = CuImage(filename, **kwargs_)
            wsi_list.append(wsi)

        return wsi_list if len(filenames) > 1 else wsi_list[0]
Ejemplo n.º 5
0
 def __init__(
     self,
     keys: KeysCollection,
     source_key: str,
     select_fn: Callable = lambda x: x > 0,
     channel_indexes: Optional[IndexSelection] = None,
     margin: int = 0,
 ) -> None:
     """
     Args:
         keys: keys of the corresponding items to be transformed.
             See also: :py:class:`monai.transforms.compose.MapTransform`
         source_key: data source to generate the bounding box of foreground, can be image or label, etc.
         select_fn: function to select expected foreground, default is to select values > 0.
         channel_indexes: if defined, select foreground only on the specified channels
             of image. if None, select foreground on the whole image.
         margin: add margin to all dims of the bounding box.
     """
     super().__init__(keys)
     self.source_key = source_key
     self.select_fn = select_fn
     self.channel_indexes = ensure_tuple(
         channel_indexes) if channel_indexes is not None else None
     self.margin = margin
Ejemplo n.º 6
0
    def get_data(self, img):
        """
        Extract data array and meta data from loaded data and return them.
        This function returns 2 objects, first is numpy array of image data, second is dict of meta data.
        It constructs `spatial_shape` and stores in meta dict.
        If loading a list of files, stack them together and add a new dimension as first dimension,
        and use the meta data of the first image to represent the stacked result.

        Args:
            img: a PIL Image object loaded from a file or a list of PIL Image objects.

        """
        img_array: List[np.ndarray] = []
        compatible_meta: Dict = {}

        for i in ensure_tuple(img):
            header = self._get_meta_dict(i)
            header["spatial_shape"] = self._get_spatial_shape(i)
            img_array.append(np.asarray(i))
            _copy_compatible_dict(header, compatible_meta)

        img_array_ = np.stack(img_array,
                              axis=0) if len(img_array) > 1 else img_array[0]
        return img_array_, compatible_meta
Ejemplo n.º 7
0
    def __call__(
        self,
        batchdata: Dict[str, torch.Tensor],
        device: Optional[Union[str, torch.device]] = None,
        non_blocking: bool = False,
    ):
        image, label = default_prepare_batch(batchdata, device, non_blocking)
        args = list()
        kwargs = dict()

        def _get_data(key: str):
            data = batchdata[key]
            return data.to(device=device,
                           non_blocking=non_blocking) if isinstance(
                               data, torch.Tensor) else data

        if isinstance(self.extra_keys, (str, list, tuple)):
            for k in ensure_tuple(self.extra_keys):
                args.append(_get_data(k))
        elif isinstance(self.extra_keys, dict):
            for k, v in self.extra_keys.items():
                kwargs.update({k: _get_data(v)})

        return image, label, tuple(args), kwargs
Ejemplo n.º 8
0
    def _find_classes_or_functions(
            self, modnames: Union[Sequence[str], str]) -> Dict[str, List]:
        """
        Find all the classes and functions in the modules with specified `modnames`.

        Args:
            modnames: names of the target modules to find all the classes and functions.

        """
        table: Dict[str, List] = {}
        # all the MONAI modules are already loaded by `load_submodules`
        for modname in ensure_tuple(modnames):
            try:
                # scan all the classes and functions in the module
                module = import_module(modname)
                for name, obj in inspect.getmembers(module):
                    if (inspect.isclass(obj) or inspect.isfunction(obj)
                        ) and obj.__module__ == modname:
                        if name not in table:
                            table[name] = []
                        table[name].append(modname)
            except ModuleNotFoundError:
                pass
        return table
Ejemplo n.º 9
0
    def get_data(self, img):
        """
        Extract data array and meta data from loaded image and return them.
        This function returns 2 objects, first is numpy array of image data, second is dict of meta data.
        It constructs `affine`, `original_affine`, and `spatial_shape` and stores in meta dict.
        If loading a list of files, stack them together and add a new dimension as first dimension,
        and use the meta data of the first image to represent the stacked result.

        Args:
            img: a Nibabel image object loaded from a image file or a list of Nibabel image objects.

        """
        img_array: List[np.ndarray] = list()
        compatible_meta: Dict = None

        for i in ensure_tuple(img):
            header = self._get_meta_dict(i)
            header["original_affine"] = self._get_affine(i)
            header["affine"] = header["original_affine"].copy()
            if self.as_closest_canonical:
                i = nib.as_closest_canonical(i)
                header["affine"] = self._get_affine(i)
            header["as_closest_canonical"] = self.as_closest_canonical
            header["spatial_shape"] = self._get_spatial_shape(i)
            img_array.append(self._get_array_data(i))

            if compatible_meta is None:
                compatible_meta = header
            else:
                if not np.allclose(header["affine"], compatible_meta["affine"]):
                    raise RuntimeError("affine matrix of all images should be same.")
                if not np.allclose(header["spatial_shape"], compatible_meta["spatial_shape"]):
                    raise RuntimeError("spatial_shape of all images should be same.")

        img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
        return img_array_, compatible_meta
Ejemplo n.º 10
0
    def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs):
        """
        Read image data from specified file or files.
        Note that the returned object is PIL image or list of PIL image.

        Args:
            data: file name or a list of file names to read.
            kwargs: additional args for `Image.open` API in `read()`, will override `self.kwargs` for existing keys.
                Mode details about available args:
                https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open

        """
        img_: List[PILImage.Image] = []

        filenames: Sequence[str] = ensure_tuple(data)
        kwargs_ = self.kwargs.copy()
        kwargs_.update(kwargs)
        for name in filenames:
            img = PILImage.open(name, **kwargs_)
            if callable(self.converter):
                img = self.converter(img)
            img_.append(img)

        return img_ if len(filenames) > 1 else img_[0]
Ejemplo n.º 11
0
    def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs):
        """
        Read image data from specified file or files.
        Note that the returned object is CuImage or list of CuImage objects.

        Args:
            data: file name or a list of file names to read.

        """
        if (self.reader_lib == "openslide") and (not has_osl):
            raise ImportError("No module named 'openslide'")
        elif (self.reader_lib == "cuclaraimage") and (not has_cux):
            raise ImportError("No module named 'cuimage'")

        img_: List = []

        filenames: Sequence[str] = ensure_tuple(data)
        for name in filenames:
            img = self.wsi_reader(name)
            if self.reader_lib == "openslide":
                img.shape = (img.dimensions[1], img.dimensions[0], 3)
            img_.append(img)

        return img_ if len(filenames) > 1 else img_[0]
Ejemplo n.º 12
0
def select_cross_validation_folds(partitions: Sequence[Iterable], folds: Union[Sequence[int], int]) -> List:
    """
    Select cross validation data based on data partitions and specified fold index.
    if a list of fold indices is provided, concatenate the partitions of these folds.

    Args:
        partitions: a sequence of datasets, each item is a iterable
        folds: the indices of the partitions to be combined.

    Returns:
        A list of combined datasets.

    Example::

        >>> partitions = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
        >>> select_cross_validation_folds(partitions, 2)
        [5, 6]
        >>> select_cross_validation_folds(partitions, [1, 2])
        [3, 4, 5, 6]
        >>> select_cross_validation_folds(partitions, [-1, 2])
        [9, 10, 5, 6]
    """
    data_list = [data_item for fold_id in ensure_tuple(folds) for data_item in partitions[fold_id]]
    return data_list
Ejemplo n.º 13
0
    def read(self, data: Union[Sequence[str], str], **kwargs):
        """
        Read image data from specified file or files.
        Note that the returned object is ITK image object or list of ITK image objects.

        Args:
            data: file name or a list of file names to read,
            kwargs: additional args for `itk.imread` API, will override `self.kwargs` for existing keys.
                More details about available args:
                https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itkExtras.py

        """
        img_: List[Image] = list()

        filenames: Sequence[str] = ensure_tuple(data)
        kwargs_ = self.kwargs.copy()
        kwargs_.update(kwargs)
        for name in filenames:
            if os.path.isdir(name):
                # read DICOM series of 1 image in a folder, refer to: https://github.com/RSIP-Vision/medio
                names_generator = itk.GDCMSeriesFileNames.New()
                names_generator.SetUseSeriesDetails(True)
                names_generator.AddSeriesRestriction("0008|0021")  # Series Date
                names_generator.SetDirectory(name)
                series_uid = names_generator.GetSeriesUIDs()

                if len(series_uid) == 0:
                    raise FileNotFoundError(f"no DICOMs in: {name}.")
                if len(series_uid) > 1:
                    raise OSError(f"the directory: {name} contains more than one DICOM series.")

                series_identifier = series_uid[0]
                name = names_generator.GetFileNames(series_identifier)

            img_.append(itk.imread(name, **kwargs_))
        return img_ if len(filenames) > 1 else img_[0]
Ejemplo n.º 14
0
 def __init__(self, npz_keys: Optional[KeysCollection] = None, **kwargs):
     super().__init__()
     if npz_keys is not None:
         npz_keys = ensure_tuple(npz_keys)
     self.npz_keys = npz_keys
     self.kwargs = kwargs
Ejemplo n.º 15
0
def generate_param_groups(
    network: torch.nn.Module,
    layer_matches: Sequence[Callable],
    match_types: Sequence[str],
    lr_values: Sequence[float],
    include_others: bool = True,
):
    """
    Utility function to generate parameter groups with different LR values for optimizer.
    The output parameter groups have the same order as `layer_match` functions.

    Args:
        network: source network to generate parameter groups from.
        layer_matches: a list of callable functions to select or filter out network layer groups,
            for "select" type, the input will be the `network`, for "filter" type,
            the input will be every item of `network.named_parameters()`.
            for "select", the parameters will be
            `select_func(network).parameters()`.
            for "filter", the parameters will be
            `map(lambda x: x[1], filter(filter_func, network.named_parameters()))`
        match_types: a list of tags to identify the matching type corresponding to the `layer_matches` functions,
            can be "select" or "filter".
        lr_values: a list of LR values corresponding to the `layer_matches` functions.
        include_others: whether to include the rest layers as the last group, default to True.

    It's mainly used to set different LR values for different network elements, for example:

    .. code-block:: python

        net = Unet(spatial_dims=3, in_channels=1, out_channels=3, channels=[2, 2, 2], strides=[1, 1, 1])
        print(net)  # print out network components to select expected items
        print(net.named_parameters())  # print out all the named parameters to filter out expected items
        params = generate_param_groups(
            network=net,
            layer_matches=[lambda x: x.model[0], lambda x: "2.0.conv" in x[0]],
            match_types=["select", "filter"],
            lr_values=[1e-2, 1e-3],
        )
        # the groups will be a list of dictionaries:
        # [{'params': <generator object Module.parameters at 0x7f9090a70bf8>, 'lr': 0.01},
        #  {'params': <filter object at 0x7f9088fd0dd8>, 'lr': 0.001},
        #  {'params': <filter object at 0x7f9088fd0da0>}]
        optimizer = torch.optim.Adam(params, 1e-4)

    """
    layer_matches = ensure_tuple(layer_matches)
    match_types = ensure_tuple_rep(match_types, len(layer_matches))
    lr_values = ensure_tuple_rep(lr_values, len(layer_matches))

    def _get_select(f):
        def _select():
            return f(network).parameters()

        return _select

    def _get_filter(f):
        def _filter():
            # should eventually generate a list of network parameters
            return map(lambda x: x[1], filter(f, network.named_parameters()))

        return _filter

    params = []
    _layers = []
    for func, ty, lr in zip(layer_matches, match_types, lr_values):
        if ty.lower() == "select":
            layer_params = _get_select(func)
        elif ty.lower() == "filter":
            layer_params = _get_filter(func)
        else:
            raise ValueError(f"unsupported layer match type: {ty}.")

        params.append({"params": layer_params(), "lr": lr})
        _layers.extend(list(map(id, layer_params())))

    if include_others:
        params.append({
            "params":
            filter(lambda p: id(p) not in _layers, network.parameters())
        })

    return params
Ejemplo n.º 16
0
def sliding_window_inference(
    inputs: torch.Tensor,
    roi_size: Union[Sequence[int], int],
    sw_batch_size: int,
    predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor],
                                   Dict[Any, torch.Tensor]]],
    overlap: float = 0.25,
    mode: Union[BlendMode, str] = BlendMode.CONSTANT,
    sigma_scale: Union[Sequence[float], float] = 0.125,
    padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
    cval: float = 0.0,
    sw_device: Union[torch.device, str, None] = None,
    device: Union[torch.device, str, None] = None,
    progress: bool = False,
    roi_weight_map: Union[torch.Tensor, None] = None,
    *args: Any,
    **kwargs: Any,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]:
    """
    Sliding window inference on `inputs` with `predictor`.

    The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
    Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
    e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
    could be ([128,64,256], [64,32,128]).
    In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
    an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
    so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).

    When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
    To maintain the same spatial sizes, the output image will be cropped to the original input size.

    Args:
        inputs: input image to be processed (assuming NCHW[D])
        roi_size: the spatial window size for inferences.
            When its components have None or non-positives, the corresponding inputs dimension will be used.
            if the components of the `roi_size` are non-positive values, the transform will use the
            corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
            to `(32, 64)` if the second spatial dimension size of img is `64`.
        sw_batch_size: the batch size to run window slices.
        predictor: given input tensor ``patch_data`` in shape NCHW[D],
            The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
            with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
            where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
            N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
            the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
            In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
            to ensure the scaled output ROI sizes are still integers.
            If the `predictor`'s input and output spatial sizes are different,
            we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
        overlap: Amount of overlap between scans.
        mode: {``"constant"``, ``"gaussian"``}
            How to blend output of overlapping windows. Defaults to ``"constant"``.

            - ``"constant``": gives equal weight to all predictions.
            - ``"gaussian``": gives less weight to predictions on edges of windows.

        sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
            Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
            When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
            spatial dimensions.
        padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
            Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        cval: fill value for 'constant' padding mode. Default: 0
        sw_device: device for the window data.
            By default the device (and accordingly the memory) of the `inputs` is used.
            Normally `sw_device` should be consistent with the device where `predictor` is defined.
        device: device for the stitched output prediction.
            By default the device (and accordingly the memory) of the `inputs` is used. If for example
            set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
            `inputs` and `roi_size`. Output is on the `device`.
        progress: whether to print a `tqdm` progress bar.
        roi_weight_map: pre-computed (non-negative) weight map for each ROI.
            If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
        args: optional args to be passed to ``predictor``.
        kwargs: optional keyword args to be passed to ``predictor``.

    Note:
        - input must be channel-first and have a batch dim, supports N-D sliding window.

    """
    compute_dtype = inputs.dtype
    num_spatial_dims = len(inputs.shape) - 2
    if overlap < 0 or overlap >= 1:
        raise ValueError("overlap must be >= 0 and < 1.")

    # determine image spatial size and batch size
    # Note: all input images must have the same image size and batch size
    batch_size, _, *image_size_ = inputs.shape

    if device is None:
        device = inputs.device
    if sw_device is None:
        sw_device = inputs.device

    roi_size = fall_back_tuple(roi_size, image_size_)
    # in case that image size is smaller than roi size
    image_size = tuple(
        max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
    pad_size = []
    for k in range(len(inputs.shape) - 1, 1, -1):
        diff = max(roi_size[k - 2] - inputs.shape[k], 0)
        half = diff // 2
        pad_size.extend([half, diff - half])
    inputs = F.pad(inputs,
                   pad=pad_size,
                   mode=look_up_option(padding_mode, PytorchPadMode),
                   value=cval)

    scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims,
                                       overlap)

    # Store all slices in list
    slices = dense_patch_slices(image_size, roi_size, scan_interval)
    num_win = len(slices)  # number of windows per image
    total_slices = num_win * batch_size  # total number of windows

    # Create window-level importance map
    valid_patch_size = get_valid_patch_size(image_size, roi_size)
    if valid_patch_size == roi_size and (roi_weight_map is not None):
        importance_map = roi_weight_map
    else:
        try:
            importance_map = compute_importance_map(valid_patch_size,
                                                    mode=mode,
                                                    sigma_scale=sigma_scale,
                                                    device=device)
        except BaseException as e:
            raise RuntimeError(
                "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
            ) from e
    importance_map = convert_data_type(importance_map, torch.Tensor, device,
                                       compute_dtype)[0]  # type: ignore
    # handle non-positive weights
    min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
    importance_map = torch.clamp(importance_map.to(torch.float32),
                                 min=min_non_zero).to(compute_dtype)

    # Perform predictions
    dict_key, output_image_list, count_map_list = None, [], []
    _initialized_ss = -1
    is_tensor_output = True  # whether the predictor's output is a tensor (instead of dict/tuple)

    # for each patch
    for slice_g in tqdm(range(0, total_slices,
                              sw_batch_size)) if progress else range(
                                  0, total_slices, sw_batch_size):
        slice_range = range(slice_g, min(slice_g + sw_batch_size,
                                         total_slices))
        unravel_slice = [
            [slice(int(idx / num_win),
                   int(idx / num_win) + 1),
             slice(None)] + list(slices[idx % num_win]) for idx in slice_range
        ]
        window_data = torch.cat([
            convert_data_type(inputs[win_slice], torch.Tensor)[0]
            for win_slice in unravel_slice
        ]).to(sw_device)
        seg_prob_out = predictor(window_data, *args,
                                 **kwargs)  # batched patch segmentation

        # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
        seg_prob_tuple: Tuple[torch.Tensor, ...]
        if isinstance(seg_prob_out, torch.Tensor):
            seg_prob_tuple = (seg_prob_out, )
        elif isinstance(seg_prob_out, Mapping):
            if dict_key is None:
                dict_key = sorted(
                    seg_prob_out.keys())  # track predictor's output keys
            seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key)
            is_tensor_output = False
        else:
            seg_prob_tuple = ensure_tuple(seg_prob_out)
            is_tensor_output = False

        # for each output in multi-output list
        for ss, seg_prob in enumerate(seg_prob_tuple):
            seg_prob = seg_prob.to(device)  # BxCxMxNxP or BxCxMxN

            # compute zoom scale: out_roi_size/in_roi_size
            zoom_scale = []
            for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
                    zip(image_size, seg_prob.shape[2:],
                        window_data.shape[2:])):
                _scale = out_w_i / float(in_w_i)
                if not (img_s_i * _scale).is_integer():
                    warnings.warn(
                        f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial "
                        f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs."
                    )
                zoom_scale.append(_scale)

            if _initialized_ss < ss:  # init. the ss-th buffer at the first iteration
                # construct multi-resolution outputs
                output_classes = seg_prob.shape[1]
                output_shape = [batch_size, output_classes] + [
                    int(image_size_d * zoom_scale_d)
                    for image_size_d, zoom_scale_d in zip(
                        image_size, zoom_scale)
                ]
                # allocate memory to store the full output and the count for overlapping parts
                output_image_list.append(
                    torch.zeros(output_shape,
                                dtype=compute_dtype,
                                device=device))
                count_map_list.append(
                    torch.zeros([1, 1] + output_shape[2:],
                                dtype=compute_dtype,
                                device=device))
                _initialized_ss += 1

            # resizing the importance_map
            resizer = Resize(spatial_size=seg_prob.shape[2:],
                             mode="nearest",
                             anti_aliasing=False)

            # store the result in the proper location of the full output. Apply weights from importance map.
            for idx, original_idx in zip(slice_range, unravel_slice):
                # zoom roi
                original_idx_zoom = list(
                    original_idx)  # 4D for 2D image, 5D for 3D image
                for axis in range(2, len(original_idx_zoom)):
                    zoomed_start = original_idx[axis].start * zoom_scale[axis -
                                                                         2]
                    zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
                    if not zoomed_start.is_integer() or (
                            not zoomed_end.is_integer()):
                        warnings.warn(
                            f"For axis-{axis-2} of output[{ss}], the output roi range is not int. "
                            f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). "
                            f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. "
                            f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n"
                            f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. "
                            "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works."
                        )
                    original_idx_zoom[axis] = slice(int(zoomed_start),
                                                    int(zoomed_end), None)
                importance_map_zoom = resizer(
                    importance_map.unsqueeze(0))[0].to(compute_dtype)
                # store results and weights
                output_image_list[ss][
                    original_idx_zoom] += importance_map_zoom * seg_prob[
                        idx - slice_g]
                count_map_list[ss][original_idx_zoom] += (
                    importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(
                        count_map_list[ss][original_idx_zoom].shape))

    # account for any overlapping sections
    for ss in range(len(output_image_list)):
        output_image_list[ss] = (output_image_list[ss] /
                                 count_map_list.pop(0)).to(compute_dtype)

    # remove padding if image_size smaller than roi_size
    for ss, output_i in enumerate(output_image_list):
        if torch.isnan(output_i).any() or torch.isinf(output_i).any():
            warnings.warn(
                "Sliding window inference results contain NaN or Inf.")

        zoom_scale = [
            seg_prob_map_shape_d / roi_size_d
            for seg_prob_map_shape_d, roi_size_d in zip(
                output_i.shape[2:], roi_size)
        ]

        final_slicing: List[slice] = []
        for sp in range(num_spatial_dims):
            slice_dim = slice(
                pad_size[sp * 2],
                image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2])
            slice_dim = slice(
                int(
                    round(slice_dim.start *
                          zoom_scale[num_spatial_dims - sp - 1])),
                int(
                    round(slice_dim.stop *
                          zoom_scale[num_spatial_dims - sp - 1])),
            )
            final_slicing.insert(0, slice_dim)
        while len(final_slicing) < len(output_i.shape):
            final_slicing.insert(0, slice(None))
        output_image_list[ss] = output_i[final_slicing]

    if dict_key is not None:  # if output of predictor is a dict
        final_output = dict(zip(dict_key, output_image_list))
    else:
        final_output = tuple(output_image_list)  # type: ignore
    final_output = final_output[
        0] if is_tensor_output else final_output  # type: ignore
    if isinstance(inputs, MetaTensor):
        final_output = convert_to_dst_type(final_output,
                                           inputs)[0]  # type: ignore
    return final_output
Ejemplo n.º 17
0
    def forward(self,
                src,
                theta,
                spatial_size: Optional[Union[Sequence[int], int]] = None):
        """
        ``theta`` must be an affine transformation matrix with shape
        3x3 or Nx3x3 or Nx2x3 or 2x3 for spatial 2D transforms,
        4x4 or Nx4x4 or Nx3x4 or 3x4 for spatial 3D transforms,
        where `N` is the batch size. `theta` will be converted into float Tensor for the computation.

        Args:
            src (array_like): image in spatial 2D or 3D (N, C, spatial_dims),
                where N is the batch dim, C is the number of channels.
            theta (array_like): Nx3x3, Nx2x3, 3x3, 2x3 for spatial 2D inputs,
                Nx4x4, Nx3x4, 3x4, 4x4 for spatial 3D inputs. When the batch dimension is omitted,
                `theta` will be repeated N times, N is the batch dim of `src`.
            spatial_size: output spatial shape, the full output shape will be
                `[N, C, *spatial_size]` where N and C are inferred from the `src`.

        Raises:
            TypeError: When ``theta`` is not a ``torch.Tensor``.
            ValueError: When ``theta`` is not one of [Nxdxd, dxd].
            ValueError: When ``theta`` is not one of [Nx3x3, Nx4x4].
            TypeError: When ``src`` is not a ``torch.Tensor``.
            ValueError: When ``src`` spatially is not one of [2D, 3D].
            ValueError: When affine and image batch dimension differ.

        """
        # validate `theta`
        if not torch.is_tensor(theta):
            raise TypeError(
                f"theta must be torch.Tensor but is {type(theta).__name__}.")
        if theta.ndim not in (2, 3):
            raise ValueError(f"theta must be Nxdxd or dxd, got {theta.shape}.")
        if theta.ndim == 2:
            theta = theta[None]  # adds a batch dim.
        theta = theta.clone()  # no in-place change of theta
        theta_shape = tuple(theta.shape[1:])
        if theta_shape in ((2, 3), (3, 4)):  # needs padding to dxd
            pad_affine = torch.tensor([0, 0, 1] if theta_shape[0] ==
                                      2 else [0, 0, 0, 1])
            pad_affine = pad_affine.repeat(theta.shape[0], 1, 1).to(theta)
            pad_affine.requires_grad = False
            theta = torch.cat([theta, pad_affine], dim=1)
        if tuple(theta.shape[1:]) not in ((3, 3), (4, 4)):
            raise ValueError(
                f"theta must be Nx3x3 or Nx4x4, got {theta.shape}.")

        # validate `src`
        if not torch.is_tensor(src):
            raise TypeError(
                f"src must be torch.Tensor but is {type(src).__name__}.")
        sr = src.ndim - 2  # input spatial rank
        if sr not in (2, 3):
            raise ValueError(
                f"Unsupported src dimension: {sr}, available options are [2, 3]."
            )

        # set output shape
        src_size = tuple(src.shape)
        dst_size = src_size  # default to the src shape
        if self.spatial_size is not None:
            dst_size = src_size[:2] + self.spatial_size
        if spatial_size is not None:
            dst_size = src_size[:2] + ensure_tuple(spatial_size)

        # reverse and normalise theta if needed
        if not self.normalized:
            theta = to_norm_affine(affine=theta,
                                   src_size=src_size[2:],
                                   dst_size=dst_size[2:],
                                   align_corners=self.align_corners)
        if self.reverse_indexing:
            rev_idx = torch.as_tensor(range(sr - 1, -1, -1), device=src.device)
            theta[:, :sr] = theta[:, rev_idx]
            theta[:, :, :sr] = theta[:, :, rev_idx]
        if (theta.shape[0] == 1) and src_size[0] > 1:
            # adds a batch dim to `theta` in order to match `src`
            theta = theta.repeat(src_size[0], 1, 1)
        if theta.shape[0] != src_size[0]:
            raise ValueError(
                f"affine and image batch dimension must match, got affine={theta.shape[0]} image={src_size[0]}."
            )

        grid = nn.functional.affine_grid(theta=theta[:, :sr],
                                         size=list(dst_size),
                                         align_corners=self.align_corners)
        dst = nn.functional.grid_sample(
            input=src.contiguous(),
            grid=grid,
            mode=self.mode.value,
            padding_mode=self.padding_mode.value,
            align_corners=self.align_corners,
        )
        return dst
Ejemplo n.º 18
0
    def __init__(
        self,
        device: torch.device,
        max_epochs: int,
        amp: bool,
        data_loader: DataLoader,
        prepare_batch: Callable = default_prepare_batch,
        iteration_update: Optional[Callable] = None,
        post_transform: Optional[Callable] = None,
        key_metric: Optional[Dict[str, Metric]] = None,
        additional_metrics: Optional[Dict[str, Metric]] = None,
        handlers: Optional[Sequence] = None,
    ) -> None:
        if iteration_update is not None:
            super().__init__(iteration_update)
        else:
            super().__init__(self._iteration)
        # FIXME:
        if amp:
            self.logger.info("Will add AMP support when PyTorch v1.6 released.")
        if not isinstance(device, torch.device):
            raise ValueError("device must be PyTorch device object.")
        if not isinstance(data_loader, DataLoader):
            raise ValueError("data_loader must be PyTorch DataLoader.")

        # set all sharable data for the workflow based on Ignite engine.state
        self.state = State(
            seed=0,
            iteration=0,
            epoch=0,
            max_epochs=max_epochs,
            epoch_length=-1,
            output=None,
            batch=None,
            metrics={},
            dataloader=None,
            device=device,
            amp=amp,
            key_metric_name=None,  # we can set many metrics, only use key_metric to compare and save the best model
            best_metric=-1,
            best_metric_epoch=-1,
        )
        self.data_loader = data_loader
        self.prepare_batch = prepare_batch

        if post_transform is not None:

            @self.on(Events.ITERATION_COMPLETED)
            def run_post_transform(engine: Engine):
                assert post_transform is not None
                engine.state.output = apply_transform(post_transform, engine.state.output)

        if key_metric is not None:

            if not isinstance(key_metric, dict):
                raise ValueError("key_metric must be a dict object.")
            self.state.key_metric_name = list(key_metric.keys())[0]
            metrics = key_metric
            if additional_metrics is not None and len(additional_metrics) > 0:
                if not isinstance(additional_metrics, dict):
                    raise ValueError("additional_metrics must be a dict object.")
                metrics.update(additional_metrics)
            for name, metric in metrics.items():
                metric.attach(self, name)

            @self.on(Events.EPOCH_COMPLETED)
            def _compare_metrics(engine: Engine):
                if engine.state.key_metric_name is not None:
                    current_val_metric = engine.state.metrics[engine.state.key_metric_name]
                    if current_val_metric > engine.state.best_metric:
                        self.logger.info(f"Got new best metric of {engine.state.key_metric_name}: {current_val_metric}")
                        engine.state.best_metric = current_val_metric
                        engine.state.best_metric_epoch = engine.state.epoch

        if handlers is not None:
            handlers_ = ensure_tuple(handlers)
            for handler in handlers_:
                handler.attach(self)
Ejemplo n.º 19
0
def write_metrics_reports(
    save_dir: str,
    images: Optional[Sequence[str]],
    metrics: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]],
    metric_details: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]],
    summary_ops: Optional[Union[str, Sequence[str]]],
    deli: str = "\t",
    output_type: str = "csv",
):
    """
    Utility function to write the metrics into files, contains 3 parts:
    1. if `metrics` dict is not None, write overall metrics into file, every line is a metric name and value pair.
    2. if `metric_details` dict is not None,  write raw metric data of every image into file, every line for 1 image.
    3. if `summary_ops` is not None, compute summary based on operations on `metric_details` and write to file.

    Args:
        save_dir: directory to save all the metrics reports.
        images: name or path of every input image corresponding to the metric_details data.
            if None, will use index number as the filename of every input image.
        metrics: a dictionary of (metric name, metric value) pairs.
        metric_details: a dictionary of (metric name, metric raw values) pairs, usually, it comes from metrics computation,
            for example, the raw value can be the mean_dice of every channel of every input image.
        summary_ops: expected computation operations to generate the summary report.
            it can be: None, "*" or list of strings.
            None - don't generate summary report for every expected metric_details
            "*" - generate summary report for every metric_details with all the supported operations.
            list of strings - generate summary report for every metric_details with specified operations, they
            should be within this list: [`mean`, `median`, `max`, `min`, `90percent`, `std`].
            default to None.
        deli: the delimiter character in the file, default to "\t".
        output_type: expected output file type, supported types: ["csv"], default to "csv".

    """
    if output_type.lower() != "csv":
        raise ValueError(f"unsupported output type: {output_type}.")

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    if metrics is not None and len(metrics) > 0:
        with open(os.path.join(save_dir, "metrics.csv"), "w") as f:
            for k, v in metrics.items():
                f.write(f"{k}{deli}{str(v)}\n")
    if metric_details is not None and len(metric_details) > 0:
        for k, v in metric_details.items():
            if isinstance(v, torch.Tensor):
                v = v.cpu().numpy()
            if v.ndim == 0:
                # reshape to [1, 1] if no batch and class dims
                v = v.reshape((1, 1))
            elif v.ndim == 1:
                # reshape to [N, 1] if no class dim
                v = v.reshape((-1, 1))

            # add the average value of all classes to v
            class_labels = ["class" + str(i) for i in range(v.shape[1])] + ["mean"]
            v = np.concatenate([v, np.nanmean(v, axis=1, keepdims=True)], axis=1)

            with open(os.path.join(save_dir, f"{k}_raw.csv"), "w") as f:
                f.write(f"filename{deli}{deli.join(class_labels)}\n")
                for i, b in enumerate(v):
                    f.write(f"{images[i] if images is not None else str(i)}{deli}{deli.join([str(c) for c in b])}\n")

            if summary_ops is not None:
                supported_ops = OrderedDict(
                    {
                        "mean": np.nanmean,
                        "median": np.nanmedian,
                        "max": np.nanmax,
                        "min": np.nanmin,
                        "90percent": lambda x: np.nanpercentile(x, 10),
                        "std": np.nanstd,
                    }
                )
                ops = ensure_tuple(summary_ops)
                if "*" in ops:
                    ops = tuple(supported_ops.keys())

                with open(os.path.join(save_dir, f"{k}_summary.csv"), "w") as f:
                    f.write(f"class{deli}{deli.join(ops)}\n")
                    for i, c in enumerate(np.transpose(v)):
                        f.write(f"{class_labels[i]}{deli}{deli.join([f'{supported_ops[k](c):.4f}' for k in ops])}\n")
Ejemplo n.º 20
0
    def __init__(
        self,
        latent_shape: Sequence[int],
        start_shape: Sequence[int],
        channels: Sequence[int],
        strides: Sequence[int],
        kernel_size: Union[Sequence[int], int] = 3,
        num_res_units: int = 2,
        act=Act.PRELU,
        norm=Norm.INSTANCE,
        dropout: Optional[float] = None,
        bias: bool = True,
    ) -> None:
        """
        Construct the generator network with the number of layers defined by `channels` and `strides`. In the
        forward pass a `nn.Linear` layer relates the input latent vector to a tensor of dimensions `start_shape`,
        this is then fed forward through the sequence of convolutional layers. The number of layers is defined by
        the length of `channels` and `strides` which must match, each layer having the number of output channels
        given in `channels` and an upsample factor given in `strides` (ie. a transpose convolution with that stride
        size).

        Args:
            latent_shape: tuple of integers stating the dimension of the input latent vector (minus batch dimension)
            start_shape: tuple of integers stating the dimension of the tensor to pass to convolution subnetwork
            channels: tuple of integers stating the output channels of each convolutional layer
            strides: tuple of integers stating the stride (upscale factor) of each convolutional layer
            kernel_size: integer or tuple of integers stating size of convolutional kernels
            num_res_units: integer stating number of convolutions in residual units, 0 means no residual units
            act: name or type defining activation layers
            norm: name or type defining normalization layers
            dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout
            bias: boolean stating if convolution layers should have a bias component
        """
        super().__init__()

        self.in_channels, *self.start_shape = ensure_tuple(start_shape)
        self.dimensions = len(self.start_shape)

        self.latent_shape = ensure_tuple(latent_shape)
        self.channels = ensure_tuple(channels)
        self.strides = ensure_tuple(strides)
        self.kernel_size = ensure_tuple_rep(kernel_size, self.dimensions)
        self.num_res_units = num_res_units
        self.act = act
        self.norm = norm
        self.dropout = dropout
        self.bias = bias

        self.flatten = nn.Flatten()
        self.linear = nn.Linear(int(np.prod(self.latent_shape)), int(np.prod(start_shape)))
        self.reshape = Reshape(*start_shape)
        self.conv = nn.Sequential()

        echannel = self.in_channels

        # transform tensor of shape `start_shape' into output shape through transposed convolutions and residual units
        for i, (c, s) in enumerate(zip(channels, strides)):
            is_last = i == len(channels) - 1
            layer = self._get_layer(echannel, c, s, is_last)
            self.conv.add_module("layer_%i" % i, layer)
            echannel = c
Ejemplo n.º 21
0
 def __init__(
     self,
     keys,
     guidance="guidance",
     axis: int = 0,
     meta_keys: Optional[KeysCollection] = None,
     meta_key_postfix: str = "meta_dict",
     allow_missing_keys: bool = False,
 ):
     super().__init__(keys, allow_missing_keys)
     self.guidance = guidance
     self.axis = axis
     self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)
     if len(self.keys) != len(self.meta_keys):
         raise ValueError("meta_keys should have the same length as keys.")
     self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
Ejemplo n.º 22
0
 def __init__(  # pytype: disable=annotation-type-mismatch
     self,
     select_labels: Union[Sequence[int], int],
     merge_channels: bool = False) -> None:  # pytype: disable=annotation-type-mismatch
     self.select_labels = ensure_tuple(select_labels)
     self.merge_channels = merge_channels
Ejemplo n.º 23
0
def generate_param_groups(
    network: torch.nn.Module,
    layer_matches: Sequence[Callable],
    match_types: Sequence[str],
    lr_values: Sequence[float],
    include_others: bool = True,
):
    """
    Utility function to generate parameter groups with different LR values for optimizer.
    The output parameter groups have the same order as `layer_match` functions.

    Args:
        network: source network to generate parameter groups from.
        layer_matches: a list of callable functions to select or filter out network layer groups,
            for "select" type, the input will be the `network`, for "filter" type,
            the input will be every item of `network.named_parameters()`.
        match_types: a list of tags to identify the matching type corresponding to the `layer_matches` functions,
            can be "select" or "filter".
        lr_values: a list of LR values corresponding to the `layer_matches` functions.
        include_others: whether to incude the rest layers as the last group, default to True.

    It's mainly used to set different init LR values for different network elements, for example::

        net = Unet(dimensions=3, in_channels=1, out_channels=3, channels=[2, 2, 2], strides=[1, 1, 1])
        print(net)  # print out network components to select expected items
        print(net.named_parameters())  # print out all the named parameters to filter out expected items
        params = generate_param_groups(
            network=net,
            layer_matches=[lambda x: x.model[-1], lambda x: "conv.weight" in x],
            match_types=["select", "filter"],
            lr_values=[1e-2, 1e-3],
        )
        optimizer = torch.optim.Adam(params, 1e-4)

    """
    layer_matches = ensure_tuple(layer_matches)
    match_types = ensure_tuple(match_types)
    lr_values = ensure_tuple(lr_values)
    if len(layer_matches) != len(lr_values) or len(layer_matches) != len(
            match_types):
        raise ValueError(
            "length of layer_match callable functions, match types and LR values should be the same."
        )

    params = list()
    _layers = list()
    for func, ty, lr in zip(layer_matches, match_types, lr_values):
        if ty == "select":
            layer_params = func(network).parameters()
        elif ty == "filter":
            layer_params = filter(func, network.named_parameters())
        else:
            raise ValueError(f"unsuppoted layer match type: {ty}.")

        params.append({"params": layer_params, "lr": lr})
        _layers.extend(list(map(id, layer_params)))

    if include_others:
        params.append({
            "params":
            filter(lambda p: id(p) not in _layers, network.parameters())
        })

    return params
Ejemplo n.º 24
0
 def __init__(self, excludes: Optional[Union[Sequence[str], str]] = None):
     self.excludes = [] if excludes is None else ensure_tuple(excludes)
     self._components_table: Optional[Dict[str, List]] = None
Ejemplo n.º 25
0
 def __init__(self, transforms: Optional[Union[Sequence[Callable], Callable]] = None, prob: float = 0.5) -> None:
     if transforms is None:
         transforms = []
     self.transforms = ensure_tuple(transforms)
     self.set_random_state(seed=get_seed())
     self.prob = prob
Ejemplo n.º 26
0
    def __init__(
        self,
        keys: KeysCollection,
        source_key: str,
        spatial_size: Union[Sequence[int], np.ndarray],
        select_fn: Callable = lambda x: x > 0,
        channel_indices: Optional[IndexSelection] = None,
        margin: int = 0,
        meta_keys: Optional[KeysCollection] = None,
        meta_key_postfix="meta_dict",
        start_coord_key: str = "foreground_start_coord",
        end_coord_key: str = "foreground_end_coord",
        original_shape_key: str = "foreground_original_shape",
        cropped_shape_key: str = "foreground_cropped_shape",
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)

        self.source_key = source_key
        self.spatial_size = list(spatial_size)
        self.select_fn = select_fn
        self.channel_indices = channel_indices
        self.margin = margin
        self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)
        if len(self.keys) != len(self.meta_keys):
            raise ValueError("meta_keys should have the same length as keys.")
        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
        self.start_coord_key = start_coord_key
        self.end_coord_key = end_coord_key
        self.original_shape_key = original_shape_key
        self.cropped_shape_key = cropped_shape_key
Ejemplo n.º 27
0
def write_metrics_reports(
    save_dir: str,
    images: Optional[Sequence[str]],
    metrics: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]],
    metric_details: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]],
    summary_ops: Optional[Union[str, Sequence[str]]],
    deli: str = "\t",
    output_type: str = "csv",
):
    """
    Utility function to write the metrics into files, contains 3 parts:
    1. if `metrics` dict is not None, write overall metrics into file, every line is a metric name and value pair.
    2. if `metric_details` dict is not None,  write raw metric data of every image into file, every line for 1 image.
    3. if `summary_ops` is not None, compute summary based on operations on `metric_details` and write to file.

    Args:
        save_dir: directory to save all the metrics reports.
        images: name or path of every input image corresponding to the metric_details data.
            if None, will use index number as the filename of every input image.
        metrics: a dictionary of (metric name, metric value) pairs.
        metric_details: a dictionary of (metric name, metric raw values) pairs, usually, it comes from metrics
            computation, for example, the raw value can be the mean_dice of every channel of every input image.
        summary_ops: expected computation operations to generate the summary report.
            it can be: None, "*" or list of strings, default to None.
            None - don't generate summary report for every expected metric_details.
            "*" - generate summary report for every metric_details with all the supported operations.
            list of strings - generate summary report for every metric_details with specified operations, they
            should be within list: ["mean", "median", "max", "min", "<int>percentile", "std", "notnans"].
            the number in "<int>percentile" should be [0, 100], like: "15percentile". default: "90percentile".
            for more details, please check: https://numpy.org/doc/stable/reference/generated/numpy.nanpercentile.html.
            note that: for the overall summary, it computes `nanmean` of all classes for each image first,
            then compute summary. example of the generated summary report::

                class    mean    median    max    5percentile 95percentile  notnans
                class0  6.0000   6.0000   7.0000   5.1000      6.9000       2.0000
                class1  6.0000   6.0000   6.0000   6.0000      6.0000       1.0000
                mean    6.2500   6.2500   7.0000   5.5750      6.9250       2.0000

        deli: the delimiter character in the file, default to "\t".
        output_type: expected output file type, supported types: ["csv"], default to "csv".

    """
    if output_type.lower() != "csv":
        raise ValueError(f"unsupported output type: {output_type}.")

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    if metrics is not None and len(metrics) > 0:
        with open(os.path.join(save_dir, "metrics.csv"), "w") as f:
            for k, v in metrics.items():
                f.write(f"{k}{deli}{str(v)}\n")
    if metric_details is not None and len(metric_details) > 0:
        for k, v in metric_details.items():
            if isinstance(v, torch.Tensor):
                v = v.cpu().numpy()
            if v.ndim == 0:
                # reshape to [1, 1] if no batch and class dims
                v = v.reshape((1, 1))
            elif v.ndim == 1:
                # reshape to [N, 1] if no class dim
                v = v.reshape((-1, 1))

            # add the average value of all classes to v
            class_labels = ["class" + str(i)
                            for i in range(v.shape[1])] + ["mean"]
            v = np.concatenate([v, np.nanmean(v, axis=1, keepdims=True)],
                               axis=1)

            with open(os.path.join(save_dir, f"{k}_raw.csv"), "w") as f:
                f.write(f"filename{deli}{deli.join(class_labels)}\n")
                for i, b in enumerate(v):
                    f.write(
                        f"{images[i] if images is not None else str(i)}{deli}{deli.join([str(c) for c in b])}\n"
                    )

            if summary_ops is not None:
                supported_ops = OrderedDict({
                    "mean":
                    np.nanmean,
                    "median":
                    np.nanmedian,
                    "max":
                    np.nanmax,
                    "min":
                    np.nanmin,
                    "90percentile":
                    lambda x: np.nanpercentile(x[0], x[1]),
                    "std":
                    np.nanstd,
                    "notnans":
                    lambda x: (~np.isnan(x)).sum(),
                })
                ops = ensure_tuple(summary_ops)
                if "*" in ops:
                    ops = tuple(supported_ops.keys())

                def _compute_op(op: str, d: np.ndarray):
                    if op.endswith("percentile"):
                        threshold = int(op.split("percentile")[0])
                        return supported_ops["90percentile"]((d, threshold))
                    else:
                        return supported_ops[op](d)

                with open(os.path.join(save_dir, f"{k}_summary.csv"),
                          "w") as f:
                    f.write(f"class{deli}{deli.join(ops)}\n")
                    for i, c in enumerate(np.transpose(v)):
                        f.write(
                            f"{class_labels[i]}{deli}{deli.join([f'{_compute_op(k, c):.4f}' for k in ops])}\n"
                        )
Ejemplo n.º 28
0
    def __init__(
        self,
        keys: KeysCollection,
        guidance: str,
        spatial_size,
        margin=20,
        meta_keys: Optional[KeysCollection] = None,
        meta_key_postfix="meta_dict",
        start_coord_key: str = "foreground_start_coord",
        end_coord_key: str = "foreground_end_coord",
        original_shape_key: str = "foreground_original_shape",
        cropped_shape_key: str = "foreground_cropped_shape",
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)

        self.guidance = guidance
        self.spatial_size = list(spatial_size)
        self.margin = margin
        self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)
        if len(self.keys) != len(self.meta_keys):
            raise ValueError("meta_keys should have the same length as keys.")
        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
        self.start_coord_key = start_coord_key
        self.end_coord_key = end_coord_key
        self.original_shape_key = original_shape_key
        self.cropped_shape_key = cropped_shape_key
Ejemplo n.º 29
0
def convert_tables_to_dicts(
    dfs,
    row_indices: Optional[Sequence[Union[int, str]]] = None,
    col_names: Optional[Sequence[str]] = None,
    col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None,
    col_groups: Optional[Dict[str, Sequence[str]]] = None,
    **kwargs,
) -> List[Dict[str, Any]]:
    """
    Utility to join pandas tables, select rows, columns and generate groups.
    Will return a list of dictionaries, every dictionary maps to a row of data in tables.

    Args:
        dfs: data table in pandas Dataframe format. if providing a list of tables, will join them.
        row_indices: indices of the expected rows to load. it should be a list,
            every item can be a int number or a range `[start, end)` for the indices.
            for example: `row_indices=[[0, 100], 200, 201, 202, 300]`. if None,
            load all the rows in the file.
        col_names: names of the expected columns to load. if None, load all the columns.
        col_types: `type` and `default value` to convert the loaded columns, if None, use original data.
            it should be a dictionary, every item maps to an expected column, the `key` is the column
            name and the `value` is None or a dictionary to define the default value and data type.
            the supported keys in dictionary are: ["type", "default"], and note that the value of `default`
            should not be `None`. for example::

                col_types = {
                    "subject_id": {"type": str},
                    "label": {"type": int, "default": 0},
                    "ehr_0": {"type": float, "default": 0.0},
                    "ehr_1": {"type": float, "default": 0.0},
                }

        col_groups: args to group the loaded columns to generate a new column,
            it should be a dictionary, every item maps to a group, the `key` will
            be the new column name, the `value` is the names of columns to combine. for example:
            `col_groups={"ehr": [f"ehr_{i}" for i in range(10)], "meta": ["meta_1", "meta_2"]}`
        kwargs: additional arguments for `pandas.merge()` API to join tables.

    """
    df = reduce(lambda l, r: pd.merge(l, r, **kwargs), ensure_tuple(dfs))
    # parse row indices
    rows: List[Union[int, str]] = []
    if row_indices is None:
        rows = slice(df.shape[0])  # type: ignore
    else:
        for i in row_indices:
            if isinstance(i, (tuple, list)):
                if len(i) != 2:
                    raise ValueError(
                        "range of row indices must contain 2 values: start and end."
                    )
                rows.extend(list(range(i[0], i[1])))
            else:
                rows.append(i)

    # convert to a list of dictionaries corresponding to every row
    data_ = df.loc[rows] if col_names is None else df.loc[rows, col_names]
    if isinstance(col_types, dict):
        # fill default values for NaN
        defaults = {
            k: v["default"]
            for k, v in col_types.items()
            if v is not None and v.get("default") is not None
        }
        if defaults:
            data_ = data_.fillna(value=defaults)
        # convert data types
        types = {
            k: v["type"]
            for k, v in col_types.items() if v is not None and "type" in v
        }
        if types:
            data_ = data_.astype(dtype=types)
    data: List[Dict] = data_.to_dict(orient="records")

    # group columns to generate new column
    if col_groups is not None:
        groups: Dict[str, List] = {}
        for name, cols in col_groups.items():
            groups[name] = df.loc[rows, cols].values
        # invert items of groups to every row of data
        data = [
            dict(d, **{k: v[i]
                       for k, v in groups.items()}) for i, d in enumerate(data)
        ]

    return data
Ejemplo n.º 30
0
 def __init__(
     self,
     keys: KeysCollection,
     ref_image: str,
     slice_only: bool = False,
     mode: Union[Sequence[Union[InterpolateMode, str]], InterpolateMode, str] = InterpolateMode.NEAREST,
     align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None,
     meta_keys: Optional[str] = None,
     meta_key_postfix: str = "meta_dict",
     start_coord_key: str = "foreground_start_coord",
     end_coord_key: str = "foreground_end_coord",
     original_shape_key: str = "foreground_original_shape",
     cropped_shape_key: str = "foreground_cropped_shape",
     allow_missing_keys: bool = False,
 ) -> None:
     super().__init__(keys, allow_missing_keys)
     self.ref_image = ref_image
     self.slice_only = slice_only
     self.mode = ensure_tuple_rep(mode, len(self.keys))
     self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))
     self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)
     if len(self.keys) != len(self.meta_keys):
         raise ValueError("meta_keys should have the same length as keys.")
     self.meta_key_postfix = meta_key_postfix
     self.start_coord_key = start_coord_key
     self.end_coord_key = end_coord_key
     self.original_shape_key = original_shape_key
     self.cropped_shape_key = cropped_shape_key