def __init__(self, image, regulariser, geodist, u_init=None, segmentation_threshold=0.5, c=None): self._image_arr = torch.Tensor(np.array(image, dtype=float) / 255) self.image_shape = self._image_arr.shape self.channels = len(image.getbands()) if self.channels > 1: self.image_shape = self.image_shape[:-1] self._dim = len(self.image_shape) self.segmentation_threshold = segmentation_threshold self.c = (0, 1) if c is None else c self.regulariser = regulariser if u_init is None: self.u = torch.rand(size=self.image_shape) else: self.u = torch.Tensor(u_init) self.c = cv.get_segmentation_mean_colours( self.u, self._image_arr, self.segmentation_threshold) self.D_M = torch.Tensor(geodist)
def update_c(self): """Update the average colours in the segmentation domain and its complement. See 'get_segmentation_mean_colours' for more information. """ try: self.c = cv.get_segmentation_mean_colours( self.u, self._image_arr, self.segmentation_threshold) except RuntimeError: self.c = tuple(np.random.rand(2))
def single_step(self, lmb_reg=1, epsilon=0.1, gamma=1): """Performs a single gradient descent step for 'self.u' along the CEN data-fitting term plus the regulariser term. After the gradient step, u is clipped in order to lie in [0,1]. Parameters: 'lmb_reg' (float): Weight for the regulariser. 'epsilon' (float): Step size for gradient descent. """ self.u.requires_grad = True self.D_M.requires_grad = False data_fitting = cv.CEN_data_fitting_energy(self.u, self.c[0], self.c[1], self._image_arr) +gamma * torch.sum(self.D_M * self.u) reg = self.regulariser(self.u.unsqueeze(0).unsqueeze(0) - 0.5) error = data_fitting + lmb_reg * reg gradients = torch.autograd.grad(error, self.u)[0] self.u = (self.u - epsilon * gradients).detach() self.u = torch.clamp(self.u, min=0.0, max=1.0)
def data_fitting_penalty( chanvese_batch, noisy_batch, lambda_chanvese=1, threshold=0.5, c1=None, c2=None, alpha=None, ): """Calculates the data-fitting term for the Chan-Esedoglu-Nikolova functional and adds a penality term, which penalises values outside [0,1]. Parameters: 'chanvese_batch' (Tensor): Minibatch of the "characteristic functions" ("u"). Expected shape is [batchsize, 1, height, width]. 'noisy_batch' (Tensor): Minibatch of original images. Expected shape is [batchsize, 1, height, width]. 'threshold' (float): Segmentation threshold (for the purpose of calculating c1, c2 from chanvese_batch 'alpha' (float): Positive constant controlling strength of the penality. """ assert chanvese_batch.size() == noisy_batch.size() assert chanvese_batch.size( 1) == 1 # require greyscale image, i.e. only one channel batchsize = chanvese_batch.size(0) # Estimate c1, c2 from u. Do NOT backpropagate along them. # REVIEW: Does this implicitly induce backpropagation? Maybe calculate c1, c2 # externally in the optimisation loop? if c1 is None or c2 is None: c1, c2 = torch.zeros(batchsize), torch.zeros(batchsize) for i in range(batchsize): c1[i], c2[i] = ChanVese.get_segmentation_mean_colours( chanvese_batch[i], noisy_batch[i]) chanvese_term = lambda_chanvese * ( (noisy_batch - c1.unsqueeze(1)).square() - (noisy_batch - c2.unsqueeze(1)).square() ) # [batchsize, 1, height, width] # REVIEW: Better to drop the [0,1] penality and just clip? # I DON'T THINK we want to backprop along alpha when performing the # reconstruction (algorithm 2), only relevant when alpha is calculated # implicitly, hence .detach() below if alpha is None: # Calculate supremum-norm of each sample (--> [batchsize]) alpha = chanvese_term.detach().abs().flatten(start_dim=1).max(dim=1)[0] penality_term = torch.nn.Threshold(0, 0)( 2 * ((chanvese_batch - 0.5).abs() - 1)) # [batchsize, 1, height, width] # integral over domain is just done by taking the mean, should just # correspond to scaling lambda_reg accordingly in reconstruct (below) # REVIEW: Is this actually correct? datafitting_term = (chanvese_term * chanvese_batch + alpha.unsqueeze(1) * penality_term).mean((1, 2, 3)) # --> [batchsize] return datafitting_term # [batchsize]