salient_filters[ind] = 999999
            # if len(salient_filters) == 0:
            #     break

    return input_gradient


if __name__ == "__main__":
    # Choose GPU device and print status information:
    setup = inversefed.utils.system_startup(args)
    start_time = time.time()

    # Prepare for training

    # Get data:
    loss_fn, trainloader, validloader = inversefed.construct_dataloaders(
        args.dataset, defs, data_path=args.data_path)

    dm = torch.as_tensor(
        getattr(inversefed.consts, f'{args.dataset.lower()}_mean'),
        **setup)[:, None, None]
    ds = torch.as_tensor(
        getattr(inversefed.consts, f'{args.dataset.lower()}_std'),
        **setup)[:, None, None]

    if args.dataset == 'ImageNet':
        if args.model == 'ResNet152':
            model = torchvision.models.resnet152(pretrained=args.trained_model)
        elif args.model == 'ResNet50':
            model = torchvision.models.resnet50(pretrained=args.trained_model)
        else:
            model = torchvision.models.resnet18(pretrained=args.trained_model)
Ejemplo n.º 2
0
def preprocess(opt, defs, valid=False):
    if opt.data == 'cifar100':
        loss_fn, trainloader, validloader = inversefed.construct_dataloaders(
            'CIFAR100', defs)
        trainset, validset = _build_cifar100('~/data/')

        if len(opt.aug_list) > 0:
            policy_list = split(opt.aug_list)
        else:
            policy_list = []
        if not valid:
            trainset.transform = build_transform(True, policy_list, opt, defs)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=defs.batch_size,
                                                  shuffle=True,
                                                  drop_last=False,
                                                  num_workers=4,
                                                  pin_memory=True)

        if valid:
            validset.transform = build_transform(True, policy_list, opt, defs)
        validloader = torch.utils.data.DataLoader(validset,
                                                  batch_size=defs.batch_size,
                                                  shuffle=False,
                                                  drop_last=False,
                                                  num_workers=4,
                                                  pin_memory=True)

        return loss_fn, trainloader, validloader

    elif opt.data == 'FashionMinist':
        loss_fn, _, _ = inversefed.construct_dataloaders('CIFAR100', defs)
        trainset = torchvision.datasets.FashionMNIST(
            '../data',
            train=True,
            download=True,
            transform=transforms.Compose([
                lambda x: transforms.functional.to_grayscale(
                    x, num_output_channels=3),
                transforms.Resize(32),
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ]))
        validset = torchvision.datasets.FashionMNIST(
            '../data',
            train=False,
            download=True,
            transform=transforms.Compose([
                lambda x: transforms.functional.to_grayscale(
                    x, num_output_channels=3),
                transforms.Resize(32),
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ]))
        # key
        if len(opt.aug_list) > 0:
            policy_list = split(
                opt.aug_list)  # [int(idx) for idx in opt.aug_list.split('-')]
        else:
            policy_list = []
        tlist = policy_list if not valid else list()
        trainset.transform = build_transform(True, tlist, opt, defs)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=defs.batch_size,
                                                  shuffle=True,
                                                  drop_last=False,
                                                  num_workers=4,
                                                  pin_memory=True)

        tlist = list() if not valid else policy_list
        validset.transform = build_transform(True, tlist, opt, defs)
        validloader = torch.utils.data.DataLoader(validset,
                                                  batch_size=defs.batch_size,
                                                  shuffle=False,
                                                  drop_last=False,
                                                  num_workers=4,
                                                  pin_memory=True)

        return loss_fn, trainloader, validloader
    else:
        raise NotImplementedError