def __init__(self, ndim, img_size=192, cps=5, svf=False, svf_steps=7, svf_scale=1): """ Compute dense displacement field of Cubic B-spline FFD transformation model from input control point parameters. Args: ndim: (int) image dimension img_size: (int or tuple) size of the image cps: (int or tuple) control point spacing in number of intervals between pixel/voxel centres svf: (bool) stationary velocity field formulation if True """ super(CubicBSplineFFDTransform, self).__init__(svf=svf, svf_steps=svf_steps, svf_scale=svf_scale) self.ndim = ndim self.img_size = param_ndim_setup(img_size, self.ndim) self.stride = param_ndim_setup(cps, self.ndim) self.kernels = self.set_kernel() self.padding = [(len(k) - 1) // 2 for k in self.kernels ] # the size of the kernel is always odd number
def __init__(self, ndim, enc_channels=(16, 32, 32, 32, 32), dec_channels=(32, 32, 32, 32), resize_channels=(32, 32), cps=(5, 5, 5), img_size=(176, 192, 176) ): """ Network to parameterise Cubic B-spline transformation """ super(CubicBSplineNet, self).__init__(ndim=ndim, enc_channels=enc_channels, conv_before_out=False) # determine and set output control point sizes from image size and control point spacing img_size = param_ndim_setup(img_size, ndim) cps = param_ndim_setup(cps, ndim) for i, c in enumerate(cps): if c > 8 or c < 2: raise ValueError(f"Control point spacing ({c}) at dim ({i}) not supported, must be within [1, 8]") self.output_size = tuple([int(math.ceil((imsz-1) / c) + 1 + 2) for imsz, c in zip(img_size, cps)]) # Network: # encoder: same u-net encoder # decoder: number of decoder layers / times of upsampling by 2 is decided by cps num_dec_layers = 4 - int(math.ceil(math.log2(min(cps)))) self.dec = self.dec[:num_dec_layers] # conv layers following resizing self.resize_conv = nn.ModuleList() for i in range(len(resize_channels)): if i == 0: if num_dec_layers > 0: in_ch = dec_channels[num_dec_layers-1] + enc_channels[-num_dec_layers] else: in_ch = enc_channels[-1] else: in_ch = resize_channels[i-1] out_ch = resize_channels[i] self.resize_conv.append(nn.Sequential(convNd(ndim, in_ch, out_ch, a=0.2), nn.LeakyReLU(0.2))) # final prediction layer delattr(self, 'out_layers') # remove u-net output layers self.out_layer = convNd(ndim, resize_channels[-1], ndim)
def bbox_from_mask(mask, pad_ratio=0.2): """ Find a bounding box indices of a mask (with positive > 0) The output indices can be directly used for slicing - for 2D, find the largest bounding box out of the N masks - for 3D, find the bounding box of the volume mask Args: mask: (numpy.ndarray, shape (N, H, W) or (N, H, W, D) pad_ratio: (int or tuple) the ratio of between the mask bounding box to image boundary to pad Return: bbox: (list of tuples) [*(bbox_min_index, bbox_max_index)] bbox_mask: (numpy.ndarray shape (N, mH, mW) or (N, mH, mW, mD)) binary mask of the bounding box """ dim = mask.ndim - 1 mask_shape = mask.shape[1:] pad_ratio = param_ndim_setup(pad_ratio, dim) # find non-zero locations in the mask nonzero_indices = np.nonzero(mask > 0) bbox = [(nonzero_indices[i + 1].min(), nonzero_indices[i + 1].max()) for i in range(dim)] # pad pad_ratio of the minimum distance # from mask bounding box to the image boundaries (half each side) for i in range(dim): if pad_ratio[i] > 1: print(f"Invalid padding value (>1) on dimension {dim}, set to 1") pad_ratio[i] = 1 bbox_padding = [ pad_ratio[i] * min(bbox[i][0], mask_shape[i] - bbox[i][1]) for i in range(dim) ] # "padding" by modifying the bounding box indices bbox = [(bbox[i][0] - int(bbox_padding[i] / 2), bbox[i][1] + int(bbox_padding[i] / 2)) for i in range(dim)] # bbox mask bbox_mask = np.zeros(mask.shape, dtype=np.float32) slicer = [slice(0, mask.shape[0])] # all slices/batch for i in range(dim): slicer.append(slice(*bbox[i])) bbox_mask[tuple(slicer)] = 1.0 return bbox, bbox_mask
def forward(self, tar, src): # products and squares tar2 = tar * tar src2 = src * src tar_src = tar * src # set window size ndim = tar.dim() - 2 window_size = param_ndim_setup(self.window_size, ndim) # summation filter for convolution sum_filt = torch.ones(1, 1, *window_size).type_as(tar) # set stride and padding stride = (1,) * ndim padding = tuple([math.floor(window_size[i]/2) for i in range(ndim)]) # get convolution function of the correct dimension conv_fn = getattr(F, f'conv{ndim}d') # summing over window by convolution tar_sum = conv_fn(tar, sum_filt, stride=stride, padding=padding) src_sum = conv_fn(src, sum_filt, stride=stride, padding=padding) tar2_sum = conv_fn(tar2, sum_filt, stride=stride, padding=padding) src2_sum = conv_fn(src2, sum_filt, stride=stride, padding=padding) tar_src_sum = conv_fn(tar_src, sum_filt, stride=stride, padding=padding) window_num_points = np.prod(window_size) mu_tar = tar_sum / window_num_points mu_src = src_sum / window_num_points cov = tar_src_sum - mu_src * tar_sum - mu_tar * src_sum + mu_tar * mu_src * window_num_points tar_var = tar2_sum - 2 * mu_tar * tar_sum + mu_tar * mu_tar * window_num_points src_var = src2_sum - 2 * mu_src * src_sum + mu_src * mu_src * window_num_points lncc = cov * cov / (tar_var * src_var + 1e-5) return -torch.mean(lncc)
def forward(self, x, y): # products and squares xsq = x * x ysq = y * y xy = x * y # set window size ndim = x.dim() - 2 window_size = param_ndim_setup(self.window_size, ndim) # summation filter for convolution sum_filt = torch.ones(1, 1, *window_size).type_as(x) # set stride and padding stride = (1, ) * ndim padding = tuple([math.floor(window_size[i] / 2) for i in range(ndim)]) # get convolution function of the correct dimension conv_fn = getattr(F, f'conv{ndim}d') # summing over window by convolution x_sum = conv_fn(x, sum_filt, stride=stride, padding=padding) y_sum = conv_fn(y, sum_filt, stride=stride, padding=padding) xsq_sum = conv_fn(xsq, sum_filt, stride=stride, padding=padding) ysq_sum = conv_fn(ysq, sum_filt, stride=stride, padding=padding) xy_sum = conv_fn(xy, sum_filt, stride=stride, padding=padding) window_num_points = np.prod(window_size) x_mu = x_sum / window_num_points y_mu = y_sum / window_num_points cov = xy_sum - y_mu * x_sum - x_mu * y_sum + x_mu * y_mu * window_num_points x_var = xsq_sum - 2 * x_mu * x_sum + x_mu * x_mu * window_num_points y_var = ysq_sum - 2 * y_mu * y_sum + y_mu * y_mu * window_num_points lncc = cov * cov / (x_var * y_var + 1e-5) return -torch.mean(lncc)
def crop_and_pad(x, new_size=192, mode="constant", **kwargs): """ Crop and/or pad input to new size. (Adapted from DLTK: https://github.com/DLTK/DLTK/blob/master/dltk/io/preprocessing.py) Args: x: (np.ndarray) input array, shape (N, H, W) or (N, H, W, D) new_size: (int or tuple/list) new size excluding the batch size mode: (string) padding value filling mode for numpy.pad() (compulsory in Numpy v1.18) kwargs: additional arguments to be passed to np.pad Returns: (np.ndarray) cropped and/or padded input array """ assert isinstance(x, (np.ndarray, np.generic)) new_size = param_ndim_setup(new_size, ndim=x.ndim - 1) dim = x.ndim - 1 sizes = x.shape[1:] # Initialise padding and slicers to_padding = [[0, 0] for i in range(x.ndim)] slicer = [slice(0, x.shape[i]) for i in range(x.ndim)] # For each dimensions except the dim 0, set crop slicers or paddings for i in range(dim): if sizes[i] < new_size[i]: to_padding[i + 1][0] = (new_size[i] - sizes[i]) // 2 to_padding[i + 1][1] = new_size[i] - sizes[i] - to_padding[i + 1][0] else: # Create slicer object to crop each dimension crop_start = int(np.floor((sizes[i] - new_size[i]) / 2.)) crop_end = crop_start + new_size[i] slicer[i + 1] = slice(crop_start, crop_end) return np.pad(x[tuple(slicer)], to_padding, mode=mode, **kwargs)