Example #1
0
def get_val_metrics(cnn, val_annot_dir, dataset_dir, in_w, out_w, bs):
    """
    Return the TP, FP, TN, FN, defined_sum, duration
    for the {cnn} on the validation set

    TODO - This is too similar to the train loop. Merge both and use flags.
    """
    start = time.time()
    fnames = ls(val_annot_dir)
    fnames = [a for a in fnames if im_utils.is_photo(a)]
    cnn.half()
    # TODO: In order to speed things up, be a bit smarter here
    # by only segmenting the parts of the image where we have
    # some annotation defined.
    # implement a 'partial segment' which exlcudes tiles with no
    # annotation defined.
    tps = 0
    fps = 0
    tns = 0
    fns = 0
    defined_sum = 0
    for fname in fnames:
        annot_path = os.path.join(val_annot_dir,
                                  os.path.splitext(fname)[0] + '.png')
        # reading the image may throw an exception.
        # I suspect this is due to it being only partially written to disk
        # simply retry if this happens.
        try:
            annot = imread(annot_path)
        except Exception as ex:
            print('Exception reading annotation inside validation method.'
                  'Will retry in 0.1 seconsds')
            print(fname, ex)
            time.sleep(0.1)
            annot = imread(annot_path)

        annot = np.array(annot)
        foreground = annot[:, :, 0].astype(bool).astype(int)
        background = annot[:, :, 1].astype(bool).astype(int)
        image_path_part = os.path.join(dataset_dir, os.path.splitext(fname)[0])
        image_path = glob.glob(image_path_part + '.*')[0]
        image = im_utils.load_image(image_path)
        predicted = unet_segment(cnn, image, bs, in_w, out_w, threshold=0.5)
        # mask defines which pixels are defined in the annotation.
        mask = foreground + background
        mask = mask.astype(bool).astype(int)
        predicted *= mask
        predicted = predicted.astype(bool).astype(int)
        y_defined = mask.reshape(-1)
        y_pred = predicted.reshape(-1)[y_defined > 0]
        y_true = foreground.reshape(-1)[y_defined > 0]
        tps += np.sum(np.logical_and(y_pred == 1, y_true == 1))
        tns += np.sum(np.logical_and(y_pred == 0, y_true == 0))
        fps += np.sum(np.logical_and(y_pred == 1, y_true == 0))
        fns += np.sum(np.logical_and(y_pred == 0, y_true == 1))
        defined_sum += np.sum(y_defined > 0)
    duration = round(time.time() - start, 3)
    metrics = get_metrics(tps, fps, tns, fns, defined_sum, duration)
    return metrics
Example #2
0
def save_im_pieces(im_path, target_dir, pieces_from_each_image, target_size):
    pieces = get_file_pieces(im_utils.load_image(im_path), target_size)
    pieces = random.sample(pieces, min(pieces_from_each_image, len(pieces)))
    fname = os.path.basename(im_path)
    fname = os.path.splitext(fname)[0]
    for i, p in enumerate(pieces):
        piece_fname = f"{fname}_{str(i).zfill(3)}.jpg"
        if p.shape[-1] == 4:
            p = rgba2rgb(p)
        imsave(os.path.join(target_dir, piece_fname), p, check_contrast=False)
Example #3
0
    def segment_file(self, in_dir, seg_dir, fname, model_paths, sync_save):
        fpath = os.path.join(in_dir, fname)

        # Segmentations are always saved as PNG.
        out_path = os.path.join(seg_dir, os.path.splitext(fname)[0] + '.png')
        if os.path.isfile(out_path):
            print('Skip because found existing segmentation file')
            return
        if not os.path.isfile(fpath):
            print('Cannot segment as missing file', fpath)
        else:
            try:
                photo = load_image(fpath)
            except Exception as e:
                # Could be temporary issues reading the image.
                # its ok just skip it.
                print('Exception loading', fpath, e)
                return
            # if input is smaller than this, behaviour is unpredictable.
            if photo.shape[0] < self.in_w or photo.shape[1] < self.in_w:
                # skip images that are too small.
                message = (f"image {fname} too small to segment. Width "
                           f" and height must be at least {self.in_w}")
                print(message)
                self.log(message)
                self.write_message(message)
            seg_start = time.time()
            segmented = ensemble_segment(model_paths, photo, self.bs,
                                         self.in_w, self.out_w)
            print(f'ensemble segment {fname}, dur',
                  round(time.time() - seg_start, 2))
            # catch warnings as low contrast is ok here.
            with warnings.catch_warnings():
                # create a version with alpha channel
                warnings.simplefilter("ignore")
                seg_alpha = np.zeros(
                    (segmented.shape[0], segmented.shape[1], 4))
                seg_alpha[segmented > 0] = [0, 1.0, 1.0, 0.7]
                # Conver to uint8 to save as png without warning
                seg_alpha = (seg_alpha * 255).astype(np.uint8)

                if sync_save:
                    # other wise do sync because we don't want to delete the segment
                    # instruction too early.
                    save_then_move(out_path, seg_alpha)
                else:
                    # TODO find a cleaner way to do this.
                    # if more than one file then optimize speed over stability.
                    x = threading.Thread(target=save_then_move,
                                         args=(out_path, seg_alpha))
                    x.start()