def jacobian(warp, bound='circular'): """Compute the jacobian of a 'vox' warp. This function estimates the field of Jacobian matrices of a deformation field using central finite differences: (next-previous)/2. Note that for Neumann boundary conditions, symmetric padding is usuallly used (symmetry w.r.t. voxel edge), when computing Jacobian fields, reflection padding is more adapted (symmetry w.r.t. voxel centre), so that derivatives are zero at the edges of the FOV. Note that voxel sizes are not considered here. The flow field should be expressed in voxels and so will the Jacobian. Args: warp (torch.Tensor): flow field (N, W, H, D, 3). bound (str, optional): Boundary conditions. Defaults to 'circular'. Returns: jac (torch.Tensor): Field of Jacobian matrices (N, W, H, D, 3, 3). jac[:,:,:,:,i,j] contains the derivative of the i-th component of the deformation field with respect to the j-th axis. """ warp = torch.as_tensor(warp) shape = warp.size() dim = shape[-1] ker = kernels.imgrad(dim, device=warp.device, dtype=warp.dtype) ker = kernels.make_separable(ker, dim) warp = utils.last2channel(warp) if bound in ('circular', 'fft'): warp = utils.pad(warp, (1, ) * dim, mode='circular', side='both') pad = 0 elif bound in ('reflect1', 'dct1'): warp = utils.pad(warp, (1, ) * dim, mode='reflect1', side='both') pad = 0 elif bound in ('reflect2', 'dct2'): warp = utils.pad(warp, (1, ) * dim, mode='reflect2', side='both') pad = 0 elif bound in ('constant', 'zero', 'zeros'): pad = 1 else: raise ValueError('Unknown bound {}.'.format(bound)) if dim == 1: conv = _F.conv1d elif dim == 2: conv = _F.conv2d elif dim == 3: conv = _F.conv3d else: raise ValueError( 'Warps must be of dimension 1, 2 or 3. Got {}.'.format(dim)) jac = conv(warp, ker, padding=pad, groups=dim) jac = jac.reshape((shape[0], dim, dim) + shape[1:]) jac = jac.permute((0, ) + tuple(range(3, 3 + dim)) + (1, 2)) return jac
def pad_same(dim, tensor, kernel_size, dilation=1, bound='zero', value=0): """Applies a padding that preserves the input dimensions when followed by a convolution-like (i.e. moving window) operation. Parameters ---------- dim : int tensor : (..., *spatial) tensor kernel_size : [sequence of] int dilation : [sequence f] int, default=1 bound : {'constant', 'dft', 'dct1', 'dct2', ...}, default='constant' value : float, default=0 Returns ------- padded : (..., *spatial_out) tensor """ kernel_size = make_list(kernel_size, dim) dilation = make_list(dilation, dim) input_shape = tensor.shape[-dim:] padding = compute_conv_padding(input_shape, kernel_size, 'same', dilation) padding = _normalize_padding(padding) padding = [0] * (2*tensor.dim()-dim) + padding return utils.pad(tensor, padding, mode=bound, value=value)
def unwrap(phase, dim=None, bound='dct2', max_iter=0, tol=1e-5): """Laplacian unwrapping of the phase Parameters ---------- phase : tensor Wrapped phase, in radian dim : int, default=phase.dim() Number of spatial dimensions max_iter : int, default=0 Maximum number of unwrapping iterations. If 0, return the Laplacian filtered phase, which is not exactly equal to the input phase modulo 2 pi. tol : float, default=1e-5 Tolerance for early stopping Returns ------- unwrapped : tensor References ---------- .. "Fast phase unwrapping algorithm for interferometric applications" Marvin A. Schofield and Yimei Zhu Optics Letters (2003) """ # TODO: would be nice to use DCT/DST rather than padding once they # are available in PyTorch. dim = dim or phase.dim() dims = list(range(-dim, 0)) shape = bigshape = phase.shape[-dim:] if bound not in ('dct', 'circulant'): phase = utils.pad(phase, [d//2 for d in shape], side='both', mode=bound) bigshape = phase.shape[-dim:] freq = _laplacian_freq(bigshape, **utils.backend(phase)) phase = fft.ifftshift(phase, dim=dims) twopi = 2 * pymath.pi if max_iter == 0: phase = _laplacian_filter(phase, freq, dims) else: for n_iter in range(1, max_iter+1): filtered_phase = _laplacian_filter(phase, freq, dims) filtered_phase.sub_(phase).div_(twopi).round_().mul_(twopi) phase += filtered_phase if n_iter < max_iter and filtered_phase.mean() < tol: break phase = fft.fftshift(phase, dim=dims) if bound not in ('dct', 'circulant'): slicer = [slice(d//2, d+d//2) for d in shape] phase = phase[(Ellipsis, *slicer)] return phase
def forward(self, q, k, v, **overload): """ Parameters ---------- q : (b, c, *spatial) Queries k : (b, c, *spatial) Keys v : (b, c, *spatial) Values Returns ------- x : (b, c, *spatial) """ kernel_size = overload.pop('kernel_size', self.kernel_size) stride = overload.pop('stride', self.kernel_size) padding = overload.pop('padding', self.padding) padding_mode = overload.pop('padding_mode', self.padding_mode) dim = q.dim() - 2 if padding == 'auto': k = spatial.pad_same(dim, k, kernel_size, bound=padding_mode) v = spatial.pad_same(dim, v, kernel_size, bound=padding_mode) elif padding: padding = [0] * 2 + py.make_list(padding, dim) k = utils.pad(k, padding, side='both', mode=padding_mode) v = utils.pad(v, padding, side='both', mode=padding_mode) # compute weights by query/key dot product kernel_size = py.make_list(kernel_size, dim) k = utils.unfold(k, kernel_size, stride) k = k.reshape([*k.shape[:dim + 2], -1]) k = utils.movedim(k, 1, -1) q = utils.movedim(q[..., None], 1, -1) k = math.softmax(linalg.dot(k, q), dim=-1) k = k[:, None] # add back channel dimension # compute new values by weight/value dot product v = utils.unfold(v, kernel_size, stride) v = v.reshape([*v.shape[:dim + 2], -1]) v = linalg.dot(k, v) return v
def forward(self, x, **overload): conv1 = self.conv clone = copy(self) clone.conv = copy(conv1) stride = overload.get('stride', clone.stride) padding = overload.get('padding', clone.padding) padding_mode = overload.get('padding_mode', clone.padding_mode) output_padding = overload.get('output_padding', clone.output_padding) dilation = overload.get('dilation', clone.dilation) kernel_size = make_tuple(clone.kernel_size, self.dim) stride = make_tuple(stride, self.dim) output_padding = make_tuple(output_padding, self.dim) dilation = make_tuple(dilation, self.dim) if padding == 'auto': padding = [((k - 1) * d) // 2 for k, d in zip(kernel_size, dilation)] padding = make_tuple(padding, self.dim) # perform pre-padding if padding_mode not in _native_padding_mode: x = utils.pad(x, padding, mode=padding_mode, side='both') padding = 0 # call native convolution clone.stride = stride clone.padding = padding clone.padding_mode = padding_mode clone.output_padding = output_padding clone.dilation = dilation x = clone.conv(x) # perform post-padding if not clone.transposed and output_padding: x = utils.pad(x, output_padding, side='right') self.conv = conv1 return x
def kernel_apply(kspace, patterns, kernel_size, kernels, inplace=False): """Apply a GRAPPA kernel to an accelerated k-space All batch elements should have the same sampling pattern Parameters ---------- kspace : ([*batch], coils, *freq) Accelerated k-space patterns : (*freq) tensor[long] Code of sampling pattern about each k-space location kernel_size : sequence of int GRAPPA kernel size kernels : dict of int -> ([*batch], coils, coils, nb_elem) tensor Dictionary of GRAPPA kernels (keys are pattern codes) Returns ------- kspace : ([*batch], coils, *freq) """ ndim = patterns.dim() coils, *freq = kspace.shape[-ndim - 1:] batch = kspace.shape[:-ndim - 1] kernel_size = py.make_list(kernel_size, ndim) kspace_out = kspace if not inplace: kspace_out = kspace_out.clone() kspace = utils.pad(kspace, [(k - 1) // 2 for k in kernel_size], side='both') kspace = utils.unfold(kspace, kernel_size, stride=1) def t(x): return x.transpose(-1, -2) for code, kernel in kernels.items(): kernel = kernels[code] pattern = code_to_pattern(code, kernel_size, device=kspace.device) pattern_size = pattern.sum() mask = patterns == code kspace1 = kspace[..., mask, :, :][..., pattern] kspace1 = kspace1.transpose(-2, -3) \ .reshape([*batch, -1, coils * pattern_size]) kernel = kernel.reshape([*batch, coils, coils * pattern_size]) kspace1 = t(kspace1.matmul(t(kernel))) kspace_out[..., mask] = kspace1 return kspace_out
def _smooth_for_reg(dat, mat, samp): """Smoothing for image registration. FWHM is computed from voxel size and sub-sampling amount. Parameters ---------- dat : (X, Y, Z) tensor_like 3D image volume. mat : (4, 4) tensor_like Affine matrix. samp : float Amount of sub-sampling (in mm). Returns ------- dat : (Nx, Ny, Nz) tensor_like Smoothed 3D image volume. """ if samp <= 0: return dat samp = torch.tensor((samp, ) * 3, dtype=dat.dtype, device=dat.device) # Make smoothing kernel vx = voxel_size(mat).to(dat.device).type(dat.dtype) fwhm = torch.sqrt( torch.max(samp**2 - vx**2, torch.zeros(3, device=dat.device, dtype=dat.dtype))) / vx smo = smooth(('gauss', ) * 3, fwhm=fwhm, device=dat.device, dtype=dat.dtype, sep=True) # Padding amount for subsequent convolution size_pad = (smo[0].shape[2], smo[1].shape[3], smo[2].shape[4]) size_pad = (torch.tensor(size_pad) - 1) // 2 size_pad = tuple(size_pad.int().tolist()) # Smooth deformation with Gaussian kernel (by separable convolution) dat = pad(dat, size_pad, side='both') dat = dat[None, None, ...] dat = F.conv3d(dat, smo[0]) dat = F.conv3d(dat, smo[1]) dat = F.conv3d(dat, smo[2])[0, 0, ...] return dat
def get_pattern_codes(sampling_mask, kernel_size): """Compute the pattern's code about each voxel Parameters ---------- sampling_mask : (*freq) tensor[bool] kernel_size : [sequence of] int Returns ------- pattern_mask : (*freq) tensor[long] """ ndim = sampling_mask.dim() kernel_size = py.make_list(kernel_size, ndim) sampling_mask = sampling_mask.long() sampling_mask = utils.pad(sampling_mask, [(k - 1) // 2 for k in kernel_size], side='both') sampling_mask = utils.unfold(sampling_mask, kernel_size, stride=1) return pattern_to_code(sampling_mask, ndim)
def pool(dim, tensor, kernel_size=3, stride=None, dilation=1, padding=0, bound='constant', reduction='mean', ceil=False, return_indices=False, affine=None): """Perform a pooling Parameters ---------- dim : {1, 2, 3} Number of spatial dimensions tensor : (*batch, *spatial_in) tensor Input tensor kernel_size : int or sequence[int], default=3 Size of the pooling window stride : int or sequence[int], default=`kernel_size` Strides between output elements. dilation : int or sequence[int], default=1 Strides between elements of the kernel. padding : 'same' or int or sequence[int], default=0 Padding performed before the convolution. If 'same', the padding is chosen such that the shape of the output tensor is `floor(spatial_in / stride)` (or `ceil(spatial_in / stride)` if `ceil` is True). bound : str, default='constant' Boundary conditions used in the padding. reduction : {'mean', 'max', 'min', 'median', 'sum'} or callable, default='mean' Function to apply to the elements in a window. ceil : bool, default=False Use ceil instead of floor to compute output shape return_indices : bool, default=False Return input index of the min/max/median element. For other types of reduction, return None. affine : (..., D+1, D+1) tensor, optional Input orientation matrix Returns ------- pooled : (*batch, *spatial_out) tensor indices : (*batch, *spatial_out, dim) tensor, if `return_indices` affine : (..., D+1, D+1) tensor, if `affine` """ # move everything to the same dtype/device tensor = torch.as_tensor(tensor) # sanity checks + reshape for torch's conv batch = tensor.shape[:-dim] spatial_in = tensor.shape[-dim:] tensor = tensor.reshape([-1, *spatial_in]) # compute padding kernel_size = make_list(kernel_size, dim) stride = make_list(stride or None, dim) stride = [st or ks for st, ks in zip(stride, kernel_size)] dilation = make_list(dilation or 1, dim) padding = compute_conv_padding(spatial_in, kernel_size, padding, dilation, stride, ceil) if ceil: # ceil mode cannot be obtained using unfold. we may need to # pad the input a bit more padding = _pad_for_ceil(spatial_in, kernel_size, padding, stride, dilation) use_torch = (reduction in ('mean', 'avg', 'max') and dim in (1, 2, 3) and dilation == [1] * dim) padding0 = padding sum_padding = sum([sum(p) if isinstance(p, (list, tuple)) else p for p in padding]) if ((not use_torch) or (bound != 'zero' and sum_padding > 0) or any(isinstance(p, (list, tuple)) for p in padding)): # torch implementation -> handles zero-padding # our implementation -> needs explicit padding padding = _normalize_padding(padding) tensor = utils.pad(tensor, padding, bound, side='both', value=_fill_value(reduction, tensor)) padding = [0] * dim return_indices0 = False pool_fn = reduction if callable(reduction) else None if use_torch: if reduction in ('mean', 'avg'): return_indices0 = return_indices return_indices = False pool_fn = (F.avg_pool1d if dim == 1 else F.avg_pool2d if dim == 2 else F.avg_pool3d if dim == 3 else None) if pool_fn: pool_fn0 = pool_fn pool_fn = lambda x, *a, **k: pool_fn0(x[:, None], *a, **k, padding=padding)[:, 0] elif reduction == 'max': pool_fn = (F.max_pool1d if dim == 1 else F.max_pool2d if dim == 2 else F.max_pool3d if dim == 3 else None) if pool_fn: pool_fn0 = pool_fn pool_fn = lambda x, *a, **k: pool_fn0(x[:, None], *a, **k, padding=padding)[:, 0] if not pool_fn: if reduction not in ('min', 'max', 'median'): return_indices0 = return_indices return_indices = False if reduction == 'mean': reduction = lambda x: math.mean(x, dim=-1) elif reduction == 'sum': reduction = lambda x: math.sum(x, dim=-1) elif reduction == 'min': reduction = lambda x: math.min(x, dim=-1) elif reduction == 'max': reduction = lambda x: math.max(x, dim=-1) elif reduction == 'median': reduction = lambda x: math.median(x, dim=-1) elif not callable(reduction): raise ValueError(f'Unknown reduction {reduction}') pool_fn = lambda *a, **k: _pool(*a, **k, dilation=dilation, reduction=reduction) outputs = [] if return_indices: tensor, ind = pool_fn(tensor, kernel_size, stride=stride) ind = utils.ind2sub(ind, stride) ind = utils.movedim(ind, 0, -1) outputs.append(ind) else: tensor = pool_fn(tensor, kernel_size, stride=stride) if return_indices0: outputs.append(None) spatial_out = tensor.shape[-dim:] tensor = tensor.reshape([*batch, *spatial_out]) outputs = [tensor, *outputs] if affine is not None: affine, _ = affine_conv(affine, spatial_in, kernel_size=kernel_size, stride=stride, padding=padding0, dilation=dilation) outputs.append(affine) return outputs[0] if len(outputs) == 1 else tuple(outputs)
def _make_image(option, dim=None, device=None): """ Load an image and build a Gaussian pyramid (if requireD) Returns: ImagePyramid """ dat, mask, affine = _load_image(option.files, dim=dim, device=device, label=option.label) dim = dat.dim() - 1 if option.mask: mask1 = mask mask, _, _ = _load_image([option.mask], dim=dim, device=device, label=option.label) if mask.shape[-dim:] != dat.shape[-dim:]: raise ValueError('Mask should have the same shape as the image. ' f'Got {mask.shape[-dim:]} and {dat.shape[-dim:]}') if mask1 is not None: mask = mask * mask1 del mask1 if option.world: # overwrite orientation matrix affine = io.transforms.map(option.world).fdata().squeeze() for transform in (option.affine or []): transform = io.transforms.map(transform).fdata().squeeze() affine = spatial.affine_lmdiv(transform, affine) if not option.discretize and any(option.rescale): dat = _rescale_image(dat, option.rescale) if option.pad: pad = option.pad if isinstance(pad[-1], str): *pad, unit = pad else: unit = 'vox' if unit == 'mm': voxel_size = spatial.voxel_size(affine) pad = torch.as_tensor(pad, **utils.backend(voxel_size)) pad = pad / voxel_size pad = pad.floor().int().tolist() else: pad = [int(p) for p in pad] pad = py.make_list(pad, dim) if any(pad): affine, _ = spatial.affine_pad(affine, dat.shape[-dim:], pad, side='both') dat = utils.pad(dat, pad, side='both', mode=option.bound) if mask is not None: mask = utils.pad(mask, pad, side='both', mode=option.bound) if option.fwhm: fwhm = option.fwhm if isinstance(fwhm[-1], str): *fwhm, unit = fwhm else: unit = 'vox' if unit == 'mm': voxel_size = spatial.voxel_size(affine) fwhm = torch.as_tensor(fwhm, **utils.backend(voxel_size)) fwhm = fwhm / voxel_size dat = spatial.smooth(dat, dim=dim, fwhm=fwhm, bound=option.bound) image = objects.ImagePyramid(dat, levels=option.pyramid, affine=affine, dim=dim, bound=option.bound, mask=mask, extrapolate=option.extrapolate, method=option.pyramid_method) if getattr(option, 'soft_quantize', False) and len(image[0].dat) == 1: for level in image: level.preview = level.dat level.dat = _soft_quantize_image(level.dat, option.soft_quantize) elif not option.label and option.discretize: for level in image: level.preview = level.dat level.dat = _discretize_image(level.dat, option.discretize) return image
def _hist_2d(img0, img1, mx_int, fwhm): """Make 2D histogram, requires: * Images same size. * Images same min and max intensities (non-negative). Parameters ---------- img0 : (X, Y, Z) tensor_like First image volume. img1 : (X, Y, Z) tensor_like Second image volume. mx_int : int This parameter sets the max intensity in the images, which decides how many bins to use in the joint image histograms (e.g, mx_int=511 -> H.shape = (512, 512)). fwhm : float Full-width at half max of Gaussian kernel, for smoothing histogram. Returns ---------- H : (mx_int + 1, mx_int + 1) tensor_like Joint intensity histogram. Notes ---------- Naive method for computing a 2D histogram: h = torch.zeros((mx_int + 1, mx_int + 1)) for n in range(num_vox): h[img0[n], mg1[n]] += 1 """ fwhm = (fwhm, ) * 2 # Convert each 'coordinate' of intensities to an index # (replicates the sub2ind function of MATLAB) img0 = img0.flatten().floor() img1 = img1.flatten().floor() sub = torch.stack((img0, img1), dim=1) # (num_vox, 2) to_ind = torch.tensor((1, mx_int + 1), dtype=sub.dtype, device=img0.device)[..., None] # (2, 1) ind = torch.tensordot(sub, to_ind, dims=([1], [0])) # (nvox, 1) # Build histogram H by adding up counts according to the indicies in ind H = torch.zeros(mx_int + 1, mx_int + 1, device=img0.device, dtype=ind.dtype) H.put_(ind.long(), torch.ones(1, device=img0.device, dtype=ind.dtype).expand_as(ind), accumulate=True) # Smoothing kernel smo = smooth(('gauss', ) * 2, fwhm=fwhm, device=img0.device, dtype=torch.float32, sep=True) # Pad p = (smo[0].shape[2], smo[1].shape[3]) p = (torch.tensor(p) - 1) // 2 p = tuple(p.int().tolist()) H = pad(H, p, side='both') # Smooth H = H[None, None, ...] H = F.conv2d(H, smo[0]) H = F.conv2d(H, smo[1]) H = H[0, 0, ...] # Clamp H = H.clamp_min(0.0) # Add eps H = H + 1e-7 # # Visualise histogram # import matplotlib.pyplot as plt # plt.figure(num=1) # plt.imshow(H.detach().cpu(), # cmap='coolwarm', interpolation='nearest', # aspect='equal', vmax=0.05*H.max()) # plt.axis('off') # plt.show() return H
def pool(dim, tensor, kernel_size=3, stride=None, dilation=1, padding=0, bound='zero', reduction='mean', return_indices=False, affine=None): """Perform a pooling Parameters ---------- dim : {1, 2, 3} Number of spatial dimensions tensor : (*batch, *spatial_in) tensor Input tensor kernel_size : int or sequence[int], default=3 Size of the pooling window stride : int or sequence[int], default=`kernel_size` Strides between output elements. dilation : int or sequece[int], default=1 Strides between elements of the kernel. padding : 'auto' or int or sequence[int], default=0 Padding performed before the convolution. If 'auto', the padding is chosen such that the shape of the output tensor is `spatial_in // stride`. bound : str, default='zero' Boundary conditions used in the padding. reduction : {'mean', 'max', 'min', 'median', 'sum'} or callable, default='mean' Function to apply to the elements in a window. return_indices : bool, default=False Return input index of the min/max/median element. For other types of reduction, return None. affine : (..., D+1, D+1) tensor, optional Input orientation matrix Returns ------- pooled : (*batch, *spatial_out) tensor indices : (*batch, *spatial_out, dim) tensor, if `return_indices` affine : (..., D+1, D+1) tensor, if `affine` """ # move everything to the same dtype/device tensor = torch.as_tensor(tensor) # sanity checks + reshape for torch's conv batch = tensor.shape[:-dim] spatial_in = tensor.shape[-dim:] tensor = tensor.reshape([-1, *spatial_in]) # Perform padding kernel_size = make_list(kernel_size, dim) stride = make_list(stride or None, dim) stride = [st or ks for st, ks in zip(stride, kernel_size)] dilation = make_list(dilation or 1, dim) padding = make_list(padding, dim) padding0 = padding # save it to update the affine for i in range(dim): if isinstance(padding[i], str) and padding[i].lower() == 'auto': if kernel_size[i] % 2 == 0: raise ValueError('Cannot compute automatic padding ' 'for even-sized kernels.') padding[i] = ((kernel_size[i] - 1) * dilation[i] + 1) // 2 use_torch = reduction in ('mean', 'avg', 'max') and dim in (1, 2, 3) if (not use_torch) or bound != 'zero' and sum(padding) > 0: # torch implementation -> handles zero-padding # our implementation -> needs explicit padding tensor = utils.pad(tensor, padding, bound, side='both') padding = [0] * dim return_indices0 = False pool_fn = reduction if callable(reduction) else None if reduction in ('mean', 'avg'): return_indices0 = True return_indices = False pool_fn = (F.avg_pool1d if dim == 1 else F.avg_pool2d if dim == 2 else F.avg_pool3d if dim == 3 else None) if pool_fn: pool_fn0 = pool_fn pool_fn = lambda x, *a, **k: pool_fn0( x[:, None], *a, **k, padding=padding, dilation=dilation)[:, 0] elif reduction == 'max': pool_fn = (F.max_pool1d if dim == 1 else F.max_pool2d if dim == 2 else F.max_pool3d if dim == 3 else None) if pool_fn: pool_fn0 = pool_fn pool_fn = lambda x, *a, **k: pool_fn0( x[:, None], *a, **k, padding=padding, dilation=dilation)[:, 0] if not pool_fn: if reduction not in ('min', 'max', 'median'): return_indices0 = True return_indices = False if reduction == 'mean': reduction = lambda x: math.mean(x, dim=-1) elif reduction == 'sum': reduction = lambda x: math.sum(x, dim=-1) elif reduction == 'min': reduction = lambda x: math.min(x, dim=-1) elif reduction == 'max': reduction = lambda x: math.max(x, dim=-1) elif reduction == 'median': reduction = lambda x: math.median(x, dim=-1) elif not callable(reduction): raise ValueError(f'Unknown reduction {reduction}') pool_fn = lambda *a, **k: _pool(*a, **k, reduction=reduction) outputs = [] if return_indices: tensor, ind = pool_fn(tensor, kernel_size, stride=stride) ind = utils.ind2sub(ind, stride) ind = utils.movedim(ind, 0, -1) outputs.append(ind) else: tensor = pool_fn(tensor, kernel_size, stride=stride) if return_indices0: outputs.append(None) spatial_out = tensor.shape[-dim:] tensor = tensor.reshape([*batch, *spatial_out]) outputs = [tensor, *outputs] if affine is not None: affine, _ = affine_conv(affine, spatial_in, kernel_size=kernel_size, stride=stride, padding=padding0, dilation=dilation) outputs.append(affine) return outputs[0] if len(outputs) == 1 else tuple(outputs)