コード例 #1
0
def conv_soft_argmax3d(input: torch.Tensor,
                       kernel_size: Tuple[int, int, int] = (3, 3, 3),
                       stride: Tuple[int, int, int] = (1, 1, 1),
                       padding: Tuple[int, int, int] = (1, 1, 1),
                       temperature: Union[torch.Tensor, float] = torch.tensor(1.0),
                       normalized_coordinates: bool = False,
                       eps: float = 1e-8,
                       output_value: bool = True,
                       strict_maxima_bonus: float = 0.0) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    r"""Function that computes the convolutional spatial Soft-Argmax 3D over the windows
    of a given input heatmap. Function has two outputs: argmax coordinates and the softmaxpooled heatmap values
    themselves.
    On each window, the function computed is:

    .. math::
             ijk(X) = \frac{\sum{(i,j,k)} * exp(x / T)  \in X} {\sum{exp(x / T)  \in X}}

    .. math::
             val(X) = \frac{\sum{x * exp(x / T)  \in X}} {\sum{exp(x / T)  \in X}}

    where T is temperature.

    Args:
        kernel_size (Tuple[int,int,int]):  size of the window
        stride (Tuple[int,int,int]): stride of the window.
        padding (Tuple[int,int,int]): input zero padding
        temperature (torch.Tensor): factor to apply to input. Default is 1.
        normalized_coordinates (bool): whether to return the coordinates normalized in the range of [-1, 1]. Otherwise,
                                       it will return the coordinates in the range of the input shape. Default is False.
        eps (float): small value to avoid zero division. Default is 1e-8.
        output_value (bool): if True, val is outputed, if False, only ij
        strict_maxima_bonus (float): pixels, which are strict maxima will score (1 + strict_maxima_bonus) * value.
                                     This is needed for mimic behavior of strict NMS in classic local features
    Shape:
        - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
        - Output: :math:`(N, C, 3, D_{out}, H_{out}, W_{out})`, :math:`(N, C, D_{out}, H_{out}, W_{out})`, where

         .. math::
             D_{out} = \left\lfloor\frac{D_{in}  + 2 \times \text{padding}[0] -
             (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor

         .. math::
             H_{out} = \left\lfloor\frac{H_{in}  + 2 \times \text{padding}[1] -
             (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor

         .. math::
             W_{out} = \left\lfloor\frac{W_{in}  + 2 \times \text{padding}[2] -
             (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor

    Examples:
        >>> input = torch.randn(20, 16, 3, 50, 32)
        >>> nms_coords, nms_val = conv_soft_argmax2d(input, (3, 3, 3), (1, 2, 2), (0, 1, 1))
    """
    if not torch.is_tensor(input):
        raise TypeError("Input type is not a torch.Tensor. Got {}"
                        .format(type(input)))

    if not len(input.shape) == 5:
        raise ValueError("Invalid input shape, we expect BxCxDxHxW. Got: {}"
                         .format(input.shape))

    if temperature <= 0:
        raise ValueError("Temperature should be positive float or tensor. Got: {}"
                         .format(temperature))

    b, c, d, h, w = input.shape
    kx, ky, kz = kernel_size
    device: torch.device = input.device
    dtype: torch.dtype = input.dtype
    input = input.view(b * c, 1, d, h, w)

    center_kernel: torch.Tensor = _get_center_kernel3d(kx, ky, kz, device).to(dtype)
    window_kernel: torch.Tensor = _get_window_grid_kernel3d(kx, ky, kz, device).to(dtype)

    # applies exponential normalization trick
    # https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/
    # https://github.com/pytorch/pytorch/blob/bcb0bb7e0e03b386ad837015faba6b4b16e3bfb9/aten/src/ATen/native/SoftMax.cpp#L44
    x_max = F.adaptive_max_pool3d(input, (1, 1, 1))

    # max is detached to prevent undesired backprop loops in the graph
    x_exp = ((input - x_max.detach()) / temperature).exp()

    pool_coef: float = float(kx * ky * kz)

    # softmax denominator
    den = pool_coef * F.avg_pool3d(x_exp.view_as(input),
                                   kernel_size,
                                   stride=stride,
                                   padding=padding) + eps

    # We need to output also coordinates
    # Pooled window center coordinates
    grid_global: torch.Tensor = create_meshgrid3d(
        d, h, w, False, device=device).to(dtype).permute(0, 4, 1, 2, 3)

    grid_global_pooled = F.conv3d(grid_global,
                                  center_kernel,
                                  stride=stride,
                                  padding=padding)

    # Coordinates of maxima residual to window center
    # prepare kernel
    coords_max: torch.Tensor = F.conv3d(x_exp,
                                        window_kernel,
                                        stride=stride,
                                        padding=padding)

    coords_max = coords_max / den.expand_as(coords_max)
    coords_max = coords_max + grid_global_pooled.expand_as(coords_max)
    # [:,:, 0, ...] is depth (scale)
    # [:,:, 1, ...] is x
    # [:,:, 2, ...] is y

    if normalized_coordinates:
        coords_max = normalize_pixel_coordinates3d(coords_max.permute(0, 2, 3, 4, 1), d, h, w)
        coords_max = coords_max.permute(0, 4, 1, 2, 3)

    # Back B*C -> (b, c)
    coords_max = coords_max.view(b, c, 3, coords_max.size(2), coords_max.size(3), coords_max.size(4))

    if not output_value:
        return coords_max
    x_softmaxpool = pool_coef * F.avg_pool3d(x_exp.view(input.size()) * input,
                                             kernel_size,
                                             stride=stride,
                                             padding=padding) / den
    if strict_maxima_bonus > 0:
        in_levels: int = input.size(2)
        out_levels: int = x_softmaxpool.size(2)
        skip_levels: int = (in_levels - out_levels) // 2
        strict_maxima: torch.Tensor = F.avg_pool3d(kornia.feature.nms3d(input, kernel_size), 1, stride, 0)
        strict_maxima = strict_maxima[:, :, skip_levels:out_levels - skip_levels]
        x_softmaxpool *= 1.0 + strict_maxima_bonus * strict_maxima
    x_softmaxpool = x_softmaxpool.view(b,
                                       c,
                                       x_softmaxpool.size(2),
                                       x_softmaxpool.size(3),
                                       x_softmaxpool.size(4))
    return coords_max, x_softmaxpool
