def __init__(self,
              path,
              batch_size=32,
              noise_config=None,
              shuffle=True,
              name="CleanDataset",
              n_channels=1,
              preprocessing=None):
     super().__init__(path, batch_size, shuffle, name, n_channels)
     if noise_config is None:
         noise_config = {gaussian_noise: [25]}
     self.noise_functions = noise_config.keys()
     self.noise_args = [
         noise_config[noise_type] for noise_type in noise_config
     ]
     self.preprocessing = [] if preprocessing is None else preprocessing
     self.n_channels = n_channels
     self.filenames = np.array(os.listdir(os.path.join(self.path, "ref")))
     self.on_epoch_end()
     module_logger.info("Generating data from {}".format(
         os.path.join(
             self.path,
             'ref',
         )))
     self.image_shape = self[0][0].shape
     module_logger.debug("[{}] Image shape: {}".format(
         self.name, self.image_shape))
 def __getitem__(self, i):
     """Generates batches of data."""
     # Get batch_filenames
     batch_filenames = self.filenames[i * self.batch_size:(i + 1) *
                                      self.batch_size]
     module_logger.debug("[{}] Got following batch names: {}".format(
         self, batch_filenames))
     # Get data batches
     inp, ref = self.__data_generation(batch_filenames)
     return inp, ref
    def __data_generation(self, batch_filenames):
        """Data generation method

        Parameters
        ----------
        batch_filenames : list
            List of strings containing filenames to read.

        Returns
        -------
        noisy_batch : :class:`numpy.ndarray`
            Batch of noisy images.
        """
        # Noised image and ground truth initialization
        inp_batch = []
        ref_batch = []

        for filename in batch_filenames:
            filepath = os.path.join(self.path, 'in', filename)
            module_logger.debug("Loading image located on {}".format(filepath))
            inp = imread(filepath)
            inp = img_as_float32(inp)

            if inp.ndim == 3 and inp.shape[-1] == 3 and self.n_channels == 1:
                # Converts RGB to Gray
                inp = rgb2gray(inp)
            if inp.ndim == 2 and self.n_channels == 1:
                # Expand last dim if image is grayscale
                inp = np.expand_dims(inp, axis=-1)
            elif inp.ndim == 2 and self.n_channels == 3:
                raise ValueError(
                    "Expected RGB image but got Grayscale (image shape: {})".
                    format(inp.shape))

            for func in self.preprocessing:
                # Preprocessing pipeline
                inp = func(inp)

            # Generates target from input
            inp, ref = self.target_fcn(inp)
            ref_batch.append(ref)
            inp_batch.append(inp)
        inp_batch = np.array(inp_batch)
        ref_batch = np.array(ref_batch)
        module_logger.debug("Data shape: {}".format(inp_batch.shape))
        return inp_batch, ref_batch
    def __getitem__(self, i):
        """Generates image batches from filenames.

        Parameters
        ----------
        i : int
            Batch index to get.

        Returns
        -------
        inp : :class:`numpy.ndararray`
            Batch of noisy images.
        ref : :class:`numpy.ndarray`
            Batch of target images.
        """
        # Get batch_filenames
        batch_filenames = self.filenames[i * self.batch_size:(i + 1) *
                                         self.batch_size]
        module_logger.debug("[{}] Got following batch names: {}".format(
            self, batch_filenames))
        # Get data batches
        inp, ref = self.__data_generation(batch_filenames)
        return inp, ref
