Exemplo n.º 1
0
def jacobian(warp, bound='circular'):
    """Compute the jacobian of a 'vox' warp.

    This function estimates the field of Jacobian matrices of a deformation
    field using central finite differences: (next-previous)/2.

    Note that for Neumann boundary conditions, symmetric padding is usuallly
    used (symmetry w.r.t. voxel edge), when computing Jacobian fields,
    reflection padding is more adapted (symmetry w.r.t. voxel centre), so that
    derivatives are zero at the edges of the FOV.

    Note that voxel sizes are not considered here. The flow field should be
    expressed in voxels and so will the Jacobian.

    Args:
        warp (torch.Tensor): flow field (N, W, H, D, 3).
        bound (str, optional): Boundary conditions. Defaults to 'circular'.

    Returns:
        jac (torch.Tensor): Field of Jacobian matrices (N, W, H, D, 3, 3).
            jac[:,:,:,:,i,j] contains the derivative of the i-th component of
            the deformation field with respect to the j-th axis.

    """
    warp = torch.as_tensor(warp)
    shape = warp.size()
    dim = shape[-1]
    ker = kernels.imgrad(dim, device=warp.device, dtype=warp.dtype)
    ker = kernels.make_separable(ker, dim)
    warp = utils.last2channel(warp)
    if bound in ('circular', 'fft'):
        warp = utils.pad(warp, (1, ) * dim, mode='circular', side='both')
        pad = 0
    elif bound in ('reflect1', 'dct1'):
        warp = utils.pad(warp, (1, ) * dim, mode='reflect1', side='both')
        pad = 0
    elif bound in ('reflect2', 'dct2'):
        warp = utils.pad(warp, (1, ) * dim, mode='reflect2', side='both')
        pad = 0
    elif bound in ('constant', 'zero', 'zeros'):
        pad = 1
    else:
        raise ValueError('Unknown bound {}.'.format(bound))
    if dim == 1:
        conv = _F.conv1d
    elif dim == 2:
        conv = _F.conv2d
    elif dim == 3:
        conv = _F.conv3d
    else:
        raise ValueError(
            'Warps must be of dimension 1, 2 or 3. Got {}.'.format(dim))
    jac = conv(warp, ker, padding=pad, groups=dim)
    jac = jac.reshape((shape[0], dim, dim) + shape[1:])
    jac = jac.permute((0, ) + tuple(range(3, 3 + dim)) + (1, 2))
    return jac
Exemplo n.º 2
0
def pad_same(dim, tensor, kernel_size, dilation=1, bound='zero', value=0):
    """Applies a padding that preserves the input dimensions when
    followed by a convolution-like (i.e. moving window) operation.

    Parameters
    ----------
    dim : int
    tensor : (..., *spatial) tensor
    kernel_size : [sequence of] int
    dilation : [sequence f] int, default=1
    bound : {'constant', 'dft', 'dct1', 'dct2', ...}, default='constant'
    value : float, default=0

    Returns
    -------
    padded : (..., *spatial_out) tensor

    """
    kernel_size = make_list(kernel_size, dim)
    dilation = make_list(dilation, dim)
    input_shape = tensor.shape[-dim:]
    padding = compute_conv_padding(input_shape, kernel_size, 'same', dilation)
    padding = _normalize_padding(padding)
    padding = [0] * (2*tensor.dim()-dim) + padding
    return utils.pad(tensor, padding, mode=bound, value=value)
