def get_val_metrics(cnn, val_annot_dir, dataset_dir, in_w, out_w, bs): """ Return the TP, FP, TN, FN, defined_sum, duration for the {cnn} on the validation set TODO - This is too similar to the train loop. Merge both and use flags. """ start = time.time() fnames = ls(val_annot_dir) fnames = [a for a in fnames if im_utils.is_photo(a)] cnn.half() # TODO: In order to speed things up, be a bit smarter here # by only segmenting the parts of the image where we have # some annotation defined. # implement a 'partial segment' which exlcudes tiles with no # annotation defined. tps = 0 fps = 0 tns = 0 fns = 0 defined_sum = 0 for fname in fnames: annot_path = os.path.join(val_annot_dir, os.path.splitext(fname)[0] + '.png') # reading the image may throw an exception. # I suspect this is due to it being only partially written to disk # simply retry if this happens. try: annot = imread(annot_path) except Exception as ex: print('Exception reading annotation inside validation method.' 'Will retry in 0.1 seconsds') print(fname, ex) time.sleep(0.1) annot = imread(annot_path) annot = np.array(annot) foreground = annot[:, :, 0].astype(bool).astype(int) background = annot[:, :, 1].astype(bool).astype(int) image_path_part = os.path.join(dataset_dir, os.path.splitext(fname)[0]) image_path = glob.glob(image_path_part + '.*')[0] image = im_utils.load_image(image_path) predicted = unet_segment(cnn, image, bs, in_w, out_w, threshold=0.5) # mask defines which pixels are defined in the annotation. mask = foreground + background mask = mask.astype(bool).astype(int) predicted *= mask predicted = predicted.astype(bool).astype(int) y_defined = mask.reshape(-1) y_pred = predicted.reshape(-1)[y_defined > 0] y_true = foreground.reshape(-1)[y_defined > 0] tps += np.sum(np.logical_and(y_pred == 1, y_true == 1)) tns += np.sum(np.logical_and(y_pred == 0, y_true == 0)) fps += np.sum(np.logical_and(y_pred == 1, y_true == 0)) fns += np.sum(np.logical_and(y_pred == 0, y_true == 1)) defined_sum += np.sum(y_defined > 0) duration = round(time.time() - start, 3) metrics = get_metrics(tps, fps, tns, fns, defined_sum, duration) return metrics
def save_im_pieces(im_path, target_dir, pieces_from_each_image, target_size): pieces = get_file_pieces(im_utils.load_image(im_path), target_size) pieces = random.sample(pieces, min(pieces_from_each_image, len(pieces))) fname = os.path.basename(im_path) fname = os.path.splitext(fname)[0] for i, p in enumerate(pieces): piece_fname = f"{fname}_{str(i).zfill(3)}.jpg" if p.shape[-1] == 4: p = rgba2rgb(p) imsave(os.path.join(target_dir, piece_fname), p, check_contrast=False)
def segment_file(self, in_dir, seg_dir, fname, model_paths, sync_save): fpath = os.path.join(in_dir, fname) # Segmentations are always saved as PNG. out_path = os.path.join(seg_dir, os.path.splitext(fname)[0] + '.png') if os.path.isfile(out_path): print('Skip because found existing segmentation file') return if not os.path.isfile(fpath): print('Cannot segment as missing file', fpath) else: try: photo = load_image(fpath) except Exception as e: # Could be temporary issues reading the image. # its ok just skip it. print('Exception loading', fpath, e) return # if input is smaller than this, behaviour is unpredictable. if photo.shape[0] < self.in_w or photo.shape[1] < self.in_w: # skip images that are too small. message = (f"image {fname} too small to segment. Width " f" and height must be at least {self.in_w}") print(message) self.log(message) self.write_message(message) seg_start = time.time() segmented = ensemble_segment(model_paths, photo, self.bs, self.in_w, self.out_w) print(f'ensemble segment {fname}, dur', round(time.time() - seg_start, 2)) # catch warnings as low contrast is ok here. with warnings.catch_warnings(): # create a version with alpha channel warnings.simplefilter("ignore") seg_alpha = np.zeros( (segmented.shape[0], segmented.shape[1], 4)) seg_alpha[segmented > 0] = [0, 1.0, 1.0, 0.7] # Conver to uint8 to save as png without warning seg_alpha = (seg_alpha * 255).astype(np.uint8) if sync_save: # other wise do sync because we don't want to delete the segment # instruction too early. save_then_move(out_path, seg_alpha) else: # TODO find a cleaner way to do this. # if more than one file then optimize speed over stability. x = threading.Thread(target=save_then_move, args=(out_path, seg_alpha)) x.start()