def smooth_patches(img, d=64, h=32, sg=32, sl=16, mu=0.1, gamma=0.25):
    """Extract smooth patches for GCBD [1]_ algorithm.

    Parameters
    ----------
    img : :class:`numpy.ndarray`
        Noised image.
    d : int
        Global patch size.
    h : int
        Local patch size.
    sg : int
        Global stride.
    sl : int
        Local stride.
    mu : float
        mean-smoothing hyper-parameter.
    gamma : float
        Variance-smoothing hyper-parameter.

    Returns
    -------
    patches : :class:`numpy.ndarray`
        Extracted patches from img.

    References
    ----------
    .. [1] Chen, J., Chen, J., Chao, H., & Yang, M. (2018). Image blind denoising with generative adversarial network
           based noise modeling. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition
    """
    height, width = img.shape[:2]
    patches = []
    for i in range(0, height - d, sg):
        for j in range(0, width - d, sg):
            # Run for global h x h patches p
            wg = img[i:i + d, j:j + d]
            # Patch mean and variance
            mean_g = np.mean(wg)
            var_g = np.var(wg)
            # Initializes smooth verifier
            smooth = True
            for k in range(i, i + d - h, sl):
                for l in range(j, j + d - h, sl):
                    # Run for local h x h patches q
                    wl = img[k:k + h, l:l + h]
                    module_logger.debug(
                        "image shape {}, {}:{}, {}:{}, local patch shape: {}".
                        format(img.shape, k, k + h, l, l + h, wl.shape))
                    # Local mean and variance
                    mean_l = np.mean(wl)
                    var_l = np.var(wl)
                    # Difference between local and global means/variances
                    mean_diff = np.abs(mean_g - mean_l)
                    var_diff = np.abs(var_g - var_l)
                    if mean_diff > mu * mean_l or var_diff > gamma * var_l:
                        module_logger.debug(
                            "{}\tmean_g: {}, mean_l: {}, mean_diff: {}, bound: {}"
                            .format([i, j, k, l], mean_g, mean_l, mean_diff,
                                    mu * mean_l))
                        module_logger.debug(
                            "{}\tvar_g: {}, var_l: {}, var_diff: {}, bound: {}\n"
                            .format([i, j, k, l], var_g, var_l, var_diff,
                                    gamma * var_l))
                        # If constraints not met for a local patch, then global patch
                        # is not smooth => smooth verifier becomes false
                        smooth = False
            if smooth:
                # If smooth enough, then extracts the noise patch through
                # noise = patch - mean(patch)
                patches.append(wg - np.mean(wg))
    patches = np.clip(np.array(patches), 0, 1)
    print(patches.shape)
    return patches
예제 #6
0
    def __data_generation(self, batch_filenames):
        """Data generation method

        Parameters
        ----------
        batch_filenames : list
            List of strings containing filenames to read. Note that, for each noisy image filename there must be a clean
            image with same filename.

        Returns
        -------
        noisy_batch : :class:`numpy.ndarray`
            Batch of noisy images.
        clean_batch : :class:`numpy.ndarray`
            Batch of reference images.
        """
        # Noised image and ground truth initialization
        ref_batch = []
        inp_batch = []

        for filename in batch_filenames:
            # Compose path
            clean_filepath = os.path.join(self.path, "in", filename)
            noisy_filepath = os.path.join(self.path, "ref", filename)
            # Read images
            ref = imread(clean_filepath)
            ref = img_as_float32(ref)
            inp = imread(noisy_filepath)
            inp = img_as_float32(inp)

            # Corrects shape of reference
            if ref.ndim == 3 and ref.shape[-1] == 3 and self.n_channels == 1:
                # Converts RGB to Gray
                ref = rgb2gray(ref)
            if ref.ndim == 2 and self.n_channels == 1:
                # Expand last dim if image is grayscale
                ref = np.expand_dims(ref, axis=-1)
            elif ref.ndim == 2 and self.n_channels == 3:
                raise ValueError(
                    "Expected RGB image but got Grayscale (image shape: {})".
                    format(ref.shape))

            # Corrects shape of input image
            if inp.ndim == 3 and inp.shape[-1] == 3 and self.n_channels == 1:
                # Converts RGB to Gray
                inp = rgb2gray(inp)
            if inp.ndim == 2 and self.n_channels == 1:
                # Expand last dim if image is grayscale
                inp = np.expand_dims(inp, axis=-1)
            elif inp.ndim == 2 and self.n_channels == 3:
                raise ValueError(
                    "Expected RGB image but got Grayscale (image shape: {})".
                    format(inp.shape))

            # Apply preprocessing functions
            for func in self.preprocessing:
                inp, ref = func(inp, ref)

            # Append images to lists
            ref_batch.append(ref)
            inp_batch.append(inp)
        ref_batch = np.array(
            ref_batch) if ref_batch[0].ndim == 3 else np.concatenate(ref_batch,
                                                                     axis=0)
        inp_batch = np.array(
            inp_batch) if inp_batch[0].ndim == 3 else np.concatenate(inp_batch,
                                                                     axis=0)
        module_logger.debug("Data shape: {}".format(ref_batch.shape))

        if not inp.shape == ref.shape:
            raise ValueError(
                "Expected {} to have same shape of {}, but got {} and {}".
                format(clean_filepath, noisy_filepath, ref.shape, inp.shape))

        return ref_batch, inp_batch