예제 #1
0
    def _draw_batch_preds(harn, batch, outputs, lim=16):
        """
        Example:
            >>> # xdoctest: +REQUIRES(--slow)
            >>> kw = {'workers': 0, 'xpu': 'cpu', 'batch_size': 8}
            >>> harn = setup_harn(cmdline=False, **kw).initialize()
            >>> batch = harn._demo_batch(tag='train')
            >>> outputs, loss_parts = harn.run_batch(batch)
            >>> toshow = harn._draw_batch_preds(batch, outputs)
            >>> # xdoctest: +REQUIRES(--show)
            >>> import kwplot
            >>> kwplot.autompl()
            >>> kwplot.imshow(toshow)
        """
        import cv2
        im = batch['im'].data.cpu().numpy()
        class_true = batch['class_idxs'].data.cpu().numpy()
        class_pred = outputs['class_probs'].data.cpu().numpy().argmax(axis=1)

        batch_imgs = []

        for bx in range(min(len(class_true), lim)):
            orig_img = im[bx].transpose(1, 2, 0)

            out_size = class_pred[bx].shape[::-1]

            orig_img = cv2.resize(orig_img, tuple(map(int, out_size)))
            orig_img = kwimage.ensure_alpha_channel(orig_img)

            pred_heatmap = kwimage.Heatmap(
                class_idx=class_pred[bx],
                classes=harn.classes
            )
            true_heatmap = kwimage.Heatmap(
                class_idx=class_true[bx],
                classes=harn.classes
            )

            # TODO: scale up to original image size

            pred_img = pred_heatmap.draw_on(orig_img, channel='idx', with_alpha=.5)
            true_img = true_heatmap.draw_on(orig_img, channel='idx', with_alpha=.5)

            true_img = kwimage.ensure_uint255(true_img)
            pred_img = kwimage.ensure_uint255(pred_img)

            true_img = kwimage.draw_text_on_image(
                true_img, 'true', org=(0, 0), valign='top', color='blue')

            pred_img = kwimage.draw_text_on_image(
                pred_img, 'pred', org=(0, 0), valign='top', color='blue')

            item_img = kwimage.stack_images([pred_img, true_img], axis=1)
            batch_imgs.append(item_img)

        toshow = kwimage.stack_images_grid(batch_imgs, chunksize=2, overlap=-32)
        return toshow
예제 #2
0
    def predict(self, path_or_image):
        self._preload()
        if isinstance(path_or_image, six.string_types):
            print('Reading {!r}'.format(path_or_image))
            full_rgb = kwimage.imread(path_or_image, space='rgb')
        else:
            full_rgb = path_or_image

        # Note, this doesn't handle the case where the image is smaller than
        # the window
        input_dims = full_rgb.shape[0:2]

        try:
            classes = self.model.module.classes
        except AttributeError:
            classes = self.model.classes

        stitchers = {
            'class_energy': nh.util.Stitcher((len(classes),) + tuple(input_dims))
        }
        slider = nh.util.SlidingWindow(input_dims, self.config['window_dims'],
                                       overlap=0, keepbound=True,
                                       allow_overshoot=True)

        slider_dset = PredSlidingWindowDataset(slider, full_rgb)
        slider_loader = torch.utils.data.DataLoader(slider_dset,
                                                    batch_size=self.config['batch_size'],
                                                    num_workers=self.config['workers'],
                                                    shuffle=False,
                                                    pin_memory=True)

        prog = ub.ProgIter(slider_loader, desc='sliding window')

        with torch.no_grad():
            for raw_batch in prog:
                im = self.xpu.move(raw_batch['im'])
                outputs = self.model(im)

                class_energy = outputs['class_energy'].data.cpu().numpy()

                batch_sl_st_dims = raw_batch['sl_st_dims'].data.cpu().long().numpy().tolist()
                batch_sl_dims = [tuple(slice(s, t) for s, t in item)
                                 for item in batch_sl_st_dims]

                for sl_dims, energy in zip(batch_sl_dims, class_energy):
                    slices = (slice(None),) + sl_dims
                    stitchers['class_energy'].add(slices, energy)

            full_class_energy = stitchers['class_energy'].finalize()
            full_class_probs = torch.FloatTensor(full_class_energy).softmax(dim=0)
            full_class_probs = full_class_probs.numpy()
            full_class_idx = full_class_probs.argmax(axis=0)

        pred_heatmap = kwimage.Heatmap(
            class_probs=full_class_probs,
            class_idx=full_class_idx,
            classes=classes,
            datakeys=['class_idx'],
        )

        return pred_heatmap
