Example #1
0
def norm_mean_std(data: torch.Tensor,
                  mean: Union[float, Sequence],
                  std: Union[float, Sequence],
                  per_channel: bool = True,
                  out: Optional[torch.Tensor] = None) -> torch.Tensor:
    """
    Normalize mean and std with provided values

    Args:
        data:input data. Per channel option supports [C,H,W] and [C,H,W,D].
        mean: used for mean normalization
        std: used for std normalization
        per_channel: range is normalized per channel
        out: if provided, result is saved into out

    Returns:
        torch.Tensor: normalized data
    """
    if out is None:
        out = torch.zeros_like(data)

    if per_channel:
        if check_scalar(mean):
            mean = [mean] * data.shape[0]
        if check_scalar(std):
            std = [std] * data.shape[0]
        for _c in range(data.shape[0]):
            out[_c] = (data[_c] - mean[_c]) / std[_c]
    else:
        out = (data - mean) / std

    return out
Example #2
0
def random_crop(
    data: torch.Tensor, size: Union[int, Sequence[int]], dist: Union[int, Sequence[int]] = 0
) -> torch.Tensor:
    """
    Crop random patch/volume from input tensor

    Args:
        data: input tensor
        size: size of patch/volume
        dist: minimum distance to border. By default zero

    Returns:
        torch.Tensor: cropped output
        List[int]: top left corner used for crop
    """
    if check_scalar(dist):
        dist = [dist] * (data.ndim - 2)
    if isinstance(dist[0], torch.Tensor):
        dist = [int(i) for i in dist]
    if check_scalar(size):
        size = [size] * (data.ndim - 2)
    if not isinstance(size[0], int):
        size = [int(s) for s in size]

    if any([crop_dim + dist_dim >= img_dim for img_dim, crop_dim, dist_dim in zip(data.shape[2:], size, dist)]):
        raise TypeError(f"Crop can not be realized with given size {size} and dist {dist}.")

    corner = [
        torch.randint(0, img_dim - crop_dim - dist_dim, (1,)).item()
        for img_dim, crop_dim, dist_dim in zip(data.shape[2:], size, dist)
    ]
    return crop(data, corner, size)
Example #3
0
    def __init__(self,
                 in_channels: int,
                 kernel_size: Union[int, Sequence],
                 dim: int = 2,
                 stride: Union[int, Sequence] = 1,
                 padding: Union[int, Sequence] = 0,
                 padding_mode: str = 'zero',
                 keys: Sequence = ('data', ),
                 grad: bool = False,
                 **kwargs):
        """
        Args:
            in_channels: number of input channels
            kernel_size: size of kernel
            dim: number of spatial dimensions
            stride: stride of convolution
            padding: padding size for input
            padding_mode: padding mode for input. Supports all modes
                from :func:`torch.functional.pad` except ``circular``
            keys: keys which should be augmented
            grad: enable gradient computation inside transformation
            kwargs: keyword arguments passed to superclass

        See Also:
            :func:`torch.functional.pad`
        """
        super().__init__(grad=grad, **kwargs)
        self.in_channels = in_channels

        if check_scalar(kernel_size):
            kernel_size = [kernel_size] * dim
        self.kernel_size = kernel_size

        if check_scalar(stride):
            stride = [stride] * dim
        self.stride = stride

        if check_scalar(padding):
            padding = [padding] * dim * 2
        self.padding = padding

        self.padding_mode = padding_mode
        self.keys = keys

        kernel = self.create_kernel()
        self.register_buffer('weight', kernel)
        self.groups = in_channels
        self.conv = self.get_conv(dim)
Example #4
0
 def __init__(self,
              augment_fn: augment_axis_callable,
              dims: Sequence,
              keys: Sequence = ('data', ),
              prob: Union[float, Sequence] = 0.5,
              grad: bool = False,
              **kwargs):
     """
     Args:
         augment_fn: function for augmentation
         dims: possible axes
         keys: keys which should be augmented
         prob: probability for mirror. If float value is provided, it is used
             for all dims
         grad: enable gradient computation inside transformation
         kwargs: keyword arguments passed to augment_fn
     """
     super().__init__(grad=grad)
     self.augment_fn = augment_fn
     self.dims = dims
     self.keys = keys
     if check_scalar(prob):
         prob = (prob, ) * len(dims)
     self.prob = prob
     self.kwargs = kwargs
