def compute_topk_imgs(self, mode='mean'): if mode == 'mean': self.topk = tally.tally_topk(self._mean_activations, dataset=self.ds, sample_size=self.sample_size, batch_size=10) else: # It can only be max if not mean self.topk = tally.tally_topk(self._max_activations, dataset=self.ds, sample_size=self.sample_size, batch_size=10)
def main(): args = parseargs() resdir = 'results/%s-%s-%s' % (args.model, args.dataset, args.seg) if args.layer is not None: resdir += '-' + args.layer if args.quantile != 0.005: resdir += ('-%g' % (args.quantile * 1000)) if args.thumbsize != 100: resdir += ('-t%d' % (args.thumbsize)) resfile = pidfile.exclusive_dirfn(resdir) model = load_model(args) layername = instrumented_layername(args) model.retain_layer(layername) dataset = load_dataset(args, model=model.model) upfn = make_upfn(args, dataset, model, layername) sample_size = len(dataset) is_generator = (args.model == 'progan') percent_level = 1.0 - args.quantile iou_threshold = args.miniou image_row_width = 5 torch.set_grad_enabled(False) # Tally rq.np (representation quantile, unconditional). pbar.descnext('rq') def compute_samples(batch, *args): data_batch = batch.cuda() _ = model(data_batch) acts = model.retained_layer(layername) hacts = upfn(acts) return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1]) rq = tally.tally_quantile(compute_samples, dataset, sample_size=sample_size, r=8192, num_workers=100, pin_memory=True, cachefile=resfile('rq.npz')) # Create visualizations - first we need to know the topk pbar.descnext('topk') def compute_image_max(batch, *args): data_batch = batch.cuda() _ = model(data_batch) acts = model.retained_layer(layername) acts = acts.view(acts.shape[0], acts.shape[1], -1) acts = acts.max(2)[0] return acts topk = tally.tally_topk(compute_image_max, dataset, sample_size=sample_size, batch_size=50, num_workers=30, pin_memory=True, cachefile=resfile('topk.npz')) # Visualize top-activating patches of top-activatin images. pbar.descnext('unit_images') image_size, image_source = None, None if is_generator: image_size = model(dataset[0][0].cuda()[None,...]).shape[2:] else: image_source = dataset iv = imgviz.ImageVisualizer((args.thumbsize, args.thumbsize), image_size=image_size, source=dataset, quantiles=rq, level=rq.quantiles(percent_level)) def compute_acts(data_batch, *ignored_class): data_batch = data_batch.cuda() out_batch = model(data_batch) acts_batch = model.retained_layer(layername) if is_generator: return (acts_batch, out_batch) else: return (acts_batch, data_batch) unit_images = iv.masked_images_for_topk( compute_acts, dataset, topk, k=image_row_width, num_workers=30, pin_memory=True, cachefile=resfile('top%dimages.npz' % image_row_width)) pbar.descnext('saving images') imgsave.save_image_set(unit_images, resfile('image/unit%d.jpg'), sourcefile=resfile('top%dimages.npz' % image_row_width)) # Compute IoU agreement between segmentation labels and every unit # Grab the 99th percentile, and tally conditional means at that level. level_at_99 = rq.quantiles(percent_level).cuda()[None,:,None,None] segmodel, seglabels, segcatlabels = setting.load_segmenter(args.seg) renorm = renormalize.renormalizer(dataset, target='zc') def compute_conditional_indicator(batch, *args): data_batch = batch.cuda() out_batch = model(data_batch) image_batch = out_batch if is_generator else renorm(data_batch) seg = segmodel.segment_batch(image_batch, downsample=4) acts = model.retained_layer(layername) hacts = upfn(acts) iacts = (hacts > level_at_99).float() # indicator return tally.conditional_samples(iacts, seg) pbar.descnext('condi99') condi99 = tally.tally_conditional_mean(compute_conditional_indicator, dataset, sample_size=sample_size, num_workers=3, pin_memory=True, cachefile=resfile('condi99.npz')) # Now summarize the iou stats and graph the units iou_99 = tally.iou_from_conditional_indicator_mean(condi99) unit_label_99 = [ (concept.item(), seglabels[concept], segcatlabels[concept], bestiou.item()) for (bestiou, concept) in zip(*iou_99.max(0))] labelcat_list = [labelcat for concept, label, labelcat, iou in unit_label_99 if iou > iou_threshold] save_conceptcat_graph(resfile('concepts_99.svg'), labelcat_list) dump_json_file(resfile('report.json'), dict( header=dict( name='%s %s %s' % (args.model, args.dataset, args.seg), image='concepts_99.svg'), units=[ dict(image='image/unit%d.jpg' % u, unit=u, iou=iou, label=label, cat=labelcat[1]) for u, (concept, label, labelcat, iou) in enumerate(unit_label_99)]) ) copy_static_file('report.html', resfile('+report.html')) resfile.done();
sr = inst_net.retained_layer('features.' + layername) return ((sr > s_99).float() * (ur > u_99).float()).permute( 0, 2, 3, 1).reshape(-1, ur.size(1)) intersect_99 = tally.tally_mean(intersect_99_fn, ds, cachefile=os.path.join( qd.dir(layername), 'intersect_99.npz')) def compute_image_max(batch): inst_net(batch.cuda()) return inst_net.retained_layer('features.' + layername).max(3)[0].max(2)[0] s_topk = tally.tally_topk(compute_image_max, sds, cachefile=os.path.join(qd.dir(layername), 's_topk.npz')) def compute_acts(image_batch): inst_net(image_batch.cuda()) acts_batch = inst_net.retained_layer('features.' + layername) return (acts_batch, image_batch) iv = imgviz.ImageVisualizer(128, quantiles=s_rq, source=sds) unit_images = iv.masked_images_for_topk(compute_acts, sds, s_topk, k=5) os.makedirs(os.path.join(qd.dir(layername), 's_imgs'), exist_ok=True) imgsave.save_image_set( unit_images, os.path.join(qd.dir(layername), 's_imgs/unit%d.jpg')) iv = imgviz.ImageVisualizer(128, quantiles=u_rq, source=uds) unit_images = iv.masked_images_for_topk(compute_acts, uds, s_topk, k=5)
def main(): # Load the arguments args = parse_option() dataset = args.dataset sample_size = args.sample_size layername = args.layer # Other values for places and imagenet MoCo model epoch = 240 image_size = 224 crop = 0.2 crop_padding = 32 batch_size = 1 num_workers = 24 train_sampler = None moco = True mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] normalize = transforms.Normalize(mean=mean, std=std) # Set appropriate paths folder_path = "/data/vision/torralba/ganprojects/yyou/CMC_data/{}_models".format( dataset) model_name = "/{}_MoCo0.999_softmax_16384_resnet50_lr_0.03".format(dataset) \ + "_decay_0.0001_bsz_128_crop_0.2_aug_CJ" epoch_name = "/ckpt_epoch_{}.pth".format(epoch) my_path = folder_path + model_name + epoch_name data_path = "/data/vision/torralba/datasets/" web_path = "/data/vision/torralba/scratch/yyou/wednesday/dissection/" if dataset == "imagenet": data_path += "imagenet_pytorch" web_path += dataset + "/" + layername elif dataset == "places365": data_path += "places/places365_standard/places365standard_easyformat" web_path += dataset + "/" + layername # Create web path folder directory for this layer if not os.path.exists(web_path): os.makedirs(web_path) # Load validation data loader val_folder = data_path + "/val" val_transform = transforms.Compose([ transforms.Resize(image_size + crop_padding), transforms.CenterCrop(image_size), transforms.ToTensor(), normalize, ]) ds = QuickImageFolder(val_folder, transform=val_transform, shuffle=True, two_crop=False) ds_loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=(train_sampler is None), num_workers=num_workers, pin_memory=True, sampler=train_sampler) # Load model from checkpoint checkpoint = torch.load(my_path) model_checkpoint = { key.replace(".module", ""): val for key, val in checkpoint['model'].items() } model = InsResNet50(parallel=False) model.load_state_dict(model_checkpoint) model = nethook.InstrumentedModel(model) model.cuda() # Renormalize RGB data from the staistical scaling in ds to [-1...1] range renorm = renormalize.renormalizer(source=ds, target='zc') # Retain desired layer with nethook batch = next(iter(ds_loader))[0] model.retain_layer(layername) model(batch.cuda()) acts = model.retained_layer(layername).cpu() upfn = upsample.upsampler( target_shape=(56, 56), data_shape=(7, 7), ) def flatten_activations(batch, *args): image_batch = batch _ = model(image_batch.cuda()) acts = model.retained_layer(layername).cpu() hacts = upfn(acts) return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1]) def tally_quantile_for_layer(layername): rq = tally.tally_quantile( flatten_activations, dataset=ds, sample_size=sample_size, batch_size=100, cachefile='results/{}/{}_rq_cache.npz'.format(dataset, layername)) return rq rq = tally_quantile_for_layer(layername) # Visualize range of activations (statistics of each filter over the sample images) fig, axs = plt.subplots(2, 2, figsize=(10, 8)) axs = axs.flatten() quantiles = [0.5, 0.8, 0.9, 0.99] for i in range(4): axs[i].plot(rq.quantiles(quantiles[i])) axs[i].set_title("Rq quantiles ({})".format(quantiles[i])) fig.suptitle("{} - sample size of {}".format(dataset, sample_size)) plt.savefig(web_path + "/rq_quantiles") # Set the image visualizer with the rq and percent level iv = imgviz.ImageVisualizer(224, source=ds, percent_level=0.95, quantiles=rq) # Tally top k images that maximize the mean activation of the filter def max_activations(batch, *args): image_batch = batch.cuda() _ = model(image_batch) acts = model.retained_layer(layername) return acts.view(acts.shape[:2] + (-1, )).max(2)[0] def mean_activations(batch, *args): image_batch = batch.cuda() _ = model(image_batch) acts = model.retained_layer(layername) return acts.view(acts.shape[:2] + (-1, )).mean(2) topk = tally.tally_topk( mean_activations, dataset=ds, sample_size=sample_size, batch_size=100, cachefile='results/{}/{}_cache_mean_topk.npz'.format( dataset, layername)) top_indexes = topk.result()[1] # Visualize top-activating images for a particular unit if not os.path.exists(web_path + "/top_activating_imgs"): os.makedirs(web_path + "/top_activating_imgs") def top_activating_imgs(unit): img_ids = [i for i in top_indexes[unit, :12]] images = [iv.masked_image(ds[i][0], \ model.retained_layer(layername)[0], unit) \ for i in img_ids] preds = [ds.classes[model(ds[i][0][None].cuda()).max(1)[1].item()]\ for i in img_ids] fig, axs = plt.subplots(3, 4, figsize=(16, 12)) axs = axs.flatten() for i in range(12): axs[i].imshow(images[i]) axs[i].tick_params(axis='both', which='both', bottom=False, \ left=False, labelbottom=False, labelleft=False) axs[i].set_title("img {} \n pred: {}".format(img_ids[i], preds[i])) fig.suptitle("unit {}".format(unit)) plt.savefig(web_path + "/top_activating_imgs/unit_{}".format(unit)) for unit in np.random.randint(len(top_indexes), size=10): top_activating_imgs(unit) def compute_activations(image_batch): image_batch = image_batch.cuda() _ = model(image_batch) acts_batch = model.retained_layer(layername) return acts_batch unit_images = iv.masked_images_for_topk( compute_activations, ds, topk, k=5, num_workers=10, pin_memory=True, cachefile='results/{}/{}_cache_top10images.npz'.format( dataset, layername)) file = open("results/{}/unit_images.pkl".format(dataset, layername), 'wb') pickle.dump(unit_images, file) # Load a segmentation model segmodel, seglabels, segcatlabels = setting.load_segmenter('netpqc') # Intersections between every unit's 99th activation # and every segmentation class identified level_at_99 = rq.quantiles(0.99).cuda()[None, :, None, None] def compute_selected_segments(batch, *args): image_batch = batch.cuda() seg = segmodel.segment_batch(renorm(image_batch), downsample=4) _ = model(image_batch) acts = model.retained_layer(layername) hacts = upfn(acts) iacts = (hacts > level_at_99).float() # indicator where > 0.99 percentile. return tally.conditional_samples(iacts, seg) condi99 = tally.tally_conditional_mean( compute_selected_segments, dataset=ds, sample_size=sample_size, cachefile='results/{}/{}_cache_condi99.npz'.format(dataset, layername)) iou99 = tally.iou_from_conditional_indicator_mean(condi99) file = open("results/{}/{}_iou99.pkl".format(dataset, layername), 'wb') pickle.dump(iou99, file) # Show units with best match to a segmentation class iou_unit_label_99 = sorted( [(unit, concept.item(), seglabels[concept], bestiou.item()) for unit, (bestiou, concept) in enumerate(zip(*iou99.max(0)))], key=lambda x: -x[-1]) fig, axs = plt.subplots(20, 1, figsize=(20, 80)) axs = axs.flatten() for i, (unit, concept, label, score) in enumerate(iou_unit_label_99[:20]): axs[i].imshow(unit_images[unit]) axs[i].set_title('unit %d; iou %g; label "%s"' % (unit, score, label)) axs[i].set_xticks([]) axs[i].set_yticks([]) plt.savefig(web_path + "/best_unit_segmentation")