def validate_and_checkpoint(): model.eval() val_loss, val_acc = AverageMeter(), AverageMeter() for input, target in pbar(val_loader): # Load data input_var, target_var = [d.cuda() for d in [input, target]] # Evaluate model with torch.no_grad(): output = model(input_var) loss = criterion(output, target_var) _, pred = output.max(1) accuracy = (target_var.eq(pred) ).data.float().sum().item() / input.size(0) val_loss.update(loss.data.item(), input.size(0)) val_acc.update(accuracy, input.size(0)) # Check accuracy pbar.post(l=val_loss.avg, a=val_acc.avg) # Save checkpoint save_checkpoint( { 'iter': iter_num, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'accuracy': val_acc.avg, 'loss': val_loss.avg, }, val_acc.avg > best['val_accuracy']) best['val_accuracy'] = max(val_acc.avg, best['val_accuracy']) printstat('Iteration %d val accuracy %.2f' % (iter_num, val_acc.avg * 100.0))
def main(): # Training settings def strpair(arg): p = tuple(arg.split(':')) if len(p) == 1: p = p + p return p parser = argparse.ArgumentParser( description='Ablation eval', epilog=textwrap.dedent(help_epilog), formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument('--model', type=str, default=None, help='constructor for the model to test') parser.add_argument('--pthfile', type=str, default=None, help='filename of .pth file for the model') parser.add_argument('--outdir', type=str, default='dissect', required=True, help='directory for dissection output') parser.add_argument('--layers', type=strpair, nargs='+', help='space-separated list of layer names to edit' + ', in the form layername[:reportedname]') parser.add_argument('--classes', type=str, nargs='+', help='space-separated list of class names to ablate') parser.add_argument('--metric', type=str, default='iou', help='ordering metric for selecting units') parser.add_argument('--unitcount', type=int, default=30, help='number of units to ablate') parser.add_argument('--segmenter', type=str, help='directory containing segmentation dataset') parser.add_argument('--netname', type=str, default=None, help='name for network in generated reports') parser.add_argument('--batch_size', type=int, default=5, help='batch size for forward pass') parser.add_argument('--size', type=int, default=200, help='number of images to test') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA usage') parser.add_argument('--quiet', action='store_true', default=False, help='silences console output') if len(sys.argv) == 1: parser.print_usage(sys.stderr) sys.exit(1) args = parser.parse_args() # Set up console output pbar.verbose(not args.quiet) # Speed up pytorch torch.backends.cudnn.benchmark = True # Set up CUDA args.cuda = not args.no_cuda and torch.cuda.is_available() if args.cuda: torch.backends.cudnn.benchmark = True # Take defaults for model constructor etc from dissect.json settings. with open(os.path.join(args.outdir, 'dissect.json')) as f: dissection = EasyDict(json.load(f)) if args.model is None: args.model = dissection.settings.model if args.pthfile is None: args.pthfile = dissection.settings.pthfile if args.segmenter is None: args.segmenter = dissection.settings.segmenter # Instantiate generator model = create_instrumented_model(args, gen=True, edit=True) if model is None: print('No model specified') sys.exit(1) # Instantiate model device = next(model.parameters()).device input_shape = model.input_shape # 4d input if convolutional, 2d input if first layer is linear. raw_sample = standard_z_sample(args.size, input_shape[1], seed=2).view((args.size, ) + input_shape[1:]) dataset = TensorDataset(raw_sample) # Create the segmenter segmenter = autoimport_eval(args.segmenter) # Now do the actual work. labelnames, catnames = (segmenter.get_label_and_category_names(dataset)) label_category = [ catnames.index(c) if c in catnames else 0 for l, c in labelnames ] labelnum_from_name = {n[0]: i for i, n in enumerate(labelnames)} segloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=10, pin_memory=(device.type == 'cuda')) # Index the dissection layers by layer name. dissect_layer = {lrec.layer: lrec for lrec in dissection.layers} # First, collect a baseline for l in model.ablation: model.ablation[l] = None # For each sort-order, do an ablation for classname in pbar(args.classes): pbar.post(c=classname) for layername in pbar(model.ablation): pbar.post(l=layername) rankname = '%s-%s' % (classname, args.metric) classnum = labelnum_from_name[classname] try: ranking = next(r for r in dissect_layer[layername].rankings if r.name == rankname) except: print('%s not found' % rankname) sys.exit(1) ordering = numpy.argsort(ranking.score) # Check if already done ablationdir = os.path.join(args.outdir, layername, 'pixablation') if os.path.isfile(os.path.join(ablationdir, '%s.json' % rankname)): with open(os.path.join(ablationdir, '%s.json' % rankname)) as f: data = EasyDict(json.load(f)) # If the unit ordering is not the same, something is wrong if not all(a == o for a, o in zip(data.ablation_units, ordering)): continue if len(data.ablation_effects) >= args.unitcount: continue # file already done. measurements = data.ablation_effects measurements = measure_ablation(segmenter, segloader, model, classnum, layername, ordering[:args.unitcount]) measurements = measurements.cpu().numpy().tolist() os.makedirs(ablationdir, exist_ok=True) with open(os.path.join(ablationdir, '%s.json' % rankname), 'w') as f: json.dump( dict(classname=classname, classnum=classnum, baseline=measurements[0], layer=layername, metric=args.metric, ablation_units=ordering.tolist(), ablation_effects=measurements[1:]), f)
def main(): args = parseargs() experiment_dir = 'results/decoupled-%d-%s-resnet' % (args.selected_classes, args.dataset) ds_dirname = dict(novelty='novelty/dataset_v1/known_classes/images', imagenet='imagenet')[args.dataset] training_dir = 'datasets/%s/train' % ds_dirname val_dir = 'datasets/%s/val' % ds_dirname os.makedirs(experiment_dir, exist_ok=True) with open(os.path.join(experiment_dir, 'args.txt'), 'w') as f: f.write(str(args) + '\n') def printstat(s): with open(os.path.join(experiment_dir, 'log.txt'), 'a') as f: f.write(str(s) + '\n') pbar.print(s) def filter_tuple(item): return item[1] < args.selected_classes # Imagenet has a couple bad exif images. warnings.filterwarnings('ignore', message='.*orrupt EXIF.*') # Here's our data train_loader = torch.utils.data.DataLoader( parallelfolder.ParallelImageFolders( [training_dir], classification=True, filter_tuples=filter_tuple, transform=transforms.Compose([ transforms.Resize(256), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), renormalize.NORMALIZER['imagenet'], ])), batch_size=64, shuffle=True, num_workers=48, pin_memory=True) val_loader = torch.utils.data.DataLoader( parallelfolder.ParallelImageFolders( [val_dir], classification=True, filter_tuples=filter_tuple, transform=transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), renormalize.NORMALIZER['imagenet'], ])), batch_size=64, shuffle=False, num_workers=24, pin_memory=True) late_model = torchvision.models.resnet50(num_classes=args.selected_classes) for n, p in late_model.named_parameters(): if 'bias' in n: torch.nn.init.zeros_(p) elif len(p.shape) <= 1: torch.nn.init.ones_(p) else: torch.nn.init.kaiming_normal_(p, nonlinearity='relu') late_model.train() late_model.cuda() model = late_model max_lr = 5e-3 max_iter = args.training_iterations def criterion(logits, true_class): goal = torch.zeros_like(logits) goal.scatter_(1, true_class[:, None], value=1.0) return torch.nn.functional.binary_cross_entropy_with_logits( logits, goal) optimizer = torch.optim.Adam(model.parameters()) scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, total_steps=max_iter - 1) iter_num = 0 best = dict(val_accuracy=0.0) # Oh, hold on. Let's actually resume training if we already have a model. checkpoint_filename = 'weights.pth' best_filename = 'best_%s' % checkpoint_filename best_checkpoint = os.path.join(experiment_dir, best_filename) try_to_resume_training = False if try_to_resume_training and os.path.exists(best_checkpoint): checkpoint = torch.load(os.path.join(experiment_dir, best_filename)) iter_num = checkpoint['iter'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) best['val_accuracy'] = checkpoint['accuracy'] def save_checkpoint(state, is_best): filename = os.path.join(experiment_dir, checkpoint_filename) os.makedirs(os.path.dirname(filename), exist_ok=True) torch.save(state, filename) if is_best: shutil.copyfile(filename, os.path.join(experiment_dir, best_filename)) def validate_and_checkpoint(): model.eval() val_loss, val_acc = AverageMeter(), AverageMeter() for input, target in pbar(val_loader): # Load data input_var, target_var = [d.cuda() for d in [input, target]] # Evaluate model with torch.no_grad(): output = model(input_var) loss = criterion(output, target_var) _, pred = output.max(1) accuracy = (target_var.eq(pred) ).data.float().sum().item() / input.size(0) val_loss.update(loss.data.item(), input.size(0)) val_acc.update(accuracy, input.size(0)) # Check accuracy pbar.post(l=val_loss.avg, a=val_acc.avg) # Save checkpoint save_checkpoint( { 'iter': iter_num, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'accuracy': val_acc.avg, 'loss': val_loss.avg, }, val_acc.avg > best['val_accuracy']) best['val_accuracy'] = max(val_acc.avg, best['val_accuracy']) printstat('Iteration %d val accuracy %.2f' % (iter_num, val_acc.avg * 100.0)) # Here is our training loop. while iter_num < max_iter: for filtered_input, filtered_target in pbar(train_loader): # Track the average training loss/accuracy for each epoch. train_loss, train_acc = AverageMeter(), AverageMeter() # Load data input_var, target_var = [ d.cuda() for d in [filtered_input, filtered_target] ] # Evaluate model output = model(input_var) loss = criterion(output, target_var) train_loss.update(loss.data.item(), filtered_input.size(0)) # Perform one step of SGD if iter_num > 0: optimizer.zero_grad() loss.backward() optimizer.step() # Learning rate schedule scheduler.step() # Also check training set accuracy _, pred = output.max(1) accuracy = (target_var.eq(pred)).data.float().sum().item() / ( filtered_input.size(0)) train_acc.update(accuracy) remaining = 1 - iter_num / float(max_iter) pbar.post(l=train_loss.avg, a=train_acc.avg, v=best['val_accuracy']) # Ocassionally check validation set accuracy and checkpoint if iter_num % 1000 == 0: validate_and_checkpoint() model.train() # Advance iter_num += 1 if iter_num >= max_iter: break
def train_ablation(args, corpus, cachefile, model, segmenter, classnum, initial_ablation=None): cachedir = os.path.dirname(cachefile) snapdir = os.path.join(cachedir, 'snapshots') os.makedirs(snapdir, exist_ok=True) # high_replacement = corpus.feature_99[None,:,None,None].cuda() if '_h99' in args.variant: high_replacement = corpus.feature_99[None, :, None, None].cuda() elif '_tcm' in args.variant: # variant: top-conditional-mean high_replacement = (corpus.mean_present_feature[None, :, None, None].cuda()) else: # default: weighted mean high_replacement = (corpus.weighted_mean_present_feature[None, :, None, None].cuda()) fullimage_measurement = False ablation_only = False fullimage_ablation = False if '_fim' in args.variant: fullimage_measurement = True elif '_fia' in args.variant: fullimage_measurement = True ablation_only = True fullimage_ablation = True high_replacement.requires_grad = False for p in model.parameters(): p.requires_grad = False ablation = torch.zeros(high_replacement.shape).cuda() if initial_ablation is not None: ablation.view(-1)[...] = initial_ablation ablation.requires_grad = True optimizer = torch.optim.Adam([ablation], lr=0.01) start_epoch = 0 epoch = 0 def eval_loss_and_reg(): discrete_experiments = dict( # dpixel=dict(discrete_pixels=True), # dunits20=dict(discrete_units=20), # dumix20=dict(discrete_units=20, mixed_units=True), # dunits10=dict(discrete_units=10), # abonly=dict(ablation_only=True), # fimabl=dict(ablation_only=True, # fullimage_ablation=True, # fullimage_measurement=True), dboth20=dict(discrete_units=20, discrete_pixels=True), # dbothm20=dict(discrete_units=20, mixed_units=True, # discrete_pixels=True), # abdisc20=dict(discrete_units=20, discrete_pixels=True, # ablation_only=True), # abdiscm20=dict(discrete_units=20, mixed_units=True, # discrete_pixels=True, # ablation_only=True), # fimadp=dict(discrete_pixels=True, # ablation_only=True, # fullimage_ablation=True, # fullimage_measurement=True), # fimadu10=dict(discrete_units=10, # ablation_only=True, # fullimage_ablation=True, # fullimage_measurement=True), # fimadb10=dict(discrete_units=10, discrete_pixels=True, # ablation_only=True, # fullimage_ablation=True, # fullimage_measurement=True), fimadbm10=dict(discrete_units=10, mixed_units=True, discrete_pixels=True, ablation_only=True, fullimage_ablation=True, fullimage_measurement=True), # fimadu20=dict(discrete_units=20, # ablation_only=True, # fullimage_ablation=True, # fullimage_measurement=True), # fimadb20=dict(discrete_units=20, discrete_pixels=True, # ablation_only=True, # fullimage_ablation=True, # fullimage_measurement=True), fimadbm20=dict(discrete_units=20, mixed_units=True, discrete_pixels=True, ablation_only=True, fullimage_ablation=True, fullimage_measurement=True)) with torch.no_grad(): total_loss = 0 discrete_losses = {k: 0 for k in discrete_experiments} for [pbatch, ploc, cbatch, cloc] in pbar(torch.utils.data.DataLoader( TensorDataset(corpus.eval_present_sample, corpus.eval_present_location, corpus.eval_candidate_sample, corpus.eval_candidate_location), batch_size=args.inference_batch_size, num_workers=10, shuffle=False, pin_memory=True), desc="Eval"): # First, put in zeros for the selected units. # Loss is amount of remaining object. total_loss = total_loss + ace_loss( segmenter, classnum, model, args.layer, high_replacement, ablation, pbatch, ploc, cbatch, cloc, run_backward=False, ablation_only=ablation_only, fullimage_measurement=fullimage_measurement) for k, config in discrete_experiments.items(): discrete_losses[k] = discrete_losses[k] + ace_loss( segmenter, classnum, model, args.layer, high_replacement, ablation, pbatch, ploc, cbatch, cloc, run_backward=False, **config) avg_loss = (total_loss / args.eval_size).item() avg_d_losses = { k: (d / args.eval_size).item() for k, d in discrete_losses.items() } regularizer = (args.l2_lambda * ablation.pow(2).sum()) pbar.print('Epoch %d Loss %g Regularizer %g' % (epoch, avg_loss, regularizer)) pbar.print(' '.join('%s: %g' % (k, d) for k, d in avg_d_losses.items())) pbar.print(scale_summary(ablation.view(-1), 10, 3)) return avg_loss, regularizer, avg_d_losses if args.eval_only: # For eval_only, just load each snapshot and re-run validation eval # pass on each one. for epoch in range(-1, args.train_epochs): snapfile = os.path.join(snapdir, 'epoch-%d.pth' % epoch) if not os.path.exists(snapfile): data = {} if epoch >= 0: print('No epoch %d' % epoch) continue else: data = torch.load(snapfile) with torch.no_grad(): ablation[...] = data['ablation'].to(ablation.device) optimizer.load_state_dict(data['optimizer']) avg_loss, regularizer, new_extra = eval_loss_and_reg() # Keep old values, and update any new ones. extra = { k: v for k, v in data.items() if k not in ['ablation', 'optimizer', 'avg_loss'] } extra.update(new_extra) torch.save( dict(ablation=ablation, optimizer=optimizer.state_dict(), avg_loss=avg_loss, **extra), os.path.join(snapdir, 'epoch-%d.pth' % epoch)) # Return loaded ablation. return ablation.view(-1).detach().cpu().numpy() if not args.no_cache: for start_epoch in reversed(range(args.train_epochs)): snapfile = os.path.join(snapdir, 'epoch-%d.pth' % start_epoch) if os.path.exists(snapfile): data = torch.load(snapfile) with torch.no_grad(): ablation[...] = data['ablation'].to(ablation.device) optimizer.load_state_dict(data['optimizer']) start_epoch += 1 break if start_epoch < args.train_epochs: epoch = start_epoch - 1 avg_loss, regularizer, extra = eval_loss_and_reg() if epoch == -1: torch.save( dict(ablation=ablation, optimizer=optimizer.state_dict(), avg_loss=avg_loss, **extra), os.path.join(snapdir, 'epoch-%d.pth' % epoch)) update_size = args.train_update_freq * args.train_batch_size for epoch in range(start_epoch, args.train_epochs): candidate_shuffle = torch.randperm(len(corpus.candidate_sample)) train_loss = 0 for batch_num, [pbatch, ploc, cbatch, cloc] in enumerate( pbar(torch.utils.data.DataLoader( TensorDataset( corpus.object_present_sample, corpus.object_present_location, corpus.candidate_sample[candidate_shuffle], corpus.candidate_location[candidate_shuffle]), batch_size=args.train_batch_size, num_workers=10, shuffle=True, pin_memory=True), desc="ACE opt epoch %d" % epoch)): if batch_num % args.train_update_freq == 0: optimizer.zero_grad() # First, put in zeros for the selected units. Loss is amount # of remaining object. loss = ace_loss(segmenter, classnum, model, args.layer, high_replacement, ablation, pbatch, ploc, cbatch, cloc, run_backward=True, ablation_only=ablation_only, fullimage_measurement=fullimage_measurement) with torch.no_grad(): train_loss = train_loss + loss if (batch_num + 1) % args.train_update_freq == 0: # Third, add some L2 loss to encourage sparsity. regularizer = (args.l2_lambda * update_size * ablation.pow(2).sum()) regularizer.backward() optimizer.step() with torch.no_grad(): ablation.clamp_(0, 1) pbar.post(l=(train_loss / update_size).item(), r=(regularizer / update_size).item()) train_loss = 0 avg_loss, regularizer, extra = eval_loss_and_reg() torch.save( dict(ablation=ablation, optimizer=optimizer.state_dict(), avg_loss=avg_loss, **extra), os.path.join(snapdir, 'epoch-%d.pth' % epoch)) numpy.save(os.path.join(snapdir, 'epoch-%d.npy' % epoch), ablation.detach().cpu().numpy()) # The output of this phase is this set of scores. return ablation.view(-1).detach().cpu().numpy()