Example #5
0
    def __init__(self,
                 *transforms: Union[AbstractTransform,
                                    Sequence[AbstractTransform]],
                 dropout: Union[float, Sequence[float]] = 0.5,
                 shuffle: bool = False,
                 random_sampler: ContinuousParameter = None,
                 transform_call: Callable[[Any, Callable], Any] = dict_call,
                 **kwargs):
        """
        Args:
            *transforms: one or multiple transformations which are applied in
                consecutive order
            dropout: if provided as float, each transform is skipped with the
                given probability
                if :attr:`dropout` is a sequence, it needs to specify the
                dropout probability for each given transform
            shuffle: apply transforms in random order
            random_sampler : a continuous parameter sampler. Samples a
                random value for each of the transforms.
            transform_call: function which determines how transforms are
                called. By default Mappings and Sequences are unpacked
                during the transform.

        Raises:
            ValueError: if dropout is a sequence it must have the same length
                as transforms
        """
        super().__init__(*transforms,
                         transform_call=transform_call,
                         shuffle=shuffle,
                         **kwargs)

        if random_sampler is None:
            random_sampler = UniformParameter(0., 1.)

        self.register_sampler('prob',
                              random_sampler,
                              size=(len(self.transforms), ))

        if check_scalar(dropout):
            dropout = [dropout] * len(self.transforms)
        self.dropout = dropout
        if len(dropout) != len(self.transforms):
            raise TypeError(f"If dropout is a sequence it must specify the "
                            f"dropout probability for each transform, "
                            f"found {len(dropout)} probabilities "
                            f"and {len(self.transforms)} transforms.")
Example #6
0
def mirror(data: torch.Tensor, dims: Union[int,
                                           Sequence[int]]) -> torch.Tensor:
    """
    Mirror data at dims

    Args:
        data: input data
        dims: dimensions to mirror

    Returns:
        torch.Tensor: tensor with mirrored dimensions
    """
    if check_scalar(dims):
        dims = (dims, )
        # batch and channel dims
    dims = [d + 2 for d in dims]
    return data.flip(dims)
Example #7
0
def center_crop(data: torch.Tensor, size: Union[int, Sequence[int]]) -> torch.Tensor:
    """
    Crop patch from center

    Args:
    data: input tensor
    size: size of patch

    Returns:
        torch.Tensor: output tensor cropped from input tensor
    """
    if check_scalar(size):
        size = [size] * (data.ndim - 2)
    if not isinstance(size[0], int):
        size = [int(s) for s in size]

    corner = [int(round((img_dim - crop_dim) / 2.0)) for img_dim, crop_dim in zip(data.shape[2:], size)]
    return crop(data, corner, size)
Example #8
0
def resize_native(
    data: torch.Tensor,
    size: Optional[Union[int, Sequence[int]]] = None,
    scale_factor: Optional[Union[float, Sequence[float]]] = None,
    mode: str = "nearest",
    align_corners: Optional[bool] = None,
    preserve_range: bool = False,
):
    """
    Down/up-sample sample to either the given :attr:`size` or the given
    :attr:`scale_factor`
    The modes available for resizing are: nearest, linear (3D-only), bilinear,
    bicubic (4D-only), trilinear (5D-only), area

    Args:
        data: input tensor of shape batch x channels x height x width x [depth]
        size: spatial output size (excluding batch size and number of channels)
        scale_factor: multiplier for spatial size
        mode: one of ``nearest``, ``linear``, ``bilinear``, ``bicubic``,
            ``trilinear``, ``area``
            (for more inforamtion see :func:`torch.nn.functional.interpolate`)
        align_corners: input and output tensors are aligned by the center
            points of their corners pixels, preserving the values at the
            corner pixels.
        preserve_range:  output tensor has same range as input tensor

    Returns:
        torch.Tensor: interpolated tensor

    See Also:
        :func:`torch.nn.functional.interpolate`
    """
    if check_scalar(scale_factor):
        # pytorch internally checks for an iterable. Single value tensors are still iterable
        scale_factor = float(scale_factor)
    out = torch.nn.functional.interpolate(data,
                                          size=size,
                                          scale_factor=scale_factor,
                                          mode=mode,
                                          align_corners=align_corners)

    if preserve_range:
        out.clamp_(data.min(), data.max())
    return out