コード例 #2
0
def conv_soft_argmax3d(
    input: torch.Tensor,
    kernel_size: Tuple[int, int, int] = (3, 3, 3),
    stride: Tuple[int, int, int] = (1, 1, 1),
    padding: Tuple[int, int, int] = (1, 1, 1),
    temperature: Union[torch.Tensor, float] = torch.tensor(1.0),
    normalized_coordinates: bool = False,
    eps: float = 1e-8,
    output_value: bool = True
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    r"""Function that computes the convolutional spatial Soft-Argmax 3D over the windows
    of a given input heatmap. Function has two outputs: argmax coordinates and the softmaxpooled heatmap values
    themselves.
    On each window, the function computed is:

    .. math::
             ijk(X) = \frac{\sum{(i,j,k)} * exp(x / T)  \in X} {\sum{exp(x / T)  \in X}}

    .. math::
             val(X) = \frac{\sum{x * exp(x / T)  \in X}} {\sum{exp(x / T)  \in X}}

    where T is temperature.

    Args:
        kernel_size (Tuple[int,int,int]):  size of the window
        stride (Tuple[int,int,int]): stride of the window.
        padding (Tuple[int,int,int]): input zero padding
        temperature (torch.Tensor): factor to apply to input. Default is 1.
        normalized_coordinates (bool): whether to return the coordinates normalized in the range of [-1, 1]. Otherwise,
                                       it will return the coordinates in the range of the input shape. Default is False.
        eps (float): small value to avoid zero division. Default is 1e-8.
        output_value (bool): if True, val is outputed, if False, only ij

    Shape:
        - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
        - Output: :math:`(N, C, 3, D_{out}, H_{out}, W_{out})`, :math:`(N, C, D_{out}, H_{out}, W_{out})`, where

         .. math::
             D_{out} = \left\lfloor\frac{D_{in}  + 2 \times \text{padding}[0] -
             (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor

         .. math::
             H_{out} = \left\lfloor\frac{H_{in}  + 2 \times \text{padding}[1] -
             (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor

         .. math::
             W_{out} = \left\lfloor\frac{W_{in}  + 2 \times \text{padding}[2] -
             (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor

    Examples:
        >>> input = torch.randn(20, 16, 3, 50, 32)
        >>> nms_coords, nms_val = conv_soft_argmax2d(input, (3, 3, 3), (1, 2, 2), (0, 1, 1))
    """
    if not torch.is_tensor(input):
        raise TypeError("Input type is not a torch.Tensor. Got {}".format(
            type(input)))
    if not len(input.shape) == 5:
        raise ValueError(
            "Invalid input shape, we expect BxCxDxHxW. Got: {}".format(
                input.shape))
    if temperature <= 0:
        raise ValueError(
            "Temperature should be positive float or tensor. Got: {}".format(
                temperature))

    b, c, d, h, w = input.shape
    input = input.view(b * c, 1, d, h, w)
    dev: torch.device = input.device

    center_kernel = _get_center_kernel3d(kernel_size[0], kernel_size[1],
                                         kernel_size[2], dev)
    window_kernel = _get_window_grid_kernel3d(kernel_size[0], kernel_size[1],
                                              kernel_size[2], dev)
    window_kernel = window_kernel.to(input.dtype)

    x_exp = (input / temperature).exp()

    pool_coef: float = float(kernel_size[0] * kernel_size[1] * kernel_size[2])

    # softmax denominator
    den = pool_coef * F.avg_pool3d(
        x_exp.view(
            input.size()), kernel_size, stride=stride, padding=padding) + 1e-12

    x_softmaxpool = pool_coef * F.avg_pool3d(x_exp.view(input.size()) * input,
                                             kernel_size,
                                             stride=stride,
                                             padding=padding) / den
    x_softmaxpool = x_softmaxpool.view(b, c, x_softmaxpool.size(2),
                                       x_softmaxpool.size(3),
                                       x_softmaxpool.size(4))

    # We need to output also coordinates
    # Pooled window center coordinates
    grid_global: torch.Tensor = create_meshgrid3d(d,
                                                  h,
                                                  w,
                                                  False,
                                                  device=input.device).permute(
                                                      0, 4, 1, 2, 3)
    grid_global = grid_global.to(input.dtype)
    grid_global_pooled = F.conv3d(grid_global,
                                  center_kernel.to(input.dtype),
                                  stride=stride,
                                  padding=padding)

    # Coordinates of maxima residual to window center
    # prepare kernel
    coords_max: torch.Tensor = F.conv3d(x_exp,
                                        window_kernel,
                                        stride=stride,
                                        padding=padding)

    coords_max = coords_max / den.expand_as(coords_max)
    coords_max = coords_max + grid_global_pooled.expand_as(coords_max)
    # [:,:, 0, ...] is depth (scale)
    # [:,:, 1, ...] is x
    # [:,:, 2, ...] is y

    if normalized_coordinates:
        coords_max = normalize_pixel_coordinates3d(
            coords_max.permute(0, 2, 3, 4, 1), d, h, w)
        coords_max = coords_max.permute(0, 4, 1, 2, 3)

    # Back B*C -> (b, c)
    coords_max = coords_max.view(b, c, 3, coords_max.size(2),
                                 coords_max.size(3), coords_max.size(4))

    if output_value:
        return coords_max, x_softmaxpool
    return coords_max