Ejemplo n.º 1
0
 def test_get_set_meta_fns(self):
     set_track_meta(False)
     self.assertEqual(get_track_meta(), False)
     set_track_meta(True)
     self.assertEqual(get_track_meta(), True)
     set_track_transforms(False)
     self.assertEqual(get_track_transforms(), False)
     set_track_transforms(True)
     self.assertEqual(get_track_transforms(), True)
Ejemplo n.º 2
0
    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
        """
        Args:
            img: torch tensor data to extract the contour, with shape: [channels, height, width[, depth]]

        Raises:
            ValueError: When ``image`` ndim is not one of [3, 4].

        Returns:
            A torch tensor with the same shape as img, note:
                1. it's the binary classification result of whether a pixel is edge or not.
                2. in order to keep the original shape of mask image, we use padding as default.
                3. the edge detection is just approximate because it defects inherent to Laplace kernel,
                   ideally the edge should be thin enough, but now it has a thickness.

        """
        img = convert_to_tensor(img, track_meta=get_track_meta())
        img_: torch.Tensor = convert_to_tensor(img, track_meta=False)
        spatial_dims = len(img_.shape) - 1
        img_ = img_.unsqueeze(0)  # adds a batch dim
        if spatial_dims == 2:
            kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]],
                                  dtype=torch.float32)
        elif spatial_dims == 3:
            kernel = -1.0 * torch.ones(3, 3, 3, dtype=torch.float32)
            kernel[1, 1, 1] = 26.0
        else:
            raise ValueError(
                f"{self.__class__} can only handle 2D or 3D images.")
        contour_img = apply_filter(img_, kernel)
        contour_img.clamp_(min=0.0, max=1.0)
        output, *_ = convert_to_dst_type(contour_img.squeeze(0), img)
        return output
Ejemplo n.º 3
0
    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
        """
        Filter the image on the `applied_labels`.

        Args:
            img: Pytorch tensor or numpy array of any shape.

        Raises:
            NotImplementedError: The provided image was not a Pytorch Tensor or numpy array.

        Returns:
            Pytorch tensor or numpy array of the same shape as the input.
        """
        if not isinstance(img, (np.ndarray, torch.Tensor)):
            raise NotImplementedError(
                f"{self.__class__} can not handle data of type {type(img)}.")

        if isinstance(img, torch.Tensor):
            img = convert_to_tensor(img, track_meta=get_track_meta())
            img_ = convert_to_tensor(img, track_meta=False)
            if hasattr(torch, "isin"):  # `isin` is new in torch 1.10.0
                appl_lbls = torch.as_tensor(self.applied_labels,
                                            device=img_.device)
                out = torch.where(torch.isin(img_, appl_lbls), img_,
                                  torch.tensor(0.0).to(img_))
                return convert_to_dst_type(out, dst=img)[0]
            out: NdarrayOrTensor = self(
                img_.detach().cpu().numpy())  # type: ignore
            out = convert_to_dst_type(out, img)[0]  # type: ignore
            return out
        return np.asarray(np.where(np.isin(img, self.applied_labels), img, 0))
Ejemplo n.º 4
0
    def ensure_torch_and_prune_meta(im: NdarrayTensor,
                                    meta: dict,
                                    simple_keys: bool = False):
        """
        Convert the image to `torch.Tensor`. If `affine` is in the `meta` dictionary,
        convert that to `torch.Tensor`, too. Remove any superfluous metadata.

        Args:
            im: Input image (`np.ndarray` or `torch.Tensor`)
            meta: Metadata dictionary.
            simple_keys: whether to keep only a simple subset of metadata keys.

        Returns:
            By default, a `MetaTensor` is returned.
            However, if `get_track_meta()` is `False`, a `torch.Tensor` is returned.
        """
        img = convert_to_tensor(im)  # potentially ascontiguousarray

        # if not tracking metadata, return `torch.Tensor`
        if not get_track_meta() or meta is None:
            return img

        # remove any superfluous metadata.
        if simple_keys:
            # ensure affine is of type `torch.Tensor`
            if "affine" in meta:
                meta["affine"] = convert_to_tensor(
                    meta["affine"])  # bc-breaking
            remove_extra_metadata(meta)  # bc-breaking

        # return the `MetaTensor`
        return MetaTensor(img, meta=meta)
