def _draw_batch_preds(harn, batch, outputs, lim=16): """ Example: >>> # xdoctest: +REQUIRES(--slow) >>> kw = {'workers': 0, 'xpu': 'cpu', 'batch_size': 8} >>> harn = setup_harn(cmdline=False, **kw).initialize() >>> batch = harn._demo_batch(tag='train') >>> outputs, loss_parts = harn.run_batch(batch) >>> toshow = harn._draw_batch_preds(batch, outputs) >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(toshow) """ import cv2 im = batch['im'].data.cpu().numpy() class_true = batch['class_idxs'].data.cpu().numpy() class_pred = outputs['class_probs'].data.cpu().numpy().argmax(axis=1) batch_imgs = [] for bx in range(min(len(class_true), lim)): orig_img = im[bx].transpose(1, 2, 0) out_size = class_pred[bx].shape[::-1] orig_img = cv2.resize(orig_img, tuple(map(int, out_size))) orig_img = kwimage.ensure_alpha_channel(orig_img) pred_heatmap = kwimage.Heatmap( class_idx=class_pred[bx], classes=harn.classes ) true_heatmap = kwimage.Heatmap( class_idx=class_true[bx], classes=harn.classes ) # TODO: scale up to original image size pred_img = pred_heatmap.draw_on(orig_img, channel='idx', with_alpha=.5) true_img = true_heatmap.draw_on(orig_img, channel='idx', with_alpha=.5) true_img = kwimage.ensure_uint255(true_img) pred_img = kwimage.ensure_uint255(pred_img) true_img = kwimage.draw_text_on_image( true_img, 'true', org=(0, 0), valign='top', color='blue') pred_img = kwimage.draw_text_on_image( pred_img, 'pred', org=(0, 0), valign='top', color='blue') item_img = kwimage.stack_images([pred_img, true_img], axis=1) batch_imgs.append(item_img) toshow = kwimage.stack_images_grid(batch_imgs, chunksize=2, overlap=-32) return toshow
def predict(self, path_or_image): self._preload() if isinstance(path_or_image, six.string_types): print('Reading {!r}'.format(path_or_image)) full_rgb = kwimage.imread(path_or_image, space='rgb') else: full_rgb = path_or_image # Note, this doesn't handle the case where the image is smaller than # the window input_dims = full_rgb.shape[0:2] try: classes = self.model.module.classes except AttributeError: classes = self.model.classes stitchers = { 'class_energy': nh.util.Stitcher((len(classes),) + tuple(input_dims)) } slider = nh.util.SlidingWindow(input_dims, self.config['window_dims'], overlap=0, keepbound=True, allow_overshoot=True) slider_dset = PredSlidingWindowDataset(slider, full_rgb) slider_loader =, batch_size=self.config['batch_size'], num_workers=self.config['workers'], shuffle=False, pin_memory=True) prog = ub.ProgIter(slider_loader, desc='sliding window') with torch.no_grad(): for raw_batch in prog: im = self.xpu.move(raw_batch['im']) outputs = self.model(im) class_energy = outputs['class_energy'].data.cpu().numpy() batch_sl_st_dims = raw_batch['sl_st_dims'].data.cpu().long().numpy().tolist() batch_sl_dims = [tuple(slice(s, t) for s, t in item) for item in batch_sl_st_dims] for sl_dims, energy in zip(batch_sl_dims, class_energy): slices = (slice(None),) + sl_dims stitchers['class_energy'].add(slices, energy) full_class_energy = stitchers['class_energy'].finalize() full_class_probs = torch.FloatTensor(full_class_energy).softmax(dim=0) full_class_probs = full_class_probs.numpy() full_class_idx = full_class_probs.argmax(axis=0) pred_heatmap = kwimage.Heatmap( class_probs=full_class_probs, class_idx=full_class_idx, classes=classes, datakeys=['class_idx'], ) return pred_heatmap
def evaluate_network(sampler, eval_config): """ TODO: - [ ] set this up as its own script - [ ] find a way to generalize the Evaluator concept using the Predictor. Notes: scores to beat: basic pixel_acc=82.8% class_acc=62.3% mean_iou=46.3% best pixel_acc=88.6% class_acc=81.3% mean_iou=69.1% Segnet (3.5K dataset) 86.8%, 81.3%, 69.1%, ours pixel_acc=86.2%, class_acc=45.4% mean_iou=31.2% Ignore: >>> import sys, ubelt >>> sys.path.append(ubelt.expandpath('~/code/netharn/examples')) >>> from sseg_camvid import * # NOQA >>> kw = {'workers': 0, 'xpu': 'auto', 'batch_size': 8} >>> harn = setup_harn(cmdline=False, **kw).initialize() >>> harn.datasets['test'] >>> out_dpath = ub.ensuredir((harn.train_dpath, 'monitor/test/')) >>> test_dset = harn.datasets['test'] >>> sampler = test_dset.sampler >>> deployed = ub.expandpath('/home/joncrall/work/camvid/fit/nice/camvid_augment_weight_v2/') >>> out_dpath = ub.expandpath('/home/joncrall/work/camvid/fit/nice/monitor/test') >>> do_draw = True >>> evaluate_network(sampler, deployed, out_dpath, do_draw) """ from import rgb_to_cid coco_dset = sampler.dset classes = sampler.classes # TODO: find a way to generalize the Evaluator concept using the # Predictor. pred_cfgs = { 'workers': 0, 'deployed': eval_config['deployed'], 'xpu': eval_config['xpu'], # 'window_dims': (720, 960), 'window_dims': (512, 512), 'batch_size': 2, } segmenter = SegmentationPredictor(**pred_cfgs) segmenter._preload() evaluator = SegmentationEvaluator(classes) cid_to_cx = classes.id_to_idx cx_to_cid = np.zeros(len(classes), dtype=np.int32) for cid, cx in cid_to_cx.items(): cx_to_cid[cx] = cid def camvid_load_truth(img): # TODO: better way to load per-pixel truth with sampler mask_fpath = join(coco_dset.img_root, img['segmentation']) rgb_mask = kwimage.imread(mask_fpath, space='rgb') r, g, b = rgb_mask.T.astype(np.int64) cid_mask = np.ascontiguousarray(rgb_to_cid(r, g, b).T) ignore_idx = 0 cidx_mask = np.full_like(cid_mask, fill_value=ignore_idx) for cx, cid in enumerate(cx_to_cid): locs = (cid_mask == cid) cidx_mask[locs] = cx return cidx_mask prog = ub.ProgIter(sampler.image_ids, 'evaluating', clearline=False) for gid in prog: img, annots = sampler.load_image_with_annots(gid) prog.ensure_newline() print('Estimate: ' + ub.repr2(evaluator.estimate, nl=0, precision=3)) full_rgb = img['imdata'] pred_heatmap = segmenter.predict(full_rgb) # Prepare truth and predictions true_cidx = camvid_load_truth(img) true_heatmap = kwimage.Heatmap(class_idx=true_cidx, classes=classes) # Ensure predictions are comparable to the truth pred_cid =['class_idx']['class_cid'] = cx_to_cid[pred_cid] # Add truth and predictions to the evaluator img_results = evaluator.add(gid, true_heatmap, pred_heatmap) prog.set_extra('mean_iou_g={:.2f}% mean_iou_t={:.2f}%'.format( img_results['mean_iou'], evaluator.estimate['mean_iou']) ) if eval_config['do_draw']: out_dpath = eval_config['out_dpath'] canvas = pred_heatmap.draw_on(full_rgb, channel='idx', with_alpha=0.5) gpath = join(out_dpath, 'gid_{:04d}.jpg'.format(gid)) kwimage.imwrite(gpath, canvas) print('Final: ' + ub.repr2(evaluator.estimate, nl=0, precision=3))