Пример #1
0
    def __init__(
        self,
        keys: KeysCollection,
        transform: InvertibleTransform,
        orig_keys: KeysCollection,
        meta_keys: Optional[KeysCollection] = None,
        orig_meta_keys: Optional[KeysCollection] = None,
        meta_key_postfix: str = DEFAULT_POST_FIX,
        nearest_interp: Union[bool, Sequence[bool]] = True,
        to_tensor: Union[bool, Sequence[bool]] = True,
        device: Union[Union[str, torch.device], Sequence[Union[str, torch.device]]] = "cpu",
        post_func: Union[Callable, Sequence[Callable]] = lambda x: x,
        allow_missing_keys: bool = False,
    ) -> None:
        """
        Args:
            keys: the key of expected data in the dict, the inverse of ``transforms`` will be applied on it in-place.
                It also can be a list of keys, will apply the inverse transform respectively.
            transform: the transform applied to ``orig_key``, its inverse will be applied on ``key``.
            orig_keys: the key of the original input data in the dict.
                the transform trace information of ``transforms`` should be stored at ``{orig_keys}_transforms``.
                It can also be a list of keys, each matches the ``keys``.
            meta_keys: The key to output the inverted meta data dictionary.
                The meta data is a dictionary optionally containing: filename, original_shape.
                It can be a sequence of strings, maps to ``keys``.
                If None, will try to create a meta data dict with the default key: `{key}_{meta_key_postfix}`.
            orig_meta_keys: the key of the meta data of original input data.
                The meta data is a dictionary optionally containing: filename, original_shape.
                It can be a sequence of strings, maps to the `keys`.
                If None, will try to create a meta data dict with the default key: `{orig_key}_{meta_key_postfix}`.
                This meta data dict will also be included in the inverted dict, stored in `meta_keys`.
            meta_key_postfix: if `orig_meta_keys` is None, use `{orig_key}_{meta_key_postfix}` to fetch the
                meta data from dict, if `meta_keys` is None, use `{key}_{meta_key_postfix}`. Default: ``"meta_dict"``.
            nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms,
                default to `True`. If `False`, use the same interpolation mode as the original transform.
                It also can be a list of bool, each matches to the `keys` data.
            to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.
                It also can be a list of bool, each matches to the `keys` data.
            device: if converted to Tensor, move the inverted results to target device before `post_func`,
                default to "cpu", it also can be a list of string or `torch.device`, each matches to the `keys` data.
            post_func: post processing for the inverted data, should be a callable function.
                It also can be a list of callable, each matches to the `keys` data.
            allow_missing_keys: don't raise exception if key is missing.

        """
        super().__init__(keys, allow_missing_keys)
        if not isinstance(transform, InvertibleTransform):
            raise ValueError("transform is not invertible, can't invert transform for the data.")
        self.transform = transform
        self.orig_keys = ensure_tuple_rep(orig_keys, len(self.keys))
        self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)
        if len(self.keys) != len(self.meta_keys):
            raise ValueError("meta_keys should have the same length as keys.")
        self.orig_meta_keys = ensure_tuple_rep(orig_meta_keys, len(self.keys))
        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
        self.nearest_interp = ensure_tuple_rep(nearest_interp, len(self.keys))
        self.to_tensor = ensure_tuple_rep(to_tensor, len(self.keys))
        self.device = ensure_tuple_rep(device, len(self.keys))
        self.post_func = ensure_tuple_rep(post_func, len(self.keys))
        self._totensor = ToTensor()
Пример #2
0
 def __init__(self, keys: KeysCollection) -> None:
     """
     Args:
         keys: keys of the corresponding items to be transformed.
             See also: :py:class:`monai.transforms.compose.MapTransform`
     """
     super().__init__(keys)
     self.converter = ToTensor()
Пример #3
0
 def __init__(self, keys):
     """
     Args:
         keys (hashable items): keys of the corresponding items to be transformed.
             See also: :py:class:`monai.transforms.compose.MapTransform`
     """
     super().__init__(keys)
     self.converter = ToTensor()
Пример #4
0
 def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
     """
     Args:
         keys: keys of the corresponding items to be transformed.
             See also: :py:class:`monai.transforms.compose.MapTransform`
         allow_missing_keys: don't raise exception if key is missing.
     """
     super().__init__(keys, allow_missing_keys)
     self.converter = ToTensor()
