def measure_segclasses_with_zeroed_units(zeroed_units, sample_size=100): model.remove_edits() def zero_some_units(x, *args): x[:, zeroed_units] = 0 return x model.edit_layer(layername, rule=zero_some_units) num_seglabels = len(segmodel.get_label_and_category_names()[0]) def compute_mean_seg_in_images(batch_z, *args): img = model(batch_z.cuda()) seg = segmodel.segment_batch(img, downsample=4) seg_area = seg.shape[2] * seg.shape[3] seg_counts = torch.bincount( (seg + (num_seglabels * torch.arange( seg.shape[0], dtype=seg.dtype, device=seg.device)[:, None, None, None])).view(-1), minlength=num_seglabels * seg.shape[0]).view(seg.shape[0], -1) seg_fracs = seg_counts.float() / seg_area return seg_fracs result = tally.tally_mean(compute_mean_seg_in_images, dataset, batch_size=30, sample_size=sample_size, pin_memory=True) model.remove_edits() return result
def test_generator_segclass_delta_stats(model, dataset, segmodel, layername=None, zeroed_units=None, sample_size=None, cachefile=None): model.remove_edits() def zero_some_units(x, *args): x[:, zeroed_units] = 0 return x num_seglabels = len(segmodel.get_label_and_category_names()[0]) def compute_mean_delta_seg_in_images(batch_z, *args): # First baseline model.remove_edits() img = model(batch_z.cuda()) seg = segmodel.segment_batch(img, downsample=4) seg_area = seg.shape[2] * seg.shape[3] seg_counts = torch.bincount((seg + (num_seglabels * torch.arange(seg.shape[0], dtype=seg.dtype, device=seg.device )[:,None,None,None])).view(-1), minlength=num_seglabels * seg.shape[0]).view(seg.shape[0], -1) seg_fracs = seg_counts.float() / seg_area # Then with changes model.edit_layer(layername, rule=zero_some_units) d_img = model(batch_z.cuda()) d_seg = segmodel.segment_batch(d_img, downsample=4) d_seg_counts = torch.bincount((d_seg + (num_seglabels * torch.arange(seg.shape[0], dtype=seg.dtype, device=seg.device )[:,None,None,None])).view(-1), minlength=num_seglabels * seg.shape[0]).view(seg.shape[0], -1) d_seg_fracs = d_seg_counts.float() / seg_area return d_seg_fracs - seg_fracs result = tally.tally_mean(compute_mean_delta_seg_in_images, dataset, batch_size=25, sample_size=sample_size, pin_memory=True, cachefile=cachefile) model.remove_edits() return result
u_rq = qd.rq(layername) def intersect_99_fn(uimg, simg): s_99 = s_rq.quantiles(0.99)[None, :, None, None].cuda() u_99 = u_rq.quantiles(0.99)[None, :, None, None].cuda() with torch.no_grad(): ux, sx = uimg.cuda(), simg.cuda() inst_net(ux) ur = inst_net.retained_layer('features.' + layername) inst_net(sx) sr = inst_net.retained_layer('features.' + layername) return ((sr > s_99).float() * (ur > u_99).float()).permute( 0, 2, 3, 1).reshape(-1, ur.size(1)) intersect_99 = tally.tally_mean(intersect_99_fn, ds, cachefile=os.path.join( qd.dir(layername), 'intersect_99.npz')) def compute_image_max(batch): inst_net(batch.cuda()) return inst_net.retained_layer('features.' + layername).max(3)[0].max(2)[0] s_topk = tally.tally_topk(compute_image_max, sds, cachefile=os.path.join(qd.dir(layername), 's_topk.npz')) def compute_acts(image_batch): inst_net(image_batch.cuda()) acts_batch = inst_net.retained_layer('features.' + layername)