Example #9
0
    def __init__(self,
                 *transforms,
                 dropout: Union[float, Sequence[float]] = 0.5,
                 random_mode: str = "random",
                 random_args: Sequence = (),
                 random_module: str = "random",
                 transform_call: Callable[[Any, Callable], Any] = dict_call,
                 **kwargs):
        """
        Args:
            *transforms: one or multiple transformations which are applied in c
                onsecutive order
            dropout: if provided as float, each transform is skipped with the
                given probability
                if :attr:`dropout` is a sequence, it needs to specify the
                dropout probability for each given transform
            random_mode: specifies distribution which should be used to sample
                additive value
            random_args: positional arguments passed for random function
            random_module: module from where function random function should
                be imported
            transform_call: function which determines how transforms are
                called. By default Mappings and Sequences are unpacked
                during the transform.

        Raises:
            ValueError: if dropout is a sequence it must have the same length
                as transforms
        """
        super().__init__(*transforms,
                         random_mode=random_mode,
                         random_args=random_args,
                         random_module=random_module,
                         rand_seq=False,
                         transform_call=transform_call,
                         **kwargs)
        if check_scalar(dropout):
            dropout = [dropout] * len(self.transforms)
        if len(dropout) != len(self.transforms):
            raise ValueError(
                f"If dropout is a sequence it must specify the dropout probability "
                f"for each transform, found {len(dropout)} probabilities "
                f"and {len(self.transforms)} transforms.")
        self.dropout = dropout
Example #10
0
    def __init__(self,
                 in_channels: int,
                 kernel_size: Union[int, Sequence],
                 std: Union[int, Sequence],
                 dim: int = 2,
                 stride: Union[int, Sequence] = 1,
                 padding: Union[int, Sequence] = 0,
                 padding_mode: str = 'reflect',
                 keys: Sequence = ('data', ),
                 grad: bool = False,
                 **kwargs):
        """
        Args:
            in_channels: number of input channels
            kernel_size: size of kernel
            std: standard deviation of gaussian
            dim: number of spatial dimensions
            stride: stride of convolution
            padding: padding size for input
            padding_mode: padding mode for input. Supports all modes from
                :func:`torch.functional.pad` except ``circular``
            keys: keys which should be augmented
            grad: enable gradient computation inside transformation
            **kwargs: keyword arguments passed to superclass

        See Also:
            :func:`torch.functional.pad`
        """
        if check_scalar(std):
            std = [std] * dim
        self.std = std
        super().__init__(in_channels=in_channels,
                         kernel_size=kernel_size,
                         dim=dim,
                         stride=stride,
                         padding=padding,
                         padding_mode=padding_mode,
                         keys=keys,
                         grad=grad,
                         **kwargs)
Example #11
0
    def forward(self, **data) -> dict:
        """
        Apply transformation

        Args:
            data: dict with tensors

        Returns:
            dict: augmented data
        """
        if check_scalar(self.gamma):
            _gamma = self.gamma
        elif self.gamma[1] < 1:
            _gamma = random.uniform(self.gamma[0], self.gamma[1])
        else:
            if random.random() < 0.5:
                _gamma = _gamma = random.uniform(self.gamma[0], 1)
            else:
                _gamma = _gamma = random.uniform(1, self.gamma[1])

        for _key in self.keys:
            data[_key] = self.augment_fn(data[_key], _gamma, **self.kwargs)
        return data
Example #12
0
 def __init__(self,
              gamma: Union[float, Sequence] = (0.5, 2),
              keys: Sequence = ('data', ),
              grad: bool = False,
              **kwargs):
     """
     Args:
         gamma: if gamma is float it is always applied.
             if gamma is a sequence it is interpreted as  the minimal and
             maximal value. If the maximal value is greater than one,
             the transform chooses gamma < 1 in 50% of the cases and
             gamma > 1 in the other cases.
         keys: keys to normalize
         grad: enable gradient computation inside transformation
         **kwargs: keyword arguments passed to superclass
     """
     super().__init__(augment_fn=gamma_correction, keys=keys, grad=grad)
     self.kwargs = kwargs
     self.gamma = gamma
     if not check_scalar(self.gamma):
         if not len(self.gamma) == 2:
             raise TypeError(
                 f"Gamma needs to be scalar or a Sequence with two entries "
                 f"(min, max), found {self.gamma}")