Ejemplo n.º 5
0
    def __call__(self,
                 img: NdarrayOrTensor,
                 randomize: bool = True,
                 device: Optional[torch.device] = None) -> NdarrayOrTensor:
        img = convert_to_tensor(img, track_meta=get_track_meta())
        if randomize:
            self.randomize()

        if not self._do_transform:
            return img

        device = device if device is not None else self.device

        field = self.sfield()

        dgrid = self.grid + field.to(self.grid_dtype)
        dgrid = moveaxis(dgrid, 1, -1)  # type: ignore

        img_t = convert_to_tensor(img[None], torch.float32, device)

        out = grid_sample(
            input=img_t,
            grid=dgrid,
            mode=look_up_option(self.grid_mode, GridSampleMode),
            align_corners=self.grid_align_corners,
            padding_mode=look_up_option(self.grid_padding_mode,
                                        GridSamplePadMode),
        )

        out_t, *_ = convert_to_dst_type(out.squeeze(0), img)

        return out_t
Ejemplo n.º 6
0
    def __call__(self,
                 img: NdarrayOrTensor,
                 randomize: bool = True) -> NdarrayOrTensor:
        """
        Apply the transform to `img`, if `randomize` randomizing the smooth field otherwise reusing the previous.
        """
        img = convert_to_tensor(img, track_meta=get_track_meta())
        if randomize:
            self.randomize()

        if not self._do_transform:
            return img

        img_min = img.min()
        img_max = img.max()
        img_rng = img_max - img_min

        field = self.sfield()
        rfield, *_ = convert_to_dst_type(field, img)

        # everything below here is to be computed using the destination type (numpy, tensor, etc.)

        img = (img - img_min) / (img_rng + 1e-10)  # rescale to unit values
        img = img**rfield  # contrast is changed by raising image data to a power, in this case the field

        out = (img * img_rng
               ) + img_min  # rescale back to the original image value range

        return out
Ejemplo n.º 7
0
    def __call__(
        self, data: Mapping[Hashable, NdarrayOrTensor]
    ) -> Mapping[Hashable, NdarrayOrTensor]:
        self.randomize()
        d = dict(data)
        if not self._do_transform:
            for key in self.key_iterator(d):
                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
            return d

        for idx, key in enumerate(self.key_iterator(d)):
            self.trans.set_mode(self.mode[idx % len(self.mode)])
            d[key] = self.trans(d[key], False)

        return d
Ejemplo n.º 8
0
    def __call__(
        self,
        img: NdarrayOrTensor,
        sigmoid: Optional[bool] = None,
        softmax: Optional[bool] = None,
        other: Optional[Callable] = None,
    ) -> NdarrayOrTensor:
        """
        Args:
            sigmoid: whether to execute sigmoid function on model output before transform.
                Defaults to ``self.sigmoid``.
            softmax: whether to execute softmax function on model output before transform.
                Defaults to ``self.softmax``.
            other: callable function to execute other activation layers, for example:
                `other = torch.tanh`. Defaults to ``self.other``.

        Raises:
            ValueError: When ``sigmoid=True`` and ``softmax=True``. Incompatible values.
            TypeError: When ``other`` is not an ``Optional[Callable]``.
            ValueError: When ``self.other=None`` and ``other=None``. Incompatible values.

        """
        if sigmoid and softmax:
            raise ValueError(
                "Incompatible values: sigmoid=True and softmax=True.")
        if other is not None and not callable(other):
            raise TypeError(
                f"other must be None or callable but is {type(other).__name__}."
            )

        # convert to float as activation must operate on float tensor
        img = convert_to_tensor(img, track_meta=get_track_meta())
        img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float)
        if sigmoid or self.sigmoid:
            img_t = torch.sigmoid(img_t)
        if softmax or self.softmax:
            img_t = torch.softmax(img_t, dim=0)

        act_func = self.other if other is None else other
        if act_func is not None:
            img_t = act_func(img_t)
        out, *_ = convert_to_dst_type(img_t, img)
        return out