예제 #3
0
def evaluate_network(sampler, eval_config):
    """
    TODO:
        - [ ] set this up as its own script
        - [ ] find a way to generalize the Evaluator concept using the Predictor.

    Notes:
        scores to beat: http://mi.eng.cam.ac.uk/projects/segnet/tutorial.html
        basic pixel_acc=82.8% class_acc=62.3% mean_iou=46.3%
        best pixel_acc=88.6% class_acc=81.3% mean_iou=69.1%
        Segnet (3.5K dataset) 86.8%, 81.3%, 69.1%,

        ours pixel_acc=86.2%, class_acc=45.4% mean_iou=31.2%

    Ignore:
        >>> import sys, ubelt
        >>> sys.path.append(ubelt.expandpath('~/code/netharn/examples'))
        >>> from sseg_camvid import *  # NOQA
        >>> kw = {'workers': 0, 'xpu': 'auto', 'batch_size': 8}
        >>> harn = setup_harn(cmdline=False, **kw).initialize()
        >>> harn.datasets['test']
        >>> out_dpath = ub.ensuredir((harn.train_dpath, 'monitor/test/'))
        >>> test_dset = harn.datasets['test']
        >>> sampler = test_dset.sampler
        >>> deployed = ub.expandpath('/home/joncrall/work/camvid/fit/nice/camvid_augment_weight_v2/deploy_UNet_otavqgrp_089_FOAUOG.zip')
        >>> out_dpath = ub.expandpath('/home/joncrall/work/camvid/fit/nice/monitor/test')
        >>> do_draw = True
        >>> evaluate_network(sampler, deployed, out_dpath, do_draw)
    """
    from netharn.data.grab_camvid import rgb_to_cid
    coco_dset = sampler.dset
    classes = sampler.classes

    # TODO: find a way to generalize the Evaluator concept using the
    # Predictor.
    pred_cfgs = {
        'workers': 0,
        'deployed': eval_config['deployed'],
        'xpu': eval_config['xpu'],
        # 'window_dims': (720, 960),
        'window_dims': (512, 512),
        'batch_size': 2,
    }
    segmenter = SegmentationPredictor(**pred_cfgs)
    segmenter._preload()

    evaluator = SegmentationEvaluator(classes)

    cid_to_cx = classes.id_to_idx
    cx_to_cid = np.zeros(len(classes), dtype=np.int32)
    for cid, cx in cid_to_cx.items():
        cx_to_cid[cx] = cid

    def camvid_load_truth(img):
        # TODO: better way to load per-pixel truth with sampler
        mask_fpath = join(coco_dset.img_root, img['segmentation'])
        rgb_mask = kwimage.imread(mask_fpath, space='rgb')
        r, g, b  = rgb_mask.T.astype(np.int64)
        cid_mask = np.ascontiguousarray(rgb_to_cid(r, g, b).T)

        ignore_idx = 0
        cidx_mask = np.full_like(cid_mask, fill_value=ignore_idx)
        for cx, cid in enumerate(cx_to_cid):
            locs = (cid_mask == cid)
            cidx_mask[locs] = cx
        return cidx_mask

    prog = ub.ProgIter(sampler.image_ids, 'evaluating', clearline=False)
    for gid in prog:
        img, annots = sampler.load_image_with_annots(gid)
        prog.ensure_newline()
        print('Estimate: ' + ub.repr2(evaluator.estimate, nl=0, precision=3))

        full_rgb = img['imdata']
        pred_heatmap = segmenter.predict(full_rgb)

        # Prepare truth and predictions
        true_cidx = camvid_load_truth(img)
        true_heatmap = kwimage.Heatmap(class_idx=true_cidx, classes=classes)

        # Ensure predictions are comparable to the truth
        pred_cid = pred_heatmap.data['class_idx']
        pred_heatmap.data['class_cid'] = cx_to_cid[pred_cid]

        # Add truth and predictions to the evaluator
        img_results = evaluator.add(gid, true_heatmap, pred_heatmap)

        prog.set_extra('mean_iou_g={:.2f}% mean_iou_t={:.2f}%'.format(
            img_results['mean_iou'], evaluator.estimate['mean_iou'])
        )

        if eval_config['do_draw']:
            out_dpath = eval_config['out_dpath']
            canvas = pred_heatmap.draw_on(full_rgb, channel='idx', with_alpha=0.5)
            gpath = join(out_dpath, 'gid_{:04d}.jpg'.format(gid))
            kwimage.imwrite(gpath, canvas)

    print('Final: ' + ub.repr2(evaluator.estimate, nl=0, precision=3))