def tally_topk(compute, dataset, sample_size=None, batch_size=10, k=100, cachefile=None, **kwargs): ''' Computes the topk statistics for a large data sample that can be computed from a dataset. The compute function should return one batch of samples as a (sample, unit)-dimension tensor. k specifies the number of top samples to retain. Results are returned as a RunningTopK object. ''' with torch.no_grad(): args = dict(sample_size=sample_size, k=k) cached_state = load_cached_state(cachefile, args) if cached_state is not None: return runningstats.RunningTopK(state=cached_state) rtk = runningstats.RunningTopK(k=k) loader = make_loader(dataset, sample_size, batch_size, **kwargs) for batch in pbar(loader): sample = call_compute(compute, batch) rtk.add(sample) rtk.to_('cpu') save_cached_state(cachefile, rtk, args) return rtk
def tally_bincount(compute, dataset, sample_size=None, batch_size=10, multi_label_axis=None, cachefile=None, **kwargs): ''' Computes bincount totals for a large data sample that can be computed from a dataset. The compute function should return one batch of samples as a (sample, unit)-dimension tensor. ''' with torch.no_grad(): args = dict(sample_size=sample_size) cached_state = load_cached_state(cachefile, args) if cached_state is not None: return runningstats.RunningBincount(state=cached_state) loader = make_loader(dataset, sample_size, batch_size, **kwargs) rbc = runningstats.RunningBincount() for batch in pbar(loader): sample = call_compute(compute, batch) if multi_label_axis: multilabel = sample.shape[multi_label_axis] size = sample.numel() // multilabel else: size = None rbc.add(sample, size=size) rbc.to_('cpu') save_cached_state(cachefile, rbc, args) return rbc
def compute_mean_present_features(args, corpus, cache_filename, model): # Phase 1.5. Figure mean activations for every channel where there # is a doorway. if all(k in corpus for k in ['mean_present_feature']): return with torch.no_grad(): total_present_feature = 0 for [zbatch, featloc] in pbar(torch.utils.data.DataLoader( TensorDataset(corpus.object_present_sample, corpus.object_present_location), batch_size=args.inference_batch_size, num_workers=10, pin_memory=True), desc="Mean activations"): zbatch = zbatch.cuda() featloc = featloc.cuda() tensor_image = model(zbatch) feat = model.retained_layer(args.layer) flatfeat = feat.view(feat.shape[0], feat.shape[1], -1) sum_feature_at_obj = flatfeat[ torch.arange(feat.shape[0]).to(feat.device), :, featloc].sum(0) total_present_feature = total_present_feature + sum_feature_at_obj corpus.mean_present_feature = ( total_present_feature / len(corpus.object_present_sample)).cpu() if cache_filename: numpy.savez(cache_filename, **corpus)
def tally_directory(directory, size=10000, seed=1): dataset = parallelfolder.ParallelImageFolders( [directory], transform=transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) loader = DataLoader( dataset, sampler=FixedRandomSubsetSampler(dataset, end=size, seed=1), # sampler=FixedSubsetSampler(range(size)), batch_size=10, pin_memory=True) upp = segmenter.UnifiedParsingSegmenter() labelnames, catnames = upp.get_label_and_category_names() result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float) batch_result = torch.zeros(loader.batch_size, NUM_OBJECTS, dtype=torch.float).cuda() with torch.no_grad(): batch_index = 0 for [batch] in pbar(loader): seg_result = upp.segment_batch(batch.cuda()) for i in range(len(batch)): batch_result[i] = (seg_result[i, 0].view(-1).bincount( minlength=NUM_OBJECTS).float() / (seg_result.shape[2] * seg_result.shape[3])) result[batch_index:batch_index + len(batch)] = (batch_result.cpu().numpy()) batch_index += len(batch) return result
def tally_conditional_mean(compute, dataset, sample_size=None, batch_size=1, cachefile=None, **kwargs): ''' Computes conditional mean and variance for a large data sample that can be computed from a dataset. The compute function should return a sequence of sample batch tuples (condition, (sample, unit)-tensor), one for each condition relevant to the batch. ''' with torch.no_grad(): args = dict(sample_size=sample_size) cached_state = load_cached_state(cachefile, args) if cached_state is not None: return runningstats.RunningConditionalVariance(state=cached_state) loader = make_loader(dataset, sample_size, batch_size, **kwargs) cv = runningstats.RunningConditionalVariance() for i, batch in enumerate(pbar(loader)): sample_set = call_compute(compute, batch) for cond, sample in sample_set: # Move uncommon conditional data to the cpu before collating. cv.add(cond, sample) # At the end, move all to the CPU cv.to_('cpu') save_cached_state(cachefile, cv, args) return cv
def compute_feature_quantiles(args, corpus, cache_filename, model, full_sample): # Phase 1.6. Figure the 99% and 99.9%ile of every feature. if all(k in corpus for k in ['feature_99', 'feature_999']): return with torch.no_grad(): rq = RunningQuantile(r=5000) # 10x what's needed. for [zbatch] in pbar(torch.utils.data.DataLoader( TensorDataset(full_sample), batch_size=args.inference_batch_size, num_workers=10, pin_memory=True), desc="Calculating 0.999 quantile"): zbatch = zbatch.cuda() tensor_image = model(zbatch) feat = model.retained_layer(args.layer) rq.add( feat.permute(0, 2, 3, 1).contiguous().view(-1, feat.shape[1])) result = rq.quantiles([0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999]) corpus.feature_001 = result[:, 0].cpu() corpus.feature_01 = result[:, 1].cpu() corpus.feature_10 = result[:, 2].cpu() corpus.feature_50 = result[:, 3].cpu() corpus.feature_90 = result[:, 4].cpu() corpus.feature_99 = result[:, 5].cpu() corpus.feature_999 = result[:, 6].cpu() numpy.savez(cache_filename, **corpus)
def validate_and_checkpoint(): model.eval() val_loss, val_acc = AverageMeter(), AverageMeter() for input, target in pbar(val_loader): # Load data input_var, target_var = [d.cuda() for d in [input, target]] # Evaluate model with torch.no_grad(): output = model(input_var) loss = criterion(output, target_var) _, pred = output.max(1) accuracy = (target_var.eq(pred) ).data.float().sum().item() / input.size(0) val_loss.update(loss.data.item(), input.size(0)) val_acc.update(accuracy, input.size(0)) # Check accuracy pbar.post(l=val_loss.avg, a=val_acc.avg) # Save checkpoint save_checkpoint( { 'iter': iter_num, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'accuracy': val_acc.avg, 'loss': val_loss.avg, }, val_acc.avg > best['val_accuracy']) best['val_accuracy'] = max(val_acc.avg, best['val_accuracy']) printstat('Iteration %d val accuracy %.2f' % (iter_num, val_acc.avg * 100.0))
def tally_quantile(compute, dataset, sample_size=None, batch_size=10, r=4096, cachefile=None, **kwargs): ''' Computes quantile sketch statistics for a large data sample that can be computed from a dataset. The compute function should return one batch of samples as a (sample, unit)-dimension tensor. The underlying quantile sketch is an optimal KLL sorted sampler that retains at least r samples (where r is the specified resolution). ''' with torch.no_grad(): args = dict(sample_size=sample_size, r=r) cached_state = load_cached_state(cachefile, args) if cached_state is not None: return runningstats.RunningQuantile(state=cached_state) loader = make_loader(dataset, sample_size, batch_size, **kwargs) rq = runningstats.RunningQuantile() for batch in pbar(loader): sample = call_compute(compute, batch) rq.add(sample) rq.to_('cpu') save_cached_state(cachefile, rq, args) return rq
def my_test_perclass(model, dataset, layername=None, ablated_units=None, sample_size=None, cachefile=None): model.remove_edits() if ablated_units is not None: def ablate_the_units(x, *args): x[:, ablated_units] = 0 return x model.edit_layer(layername, rule=ablate_the_units) # sampler = None if sample_size is None else ( # FixedSubsetSampler(list(range(sample_size)))) with torch.no_grad(): num_classes = 101 loader = torch.utils.data.DataLoader( dataset, batch_size=16, num_workers=6, # num_workers=20, batch_size=100 pin_memory=True) test_running_correct = 0 total_count = 0 test_correct_per_class = [0] * num_classes total_per_class = [0] * num_classes acc_per_class = [0] * num_classes for batch_idx, (image_batch, class_batch) in enumerate(pbar(loader)): total_count += len(image_batch) image_batch, class_batch = [d.cuda() for d in [image_batch, class_batch]] # image batch has shape of [bs,test_frame_size,c,l,w] batch_size, test_frame_size, channels, l, w = image_batch.shape images = torch.reshape(image_batch, (batch_size * test_frame_size, channels, l, -1)) # images has shape of [bs*test_frame_size, c,l,w] scores = model(images) scores = torch.reshape(scores, (batch_size, test_frame_size, -1)) # scores has shape of [bs,test_frame_size,101] scores = scores.mean(dim=1) #scores has shape of [bs,101] # loss = criterion(outputs, labels) # test_running_loss += loss.item() _, preds = scores.max(1) correct = (preds == class_batch) test_running_correct += correct.sum().item() torch.cuda.empty_cache() for idx, (label, pred) in enumerate(zip(class_batch, preds)): if(label.item() == pred.item()): test_correct_per_class[label.item()] += 1 total_per_class[label.item()] += 1 torch.cuda.empty_cache() for idx, (corr, total) in enumerate(zip(test_correct_per_class, total_per_class)): acc_per_class[idx] = round(corr/total,4) print("Predicted correctly %d" % (test_running_correct)) print("Total data points: %d" % (total_count)) accuracy = round(test_running_correct / total_count,4) print('Test Acc: %.3f percent' % (100 * accuracy)) return accuracy, np.asarray(acc_per_class)
def measure_ablation(segmenter, loader, model, classnum, layer, ordering): total_bincount = 0 data_size = 0 device = next(model.parameters()).device for l in model.ablation: model.ablation[l] = None feature_units = model.feature_shape[layer][1] feature_shape = model.feature_shape[layer][2:] repeats = len(ordering) total_scores = torch.zeros(repeats + 1) for i, batch in enumerate(pbar(loader)): z_batch = batch[0] model.ablation[layer] = None tensor_images = model(z_batch.to(device)) seg = segmenter.segment_batch(tensor_images, downsample=2) mask = (seg == classnum).max(1)[0] downsampled_seg = torch.nn.functional.adaptive_avg_pool2d( mask.float()[:, None, :, :], feature_shape)[:, 0, :, :] total_scores[0] += downsampled_seg.sum().cpu() # Now we need to do an intervention for every location # that had a nonzero downsampled_seg, if any. interventions_needed = downsampled_seg.nonzero() location_count = len(interventions_needed) if location_count == 0: continue interventions_needed = interventions_needed.repeat(repeats, 1) inter_z = batch[0][interventions_needed[:, 0]].to(device) inter_chan = torch.zeros(repeats, location_count, feature_units, device=device) for j, u in enumerate(ordering): inter_chan[j:, :, u] = 1 inter_chan = inter_chan.view(len(inter_z), feature_units) inter_loc = interventions_needed[:, 1:] scores = torch.zeros(len(inter_z)) batch_size = len(batch[0]) for j in range(0, len(inter_z), batch_size): ibz = inter_z[j:j + batch_size] ibl = inter_loc[j:j + batch_size].t() imask = torch.zeros((len(ibz), ) + feature_shape, device=ibz.device) imask[(torch.arange(len(ibz)), ) + tuple(ibl)] = 1 ibc = inter_chan[j:j + batch_size] model.ablation[layer] = (imask.float()[:, None, :, :] * ibc[:, :, None, None]) tensor_images = model(ibz) seg = segmenter.segment_batch(tensor_images, downsample=2) mask = (seg == classnum).max(1)[0] downsampled_iseg = torch.nn.functional.adaptive_avg_pool2d( mask.float()[:, None, :, :], feature_shape)[:, 0, :, :] scores[j:j + batch_size] = downsampled_iseg[(torch.arange(len(ibz)), ) + tuple(ibl)] scores = scores.view(repeats, location_count).sum(1) total_scores[1:] += scores return total_scores
def compute_present_locations(args, corpus, cache_filename, model, segmenter, classnum, full_sample): # Phase 1. Identify a set of locations where there are doorways. # Segment the image and find featuremap pixels that maximize the number # of doorway pixels under the featuremap pixel. if all(k in corpus for k in [ 'present_indices', 'object_present_sample', 'object_present_location', 'object_location_popularity', 'weighted_mean_present_feature' ]): return feature_shape = model.feature_shape[args.layer][2:] num_locations = numpy.prod(feature_shape).item() num_units = model.feature_shape[args.layer][1] with torch.no_grad(): weighted_feature_sum = torch.zeros(num_units).cuda() object_presence_scores = [] for [zbatch] in pbar(torch.utils.data.DataLoader( TensorDataset(full_sample), batch_size=args.inference_batch_size, num_workers=10, pin_memory=True), desc="Object pool"): zbatch = zbatch.cuda() tensor_image = model(zbatch) segmented_image = segmenter.segment_batch(tensor_image, downsample=2) mask = (segmented_image == classnum).max(1)[0] score = torch.nn.functional.adaptive_avg_pool2d( mask.float(), feature_shape) object_presence_scores.append(score.cpu()) feat = model.retained_layer(args.layer) weighted_feature_sum += (feat * score[:, None, :, :]).view( feat.shape[0], feat.shape[1], -1).sum(2).sum(0) object_presence_at_feature = torch.cat(object_presence_scores) object_presence_at_image, object_location_in_image = ( object_presence_at_feature.view(args.search_size, -1).max(1)) best_presence_scores, best_presence_images = torch.sort( -object_presence_at_image) all_present_indices = torch.sort( best_presence_images[:(args.train_size + args.eval_size)])[0] corpus.present_indices = all_present_indices[:args.train_size] corpus.object_present_sample = full_sample[corpus.present_indices] corpus.object_present_location = object_location_in_image[ corpus.present_indices] corpus.object_location_popularity = torch.bincount( corpus.object_present_location, minlength=num_locations) corpus.weighted_mean_present_feature = ( weighted_feature_sum.cpu() / (1e-20 + object_presence_at_feature.view(-1).sum())) corpus.eval_present_indices = all_present_indices[-args.eval_size:] corpus.eval_present_sample = full_sample[corpus.eval_present_indices] corpus.eval_present_location = object_location_in_image[ corpus.eval_present_indices] if cache_filename: numpy.savez(cache_filename, **corpus)
def tally_cat(compute, dataset, sample_size=None, batch_size=10, **kwargs): ''' Computes a concatenated tensor for data batches that can be computed from a dataset. The compute function should return a tensor that should be concatenated to the others along its first dimension. ''' with torch.no_grad(): loader = make_loader(dataset, sample_size, batch_size, **kwargs) result = [] for batch in pbar(loader): result.append(call_compute(compute, batch).cpu()) return torch.cat(result)
def walk_image_files(rootdir, verbose=None): print("Walking image files ... ") indexfile = '%s.txt' % rootdir if os.path.isfile(indexfile): print("from index file: ", indexfile) basedir = os.path.dirname(rootdir) with open(indexfile) as f: result = sorted([ os.path.join(basedir, line.strip()) for line in f.readlines() ]) return result result = [] for dirname, _, fnames in sorted( pbar(os.walk(rootdir), desc='Walking %s' % os.path.basename(rootdir))): for fname in sorted(fnames): if is_image_file(fname) or is_npy_file(fname): result.append(os.path.join(dirname, fname)) return result
def visualize_training_locations(args, corpus, cachedir, model): # Phase 2.5 Create visualizations of the corpus images. feature_shape = model.feature_shape[args.layer][2:] num_locations = numpy.prod(feature_shape).item() with torch.no_grad(): imagedir = os.path.join(cachedir, 'image') os.makedirs(imagedir, exist_ok=True) image_saver = WorkerPool(SaveImageWorker) for group, group_sample, group_location, group_indices in [ ('present', corpus.object_present_sample, corpus.object_present_location, corpus.present_indices), ('candidate', corpus.candidate_sample, corpus.candidate_location, corpus.candidate_indices) ]: for [zbatch, featloc, indices] in pbar(torch.utils.data.DataLoader( TensorDataset(group_sample, group_location, group_indices), batch_size=args.inference_batch_size, num_workers=10, pin_memory=True), desc="Visualize %s" % group): zbatch = zbatch.cuda() tensor_image = model(zbatch) feature_mask = torch.zeros((len(zbatch), 1) + feature_shape) feature_mask.view(len(zbatch), -1).scatter_(1, featloc[:, None], 1) feature_mask = torch.nn.functional.adaptive_max_pool2d( feature_mask.float(), tensor_image.shape[-2:]).cuda() yellow = torch.Tensor([1.0, 1.0, -1.0])[None, :, None, None].cuda() tensor_image = tensor_image * (1 - 0.5 * feature_mask) + ( 0.5 * feature_mask * yellow) byte_image = (((tensor_image + 1) / 2) * 255).clamp( 0, 255).byte() numpy_image = byte_image.permute(0, 2, 3, 1).cpu().numpy() for i, index in enumerate(indices): image_saver.add( numpy_image[i], os.path.join(imagedir, '%s_%d.jpg' % (group, index))) image_saver.join()
def tally_generated_objects(model, size=10000): zds = zdataset.z_dataset_for_model(model, size) loader = DataLoader(zds, batch_size=10, pin_memory=True) upp = segmenter.UnifiedParsingSegmenter() labelnames, catnames = upp.get_label_and_category_names() result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float) batch_result = torch.zeros(loader.batch_size, NUM_OBJECTS, dtype=torch.float).cuda() with torch.no_grad(): batch_index = 0 for [zbatch] in pbar(loader): img = model(zbatch.cuda()) seg_result = upp.segment_batch(img) for i in range(len(zbatch)): batch_result[i] = (seg_result[i, 0].view(-1).bincount( minlength=NUM_OBJECTS).float() / (seg_result.shape[2] * seg_result.shape[3])) result[batch_index:batch_index + len(zbatch)] = (batch_result.cpu().numpy()) batch_index += len(zbatch) return result
def tally_conditional_quantile(compute, dataset, sample_size=None, batch_size=1, gpu_cache=64, r=1024, cachefile=None, **kwargs): ''' Computes conditional quantile sketches for a large data sample that can be computed from a dataset. The compute function should return a sequence of sample batch tuples (condition, (sample, unit)-tensor), one for each condition relevant to the batch. ''' with torch.no_grad(): args = dict(sample_size=sample_size, r=r) cached_state = load_cached_state(cachefile, args) if cached_state is not None: return runningstats.RunningConditionalQuantile(state=cached_state) loader = make_loader(dataset, sample_size, batch_size, **kwargs) cq = runningstats.RunningConditionalQuantile(r=r) most_common_conditions = set() for i, batch in enumerate(pbar(loader)): sample_set = call_compute(compute, batch) for cond, sample in sample_set: # Move uncommon conditional data to the cpu before collating. if cond not in most_common_conditions: sample = sample.cpu() cq.add(cond, sample) # Move uncommon conditions off the GPU. if i and not i & (i - 1): # if i is a power of 2: common_conditions = set(cq.most_common_conditions(gpu_cache)) cq.to_('cpu', [k for k in cq.keys() if k not in common_conditions]) # At the end, move all to the CPU cq.to_('cpu') save_cached_state(cachefile, cq, args) return cq
def tally_dataset_objects(dataset, size=10000): loader = DataLoader(dataset, sampler=FixedRandomSubsetSampler(dataset, end=size), batch_size=10, pin_memory=True) upp = segmenter.UnifiedParsingSegmenter() labelnames, catnames = upp.get_label_and_category_names() result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float) batch_result = torch.zeros(loader.batch_size, NUM_OBJECTS, dtype=torch.float).cuda() with torch.no_grad(): batch_index = 0 for [batch] in pbar(loader): seg_result = upp.segment_batch(batch.cuda()) for i in range(len(batch)): batch_result[i] = (seg_result[i, 0].view(-1).bincount( minlength=NUM_OBJECTS).float() / (seg_result.shape[2] * seg_result.shape[3])) result[batch_index:batch_index + len(batch)] = (batch_result.cpu().numpy()) batch_index += len(batch) return result
def tally_mean(compute, dataset, sample_size=None, batch_size=10, cachefile=None, **kwargs): ''' Computes unitwise mean and variance stats for a large data sample that can be computed from a dataset. The compute function should return one batch of samples as a (sample, unit)-dimension tensor. ''' with torch.no_grad(): args = dict(sample_size=sample_size) cached_state = load_cached_state(cachefile, args) if cached_state is not None: return runningstats.RunningVariance(state=cached_state) loader = make_loader(dataset, sample_size, batch_size, **kwargs) rv = runningstats.RunningVariance() for batch in pbar(loader): sample = call_compute(compute, batch) rv.add(sample) rv.to_('cpu') save_cached_state(cachefile, rv, args) return rv
def eval_loss_and_reg(): discrete_experiments = dict( # dpixel=dict(discrete_pixels=True), # dunits20=dict(discrete_units=20), # dumix20=dict(discrete_units=20, mixed_units=True), # dunits10=dict(discrete_units=10), # abonly=dict(ablation_only=True), # fimabl=dict(ablation_only=True, # fullimage_ablation=True, # fullimage_measurement=True), dboth20=dict(discrete_units=20, discrete_pixels=True), # dbothm20=dict(discrete_units=20, mixed_units=True, # discrete_pixels=True), # abdisc20=dict(discrete_units=20, discrete_pixels=True, # ablation_only=True), # abdiscm20=dict(discrete_units=20, mixed_units=True, # discrete_pixels=True, # ablation_only=True), # fimadp=dict(discrete_pixels=True, # ablation_only=True, # fullimage_ablation=True, # fullimage_measurement=True), # fimadu10=dict(discrete_units=10, # ablation_only=True, # fullimage_ablation=True, # fullimage_measurement=True), # fimadb10=dict(discrete_units=10, discrete_pixels=True, # ablation_only=True, # fullimage_ablation=True, # fullimage_measurement=True), fimadbm10=dict(discrete_units=10, mixed_units=True, discrete_pixels=True, ablation_only=True, fullimage_ablation=True, fullimage_measurement=True), # fimadu20=dict(discrete_units=20, # ablation_only=True, # fullimage_ablation=True, # fullimage_measurement=True), # fimadb20=dict(discrete_units=20, discrete_pixels=True, # ablation_only=True, # fullimage_ablation=True, # fullimage_measurement=True), fimadbm20=dict(discrete_units=20, mixed_units=True, discrete_pixels=True, ablation_only=True, fullimage_ablation=True, fullimage_measurement=True)) with torch.no_grad(): total_loss = 0 discrete_losses = {k: 0 for k in discrete_experiments} for [pbatch, ploc, cbatch, cloc] in pbar(torch.utils.data.DataLoader( TensorDataset(corpus.eval_present_sample, corpus.eval_present_location, corpus.eval_candidate_sample, corpus.eval_candidate_location), batch_size=args.inference_batch_size, num_workers=10, shuffle=False, pin_memory=True), desc="Eval"): # First, put in zeros for the selected units. # Loss is amount of remaining object. total_loss = total_loss + ace_loss( segmenter, classnum, model, args.layer, high_replacement, ablation, pbatch, ploc, cbatch, cloc, run_backward=False, ablation_only=ablation_only, fullimage_measurement=fullimage_measurement) for k, config in discrete_experiments.items(): discrete_losses[k] = discrete_losses[k] + ace_loss( segmenter, classnum, model, args.layer, high_replacement, ablation, pbatch, ploc, cbatch, cloc, run_backward=False, **config) avg_loss = (total_loss / args.eval_size).item() avg_d_losses = { k: (d / args.eval_size).item() for k, d in discrete_losses.items() } regularizer = (args.l2_lambda * ablation.pow(2).sum()) pbar.print('Epoch %d Loss %g Regularizer %g' % (epoch, avg_loss, regularizer)) pbar.print(' '.join('%s: %g' % (k, d) for k, d in avg_d_losses.items())) pbar.print(scale_summary(ablation.view(-1), 10, 3)) return avg_loss, regularizer, avg_d_losses
def compute_candidate_locations(args, corpus, cache_filename, model, segmenter, classnum, second_sample): # Phase 2. Identify a set of candidate locations for doorways. # Place the median doorway activation in every location of an image # and identify where it can go that doorway pixels increase. if all(k in corpus for k in [ 'candidate_indices', 'candidate_sample', 'candidate_score', 'candidate_location', 'object_score_at_candidate', 'candidate_location_popularity' ]): return feature_shape = model.feature_shape[args.layer][2:] num_locations = numpy.prod(feature_shape).item() with torch.no_grad(): # Simplify - just treat all locations as possible possible_locations = numpy.arange(num_locations) # Speed up search for locations, by weighting probed locations # according to observed distribution. location_weights = (corpus.object_location_popularity).double() location_weights += (location_weights.mean()) / 10.0 location_weights = location_weights / location_weights.sum() candidate_scores = [] object_scores = [] prng = numpy.random.RandomState(1) for [zbatch] in pbar(torch.utils.data.DataLoader( TensorDataset(second_sample), batch_size=args.inference_batch_size, num_workers=10, pin_memory=True), desc="Candidate pool"): batch_scores = torch.zeros((len(zbatch), ) + feature_shape).cuda() flat_batch_scores = batch_scores.view(len(zbatch), -1) zbatch = zbatch.cuda() tensor_image = model(zbatch) segmented_image = segmenter.segment_batch(tensor_image, downsample=2) mask = (segmented_image == classnum).max(1)[0] object_score = torch.nn.functional.adaptive_avg_pool2d( mask.float(), feature_shape) baseline_presence = mask.float().view(mask.shape[0], -1).sum(1) edit_mask = torch.zeros((1, 1) + feature_shape).cuda() if '_tcm' in args.variant: # variant: top-conditional-mean replace_vec = (corpus.mean_present_feature[None, :, None, None].cuda()) else: # default: weighted mean replace_vec = ( corpus.weighted_mean_present_feature[None, :, None, None].cuda()) # Sample 10 random locations to examine. for loc in prng.choice(possible_locations, replace=False, p=location_weights, size=5): edit_mask.zero_() edit_mask.view(-1)[loc] = 1 model.edit_layer(args.layer, ablation=edit_mask, replacement=replace_vec) tensor_image = model(zbatch) segmented_image = segmenter.segment_batch(tensor_image, downsample=2) mask = (segmented_image == classnum).max(1)[0] modified_presence = mask.float().view(mask.shape[0], -1).sum(1) flat_batch_scores[:, loc] = (modified_presence - baseline_presence) candidate_scores.append(batch_scores.cpu()) object_scores.append(object_score.cpu()) object_scores = torch.cat(object_scores) candidate_scores = torch.cat(candidate_scores) # Eliminate candidates where the object is present. candidate_scores = candidate_scores * (object_scores == 0).float() candidate_score_at_image, candidate_location_in_image = ( candidate_scores.view(args.search_size, -1).max(1)) best_candidate_scores, best_candidate_images = torch.sort( -candidate_score_at_image) all_candidate_indices = torch.sort( best_candidate_images[:(args.train_size + args.eval_size)])[0] corpus.candidate_indices = all_candidate_indices[:args.train_size] corpus.candidate_sample = second_sample[corpus.candidate_indices] corpus.candidate_location = candidate_location_in_image[ corpus.candidate_indices] corpus.candidate_score = candidate_score_at_image[ corpus.candidate_indices] corpus.object_score_at_candidate = object_scores.view( len(object_scores), -1)[corpus.candidate_indices, corpus.candidate_location] corpus.candidate_location_popularity = torch.bincount( corpus.candidate_location, minlength=num_locations) corpus.eval_candidate_indices = all_candidate_indices[-args.eval_size:] corpus.eval_candidate_sample = second_sample[ corpus.eval_candidate_indices] corpus.eval_candidate_location = candidate_location_in_image[ corpus.eval_candidate_indices] numpy.savez(cache_filename, **corpus)
def main(): # Training settings def strpair(arg): p = tuple(arg.split(':')) if len(p) == 1: p = p + p return p parser = argparse.ArgumentParser( description='Ablation eval', epilog=textwrap.dedent(help_epilog), formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument('--model', type=str, default=None, help='constructor for the model to test') parser.add_argument('--pthfile', type=str, default=None, help='filename of .pth file for the model') parser.add_argument('--outdir', type=str, default='dissect', required=True, help='directory for dissection output') parser.add_argument('--layers', type=strpair, nargs='+', help='space-separated list of layer names to edit' + ', in the form layername[:reportedname]') parser.add_argument('--classes', type=str, nargs='+', help='space-separated list of class names to ablate') parser.add_argument('--metric', type=str, default='iou', help='ordering metric for selecting units') parser.add_argument('--unitcount', type=int, default=30, help='number of units to ablate') parser.add_argument('--segmenter', type=str, help='directory containing segmentation dataset') parser.add_argument('--netname', type=str, default=None, help='name for network in generated reports') parser.add_argument('--batch_size', type=int, default=5, help='batch size for forward pass') parser.add_argument('--size', type=int, default=200, help='number of images to test') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA usage') parser.add_argument('--quiet', action='store_true', default=False, help='silences console output') if len(sys.argv) == 1: parser.print_usage(sys.stderr) sys.exit(1) args = parser.parse_args() # Set up console output pbar.verbose(not args.quiet) # Speed up pytorch torch.backends.cudnn.benchmark = True # Set up CUDA args.cuda = not args.no_cuda and torch.cuda.is_available() if args.cuda: torch.backends.cudnn.benchmark = True # Take defaults for model constructor etc from dissect.json settings. with open(os.path.join(args.outdir, 'dissect.json')) as f: dissection = EasyDict(json.load(f)) if args.model is None: args.model = dissection.settings.model if args.pthfile is None: args.pthfile = dissection.settings.pthfile if args.segmenter is None: args.segmenter = dissection.settings.segmenter # Instantiate generator model = create_instrumented_model(args, gen=True, edit=True) if model is None: print('No model specified') sys.exit(1) # Instantiate model device = next(model.parameters()).device input_shape = model.input_shape # 4d input if convolutional, 2d input if first layer is linear. raw_sample = standard_z_sample(args.size, input_shape[1], seed=2).view((args.size, ) + input_shape[1:]) dataset = TensorDataset(raw_sample) # Create the segmenter segmenter = autoimport_eval(args.segmenter) # Now do the actual work. labelnames, catnames = (segmenter.get_label_and_category_names(dataset)) label_category = [ catnames.index(c) if c in catnames else 0 for l, c in labelnames ] labelnum_from_name = {n[0]: i for i, n in enumerate(labelnames)} segloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=10, pin_memory=(device.type == 'cuda')) # Index the dissection layers by layer name. dissect_layer = {lrec.layer: lrec for lrec in dissection.layers} # First, collect a baseline for l in model.ablation: model.ablation[l] = None # For each sort-order, do an ablation for classname in pbar(args.classes): pbar.post(c=classname) for layername in pbar(model.ablation): pbar.post(l=layername) rankname = '%s-%s' % (classname, args.metric) classnum = labelnum_from_name[classname] try: ranking = next(r for r in dissect_layer[layername].rankings if r.name == rankname) except: print('%s not found' % rankname) sys.exit(1) ordering = numpy.argsort(ranking.score) # Check if already done ablationdir = os.path.join(args.outdir, layername, 'pixablation') if os.path.isfile(os.path.join(ablationdir, '%s.json' % rankname)): with open(os.path.join(ablationdir, '%s.json' % rankname)) as f: data = EasyDict(json.load(f)) # If the unit ordering is not the same, something is wrong if not all(a == o for a, o in zip(data.ablation_units, ordering)): continue if len(data.ablation_effects) >= args.unitcount: continue # file already done. measurements = data.ablation_effects measurements = measure_ablation(segmenter, segloader, model, classnum, layername, ordering[:args.unitcount]) measurements = measurements.cpu().numpy().tolist() os.makedirs(ablationdir, exist_ok=True) with open(os.path.join(ablationdir, '%s.json' % rankname), 'w') as f: json.dump( dict(classname=classname, classnum=classnum, baseline=measurements[0], layer=layername, metric=args.metric, ablation_units=ordering.tolist(), ablation_effects=measurements[1:]), f)
def main(): args = parseargs() resdir = 'results/%s-%s-%s-%s-%s' % (args.model, args.dataset, args.seg, args.layer, int(args.quantile * 1000)) def resfile(f): return os.path.join(resdir, f) model = load_model(args) layername = instrumented_layername(args) model.retain_layer(layername) dataset = load_dataset(args, model=model.model) upfn = make_upfn(args, dataset, model, layername) sample_size = len(dataset) is_generator = (args.model == 'progan') percent_level = 1.0 - args.quantile # Tally rq.np (representation quantile, unconditional). torch.set_grad_enabled(False) pbar.descnext('rq') def compute_samples(batch, *args): data_batch = batch.cuda() _ = model(data_batch) acts = model.retained_layer(layername) hacts = upfn(acts) return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1]) rq = tally.tally_quantile(compute_samples, dataset, sample_size=sample_size, r=8192, num_workers=100, pin_memory=True, cachefile=resfile('rq.npz')) # Grab the 99th percentile, and tally conditional means at that level. level_at_99 = rq.quantiles(percent_level).cuda()[None, :, None, None] segmodel, seglabels, segcatlabels = setting.load_segmenter(args.seg) renorm = renormalize.renormalizer(dataset, target='zc') def compute_conditional_indicator(batch, *args): data_batch = batch.cuda() out_batch = model(data_batch) image_batch = out_batch if is_generator else renorm(data_batch) seg = segmodel.segment_batch(image_batch, downsample=4) acts = model.retained_layer(layername) hacts = upfn(acts) iacts = (hacts > level_at_99).float() # indicator return tally.conditional_samples(iacts, seg) pbar.descnext('condi99') condi99 = tally.tally_conditional_mean(compute_conditional_indicator, dataset, sample_size=sample_size, num_workers=3, pin_memory=True, cachefile=resfile('condi99.npz')) # Now summarize the iou stats and graph the units iou_99 = tally.iou_from_conditional_indicator_mean(condi99) unit_label_99 = [(concept.item(), seglabels[concept], segcatlabels[concept], bestiou.item()) for (bestiou, concept) in zip(*iou_99.max(0))] def measure_segclasses_with_zeroed_units(zeroed_units, sample_size=100): model.remove_edits() def zero_some_units(x, *args): x[:, zeroed_units] = 0 return x model.edit_layer(layername, rule=zero_some_units) num_seglabels = len(segmodel.get_label_and_category_names()[0]) def compute_mean_seg_in_images(batch_z, *args): img = model(batch_z.cuda()) seg = segmodel.segment_batch(img, downsample=4) seg_area = seg.shape[2] * seg.shape[3] seg_counts = torch.bincount( (seg + (num_seglabels * torch.arange( seg.shape[0], dtype=seg.dtype, device=seg.device)[:, None, None, None])).view(-1), minlength=num_seglabels * seg.shape[0]).view(seg.shape[0], -1) seg_fracs = seg_counts.float() / seg_area return seg_fracs result = tally.tally_mean(compute_mean_seg_in_images, dataset, batch_size=30, sample_size=sample_size, pin_memory=True) model.remove_edits() return result # Intervention experiment here: # segs_baseline = measure_segclasses_with_zeroed_units([]) # segs_without_treeunits = measure_segclasses_with_zeroed_units(tree_units) num_units = len(unit_label_99) baseline_segmean = test_generator_segclass_stats( model, dataset, segmodel, layername=layername, cachefile=resfile('segstats/baseline.npz')).mean() pbar.descnext('unit ablation') unit_ablation_segmean = torch.zeros(num_units, len(baseline_segmean)) for unit in pbar(random.sample(range(num_units), num_units)): stats = test_generator_segclass_stats( model, dataset, segmodel, layername=layername, zeroed_units=[unit], cachefile=resfile('segstats/ablated_unit_%d.npz' % unit)) unit_ablation_segmean[unit] = stats.mean() ablate_segclass_name = 'tree' ablate_segclass = seglabels.index(ablate_segclass_name) best_iou_units = iou_99[ablate_segclass, :].sort(0)[1].flip(0) byiou_unit_ablation_seg = torch.zeros(30) for unitcount in pbar(random.sample(range(0, 30), 30)): zero_units = best_iou_units[:unitcount].tolist() stats = test_generator_segclass_delta_stats( model, dataset, segmodel, layername=layername, zeroed_units=zero_units, cachefile=resfile('deltasegstats/ablated_best_%d_iou_%s.npz' % (unitcount, ablate_segclass_name))) byiou_unit_ablation_seg[unitcount] = stats.mean()[ablate_segclass] # Generator context experiment. num_segclass = len(seglabels) door_segclass = seglabels.index('door') door_units = iou_99[door_segclass].sort(0)[1].flip(0)[:20] door_high_values = rq.quantiles(0.995)[door_units].cuda() def compute_seg_impact(zbatch, *args): zbatch = zbatch.cuda() model.remove_edits() orig_img = model(zbatch) orig_seg = segmodel.segment_batch(orig_img, downsample=4) orig_segcount = tally.batch_bincount(orig_seg, num_segclass) rep = model.retained_layer(layername).clone() ysize = orig_seg.shape[2] // rep.shape[2] xsize = orig_seg.shape[3] // rep.shape[3] def gen_conditions(): for y in range(rep.shape[2]): for x in range(rep.shape[3]): # Take as the context location the segmentation # labels at the center of the square. selsegs = orig_seg[:, :, y * ysize + ysize // 2, x * xsize + xsize // 2] changed_rep = rep.clone() changed_rep[:, door_units, y, x] = (door_high_values[None, :]) model.edit_layer(layername, ablation=1.0, replacement=changed_rep) changed_img = model(zbatch) changed_seg = segmodel.segment_batch(changed_img, downsample=4) changed_segcount = tally.batch_bincount( changed_seg, num_segclass) delta_segcount = (changed_segcount - orig_segcount).float() for sel, delta in zip(selsegs, delta_segcount): for cond in torch.bincount(sel).nonzero()[:, 0]: if cond == 0: continue yield (cond.item(), delta) return gen_conditions() cond_changes = tally.tally_conditional_mean( compute_seg_impact, dataset, sample_size=10000, batch_size=20, cachefile=resfile('big_door_cond_changes.npz'))
def main(): args = parseargs() experiment_dir = 'results/decoupled-%d-%s-resnet' % (args.selected_classes, args.dataset) ds_dirname = dict(novelty='novelty/dataset_v1/known_classes/images', imagenet='imagenet')[args.dataset] training_dir = 'datasets/%s/train' % ds_dirname val_dir = 'datasets/%s/val' % ds_dirname os.makedirs(experiment_dir, exist_ok=True) with open(os.path.join(experiment_dir, 'args.txt'), 'w') as f: f.write(str(args) + '\n') def printstat(s): with open(os.path.join(experiment_dir, 'log.txt'), 'a') as f: f.write(str(s) + '\n') pbar.print(s) def filter_tuple(item): return item[1] < args.selected_classes # Imagenet has a couple bad exif images. warnings.filterwarnings('ignore', message='.*orrupt EXIF.*') # Here's our data train_loader = torch.utils.data.DataLoader( parallelfolder.ParallelImageFolders( [training_dir], classification=True, filter_tuples=filter_tuple, transform=transforms.Compose([ transforms.Resize(256), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), renormalize.NORMALIZER['imagenet'], ])), batch_size=64, shuffle=True, num_workers=48, pin_memory=True) val_loader = torch.utils.data.DataLoader( parallelfolder.ParallelImageFolders( [val_dir], classification=True, filter_tuples=filter_tuple, transform=transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), renormalize.NORMALIZER['imagenet'], ])), batch_size=64, shuffle=False, num_workers=24, pin_memory=True) late_model = torchvision.models.resnet50(num_classes=args.selected_classes) for n, p in late_model.named_parameters(): if 'bias' in n: torch.nn.init.zeros_(p) elif len(p.shape) <= 1: torch.nn.init.ones_(p) else: torch.nn.init.kaiming_normal_(p, nonlinearity='relu') late_model.train() late_model.cuda() model = late_model max_lr = 5e-3 max_iter = args.training_iterations def criterion(logits, true_class): goal = torch.zeros_like(logits) goal.scatter_(1, true_class[:, None], value=1.0) return torch.nn.functional.binary_cross_entropy_with_logits( logits, goal) optimizer = torch.optim.Adam(model.parameters()) scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, total_steps=max_iter - 1) iter_num = 0 best = dict(val_accuracy=0.0) # Oh, hold on. Let's actually resume training if we already have a model. checkpoint_filename = 'weights.pth' best_filename = 'best_%s' % checkpoint_filename best_checkpoint = os.path.join(experiment_dir, best_filename) try_to_resume_training = False if try_to_resume_training and os.path.exists(best_checkpoint): checkpoint = torch.load(os.path.join(experiment_dir, best_filename)) iter_num = checkpoint['iter'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) best['val_accuracy'] = checkpoint['accuracy'] def save_checkpoint(state, is_best): filename = os.path.join(experiment_dir, checkpoint_filename) os.makedirs(os.path.dirname(filename), exist_ok=True) torch.save(state, filename) if is_best: shutil.copyfile(filename, os.path.join(experiment_dir, best_filename)) def validate_and_checkpoint(): model.eval() val_loss, val_acc = AverageMeter(), AverageMeter() for input, target in pbar(val_loader): # Load data input_var, target_var = [d.cuda() for d in [input, target]] # Evaluate model with torch.no_grad(): output = model(input_var) loss = criterion(output, target_var) _, pred = output.max(1) accuracy = (target_var.eq(pred) ).data.float().sum().item() / input.size(0) val_loss.update(loss.data.item(), input.size(0)) val_acc.update(accuracy, input.size(0)) # Check accuracy pbar.post(l=val_loss.avg, a=val_acc.avg) # Save checkpoint save_checkpoint( { 'iter': iter_num, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'accuracy': val_acc.avg, 'loss': val_loss.avg, }, val_acc.avg > best['val_accuracy']) best['val_accuracy'] = max(val_acc.avg, best['val_accuracy']) printstat('Iteration %d val accuracy %.2f' % (iter_num, val_acc.avg * 100.0)) # Here is our training loop. while iter_num < max_iter: for filtered_input, filtered_target in pbar(train_loader): # Track the average training loss/accuracy for each epoch. train_loss, train_acc = AverageMeter(), AverageMeter() # Load data input_var, target_var = [ d.cuda() for d in [filtered_input, filtered_target] ] # Evaluate model output = model(input_var) loss = criterion(output, target_var) train_loss.update(loss.data.item(), filtered_input.size(0)) # Perform one step of SGD if iter_num > 0: optimizer.zero_grad() loss.backward() optimizer.step() # Learning rate schedule scheduler.step() # Also check training set accuracy _, pred = output.max(1) accuracy = (target_var.eq(pred)).data.float().sum().item() / ( filtered_input.size(0)) train_acc.update(accuracy) remaining = 1 - iter_num / float(max_iter) pbar.post(l=train_loss.avg, a=train_acc.avg, v=best['val_accuracy']) # Ocassionally check validation set accuracy and checkpoint if iter_num % 1000 == 0: validate_and_checkpoint() model.train() # Advance iter_num += 1 if iter_num >= max_iter: break
def train_ablation(args, corpus, cachefile, model, segmenter, classnum, initial_ablation=None): cachedir = os.path.dirname(cachefile) snapdir = os.path.join(cachedir, 'snapshots') os.makedirs(snapdir, exist_ok=True) # high_replacement = corpus.feature_99[None,:,None,None].cuda() if '_h99' in args.variant: high_replacement = corpus.feature_99[None, :, None, None].cuda() elif '_tcm' in args.variant: # variant: top-conditional-mean high_replacement = (corpus.mean_present_feature[None, :, None, None].cuda()) else: # default: weighted mean high_replacement = (corpus.weighted_mean_present_feature[None, :, None, None].cuda()) fullimage_measurement = False ablation_only = False fullimage_ablation = False if '_fim' in args.variant: fullimage_measurement = True elif '_fia' in args.variant: fullimage_measurement = True ablation_only = True fullimage_ablation = True high_replacement.requires_grad = False for p in model.parameters(): p.requires_grad = False ablation = torch.zeros(high_replacement.shape).cuda() if initial_ablation is not None: ablation.view(-1)[...] = initial_ablation ablation.requires_grad = True optimizer = torch.optim.Adam([ablation], lr=0.01) start_epoch = 0 epoch = 0 def eval_loss_and_reg(): discrete_experiments = dict( # dpixel=dict(discrete_pixels=True), # dunits20=dict(discrete_units=20), # dumix20=dict(discrete_units=20, mixed_units=True), # dunits10=dict(discrete_units=10), # abonly=dict(ablation_only=True), # fimabl=dict(ablation_only=True, # fullimage_ablation=True, # fullimage_measurement=True), dboth20=dict(discrete_units=20, discrete_pixels=True), # dbothm20=dict(discrete_units=20, mixed_units=True, # discrete_pixels=True), # abdisc20=dict(discrete_units=20, discrete_pixels=True, # ablation_only=True), # abdiscm20=dict(discrete_units=20, mixed_units=True, # discrete_pixels=True, # ablation_only=True), # fimadp=dict(discrete_pixels=True, # ablation_only=True, # fullimage_ablation=True, # fullimage_measurement=True), # fimadu10=dict(discrete_units=10, # ablation_only=True, # fullimage_ablation=True, # fullimage_measurement=True), # fimadb10=dict(discrete_units=10, discrete_pixels=True, # ablation_only=True, # fullimage_ablation=True, # fullimage_measurement=True), fimadbm10=dict(discrete_units=10, mixed_units=True, discrete_pixels=True, ablation_only=True, fullimage_ablation=True, fullimage_measurement=True), # fimadu20=dict(discrete_units=20, # ablation_only=True, # fullimage_ablation=True, # fullimage_measurement=True), # fimadb20=dict(discrete_units=20, discrete_pixels=True, # ablation_only=True, # fullimage_ablation=True, # fullimage_measurement=True), fimadbm20=dict(discrete_units=20, mixed_units=True, discrete_pixels=True, ablation_only=True, fullimage_ablation=True, fullimage_measurement=True)) with torch.no_grad(): total_loss = 0 discrete_losses = {k: 0 for k in discrete_experiments} for [pbatch, ploc, cbatch, cloc] in pbar(torch.utils.data.DataLoader( TensorDataset(corpus.eval_present_sample, corpus.eval_present_location, corpus.eval_candidate_sample, corpus.eval_candidate_location), batch_size=args.inference_batch_size, num_workers=10, shuffle=False, pin_memory=True), desc="Eval"): # First, put in zeros for the selected units. # Loss is amount of remaining object. total_loss = total_loss + ace_loss( segmenter, classnum, model, args.layer, high_replacement, ablation, pbatch, ploc, cbatch, cloc, run_backward=False, ablation_only=ablation_only, fullimage_measurement=fullimage_measurement) for k, config in discrete_experiments.items(): discrete_losses[k] = discrete_losses[k] + ace_loss( segmenter, classnum, model, args.layer, high_replacement, ablation, pbatch, ploc, cbatch, cloc, run_backward=False, **config) avg_loss = (total_loss / args.eval_size).item() avg_d_losses = { k: (d / args.eval_size).item() for k, d in discrete_losses.items() } regularizer = (args.l2_lambda * ablation.pow(2).sum()) pbar.print('Epoch %d Loss %g Regularizer %g' % (epoch, avg_loss, regularizer)) pbar.print(' '.join('%s: %g' % (k, d) for k, d in avg_d_losses.items())) pbar.print(scale_summary(ablation.view(-1), 10, 3)) return avg_loss, regularizer, avg_d_losses if args.eval_only: # For eval_only, just load each snapshot and re-run validation eval # pass on each one. for epoch in range(-1, args.train_epochs): snapfile = os.path.join(snapdir, 'epoch-%d.pth' % epoch) if not os.path.exists(snapfile): data = {} if epoch >= 0: print('No epoch %d' % epoch) continue else: data = torch.load(snapfile) with torch.no_grad(): ablation[...] = data['ablation'].to(ablation.device) optimizer.load_state_dict(data['optimizer']) avg_loss, regularizer, new_extra = eval_loss_and_reg() # Keep old values, and update any new ones. extra = { k: v for k, v in data.items() if k not in ['ablation', 'optimizer', 'avg_loss'] } extra.update(new_extra) torch.save( dict(ablation=ablation, optimizer=optimizer.state_dict(), avg_loss=avg_loss, **extra), os.path.join(snapdir, 'epoch-%d.pth' % epoch)) # Return loaded ablation. return ablation.view(-1).detach().cpu().numpy() if not args.no_cache: for start_epoch in reversed(range(args.train_epochs)): snapfile = os.path.join(snapdir, 'epoch-%d.pth' % start_epoch) if os.path.exists(snapfile): data = torch.load(snapfile) with torch.no_grad(): ablation[...] = data['ablation'].to(ablation.device) optimizer.load_state_dict(data['optimizer']) start_epoch += 1 break if start_epoch < args.train_epochs: epoch = start_epoch - 1 avg_loss, regularizer, extra = eval_loss_and_reg() if epoch == -1: torch.save( dict(ablation=ablation, optimizer=optimizer.state_dict(), avg_loss=avg_loss, **extra), os.path.join(snapdir, 'epoch-%d.pth' % epoch)) update_size = args.train_update_freq * args.train_batch_size for epoch in range(start_epoch, args.train_epochs): candidate_shuffle = torch.randperm(len(corpus.candidate_sample)) train_loss = 0 for batch_num, [pbatch, ploc, cbatch, cloc] in enumerate( pbar(torch.utils.data.DataLoader( TensorDataset( corpus.object_present_sample, corpus.object_present_location, corpus.candidate_sample[candidate_shuffle], corpus.candidate_location[candidate_shuffle]), batch_size=args.train_batch_size, num_workers=10, shuffle=True, pin_memory=True), desc="ACE opt epoch %d" % epoch)): if batch_num % args.train_update_freq == 0: optimizer.zero_grad() # First, put in zeros for the selected units. Loss is amount # of remaining object. loss = ace_loss(segmenter, classnum, model, args.layer, high_replacement, ablation, pbatch, ploc, cbatch, cloc, run_backward=True, ablation_only=ablation_only, fullimage_measurement=fullimage_measurement) with torch.no_grad(): train_loss = train_loss + loss if (batch_num + 1) % args.train_update_freq == 0: # Third, add some L2 loss to encourage sparsity. regularizer = (args.l2_lambda * update_size * ablation.pow(2).sum()) regularizer.backward() optimizer.step() with torch.no_grad(): ablation.clamp_(0, 1) pbar.post(l=(train_loss / update_size).item(), r=(regularizer / update_size).item()) train_loss = 0 avg_loss, regularizer, extra = eval_loss_and_reg() torch.save( dict(ablation=ablation, optimizer=optimizer.state_dict(), avg_loss=avg_loss, **extra), os.path.join(snapdir, 'epoch-%d.pth' % epoch)) numpy.save(os.path.join(snapdir, 'epoch-%d.npy' % epoch), ablation.detach().cpu().numpy()) # The output of this phase is this set of scores. return ablation.view(-1).detach().cpu().numpy()
def main(): args = parseargs() model = setting.load_classifier(args.model) model = nethook.InstrumentedModel(model).cuda().eval() layername = args.layer model.retain_layer(layername) dataset = setting.load_dataset(args.dataset, crop_size=224) train_dataset = setting.load_dataset(args.dataset, crop_size=224, split='train') sample_size = len(dataset) # Probe layer to get sizes model(dataset[0][0][None].cuda()) num_units = model.retained_layer(layername).shape[1] classlabels = dataset.classes # Measure baseline classification accuracy on val set, and cache. pbar.descnext('baseline_pra') baseline_precision, baseline_recall, baseline_accuracy, baseline_ba = ( test_perclass_pra(model, dataset, cachefile=sharedfile('pra-%s-%s/pra_baseline.npz' % (args.model, args.dataset)))) pbar.print('baseline acc', baseline_ba.mean().item()) # Now erase each unit, one at a time, and retest accuracy. unit_list = random.sample(list(range(num_units)), num_units) val_single_unit_ablation_ba = torch.zeros(num_units, len(classlabels)) for unit in pbar(unit_list): pbar.descnext('test unit %d' % unit) # Get binary accuracy if the model after ablating the unit. _, _, _, ablation_ba = test_perclass_pra( model, dataset, layername=layername, ablated_units=[unit], cachefile=sharedfile('pra-%s-%s/pra_ablate_unit_%d.npz' % (args.model, args.dataset, unit))) val_single_unit_ablation_ba[unit] = ablation_ba # For the purpose of ranking units by importance to a class, we # measure using the training set (to avoid training unit ordering # on the test set). sample_size = None # Measure baseline classification accuracy, and cache. pbar.descnext('train_baseline_pra') baseline_precision, baseline_recall, baseline_accuracy, baseline_ba = ( test_perclass_pra( model, train_dataset, sample_size=sample_size, cachefile=sharedfile('ttv-pra-%s-%s/pra_train_baseline.npz' % (args.model, args.dataset)))) pbar.print('baseline acc', baseline_ba.mean().item()) # Measure accuracy on the val set. pbar.descnext('val_baseline_pra') _, _, _, val_baseline_ba = (test_perclass_pra( model, dataset, cachefile=sharedfile('ttv-pra-%s-%s/pra_val_baseline.npz' % (args.model, args.dataset)))) pbar.print('val baseline acc', val_baseline_ba.mean().item()) # Do in shuffled order to allow multiprocessing. single_unit_ablation_ba = torch.zeros(num_units, len(classlabels)) for unit in pbar(unit_list): pbar.descnext('test unit %d' % unit) _, _, _, ablation_ba = test_perclass_pra( model, train_dataset, layername=layername, ablated_units=[unit], sample_size=sample_size, cachefile=sharedfile('ttv-pra-%s-%s/pra_train_ablate_unit_%d.npz' % (args.model, args.dataset, unit))) single_unit_ablation_ba[unit] = ablation_ba # Now for every class, remove a set of the N most-important # and N least-important units for that class, and measure accuracy. for classnum in pbar( random.sample(range(len(classlabels)), len(classlabels))): # For a few classes, let's chart the whole range of ablations. if classnum in [100, 169, 351, 304]: num_best_list = range(1, num_units) else: num_best_list = [1, 2, 3, 4, 5, 20, 64, 128, 256] pbar.descnext('numbest') for num_best in pbar(random.sample(num_best_list, len(num_best_list))): num_worst = num_units - num_best unitlist = single_unit_ablation_ba[:, classnum].sort(0)[1][:num_best] _, _, _, testba = test_perclass_pra( model, dataset, layername=layername, ablated_units=unitlist, cachefile=sharedfile( 'ttv-pra-%s-%s/pra_val_ablate_classunits_%s_ba_%d.npz' % (args.model, args.dataset, classlabels[classnum], len(unitlist)))) unitlist = ( single_unit_ablation_ba[:, classnum].sort(0)[1][-num_worst:]) _, _, _, testba2 = test_perclass_pra( model, dataset, layername=layername, ablated_units=unitlist, cachefile=sharedfile( 'ttv-pra-%s-%s/pra_val_ablate_classunits_%s_worstba_%d.npz' % (args.model, args.dataset, classlabels[classnum], len(unitlist)))) pbar.print('%s: best %d %.3f vs worst N %.3f' % (classlabels[classnum], num_best, testba[classnum] - val_baseline_ba[classnum], testba2[classnum] - val_baseline_ba[classnum]))
def test_perclass_pra(model, dataset, layername=None, ablated_units=None, sample_size=None, cachefile=None): '''Classifier precision/recall/accuracy measurement. Disables a set of units in the specified layer, and then measures per-class precision, recall, accuracy and balanced (binary classification) accuracy for each class, compared to the ground truth in the given dataset.''' try: if cachefile is not None: data = numpy.load(cachefile) # verify that this is computed. data['true_negative_rate'] result = tuple( torch.tensor(data[key]) for key in ['precision', 'recall', 'accuracy', 'balanced_accuracy']) pbar.print('Loading cached %s' % cachefile) return result except: pass model.remove_edits() if ablated_units is not None: def ablate_the_units(x, *args): x[:, ablated_units] = 0 return x model.edit_layer(layername, rule=ablate_the_units) with torch.no_grad(): num_classes = len(dataset.classes) true_counts = torch.zeros(num_classes, dtype=torch.int64).cuda() pred_counts = torch.zeros(num_classes, dtype=torch.int64).cuda() correct_counts = torch.zeros(num_classes, dtype=torch.int64).cuda() total_count = 0 sampler = None if sample_size is None else (FixedSubsetSampler( list(range(sample_size)))) loader = torch.utils.data.DataLoader(dataset, batch_size=100, num_workers=20, sampler=sampler, pin_memory=True) for image_batch, class_batch in pbar(loader): total_count += len(image_batch) image_batch, class_batch = [ d.cuda() for d in [image_batch, class_batch] ] scores = model(image_batch) preds = scores.max(1)[1] correct = (preds == class_batch) true_counts.add_(class_batch.bincount(minlength=num_classes)) pred_counts.add_(preds.bincount(minlength=num_classes)) correct_counts.add_( class_batch.bincount(correct, minlength=num_classes).long()) model.remove_edits() true_neg_counts = ((total_count - true_counts) - (pred_counts - correct_counts)) precision = (correct_counts.float() / pred_counts.float()).cpu() recall = (correct_counts.float() / true_counts.float()).cpu() accuracy = (correct_counts + true_neg_counts).float().cpu() / total_count true_neg_rate = (true_neg_counts.float() / (total_count - true_counts).float()).cpu() balanced_accuracy = (recall + true_neg_rate) / 2 if cachefile is not None: numpy.savez(cachefile, precision=precision.numpy(), recall=recall.numpy(), accuracy=accuracy.numpy(), true_negative_rate=true_neg_rate.numpy(), balanced_accuracy=balanced_accuracy.numpy()) return precision, recall, accuracy, balanced_accuracy