Пример #5
0
    def __call__(self, batch: Any):
        """
        Args:
            batch: batch of data to pad-collate
        """
        # data is either list of dicts or list of lists
        is_list_of_dicts = isinstance(batch[0], dict)
        # loop over items inside of each element in a batch
        for key_or_idx in batch[0].keys() if is_list_of_dicts else range(
                len(batch[0])):
            # calculate max size of each dimension
            max_shapes = []
            for elem in batch:
                if not isinstance(elem[key_or_idx],
                                  (torch.Tensor, np.ndarray)):
                    break
                max_shapes.append(elem[key_or_idx].shape[1:])
            # len > 0 if objects were arrays, else skip as no padding to be done
            if not max_shapes:
                continue
            max_shape = np.array(max_shapes).max(axis=0)
            # If all same size, skip
            if np.all(np.array(max_shapes).min(axis=0) == max_shape):
                continue
            # Do we need to convert output to Tensor?
            output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor)

            # Use `SpatialPadd` or `SpatialPad` to match sizes
            # Default params are central padding, padding with 0's
            # If input is dictionary, use the dictionary version so that the transformation is recorded

            padder = SpatialPad(spatial_size=max_shape,
                                method=self.method,
                                mode=self.mode,
                                **self.np_kwargs)
            transform = padder if not output_to_tensor else Compose(
                [padder, ToTensor()])

            for idx, batch_i in enumerate(batch):
                im = batch_i[key_or_idx]
                orig_size = im.shape[1:]
                padded = transform(batch_i[key_or_idx])
                batch = replace_element(padded, batch, idx, key_or_idx)

                # If we have a dictionary of data, append to list
                if is_list_of_dicts:
                    self.push_transform(batch[idx],
                                        key_or_idx,
                                        orig_size=orig_size)

        # After padding, use default list collator
        return list_data_collate(batch)
Пример #6
0
    def __init__(
        self,
        keys: KeysCollection,
        transform: InvertibleTransform,
        loader: TorchDataLoader,
        orig_keys: Union[str, Sequence[str]],
        meta_key_postfix: str = "meta_dict",
        collate_fn: Optional[Callable] = no_collation,
        postfix: str = "inverted",
        nearest_interp: Union[bool, Sequence[bool]] = True,
        to_tensor: Union[bool, Sequence[bool]] = True,
        device: Union[Union[str, torch.device],
                      Sequence[Union[str, torch.device]]] = "cpu",
        post_func: Union[Callable, Sequence[Callable]] = lambda x: x,
        num_workers: Optional[int] = 0,
        allow_missing_keys: bool = False,
    ) -> None:
        """
        Args:
            keys: the key of expected data in the dict, invert transforms on it.
                it also can be a list of keys, will invert transform for each of them, like: ["pred", "pred_class2"].
            transform: the previous callable transform that applied on input data.
            loader: data loader used to run transforms and generate the batch of data.
            orig_keys: the key of the original input data in the dict. will get the applied transform information
                for this input data, then invert them for the expected data with `keys`.
                It can also be a list of keys, each matches to the `keys` data.
            meta_key_postfix: use `{orig_key}_{postfix}` to to fetch the meta data from dict,
                default is `meta_dict`, the meta data is a dictionary object.
                For example, to handle orig_key `image`,  read/write `affine` matrices from the
                metadata `image_meta_dict` dictionary's `affine` field.
            collate_fn: how to collate data after inverse transformations. default won't do any collation,
                so the output will be a list of PyTorch Tensor or numpy array without batch dim.
            postfix: will save the inverted result into dict with key `{key}_{postfix}`.
            nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms,
                default to `True`. If `False`, use the same interpolation mode as the original transform.
                it also can be a list of bool, each matches to the `keys` data.
            to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.
                it also can be a list of bool, each matches to the `keys` data.
            device: if converted to Tensor, move the inverted results to target device before `post_func`,
                default to "cpu", it also can be a list of string or `torch.device`,
                each matches to the `keys` data.
            post_func: post processing for the inverted data, should be a callable function.
                it also can be a list of callable, each matches to the `keys` data.
            num_workers: number of workers when run data loader for inverse transforms,
                default to 0 as only run one iteration and multi-processing may be even slower.
                Set to `None`, to use the `num_workers` of the input transform data loader.
            allow_missing_keys: don't raise exception if key is missing.

        """
        super().__init__(keys, allow_missing_keys)
        self.transform = transform
        self.inverter = BatchInverseTransform(
            transform=transform,
            loader=loader,
            collate_fn=collate_fn,
            num_workers=num_workers,
        )
        self.orig_keys = ensure_tuple_rep(orig_keys, len(self.keys))
        self.meta_key_postfix = meta_key_postfix
        self.postfix = postfix
        self.nearest_interp = ensure_tuple_rep(nearest_interp, len(self.keys))
        self.to_tensor = ensure_tuple_rep(to_tensor, len(self.keys))
        self.device = ensure_tuple_rep(device, len(self.keys))
        self.post_func = ensure_tuple_rep(post_func, len(self.keys))
        self._totensor = ToTensor()
