Beispiel #1
0
def spatial_gradient3d(input: torch.Tensor,
                       mode: str = 'diff',
                       order: int = 1,
                       normalized: bool = True) -> torch.Tensor:
    r"""Computes the first and second order volume derivative in x, y and d using a diff
    operator.

    Args:
        input (torch.Tensor): input features tensor with shape :math:`(B, C, D, H, W)`.
        mode (str): derivatives modality, can be: `sobel` or `diff`. Default: `diff`.
        order (int): the order of the derivatives. Default: 1.

    Return:
        torch.Tensor: the spatial gradients of the input feature map.

    Shape:
        - Input: :math:`(B, C, D, H, W)`. D, H, W are spatial dimensions, gradient is calculated w.r.t to them.
        - Output: :math:`(B, C, 3, D, H, W)` or :math:`(B, C, 6, D, H, W)`

    Examples:
        >>> input = torch.rand(1, 4, 2, 4, 4)
        >>> output = spatial_gradient3d(input)
        >>> output.shape
        torch.Size([1, 4, 3, 2, 4, 4])
    """
    if not isinstance(input, torch.Tensor):
        raise TypeError("Input type is not a torch.Tensor. Got {}".format(
            type(input)))

    if not len(input.shape) == 5:
        raise ValueError(
            "Invalid input shape, we expect BxCxDxHxW. Got: {}".format(
                input.shape))
    # allocate kernel
    kernel: torch.Tensor = get_spatial_gradient_kernel3d(mode, order)
    if normalized:
        kernel = normalize_kernel3d(kernel)

    # prepare kernel
    b, c, d, h, w = input.shape
    tmp_kernel: torch.Tensor = kernel.to(input).detach()
    tmp_kernel = tmp_kernel.repeat(c, 1, 1, 1, 1)

    # convolve input tensor with grad kernel
    kernel_flip: torch.Tensor = tmp_kernel.flip(-3)

    # Pad with "replicate for spatial dims, but with zeros for channel
    spatial_pad = [
        kernel.size(2) // 2,
        kernel.size(2) // 2,
        kernel.size(3) // 2,
        kernel.size(3) // 2,
        kernel.size(4) // 2,
        kernel.size(4) // 2
    ]
    out_ch: int = 6 if order == 2 else 3
    return F.conv3d(F.pad(input, spatial_pad, 'replicate'),
                    kernel_flip,
                    padding=0,
                    groups=c).view(b, c, out_ch, d, h, w)
Beispiel #2
0
 def __init__(self,
              mode: str = 'diff',
              order: int = 1) -> None:
     super(SpatialGradient3d, self).__init__()
     self.order: int = order
     self.mode: str = mode
     self.kernel = get_spatial_gradient_kernel3d(mode, order)
     return