Ejemplo n.º 1
0
    def prediction(self, test_path, dest_path):
        """
        a method to get predictions from a trained model of images in the
        test_path variable, and save the results to the path specified in the
        dest_path variable
        :param dest_path: the destination path to save he prediction results
        :param test_path: the path where the test data resides
        :return:
        """
        logger.info(f"prediction on files from {test_path}")

        if self.train_time is None:
            self.train_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

        save_path = data_functions.create_path(dest_path, self.train_time)
        save_path = data_functions.create_path(save_path, 'raw_pred')
        logger.info(f"saving predictions to {save_path}")

        test_gen = self.test_generator(test_path)
        for img, img_entry, orig_shape in test_gen:
            logger.info(f"getting prediction for {img_entry.name}")
            pred_raw = self.model.predict(img, batch_size=1)[0]
            pred_raw_resized = cv2.resize(pred_raw, orig_shape)

            file_name = img_entry.name.rsplit('.', 1)[0] + '.npy'
            npy_file_save_path = os.path.join(save_path, file_name)
            np.save(npy_file_save_path, pred_raw_resized, allow_pickle=True)

            pred_image = (255 * pred_raw_resized).astype(np.uint8)
            cv2.imwrite(os.path.join(save_path, img_entry.name), pred_image)

        return save_path
Ejemplo n.º 2
0
 def create_save_dest(self):
     """
     Create a folder with the current time stamp the destination folder,
     where any results to be save to file will be saved to
     :return:
     """
     dir_save_path = data_functions.create_path(self.save_path,
                                                self.img_raw_name)
     current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
     self.save_path = data_functions.create_path(dir_save_path,
                                                 current_time)
Ejemplo n.º 3
0
    def __init__(self, img_path, mask_path, save_path, threshold=0.5,
                 use_thres_flag=True, hist_type='bgr', is_binary_mask=False):
        """

        :param img_path: path to source images
        :param mask_path:path to the unet masks of the source images
        :param save_path: the save path of the results
        :param threshold: the threshold to get the binary mask, the binary mask will be activated for pixels that have
        values greater than the threshold

        :param use_thres_flag:
        :param hist_type: the type of histogram to calculate. default value is for 'bgr' which is the normal color mode.
        another option is 'hsv' for heu,saturation and value color space
        """

        self.img_path = img_path
        self.mask_path = mask_path
        self.threshold = threshold
        self.is_binary_mask = is_binary_mask
        if use_thres_flag:
            self.thres_save_path = data_functions.create_path(save_path,
                                                              f"thres_{threshold}")
        else:
            self.thres_save_path = save_path

        self.threshold_masks_path = None

        self.cut_image_path = None
        self.ontop_path = None
        self.hist_type = hist_type
Ejemplo n.º 4
0
    def set_model_checkpint(self):
        """
        set the model checkpoint keras callbacks method for the current training
        session, where the model weights will be saved in folder assigned for
        the current session
        :return: the save folder for the current training session
        """

        keras_logs_path = data_functions.create_path(self.curr_folder,
                                                     self.keras_logs_folder)

        file_name = self.weights_file_name if self.save_weights_only \
            else self.checkpoint_filename

        steps_file_name = file_name + \
                          PARAMS_UPDATE_FORMAT + \
                          'loss_{loss:.4f}.hdf5'

        steps_out_model_path = os.path.join(self.curr_folder, steps_file_name)

        steps_model_checkpoint = CustomModelCheckpoint(
            steps_out_model_path,
            monitor='loss',
            verbose=1,
            save_best_only=False,
            update_freq=self.weights_update_freq,
            batch_size=self.batch_size,
            save_weights_only=self.save_weights_only,
            samples_seen=self.samples_seen,
            model_params_path=self.params_filepath,
            session_n=self.session_number)

        # get some non augmented images for tensorboard visualizations
        _, val_gen_no_aug = self.clarifruit_train_val_generators(
            aug_flag=False)

        # TODO modify hardcoded values
        v_data = [next(val_gen_no_aug) for i in range(1000) if i % 200 == 0.0]

        image_history = CustomTensorboardCallback(
            log_dir=keras_logs_path,
            batch_size=self.batch_size,
            histogram_freq=0,
            write_graph=True,
            update_freq=self.tensorboard_update_freq,
            data=v_data,
            threshold=self.ontop_display_threshold,
            samples_seen=self.samples_seen)

        callbacks = [image_history, steps_model_checkpoint]

        if self.callbacks is None:
            self.callbacks = callbacks
        else:
            self.callbacks = callbacks + self.callbacks

        return keras_logs_path