Пример #7
0
    def __init__(
        self,
        keys: KeysCollection,
        transform: InvertibleTransform,
        orig_keys: KeysCollection,
        meta_keys: Optional[KeysCollection] = None,
        orig_meta_keys: Optional[KeysCollection] = None,
        meta_key_postfix: str = "meta_dict",
        nearest_interp: Union[bool, Sequence[bool]] = True,
        to_tensor: Union[bool, Sequence[bool]] = True,
        device: Union[Union[str, torch.device],
                      Sequence[Union[str, torch.device]]] = "cpu",
        post_func: Union[Callable, Sequence[Callable]] = lambda x: x,
        allow_missing_keys: bool = False,
    ) -> None:
        """
        Args:
            keys: the key of expected data in the dict, invert transforms on it, in-place operation.
                it also can be a list of keys, will invert transform for each of them, like: ["pred", "pred_class2"].
            transform: the previous callable transform that applied on input data.
            orig_keys: the key of the original input data in the dict. will get the applied transform information
                for this input data, then invert them for the expected data with `keys`.
                It can also be a list of keys, each matches to the `keys` data.
            meta_keys: explicitly indicate the key for the inverted meta data dictionary.
                the meta data is a dictionary object which contains: filename, original_shape, etc.
                it can be a sequence of string, map to the `keys`.
                if None, will try to construct meta_keys by `{key}_{meta_key_postfix}`.
            orig_meta_keys: the key of the meta data of original input data, will get the `affine`, `data_shape`, etc.
                the meta data is a dictionary object which contains: filename, original_shape, etc.
                it can be a sequence of string, map to the `keys`.
                if None, will try to construct meta_keys by `{orig_key}_{meta_key_postfix}`.
                meta data will also be inverted and stored in `meta_keys`.
            meta_key_postfix: if `orig_meta_keys` is None, use `{orig_key}_{meta_key_postfix}` to to fetch the
                meta data from dict, if `meta_keys` is None, use `{key}_{meta_key_postfix}`.
                default is `meta_dict`, the meta data is a dictionary object.
                For example, to handle orig_key `image`,  read/write `affine` matrices from the
                metadata `image_meta_dict` dictionary's `affine` field.
                the inverted meta dict will be stored with key: "{key}_{meta_key_postfix}".
            nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms,
                default to `True`. If `False`, use the same interpolation mode as the original transform.
                it also can be a list of bool, each matches to the `keys` data.
            to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.
                it also can be a list of bool, each matches to the `keys` data.
            device: if converted to Tensor, move the inverted results to target device before `post_func`,
                default to "cpu", it also can be a list of string or `torch.device`,
                each matches to the `keys` data.
            post_func: post processing for the inverted data, should be a callable function.
                it also can be a list of callable, each matches to the `keys` data.
            allow_missing_keys: don't raise exception if key is missing.

        """
        super().__init__(keys, allow_missing_keys)
        if not isinstance(transform, InvertibleTransform):
            raise ValueError(
                "transform is not invertible, can't invert transform for the data."
            )
        self.transform = transform
        self.orig_keys = ensure_tuple_rep(orig_keys, len(self.keys))
        self.meta_keys = ensure_tuple_rep(None, len(
            self.keys)) if meta_keys is None else ensure_tuple(meta_keys)
        if len(self.keys) != len(self.meta_keys):
            raise ValueError("meta_keys should have the same length as keys.")
        self.orig_meta_keys = ensure_tuple_rep(orig_meta_keys, len(self.keys))
        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix,
                                                 len(self.keys))
        self.nearest_interp = ensure_tuple_rep(nearest_interp, len(self.keys))
        self.to_tensor = ensure_tuple_rep(to_tensor, len(self.keys))
        self.device = ensure_tuple_rep(device, len(self.keys))
        self.post_func = ensure_tuple_rep(post_func, len(self.keys))
        self._totensor = ToTensor()