def test_get_set_meta_fns(self): set_track_meta(False) self.assertEqual(get_track_meta(), False) set_track_meta(True) self.assertEqual(get_track_meta(), True) set_track_transforms(False) self.assertEqual(get_track_transforms(), False) set_track_transforms(True) self.assertEqual(get_track_transforms(), True)
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: torch tensor data to extract the contour, with shape: [channels, height, width[, depth]] Raises: ValueError: When ``image`` ndim is not one of [3, 4]. Returns: A torch tensor with the same shape as img, note: 1. it's the binary classification result of whether a pixel is edge or not. 2. in order to keep the original shape of mask image, we use padding as default. 3. the edge detection is just approximate because it defects inherent to Laplace kernel, ideally the edge should be thin enough, but now it has a thickness. """ img = convert_to_tensor(img, track_meta=get_track_meta()) img_: torch.Tensor = convert_to_tensor(img, track_meta=False) spatial_dims = len(img_.shape) - 1 img_ = img_.unsqueeze(0) # adds a batch dim if spatial_dims == 2: kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32) elif spatial_dims == 3: kernel = -1.0 * torch.ones(3, 3, 3, dtype=torch.float32) kernel[1, 1, 1] = 26.0 else: raise ValueError( f"{self.__class__} can only handle 2D or 3D images.") contour_img = apply_filter(img_, kernel) contour_img.clamp_(min=0.0, max=1.0) output, *_ = convert_to_dst_type(contour_img.squeeze(0), img) return output
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Filter the image on the `applied_labels`. Args: img: Pytorch tensor or numpy array of any shape. Raises: NotImplementedError: The provided image was not a Pytorch Tensor or numpy array. Returns: Pytorch tensor or numpy array of the same shape as the input. """ if not isinstance(img, (np.ndarray, torch.Tensor)): raise NotImplementedError( f"{self.__class__} can not handle data of type {type(img)}.") if isinstance(img, torch.Tensor): img = convert_to_tensor(img, track_meta=get_track_meta()) img_ = convert_to_tensor(img, track_meta=False) if hasattr(torch, "isin"): # `isin` is new in torch 1.10.0 appl_lbls = torch.as_tensor(self.applied_labels, device=img_.device) out = torch.where(torch.isin(img_, appl_lbls), img_, torch.tensor(0.0).to(img_)) return convert_to_dst_type(out, dst=img)[0] out: NdarrayOrTensor = self( img_.detach().cpu().numpy()) # type: ignore out = convert_to_dst_type(out, img)[0] # type: ignore return out return np.asarray(np.where(np.isin(img, self.applied_labels), img, 0))
def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict, simple_keys: bool = False): """ Convert the image to `torch.Tensor`. If `affine` is in the `meta` dictionary, convert that to `torch.Tensor`, too. Remove any superfluous metadata. Args: im: Input image (`np.ndarray` or `torch.Tensor`) meta: Metadata dictionary. simple_keys: whether to keep only a simple subset of metadata keys. Returns: By default, a `MetaTensor` is returned. However, if `get_track_meta()` is `False`, a `torch.Tensor` is returned. """ img = convert_to_tensor(im) # potentially ascontiguousarray # if not tracking metadata, return `torch.Tensor` if not get_track_meta() or meta is None: return img # remove any superfluous metadata. if simple_keys: # ensure affine is of type `torch.Tensor` if "affine" in meta: meta["affine"] = convert_to_tensor( meta["affine"]) # bc-breaking remove_extra_metadata(meta) # bc-breaking # return the `MetaTensor` return MetaTensor(img, meta=meta)
def __call__(self, img: NdarrayOrTensor, randomize: bool = True, device: Optional[torch.device] = None) -> NdarrayOrTensor: img = convert_to_tensor(img, track_meta=get_track_meta()) if randomize: self.randomize() if not self._do_transform: return img device = device if device is not None else self.device field = self.sfield() dgrid = self.grid + field.to(self.grid_dtype) dgrid = moveaxis(dgrid, 1, -1) # type: ignore img_t = convert_to_tensor(img[None], torch.float32, device) out = grid_sample( input=img_t, grid=dgrid, mode=look_up_option(self.grid_mode, GridSampleMode), align_corners=self.grid_align_corners, padding_mode=look_up_option(self.grid_padding_mode, GridSamplePadMode), ) out_t, *_ = convert_to_dst_type(out.squeeze(0), img) return out_t
def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`, if `randomize` randomizing the smooth field otherwise reusing the previous. """ img = convert_to_tensor(img, track_meta=get_track_meta()) if randomize: self.randomize() if not self._do_transform: return img img_min = img.min() img_max = img.max() img_rng = img_max - img_min field = self.sfield() rfield, *_ = convert_to_dst_type(field, img) # everything below here is to be computed using the destination type (numpy, tensor, etc.) img = (img - img_min) / (img_rng + 1e-10) # rescale to unit values img = img**rfield # contrast is changed by raising image data to a power, in this case the field out = (img * img_rng ) + img_min # rescale back to the original image value range return out
def __call__( self, data: Mapping[Hashable, NdarrayOrTensor] ) -> Mapping[Hashable, NdarrayOrTensor]: self.randomize() d = dict(data) if not self._do_transform: for key in self.key_iterator(d): d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d for idx, key in enumerate(self.key_iterator(d)): self.trans.set_mode(self.mode[idx % len(self.mode)]) d[key] = self.trans(d[key], False) return d
def __call__( self, img: NdarrayOrTensor, sigmoid: Optional[bool] = None, softmax: Optional[bool] = None, other: Optional[Callable] = None, ) -> NdarrayOrTensor: """ Args: sigmoid: whether to execute sigmoid function on model output before transform. Defaults to ``self.sigmoid``. softmax: whether to execute softmax function on model output before transform. Defaults to ``self.softmax``. other: callable function to execute other activation layers, for example: `other = torch.tanh`. Defaults to ``self.other``. Raises: ValueError: When ``sigmoid=True`` and ``softmax=True``. Incompatible values. TypeError: When ``other`` is not an ``Optional[Callable]``. ValueError: When ``self.other=None`` and ``other=None``. Incompatible values. """ if sigmoid and softmax: raise ValueError( "Incompatible values: sigmoid=True and softmax=True.") if other is not None and not callable(other): raise TypeError( f"other must be None or callable but is {type(other).__name__}." ) # convert to float as activation must operate on float tensor img = convert_to_tensor(img, track_meta=get_track_meta()) img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) if sigmoid or self.sigmoid: img_t = torch.sigmoid(img_t) if softmax or self.softmax: img_t = torch.softmax(img_t, dim=0) act_func = self.other if other is None else other if act_func is not None: img_t = act_func(img_t) out, *_ = convert_to_dst_type(img_t, img) return out
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: shape must be (C, spatial_dim1[, spatial_dim2, ...]). Returns: An array with shape (C, spatial_dim1[, spatial_dim2, ...]). """ is_onehot = img.shape[ 0] > 1 if self.is_onehot is None else self.is_onehot if self.applied_labels is not None: applied_labels = self.applied_labels else: applied_labels = tuple(get_unique_labels(img, is_onehot, discard=0)) img = convert_to_tensor(img, track_meta=get_track_meta()) img_: torch.Tensor = convert_to_tensor(img, track_meta=False) if self.independent: for i in applied_labels: foreground = img_[i] > 0 if is_onehot else img_[0] == i mask = get_largest_connected_component_mask( foreground, self.connectivity) if is_onehot: img_[i][foreground != mask] = 0 else: img_[0][foreground != mask] = 0 return convert_to_dst_type(img_, dst=img)[0] if not is_onehot: # not one-hot, union of labels labels, *_ = convert_to_dst_type(applied_labels, dst=img_, wrap_sequence=True) foreground = (img_[..., None] == labels).any(-1)[0] mask = get_largest_connected_component_mask( foreground, self.connectivity) img_[0][foreground != mask] = 0 return convert_to_dst_type(img_, dst=img)[0] # one-hot, union of labels foreground = (img_[applied_labels, ...] == 1).any(0) mask = get_largest_connected_component_mask(foreground, self.connectivity) for i in applied_labels: img_[i][foreground != mask] = 0 return convert_to_dst_type(img_, dst=img)[0]
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Fill the holes in the provided image. Note: The value 0 is assumed as background label. Args: img: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]]. Raises: NotImplementedError: The provided image was not a Pytorch Tensor or numpy array. Returns: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]]. """ img = convert_to_tensor(img, track_meta=get_track_meta()) img_np, *_ = convert_data_type(img, np.ndarray) out_np: np.ndarray = fill_holes(img_np, self.applied_labels, self.connectivity) out, *_ = convert_to_dst_type(out_np, img) return out
def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`, if `randomize` randomizing the smooth field otherwise reusing the previous. """ img = convert_to_tensor(img, track_meta=get_track_meta()) if randomize: self.randomize() if not self._do_transform: return img field = self.sfield() rfield, *_ = convert_to_dst_type(field, img) # everything below here is to be computed using the destination type (numpy, tensor, etc.) out = img * rfield return out
def update_meta(rets: Sequence, func, args, kwargs): """Update the metadata from the output of `__torch_function__`. The output could be a single object, or a sequence of them. Hence, they get converted to a sequence if necessary and then processed by iterating across them. For each element, if not of type `MetaTensor`, then nothing to do """ out = [] metas = None for idx, ret in enumerate(rets): # if not `MetaTensor`, nothing to do. if not isinstance(ret, MetaTensor): pass # if not tracking, convert to `torch.Tensor`. elif not (get_track_meta() or get_track_transforms()): ret = ret.as_tensor() # else, handle the `MetaTensor` metadata. else: meta_args = MetaObj.flatten_meta_objs( list(args) + list(kwargs.values())) ret._copy_meta(meta_args) # If we have a batch of data, then we need to be careful if a slice of # the data is returned. Depending on how the data are indexed, we return # some or all of the metadata, and the return object may or may not be a # batch of data (e.g., `batch[:,-1]` versus `batch[0]`). if ret.is_batch: # only decollate metadata once if metas is None: metas = decollate_batch(ret.meta) # if indexing e.g., `batch[0]` if func == torch.Tensor.__getitem__: idx = args[1] if isinstance(idx, Sequence): idx = idx[0] # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the # first element will be `slice(None, None, None)` and `Ellipsis`, # respectively. Don't need to do anything with the metadata. if idx not in (slice(None, None, None), Ellipsis): meta = metas[idx] # if using e.g., `batch[0:2]`, then `is_batch` should still be # `True`. Also re-collate the remaining elements. if isinstance(meta, list) and len(meta) > 1: ret.meta = list_data_collate(meta) # if using e.g., `batch[0]` or `batch[0, 1]`, then return single # element from batch, and set `is_batch` to `False`. else: ret.meta = meta ret.is_batch = False # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. # But we only want to split the batch if the `unbind` is along the 0th # dimension. elif func == torch.Tensor.unbind: if len(args) > 1: dim = args[1] elif "dim" in kwargs: dim = kwargs["dim"] else: dim = 0 if dim == 0: ret.meta = metas[idx] ret.is_batch = False ret.affine = ret.affine.to(ret.device) out.append(ret) # if the input was a tuple, then return it as a tuple return tuple(out) if isinstance(rets, tuple) else out
def __call__( self, img: NdarrayOrTensor, argmax: Optional[bool] = None, to_onehot: Optional[int] = None, threshold: Optional[float] = None, rounding: Optional[str] = None, n_classes: Optional[int] = None, # deprecated num_classes: Optional[int] = None, # deprecated logit_thresh: Optional[float] = None, # deprecated threshold_values: Optional[bool] = None, # deprecated ) -> NdarrayOrTensor: """ Args: img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`, will automatically add it. argmax: whether to execute argmax function on input data before transform. Defaults to ``self.argmax``. to_onehot: if not None, convert input data into the one-hot format with specified number of classes. Defaults to ``self.to_onehot``. threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value. Defaults to ``self.threshold``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. .. deprecated:: 0.6.0 ``n_classes`` is deprecated, use ``to_onehot`` instead. .. deprecated:: 0.7.0 ``num_classes`` is deprecated, use ``to_onehot`` instead. ``logit_thresh`` is deprecated, use ``threshold`` instead. ``threshold_values`` is deprecated, use ``threshold`` instead. """ if isinstance(to_onehot, bool): warnings.warn( "`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead." ) to_onehot = num_classes if to_onehot else None if isinstance(threshold, bool): warnings.warn( "`threshold_values=True/False` is deprecated, please use `threshold=value` instead." ) threshold = logit_thresh if threshold else None img = convert_to_tensor(img, track_meta=get_track_meta()) img_t, *_ = convert_data_type(img, torch.Tensor) if argmax or self.argmax: img_t = torch.argmax(img_t, dim=0, keepdim=True) to_onehot = self.to_onehot if to_onehot is None else to_onehot if to_onehot is not None: if not isinstance(to_onehot, int): raise AssertionError( "the number of classes for One-Hot must be an integer.") img_t = one_hot(img_t, num_classes=to_onehot, dim=0) threshold = self.threshold if threshold is None else threshold if threshold is not None: img_t = img_t >= threshold rounding = self.rounding if rounding is None else rounding if rounding is not None: look_up_option(rounding, ["torchrounding"]) img_t = torch.round(img_t) img, *_ = convert_to_dst_type(img_t, img, dtype=torch.float) return img
def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: """ Update the metadata from the output of `MetaTensor.__torch_function__`. The output of `torch.Tensor.__torch_function__` could be a single object or a sequence of them. Hence, in `MetaTensor.__torch_function__` we convert them to a list of not already, and then we loop across each element, processing metadata as necessary. For each element, if not of type `MetaTensor`, then nothing to do. Args: rets: the output from `torch.Tensor.__torch_function__`, which has been converted to a list in `MetaTensor.__torch_function__` if it wasn't already a `Sequence`. func: the torch function that was applied. Examples might be `torch.squeeze` or `torch.Tensor.__add__`. We need this since the metadata need to be treated differently if a batch of data is considered. For example, slicing (`torch.Tensor.__getitem__`) the ith element of the 0th dimension of a batch of data should return a ith tensor with the ith metadata. args: positional arguments that were passed to `func`. kwargs: keyword arguments that were passed to `func`. Returns: A sequence with the same number of elements as `rets`. For each element, if the input type was not `MetaTensor`, then no modifications will have been made. If global parameters have been set to false (e.g., `not get_track_meta()`), then any `MetaTensor` will be converted to `torch.Tensor`. Else, metadata will be propagated as necessary (see :py:func:`MetaTensor._copy_meta`). """ out = [] metas = None is_batch = any( x.is_batch for x in MetaObj.flatten_meta_objs(args, kwargs.values()) if hasattr(x, "is_batch")) for idx, ret in enumerate(rets): # if not `MetaTensor`, nothing to do. if not isinstance(ret, MetaTensor): pass # if not tracking, convert to `torch.Tensor`. elif not get_track_meta(): ret = ret.as_tensor() # else, handle the `MetaTensor` metadata. else: meta_args = MetaObj.flatten_meta_objs(args, kwargs.values()) ret.is_batch = is_batch ret.copy_meta_from(meta_args, copy_attr=not is_batch) # the following is not implemented but the network arch may run into this case: # if func == torch.cat and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args): # raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.") # If we have a batch of data, then we need to be careful if a slice of # the data is returned. Depending on how the data are indexed, we return # some or all of the metadata, and the return object may or may not be a # batch of data (e.g., `batch[:,-1]` versus `batch[0]`). if is_batch: # if indexing e.g., `batch[0]` if func == torch.Tensor.__getitem__: batch_idx = args[1] if isinstance(batch_idx, Sequence): batch_idx = batch_idx[0] # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the # first element will be `slice(None, None, None)` and `Ellipsis`, # respectively. Don't need to do anything with the metadata. if batch_idx not in (slice(None, None, None), Ellipsis, None) and idx == 0: ret_meta = decollate_batch(args[0], detach=False)[batch_idx] if isinstance(ret_meta, list): # e.g. batch[0:2], re-collate ret_meta = list_data_collate(ret_meta) else: # e.g. `batch[0]` or `batch[0, 1]`, batch index is an integer ret_meta.is_batch = False ret.__dict__ = ret_meta.__dict__.copy() # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. # But we only want to split the batch if the `unbind` is along the 0th # dimension. elif func == torch.Tensor.unbind: if len(args) > 1: dim = args[1] elif "dim" in kwargs: dim = kwargs["dim"] else: dim = 0 if dim == 0: if metas is None: metas = decollate_batch(args[0], detach=False) ret.__dict__ = metas[idx].__dict__.copy() ret.is_batch = False out.append(ret) # if the input was a tuple, then return it as a tuple return tuple(out) if isinstance(rets, tuple) else out