def show_seg_gt(self, num_samples=5): imgs = [] seg = [] for i in range(num_samples): img, lbl = self.ds[i] imgs.append(img) seg.append(lbl) show([(self.iv.image(imgs[i]), self.iv.segmentation(seg[i]), self.iv.segment_key_with_lbls(seg[i], self.seglabels)) for i in range(len(seg))])
def show_idx(idxlist, val_dataset, source='imagenet'): """ Show images corresponding to the inputted index list """ if len(idxlist) == 1: print("image", idxlist[0]) show( renormalize.as_image(split_img(val_dataset[idxlist[0]][0]), source=source)) else: print("idx list", idxlist) show([[ renormalize.as_image(split_img(val_dataset[idx][0]), source=source) ] for idx in idxlist])
def show_top_activating_imgs_per_units_with_seg(self, units, top_num=1): if self.topk is None: self.compute_topk_imgs() top_indexes = self.topk.result()[1] show([[ 'unit %d' % u, 'img %d' % i, 'pred: %s' % self._get_pred(i), [ self.iv.masked_image( self.ds[i.item()][0], self.model.retained_layer(self.layername)[0], u) ], [self.iv.heatmap(self.model.retained_layer(self.layername)[0], u)], [self.iv.segmentation(self.ds[i.item()][1])] ] for u in units for i in top_indexes[u, :top_num]])
def show_seg_results(self): if self.unit_images is None: self.compute_top_unit_imgs() level_at_99 = self.rq.quantiles(0.99).cuda()[None, :, None, None] sample_size = 20 def compute_selected_segments(batch, *args): img, seg = batch # show(iv.segmentation(seg)) image_batch = img.cuda() seg_batch = seg.cuda() _ = self.model(image_batch) acts = self.model.retained_layer(self.layername) hacts = self.upfn(acts) iacts = (hacts > level_at_99).float() # indicator where > 0.99 percentile. return tally.conditional_samples(iacts, seg_batch) condi99 = tally.tally_conditional_mean(compute_selected_segments, dataset=self.ds, sample_size=sample_size, loader=self.ds_loader, pass_with_lbl=True) self.iou99 = tally.iou_from_conditional_indicator_mean(condi99) bolded_string = "\033[1m" + self.layername + "\033[0m" print(bolded_string) iou_unit_label_99 = sorted([ (unit, concept.item(), self.seglabels[int(concept)], bestiou.item()) for unit, (bestiou, concept) in enumerate(zip(*self.iou99.max(0))) ], key=lambda x: -x[-1]) for unit, concept, label, score in iou_unit_label_99[:20]: show([ 'unit %d; iou %g; label "%s"' % (unit, score, label), [self.unit_images[unit]] ])