Ejemplo n.º 9
0
    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
        """
        Args:
            img: shape must be (C, spatial_dim1[, spatial_dim2, ...]).

        Returns:
            An array with shape (C, spatial_dim1[, spatial_dim2, ...]).
        """
        is_onehot = img.shape[
            0] > 1 if self.is_onehot is None else self.is_onehot
        if self.applied_labels is not None:
            applied_labels = self.applied_labels
        else:
            applied_labels = tuple(get_unique_labels(img, is_onehot,
                                                     discard=0))
        img = convert_to_tensor(img, track_meta=get_track_meta())
        img_: torch.Tensor = convert_to_tensor(img, track_meta=False)
        if self.independent:
            for i in applied_labels:
                foreground = img_[i] > 0 if is_onehot else img_[0] == i
                mask = get_largest_connected_component_mask(
                    foreground, self.connectivity)
                if is_onehot:
                    img_[i][foreground != mask] = 0
                else:
                    img_[0][foreground != mask] = 0
            return convert_to_dst_type(img_, dst=img)[0]
        if not is_onehot:  # not one-hot, union of labels
            labels, *_ = convert_to_dst_type(applied_labels,
                                             dst=img_,
                                             wrap_sequence=True)
            foreground = (img_[..., None] == labels).any(-1)[0]
            mask = get_largest_connected_component_mask(
                foreground, self.connectivity)
            img_[0][foreground != mask] = 0
            return convert_to_dst_type(img_, dst=img)[0]
        # one-hot, union of labels
        foreground = (img_[applied_labels, ...] == 1).any(0)
        mask = get_largest_connected_component_mask(foreground,
                                                    self.connectivity)
        for i in applied_labels:
            img_[i][foreground != mask] = 0
        return convert_to_dst_type(img_, dst=img)[0]
Ejemplo n.º 10
0
    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
        """
        Fill the holes in the provided image.

        Note:
            The value 0 is assumed as background label.

        Args:
            img: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].

        Raises:
            NotImplementedError: The provided image was not a Pytorch Tensor or numpy array.

        Returns:
            Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].
        """
        img = convert_to_tensor(img, track_meta=get_track_meta())
        img_np, *_ = convert_data_type(img, np.ndarray)
        out_np: np.ndarray = fill_holes(img_np, self.applied_labels,
                                        self.connectivity)
        out, *_ = convert_to_dst_type(out_np, img)
        return out
Ejemplo n.º 11
0
    def __call__(self,
                 img: NdarrayOrTensor,
                 randomize: bool = True) -> NdarrayOrTensor:
        """
        Apply the transform to `img`, if `randomize` randomizing the smooth field otherwise reusing the previous.
        """
        img = convert_to_tensor(img, track_meta=get_track_meta())

        if randomize:
            self.randomize()

        if not self._do_transform:
            return img

        field = self.sfield()
        rfield, *_ = convert_to_dst_type(field, img)

        # everything below here is to be computed using the destination type (numpy, tensor, etc.)

        out = img * rfield

        return out
