salient_filters[ind] = 999999 # if len(salient_filters) == 0: # break return input_gradient if __name__ == "__main__": # Choose GPU device and print status information: setup = inversefed.utils.system_startup(args) start_time = time.time() # Prepare for training # Get data: loss_fn, trainloader, validloader = inversefed.construct_dataloaders( args.dataset, defs, data_path=args.data_path) dm = torch.as_tensor( getattr(inversefed.consts, f'{args.dataset.lower()}_mean'), **setup)[:, None, None] ds = torch.as_tensor( getattr(inversefed.consts, f'{args.dataset.lower()}_std'), **setup)[:, None, None] if args.dataset == 'ImageNet': if args.model == 'ResNet152': model = torchvision.models.resnet152(pretrained=args.trained_model) elif args.model == 'ResNet50': model = torchvision.models.resnet50(pretrained=args.trained_model) else: model = torchvision.models.resnet18(pretrained=args.trained_model)
def preprocess(opt, defs, valid=False): if opt.data == 'cifar100': loss_fn, trainloader, validloader = inversefed.construct_dataloaders( 'CIFAR100', defs) trainset, validset = _build_cifar100('~/data/') if len(opt.aug_list) > 0: policy_list = split(opt.aug_list) else: policy_list = [] if not valid: trainset.transform = build_transform(True, policy_list, opt, defs) trainloader = torch.utils.data.DataLoader(trainset, batch_size=defs.batch_size, shuffle=True, drop_last=False, num_workers=4, pin_memory=True) if valid: validset.transform = build_transform(True, policy_list, opt, defs) validloader = torch.utils.data.DataLoader(validset, batch_size=defs.batch_size, shuffle=False, drop_last=False, num_workers=4, pin_memory=True) return loss_fn, trainloader, validloader elif opt.data == 'FashionMinist': loss_fn, _, _ = inversefed.construct_dataloaders('CIFAR100', defs) trainset = torchvision.datasets.FashionMNIST( '../data', train=True, download=True, transform=transforms.Compose([ lambda x: transforms.functional.to_grayscale( x, num_output_channels=3), transforms.Resize(32), transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) validset = torchvision.datasets.FashionMNIST( '../data', train=False, download=True, transform=transforms.Compose([ lambda x: transforms.functional.to_grayscale( x, num_output_channels=3), transforms.Resize(32), transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) # key if len(opt.aug_list) > 0: policy_list = split( opt.aug_list) # [int(idx) for idx in opt.aug_list.split('-')] else: policy_list = [] tlist = policy_list if not valid else list() trainset.transform = build_transform(True, tlist, opt, defs) trainloader = torch.utils.data.DataLoader(trainset, batch_size=defs.batch_size, shuffle=True, drop_last=False, num_workers=4, pin_memory=True) tlist = list() if not valid else policy_list validset.transform = build_transform(True, tlist, opt, defs) validloader = torch.utils.data.DataLoader(validset, batch_size=defs.batch_size, shuffle=False, drop_last=False, num_workers=4, pin_memory=True) return loss_fn, trainloader, validloader else: raise NotImplementedError