def write_summaries(self, was_best): """ write out tensorboard write out html webpage summary only update tensorboard if was a best epoch always update webpage always save N images """ if self.write_webpage: ip = ResultsPage('prediction examples', self.webpage_fn) for img_set in self.imgs_to_webpage: ip.add_table(img_set) ip.write_page() if self.tensorboard and was_best: if len(self.imgs_to_tensorboard): num_per_row = len(self.imgs_to_tensorboard[0]) # flatten array: flattenned = [] for a in self.imgs_to_tensorboard: for b in a: flattenned.append(b) imgs_to_tensorboard = torch.stack(flattenned, 0) imgs_to_tensorboard = vutils.make_grid(imgs_to_tensorboard, nrow=num_per_row, padding=5) logx.add_image('imgs', imgs_to_tensorboard, cfg.EPOCH)
def validate_topn(val_loader, net, criterion, optim, epoch, args): """ Find worse case failures ... Only single GPU for now First pass = calculate TP, FP, FN pixels per image per class Take these stats and determine the top20 images to dump per class Second pass = dump all those selected images """ assert args.bs_val == 1 ###################################################################### # First pass ###################################################################### logx.msg('First pass') image_metrics = {} net.eval() val_loss = AverageMeter() iou_acc = 0 for val_idx, data in enumerate(val_loader): # Run network assets, _iou_acc = \ run_minibatch(data, net, criterion, val_loss, True, args, val_idx) # per-class metrics input_images, labels, img_names, _ = data fp, fn = metrics_per_image(_iou_acc) img_name = img_names[0] image_metrics[img_name] = (fp, fn) iou_acc += _iou_acc if val_idx % 20 == 0: logx.msg(f'validating[Iter: {val_idx + 1} / {len(val_loader)}]') if val_idx > 5 and args.test_mode: break eval_metrics(iou_acc, args, net, optim, val_loss, epoch) ###################################################################### # Find top 20 worst failures from a pixel count perspective ###################################################################### from collections import defaultdict worst_images = defaultdict(dict) class_to_images = defaultdict(dict) for classid in range(cfg.DATASET.NUM_CLASSES): tbl = {} for img_name in image_metrics.keys(): fp, fn = image_metrics[img_name] fp = fp[classid] fn = fn[classid] tbl[img_name] = fp + fn worst = sorted(tbl, key=tbl.get, reverse=True) for img_name in worst[:args.dump_topn]: fail_pixels = tbl[img_name] worst_images[img_name][classid] = fail_pixels class_to_images[classid][img_name] = fail_pixels msg = str(worst_images) logx.msg(msg) # write out per-gpu jsons # barrier # make single table ###################################################################### # 2nd pass ###################################################################### logx.msg('Second pass') attn_map = None for val_idx, data in enumerate(val_loader): in_image, gt_image, img_names, _ = data # Only process images that were identified in first pass if not args.dump_topn_all and img_names[0] not in worst_images: continue with torch.no_grad(): inputs = in_image.cuda() inputs = {'images': inputs, 'gts': gt_image} if cfg.MODEL.MSCALE: output, attn_map = net(inputs) else: output = net(inputs) output = torch.nn.functional.softmax(output, dim=1) prob_mask, predictions = output.data.max(1) predictions = predictions.cpu() # this has shape [bs, h, w] img_name = img_names[0] for classid in worst_images[img_name].keys(): err_mask = calc_err_mask(predictions.numpy(), gt_image.numpy(), cfg.DATASET.NUM_CLASSES, classid) class_name = cfg.DATASET_INST.trainid_to_name[classid] error_pixels = worst_images[img_name][classid] logx.msg(f'{img_name} {class_name}: {error_pixels}') img_names = [img_name + f'_{class_name}'] to_dump = { 'gt_images': gt_image, 'input_images': in_image, 'predictions': predictions.numpy(), 'err_mask': err_mask, 'prob_mask': prob_mask, 'img_names': img_names } if attn_map is not None: to_dump['attn_maps'] = attn_map # FIXME! # do_dump_images([to_dump]) html_fn = os.path.join(args.result_dir, 'best_images', 'topn_failures.html') from utils.results_page import ResultsPage ip = ResultsPage('topn failures', html_fn) for classid in class_to_images: class_name = cfg.DATASET_INST.trainid_to_name[classid] img_dict = class_to_images[classid] for img_name in sorted(img_dict, key=img_dict.get, reverse=True): fail_pixels = class_to_images[classid][img_name] img_cls = f'{img_name}_{class_name}' pred_fn = f'{img_cls}_prediction.png' gt_fn = f'{img_cls}_gt.png' inp_fn = f'{img_cls}_input.png' err_fn = f'{img_cls}_err_mask.png' prob_fn = f'{img_cls}_prob_mask.png' img_label_pairs = [(pred_fn, 'pred'), (gt_fn, 'gt'), (inp_fn, 'input'), (err_fn, 'errors'), (prob_fn, 'prob')] ip.add_table(img_label_pairs, table_heading=f'{class_name}-{fail_pixels}') ip.write_page() return val_loss.avg