def tally_quantile_for_layer(layername): rq = tally.tally_quantile( flatten_activations, dataset=ds, sample_size=sample_size, batch_size=100, cachefile='results/{}/{}_rq_cache.npz'.format(dataset, layername)) return rq
def main(): args = parseargs() resdir = 'results/%s-%s-%s' % (args.model, args.dataset, args.seg) if args.layer is not None: resdir += '-' + args.layer if args.quantile != 0.005: resdir += ('-%g' % (args.quantile * 1000)) if args.thumbsize != 100: resdir += ('-t%d' % (args.thumbsize)) resfile = pidfile.exclusive_dirfn(resdir) model = load_model(args) layername = instrumented_layername(args) model.retain_layer(layername) dataset = load_dataset(args, model=model.model) upfn = make_upfn(args, dataset, model, layername) sample_size = len(dataset) is_generator = (args.model == 'progan') percent_level = 1.0 - args.quantile iou_threshold = args.miniou image_row_width = 5 torch.set_grad_enabled(False) # Tally rq.np (representation quantile, unconditional). pbar.descnext('rq') def compute_samples(batch, *args): data_batch = batch.cuda() _ = model(data_batch) acts = model.retained_layer(layername) hacts = upfn(acts) return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1]) rq = tally.tally_quantile(compute_samples, dataset, sample_size=sample_size, r=8192, num_workers=100, pin_memory=True, cachefile=resfile('rq.npz')) # Create visualizations - first we need to know the topk pbar.descnext('topk') def compute_image_max(batch, *args): data_batch = batch.cuda() _ = model(data_batch) acts = model.retained_layer(layername) acts = acts.view(acts.shape[0], acts.shape[1], -1) acts = acts.max(2)[0] return acts topk = tally.tally_topk(compute_image_max, dataset, sample_size=sample_size, batch_size=50, num_workers=30, pin_memory=True, cachefile=resfile('topk.npz')) # Visualize top-activating patches of top-activatin images. pbar.descnext('unit_images') image_size, image_source = None, None if is_generator: image_size = model(dataset[0][0].cuda()[None,...]).shape[2:] else: image_source = dataset iv = imgviz.ImageVisualizer((args.thumbsize, args.thumbsize), image_size=image_size, source=dataset, quantiles=rq, level=rq.quantiles(percent_level)) def compute_acts(data_batch, *ignored_class): data_batch = data_batch.cuda() out_batch = model(data_batch) acts_batch = model.retained_layer(layername) if is_generator: return (acts_batch, out_batch) else: return (acts_batch, data_batch) unit_images = iv.masked_images_for_topk( compute_acts, dataset, topk, k=image_row_width, num_workers=30, pin_memory=True, cachefile=resfile('top%dimages.npz' % image_row_width)) pbar.descnext('saving images') imgsave.save_image_set(unit_images, resfile('image/unit%d.jpg'), sourcefile=resfile('top%dimages.npz' % image_row_width)) # Compute IoU agreement between segmentation labels and every unit # Grab the 99th percentile, and tally conditional means at that level. level_at_99 = rq.quantiles(percent_level).cuda()[None,:,None,None] segmodel, seglabels, segcatlabels = setting.load_segmenter(args.seg) renorm = renormalize.renormalizer(dataset, target='zc') def compute_conditional_indicator(batch, *args): data_batch = batch.cuda() out_batch = model(data_batch) image_batch = out_batch if is_generator else renorm(data_batch) seg = segmodel.segment_batch(image_batch, downsample=4) acts = model.retained_layer(layername) hacts = upfn(acts) iacts = (hacts > level_at_99).float() # indicator return tally.conditional_samples(iacts, seg) pbar.descnext('condi99') condi99 = tally.tally_conditional_mean(compute_conditional_indicator, dataset, sample_size=sample_size, num_workers=3, pin_memory=True, cachefile=resfile('condi99.npz')) # Now summarize the iou stats and graph the units iou_99 = tally.iou_from_conditional_indicator_mean(condi99) unit_label_99 = [ (concept.item(), seglabels[concept], segcatlabels[concept], bestiou.item()) for (bestiou, concept) in zip(*iou_99.max(0))] labelcat_list = [labelcat for concept, label, labelcat, iou in unit_label_99 if iou > iou_threshold] save_conceptcat_graph(resfile('concepts_99.svg'), labelcat_list) dump_json_file(resfile('report.json'), dict( header=dict( name='%s %s %s' % (args.model, args.dataset, args.seg), image='concepts_99.svg'), units=[ dict(image='image/unit%d.jpg' % u, unit=u, iou=iou, label=label, cat=labelcat[1]) for u, (concept, label, labelcat, iou) in enumerate(unit_label_99)]) ) copy_static_file('report.html', resfile('+report.html')) resfile.done();
def _get_rq_vals(self): rq = tally.tally_quantile(self._flatten_activations, dataset=self.ds, sample_size=self.sample_size, batch_size=10) return rq
f.write('busy') print('working on', layername) inst_net = nethook.InstrumentedModel(copy.deepcopy(net)).cuda() inst_net.retain_layer('features.' + layername) inst_net(ds[0][0][None].cuda()) sample_act = inst_net.retained_layer('features.' + layername).cpu() upfn = upsample.upsampler((64, 64), sample_act.shape[2:]) def flat_acts(batch): inst_net(batch.cuda()) acts = upfn(inst_net.retained_layer('features.' + layername)) return acts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1]) s_rq = tally.tally_quantile(flat_acts, sds, cachefile=os.path.join(qd.dir(layername), 's_rq.npz')) u_rq = qd.rq(layername) def intersect_99_fn(uimg, simg): s_99 = s_rq.quantiles(0.99)[None, :, None, None].cuda() u_99 = u_rq.quantiles(0.99)[None, :, None, None].cuda() with torch.no_grad(): ux, sx = uimg.cuda(), simg.cuda() inst_net(ux) ur = inst_net.retained_layer('features.' + layername) inst_net(sx) sr = inst_net.retained_layer('features.' + layername) return ((sr > s_99).float() * (ur > u_99).float()).permute( 0, 2, 3, 1).reshape(-1, ur.size(1))
(56, 56), # The target output shape (7, 7), source=dataset, ) renorm = renormalize.renormalizer(dataset, mode='zc') def compute_samples(batch, *args): image_batch = batch.cuda() _ = model(image_batch) acts = model.retained_layer(layername) hacts = upfn(acts) return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1]) pbar.descnext('rq') rq = tally.tally_quantile(compute_samples, dataset, sample_size=sample_size, r=8192, cachefile=resfile('rq.npz')) if False: def compute_conditional_samples(batch, *args): image_batch = batch.cuda() _ = model(image_batch) acts = model.retained_layer(layername) seg = segmodel.segment_batch(renorm(image_batch), downsample=4) hacts = upfn(acts) return tally.conditional_samples(hacts, seg) pbar.descnext('condq') condq = tally.tally_conditional_quantile(compute_conditional_samples, dataset, sample_size=sample_size, cachefile=resfile('condq.npz')) level_at_99 = rq.quantiles(0.99).cuda()[None,:,None,None]
def main(): args = parseargs() resdir = 'results/%s-%s-%s-%s-%s' % (args.model, args.dataset, args.seg, args.layer, int(args.quantile * 1000)) def resfile(f): return os.path.join(resdir, f) model = load_model(args) layername = instrumented_layername(args) model.retain_layer(layername) dataset = load_dataset(args, model=model.model) upfn = make_upfn(args, dataset, model, layername) sample_size = len(dataset) is_generator = (args.model == 'progan') percent_level = 1.0 - args.quantile # Tally rq.np (representation quantile, unconditional). torch.set_grad_enabled(False) pbar.descnext('rq') def compute_samples(batch, *args): data_batch = batch.cuda() _ = model(data_batch) acts = model.retained_layer(layername) hacts = upfn(acts) return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1]) rq = tally.tally_quantile(compute_samples, dataset, sample_size=sample_size, r=8192, num_workers=100, pin_memory=True, cachefile=resfile('rq.npz')) # Grab the 99th percentile, and tally conditional means at that level. level_at_99 = rq.quantiles(percent_level).cuda()[None, :, None, None] segmodel, seglabels, segcatlabels = setting.load_segmenter(args.seg) renorm = renormalize.renormalizer(dataset, target='zc') def compute_conditional_indicator(batch, *args): data_batch = batch.cuda() out_batch = model(data_batch) image_batch = out_batch if is_generator else renorm(data_batch) seg = segmodel.segment_batch(image_batch, downsample=4) acts = model.retained_layer(layername) hacts = upfn(acts) iacts = (hacts > level_at_99).float() # indicator return tally.conditional_samples(iacts, seg) pbar.descnext('condi99') condi99 = tally.tally_conditional_mean(compute_conditional_indicator, dataset, sample_size=sample_size, num_workers=3, pin_memory=True, cachefile=resfile('condi99.npz')) # Now summarize the iou stats and graph the units iou_99 = tally.iou_from_conditional_indicator_mean(condi99) unit_label_99 = [(concept.item(), seglabels[concept], segcatlabels[concept], bestiou.item()) for (bestiou, concept) in zip(*iou_99.max(0))] def measure_segclasses_with_zeroed_units(zeroed_units, sample_size=100): model.remove_edits() def zero_some_units(x, *args): x[:, zeroed_units] = 0 return x model.edit_layer(layername, rule=zero_some_units) num_seglabels = len(segmodel.get_label_and_category_names()[0]) def compute_mean_seg_in_images(batch_z, *args): img = model(batch_z.cuda()) seg = segmodel.segment_batch(img, downsample=4) seg_area = seg.shape[2] * seg.shape[3] seg_counts = torch.bincount( (seg + (num_seglabels * torch.arange( seg.shape[0], dtype=seg.dtype, device=seg.device)[:, None, None, None])).view(-1), minlength=num_seglabels * seg.shape[0]).view(seg.shape[0], -1) seg_fracs = seg_counts.float() / seg_area return seg_fracs result = tally.tally_mean(compute_mean_seg_in_images, dataset, batch_size=30, sample_size=sample_size, pin_memory=True) model.remove_edits() return result # Intervention experiment here: # segs_baseline = measure_segclasses_with_zeroed_units([]) # segs_without_treeunits = measure_segclasses_with_zeroed_units(tree_units) num_units = len(unit_label_99) baseline_segmean = test_generator_segclass_stats( model, dataset, segmodel, layername=layername, cachefile=resfile('segstats/baseline.npz')).mean() pbar.descnext('unit ablation') unit_ablation_segmean = torch.zeros(num_units, len(baseline_segmean)) for unit in pbar(random.sample(range(num_units), num_units)): stats = test_generator_segclass_stats( model, dataset, segmodel, layername=layername, zeroed_units=[unit], cachefile=resfile('segstats/ablated_unit_%d.npz' % unit)) unit_ablation_segmean[unit] = stats.mean() ablate_segclass_name = 'tree' ablate_segclass = seglabels.index(ablate_segclass_name) best_iou_units = iou_99[ablate_segclass, :].sort(0)[1].flip(0) byiou_unit_ablation_seg = torch.zeros(30) for unitcount in pbar(random.sample(range(0, 30), 30)): zero_units = best_iou_units[:unitcount].tolist() stats = test_generator_segclass_delta_stats( model, dataset, segmodel, layername=layername, zeroed_units=zero_units, cachefile=resfile('deltasegstats/ablated_best_%d_iou_%s.npz' % (unitcount, ablate_segclass_name))) byiou_unit_ablation_seg[unitcount] = stats.mean()[ablate_segclass] # Generator context experiment. num_segclass = len(seglabels) door_segclass = seglabels.index('door') door_units = iou_99[door_segclass].sort(0)[1].flip(0)[:20] door_high_values = rq.quantiles(0.995)[door_units].cuda() def compute_seg_impact(zbatch, *args): zbatch = zbatch.cuda() model.remove_edits() orig_img = model(zbatch) orig_seg = segmodel.segment_batch(orig_img, downsample=4) orig_segcount = tally.batch_bincount(orig_seg, num_segclass) rep = model.retained_layer(layername).clone() ysize = orig_seg.shape[2] // rep.shape[2] xsize = orig_seg.shape[3] // rep.shape[3] def gen_conditions(): for y in range(rep.shape[2]): for x in range(rep.shape[3]): # Take as the context location the segmentation # labels at the center of the square. selsegs = orig_seg[:, :, y * ysize + ysize // 2, x * xsize + xsize // 2] changed_rep = rep.clone() changed_rep[:, door_units, y, x] = (door_high_values[None, :]) model.edit_layer(layername, ablation=1.0, replacement=changed_rep) changed_img = model(zbatch) changed_seg = segmodel.segment_batch(changed_img, downsample=4) changed_segcount = tally.batch_bincount( changed_seg, num_segclass) delta_segcount = (changed_segcount - orig_segcount).float() for sel, delta in zip(selsegs, delta_segcount): for cond in torch.bincount(sel).nonzero()[:, 0]: if cond == 0: continue yield (cond.item(), delta) return gen_conditions() cond_changes = tally.tally_conditional_mean( compute_seg_impact, dataset, sample_size=10000, batch_size=20, cachefile=resfile('big_door_cond_changes.npz'))