Ejemplo n.º 5
0
def create_ground_truth_objects(ground_path,mask_path,save_path, threshold,
                                hist_type, obj_type,use_thres_flag,
                                is_binary_mask,folder='train'):
    """
    A method to create object from the ground truth path,
    e.g can create hsv histogrames for the test and train,
    or bgr histogrames, or return images of the masks ovelayed ontop of the
    source images . e.t.c
    :param ground_path: path the the groud truth test and train dataset
    :param mask_path: path to the masks which will be used
    :param save_path:  the destination path to save the results
    :param threshold: float, if the mask are the results of a prediction,
    than this is used to create binary images
    :param hist_type: str,optional, if creating a histogram, what type of
     histogram, 'bgr' or 'hsv'
    :param obj_type: str, what type of object to create , binary_images,histograms,
    stems (cutting the images with the masks)
    :param use_thres_flag: bool, where to create a new save folder in the sestination
    path for the current instance
    :param is_binary_mask: bool, whether the masks are binary
    :param folder: str, where this instance is for test or train
    :return:
    """
    if use_thres_flag :
        dest_path = data_functions.create_path(save_path, f"thres_{threshold}")
    dest_path = data_functions.create_path(dest_path, folder)
    ground_path = os.path.join(ground_path, folder)
    logger.info(f"getting {obj_type} objects for {folder}")
    for curr_class in os.scandir(ground_path):
        logger.info(f"getting objects for {curr_class.name} class")
        curr_dest = data_functions.create_path(dest_path, curr_class.name)
        curr_ground_path = os.path.join(ground_path, curr_class.name)

        create_object(img_path=curr_ground_path,
                      mask_path=mask_path,
                      save_path=curr_dest,
                      threshold=threshold,
                      hist_type=hist_type,
                      use_thres_flag=False,
                      obj_type=obj_type,
                      is_binary_mask=is_binary_mask)
Ejemplo n.º 6
0
    def fillter_via_color(self):
        """
        ---EXPERIMENTAL---
        :return:
        """

        out_path = data_functions.create_path(self.thres_save_path, f'filtered')
        logger.info(f"getting filterred images for threshold {self.threshold}")
        logger.info(f"creting filltered images in {out_path}")
        for img in self.image_obj_iterator():
            res = img.filter_cut_image()
            cv2.imwrite(os.path.join(out_path, img.image_name), res)
Ejemplo n.º 7
0
 def fillter_via_color_green_brown(self, save_flag=False):
     """
     ---EXPERIMENTAL---
     :param save_flag:
     :return:
     """
     out_path = data_functions.create_path(self.thres_save_path, f'filtered')
     logger.info(f"creting filltered images in {out_path}")
     for img in self.image_obj_iterator():
         pr_green, pr_brown = img.filter_cut_image_green_brown()
         raw_name = img.image_name.rsplit('.', 1)[0]
         pred = self.get_label(pr_green, pr_brown, img.image_name)
         curr_save_path = data_functions.create_path(out_path, pred)
         _ = shutil.copy(img.img_path, curr_save_path)
         if save_flag:
             cv2.imwrite(os.path.join(out_path, f'{raw_name}_green.jpg'),
                         img.green_part)
             cv2.imwrite(os.path.join(out_path, f'{raw_name}_brown.jpg'),
                         img.brown_part)
             cv2.imwrite(os.path.join(out_path, img.image_name),
                         img.threshold_mask)
Ejemplo n.º 8
0
    def get_ontop_images(self):
        """
        return the overlay of the thresholded mask on top of the source image
        :return:
        """

        self.ontop_path = data_functions.create_path(self.thres_save_path,
                                                     f'on_top')
        logger.info(f"getting ontop images for threshold {self.threshold}")
        logger.info(f"creting ontop images in {self.ontop_path}")
        for img in self.image_obj_iterator():
            cv2.imwrite(os.path.join(self.ontop_path, img.img_name),
                        img.get_ontop())
Ejemplo n.º 9
0
    def get_stems(self):
        """
        extract the "stems" from the image - return the areas in the image that are activated in the thresholded mask
        eg if pixel (156,46) is turned on in the mask image, it will show in the result
        :return:
        """

        self.cut_image_path = data_functions.create_path(self.thres_save_path,
                                                         f'stems')
        logger.info(f"getting stems for threshold {self.threshold}")
        logger.info(f"creting stems in {self.cut_image_path}")

        for img in self.image_obj_iterator():
            image_cut = img.cut_via_mask()
            cv2.imwrite(os.path.join(self.cut_image_path, img.image_name),
                        image_cut)
Ejemplo n.º 10
0
 def calc_hists(self):
     """
     A method to calculate color histograms on source images while using
     the segmentation mask for calculation at the segmentation areas
     :return:
     """
     dest_path = data_functions.create_path(self.thres_save_path,
                                            f'{self.hist_type}_histograms')
     logger.info(
         f"getting {self.hist_type} histograms for threshold"
         f" {self.threshold}")
     logger.info(f"saving results at {dest_path}")
     for img in self.image_obj_iterator():
         curr_dest_path = os.path.join(dest_path, f"{img.img_raw_name}.npy")
         fig_big_hist = img.get_hist_via_mask()
         np.save(curr_dest_path, fig_big_hist)
Ejemplo n.º 11
0
    def get_threshold_masks(self):
        """
        get the binary mask, where the pixel value of the source mask (float
        type) are greater than self.threshold
        :return:
        """
        logger.info(
            f"getting binary masks with threshold: {self.threshold} "
            f"from{self.mask_path}")

        self.threshold_masks_path = data_functions.create_path(
            self.thres_save_path, f'binary')

        logger.info(f"saving results at {self.threshold_masks_path}")
        for img in self.image_obj_iterator():
            img_save_path = os.path.join(self.threshold_masks_path,
                                         img.image_name)
            cv2.imwrite(img_save_path, img.threshold_mask)