Exemplo n.º 3
0
def unwrap(phase, dim=None, bound='dct2', max_iter=0, tol=1e-5):
    """Laplacian unwrapping of the phase

    Parameters
    ----------
    phase : tensor
        Wrapped phase, in radian
    dim : int, default=phase.dim()
        Number of spatial dimensions
    max_iter : int, default=0
        Maximum number of unwrapping iterations.
        If 0, return the Laplacian filtered phase, which is not exactly
        equal to the input phase modulo 2 pi.
    tol : float, default=1e-5
        Tolerance for early stopping


    Returns
    -------
    unwrapped : tensor

    References
    ----------
    .. "Fast phase unwrapping algorithm for interferometric applications"
       Marvin A. Schofield and Yimei Zhu
       Optics Letters (2003)

    """
    # TODO: would be nice to use DCT/DST rather than padding once they
    #       are available in PyTorch.

    dim = dim or phase.dim()
    dims = list(range(-dim, 0))
    shape = bigshape = phase.shape[-dim:]

    if bound not in ('dct', 'circulant'):
        phase = utils.pad(phase, [d//2 for d in shape], side='both', mode=bound)
        bigshape = phase.shape[-dim:]

    freq = _laplacian_freq(bigshape, **utils.backend(phase))
    phase = fft.ifftshift(phase, dim=dims)
    twopi = 2 * pymath.pi

    if max_iter == 0:
        phase = _laplacian_filter(phase, freq, dims)
    else:
        for n_iter in range(1, max_iter+1):
            filtered_phase = _laplacian_filter(phase, freq, dims)
            filtered_phase.sub_(phase).div_(twopi).round_().mul_(twopi)
            phase += filtered_phase

            if n_iter < max_iter and filtered_phase.mean() < tol:
                break

    phase = fft.fftshift(phase, dim=dims)

    if bound not in ('dct', 'circulant'):
        slicer = [slice(d//2, d+d//2) for d in shape]
        phase = phase[(Ellipsis, *slicer)]
    return phase
Exemplo n.º 4
0
    def forward(self, q, k, v, **overload):
        """

        Parameters
        ----------
        q : (b, c, *spatial)
            Queries
        k : (b, c, *spatial)
            Keys
        v : (b, c, *spatial)
            Values

        Returns
        -------
        x : (b, c, *spatial)

        """
        kernel_size = overload.pop('kernel_size', self.kernel_size)
        stride = overload.pop('stride', self.kernel_size)
        padding = overload.pop('padding', self.padding)
        padding_mode = overload.pop('padding_mode', self.padding_mode)

        dim = q.dim() - 2
        if padding == 'auto':
            k = spatial.pad_same(dim, k, kernel_size, bound=padding_mode)
            v = spatial.pad_same(dim, v, kernel_size, bound=padding_mode)
        elif padding:
            padding = [0] * 2 + py.make_list(padding, dim)
            k = utils.pad(k, padding, side='both', mode=padding_mode)
            v = utils.pad(v, padding, side='both', mode=padding_mode)

        # compute weights by query/key dot product
        kernel_size = py.make_list(kernel_size, dim)
        k = utils.unfold(k, kernel_size, stride)
        k = k.reshape([*k.shape[:dim + 2], -1])
        k = utils.movedim(k, 1, -1)
        q = utils.movedim(q[..., None], 1, -1)
        k = math.softmax(linalg.dot(k, q), dim=-1)
        k = k[:, None]  # add back channel dimension

        # compute new values by weight/value dot product
        v = utils.unfold(v, kernel_size, stride)
        v = v.reshape([*v.shape[:dim + 2], -1])
        v = linalg.dot(k, v)

        return v
Exemplo n.º 5
0
    def forward(self, x, **overload):

        conv1 = self.conv
        clone = copy(self)
        clone.conv = copy(conv1)

        stride = overload.get('stride', clone.stride)
        padding = overload.get('padding', clone.padding)
        padding_mode = overload.get('padding_mode', clone.padding_mode)
        output_padding = overload.get('output_padding', clone.output_padding)
        dilation = overload.get('dilation', clone.dilation)

        kernel_size = make_tuple(clone.kernel_size, self.dim)
        stride = make_tuple(stride, self.dim)
        output_padding = make_tuple(output_padding, self.dim)
        dilation = make_tuple(dilation, self.dim)

        if padding == 'auto':
            padding = [((k - 1) * d) // 2
                       for k, d in zip(kernel_size, dilation)]
        padding = make_tuple(padding, self.dim)

        # perform pre-padding
        if padding_mode not in _native_padding_mode:
            x = utils.pad(x, padding, mode=padding_mode, side='both')
            padding = 0

        # call native convolution
        clone.stride = stride
        clone.padding = padding
        clone.padding_mode = padding_mode
        clone.output_padding = output_padding
        clone.dilation = dilation
        x = clone.conv(x)

        # perform post-padding
        if not clone.transposed and output_padding:
            x = utils.pad(x, output_padding, side='right')

        self.conv = conv1
        return x
Exemplo n.º 6
0
def kernel_apply(kspace, patterns, kernel_size, kernels, inplace=False):
    """Apply a GRAPPA kernel to an accelerated k-space

    All batch elements should have the same sampling pattern

    Parameters
    ----------
    kspace : ([*batch], coils, *freq)
        Accelerated k-space
    patterns : (*freq) tensor[long]
        Code of sampling pattern about each k-space location
    kernel_size : sequence of int
        GRAPPA kernel size
    kernels : dict of int -> ([*batch], coils, coils, nb_elem) tensor
        Dictionary of GRAPPA kernels (keys are pattern codes)

    Returns
    -------
    kspace : ([*batch], coils, *freq)

    """
    ndim = patterns.dim()
    coils, *freq = kspace.shape[-ndim - 1:]
    batch = kspace.shape[:-ndim - 1]
    kernel_size = py.make_list(kernel_size, ndim)

    kspace_out = kspace
    if not inplace:
        kspace_out = kspace_out.clone()
    kspace = utils.pad(kspace, [(k - 1) // 2 for k in kernel_size],
                       side='both')
    kspace = utils.unfold(kspace, kernel_size, stride=1)

    def t(x):
        return x.transpose(-1, -2)

    for code, kernel in kernels.items():
        kernel = kernels[code]
        pattern = code_to_pattern(code, kernel_size, device=kspace.device)
        pattern_size = pattern.sum()
        mask = patterns == code
        kspace1 = kspace[..., mask, :, :][..., pattern]
        kspace1 = kspace1.transpose(-2, -3) \
                         .reshape([*batch, -1, coils * pattern_size])
        kernel = kernel.reshape([*batch, coils, coils * pattern_size])
        kspace1 = t(kspace1.matmul(t(kernel)))
        kspace_out[..., mask] = kspace1

    return kspace_out
Exemplo n.º 7
0
def _smooth_for_reg(dat, mat, samp):
    """Smoothing for image registration. FWHM is computed from voxel size
       and sub-sampling amount.

    Parameters
    ----------
    dat : (X, Y, Z) tensor_like
        3D image volume.
    mat : (4, 4) tensor_like
        Affine matrix.
    samp : float
        Amount of sub-sampling (in mm).

    Returns
    -------
    dat : (Nx, Ny, Nz) tensor_like
        Smoothed 3D image volume.

    """
    if samp <= 0:
        return dat
    samp = torch.tensor((samp, ) * 3, dtype=dat.dtype, device=dat.device)
    # Make smoothing kernel
    vx = voxel_size(mat).to(dat.device).type(dat.dtype)
    fwhm = torch.sqrt(
        torch.max(samp**2 - vx**2,
                  torch.zeros(3, device=dat.device, dtype=dat.dtype))) / vx
    smo = smooth(('gauss', ) * 3,
                 fwhm=fwhm,
                 device=dat.device,
                 dtype=dat.dtype,
                 sep=True)
    # Padding amount for subsequent convolution
    size_pad = (smo[0].shape[2], smo[1].shape[3], smo[2].shape[4])
    size_pad = (torch.tensor(size_pad) - 1) // 2
    size_pad = tuple(size_pad.int().tolist())
    # Smooth deformation with Gaussian kernel (by separable convolution)
    dat = pad(dat, size_pad, side='both')
    dat = dat[None, None, ...]
    dat = F.conv3d(dat, smo[0])
    dat = F.conv3d(dat, smo[1])
    dat = F.conv3d(dat, smo[2])[0, 0, ...]

    return dat
Exemplo n.º 8
0
def get_pattern_codes(sampling_mask, kernel_size):
    """Compute the pattern's code about each voxel

    Parameters
    ----------
    sampling_mask : (*freq) tensor[bool]
    kernel_size : [sequence of] int

    Returns
    -------
    pattern_mask : (*freq) tensor[long]

    """
    ndim = sampling_mask.dim()
    kernel_size = py.make_list(kernel_size, ndim)
    sampling_mask = sampling_mask.long()
    sampling_mask = utils.pad(sampling_mask,
                              [(k - 1) // 2 for k in kernel_size],
                              side='both')
    sampling_mask = utils.unfold(sampling_mask, kernel_size, stride=1)
    return pattern_to_code(sampling_mask, ndim)
Exemplo n.º 9
0
def pool(dim, tensor, kernel_size=3, stride=None, dilation=1, padding=0,
         bound='constant', reduction='mean', ceil=False, return_indices=False,
         affine=None):
    """Perform a pooling

    Parameters
    ----------
    dim : {1, 2, 3}
        Number of spatial dimensions
    tensor : (*batch, *spatial_in) tensor
        Input tensor
    kernel_size : int or sequence[int], default=3
        Size of the pooling window
    stride : int or sequence[int], default=`kernel_size`
        Strides between output elements.
    dilation : int or sequence[int], default=1
        Strides between elements of the kernel.
    padding : 'same' or int or sequence[int], default=0
        Padding performed before the convolution.
        If 'same', the padding is chosen such that the shape of the
        output tensor is `floor(spatial_in / stride)` (or
        `ceil(spatial_in / stride)` if `ceil` is True).
    bound : str, default='constant'
        Boundary conditions used in the padding.
    reduction : {'mean', 'max', 'min', 'median', 'sum'} or callable, default='mean'
        Function to apply to the elements in a window.
    ceil : bool, default=False
        Use ceil instead of floor to compute output shape
    return_indices : bool, default=False
        Return input index of the min/max/median element.
        For other types of reduction, return None.
    affine : (..., D+1, D+1) tensor, optional
        Input orientation matrix

    Returns
    -------
    pooled : (*batch, *spatial_out) tensor
    indices : (*batch, *spatial_out, dim) tensor, if `return_indices`
    affine : (..., D+1, D+1) tensor, if `affine`

    """
    # move everything to the same dtype/device
    tensor = torch.as_tensor(tensor)

    # sanity checks + reshape for torch's conv
    batch = tensor.shape[:-dim]
    spatial_in = tensor.shape[-dim:]
    tensor = tensor.reshape([-1, *spatial_in])

    # compute padding
    kernel_size = make_list(kernel_size, dim)
    stride = make_list(stride or None, dim)
    stride = [st or ks for st, ks in zip(stride, kernel_size)]
    dilation = make_list(dilation or 1, dim)
    padding = compute_conv_padding(spatial_in, kernel_size, padding,
                                   dilation, stride, ceil)
    if ceil:
        # ceil mode cannot be obtained using unfold. we may need to
        # pad the input a bit more
        padding = _pad_for_ceil(spatial_in, kernel_size, padding, stride, dilation)

    use_torch = (reduction in ('mean', 'avg', 'max') and 
                 dim in (1, 2, 3) and
                 dilation == [1] * dim)

    padding0 = padding
    sum_padding = sum([sum(p) if isinstance(p, (list, tuple)) else p
                       for p in padding])
    if ((not use_torch) or (bound != 'zero' and sum_padding > 0)
            or any(isinstance(p, (list, tuple)) for p in padding)):
        # torch implementation -> handles zero-padding
        # our implementation -> needs explicit padding
        padding = _normalize_padding(padding)
        tensor = utils.pad(tensor, padding, bound, side='both',
                           value=_fill_value(reduction, tensor))
        padding = [0] * dim

    return_indices0 = False
    pool_fn = reduction if callable(reduction) else None

    if use_torch:
        if reduction in ('mean', 'avg'):
            return_indices0 = return_indices
            return_indices = False
            pool_fn = (F.avg_pool1d if dim == 1 else
                       F.avg_pool2d if dim == 2 else
                       F.avg_pool3d if dim == 3 else None)
            if pool_fn:
                pool_fn0 = pool_fn
                pool_fn = lambda x, *a, **k: pool_fn0(x[:, None], *a, **k,
                                                      padding=padding)[:, 0]
        elif reduction == 'max':
            pool_fn = (F.max_pool1d if dim == 1 else
                       F.max_pool2d if dim == 2 else
                       F.max_pool3d if dim == 3 else None)
            if pool_fn:
                pool_fn0 = pool_fn
                pool_fn = lambda x, *a, **k: pool_fn0(x[:, None], *a, **k,
                                                      padding=padding)[:, 0]

    if not pool_fn:
        if reduction not in ('min', 'max', 'median'):
            return_indices0 = return_indices
            return_indices = False
        if reduction == 'mean':
            reduction = lambda x: math.mean(x, dim=-1)
        elif reduction == 'sum':
            reduction = lambda x: math.sum(x, dim=-1)
        elif reduction == 'min':
            reduction = lambda x: math.min(x, dim=-1)
        elif reduction == 'max':
            reduction = lambda x: math.max(x, dim=-1)
        elif reduction == 'median':
            reduction = lambda x: math.median(x, dim=-1)
        elif not callable(reduction):
            raise ValueError(f'Unknown reduction {reduction}')
        pool_fn = lambda *a, **k: _pool(*a, **k, dilation=dilation, reduction=reduction)

    outputs = []
    if return_indices:
        tensor, ind = pool_fn(tensor, kernel_size, stride=stride)
        ind = utils.ind2sub(ind, stride)
        ind = utils.movedim(ind, 0, -1)
        outputs.append(ind)
    else:
        tensor = pool_fn(tensor, kernel_size, stride=stride)
        if return_indices0:
            outputs.append(None)

    spatial_out = tensor.shape[-dim:]
    tensor = tensor.reshape([*batch, *spatial_out])
    outputs = [tensor, *outputs]

    if affine is not None:
        affine, _ = affine_conv(affine, spatial_in,
                                kernel_size=kernel_size, stride=stride,
                                padding=padding0, dilation=dilation)
        outputs.append(affine)

    return outputs[0] if len(outputs) == 1 else tuple(outputs)
Exemplo n.º 10
0
def _make_image(option, dim=None, device=None):
    """
    Load an image and build a Gaussian pyramid (if requireD)
    Returns: ImagePyramid
    """
    dat, mask, affine = _load_image(option.files,
                                    dim=dim,
                                    device=device,
                                    label=option.label)
    dim = dat.dim() - 1
    if option.mask:
        mask1 = mask
        mask, _, _ = _load_image([option.mask],
                                 dim=dim,
                                 device=device,
                                 label=option.label)
        if mask.shape[-dim:] != dat.shape[-dim:]:
            raise ValueError('Mask should have the same shape as the image. '
                             f'Got {mask.shape[-dim:]} and {dat.shape[-dim:]}')
        if mask1 is not None:
            mask = mask * mask1
        del mask1
    if option.world:  # overwrite orientation matrix
        affine = io.transforms.map(option.world).fdata().squeeze()
    for transform in (option.affine or []):
        transform = io.transforms.map(transform).fdata().squeeze()
        affine = spatial.affine_lmdiv(transform, affine)
    if not option.discretize and any(option.rescale):
        dat = _rescale_image(dat, option.rescale)
    if option.pad:
        pad = option.pad
        if isinstance(pad[-1], str):
            *pad, unit = pad
        else:
            unit = 'vox'
        if unit == 'mm':
            voxel_size = spatial.voxel_size(affine)
            pad = torch.as_tensor(pad, **utils.backend(voxel_size))
            pad = pad / voxel_size
            pad = pad.floor().int().tolist()
        else:
            pad = [int(p) for p in pad]
        pad = py.make_list(pad, dim)
        if any(pad):
            affine, _ = spatial.affine_pad(affine,
                                           dat.shape[-dim:],
                                           pad,
                                           side='both')
            dat = utils.pad(dat, pad, side='both', mode=option.bound)
            if mask is not None:
                mask = utils.pad(mask, pad, side='both', mode=option.bound)
    if option.fwhm:
        fwhm = option.fwhm
        if isinstance(fwhm[-1], str):
            *fwhm, unit = fwhm
        else:
            unit = 'vox'
        if unit == 'mm':
            voxel_size = spatial.voxel_size(affine)
            fwhm = torch.as_tensor(fwhm, **utils.backend(voxel_size))
            fwhm = fwhm / voxel_size
        dat = spatial.smooth(dat, dim=dim, fwhm=fwhm, bound=option.bound)
    image = objects.ImagePyramid(dat,
                                 levels=option.pyramid,
                                 affine=affine,
                                 dim=dim,
                                 bound=option.bound,
                                 mask=mask,
                                 extrapolate=option.extrapolate,
                                 method=option.pyramid_method)
    if getattr(option, 'soft_quantize', False) and len(image[0].dat) == 1:
        for level in image:
            level.preview = level.dat
            level.dat = _soft_quantize_image(level.dat, option.soft_quantize)
    elif not option.label and option.discretize:
        for level in image:
            level.preview = level.dat
            level.dat = _discretize_image(level.dat, option.discretize)
    return image
Exemplo n.º 11
0
def _hist_2d(img0, img1, mx_int, fwhm):
    """Make 2D histogram, requires:
        * Images same size.
        * Images same min and max intensities (non-negative).

    Parameters
    ----------
    img0 : (X, Y, Z) tensor_like
        First image volume.
    img1 : (X, Y, Z) tensor_like
        Second image volume.
    mx_int : int
        This parameter sets the max intensity in the images, which decides
        how many bins to use in the joint image histograms
        (e.g, mx_int=511 -> H.shape = (512, 512)).
    fwhm : float
        Full-width at half max of Gaussian kernel, for smoothing
        histogram.

    Returns
    ----------
    H : (mx_int + 1, mx_int + 1) tensor_like
        Joint intensity histogram.

    Notes
    ----------
    Naive method for computing a 2D histogram:
    h = torch.zeros((mx_int + 1, mx_int + 1))
    for n in range(num_vox):
        h[img0[n], mg1[n]] += 1

    """
    fwhm = (fwhm, ) * 2
    # Convert each 'coordinate' of intensities to an index
    # (replicates the sub2ind function of MATLAB)
    img0 = img0.flatten().floor()
    img1 = img1.flatten().floor()
    sub = torch.stack((img0, img1), dim=1)  # (num_vox, 2)
    to_ind = torch.tensor((1, mx_int + 1), dtype=sub.dtype,
                          device=img0.device)[..., None]  # (2, 1)
    ind = torch.tensordot(sub, to_ind, dims=([1], [0]))  # (nvox, 1)
    # Build histogram H by adding up counts according to the indicies in ind
    H = torch.zeros(mx_int + 1,
                    mx_int + 1,
                    device=img0.device,
                    dtype=ind.dtype)
    H.put_(ind.long(),
           torch.ones(1, device=img0.device, dtype=ind.dtype).expand_as(ind),
           accumulate=True)
    # Smoothing kernel
    smo = smooth(('gauss', ) * 2,
                 fwhm=fwhm,
                 device=img0.device,
                 dtype=torch.float32,
                 sep=True)
    # Pad
    p = (smo[0].shape[2], smo[1].shape[3])
    p = (torch.tensor(p) - 1) // 2
    p = tuple(p.int().tolist())
    H = pad(H, p, side='both')
    # Smooth
    H = H[None, None, ...]
    H = F.conv2d(H, smo[0])
    H = F.conv2d(H, smo[1])
    H = H[0, 0, ...]
    # Clamp
    H = H.clamp_min(0.0)
    # Add eps
    H = H + 1e-7
    # # Visualise histogram
    # import matplotlib.pyplot as plt
    # plt.figure(num=1)
    # plt.imshow(H.detach().cpu(),
    #     cmap='coolwarm', interpolation='nearest',
    #     aspect='equal', vmax=0.05*H.max())
    # plt.axis('off')
    # plt.show()

    return H
Exemplo n.º 12
0
def pool(dim,
         tensor,
         kernel_size=3,
         stride=None,
         dilation=1,
         padding=0,
         bound='zero',
         reduction='mean',
         return_indices=False,
         affine=None):
    """Perform a pooling

    Parameters
    ----------
    dim : {1, 2, 3}
        Number of spatial dimensions
    tensor : (*batch, *spatial_in) tensor
        Input tensor
    kernel_size : int or sequence[int], default=3
        Size of the pooling window
    stride : int or sequence[int], default=`kernel_size`
        Strides between output elements.
    dilation : int or sequece[int], default=1
        Strides between elements of the kernel.
    padding : 'auto' or int or sequence[int], default=0
        Padding performed before the convolution.
        If 'auto', the padding is chosen such that the shape of the
        output tensor is `spatial_in // stride`.
    bound : str, default='zero'
        Boundary conditions used in the padding.
    reduction : {'mean', 'max', 'min', 'median', 'sum'} or callable, default='mean'
        Function to apply to the elements in a window.
    return_indices : bool, default=False
        Return input index of the min/max/median element.
        For other types of reduction, return None.
    affine : (..., D+1, D+1) tensor, optional
        Input orientation matrix

    Returns
    -------
    pooled : (*batch, *spatial_out) tensor
    indices : (*batch, *spatial_out, dim) tensor, if `return_indices`
    affine : (..., D+1, D+1) tensor, if `affine`

    """
    # move everything to the same dtype/device
    tensor = torch.as_tensor(tensor)

    # sanity checks + reshape for torch's conv
    batch = tensor.shape[:-dim]
    spatial_in = tensor.shape[-dim:]
    tensor = tensor.reshape([-1, *spatial_in])

    # Perform padding
    kernel_size = make_list(kernel_size, dim)
    stride = make_list(stride or None, dim)
    stride = [st or ks for st, ks in zip(stride, kernel_size)]
    dilation = make_list(dilation or 1, dim)
    padding = make_list(padding, dim)
    padding0 = padding  # save it to update the affine
    for i in range(dim):
        if isinstance(padding[i], str) and padding[i].lower() == 'auto':
            if kernel_size[i] % 2 == 0:
                raise ValueError('Cannot compute automatic padding '
                                 'for even-sized kernels.')
            padding[i] = ((kernel_size[i] - 1) * dilation[i] + 1) // 2

    use_torch = reduction in ('mean', 'avg', 'max') and dim in (1, 2, 3)

    if (not use_torch) or bound != 'zero' and sum(padding) > 0:
        # torch implementation -> handles zero-padding
        # our implementation -> needs explicit padding
        tensor = utils.pad(tensor, padding, bound, side='both')
        padding = [0] * dim

    return_indices0 = False
    pool_fn = reduction if callable(reduction) else None

    if reduction in ('mean', 'avg'):
        return_indices0 = True
        return_indices = False
        pool_fn = (F.avg_pool1d if dim == 1 else F.avg_pool2d
                   if dim == 2 else F.avg_pool3d if dim == 3 else None)
        if pool_fn:
            pool_fn0 = pool_fn
            pool_fn = lambda x, *a, **k: pool_fn0(
                x[:, None], *a, **k, padding=padding, dilation=dilation)[:, 0]
    elif reduction == 'max':
        pool_fn = (F.max_pool1d if dim == 1 else F.max_pool2d
                   if dim == 2 else F.max_pool3d if dim == 3 else None)
        if pool_fn:
            pool_fn0 = pool_fn
            pool_fn = lambda x, *a, **k: pool_fn0(
                x[:, None], *a, **k, padding=padding, dilation=dilation)[:, 0]

    if not pool_fn:
        if reduction not in ('min', 'max', 'median'):
            return_indices0 = True
            return_indices = False
        if reduction == 'mean':
            reduction = lambda x: math.mean(x, dim=-1)
        elif reduction == 'sum':
            reduction = lambda x: math.sum(x, dim=-1)
        elif reduction == 'min':
            reduction = lambda x: math.min(x, dim=-1)
        elif reduction == 'max':
            reduction = lambda x: math.max(x, dim=-1)
        elif reduction == 'median':
            reduction = lambda x: math.median(x, dim=-1)
        elif not callable(reduction):
            raise ValueError(f'Unknown reduction {reduction}')
        pool_fn = lambda *a, **k: _pool(*a, **k, reduction=reduction)

    outputs = []
    if return_indices:
        tensor, ind = pool_fn(tensor, kernel_size, stride=stride)
        ind = utils.ind2sub(ind, stride)
        ind = utils.movedim(ind, 0, -1)
        outputs.append(ind)
    else:
        tensor = pool_fn(tensor, kernel_size, stride=stride)
        if return_indices0:
            outputs.append(None)

    spatial_out = tensor.shape[-dim:]
    tensor = tensor.reshape([*batch, *spatial_out])
    outputs = [tensor, *outputs]

    if affine is not None:
        affine, _ = affine_conv(affine,
                                spatial_in,
                                kernel_size=kernel_size,
                                stride=stride,
                                padding=padding0,
                                dilation=dilation)
        outputs.append(affine)

    return outputs[0] if len(outputs) == 1 else tuple(outputs)