def main(args): dataset = args.dataset layernum = args.layernum split = args.split selected_classes = args.selected_classes def ef(s): return os.path.join(args.expdir, s) model_dir = "/data/vision/torralba/dissect/novelty/models" model_name = f"{dataset}_moco_resnet50.pth" model_path = os.path.join(model_dir, model_name) val_path = f"datasets/{dataset}/val" train_path = f"datasets/{dataset}/train" img_trans = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), renormalize.NORMALIZER['imagenet'] ]) dsv = parallelfolder.ParallelImageFolders([val_path], transform=img_trans, classification=True) dst = parallelfolder.ParallelImageFolders([train_path], transform=img_trans, classification=True) if selected_classes is None: selected_classes = len(dst.classes) // 2 dsm = dict(val=dsv, train=dst) dp_model = InsResNet50() checkpoint = torch.load(model_path) dp_model.load_state_dict(checkpoint['model_ema']) model = dp_model.encoder.module model.cuda() def batch_features(imgbatch, cls): result = model(imgbatch.cuda(), layer=layernum) if len(result.shape) == 4: result = result.permute(0, 2, 3, 1).reshape(-1, result.shape[1]) return result mcov = tally.tally_covariance(batch_features, dsm[split], num_workers=100, batch_size=args.batch_size, pin_memory=True, cachefile=ef(f'{dataset}-{split}-layer{layernum}-mcov.npz')) def selclass_features(imgbatch, cls): result = model(imgbatch.cuda(), layer=layernum) if len(result.shape) == 4: cls = cls[:,None,None].expand(result.shape[0], result.shape[2], result.shape[3]).reshape(-1) result = result.permute(0, 2, 3, 1).reshape(-1, result.shape[1]) selected = result[cls < selected_classes] return selected selcov = tally.tally_covariance(selclass_features, dsm[split], num_workers=100, batch_size=args.batch_size, pin_memory=True, cachefile=ef(f'{dataset}-{split}-layer{layernum}' + f'-sel{selected_classes}-mcov.npz'))
def load_dataset(domain, split=None, full=False, crop_size=None, download=True): if domain in [ 'places', 'imagenet', 'pacs-p', 'pacs-a', 'pacs-c', 'pacs-s' ]: if split is None: split = 'val' dirname = 'datasets/%s/%s' % (domain, split) if download and not os.path.exists(dirname) and domain == 'places': os.makedirs('datasets', exist_ok=True) torchvision.datasets.utils.download_and_extract_archive( 'https://dissect.csail.mit.edu/datasets/' + 'places_%s.zip' % split, 'datasets', md5=dict(val='593bbc21590cf7c396faac2e600cd30c', train='d1db6ad3fc1d69b94da325ac08886a01')[split]) places_transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((256, 256)), torchvision.transforms.CenterCrop(crop_size or 224), torchvision.transforms.ToTensor(), renormalize.NORMALIZER['imagenet'] ]) return parallelfolder.ParallelImageFolders([dirname], classification=True, shuffle=True, transform=places_transform)
def tally_directory(directory, size=10000, seed=1): dataset = parallelfolder.ParallelImageFolders( [directory], transform=transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) loader = DataLoader( dataset, sampler=FixedRandomSubsetSampler(dataset, end=size, seed=1), # sampler=FixedSubsetSampler(range(size)), batch_size=10, pin_memory=True) upp = segmenter.UnifiedParsingSegmenter() labelnames, catnames = upp.get_label_and_category_names() result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float) batch_result = torch.zeros(loader.batch_size, NUM_OBJECTS, dtype=torch.float).cuda() with torch.no_grad(): batch_index = 0 for [batch] in pbar(loader): seg_result = upp.segment_batch(batch.cuda()) for i in range(len(batch)): batch_result[i] = (seg_result[i, 0].view(-1).bincount( minlength=NUM_OBJECTS).float() / (seg_result.shape[2] * seg_result.shape[3])) result[batch_index:batch_index + len(batch)] = (batch_result.cpu().numpy()) batch_index += len(batch) return result
def load_dataset(domain, split=None, full=False, crop_size=None, download=True): if domain in ['places', 'imagenet']: if split is None: split = 'val' #Create the directory. dirname = 'datasets/%s/%s' % (domain, split) if download and not os.path.exists(dirname) and domain == 'places': os.makedirs('datasets', exist_ok=True) torchvision.datasets.utils.download_and_extract_archive( 'https://dissect.csail.mit.edu/datasets/' + 'places_%s.zip' % split, 'datasets', md5=dict(val='593bbc21590cf7c396faac2e600cd30c', train='d1db6ad3fc1d69b94da325ac08886a01')[split]) places_transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((256, 256)), torchvision.transforms.CenterCrop(crop_size or 224), torchvision.transforms.ToTensor(), renormalize.NORMALIZER['imagenet']]) return parallelfolder.ParallelImageFolders([dirname], classification=True, shuffle=True, transform=places_transform) else: #This is what I added, here is also on UCF101 but for network dissection experiment. DATASET_DIR = os.path.abspath(os.path.join("/", "mnt", "sdb", "danielw")) RGB_DIR = os.path.join(DATASET_DIR, "jpegs_256") train_transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((256, 256)), # torchvision.transforms.CenterCrop(crop_size or 224), torchvision.transforms.CenterCrop(crop_size), torchvision.transforms.ToTensor(), renormalize.NORMALIZER['imagenet'] ]) dataset = RGB_Dissect(data_root=RGB_DIR, is_train=True, transform=train_transform) dataloader = DataLoader(dataset=dataset, batch_size=64, shuffle=True, num_workers=6) return dataset
from IPython.display import SVG from matplotlib import pyplot as plt torch.set_grad_enabled(False) def normalize_filename(n): return re.match(r'^(.*Places365_\w+_\d+)', n).group(1) ds = parallelfolder.ParallelImageFolders( ['datasets/places/val', 'datasets/stylized-places/val'], transform=torchvision.transforms.Compose([ torchvision.transforms.Resize(256), # transforms.CenterCrop(224), torchvision.transforms.CenterCrop(256), torchvision.transforms.ToTensor(), renormalize.NORMALIZER['imagenet'], ]), normalize_filename=normalize_filename, shuffle=True) layers = [ 'conv5_3', 'conv5_2', 'conv5_1', 'conv4_3', 'conv4_2', 'conv4_1', 'conv3_3', 'conv3_2',
segmodel = segmenter.UnifiedParsingSegmenter(segsizes=[256]) seglabels = [l for l, c in segmodel.get_label_and_category_names()[0]] # Load places dataset from netdissect import parallelfolder, renormalize from torchvision import transforms center_crop = transforms.Compose([ transforms.Resize((256,256)), transforms.CenterCrop(224), transforms.ToTensor(), renormalize.NORMALIZER['imagenet'] ]) dataset = parallelfolder.ParallelImageFolders( ['dataset/places/val'], transform=[center_crop], classification=True, shuffle=True) train_dataset = parallelfolder.ParallelImageFolders( ['dataset/places/train'], transform=[center_crop], classification=True, shuffle=True) # Collect unconditional quantiles from netdissect import tally upfn = upsample.upsampler( (56, 56), # The target output shape (7, 7), source=dataset, )
def main(): args = parseargs() experiment_dir = 'results/decoupled-%d-%s-resnet' % (args.selected_classes, args.dataset) ds_dirname = dict(novelty='novelty/dataset_v1/known_classes/images', imagenet='imagenet')[args.dataset] training_dir = 'datasets/%s/train' % ds_dirname val_dir = 'datasets/%s/val' % ds_dirname os.makedirs(experiment_dir, exist_ok=True) with open(os.path.join(experiment_dir, 'args.txt'), 'w') as f: f.write(str(args) + '\n') def printstat(s): with open(os.path.join(experiment_dir, 'log.txt'), 'a') as f: f.write(str(s) + '\n') pbar.print(s) def filter_tuple(item): return item[1] < args.selected_classes # Imagenet has a couple bad exif images. warnings.filterwarnings('ignore', message='.*orrupt EXIF.*') # Here's our data train_loader = torch.utils.data.DataLoader( parallelfolder.ParallelImageFolders( [training_dir], classification=True, filter_tuples=filter_tuple, transform=transforms.Compose([ transforms.Resize(256), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), renormalize.NORMALIZER['imagenet'], ])), batch_size=64, shuffle=True, num_workers=48, pin_memory=True) val_loader = torch.utils.data.DataLoader( parallelfolder.ParallelImageFolders( [val_dir], classification=True, filter_tuples=filter_tuple, transform=transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), renormalize.NORMALIZER['imagenet'], ])), batch_size=64, shuffle=False, num_workers=24, pin_memory=True) late_model = torchvision.models.resnet50(num_classes=args.selected_classes) for n, p in late_model.named_parameters(): if 'bias' in n: torch.nn.init.zeros_(p) elif len(p.shape) <= 1: torch.nn.init.ones_(p) else: torch.nn.init.kaiming_normal_(p, nonlinearity='relu') late_model.train() late_model.cuda() model = late_model max_lr = 5e-3 max_iter = args.training_iterations def criterion(logits, true_class): goal = torch.zeros_like(logits) goal.scatter_(1, true_class[:, None], value=1.0) return torch.nn.functional.binary_cross_entropy_with_logits( logits, goal) optimizer = torch.optim.Adam(model.parameters()) scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, total_steps=max_iter - 1) iter_num = 0 best = dict(val_accuracy=0.0) # Oh, hold on. Let's actually resume training if we already have a model. checkpoint_filename = 'weights.pth' best_filename = 'best_%s' % checkpoint_filename best_checkpoint = os.path.join(experiment_dir, best_filename) try_to_resume_training = False if try_to_resume_training and os.path.exists(best_checkpoint): checkpoint = torch.load(os.path.join(experiment_dir, best_filename)) iter_num = checkpoint['iter'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) best['val_accuracy'] = checkpoint['accuracy'] def save_checkpoint(state, is_best): filename = os.path.join(experiment_dir, checkpoint_filename) os.makedirs(os.path.dirname(filename), exist_ok=True) torch.save(state, filename) if is_best: shutil.copyfile(filename, os.path.join(experiment_dir, best_filename)) def validate_and_checkpoint(): model.eval() val_loss, val_acc = AverageMeter(), AverageMeter() for input, target in pbar(val_loader): # Load data input_var, target_var = [d.cuda() for d in [input, target]] # Evaluate model with torch.no_grad(): output = model(input_var) loss = criterion(output, target_var) _, pred = output.max(1) accuracy = (target_var.eq(pred) ).data.float().sum().item() / input.size(0) val_loss.update(loss.data.item(), input.size(0)) val_acc.update(accuracy, input.size(0)) # Check accuracy pbar.post(l=val_loss.avg, a=val_acc.avg) # Save checkpoint save_checkpoint( { 'iter': iter_num, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'accuracy': val_acc.avg, 'loss': val_loss.avg, }, val_acc.avg > best['val_accuracy']) best['val_accuracy'] = max(val_acc.avg, best['val_accuracy']) printstat('Iteration %d val accuracy %.2f' % (iter_num, val_acc.avg * 100.0)) # Here is our training loop. while iter_num < max_iter: for filtered_input, filtered_target in pbar(train_loader): # Track the average training loss/accuracy for each epoch. train_loss, train_acc = AverageMeter(), AverageMeter() # Load data input_var, target_var = [ d.cuda() for d in [filtered_input, filtered_target] ] # Evaluate model output = model(input_var) loss = criterion(output, target_var) train_loss.update(loss.data.item(), filtered_input.size(0)) # Perform one step of SGD if iter_num > 0: optimizer.zero_grad() loss.backward() optimizer.step() # Learning rate schedule scheduler.step() # Also check training set accuracy _, pred = output.max(1) accuracy = (target_var.eq(pred)).data.float().sum().item() / ( filtered_input.size(0)) train_acc.update(accuracy) remaining = 1 - iter_num / float(max_iter) pbar.post(l=train_loss.avg, a=train_acc.avg, v=best['val_accuracy']) # Ocassionally check validation set accuracy and checkpoint if iter_num % 1000 == 0: validate_and_checkpoint() model.train() # Advance iter_num += 1 if iter_num >= max_iter: break