Ejemplo n.º 12
0
    def update_meta(rets: Sequence, func, args, kwargs):
        """Update the metadata from the output of `__torch_function__`.
        The output could be a single object, or a sequence of them. Hence, they get
        converted to a sequence if necessary and then processed by iterating across them.

        For each element, if not of type `MetaTensor`, then nothing to do
        """
        out = []
        metas = None
        for idx, ret in enumerate(rets):
            # if not `MetaTensor`, nothing to do.
            if not isinstance(ret, MetaTensor):
                pass
            # if not tracking, convert to `torch.Tensor`.
            elif not (get_track_meta() or get_track_transforms()):
                ret = ret.as_tensor()
            # else, handle the `MetaTensor` metadata.
            else:
                meta_args = MetaObj.flatten_meta_objs(
                    list(args) + list(kwargs.values()))
                ret._copy_meta(meta_args)

                # If we have a batch of data, then we need to be careful if a slice of
                # the data is returned. Depending on how the data are indexed, we return
                # some or all of the metadata, and the return object may or may not be a
                # batch of data (e.g., `batch[:,-1]` versus `batch[0]`).
                if ret.is_batch:
                    # only decollate metadata once
                    if metas is None:
                        metas = decollate_batch(ret.meta)
                    # if indexing e.g., `batch[0]`
                    if func == torch.Tensor.__getitem__:
                        idx = args[1]
                        if isinstance(idx, Sequence):
                            idx = idx[0]
                        # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
                        # first element will be `slice(None, None, None)` and `Ellipsis`,
                        # respectively. Don't need to do anything with the metadata.
                        if idx not in (slice(None, None, None), Ellipsis):
                            meta = metas[idx]
                            # if using e.g., `batch[0:2]`, then `is_batch` should still be
                            # `True`. Also re-collate the remaining elements.
                            if isinstance(meta, list) and len(meta) > 1:
                                ret.meta = list_data_collate(meta)
                            # if using e.g., `batch[0]` or `batch[0, 1]`, then return single
                            # element from batch, and set `is_batch` to `False`.
                            else:
                                ret.meta = meta
                                ret.is_batch = False
                    # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
                    # But we only want to split the batch if the `unbind` is along the 0th
                    # dimension.
                    elif func == torch.Tensor.unbind:
                        if len(args) > 1:
                            dim = args[1]
                        elif "dim" in kwargs:
                            dim = kwargs["dim"]
                        else:
                            dim = 0
                        if dim == 0:
                            ret.meta = metas[idx]
                            ret.is_batch = False

                ret.affine = ret.affine.to(ret.device)
            out.append(ret)
        # if the input was a tuple, then return it as a tuple
        return tuple(out) if isinstance(rets, tuple) else out
Ejemplo n.º 13
0
    def __call__(
            self,
            img: NdarrayOrTensor,
            argmax: Optional[bool] = None,
            to_onehot: Optional[int] = None,
            threshold: Optional[float] = None,
            rounding: Optional[str] = None,
            n_classes: Optional[int] = None,  # deprecated
            num_classes: Optional[int] = None,  # deprecated
            logit_thresh: Optional[float] = None,  # deprecated
            threshold_values: Optional[bool] = None,  # deprecated
    ) -> NdarrayOrTensor:
        """
        Args:
            img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`,
                will automatically add it.
            argmax: whether to execute argmax function on input data before transform.
                Defaults to ``self.argmax``.
            to_onehot: if not None, convert input data into the one-hot format with specified number of classes.
                Defaults to ``self.to_onehot``.
            threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value.
                Defaults to ``self.threshold``.
            rounding: if not None, round the data according to the specified option,
                available options: ["torchrounding"].

        .. deprecated:: 0.6.0
            ``n_classes`` is deprecated, use ``to_onehot`` instead.

        .. deprecated:: 0.7.0
            ``num_classes`` is deprecated, use ``to_onehot`` instead.
            ``logit_thresh`` is deprecated, use ``threshold`` instead.
            ``threshold_values`` is deprecated, use ``threshold`` instead.

        """
        if isinstance(to_onehot, bool):
            warnings.warn(
                "`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead."
            )
            to_onehot = num_classes if to_onehot else None
        if isinstance(threshold, bool):
            warnings.warn(
                "`threshold_values=True/False` is deprecated, please use `threshold=value` instead."
            )
            threshold = logit_thresh if threshold else None
        img = convert_to_tensor(img, track_meta=get_track_meta())
        img_t, *_ = convert_data_type(img, torch.Tensor)
        if argmax or self.argmax:
            img_t = torch.argmax(img_t, dim=0, keepdim=True)

        to_onehot = self.to_onehot if to_onehot is None else to_onehot
        if to_onehot is not None:
            if not isinstance(to_onehot, int):
                raise AssertionError(
                    "the number of classes for One-Hot must be an integer.")
            img_t = one_hot(img_t, num_classes=to_onehot, dim=0)

        threshold = self.threshold if threshold is None else threshold
        if threshold is not None:
            img_t = img_t >= threshold

        rounding = self.rounding if rounding is None else rounding
        if rounding is not None:
            look_up_option(rounding, ["torchrounding"])
            img_t = torch.round(img_t)

        img, *_ = convert_to_dst_type(img_t, img, dtype=torch.float)
        return img
