Ejemplo n.º 1
0
    def __init__(self,
                 image,
                 regulariser,
                 geodist,
                 u_init=None,
                 segmentation_threshold=0.5,
                 c=None):
        self._image_arr = torch.Tensor(np.array(image, dtype=float) / 255)
        self.image_shape = self._image_arr.shape
        self.channels = len(image.getbands())
        if self.channels > 1:
            self.image_shape = self.image_shape[:-1]
        self._dim = len(self.image_shape)
        self.segmentation_threshold = segmentation_threshold
        self.c = (0, 1) if c is None else c

        self.regulariser = regulariser

        if u_init is None:
            self.u = torch.rand(size=self.image_shape)
        else:
            self.u = torch.Tensor(u_init)
            self.c = cv.get_segmentation_mean_colours(
                self.u, self._image_arr, self.segmentation_threshold)
            self.D_M = torch.Tensor(geodist)
Ejemplo n.º 2
0
 def update_c(self):
     """Update the average colours in the segmentation domain and its complement. See
     'get_segmentation_mean_colours' for more information.
     """
     try:
         self.c = cv.get_segmentation_mean_colours(
             self.u, self._image_arr, self.segmentation_threshold)
     except RuntimeError:
         self.c = tuple(np.random.rand(2))
Ejemplo n.º 3
0
    def single_step(self, lmb_reg=1, epsilon=0.1, gamma=1):
        """Performs a single gradient descent step for 'self.u' along the CEN
        data-fitting term plus the regulariser term. After the gradient step, u
        is clipped in order to lie in [0,1].

        Parameters:

        'lmb_reg' (float): Weight for the regulariser.

        'epsilon' (float): Step size for gradient descent.
        """

        self.u.requires_grad = True
        self.D_M.requires_grad = False

        data_fitting = cv.CEN_data_fitting_energy(self.u, self.c[0], self.c[1],
                                                  self._image_arr)
        +gamma * torch.sum(self.D_M * self.u)
        reg = self.regulariser(self.u.unsqueeze(0).unsqueeze(0) - 0.5)
        error = data_fitting + lmb_reg * reg

        gradients = torch.autograd.grad(error, self.u)[0]
        self.u = (self.u - epsilon * gradients).detach()
        self.u = torch.clamp(self.u, min=0.0, max=1.0)
Ejemplo n.º 4
0
def data_fitting_penalty(
    chanvese_batch,
    noisy_batch,
    lambda_chanvese=1,
    threshold=0.5,
    c1=None,
    c2=None,
    alpha=None,
):
    """Calculates the data-fitting term for the Chan-Esedoglu-Nikolova functional
    and adds a penality term, which penalises values outside [0,1].

    Parameters:

    'chanvese_batch' (Tensor): Minibatch of the "characteristic functions"
    ("u"). Expected shape is [batchsize, 1, height, width].

    'noisy_batch' (Tensor): Minibatch of original images.
    Expected shape is [batchsize, 1, height, width].

    'threshold' (float): Segmentation threshold (for the purpose of calculating
    c1, c2 from chanvese_batch

    'alpha' (float): Positive constant controlling strength of the penality.

    """

    assert chanvese_batch.size() == noisy_batch.size()
    assert chanvese_batch.size(
        1) == 1  # require greyscale image, i.e. only one channel

    batchsize = chanvese_batch.size(0)

    # Estimate c1, c2 from u. Do NOT backpropagate along them.
    # REVIEW: Does this implicitly induce backpropagation? Maybe calculate c1, c2
    # externally in the optimisation loop?
    if c1 is None or c2 is None:
        c1, c2 = torch.zeros(batchsize), torch.zeros(batchsize)
        for i in range(batchsize):
            c1[i], c2[i] = ChanVese.get_segmentation_mean_colours(
                chanvese_batch[i], noisy_batch[i])

    chanvese_term = lambda_chanvese * (
        (noisy_batch - c1.unsqueeze(1)).square() -
        (noisy_batch - c2.unsqueeze(1)).square()
    )  # [batchsize, 1, height, width]

    # REVIEW: Better to drop the [0,1] penality and just clip?

    # I DON'T THINK we want to backprop along alpha when performing the
    # reconstruction (algorithm 2), only relevant when alpha is calculated
    # implicitly, hence .detach() below
    if alpha is None:
        # Calculate supremum-norm of each sample (--> [batchsize])
        alpha = chanvese_term.detach().abs().flatten(start_dim=1).max(dim=1)[0]

    penality_term = torch.nn.Threshold(0, 0)(
        2 *
        ((chanvese_batch - 0.5).abs() - 1))  # [batchsize, 1, height, width]

    # integral over domain is just done by taking the mean, should just
    # correspond to scaling lambda_reg accordingly in reconstruct (below)
    # REVIEW: Is this actually correct?
    datafitting_term = (chanvese_term * chanvese_batch +
                        alpha.unsqueeze(1) * penality_term).mean((1, 2, 3))
    # --> [batchsize]

    return datafitting_term  # [batchsize]