def prediction(self, test_path, dest_path): """ a method to get predictions from a trained model of images in the test_path variable, and save the results to the path specified in the dest_path variable :param dest_path: the destination path to save he prediction results :param test_path: the path where the test data resides :return: """ logger.info(f"prediction on files from {test_path}") if self.train_time is None: self.train_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') save_path = data_functions.create_path(dest_path, self.train_time) save_path = data_functions.create_path(save_path, 'raw_pred') logger.info(f"saving predictions to {save_path}") test_gen = self.test_generator(test_path) for img, img_entry, orig_shape in test_gen: logger.info(f"getting prediction for {img_entry.name}") pred_raw = self.model.predict(img, batch_size=1)[0] pred_raw_resized = cv2.resize(pred_raw, orig_shape) file_name = img_entry.name.rsplit('.', 1)[0] + '.npy' npy_file_save_path = os.path.join(save_path, file_name) np.save(npy_file_save_path, pred_raw_resized, allow_pickle=True) pred_image = (255 * pred_raw_resized).astype(np.uint8) cv2.imwrite(os.path.join(save_path, img_entry.name), pred_image) return save_path
def create_save_dest(self): """ Create a folder with the current time stamp the destination folder, where any results to be save to file will be saved to :return: """ dir_save_path = data_functions.create_path(self.save_path, self.img_raw_name) current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') self.save_path = data_functions.create_path(dir_save_path, current_time)
def __init__(self, img_path, mask_path, save_path, threshold=0.5, use_thres_flag=True, hist_type='bgr', is_binary_mask=False): """ :param img_path: path to source images :param mask_path:path to the unet masks of the source images :param save_path: the save path of the results :param threshold: the threshold to get the binary mask, the binary mask will be activated for pixels that have values greater than the threshold :param use_thres_flag: :param hist_type: the type of histogram to calculate. default value is for 'bgr' which is the normal color mode. another option is 'hsv' for heu,saturation and value color space """ self.img_path = img_path self.mask_path = mask_path self.threshold = threshold self.is_binary_mask = is_binary_mask if use_thres_flag: self.thres_save_path = data_functions.create_path(save_path, f"thres_{threshold}") else: self.thres_save_path = save_path self.threshold_masks_path = None self.cut_image_path = None self.ontop_path = None self.hist_type = hist_type
def set_model_checkpint(self): """ set the model checkpoint keras callbacks method for the current training session, where the model weights will be saved in folder assigned for the current session :return: the save folder for the current training session """ keras_logs_path = data_functions.create_path(self.curr_folder, self.keras_logs_folder) file_name = self.weights_file_name if self.save_weights_only \ else self.checkpoint_filename steps_file_name = file_name + \ PARAMS_UPDATE_FORMAT + \ 'loss_{loss:.4f}.hdf5' steps_out_model_path = os.path.join(self.curr_folder, steps_file_name) steps_model_checkpoint = CustomModelCheckpoint( steps_out_model_path, monitor='loss', verbose=1, save_best_only=False, update_freq=self.weights_update_freq, batch_size=self.batch_size, save_weights_only=self.save_weights_only, samples_seen=self.samples_seen, model_params_path=self.params_filepath, session_n=self.session_number) # get some non augmented images for tensorboard visualizations _, val_gen_no_aug = self.clarifruit_train_val_generators( aug_flag=False) # TODO modify hardcoded values v_data = [next(val_gen_no_aug) for i in range(1000) if i % 200 == 0.0] image_history = CustomTensorboardCallback( log_dir=keras_logs_path, batch_size=self.batch_size, histogram_freq=0, write_graph=True, update_freq=self.tensorboard_update_freq, data=v_data, threshold=self.ontop_display_threshold, samples_seen=self.samples_seen) callbacks = [image_history, steps_model_checkpoint] if self.callbacks is None: self.callbacks = callbacks else: self.callbacks = callbacks + self.callbacks return keras_logs_path
def create_ground_truth_objects(ground_path,mask_path,save_path, threshold, hist_type, obj_type,use_thres_flag, is_binary_mask,folder='train'): """ A method to create object from the ground truth path, e.g can create hsv histogrames for the test and train, or bgr histogrames, or return images of the masks ovelayed ontop of the source images . e.t.c :param ground_path: path the the groud truth test and train dataset :param mask_path: path to the masks which will be used :param save_path: the destination path to save the results :param threshold: float, if the mask are the results of a prediction, than this is used to create binary images :param hist_type: str,optional, if creating a histogram, what type of histogram, 'bgr' or 'hsv' :param obj_type: str, what type of object to create , binary_images,histograms, stems (cutting the images with the masks) :param use_thres_flag: bool, where to create a new save folder in the sestination path for the current instance :param is_binary_mask: bool, whether the masks are binary :param folder: str, where this instance is for test or train :return: """ if use_thres_flag : dest_path = data_functions.create_path(save_path, f"thres_{threshold}") dest_path = data_functions.create_path(dest_path, folder) ground_path = os.path.join(ground_path, folder) logger.info(f"getting {obj_type} objects for {folder}") for curr_class in os.scandir(ground_path): logger.info(f"getting objects for {curr_class.name} class") curr_dest = data_functions.create_path(dest_path, curr_class.name) curr_ground_path = os.path.join(ground_path, curr_class.name) create_object(img_path=curr_ground_path, mask_path=mask_path, save_path=curr_dest, threshold=threshold, hist_type=hist_type, use_thres_flag=False, obj_type=obj_type, is_binary_mask=is_binary_mask)
def fillter_via_color(self): """ ---EXPERIMENTAL--- :return: """ out_path = data_functions.create_path(self.thres_save_path, f'filtered') logger.info(f"getting filterred images for threshold {self.threshold}") logger.info(f"creting filltered images in {out_path}") for img in self.image_obj_iterator(): res = img.filter_cut_image() cv2.imwrite(os.path.join(out_path, img.image_name), res)
def fillter_via_color_green_brown(self, save_flag=False): """ ---EXPERIMENTAL--- :param save_flag: :return: """ out_path = data_functions.create_path(self.thres_save_path, f'filtered') logger.info(f"creting filltered images in {out_path}") for img in self.image_obj_iterator(): pr_green, pr_brown = img.filter_cut_image_green_brown() raw_name = img.image_name.rsplit('.', 1)[0] pred = self.get_label(pr_green, pr_brown, img.image_name) curr_save_path = data_functions.create_path(out_path, pred) _ = shutil.copy(img.img_path, curr_save_path) if save_flag: cv2.imwrite(os.path.join(out_path, f'{raw_name}_green.jpg'), img.green_part) cv2.imwrite(os.path.join(out_path, f'{raw_name}_brown.jpg'), img.brown_part) cv2.imwrite(os.path.join(out_path, img.image_name), img.threshold_mask)
def get_ontop_images(self): """ return the overlay of the thresholded mask on top of the source image :return: """ self.ontop_path = data_functions.create_path(self.thres_save_path, f'on_top') logger.info(f"getting ontop images for threshold {self.threshold}") logger.info(f"creting ontop images in {self.ontop_path}") for img in self.image_obj_iterator(): cv2.imwrite(os.path.join(self.ontop_path, img.img_name), img.get_ontop())
def get_stems(self): """ extract the "stems" from the image - return the areas in the image that are activated in the thresholded mask eg if pixel (156,46) is turned on in the mask image, it will show in the result :return: """ self.cut_image_path = data_functions.create_path(self.thres_save_path, f'stems') logger.info(f"getting stems for threshold {self.threshold}") logger.info(f"creting stems in {self.cut_image_path}") for img in self.image_obj_iterator(): image_cut = img.cut_via_mask() cv2.imwrite(os.path.join(self.cut_image_path, img.image_name), image_cut)
def calc_hists(self): """ A method to calculate color histograms on source images while using the segmentation mask for calculation at the segmentation areas :return: """ dest_path = data_functions.create_path(self.thres_save_path, f'{self.hist_type}_histograms') logger.info( f"getting {self.hist_type} histograms for threshold" f" {self.threshold}") logger.info(f"saving results at {dest_path}") for img in self.image_obj_iterator(): curr_dest_path = os.path.join(dest_path, f"{img.img_raw_name}.npy") fig_big_hist = img.get_hist_via_mask() np.save(curr_dest_path, fig_big_hist)
def get_threshold_masks(self): """ get the binary mask, where the pixel value of the source mask (float type) are greater than self.threshold :return: """ logger.info( f"getting binary masks with threshold: {self.threshold} " f"from{self.mask_path}") self.threshold_masks_path = data_functions.create_path( self.thres_save_path, f'binary') logger.info(f"saving results at {self.threshold_masks_path}") for img in self.image_obj_iterator(): img_save_path = os.path.join(self.threshold_masks_path, img.image_name) cv2.imwrite(img_save_path, img.threshold_mask)