Ejemplo n.º 14
0
    def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
        """
        Update the metadata from the output of `MetaTensor.__torch_function__`.

        The output of `torch.Tensor.__torch_function__` could be a single object or a
        sequence of them. Hence, in `MetaTensor.__torch_function__` we convert them to a
        list of not already, and then we loop across each element, processing metadata
        as necessary. For each element, if not of type `MetaTensor`, then nothing to do.

        Args:
            rets: the output from `torch.Tensor.__torch_function__`, which has been
                converted to a list in `MetaTensor.__torch_function__` if it wasn't
                already a `Sequence`.
            func: the torch function that was applied. Examples might be `torch.squeeze`
                or `torch.Tensor.__add__`. We need this since the metadata need to be
                treated differently if a batch of data is considered. For example,
                slicing (`torch.Tensor.__getitem__`) the ith element of the 0th
                dimension of a batch of data should return a ith tensor with the ith
                metadata.
            args: positional arguments that were passed to `func`.
            kwargs: keyword arguments that were passed to `func`.

        Returns:
            A sequence with the same number of elements as `rets`. For each element, if
            the input type was not `MetaTensor`, then no modifications will have been
            made. If global parameters have been set to false (e.g.,
            `not get_track_meta()`), then any `MetaTensor` will be converted to
            `torch.Tensor`. Else, metadata will be propagated as necessary (see
            :py:func:`MetaTensor._copy_meta`).
        """
        out = []
        metas = None
        is_batch = any(
            x.is_batch
            for x in MetaObj.flatten_meta_objs(args, kwargs.values())
            if hasattr(x, "is_batch"))
        for idx, ret in enumerate(rets):
            # if not `MetaTensor`, nothing to do.
            if not isinstance(ret, MetaTensor):
                pass
            # if not tracking, convert to `torch.Tensor`.
            elif not get_track_meta():
                ret = ret.as_tensor()
            # else, handle the `MetaTensor` metadata.
            else:
                meta_args = MetaObj.flatten_meta_objs(args, kwargs.values())
                ret.is_batch = is_batch
                ret.copy_meta_from(meta_args, copy_attr=not is_batch)
                # the following is not implemented but the network arch may run into this case:
                # if func == torch.cat and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args):
                #     raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.")

                # If we have a batch of data, then we need to be careful if a slice of
                # the data is returned. Depending on how the data are indexed, we return
                # some or all of the metadata, and the return object may or may not be a
                # batch of data (e.g., `batch[:,-1]` versus `batch[0]`).
                if is_batch:
                    # if indexing e.g., `batch[0]`
                    if func == torch.Tensor.__getitem__:
                        batch_idx = args[1]
                        if isinstance(batch_idx, Sequence):
                            batch_idx = batch_idx[0]
                        # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
                        # first element will be `slice(None, None, None)` and `Ellipsis`,
                        # respectively. Don't need to do anything with the metadata.
                        if batch_idx not in (slice(None, None, None), Ellipsis,
                                             None) and idx == 0:
                            ret_meta = decollate_batch(args[0],
                                                       detach=False)[batch_idx]
                            if isinstance(ret_meta,
                                          list):  # e.g. batch[0:2], re-collate
                                ret_meta = list_data_collate(ret_meta)
                            else:  # e.g. `batch[0]` or `batch[0, 1]`, batch index is an integer
                                ret_meta.is_batch = False
                            ret.__dict__ = ret_meta.__dict__.copy()
                    # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
                    # But we only want to split the batch if the `unbind` is along the 0th
                    # dimension.
                    elif func == torch.Tensor.unbind:
                        if len(args) > 1:
                            dim = args[1]
                        elif "dim" in kwargs:
                            dim = kwargs["dim"]
                        else:
                            dim = 0
                        if dim == 0:
                            if metas is None:
                                metas = decollate_batch(args[0], detach=False)
                            ret.__dict__ = metas[idx].__dict__.copy()
                            ret.is_batch = False

            out.append(ret)
        # if the input was a tuple, then return it as a tuple
        return tuple(out) if isinstance(rets, tuple) else out