Beispiel #1
0
def crop_generator(generator, crop_length, val=False, weights_on_data=False):
    """Generate random crops of the images extracted from a generator. Uses
       :meth:`util.random_crop` to make the crop. 

       Parameters
       ----------
       generator : Keras ImageDataGenerator
           Generator to extract images from. 
        
       crop_length : int
           Size of the crop. Its shape will be square: 
           ``(crop_length, crop_length)``.
      
       val: bool, optional
           To advise the function that the given data is for validation, which 
           ensures that the crop choosen will be always the same: starting from
           (0,0) coords. 

       weights_on_data : bool, optional
           To advise the method that weights on data should be unpacked from the
           batches extracted from the generator.
     
       Yields
       ------
       out : 2 element list
           One batch of the generator. 
    """

    while True:
        batch_x, batch_y = next(generator)
        if weights_on_data:
            x, w = batch_x
            y = batch_y
            batch_crops_w = np.zeros((x.shape[0], crop_length, crop_length, 1))
        else:
            x = batch_x
            y = batch_y

        batch_crops_x = np.zeros((x.shape[0], crop_length, crop_length, 1))
        batch_crops_y = np.zeros((x.shape[0], crop_length, crop_length, 1))

        for i in range(x.shape[0]):
            if weights_on_data:
                batch_crops_x[i],\
                batch_crops_y[i],\
                batch_crops_w[i] = random_crop(
                    x[i], y[i], (crop_length, crop_length), val=val,
                    weight_map=w[i])

                yield ([batch_crops_x, batch_crops_w], batch_crops_y)

            else:
                batch_crops_x[i],\
                batch_crops_y[i] = random_crop(
                    batch_x[i], batch_y[i], (crop_length, crop_length), val=val)

                yield (batch_crops_x, batch_crops_y)
    def get_transformed_samples(self, num_examples, save_to_dir=False, 
                                out_dir='aug', save_prefix=None, train=True, 
                                random_images=True, force_full_images=False):
        """Apply selected transformations to a defined number of images from
           the dataset. 
            
           Parameters
           ----------
           num_examples : int
               Number of examples to generate.

           save_to_dir : bool, optional
               Save the images generated. The purpose of this variable is to
               check the images generated by data augmentation.

           out_dir : str, optional
               Name of the folder where the examples will be stored. If any
               provided the examples will be generated under a folder ``aug``.

           save_prefix : str, optional
               Prefix to add to the generated examples' name. 

           train : bool, optional
               To avoid drawing a grid on the generated images. This should be
               set when the samples will be used for training.

           random_images : bool, optional
               Randomly select images from the dataset. If ``False`` the examples
               will be generated from the start of the dataset. 

           force_full_images : bool, optional
               Force the usage of the entire images. Useful to generate extra
               images and override ``random_crops_in_DA`` functionality.


           Returns
           -------
           batch_x : 4D Numpy array
               Batch of data. E.g. ``(num_examples, x, y, channels)``.

           batch_y : 4D Numpy array
               Batch of data mask. E.g. ``(num_examples, x, y, channels)``.

          
           Examples
           --------
           ::

               # EXAMPLE 1
               # Generate 10 samples following with the example 1 of the class definition
               X_train = np.ones((1776, 256, 256, 1))                               
               Y_train = np.ones((1776, 256, 256, 1))                               
                                                                                
               data_gen_args = dict(                                                
                   X=X_train, Y=Y_train, batch_size=6, shape=(256, 256, 1),
                   shuffle=True, rotation_range=True, vflip=True, hflip=True)                     
                                                                                
               train_generator = ImageDataGenerator(**data_gen_args)                       

               train_generator.get_transformed_samples(                                
                   10, save_to_dir=True, train=False, out_dir='da_dir') 

               # EXAMPLE 2
               # If random crop in DA-time is choosen, as the example 2 of the class definition, 
               # the call should be the same but two more images will be stored: img and mask
               # representing the random crop extracted. There a red point is painted representing 
               # the pixel choosen to be the center of the random crop and a blue square which
               # delimits crop boundaries

               train_prob = calculate_2D_volume_prob_map(                           
                   Y_train, 0.94, 0.06, save_file=''prob_map.npy')                 
                                                                                
               data_gen_args = dict(                                                
                   X=X_train, Y=Y_train, batch_size=6, shape=(256, 256, 1), shuffle=True,
                   rotation_range=True, vflip=True, hflip=True, random_crops_in_DA=True,
                   prob_map=True, train_prob=train_prob)                            
               train_generator = ImageDataGenerator(**data_gen_args)
            
                train_generator.get_transformed_samples(                                
                   10, save_to_dir=True, train=False, out_dir='da_dir')
            

           Example 2 will store two additional images as the following:

           +--------------------------------------+-------------------------------------------+
           | .. figure:: img/rd_crop_2d.png       | .. figure:: img/rd_crop_mask_2d.png       |
           |   :width: 80%                        |   :width: 70%                             |
           |   :align: center                     |   :align: center                          |
           |                                      |                                           |
           |   Original crop                      |   Original crop mask                      |
           +--------------------------------------+-------------------------------------------+

           Together with these images another pair of images will be stored: the crop made and a 
           transformed version of it, which is really the generator output. 
    
           For instance, setting ``elastic=True`` the above extracted crop should be transformed as follows:
        
           +--------------------------------------+-------------------------------------------+
           | .. figure:: img/original_crop_2d.png | .. figure:: img/original_crop_mask_2d.png |
           |   :width: 80%                        |   :width: 70%                             |
           |   :align: center                     |   :align: center                          |
           |                                      |                                           |
           |   Original crop                      |   Original crop mask                      |
           +--------------------------------------+-------------------------------------------+
           | .. figure:: img/elastic_crop_2d.png  | .. figure:: img/elastic_crop_mask_2d.png  |
           |   :width: 80%                        |   :width: 70%                             |
           |   :align: center                     |   :align: center                          |
           |                                      |                                           |
           |   Elastic transformation of the crop |   Elastic transformation of them crop mask|
           +--------------------------------------+-------------------------------------------+

           The grid is only painted if ``train=False`` which should be used just to display transformations made.
           Selecting random rotations between 0 and 180 degrees should generate the following:
            
           +---------------------------------------------+--------------------------------------------------+
           | .. figure:: img/original_rd_rot_crop_2d.png | .. figure:: img/original_rd_rot_crop_mask_2d.png |
           |   :width: 80%                               |   :width: 70%                                    |
           |   :align: center                            |   :align: center                                 |
           |                                             |                                                  |
           |   Original crop                             |   Original crop mask                             |
           +---------------------------------------------+--------------------------------------------------+
           | .. figure:: img/rd_rot_crop_2d.png          | .. figure:: img/rd_rot_crop_mask_2d.png          |
           |   :width: 80%                               |   :width: 70%                                    |
           |   :align: center                            |   :align: center                                 |
           |                                             |                                                  |
           |   Random rotation [0, 180] of the crop      |   Random rotation [0, 180] of the crop mask      |
           +---------------------------------------------+--------------------------------------------------+
        """

        print("### TR-SAMPLES ###")

        if self.random_crops_in_DA and not force_full_images:
            batch_x = np.zeros((num_examples,) + self.shape)
            batch_y = np.zeros((num_examples,) + self.shape[:2]+(1,), dtype=np.uint8)
        else:
            batch_x = np.zeros((num_examples,) + self.X.shape[1:])
            batch_y = np.zeros((num_examples,) + self.Y.shape[1:3]+(1,), dtype=np.uint8)

        if save_to_dir:
            p = '_' if save_prefix is None else str(save_prefix)
            os.makedirs(out_dir, exist_ok=True)
   
        grid = False if train else True
                 
        # Generate the examples 
        print("0) Creating the examples of data augmentation . . .")
        for i in tqdm(range(num_examples)):
            if random_images:
                pos = random.randint(1,self.X.shape[0]-1) 
            else:
                pos = i

            # Apply crops if selected
            if self.random_crops_in_DA and not force_full_images:
                batch_x[i], batch_y[i], ox, oy,\
                s_x, s_y = random_crop(self.X[pos], self.Y[pos], self.shape[:2], 
                    self.val, draw_prob_map_points=True,
                    img_prob=(self.train_prob[pos] if self.train_prob is not None else None))
            else:
                batch_x[i] = self.X[pos]
                batch_y[i] = self.Y[pos]

            if not train:
                self.__draw_grid(batch_x[i])
                self.__draw_grid(batch_y[i])

            if save_to_dir:
                if self.X.shape[-1] > 1:
                    o_x = np.copy(batch_x[i])                                 
                else:
                    o_x = np.copy(batch_x[i,...,0])
                o_y = np.copy(batch_y[i,...,0])

            # Apply transformations
            if self.da:                                                         
                segmap = SegmentationMapsOnImage(                               
                    batch_y[i], shape=batch_y[i].shape)                         
                t_img, t_mask = self.seq(                                       
                    image=batch_x[i], segmentation_maps=segmap)                 
                t_mask = t_mask.get_arr()                                       
                batch_x[i] = t_img                                              
                batch_y[i] = t_mask

            if save_to_dir:
                # Save original images
                self.__draw_grid(o_x)                                           
                self.__draw_grid(o_y)
                o_x = o_x*255
                o_y = o_y*255
                if self.X.shape[-1] > 1:
                    im = Image.fromarray(o_x, 'RGB')
                else:
                    im = Image.fromarray(o_x)
                    im = im.convert('L')                                            
                im.save(os.path.join(out_dir,str(pos)+'_orig_x'+self.t_made+".png"))
                mask = Image.fromarray(o_y)
                mask = mask.convert('L')                                        
                mask.save(os.path.join(out_dir,str(pos)+'_orig_y'+self.t_made+".png"))

                # Save transformed images
                if self.X.shape[-1] > 1:
                    im = Image.fromarray(batch_x[i]*255, 'RGB')
                else:
                    im = Image.fromarray(batch_x[i,:,:,0]*255)
                    im = im.convert('L')
                im.save(os.path.join(out_dir, str(pos)+p+'x'+self.t_made+".png"))
                mask = Image.fromarray(batch_y[i,:,:,0]*255)
                mask = mask.convert('L')
                mask.save(os.path.join(out_dir, str(pos)+p+'y'+self.t_made+".png"))

                if self.n_classes > 1:
                    h_maks = np.zeros(self.shape[:2] + (self.n_classes,))
                    h_maks = np.asarray(img_to_onehot_encoding(
                                          batch_y[i], self.n_classes))
                    for i in range(self.n_classes):
                        a = Image.fromarray(h_maks[...,i])
                        a= a.convert('L')
                        a.save(os.path.join(out_dir, str(pos)+"h_mask_"+str(i)+".png"))

                # Save the original images with a point that represents the 
                # selected coordinates to be the center of the crop
                if self.random_crops_in_DA and self.train_prob is not None\
                   and not force_full_images:
                    if self.X.shape[-1] > 1:
                        im = Image.fromarray(self.X[pos]*255, 'RGB') 
                    else:
                        im = Image.fromarray(self.X[pos*255,:,:,0]) 
                    im = im.convert('RGB')                                                  
                    px = im.load()                                                          
                        
                    # Paint the selected point in red
                    p_size=6
                    for col in range(oy-p_size, oy+p_size):
                        for row in range(ox-p_size, ox+p_size): 
                            if col >= 0 and col < self.X.shape[1] and \
                               row >= 0 and row < self.X.shape[2]:
                               px[row, col] = (255, 0, 0) 
                    
                    # Paint a blue square that represents the crop made 
                    for row in range(s_x, s_x+self.shape[0]):
                        px[row, s_y] = (0, 0, 255)
                        px[row, s_y+self.shape[0]-1] = (0, 0, 255)
                    for col in range(s_y, s_y+self.shape[0]):                    
                        px[s_x, col] = (0, 0, 255)
                        px[s_x+self.shape[0]-1, col] = (0, 0, 255)

                    im.save(os.path.join(out_dir, str(pos)+p+'mark_x'+self.t_made+'.png'))
                   
                    mask = Image.fromarray(self.Y[pos,:,:,0]) 
                    mask = mask.convert('RGB')                                      
                    px = mask.load()                                              
                        
                    # Paint the selected point in red
                    for col in range(oy-p_size, oy+p_size):                       
                        for row in range(ox-p_size, ox+p_size):                   
                            if col >= 0 and col < self.Y.shape[1] and \
                               row >= 0 and row < self.Y.shape[2]:                
                               px[row, col] = (255, 0, 0)

                    # Paint a blue square that represents the crop made
                    for row in range(s_x, s_x+self.shape[0]):                
                        px[row, s_y] = (0, 0, 255)                          
                        px[row, s_y+self.shape[0]-1] = (0, 0, 255)       
                    for col in range(s_y, s_y+self.shape[0]):                
                        px[s_x, col] = (0, 0, 255)                          
                        px[s_x+self.shape[0]-1, col] = (0, 0, 255)

                    mask.save(os.path.join(out_dir, str(pos)+p+'mark_y'+self.t_made+'.png'))          
                
        print("### END TR-SAMPLES ###")
        return batch_x, batch_y
    def __getitem__(self, index):
        """Generation of one batch data. 

           Parameters
           ----------
           index : int
               Batch index counter.
            
           Returns
           -------
           batch_x : 4D Numpy array
               Corresponding X elements of the batch. 
               E.g. ``(batch_size, x, y, channels)``.

           batch_y : 4D Numpy array
               Corresponding Y elements of the batch.
               E.g. ``(batch_size, x, y, channels)``.
        """

        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        batch_x = np.zeros((len(indexes), *self.shape))
        batch_y = np.zeros((len(indexes), *self.shape[:2]+(1,)), dtype=np.uint8)

        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        for i, j in zip(range(len(indexes)), indexes):
            if self.random_crops_in_DA:
                batch_x[i], batch_y[i] = random_crop(
                    self.X[j], self.Y[j], self.shape[:2],
                    self.val, img_prob=(self.train_prob[j] if self.train_prob is not None else None))
            else:
                batch_x[i], batch_y[i] = self.X[j], self.Y[j]
            
            if self.da: 
                segmap = SegmentationMapsOnImage(
                    batch_y[i], shape=batch_y[i].shape)
                t_img, t_mask = self.seq(
                    image=batch_x[i], segmentation_maps=segmap)
                t_mask = t_mask.get_arr()
                batch_x[i] = t_img
                batch_y[i] = t_mask
                
        if self.n_classes > 1:
            batch_y_ = np.zeros((len(indexes),) + self.shape[:2] + (self.n_classes,), 
                                dtype=np.uint8)
            for i in range(len(indexes)):
                batch_y_[i] = np.asarray(img_to_onehot_encoding(
                                         batch_y[i], self.n_classes))
            batch_y = batch_y_

        # Need to cahnge datatype for resize function
        batch_y = batch_y.astype(np.float32)
        if self.out_number == 1:                                                
            return batch_x, batch_y                                             
        else:                                                                   
            return ([batch_x], [np.array([resize(img, (2,2), order=0, anti_aliasing=False) for img in batch_y]),\
                   np.array([resize(img, (4,4), order=0, anti_aliasing=False) for img in batch_y]),\
                   np.array([resize(img, (8,8), order=0, anti_aliasing=False) for img in batch_y]),\
                   np.array([resize(img, (16,16), order=0, anti_aliasing=False) for img in batch_y]),\
                   np.array([resize(img, (32,32), order=0, anti_aliasing=False) for img in batch_y]),\
                   np.array([resize(img, (64,64), order=0, anti_aliasing=False) for img in batch_y]),\
                   np.array([resize(img, (128,128), order=0, anti_aliasing=False) for img in batch_y]),\
                   batch_y])